本文参考自: 原文地址
4.2.1 矩阵乘法的暴力解法
#暴力解法
def matrix_multiply(a,b):
n=len(a)
c=[[0]*n for i in range(n)]#快速创建n阶初始化方阵
for i in range (0,n):
for j in range(0,n):
c[i][j]=0
for k in range(0,n):
c[i][j]+=a[i][k]*b[k][j]
return c
这里假定a和b都是方阵,如果选择暴力解法,三个for循环,循环次数为n,总共需要花费θ(n^3)时间。
4.2.2 矩阵乘法的简单分治法
算法策略:(前提:假定A,B都是n等于2的次幂的方阵)
(1)基本思路:计算C=A*B时,将C,A,B矩阵进行分块操作,对每个分块的矩阵进行乘法运算,运算完毕后重新对得到的C11,C12,C21,C22进行组合操作。
(2)确定递归终止条件:当分块矩阵得到的阶数为1 时,得到的C即是A和B中两个元素的乘积。
def division(a): #矩阵分块函数
n=len(a)//2
a11=[[0 for i in range(n)]for j in range(n)]
a12=[[0 for i in range(n)]for j in range(n)]
a21=[[0 for i in range(n)]for j in range(n)]
a22=[[0 for i in range(n)]for j in range(n)]
for i in range(n):
for j in range(n):
a11[i][j]=a[i][j]
a12[i][j]=a[i][j+n]
a21[i][j]=a[i+n][j]
a22[i][j]=a[i+n][j+n]
return (a11,a12,a21,a22)
def matrix_combination(a11,a12,a21,a22):
n2 = len(a11)
n=n2*2
a = [[0 for col in range(n)] for row in range(n)]
for i in range (0,n):
for j in range (0,n):
if i <= (n2-1) and j <= (n2-1):
a[i][j] = a11[i][j]
elif i <= (n2-1) and j > (n2-1):
a[i][j] = a12[i][j-2]
elif i > (n2-1) and j <= (n2-1):
a[i][j] = a21[i-n2][j]
else:
a[i][j] = a22[i-n2][j-n2]
return a
def matrix_add(a,b): #矩阵相加函数
n = len(a)
c = [[0 for col in range(n)] for row in range(n)]
for i in range(0,n):
for j in range(0,n):
c[i][j] = a[i][j]+b[i][j]
return c
def matrix_devision_multiply(a,b): #矩阵乘法的简单分治法主程序
n=len(a)
c = [[0 for col in range(n)] for row in range(n)]#c=[[0]*n for i in range(n)]
if n==1:
c[0][0]=a[0][0]*b[0][0]
else:
(a11,a21,a12,a22)=division(a)
(b11,b21,b12,b22)=division(b)
(c11,c21,c12,c22)=division(c)
c11=matrix_add(matrix_devision_multiply(a11,b11),matrix_devision_multiply(a12,b21))
c12=matrix_add(matrix_devision_multiply(a11,b12),matrix_devision_multiply(a12,b22))
c21=matrix_add(matrix_devision_multiply(a21,b11),matrix_devision_multiply(a22,b21))
c22=matrix_add(matrix_devision_multiply(a21,b12),matrix_devision_multiply(a22,b22))
c=matrix_combination(c11,c12,c21,c22)
return c
a=[[1,1,1,1],[1,1,1,1],[2,2,2,2],[2,2,2,2]]
b=a
print(matrix_devision_multiply(a,b))
4.2.3矩阵的Strassen算法
在简单分治法的思想上,为进一步减少递归树的分枝,在递归函数中只进行7次而不是8次的矩阵的乘法,而减少一次乘法的代价是增加额外的几次矩阵加法运算。
def matrix_strassen(a,b):
n=len(a)
c = [[0 for col in range(n)] for row in range(n)]
if n==1:
c[0][0]=a[0][0]*b[0][0]
else:
(a11,a12,a21,a22)=division(a)
(b11,b12,b21,b22)=division(b)
(c11,c12,c21,c22)=division(c)
s1=matrix_add_sub(b12,b22,0)
s2=matrix_add_sub(a11,a12,1)
s3=matrix_add_sub(a21,a22,1)
s4=matrix_add_sub(b21,b11,0)
s5=matrix_add_sub(a11,a22,1)
s6=matrix_add_sub(b11,b22,1)
s7=matrix_add_sub(a12,a22,0)
s8=matrix_add_sub(b21,b22,1)
s9=matrix_add_sub(a11,a21,0)
s10=matrix_add_sub(b11,b12,1)
p1=matrix_strassen(a11,s1)
p2=matrix_strassen(s2,b22)
p3=matrix_strassen(s3,b11)
p4=matrix_strassen(a22,s4)
p5=matrix_strassen(s5,s6)
p6=matrix_strassen(s7,s8)
p7=matrix_strassen(s9,s10)
c11=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p4,1),p2,0),p6,1)
c12=matrix_add_sub(p1,p2,1)
c21=matrix_add_sub(p3,p4,1)
c22=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p1,1),p3,0),p7,0)
c=matrix_combination(c11,c12,c21,c22)
return c
#矩阵的strssen算法
def division(a): #对矩阵进行分解操作
n=len(a)//2
a11=[[0 for i in range(n)]for j in range(n)]
a12=[[0 for i in range(n)]for j in range(n)]
a21=[[0 for i in range(n)]for j in range(n)]
a22=[[0 for i in range(n)]for j in range(n)]
for i in range(n):
for j in range(n):
a11[i][j]=a[i][j]
a12[i][j]=a[i][j+n]
a21[i][j]=a[i+n][j]
a22[i][j]=a[i+n][j+n]
return (a11,a12,a21,a22)
def matrix_add_sub(a,b,keys):
n = len(a)
c = [[0 for col in range(n)] for row in range(n)]
if keys==1:
for i in range(n):
for j in range(n):
c[i][j] = a[i][j]+b[i][j]
else:
for i in range(n):
for j in range(n):
c[i][j]=a[i][j]-b[i][j]
return c
def matrix_combination(a11,a12,a21,a22): #对矩阵进行组合操作
n2 = len(a11)
n=n2*2
a = [[0 for col in range(n)] for row in range(n)]
for i in range (0,n):
for j in range (0,n):
if i <= (n2-1) and j <= (n2-1):
a[i][j] = a11[i][j]
elif i <= (n2-1) and j > (n2-1):
a[i][j] = a12[i][j-n2]
elif i > (n2-1) and j <= (n2-1):
a[i][j] = a21[i-n2][j]
else:
a[i][j] = a22[i-n2][j-n2]
return a
a=[[1,1,1,1],[1,1,1,1],[2,2,2,2],[2,2,2,2]]
b=a
print(matrix_strassen(a,b))
4.2.4 修改Strassen算法,使之适应矩阵规模n不是2的幂的情况。
具体思路是将不是2的次幂的矩阵扩展成2的次幂的矩阵,在多出的行和列上添上0元素,在计算结果重新组合成c后,对c矩阵多出的行和列上的0元素舍去。因此在简单分治程序的基础上增加了matrix_expand和matrix_shrink函数。在主函数中,首先对输入矩阵A,B的阶数进行判断,如果是2的次幂则不用进行任何操作,直接用普通的Strassen算法,如果不是2的次幂,先对A,B进行矩阵拓展,在计算得到的结果后进行矩阵缩略。
#coding UTF-8
#矩阵的strassen算法
from math import *
def matrix_strassen(a,b):
n=len(a)
c = [[0 for col in range(n)] for row in range(n)]
if n==1:
c[0][0]=a[0][0]*b[0][0]
else:
(a11,a12,a21,a22)=division(a)
(b11,b12,b21,b22)=division(b)
(c11,c12,c21,c22)=division(c)
s1=matrix_add_sub(b12,b22,0)
s2=matrix_add_sub(a11,a12,1)
s3=matrix_add_sub(a21,a22,1)
s4=matrix_add_sub(b21,b11,0)
s5=matrix_add_sub(a11,a22,1)
s6=matrix_add_sub(b11,b22,1)
s7=matrix_add_sub(a12,a22,0)
s8=matrix_add_sub(b21,b22,1)
s9=matrix_add_sub(a11,a21,0)
s10=matrix_add_sub(b11,b12,1)
p1=matrix_strassen(a11,s1)
p2=matrix_strassen(s2,b22)
p3=matrix_strassen(s3,b11)
p4=matrix_strassen(a22,s4)
p5=matrix_strassen(s5,s6)
p6=matrix_strassen(s7,s8)
p7=matrix_strassen(s9,s10)
c11=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p4,1),p2,0),p6,1)
c12=matrix_add_sub(p1,p2,1)
c21=matrix_add_sub(p3,p4,1)
c22=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p1,1),p3,0),p7,0)
c=matrix_combination(c11,c12,c21,c22)
return c
def matrix_expand(a): #对a,b执行矩阵扩展程序段
n=len(a)
m=ceil(log(n,2))
p=int(pow(2,m))
c=[[0 for col in range(p)]for row in range(p)]#执行expand模式
for i in range(p):
for j in range(p):
if i>=n or j>=n:
c[i][j]=0
else:
c[i][j]=a[i][j]
return c
def matrix_shrink(a,b):
n=len(b)
c=[[0 for col in range(n)]for row in range(n)]
for i in range(n):
for j in range(n):
c[i][j]=a[i][j]
return c
def division(a): #对矩阵进行分解操作
n=len(a)//2
a11=[[0 for i in range(n)]for j in range(n)]
a12=[[0 for i in range(n)]for j in range(n)]
a21=[[0 for i in range(n)]for j in range(n)]
a22=[[0 for i in range(n)]for j in range(n)]
for i in range(n):
for j in range(n):
a11[i][j]=a[i][j]
a12[i][j]=a[i][j+n]
a21[i][j]=a[i+n][j]
a22[i][j]=a[i+n][j+n]
return (a11,a12,a21,a22)
def matrix_add_sub(a,b,keys): #矩阵加减程序,keys=1时执行矩阵相加,keys=0时执行矩阵相减
n = len(a)
c = [[0 for col in range(n)] for row in range(n)]
if keys==1:
for i in range(n):
for j in range(n):
c[i][j] = a[i][j]+b[i][j]
else:
for i in range(n):
for j in range(n):
c[i][j]=a[i][j]-b[i][j]
return c
def matrix_combination(a11,a12,a21,a22): #对矩阵进行组合操作
n2 = len(a11)
n=n2*2
a = [[0 for col in range(n)] for row in range(n)]
for i in range (0,n):
for j in range (0,n):
if i <= (n2-1) and j <= (n2-1):
a[i][j] = a11[i][j]
elif i <= (n2-1) and j > (n2-1):
a[i][j] = a12[i][j-n2]
elif i > (n2-1) and j <= (n2-1):
a[i][j] = a21[i-n2][j]
else:
a[i][j] = a22[i-n2][j-n2]
return a
a=[[1,1,1,1,1],[1,1,1,1,1],[2,2,2,2,2],[2,2,2,2,2],[3,3,3,3,3]]
b=a
n=len(a)
if not(log(n,2)-floor(log(n,2))): #如果是2的次幂
print(matrix_strassen(a,b))
else:
print(matrix_shrink(matrix_strassen(matrix_expand(a),matrix_expand(b)),a))
此种算法是自己能想到的最简单的思路,但是增加了计算量,例如5阶的方阵会转化为8阶的方阵进行计算,增加了不必要的繁琐的0和0的乘法。自己现在也没在网上看相关的资料,能力有限,望诸位看官海涵。
参考文献:
1.算法导论 机械工业出版社 第四章第二节 矩阵的Strassen算法。