n阶方阵乘法straseen

原理:分块矩阵乘法,进行8次矩阵乘法,时间复杂度为 $\theta(n^3) = \theta(n^{\lg{8}}) $ , 改进后仅需要7次乘法, 时间复杂度为 \(\theta(n^{\lg{7}})\)
具体推到见算法导论中利用主定理推导时间复杂度

def matrix_divide(A):
    rows = len(A)
    mid = rows // 2
    A11 = [[0]*mid for _ in range(mid)]
    A12 = [[0]*mid for _ in range(mid)]
    A21 = [[0]*mid for _ in range(mid)]
    A22 = [[0]*mid for _ in range(mid)]

    for i in range(mid):
        for j in range(mid):
            A11[i][j] = A[i][j]
            A12[i][j] = A[i][mid+j]
            A21[i][j] = A[mid+i][j]
            A22[i][j] = A[mid+i][mid+j]
    return A11, A12, A21, A22

def matrix_add(A, B):
    rows = len(A)
    C = [[0]*rows for _ in range(rows)]
    for i in range(rows):
        for j in range(rows):
            C[i][j] = A[i][j] + B[i][j]
    return C

def matrix_sub(A, B):
    rows = len(A)
    C = [[0]*rows for _ in range(rows)]
    for i in range(rows):
        for j in range(rows):
            C[i][j] = A[i][j] - B[i][j]
    return C


def matrix_merge(C11, C12, C21, C22):
    rows = len(C11)
    n = rows * 2
    C = [[0]*n for _ in range(n)]
    for i in range(rows):
        for j in range(rows):
            C[i][j] = C11[i][j]
            C[i][rows+j] = C12[i][j]
            C[rows+i][j] = C21[i][j]
            C[rows+i][rows+j] = C22[i][j]
    return C


def strassen(A, B):
    n = len(A)
    C = [[0] for _ in range(n)]
    if n == 1:
        C[0][0] = A[0][0]*B[0][0]
        return C
    A11, A12, A21, A22 = matrix_divide(A)
    B11, B12, B21, B22 = matrix_divide(B)

    S1 = matrix_sub(B12, B22)
    S2 = matrix_add(A11, A12)
    S3 = matrix_add(A21, A22)
    S4 = matrix_sub(B21, B11)
    S5 = matrix_add(A11, A22)
    S6 = matrix_add(B11, B22)
    S7 = matrix_sub(A12, A22)
    S8 = matrix_add(B21, B22)
    S9 = matrix_sub(A11, A21)
    S10 = matrix_add(B11, B12)
    
    P1 = strassen(A11, S1)
    P2 = strassen(S2, B22)
    P3 = strassen(S3, B11)
    P4 = strassen(A22, S4)
    P5 = strassen(S5, S6)
    P6 = strassen(S7, S8)
    P7 = strassen(S9, S10)

    C11 = matrix_add(P5, matrix_sub(P4, matrix_sub(P2, P6)))
    C12 = matrix_add(P1, P2)
    C21 = matrix_add(P3, P4)
    C22 = matrix_add(P5, matrix_sub(P1, matrix_add(P3, P7)))
    
    return matrix_merge(C11, C12, C21, C22)
def main():
    A = [[1,1,1,1],[2,2,2,2],[3,3,3,3],[4,4,4,4]]
    B = [[5,5,5,5],[6,6,6,6],[7,7,7,7],[8,8,8,8]]
    C = strassen(A, B)
    print(C)
if __name__ == '__main__':
    main()

猜你喜欢

转载自www.cnblogs.com/vito_wang/p/10806816.html