斯特拉森(Strassen)算法

斯特拉森(Strassen)算法简介

先借用一下百度百科:
在这里插入图片描述

简单解释

通俗的说,斯特拉森算法把原先普通二阶矩阵相乘需要的8次乘法压缩到7次乘法,而在计算机,乘法运算的耗时远远高于加减运算,所以斯特拉森算法可以将O(d^ 3) 压缩到O(d^2.8)。
需要知道的是,斯特拉森算法只是对矩阵分治的算法而不是单独的乘法算法,分治完成时最后使用的还是普通矩阵乘法,在阶数小于等于32(或者64?看过不同的实验结果)时普通的矩阵乘法会有更快的速度,而随着矩阵的阶不断增加,斯特拉森可以提供更快的速度。

一些拓展

在这里,提供另外一种形式的斯特拉森:
在这里插入图片描述
在这里插入图片描述
同样是通过代数的分解与合并,我们构造出另外一种看起来更复杂的算法。这次我们分解出的项数更多,意味着拓展性更高,于是我们在求矩阵平方时有了新的改进。

矩阵的平方-斯特拉森

在这里插入图片描述
我们只需要沿用之前图中的2.s。然后7次乘法使用上图方法,便可以减少一部分的预运算,而上图中,我们依然有P1P2P3可以递归入经过平方优化的斯特拉森算法,获得更快的速度。

C语言描述斯特拉森标准模板

第一个部分我们先来写核心思路。

int Strassen(int N, int **MatrixA, int **MatrixB, int **MatrixC)
{
    
    int HalfSize = N/2;
    int newSize = N/2;
    int nsize=N/2;
    int i=0,j=0;
    int newLength=0;
    newLength = newSize;
    if ( N <= 32)//choosing the threshold is extremely important, try N<=2 to see the result
    {
    
    
        MUL(MatrixA,MatrixB,MatrixC,N);
    }
    else
    {
    
    int** A11;
        int** A12;
        int** A21;
        int** A22;
        int** B11;
        int** B12;
        int** B21;
        int** B22;
        int** C11;
        int** C12;
        int** C21;
        int** C22;
        int** M1;
        int** M2;
        int** M3;
        int** M4;
        int** M5;
        int** M6;
        int** M7;
        int** AResult;
        int** BResult;
        }//此处省略动态内存的申请,最后也省略了内存的释放
        for (i = 0; i < N / 2; i++)
        {
    
    
            for (j = 0; j < N / 2; j++)
            {
    
    
                A11[i][j] = MatrixA[i][j];
                A12[i][j] = MatrixA[i][j + N / 2];
                A21[i][j] = MatrixA[i + N / 2][j];
                A22[i][j] = MatrixA[i + N / 2][j + N / 2];
                B11[i][j] = MatrixB[i][j];
                B12[i][j] = MatrixB[i][j + N / 2];
                B21[i][j] = MatrixB[i + N / 2][j];
                B22[i][j] = MatrixB[i + N / 2][j + N / 2];
            }
        }
        //here we calculate M1..M7 matrices .
        //M1[][]
        ADD( A11,A22,AResult, HalfSize);
        ADD( B11,B22,BResult, HalfSize);
        Strassen( HalfSize, AResult, BResult, M1 ); //now that we need to multiply this , we use the strassen itself .
        //M2[][]
        ADD( A21,A22,AResult, HalfSize);              //M2=(A21+A22)B11
        Strassen(HalfSize, AResult, B11, M2);       //Mul(AResult,B11,M2);
        //M3[][]
        SUB( B12,B22,BResult, HalfSize);              //M3=A11(B12-B22)
        Strassen(HalfSize, A11, BResult, M3);       //Mul(A11,BResult,M3);
        //M4[][]
        SUB( B21, B11, BResult, HalfSize);           //M4=A22(B21-B11)
        Strassen(HalfSize, A22, BResult, M4);       //Mul(A22,BResult,M4);
        //M5[][]
        ADD( A11, A12, AResult, HalfSize);           //M5=(A11+A12)B22
        Strassen(HalfSize, AResult, B22, M5);       //Mul(AResult,B22,M5);
        //M6[][]
        SUB( A21, A11, AResult, HalfSize);
        ADD( B11, B12, BResult, HalfSize);             //M6=(A21-A11)(B11+B12)
        Strassen( HalfSize, AResult, BResult, M6);    //Mul(AResult,BResult,M6);
        //M7[][]
        SUB(A12, A22, AResult, HalfSize);
        ADD(B21, B22, BResult, HalfSize);             //M7=(A12-A22)(B21+B22)
        Strassen(HalfSize, AResult, BResult, M7);     //Mul(AResult,BResult,M7);
        //C11 = M1 + M4 - M5 + M7;
        ADD( M1, M4, AResult, HalfSize);
        SUB( M7, M5, BResult, HalfSize);
        ADD( AResult, BResult, C11, HalfSize);
        //C12 = M3 + M5;
        ADD( M3, M5, C12, HalfSize);
        //C21 = M2 + M4;
        ADD( M2, M4, C21, HalfSize);
        //C22 = M1 + M3 - M2 + M6;
        ADD( M1, M3, AResult, HalfSize);
        SUB( M6, M2, BResult, HalfSize);
        ADD( AResult, BResult, C22, HalfSize);
        //at this point , we have calculated the c11..c22 matrices, and now we are going to
        //put them together and make a unit matrix which would describe our resulting Matrix.
        for (i = 0; i < N/2 ; i++)
        {
    
    
            for (j = 0 ; j < N/2 ; j++)
            {
    
    
                MatrixC[i][j] = C11[i][j];
                MatrixC[i][j + N / 2] = C12[i][j];
                MatrixC[i + N / 2][j] = C21[i][j];
                MatrixC[i + N / 2][j + N / 2] = C22[i][j];
            }
        }
return 0;
}

可以看到我们不断递归对矩阵做分治拆解,简单来说就是把一个大块矩阵不断做斯特拉森拆解,考虑速度在斯特拉森函数内部只做加减法,当拆解到一定大小的矩阵时,将每个小矩阵投入普通的矩阵乘法运算,至于这个 ‘一定大小’ 在上面的代码中定义为32,很有可能在64达到最快速度。
于是我们有三个运算:
乘:MUL函数,采用标准三次方速度的矩阵乘法;
加:ADD函数,执行普通矩阵加法;
减:SUB函数,执行普通矩阵减法;
如下:

int ADD(int** MatrixA, int** MatrixB, int** MatrixResult, int MatrixSize )
{
    
    int i=0,j=0;
    for ( i = 0; i < MatrixSize; i++)
    {
    
    
        for ( j = 0; j < MatrixSize; j++)
        {
    
    
            MatrixResult[i][j] =  MatrixA[i][j] + MatrixB[i][j];
        }
    }
return 0;
}

int SUB(int** MatrixA, int** MatrixB, int** MatrixResult, int MatrixSize )
{
    
    int i=0,j=0;
    for ( i = 0; i < MatrixSize; i++)
    {
    
    
        for (  j = 0; j < MatrixSize; j++)
        {
    
    
            MatrixResult[i][j] =  MatrixA[i][j] - MatrixB[i][j];
        }
    }
return 0;
}

int MUL( int** MatrixA, int** MatrixB, int** MatrixResult, int MatrixSize )
{
    
    int i=0,j=0,k=0;
    for (i=0;i<MatrixSize ;i++)
    {
    
    
        for ( j=0;j<MatrixSize ;j++)
        {
    
    
            MatrixResult[i][j]=0;
            for (k=0;k<MatrixSize ;k++)
            {
    
    
                MatrixResult[i][j]=MatrixResult[i][j]+MatrixA[i][k]*MatrixB[k][j];
            }
        }
    }return 0;
}

主函数可自行定义输入输出方法:
这里,我们输入的是矩阵的阶数和一个矩阵A,输出矩阵A*A

但这是一个普通的斯特拉森模板,没有针对平方的特殊处理,输入输出操作可自行更改。
PS:我们不建议使用任何全局变量。

int main()
{
    
    int i=0,j=0,size=0;
    int **a,**new1;
  scanf("%d",&size);
  if(size==1)
  {
    
    
      scanf("%d",&i);
      printf("%d",i*i);
      return 0;
  }
   new1= (int **) malloc(sizeof(int *) *size);
    for(i=0;i<size;i++)
        {
    
    new1[i]=(int *)malloc(sizeof(int)*size);}
a= (int **) malloc(sizeof(int *) *size);
    for(i=0;i<size;i++)
        {
    
    a[i]=(int *)malloc(sizeof(int)*size);}
   for(i=0;i<size;i++)
        for(j=0;j<size;j++)
   {
    
    
       scanf("%d",&a[i][j]);
   }
  Strassen(size,a,a,new1);

  for(i=0;i<size;i++){
    
    
for(j=0;j<size;j++){
    
    
printf("%d ",new1[i][j]);
}
printf("\n");
}
  for(i = 0; i <size; i ++)
    {
    
    free(a[i]);
        free(new1[i]);
   }
     free(a); free(new1);

    return 0;
}

至于只针对求平方的代码优化,有机会再写。

猜你喜欢

转载自blog.csdn.net/weixin_43736127/article/details/110496581