当前位置:编程学习 > C/C++ >>

矩阵乘法 之 strassen 算法

一般情况下矩阵乘法需要三个for循环,时间复杂度为O(n^3),现在我们将矩阵分块如图:( 来自MIT算法导论 )
一般算法需要八次乘法
r = a * e + b * g ;
s = a * f  + b * h ;
t = c * e + d  * g; 
u = c * f + d * h;
 
strassen将其变成7次乘法,因为大家都知道乘法比加减法消耗更多,所有时间复杂更高!
strassen的处理是:
令:
p1 = a * ( f - h )
p2 = ( a + b ) *  h
p3 = ( c +d ) * e
p4 = d *  ( g - e )
p5 = ( a + d ) * ( e + h )
p6 =  ( b - d ) * ( g + h ) 
p7 = ( a - c ) * ( e + f )
 
那么我们可以知道:
r  = p5 + p4 + p6 - p2
s = p1 + p2
t = p3 + p4
u = p5 + p1 - p3 - p7
 
我们可以看到上面只有7次乘法和多次加减法,最终达到降低复杂度为O( n^lg7 ) ~= O( n^2.81 );
代码实现如下:
[cpp]  
// strassen 算法:将矩阵相乘的复杂度降到O(n^lg7) ~= O(n^2.81)  
// 原理是将8次乘法减少到7次的处理  
// 现在理论上的最好的算法是O(n^2,367),仅仅是理论上的而已  
//  
//  
// 下面的代码仅仅是简单的实例而已,不必较真哦,呵呵~  
// 下面的空间可以优化的,此处就不麻烦了~  
  
#include <stdio.h>  
  
#define  N  10  
  
//matrix + matrix  
void plus( int t[N/2][N/2], int r[N/2][N/2], int s[N/2][N/2] )  
{  
    int i, j;  
    for( i = 0; i < N / 2; i++ )  
    {  
        for( j = 0; j < N / 2; j++ )  
        {  
            t[i][j] = r[i][j] + s[i][j];  
        }  
    }  
}  
  
//matrix - matrix  
void minus( int t[N/2][N/2], int r[N/2][N/2], int s[N/2][N/2] )  
{  
    int i, j;  
    for( i = 0; i < N / 2; i++ )  
    {  
        for( j = 0; j < N / 2; j++ )  
        {  
            t[i][j] = r[i][j] - s[i][j];  
        }  
    }  
}  
  
//matrix * matrix  
void mul( int t[N/2][N/2], int r[N/2][N/2], int s[N/2][N/2]  )  
{  
    int i, j, k;  
    for( i = 0; i < N / 2; i++ )  
    {  
        for( j = 0; j < N / 2; j++ )  
        {  
            t[i][j] = 0;  
            for( k = 0; k < N / 2; k++ )  
            {  
                t[i][j] += r[i][k] * s[k][j];  
            }  
        }  
    }  
}  
  
int main()  
{  
    int i, j, k;  
    int mat[N][N];  
    int m1[N][N];  
    int m2[N][N];  
    int a[N/2][N/2],b[N/2][N/2],c[N/2][N/2],d[N/2][N/2];  
    int e[N/2][N/2],f[N/2][N/2],g[N/2][N/2],h[N/2][N/2];  
    int p1[N/2][N/2],p2[N/2][N/2],p3[N/2][N/2],p4[N/2][N/2];  
    int p5[N/2][N/2],p6[N/2][N/2],p7[N/2][N/2];  
    int r[N/2][N/2], s[N/2][N/2], t[N/2][N/2], u[N/2][N/2], t1[N/2][N/2], t2[N/2][N/2];  
  
  
    printf("\nInput the first matrix...:\n");  
    for( i = 0; i < N; i++ )  
    {  
        for( j = 0; j < N; j++ )  
        {  
            scanf("%d", &m1[i][j]);  
        }  
    }  
  
    printf("\nInput the second matrix...:\n");  
    for( i = 0; i < N; i++ )  
    {  
        for( j = 0; j < N; j++ )  
        {  
            scanf("%d", &m2[i][j]);  
        }  
    }  
  
    // a b c d e f g h  
    for( i = 0; i < N / 2; i++ )  
    {  
        for( j = 0; j < N / 2; j++ )  
        {  
            a[i][j] = m1[i][j];  
            b[i][j] = m1[i][j + N / 2];  
            c[i][j] = m1[i + N / 2][j];  
            d[i][j] = m1[i + N / 2][j + N / 2];  
            e[i][j] = m2[i][j];  
            f[i][j] = m2[i][j + N / 2];  
            g[i][j] = m2[i + N / 2][j];  
            h[i][j] = m2[i + N / 2][j + N / 2];  
        }  
    }  
      
    //p1  
    minus( r, f, h );  
    mul( p1, a, r );   
  
    //p2  
    plus( r, a, b );  
    mul( p2, r, h );  
  
    //p3  
    plus( r, c, d );  
    mul( p3, r, e );  
  
    //p4  
    minus( r, g, e );  
    mul( p4, d, r );  
  
    //p5  
    plus( r, a, d );  
    plus( s, e, f );  
    mul( p5, r, s );  
  
    //p6  
    minus( r, b, d );  
    plus( s, g, h );  
    mul( p6, r, s );  
  
    //p7  
    minus( r, a, c );  
    plus( s, e, f );  
    mul( p7, r, s );  
  
    //r = p5 + p4 - p2 + p6  
    plus( t1, p5, p4 );  
补充:软件开发 , C++ ,
CopyRight © 2022 站长资源库 编程知识问答 zzzyk.com All Rights Reserved
部分文章来自网络,