矩阵相乘提高版

Description

在科学计算中经常要计算矩阵的乘积。矩阵A和B可乘的条件是矩阵A的列数等于矩阵B的行数。若A是一个m行、n列的矩阵(简称m×n的矩阵),B是一个n行、p列的矩阵(简称n×p的矩阵),则其乘积C=A×B是一个m×p的矩阵。其标准计算公式为: 
          
  由该公式知计算C=A×B总共需要进行m×n×p次的数乘法,我们将两个矩阵乘法次数定义为矩阵想乘的时间复杂度。 
  矩阵乘法满足矩阵乘法满足结合律(但不满足交换律),即对于D=A×B×C,可以有如下两种计算方式: 
   
   1、D=(A×B)×C   
   2、D=A×(B×C) 

  假设A是个10×100的矩阵、B是个100×5的矩阵,C是个5×50的矩阵,那么: 

   ●按照第一种计算方法得到的时间复杂度是:10×100×5+10×5×50=7500; 
   ●按照第一种计算方法得到的时间复杂度是:100×5×50+10×100×50=75000; 

  所以不同的计算顺序得到的时间复杂度是不一样的,现在的问题是:顺序给出n个矩阵的大小,请计算出它们的乘积的最小的时间复杂度。

Input

  第一行输入一个正整数n,表示有n个矩阵。 
  接下来m行每行两个正整数Xi,Yi,其中第i行的两个数表示第i个矩阵的规模为Xi行、Yi列。所有的Xi、Yi<=100。 
  输入数据保证这些矩阵可以相乘。

Output

  输出仅一行为一个整数表示n个矩阵连乘最小时间复杂度。

Sample Input

3
10 100
100 5
5 50

Sample Output

7500

Hint

数据范围:n<=100

因为若两个矩阵可以相乘,那么$A[i].y=B[i].x$,故可以把$n$个矩阵的序列拉成$2n-1$数的序列

简单的区间DP,设f[i][j]为$[i,j]$为一个矩阵的最小代价

$f[i][j]=min(f[i][k]+f[k][j]+a[i]*a[k]*a[j],f[i][j])$

长度为1、2的区间初始为0,长度为3的区间暴力乘一下,其他为INF

#include<iostream>
#include<iomanip>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
using namespace std;
#define INF 0x3f3f3f3f
#define int long long
inline char gc() {
//    return getchar();
    static char buf[100000],*p1=buf,*p2=buf;
    return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
const int mod=19260817;
inline int read() {
    char ch;
    int bj=0;
    while(!isdigit(ch=gc()))
        bj|=(ch=='-');
    int res=ch^(3<<4);
    while(isdigit(ch=gc()))
        res=((res<<1)+(res<<3)+(ch^(3<<4)))%mod;
    return bj?-res:res;
}
void printnum(int x) {
    if(x>9)printnum(x/10);
    putchar(x%10+'0');
}
inline void print(int x,char ch) {
    if(x<0) {
        putchar('-');
        x=-x;
    }
    printnum(x);
    putchar(ch);
}
int n,a[205],tot,f[205][205];
signed main() {
    n=read();
    memset(f,0x3f,sizeof(f));
    a[++tot]=read(),read();
    for(int i=2; i<n; i++)a[++tot]=read(),read();
    a[++tot]=read(),a[++tot]=read();
    for(int i=1; i<=2; i++)
        for(int j=1; j<=n; j++)f[j][j+i-1]=0;
    for(int i=1; i+2<=tot; i++)f[i][i+2]=a[i]*a[i+1]*a[i+2];
    for(int len=4; len<=tot; len++) {
        for(int i=1; i<=tot-len+1; i++) {
            int j=i+len-1;
            for(int k=i+1; k<j; k++)f[i][j]=min(f[i][k]+f[k][j]+a[i]*a[k]*a[j],f[i][j]);
        }
    }
    print(f[1][tot],'\n');
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/soledadstar/p/11742288.html