算法题03:分治与递归:三种矩阵乘法(蛮力法,朴素分治法、Strassen法)

问题:
A A B B 是两个 n × n n\times n 阶矩阵,求它们的乘积矩阵C。这里,假设 n n 2 2 的幂次方。

一、问题分析(模型、算法设计和正确性证明等)

​ 实验要求使用分治法解决n阶矩阵(n是2的幂次方)相乘问题,因为n是2的幂次方,可以使用朴素分块矩阵乘法或者 Strassen 法,这里两种都尝试一下,顺便连蛮力法也放进去。

二、复杂度分析

蛮力法伪代码

for i = 1 to n do:
	for j = 1 to n do:
		for k = 1 to n do:
			C[i][j] = C[i][j] + A[i][k]・B[k][j]

显然时间复杂度为 O ( n 3 ) O(n^3) .

朴素分块矩阵乘法伪代码

Divide_And_Conquer(int[][]A,int[][]B,int n){

    int [][]C = new int[n][n];		//定义一个新矩阵存放结果
    if n==1:
        C11=A11*B11;
    else Divide A, B and C as in 4 equation:
    	n /= 2;
        C11=Divide_And_Conquer(A11,B11,n) + Divide_And_Conquer(A12,B21,n);
        C22=Divide_And_Conquer(A11,B12,n) + Divide_And_Conquer(A12,B22,n);
        C21=Divide_And_Conquer(A21,B11,n) + Divide_And_Conquer(A22,B21,n);
        C22=Divide_And_Conquer(A21,B22,n) + Divide_And_Conquer(A22,B22,n);

    return C;
}

因为n/2 * n/2 的矩阵乘法进行了8次, n/2 * n/2的矩阵加法进行了4次所以复杂度为:

T ( n ) = { 1 n = 1 8 T ( n / 2 ) + 4 ( n 2 ) 2 n > 1 O ( n 3 ) T(n)=\left\{\begin{array}{ll} 1 & n=1 \\ 8 T(n / 2)+4\left(\frac{n}{2}\right)^{2} & n>1 \end{array} \Longrightarrow O\left(n^{3}\right)\right.

Strassen 法伪代码:

Strassen_DAC(int [][]A, int [][]B,int n){
    int [][]C = new int[n][n];		//定义结果矩阵
    if n==1:
    	C11 = A11*B11;
    else Divide A, B, and C as in 4 equation:
    	n /= 2;
    	int [][]M1,M2,M3,M4,M5,M6,M7 = new int[n][n];
    	M1 = Strassen_DAC(A11, B12-B22, n);
		M2 = Strassen_DAC(A11+A12, B22, n);
    	M3 = Strassen_DAC(A21+A22, B11, n);
    	M4 = Strassen_DAC(A22, B21-B11, n);
    	M5 = Strassen_DAC(A11+A12, B11+B12, n);
    	M6 = Strassen_DAC(A12-A22, B21+B22, n);
    	M7 = Strassen_DAC(A11-A21, B11+B12, n);
    	C11 = M5 + M4 - M2 + M6;
    	C12 = M1 + M2;
    	C21 = M3 + M4;
    	C22 = M5 + M1 -M3 -M7;
    return C;
}

从伪代码明显可以看出,程序执行了7次n/2 * n/2的矩阵乘法,以及 18次n/2 *n/2的矩阵加减运算,所以复杂度为:

T ( n ) = { 1 n = 1 7 T ( n / 2 ) + 18 ( n 2 ) 2 n > 1 O ( n log 3 ) T(n)=\left\{\begin{array}{ll}1 & n=1 \\ 7 T(n / 2)+18\left(\frac{n}{2}\right)^{2} & n>1\end{array} \Longrightarrow O\left(n^{\log 3}\right)\right.

三、程序实现和测试过程和结果(主要描述出现的问题和解决方法)

​ 算法的思路倒是不难,难的是具体实现的时候,矩阵的分块操作,容易绕不清楚。而且其中还有矩阵的加减运算,得出的结果最后还要把 C 11 , C 12 , C 21 , C 22 C11, C12, C21, C22 合并成为一个矩阵 C C ,这些在伪代码里都没有给出来,但是复杂繁琐容易出bug的正是这些细节。

​ 实验中使用Java语言编写将三种方法放到同一个类中,一下为类中的各个方法:

在这里插入图片描述

源码:

package root;

/**
 * @author 宇智波Akali
 * 这是三种矩阵乘法
 * @date 2020.3.18
 */
public class Try {
	//创建一个随机数构成的n*n矩阵
	public static int[][] initializationMatrix(int n){
		int[][] result = new int[n][n];//创建一个n*n矩阵
		for(int i = 0;i < n;i++){
			for(int j = 0;j < n;j++){
				result[i][j] = (int)(Math.random()*10); //随机生成1~10之间的数
			}
		} 
		return result; 
	}

	//蛮力法求矩阵相乘
	public static int[][] BruteForce(int[][] p,int[][] q,int n){
		int[][] result = new int[n][n];
		for(int i=0;i<n;i++){
			for(int j=0;j<n;j++){
				result[i][j] = 0;
				for(int k=0;k<n;k++){
					result[i][j] += p[i][k]*q[k][j];
				}
			}
		}  
		return result;
	}

	//分治法求矩阵相乘
	public static int[][] DivideAndConquer(int[][] p,int[][] q,int n){
		int[][] result = new int[n][n];//创建一个n*n矩阵
		//当n为2时,用蛮力法求矩阵相乘,返回结果结果
		if(n == 2){
			result = BruteForce(p,q,n); 
			return result;
		}
	 
		//当n大于3时,采用分治法,递归求最终结果
		if(n > 2){
			int m = n/2;
			
			//将矩阵p分成四块
			int[][] p1 = QuarterMatrix(p,n,1);
			int[][] p2 = QuarterMatrix(p,n,2);
			int[][] p3 = QuarterMatrix(p,n,3);
			int[][] p4 = QuarterMatrix(p,n,4);
			
			//将矩阵q分成四块
			int[][] q1 = QuarterMatrix(q,n,1);
			int[][] q2 = QuarterMatrix(q,n,2);
			int[][] q3 = QuarterMatrix(q,n,3);
			int[][] q4 = QuarterMatrix(q,n,4);
			
			//将结果矩阵分成同等大小的四块
			int[][] result1 = QuarterMatrix(result,n,1);
			int[][] result2 = QuarterMatrix(result,n,2);
			int[][] result3 = QuarterMatrix(result,n,3);
			int[][] result4 = QuarterMatrix(result,n,4);
		
			//最关键的步骤,递归调用DivideAndConquer()函数,并用公式相加
			result1 = AddMatrix(DivideAndConquer(p1,q1,m),DivideAndConquer(p2,q3,m),m);//y=ae+bg
			result2 = AddMatrix(DivideAndConquer(p1,q2,m),DivideAndConquer(p2,q4,m),m);//s=af+bh
			result3 = AddMatrix(DivideAndConquer(p3,q1,m),DivideAndConquer(p4,q3,m),m);//t=ce+dg
			result4 = AddMatrix(DivideAndConquer(p3,q2,m),DivideAndConquer(p4,q4,m),m);//u=cf+dh
			
			//合并,将四块小矩阵合成整体
			result = TogetherMatrix(result1,result2,result3,result4,m);//把分成的四个小矩阵合并成一个大矩阵
		}
		return result;
	}
	
	//strassen法
	public static int[][] Strassen(int[][] p,int[][] q,int n){
		int[][] result = new int[n][n];//创建一个n*n矩阵
		if( n == 2){
			result = BruteForce(p,q,n);
			return result;
		}
		int m = n/2;
		
		//将矩阵p分成四块
		int[][] p1 = QuarterMatrix(p,n,1);
		int[][] p2 = QuarterMatrix(p,n,2);
		int[][] p3 = QuarterMatrix(p,n,3);
		int[][] p4 = QuarterMatrix(p,n,4);
		
		//将矩阵q分成四块
		int[][] q1 = QuarterMatrix(q,n,1);
		int[][] q2 = QuarterMatrix(q,n,2);
		int[][] q3 = QuarterMatrix(q,n,3);
		int[][] q4 = QuarterMatrix(q,n,4);
				int[][] m1 = DivideAndConquer(AddMatrix(p1,p4,m),AddMatrix(q1,q4,m),m);
		int[][] m2 = Strassen(AddMatrix(p3,p4,m),q1,m);
		int[][] m3 = Strassen(p1,ReduceMatrix(q2,q4,m),m);
		int[][] m4 = Strassen(p4,ReduceMatrix(q3,q1,m),m);
		int[][] m5 = Strassen(AddMatrix(p1,p2,m),q4,m);
		int[][] m6 = Strassen(ReduceMatrix(p3,p1,m),AddMatrix(q1,q2,m),m);
		int[][] m7 = Strassen(ReduceMatrix(p2,p4,m),AddMatrix(q3,q4,m),m);
		
		//将结果矩阵分成同等大小的四块
		int[][] result1 = QuarterMatrix(result,n,1);
		int[][] result2 = QuarterMatrix(result,n,2);
		int[][] result3 = QuarterMatrix(result,n,3);
		int[][] result4 = QuarterMatrix(result,n,4);
	
		result1 = AddMatrix(ReduceMatrix(AddMatrix(m1,m4,m),m5,m),m7,m);
		result2 = AddMatrix(m3,m5,m);
		result3 = AddMatrix(m2,m4,m);
		result4 = AddMatrix(AddMatrix(ReduceMatrix(m1,m2,m),m3,m),m6,m);
		
		result = TogetherMatrix(result1,result2,result3,result4,m);//把分成的四个小矩阵合并成一个大矩阵
		
		return result;
	}
	
	
	
	
	//获取矩阵的四分之一,number用来确定返回哪一个四分之一
	public static int[][] QuarterMatrix(int[][] p,int n,int number){
		int rows = n/2;  //行数减半
		int cols = n/2;  //列数减半
		int[][] result = new int[rows][cols];
		switch(number){
		//左上
		case 1 :
		{
			for(int i=0;i<rows;i++)
				for(int j=0;j<cols;j++)
					result[i][j] = p[i][j];
			break;
		}
		//右上
		case 2 :
		{
			for(int i=0;i<rows;i++)
				for(int j=0;j<n-cols;j++)
					result[i][j] = p[i][j+cols];
			break;
		}
		//左下
		case 3 :
		{
			for(int i=0;i<n-rows;i++)
				for(int j=0;j<cols;j++)
					result[i][j] = p[i+rows][j];
			break;
		}
		//右下
		case 4 :
		{
			for(int i=0;i<n-rows;i++)
				for(int j=0;j<n-cols;j++)
					result[i][j] = p[i+rows][j+cols];
			break;
		}
		default:
			break;
		}
	
		return result;
	}

	//把均分为四分之一的矩阵,合成一个矩阵
	public static int[][] TogetherMatrix(int[][] a,int[][] b,int[][] c,int[][] d,int n){
		int[][] result = new int[2*n][2*n];
		for(int i=0;i<2*n;i++){
			for(int j=0;j<2*n;j++){
				if(i<n){
					if(j<n)
						result[i][j] = a[i][j];
					else
						result[i][j] = b[i][j-n];
				}else{
					if(j<n)
						result[i][j] = c[i-n][j];
					else
						result[i][j] = d[i-n][j-n];
				}
			}
		}
		return result;
	}


	//求两个矩阵相加结果
	public static int[][] AddMatrix(int[][] p,int[][] q,int n){
		int[][] result = new int[n][n];
		for(int i=0;i<n;i++){
			for(int j=0;j<n;j++){
				result[i][j] = p[i][j]+q[i][j];
			}
		}
		return result;
	}
	
	//求两个矩阵相减结果
	public static int[][] ReduceMatrix(int[][] p,int[][] q,int n){
		int[][] result = new int[n][n];
		for(int i=0;i<n;i++){
			for(int j=0;j<n;j++){
				result[i][j] = p[i][j]-q[i][j];
			}
		}
		return result;
	}
	
	//输出矩阵的函数
	public static void PrintfMatrix(int[][] matrix,int n){
		for(int i=0;i<n;i++){
			for(int j=0;j<n;j++)
				System.out.printf("% 4d",matrix[i][j]);
			System.out.println();
		}
	
	}

	public static void main(String args[]){
		int[][] p = initializationMatrix(8);
		int[][] q = initializationMatrix(8);
		//输出生成的两个矩阵
		System.out.println("p:");
		PrintfMatrix(p,8);
		System.out.println();
		System.out.println("q:");
		PrintfMatrix(q,8);
 
		//输出分治法矩阵相乘后的结果
		int[][] bru_result = BruteForce(p,q,8);//新建一个矩阵存放矩阵相乘结果
		System.out.println();
		System.out.println("\nA*B(蛮力法):");
		PrintfMatrix(bru_result,8);
		
		//输出分治法矩阵相乘后的结果
		int[][] dac_result = DivideAndConquer(p,q,8);//新建一个矩阵存放矩阵相乘结果
		System.out.println();
		System.out.println("A*B(分治法):");
		PrintfMatrix(dac_result,8);
		
		//输出strassen法矩阵相乘后的结果
		int[][] stra_result = Strassen(p,q,8);//新建一个矩阵存放矩阵相乘结果
		System.out.println("\nA*B(strassen法):");
		PrintfMatrix(stra_result,8);
		
	}
 
}

运行结果:
运行结果

发布了27 篇原创文章 · 获赞 19 · 访问量 4535

猜你喜欢

转载自blog.csdn.net/qq_43617268/article/details/104948499