Algorithm —— 矩阵乘法的Strassen算法(六)

Algorithm —— 矩阵乘法的Strassen算法


算法导论借助暴力求两NxN矩阵乘积的问题,引出了Strassen算法。下面的代码实现分别对应了书中暴力求解法、分治求解法和Strassen求解法的实现,具体如下文所示。

关于这部分内容的伪代码可及说明可以参看《算法导论》4.2章节。

根据矩阵的乘法知识,两个NxN的矩阵A和B相乘的结果矩阵C的暴力算法是:

	/**
	 * 一般的暴力矩阵乘法运算;矩阵A和B都是NxN的方阵
	 * 
	 * @param A
	 *          参加运算的矩阵之一A
	 * @param B
	 *          参加运算的矩阵之一B
	 * @return 
	 * 			矩阵A和B相乘得到的矩阵C
	 */
	public static int[][] squareMatrixMultiply(int[][] A, int[][] B) {

		int rows = A.length;
		int[][] C = new int[rows][rows];

		for (int i = 0; i < rows; i++) {
			for (int j = 0; j < rows; j++) {
				C[i][j] = 0;
				for (int k = 0; k < rows; k++) {
					C[i][j] = C[i][j] + A[i][k] * B[k][j];
				}
			}
		}
		return C;
	}
加入了分治思想的两个NxN的矩阵A和B得到相乘结果矩阵C的算法是:

/**
	 * 使用分治算法的NxN矩阵乘法运算
	 * @param A
	 * 			参加运算的矩阵之一A
	 * @param B
	 * 			参加运算的矩阵之一B
	 * @return
	 */
	public static int[][] martixMultiplyRecursive(int[][] A, int[][] B) {
		int rows = A.length;
		int[][] C = new int[rows][rows];
		if (rows == 1) {
			C[0][0] = A[0][0] * B[0][0];
		} else {
			int[][] A11 = new int[rows / 2][rows / 2];
			int[][] A12 = new int[rows / 2][rows / 2];
			int[][] A21 = new int[rows / 2][rows / 2];
			int[][] A22 = new int[rows / 2][rows / 2];

			copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, 0, rows / 2, A11);
			copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, rows / 2, rows / 2, A12);
			copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, 0, rows / 2, A21);
			copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, rows / 2, rows / 2, A22);

			int[][] B11 = new int[rows / 2][rows / 2];
			int[][] B12 = new int[rows / 2][rows / 2];
			int[][] B21 = new int[rows / 2][rows / 2];
			int[][] B22 = new int[rows / 2][rows / 2];

			copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, 0, rows / 2, B11);
			copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, rows / 2, rows / 2, B12);
			copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, 0, rows / 2, B21);
			copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, rows / 2, rows / 2, B22);

			int[][] C11 = new int[rows / 2][rows / 2];
			int[][] C12 = new int[rows / 2][rows / 2];
			int[][] C21 = new int[rows / 2][rows / 2];
			int[][] C22 = new int[rows / 2][rows / 2];

			squareMatrixElementAdd(squareMatrixMultiply(A11, B11), squareMatrixMultiply(A12, B21), C11);
			squareMatrixElementAdd(squareMatrixMultiply(A11, B12), squareMatrixMultiply(A12, B22), C12);
			squareMatrixElementAdd(squareMatrixMultiply(A21, B11), squareMatrixMultiply(A22, B21), C21);
			squareMatrixElementAdd(squareMatrixMultiply(A21, B12), squareMatrixMultiply(A22, B22), C22);

			copySubMatrixByParamFromSrcToDest(C11, 0, rows / 2, 0, rows / 2, C);
			copySubMatrixByParamFromSrcToDest(C12, 0, rows / 2, rows / 2, rows / 2, C);
			copySubMatrixByParamFromSrcToDest(C21, rows / 2, rows / 2, 0, rows / 2, C);
			copySubMatrixByParamFromSrcToDest(C22, rows / 2, rows / 2, rows / 2, rows / 2, C);

		}
		return C;
	}

	/**
	 * 将一个NxN的大矩阵分解成4个N/2xN/2的子矩阵
	 * 
	 */
	public static void copyMatrixbyParamFromSrcToSubMatrix(int[][] src, int startI, int lenI, int startJ, int lenJ,
			int[][] dest) {
		for (int i = 0; i < lenI; i++)
			for (int j = 0; j < lenJ; j++) {
				dest[i][j] = src[startI + i][startJ + j];
			}
	}

	/**
	 * 将4个N/2xN/2的子矩阵合并成一个NxN的大矩阵
	 * 
	 */
	public static void copySubMatrixByParamFromSrcToDest(int[][] src, int startI, int lenI, int startJ, int lenJ,
			int[][] dest) {
		for (int i = 0; i < lenI; i++)
			for (int j = 0; j < lenJ; j++) {
				dest[startI + i][startJ + j] = src[i][j];
			}
	}

	/**
	 * NxN矩阵加法
	 * 
	 * @param srcA
	 *            加法源矩阵之一
	 * @param srcB
	 *            加法源矩阵之二
	 * @param dest
	 *            矩阵加法结果
	 */
	public static void squareMatrixElementAdd(int[][] srcA, int[][] srcB, int[][] dest) {
		for (int i = 0; i < srcA.length; i++)
			for (int j = 0; j < srcA[i].length; j++)
				dest[i][j] = srcA[i][j] + srcB[i][j];
	}

	/**
	 * 打印NxN矩阵
	 * 
	 */
	public static void displaySquare(int matrix[][]) {
		for (int i = 0; i < matrix.length; i++) {
			for (int j : matrix[i]) {
				System.out.print(j + " ");
			}
			System.out.println();
		}
	}
使用Strassen算法求两方阵矩阵的积的算法实现代码是:

/**
	 * Strassen算法的NxN矩阵乘法运算
	 * 
	 * @param A
	 *            参加运算的矩阵之一A
	 * @param B
	 *            参加运算的矩阵之一B
	 * @return
	 */
	public static int[][] strassenMartixMultiplyRecursive(int[][] A, int[][] B) {
		int rows = A.length;
		int[][] C = new int[rows][rows];
		if (rows == 1) {
			C[0][0] = A[0][0] * B[0][0];
		} else {
			int[][] A11 = new int[rows / 2][rows / 2];
			int[][] A12 = new int[rows / 2][rows / 2];
			int[][] A21 = new int[rows / 2][rows / 2];
			int[][] A22 = new int[rows / 2][rows / 2];

			copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, 0, rows / 2, A11);
			copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, rows / 2, rows / 2, A12);
			copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, 0, rows / 2, A21);
			copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, rows / 2, rows / 2, A22);

			int[][] B11 = new int[rows / 2][rows / 2];
			int[][] B12 = new int[rows / 2][rows / 2];
			int[][] B21 = new int[rows / 2][rows / 2];
			int[][] B22 = new int[rows / 2][rows / 2];

			copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, 0, rows / 2, B11);
			copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, rows / 2, rows / 2, B12);
			copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, 0, rows / 2, B21);
			copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, rows / 2, rows / 2, B22);

			int[][] S1 = new int[rows / 2][rows / 2];
			int[][] S2 = new int[rows / 2][rows / 2];
			int[][] S3 = new int[rows / 2][rows / 2];
			int[][] S4 = new int[rows / 2][rows / 2];
			int[][] S5 = new int[rows / 2][rows / 2];
			int[][] S6 = new int[rows / 2][rows / 2];
			int[][] S7 = new int[rows / 2][rows / 2];
			int[][] S8 = new int[rows / 2][rows / 2];
			int[][] S9 = new int[rows / 2][rows / 2];
			int[][] S10 = new int[rows / 2][rows / 2];

			squareMatrixElementSub(B12, B22, S1);// S1 = B12 - B22
			squareMatrixElementAdd(A11, A12, S2);// S2 = A11 + A12
			squareMatrixElementAdd(A21, A22, S3);// S3 = A21 + A22
			squareMatrixElementSub(B21, B11, S4);// S4 = B21 - B11
			squareMatrixElementAdd(A11, A22, S5);// S5 = A11 + A22
			squareMatrixElementAdd(B11, B22, S6);// S6 = B11 + B22
			squareMatrixElementSub(A12, A22, S7);// S7 = A12 - A22
			squareMatrixElementAdd(B21, B22, S8);// S8 = B21 + B22
			squareMatrixElementSub(A11, A21, S9);// S9 = A11 - A21
			squareMatrixElementAdd(B11, B12, S10);// S10 = B11 + B12

			int[][] P1 = new int[rows / 2][rows / 2];
			int[][] P2 = new int[rows / 2][rows / 2];
			int[][] P3 = new int[rows / 2][rows / 2];
			int[][] P4 = new int[rows / 2][rows / 2];
			int[][] P5 = new int[rows / 2][rows / 2];
			int[][] P6 = new int[rows / 2][rows / 2];
			int[][] P7 = new int[rows / 2][rows / 2];

			P1 = strassenMartixMultiplyRecursive(A11, S1); // P1 = A11 X S1
			P2 = strassenMartixMultiplyRecursive(S2, B22);// P2 = S2 X B22
			P3 = strassenMartixMultiplyRecursive(S3, B11);// P3 = S3 X B11
			P4 = strassenMartixMultiplyRecursive(A22, S4);// P4 = A22 X S4
			P5 = strassenMartixMultiplyRecursive(S5, S6);// P5 = S5 X S6
			P6 = strassenMartixMultiplyRecursive(S7, S8);// P6 = S7 X S8
			P7 = strassenMartixMultiplyRecursive(S9, S10);// P7 = S9 X S10

			int[][] C11 = new int[rows / 2][rows / 2];
			int[][] C12 = new int[rows / 2][rows / 2];
			int[][] C21 = new int[rows / 2][rows / 2];
			int[][] C22 = new int[rows / 2][rows / 2];

			int[][] temp = new int[rows / 2][rows / 2];

			// C11 = P5 + P4 - P2 + P6
			squareMatrixElementAdd(P5, P4, temp);
			squareMatrixElementSub(temp, P2, temp);
			squareMatrixElementAdd(temp, P6, C11);

			// C12 = P1 + P2
			squareMatrixElementAdd(P1, P2, C12);

			// C21 = P3 + P4
			squareMatrixElementAdd(P3, P4, C21);

			// C22 = P5 + P1 - P3 -P7
			squareMatrixElementAdd(P5, P1, temp);
			squareMatrixElementSub(temp, P3, temp);
			squareMatrixElementSub(temp, P7, C22);
			
			//将C11/C12/C21/C22四个子矩阵合并为最终的结果C矩阵
			copySubMatrixByParamFromSrcToDest(C11, 0, rows / 2, 0, rows / 2, C);
			copySubMatrixByParamFromSrcToDest(C12, 0, rows / 2, rows / 2, rows / 2, C);
			copySubMatrixByParamFromSrcToDest(C21, rows / 2, rows / 2, 0, rows / 2, C);
			copySubMatrixByParamFromSrcToDest(C22, rows / 2, rows / 2, rows / 2, rows / 2, C);

		}
		return C;
	}

	/**
	 * 将一个NxN的大矩阵分解成4个N/2xN/2的子矩阵
	 * 
	 */
	public static void copyMatrixbyParamFromSrcToSubMatrix(int[][] src, int startI, int lenI, int startJ, int lenJ,
			int[][] dest) {
		for (int i = 0; i < lenI; i++)
			for (int j = 0; j < lenJ; j++) {
				dest[i][j] = src[startI + i][startJ + j];
			}
	}

	/**
	 * 将4个N/2xN/2的子矩阵合并成一个NxN的大矩阵
	 * 
	 */
	public static void copySubMatrixByParamFromSrcToDest(int[][] src, int startI, int lenI, int startJ, int lenJ,
			int[][] dest) {
		for (int i = 0; i < lenI; i++)
			for (int j = 0; j < lenJ; j++) {
				dest[startI + i][startJ + j] = src[i][j];
			}
	}

	/**
	 * NxN矩阵加法
	 * 
	 * @param srcA
	 *            加法源矩阵之一
	 * @param srcB
	 *            加法源矩阵之二
	 * @param dest
	 *            矩阵加法结果
	 */
	public static void squareMatrixElementAdd(int[][] srcA, int[][] srcB, int[][] dest) {
		for (int i = 0; i < srcA.length; i++)
			for (int j = 0; j < srcA[i].length; j++)
				dest[i][j] = srcA[i][j] + srcB[i][j];
	}

	/**
	 * NxN矩阵减法
	 * 
	 * @param srcA
	 *            减法源矩阵之一
	 * @param srcB
	 *            减法源矩阵之二
	 * @param dest
	 *            矩阵减法结果
	 */
	public static void squareMatrixElementSub(int[][] srcA, int[][] srcB, int[][] dest) {
		for (int i = 0; i < srcA.length; i++)
			for (int j = 0; j < srcA[i].length; j++)
				dest[i][j] = srcA[i][j] - srcB[i][j];
	}
最后本文中涉及到的完整测试如下:

public class StrassenAlgor {

	static int[][] A = { 
							{ 1, 2, 2, 1 }, 
							{ 1, 2, 2, 1 }, 
							{ 1, 2, 2, 1 }, 
							{ 1, 2, 2, 1 } 
					   };
	static int[][] B = { 
							{ 1, 2, 2, 1 }, 
							{ 1, 2, 2, 1 }, 
							{ 1, 2, 3, 1 }, 
							{ 1, 2, 2, 1 } 
					   };

	public static void main(String[] args) {

		System.out.println("使用暴力迭代形式的方阵矩阵求积");
		int[][] C = martixMultiplyRecursive(A, B);
		displaySquare(C);

		System.out.println("使用分治思想的普通形式的方阵矩阵求积");
		int[][] C1 = martixMultiplyRecursive(A, B);
		displaySquare(C1);

		System.out.println("Strassen 方阵求积");
		int[][] C2 = strassenMartixMultiplyRecursive(A, B);
		displaySquare(C2);

	}

	/**
	 * 一般的暴力矩阵乘法运算;矩阵A和B都是NxN的方阵
	 * 
	 * @param A
	 *            参加运算的矩阵之一A
	 * @param B
	 *            参加运算的矩阵之一B
	 * @return 矩阵A和B相乘得到的矩阵C
	 */
	public static int[][] squareMatrixMultiply(int[][] A, int[][] B) {

		int rows = A.length;
		int[][] C = new int[rows][rows];

		for (int i = 0; i < rows; i++) {
			for (int j = 0; j < rows; j++) {
				C[i][j] = 0;
				for (int k = 0; k < rows; k++) {
					C[i][j] = C[i][j] + A[i][k] * B[k][j];
				}
			}
		}
		return C;
	}

	/**
	 * 使用分治算法的NxN矩阵乘法运算
	 * 
	 * @param A
	 *            参加运算的矩阵之一A
	 * @param B
	 *            参加运算的矩阵之一B
	 * @return
	 */
	public static int[][] martixMultiplyRecursive(int[][] A, int[][] B) {
		int rows = A.length;
		int[][] C = new int[rows][rows];
		if (rows == 1) {
			C[0][0] = A[0][0] * B[0][0];
		} else {
			int[][] A11 = new int[rows / 2][rows / 2];
			int[][] A12 = new int[rows / 2][rows / 2];
			int[][] A21 = new int[rows / 2][rows / 2];
			int[][] A22 = new int[rows / 2][rows / 2];

			copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, 0, rows / 2, A11);
			copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, rows / 2, rows / 2, A12);
			copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, 0, rows / 2, A21);
			copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, rows / 2, rows / 2, A22);

			int[][] B11 = new int[rows / 2][rows / 2];
			int[][] B12 = new int[rows / 2][rows / 2];
			int[][] B21 = new int[rows / 2][rows / 2];
			int[][] B22 = new int[rows / 2][rows / 2];

			copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, 0, rows / 2, B11);
			copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, rows / 2, rows / 2, B12);
			copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, 0, rows / 2, B21);
			copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, rows / 2, rows / 2, B22);

			int[][] C11 = new int[rows / 2][rows / 2];
			int[][] C12 = new int[rows / 2][rows / 2];
			int[][] C21 = new int[rows / 2][rows / 2];
			int[][] C22 = new int[rows / 2][rows / 2];

			squareMatrixElementAdd(squareMatrixMultiply(A11, B11), squareMatrixMultiply(A12, B21), C11);
			squareMatrixElementAdd(squareMatrixMultiply(A11, B12), squareMatrixMultiply(A12, B22), C12);
			squareMatrixElementAdd(squareMatrixMultiply(A21, B11), squareMatrixMultiply(A22, B21), C21);
			squareMatrixElementAdd(squareMatrixMultiply(A21, B12), squareMatrixMultiply(A22, B22), C22);

			// 将C11/C12/C21/C22四个子矩阵合并为最终的结果C矩阵
			copySubMatrixByParamFromSrcToDest(C11, 0, rows / 2, 0, rows / 2, C);
			copySubMatrixByParamFromSrcToDest(C12, 0, rows / 2, rows / 2, rows / 2, C);
			copySubMatrixByParamFromSrcToDest(C21, rows / 2, rows / 2, 0, rows / 2, C);
			copySubMatrixByParamFromSrcToDest(C22, rows / 2, rows / 2, rows / 2, rows / 2, C);

		}
		return C;
	}

	/**
	 * Strassen算法的NxN矩阵乘法运算
	 * 
	 * @param A
	 *            参加运算的矩阵之一A
	 * @param B
	 *            参加运算的矩阵之一B
	 * @return
	 */
	public static int[][] strassenMartixMultiplyRecursive(int[][] A, int[][] B) {
		int rows = A.length;
		int[][] C = new int[rows][rows];
		if (rows == 1) {
			C[0][0] = A[0][0] * B[0][0];
		} else {
			int[][] A11 = new int[rows / 2][rows / 2];
			int[][] A12 = new int[rows / 2][rows / 2];
			int[][] A21 = new int[rows / 2][rows / 2];
			int[][] A22 = new int[rows / 2][rows / 2];

			copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, 0, rows / 2, A11);
			copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, rows / 2, rows / 2, A12);
			copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, 0, rows / 2, A21);
			copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, rows / 2, rows / 2, A22);

			int[][] B11 = new int[rows / 2][rows / 2];
			int[][] B12 = new int[rows / 2][rows / 2];
			int[][] B21 = new int[rows / 2][rows / 2];
			int[][] B22 = new int[rows / 2][rows / 2];

			copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, 0, rows / 2, B11);
			copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, rows / 2, rows / 2, B12);
			copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, 0, rows / 2, B21);
			copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, rows / 2, rows / 2, B22);

			int[][] S1 = new int[rows / 2][rows / 2];
			int[][] S2 = new int[rows / 2][rows / 2];
			int[][] S3 = new int[rows / 2][rows / 2];
			int[][] S4 = new int[rows / 2][rows / 2];
			int[][] S5 = new int[rows / 2][rows / 2];
			int[][] S6 = new int[rows / 2][rows / 2];
			int[][] S7 = new int[rows / 2][rows / 2];
			int[][] S8 = new int[rows / 2][rows / 2];
			int[][] S9 = new int[rows / 2][rows / 2];
			int[][] S10 = new int[rows / 2][rows / 2];

			squareMatrixElementSub(B12, B22, S1);// S1 = B12 - B22
			squareMatrixElementAdd(A11, A12, S2);// S2 = A11 + A12
			squareMatrixElementAdd(A21, A22, S3);// S3 = A21 + A22
			squareMatrixElementSub(B21, B11, S4);// S4 = B21 - B11
			squareMatrixElementAdd(A11, A22, S5);// S5 = A11 + A22
			squareMatrixElementAdd(B11, B22, S6);// S6 = B11 + B22
			squareMatrixElementSub(A12, A22, S7);// S7 = A12 - A22
			squareMatrixElementAdd(B21, B22, S8);// S8 = B21 + B22
			squareMatrixElementSub(A11, A21, S9);// S9 = A11 - A21
			squareMatrixElementAdd(B11, B12, S10);// S10 = B11 + B12

			int[][] P1 = new int[rows / 2][rows / 2];
			int[][] P2 = new int[rows / 2][rows / 2];
			int[][] P3 = new int[rows / 2][rows / 2];
			int[][] P4 = new int[rows / 2][rows / 2];
			int[][] P5 = new int[rows / 2][rows / 2];
			int[][] P6 = new int[rows / 2][rows / 2];
			int[][] P7 = new int[rows / 2][rows / 2];

			P1 = strassenMartixMultiplyRecursive(A11, S1); // P1 = A11 X S1
			P2 = strassenMartixMultiplyRecursive(S2, B22);// P2 = S2 X B22
			P3 = strassenMartixMultiplyRecursive(S3, B11);// P3 = S3 X B11
			P4 = strassenMartixMultiplyRecursive(A22, S4);// P4 = A22 X S4
			P5 = strassenMartixMultiplyRecursive(S5, S6);// P5 = S5 X S6
			P6 = strassenMartixMultiplyRecursive(S7, S8);// P6 = S7 X S8
			P7 = strassenMartixMultiplyRecursive(S9, S10);// P7 = S9 X S10

			int[][] C11 = new int[rows / 2][rows / 2];
			int[][] C12 = new int[rows / 2][rows / 2];
			int[][] C21 = new int[rows / 2][rows / 2];
			int[][] C22 = new int[rows / 2][rows / 2];

			int[][] temp = new int[rows / 2][rows / 2];

			// C11 = P5 + P4 - P2 + P6
			squareMatrixElementAdd(P5, P4, temp);
			squareMatrixElementSub(temp, P2, temp);
			squareMatrixElementAdd(temp, P6, C11);

			// C12 = P1 + P2
			squareMatrixElementAdd(P1, P2, C12);

			// C21 = P3 + P4
			squareMatrixElementAdd(P3, P4, C21);

			// C22 = P5 + P1 - P3 -P7
			squareMatrixElementAdd(P5, P1, temp);
			squareMatrixElementSub(temp, P3, temp);
			squareMatrixElementSub(temp, P7, C22);

			// 将C11/C12/C21/C22四个子矩阵合并为最终的结果C矩阵
			copySubMatrixByParamFromSrcToDest(C11, 0, rows / 2, 0, rows / 2, C);
			copySubMatrixByParamFromSrcToDest(C12, 0, rows / 2, rows / 2, rows / 2, C);
			copySubMatrixByParamFromSrcToDest(C21, rows / 2, rows / 2, 0, rows / 2, C);
			copySubMatrixByParamFromSrcToDest(C22, rows / 2, rows / 2, rows / 2, rows / 2, C);

		}
		return C;
	}

	/**
	 * 将一个NxN的大矩阵分解成4个N/2xN/2的子矩阵
	 * 
	 */
	public static void copyMatrixbyParamFromSrcToSubMatrix(int[][] src, int startI, int lenI, int startJ, int lenJ,
			int[][] dest) {
		for (int i = 0; i < lenI; i++)
			for (int j = 0; j < lenJ; j++) {
				dest[i][j] = src[startI + i][startJ + j];
			}
	}

	/**
	 * 将4个N/2xN/2的子矩阵合并成一个NxN的大矩阵
	 * 
	 */
	public static void copySubMatrixByParamFromSrcToDest(int[][] src, int startI, int lenI, int startJ, int lenJ,
			int[][] dest) {
		for (int i = 0; i < lenI; i++)
			for (int j = 0; j < lenJ; j++) {
				dest[startI + i][startJ + j] = src[i][j];
			}
	}

	/**
	 * NxN矩阵加法
	 * 
	 * @param srcA
	 *            加法源矩阵之一
	 * @param srcB
	 *            加法源矩阵之二
	 * @param dest
	 *            矩阵加法结果
	 */
	public static void squareMatrixElementAdd(int[][] srcA, int[][] srcB, int[][] dest) {
		for (int i = 0; i < srcA.length; i++)
			for (int j = 0; j < srcA[i].length; j++)
				dest[i][j] = srcA[i][j] + srcB[i][j];
	}

	/**
	 * NxN矩阵减法
	 * 
	 * @param srcA
	 *            减法源矩阵之一
	 * @param srcB
	 *            减法源矩阵之二
	 * @param dest
	 *            矩阵减法结果
	 */
	public static void squareMatrixElementSub(int[][] srcA, int[][] srcB, int[][] dest) {
		for (int i = 0; i < srcA.length; i++)
			for (int j = 0; j < srcA[i].length; j++)
				dest[i][j] = srcA[i][j] - srcB[i][j];
	}

	/**
	 * 打印NxN矩阵
	 * 
	 */
	public static void displaySquare(int[][] matrix) {
		for (int i = 0; i < matrix.length; i++) {
			for (int j : matrix[i]) {
				System.out.print(j + " ");
			}
			System.out.println();
		}
	}
}
输出如下:

使用暴力迭代形式的方阵矩阵求积
6 12 14 6 
6 12 14 6 
6 12 14 6 
6 12 14 6 
使用分治思想的普通形式的方阵矩阵求积
6 12 14 6 
6 12 14 6 
6 12 14 6 
6 12 14 6 
Strassen 方阵求积
6 12 14 6 
6 12 14 6 
6 12 14 6 
6 12 14 6 


PS:

最后一直在想为什么Strassen算法会降低方阵求积的复杂度;最后在知乎上看到一个答案,感觉解释的挺好;现贴出来,一起分享下:

"strassen算法的关键不在于是乘法还是加法,而是在于算法内部递归调用的次数。strassen算法的关键在于内部递归调用的次数减少了1(从普通的8次变为特殊的7次)。这里的一个结论就是递归算法中递归调用次数少,时间复杂度底。这很容易理解,在算法导论中用了“茂盛”度来描述这一时间复杂度在递归算法中的变化。所以strassen算法的关键在于,递归调用的次数怎么从8次减少一次的。反推理解一下,这说明8次递归调用中有一次是冗余的,即第8次递归乘法的结果信息已经包含在了前7次的结果里,前7次的计算结果通过线性组合就能得到第8个递归的结果了。而该线性组合的时间复杂度低于该算法本身(即一次递归调用)的时间复杂度。做一个结论。但凡是能够优化时间复杂度的算法,高复杂度的算法中必然是有一些计算是冗余的,如能用更少的计算代替冗余,就能提高效率。(因为算法递归的刚好是乘法,所以此处看起来似乎是重点放在了乘法上)

至于为什么传统矩阵相乘算法中有冗余计算,也尝试分析一下:

冗余的根本原因应该在于基本的乘法分配律a*(b+c)=a*b+a*c。同样的计算结果,前一种(等号前)方法计算需要2次基本运算,而后一种(等号后)方法需要3次。(假设乘法运算和加法运算是同等开销的基本运算)。而一般的矩阵乘法算法中是大量的单步乘法运算后求和,即采用的是上述等号右边的计算式。如果能有一种方法,将乘法运算中的相同因子提到前边来,运用上述乘法分配律转换计算形式,那么就能提高计算效率。这应该就是strassen算法的本质。看strassen算法的过程,就是先将一部分子矩阵进行加(减)运算,再进行乘法运算。其实就是构造了上述分配律的左式

作者:知乎用户
链接:https://www.zhihu.com/question/28558331/answer/146497271
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。"

猜你喜欢

转载自blog.csdn.net/csdn_of_coder/article/details/80058361