算法导论——python实践(4.2矩阵乘法的Strassen算法)

本文参考自: 原文地址

4.2.1 矩阵乘法的暴力解法

#暴力解法
def matrix_multiply(a,b):
    n=len(a)
    c=[[0]*n for i in range(n)]#快速创建n阶初始化方阵
    for i in range (0,n):
        for j in range(0,n):
            c[i][j]=0
            for k in range(0,n):
                c[i][j]+=a[i][k]*b[k][j]
    return c

这里假定a和b都是方阵,如果选择暴力解法,三个for循环,循环次数为n,总共需要花费θ(n^3)时间。

4.2.2 矩阵乘法的简单分治法

算法策略:(前提:假定A,B都是n等于2的次幂的方阵)

(1)基本思路:计算C=A*B时,将C,A,B矩阵进行分块操作,对每个分块的矩阵进行乘法运算,运算完毕后重新对得到的C11,C12,C21,C22进行组合操作。

(2)确定递归终止条件:当分块矩阵得到的阶数为1 时,得到的C即是A和B中两个元素的乘积。

def division(a):    #矩阵分块函数
    n=len(a)//2
    a11=[[0 for i in range(n)]for j in range(n)]
    a12=[[0 for i in range(n)]for j in range(n)]
    a21=[[0 for i in range(n)]for j in range(n)]
    a22=[[0 for i in range(n)]for j in range(n)]
    for i in range(n):
        for j in range(n):
            a11[i][j]=a[i][j]
            a12[i][j]=a[i][j+n]
            a21[i][j]=a[i+n][j]
            a22[i][j]=a[i+n][j+n]
    return (a11,a12,a21,a22)

def matrix_combination(a11,a12,a21,a22):
    n2 = len(a11)
    n=n2*2
    a = [[0 for col in range(n)] for row in range(n)]
    for i in range (0,n):
        for j in range (0,n):
            if i <= (n2-1) and j <= (n2-1):
                a[i][j] = a11[i][j]
            elif i <= (n2-1) and j > (n2-1):
                a[i][j] = a12[i][j-2]
            elif i > (n2-1) and j <= (n2-1):
                a[i][j] = a21[i-n2][j]
            else:
                a[i][j] = a22[i-n2][j-n2]
    return a
def matrix_add(a,b):  #矩阵相加函数
    n = len(a)
    c = [[0 for col in range(n)] for row in range(n)]
    for i in range(0,n):
        for j in range(0,n):
            c[i][j] = a[i][j]+b[i][j]
    return c

def matrix_devision_multiply(a,b):   #矩阵乘法的简单分治法主程序
    n=len(a)
    c = [[0 for col in range(n)] for row in range(n)]#c=[[0]*n for i in range(n)]
    if n==1:
        c[0][0]=a[0][0]*b[0][0]
    else:
        (a11,a21,a12,a22)=division(a)
        (b11,b21,b12,b22)=division(b)
        (c11,c21,c12,c22)=division(c)
        c11=matrix_add(matrix_devision_multiply(a11,b11),matrix_devision_multiply(a12,b21))
        c12=matrix_add(matrix_devision_multiply(a11,b12),matrix_devision_multiply(a12,b22))
        c21=matrix_add(matrix_devision_multiply(a21,b11),matrix_devision_multiply(a22,b21))
        c22=matrix_add(matrix_devision_multiply(a21,b12),matrix_devision_multiply(a22,b22))
        c=matrix_combination(c11,c12,c21,c22)
    return c

a=[[1,1,1,1],[1,1,1,1],[2,2,2,2],[2,2,2,2]]
b=a
print(matrix_devision_multiply(a,b))

4.2.3矩阵的Strassen算法

  在简单分治法的思想上,为进一步减少递归树的分枝,在递归函数中只进行7次而不是8次的矩阵的乘法,而减少一次乘法的代价是增加额外的几次矩阵加法运算。

def matrix_strassen(a,b):
    n=len(a)
    c = [[0 for col in range(n)] for row in range(n)]
    if n==1:
        c[0][0]=a[0][0]*b[0][0]
    else:
        (a11,a12,a21,a22)=division(a)
        (b11,b12,b21,b22)=division(b)
        (c11,c12,c21,c22)=division(c)
        s1=matrix_add_sub(b12,b22,0)
        s2=matrix_add_sub(a11,a12,1)
        s3=matrix_add_sub(a21,a22,1)
        s4=matrix_add_sub(b21,b11,0)
        s5=matrix_add_sub(a11,a22,1)
        s6=matrix_add_sub(b11,b22,1)
        s7=matrix_add_sub(a12,a22,0)
        s8=matrix_add_sub(b21,b22,1)
        s9=matrix_add_sub(a11,a21,0)
        s10=matrix_add_sub(b11,b12,1)
        p1=matrix_strassen(a11,s1)
        p2=matrix_strassen(s2,b22)
        p3=matrix_strassen(s3,b11)
        p4=matrix_strassen(a22,s4)
        p5=matrix_strassen(s5,s6)
        p6=matrix_strassen(s7,s8)
        p7=matrix_strassen(s9,s10)
        c11=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p4,1),p2,0),p6,1)
        c12=matrix_add_sub(p1,p2,1)
        c21=matrix_add_sub(p3,p4,1)
        c22=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p1,1),p3,0),p7,0)
        c=matrix_combination(c11,c12,c21,c22)
    return c

#矩阵的strssen算法
def division(a):                              #对矩阵进行分解操作
    n=len(a)//2
    a11=[[0 for i in range(n)]for j in range(n)]
    a12=[[0 for i in range(n)]for j in range(n)]
    a21=[[0 for i in range(n)]for j in range(n)]
    a22=[[0 for i in range(n)]for j in range(n)]
    for i in range(n):
        for j in range(n):
            a11[i][j]=a[i][j]
            a12[i][j]=a[i][j+n]
            a21[i][j]=a[i+n][j]
            a22[i][j]=a[i+n][j+n]
    return (a11,a12,a21,a22)

def matrix_add_sub(a,b,keys):
    n = len(a)
    c = [[0 for col in range(n)] for row in range(n)]
    if keys==1:
        for i in range(n):
            for j in range(n):
                c[i][j] = a[i][j]+b[i][j]
    else:
        for i in range(n):
            for j in range(n):
                c[i][j]=a[i][j]-b[i][j]
    return c
def matrix_combination(a11,a12,a21,a22):    #对矩阵进行组合操作
    n2 = len(a11)
    n=n2*2
    a = [[0 for col in range(n)] for row in range(n)]
    for i in range (0,n):
        for j in range (0,n):
            if i <= (n2-1) and j <= (n2-1):
                a[i][j] = a11[i][j]
            elif i <= (n2-1) and j > (n2-1):
                a[i][j] = a12[i][j-n2]
            elif i > (n2-1) and j <= (n2-1):
                a[i][j] = a21[i-n2][j]
            else:
                a[i][j] = a22[i-n2][j-n2]
    return a

a=[[1,1,1,1],[1,1,1,1],[2,2,2,2],[2,2,2,2]]
b=a
print(matrix_strassen(a,b))

4.2.4 修改Strassen算法,使之适应矩阵规模n不是2的幂的情况。

具体思路是将不是2的次幂的矩阵扩展成2的次幂的矩阵,在多出的行和列上添上0元素,在计算结果重新组合成c后,对c矩阵多出的行和列上的0元素舍去。因此在简单分治程序的基础上增加了matrix_expand和matrix_shrink函数。在主函数中,首先对输入矩阵A,B的阶数进行判断,如果是2的次幂则不用进行任何操作,直接用普通的Strassen算法,如果不是2的次幂,先对A,B进行矩阵拓展,在计算得到的结果后进行矩阵缩略。

#coding UTF-8
#矩阵的strassen算法
from math import *
def matrix_strassen(a,b):
    n=len(a)
    c = [[0 for col in range(n)] for row in range(n)]
    if n==1:
        c[0][0]=a[0][0]*b[0][0]
    else:
        (a11,a12,a21,a22)=division(a)
        (b11,b12,b21,b22)=division(b)
        (c11,c12,c21,c22)=division(c)
        s1=matrix_add_sub(b12,b22,0)
        s2=matrix_add_sub(a11,a12,1)
        s3=matrix_add_sub(a21,a22,1)
        s4=matrix_add_sub(b21,b11,0)
        s5=matrix_add_sub(a11,a22,1)
        s6=matrix_add_sub(b11,b22,1)
        s7=matrix_add_sub(a12,a22,0)
        s8=matrix_add_sub(b21,b22,1)
        s9=matrix_add_sub(a11,a21,0)
        s10=matrix_add_sub(b11,b12,1)
        p1=matrix_strassen(a11,s1)
        p2=matrix_strassen(s2,b22)
        p3=matrix_strassen(s3,b11)
        p4=matrix_strassen(a22,s4)
        p5=matrix_strassen(s5,s6)
        p6=matrix_strassen(s7,s8)
        p7=matrix_strassen(s9,s10)
        c11=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p4,1),p2,0),p6,1)
        c12=matrix_add_sub(p1,p2,1)
        c21=matrix_add_sub(p3,p4,1)
        c22=matrix_add_sub(matrix_add_sub(matrix_add_sub(p5,p1,1),p3,0),p7,0)
        c=matrix_combination(c11,c12,c21,c22)
    return c

def matrix_expand(a):       #对a,b执行矩阵扩展程序段
    n=len(a)
    m=ceil(log(n,2))
    p=int(pow(2,m))
    c=[[0 for col in range(p)]for row in range(p)]#执行expand模式
    for i in range(p):
        for j in range(p):
            if i>=n or j>=n:
                c[i][j]=0
            else:
                c[i][j]=a[i][j]
    return c
def matrix_shrink(a,b):
    n=len(b)
    c=[[0 for col in range(n)]for row in range(n)]
    for i in range(n):
        for j in range(n):
            c[i][j]=a[i][j]
    return c

def division(a):                              #对矩阵进行分解操作
    n=len(a)//2
    a11=[[0 for i in range(n)]for j in range(n)]
    a12=[[0 for i in range(n)]for j in range(n)]
    a21=[[0 for i in range(n)]for j in range(n)]
    a22=[[0 for i in range(n)]for j in range(n)]
    for i in range(n):
        for j in range(n):
            a11[i][j]=a[i][j]
            a12[i][j]=a[i][j+n]
            a21[i][j]=a[i+n][j]
            a22[i][j]=a[i+n][j+n]
    return (a11,a12,a21,a22)

def matrix_add_sub(a,b,keys):  #矩阵加减程序,keys=1时执行矩阵相加,keys=0时执行矩阵相减
    n = len(a)
    c = [[0 for col in range(n)] for row in range(n)]
    if keys==1:
        for i in range(n):
            for j in range(n):
                c[i][j] = a[i][j]+b[i][j]
    else:
        for i in range(n):
            for j in range(n):
                c[i][j]=a[i][j]-b[i][j]
    return c
def matrix_combination(a11,a12,a21,a22):    #对矩阵进行组合操作
    n2 = len(a11)
    n=n2*2
    a = [[0 for col in range(n)] for row in range(n)]
    for i in range (0,n):
        for j in range (0,n):
            if i <= (n2-1) and j <= (n2-1):
                a[i][j] = a11[i][j]
            elif i <= (n2-1) and j > (n2-1):
                a[i][j] = a12[i][j-n2]
            elif i > (n2-1) and j <= (n2-1):
                a[i][j] = a21[i-n2][j]
            else:
                a[i][j] = a22[i-n2][j-n2]
    return a

a=[[1,1,1,1,1],[1,1,1,1,1],[2,2,2,2,2],[2,2,2,2,2],[3,3,3,3,3]]
b=a
n=len(a)
if not(log(n,2)-floor(log(n,2))):  #如果是2的次幂
    print(matrix_strassen(a,b))
else:
    print(matrix_shrink(matrix_strassen(matrix_expand(a),matrix_expand(b)),a))

此种算法是自己能想到的最简单的思路,但是增加了计算量,例如5阶的方阵会转化为8阶的方阵进行计算,增加了不必要的繁琐的0和0的乘法。自己现在也没在网上看相关的资料,能力有限,望诸位看官海涵。

参考文献:

1.算法导论 机械工业出版社 第四章第二节 矩阵的Strassen算法。


猜你喜欢

转载自blog.csdn.net/running987/article/details/81429957
今日推荐