算法导论 · 分治法 · strassen矩阵乘法

  • 算法说明
    strassen矩阵乘法是将每个矩阵分成4块,分别按照公式计算,其中再遇到矩阵乘法时,递归调用。公式如下:
    在这里插入图片描述
  • 源代码
#include <cstdio>
const int n = 2;

void strassen(int **a, int **b, int **c, int m) {
	if(m == 1) {
		 c[0][0] = a[0][0] * b[0][0];
		 return;
	}
	//申请变量空间 
	int **a1, **a2, **a3, **a4, **b1, **b2, **b3, **b4, **c1, **c2, **c3, **c4;
	int **m1, **m2, **m3, **m4, **m5, **m6, **m7, mid = n / 2;
	a1 = new int * [mid], a2 = new int * [mid], a3 = new int * [mid], a4 = new int * [mid];
	b1 = new int * [mid], b2 = new int * [mid], b3 = new int * [mid], b4 = new int * [mid];
	c1 = new int * [mid], c2 = new int * [mid], c3 = new int * [mid], c4 = new int * [mid];
	m1 = new int * [mid], m2 = new int * [mid], m3 = new int * [mid], m4 = new int * [mid], m5 = new int * [mid], m6 = new int * [mid], m7 = new int * [mid];
	for(int i = 0; i < n; i++) {
		a1[i] = new int[mid], a2[i] = new int[mid], a3[i] = new int[mid], a4[i] = new int[mid];
		b1[i] = new int[mid], b2[i] = new int[mid], b3[i] = new int[mid], b4[i] = new int[mid];
		c1[i] = new int[mid], c2[i] = new int[mid], c3[i] = new int[mid], c4[i] = new int[mid];
		m1[i] = new int[mid], m2[i] = new int[mid], m3[i] = new int[mid], m4[i] = new int[mid], m5[i] = new int[mid], m6[i] = new int[mid], m7[i] = new int[mid];
	}
	
	//开始计算
	for(int i = 0; i < mid; i++) { //划分矩阵 
		for(int j = 0; j < mid; j++) {
			a1[i][j] = a[i][j];
			a2[i][j] = a[i + mid][j];
			a3[i][j] = a[i][j + mid];
			a4[i][j] = a[i + mid][j + mid];
			b1[i][j] = b[i][j];
			b2[i][j] = b[i + mid][j];
			b3[i][j] = b[i][j + mid];
			b4[i][j] = b[i + mid][j + mid];
		}
	}
	
	for(int i = 0; i < mid; i++) {  
		for(int j = 0; j < mid; j++) { //临时变量 
			m1[i][j] = a1[i][j] + a4[i][j]; //计算m1用的 
			m2[i][j] = b1[i][j] + b4[i][j];
			
			m3[i][j] = a3[i][j] + a4[i][j];//计算m2用的
			
			m4[i][j] = b2[i][j] - b4[i][j];//计算m3用的
			
			m5[i][j] = b3[i][j] - b1[i][j];//计算m4用的
			
			m6[i][j] = a1[i][j] + a2[i][j];//计算m5用的
			
			m7[i][j] = a3[i][j] - a1[i][j];//计算m6用的
			c1[i][j] = b1[i][j] + b2[i][j];
			
			c2[i][j] = a2[i][j] - a4[i][j];//计算m7用的
			c3[i][j] = b3[i][j] + b4[i][j];
		}
	}
//	printf("\n%d ", m1[0][0]);
//	printf("%d\n ", m2[0][0]);
	strassen(m1, m2, m1, mid); //计算m1
	strassen(m3, b1, m2, mid); //计算m2
	strassen(a1, m4, m3, mid); //计算m3
	strassen(a4, m5, m4, mid); //计算m4
	strassen(m6, b4, m5, mid); //计算m5
	strassen(m7, c1, m6, mid); //计算m6
	strassen(c2, c3, m7, mid); //计算m7
//	printf("\n%d ", m1[0][0]);
//	printf("%d ", m2[0][0]);
//	printf("%d ", m3[0][0]); 
//	printf("%d ", m4[0][0]);
//	printf("%d ", m5[0][0]);
//	printf("%d ", m6[0][0]);
//	printf("%d \n", m7[0][0]);
	for(int i = 0; i < mid; i++) { //计算c的分量矩阵 
		for(int j = 0; j < mid; j++) { 
			c1[i][j] = m1[i][j] + m4[i][j] - m5[i][j] + m7[i][j]; 
			c2[i][j] = m3[i][j] + m5[i][j];
			c3[i][j] = m2[i][j] + m4[i][j];
			c4[i][j] = m1[i][j] + m3[i][j] - m2[i][j] + m6[i][j];
		}
	}
	
	for(int i = 0; i < mid; i++) { //合并矩阵 
		for(int j = 0; j < mid; j++) {
			c[i][j] = c1[i][j];
			c[i + mid][j] = c2[i][j];
			c[i][j + mid] = c3[i][j];
			c[i + mid][j + mid] = c4[i][j];
		}
	}
	delete a1, a2, a3, a4, b1, b2, b3, b4, c1, c2, c3, c4;
	delete m1, m2, m3, m4, m5, m6, m7; 
}
int main() {
	int **a, **b, **c, tempNum = 1;
	a = new int * [n], b = new int * [n], c = new int * [n];
	for(int i = 0; i < n; i++) {
		a[i] = new int[n];
		b[i] = new int[n];
		c[i] = new int[n];
	}
	for(int i = 0; i < n; i++) { //初始化a[n][n] = {{1, 2}, {3, 4}}, b[n][n] = {{5, 6}, {7, 8}}
		for(int j = 0; j < n; j++) {
			a[i][j] = tempNum;
			b[i][j] = tempNum++ + 4;
		}
	} 
	for(int i = 0; i < n; i++) {
		for(int j = 0; j < n; j++) {
			printf("%d ", a[i][j]);
		}
		printf("\n");
	}
	for(int i = 0; i < n; i++) {
		for(int j = 0; j < n; j++) {
			printf("%d ", b[i][j]);
		}
		printf("\n");
	}
	strassen(a, b, c, n);
	
	//output
	for(int i = 0; i < n; i++) {
		for(int j = 0; j < n; j++) {
			printf("%d ", c[i][j]);
		}
		printf("\n");
	} 
	delete a, b, c;
	return 0;
} 
  • 运行结果
    在这里插入图片描述
发布了77 篇原创文章 · 获赞 40 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/y_dd6011/article/details/97429095