带你吃透分治算法(二)矩阵相乘

分治策略(二)

我们在上一篇博文介绍了分治策略思想以及解决最大子数组的问题。

见:带你吃透分治算法之最大子数组

下面我们来介绍应用分治思想解决“矩阵相乘问题”(虽然用分治策略解决矩阵相乘问题时间复杂度不低,但是它反映的是一种思想,当我们遇到棘手问题不好解决,往往先找到能够解决的办法比找到最优办法更重要)

矩阵乘法

矩阵乘法我们应该都有所了解,若矩阵A和矩阵B均是nxn的方阵,则定义乘积C=A·B中的元素c(i,j)为:
在这里插入图片描述

暴力解法

暴力解法就是按照矩阵乘积定义进行运算,得到最终的C矩阵,时间复杂度为O(n^3)。

实现代码如下:

//矩阵相乘
Matrix MatrixMultip(Matrix A,Matrix B)
{
    Matrix C;
    Zero_Martix(&C);
    if(A.cols == B.rows)
    {
        for(int i=0;i<A.rows;i++)
        {
            for(int j=0;j<B.cols;j++)
            {
                for(int k=0;k<A.cols;k++)
                {
                    C.martix[i][j] += A.martix[i][k]*B.martix[k][j];
                }
            }
        }
        C.rows = A.rows;
        C.cols = B.cols;
        return C;
    }else
    {
        C.cols = -1;
        C.rows = -1;
        return C;
    }
}

其中结构体定义如下:

#define MAX 16
typedef int ElemType;

typedef struct Matrix{
    ElemType martix[MAX][MAX];
    int rows,cols;//行数和列数
}Matrix;

那么如何采用分治思想解决该问题呢?来咯!

分治算法解决方案

我们再来回顾下分治算法思想:将一个规模为N的问题分解为K个规模较小的子问题,这些子问题相互独立且与原问题性质相同。求出子问题的解,就可得到原问题的解。

扫描二维码关注公众号,回复: 11478827 查看本文章

那么如何将它套用到这个问题中呢?我们发现,一个矩阵与其子矩阵的关系就像上文说的原问题与规模较小的子问题的关系一样的。

我们想到矩阵相乘的分块矩阵相乘的原理不就是这样吗!

设三个矩阵A、B、C均为nxn的矩阵,其中n为2的幂(因为再每个分解步骤中,nxn矩阵都被划分为4个(n/2)x(n/2)的子矩阵,如果假定n是2的幂,则只要n≥2即可保证子矩阵规模n/2为整数),因此,若n不是2的幂,则不能使用分治法计算

假定将A、B、C军分解为4个(n/2)x(n/2)的子矩阵,如下图所示:
在这里插入图片描述

则可以将公式改写为:
在这里插入图片描述
即等价于如下四个公式:
在这里插入图片描述

这个时候大神们应该都发现了其中的递归结构,那么接下来就是设计分治算法结构了。

C语言代码实现:

//矩阵相乘——分治策略
Matrix MatrixMultip(Matrix A,Matrix B)
{
    Matrix C;
    Zero_Martix(&C);
    if(A.cols == B.rows)
    {
        for(int i=0;i<A.rows;i++)
        {
            for(int j=0;j<B.cols;j++)
            {
                for(int k=0;k<A.cols;k++)
                {
                    C.martix[i][j] += A.martix[i][k]*B.martix[k][j];
                }
            }
        }
        C.rows = A.rows;
        C.cols = B.cols;
        return C;
    }else
    {
        C.cols = -1;
        C.rows = -1;
        return C;
    }
}

//矩阵相乘——分治策略
Matrix MatrixMultip_DAC(Matrix A,Matrix B)
{
    Matrix C;
    if(A.rows == 2)
    {
        C = MatrixMultip_Divide(A,B); //最小化为2*2矩阵时,直接暴力求解
        return C;
    }
    if(A.rows > 2)
    {
        int m = A.rows/2;
        //求取A矩阵的分块矩阵
        Matrix a1 = QuarterMatrix(A,1);
        Matrix a2 = QuarterMatrix(A,2);
        Matrix a3 = QuarterMatrix(A,3);
        Matrix a4 = QuarterMatrix(A,4);
        //求取B矩阵的分块矩阵
        Matrix b1 = QuarterMatrix(B,1);
        Matrix b2 = QuarterMatrix(B,2);
        Matrix b3 = QuarterMatrix(B,3);
        Matrix b4 = QuarterMatrix(B,4);
        //求取对于C矩阵的分块矩阵
        Matrix res1 = AddTwoMartix(MatrixMultip_DAC(a1,b1),MatrixMultip_DAC(a2,b3));
        Matrix res2 = AddTwoMartix(MatrixMultip_DAC(a1,b2),MatrixMultip_DAC(a2,b4));
        Matrix res3 = AddTwoMartix(MatrixMultip_DAC(a3,b1),MatrixMultip_DAC(a4,b3));
        Matrix res4 = AddTwoMartix(MatrixMultip_DAC(a3,b2),MatrixMultip_DAC(a4,b4));
        //合并为C矩阵
        return CombineMatrix(res1,res2,res3,res4);
    }
}

//将所得C矩阵的分块子矩阵合并
Matrix CombineMatrix(Matrix res1,Matrix res2,Matrix res3,Matrix res4)
{
    Matrix Result;
    int n = res1.rows*2;
    Result.rows = Result.cols = n;
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<n;j++)
        {//
            if(i>=0 && i<n/2 && j<n/2 && j>=0)
                Result.martix[i][j] = res1.martix[i][j];
            else if(i>=0 && i<n/2 && j<n && j>=n/2)
                Result.martix[i][j] = res2.martix[i][j-n/2];
            else if(i<n && i>=n/2 && j<n/2 && j>=0)
                Result.martix[i][j] = res3.martix[i-n/2][j];
            else
                Result.martix[i][j] = res4.martix[i-n/2][j-n/2];
        }
    }
    return Result;
}
//将两个矩阵相加
Matrix AddTwoMartix(Matrix A,Matrix B)
{
    Matrix C;
    C.rows = A.rows;
    C.cols = A.cols;
    for(int i=0;i<A.rows;i++)
    {
        for(int j=0;j<A.cols;j++)
        {
            C.martix[i][j] = A.martix[i][j] + B.martix[i][j];
        }
    }
    return C;
}

//计算2*2矩阵
Matrix MatrixMultip_Divide(Matrix A,Matrix B)//此时两者都已经为2*2矩阵了
{
    Matrix C;
    Zero_Martix(&C);
    if(A.cols == B.rows)
    {
        C.martix[0][0] = A.martix[0][0]*B.martix[0][0]+A.martix[0][1]*B.martix[1][0];
        C.martix[0][1] = A.martix[0][0]*B.martix[0][1]+A.martix[0][1]*B.martix[1][1];
        C.martix[1][0] = A.martix[1][0]*B.martix[0][0]+A.martix[1][1]*B.martix[1][0];
        C.martix[1][1] = A.martix[1][0]*B.martix[0][1]+A.martix[1][1]*B.martix[1][1];
        C.rows = A.rows;
        C.cols = B.cols;
    }else
    {
        C.cols = -1;
        C.rows = -1;
    }
    return C;
}

//获得矩阵的分块矩阵
Matrix QuarterMatrix(Matrix M,int index)
{
    Matrix res;
    res.cols = res.rows = M.rows / 2;
    switch (index)
    {
    case 1: 
    for(int i=0;i<res.rows;i++)
    {
        for(int j=0;j<res.cols;j++)
        {
            res.martix[i][j] = M.martix[i][j];
        }
    }
    break;
    
    case 2: 
    for(int i=0;i<res.rows;i++)
    {
        for(int j=0;j<res.cols;j++)
        {
            res.martix[i][j] = M.martix[i][j+res.cols];
        }
    }
    break;

    case 3: 
    for(int i=0;i<res.rows;i++)
    {
        for(int j=0;j<res.cols;j++)
        {
            res.martix[i][j] = M.martix[i+res.rows][j];
        }
    }
    break;

    case 4: 
    for(int i=0;i<res.rows;i++)
    {
        for(int j=0;j<res.cols;j++)
        {
            res.martix[i][j] = M.martix[i+res.rows][j+res.cols];
        }
    }
    break;

    default:
    break;
    }

    return res;
}
//随机生成矩阵
Matrix CreatMatrix_Random(int row,int col)
{
    Matrix M;
    M.cols = col;
    M.rows = row;
    for(int i=0;i<row;i++)
    {
        for(int j=0;j<col;j++)
        {
            M.martix[i][j] = rand()%12;
        }
    }
    return M;
}

猜你喜欢

转载自blog.csdn.net/qq_42642142/article/details/107546325