矩阵快速幂优化递推式 例:斐波那契数列

首先是一点基础知识:

① 矩阵相乘的规则:矩阵与矩阵相乘 第一个矩阵的列数必须等于第二个矩阵的行数 假如第一个是m*n的矩阵 第二个是n*p的矩  阵则结果就是m*p的矩阵且得出来的矩阵中元素具有以下特点:

第一行第一列元素为第一个矩阵的第一行的每个元素和第二个矩阵的第一列的每个元素乘积的和 以此类推 第i行第j列的元素就是第一个矩阵的第i行的每个元素与第二个矩阵第j列的每个元素的乘积的和。

② 单位矩阵: n*n的矩阵 mat ( i , i )=1; 任何一个矩阵乘以单位矩阵就是它本身 n*单位矩阵=n, 可以把单位矩阵等价为整数1。(单位矩阵用在矩阵快速幂中)

例如下图就是一个7*7的单位矩阵:

矩阵及其乘法的实现:
 http://blog.csdn.net/g_congratulation/article/details/52734281


对于矩阵乘法与递推式之间的关系:

如:在斐波那契数列之中

fi[i] = 1*fi[i-1]+1*fi[i-2] fi[i-1] = 1*f[i-1] + 0*f[i-2];

所以

矩阵快速幂:

因为矩阵乘法满足结合律,原因如下


所以,我们可以用类似数字快速幂的算法来解决矩阵快速幂。(前提:矩阵为n*n的矩阵,原因见矩阵乘法定义)

代码

[cpp]  view plain  copy
  1. Matrix fast_pow(Matrix a, int x) {  
  2.     Matrix ans;  
  3.     ans.x = a.x;  
  4.     for(int i = 0; i < ans.x; i++)  
  5.         ans.a[i][i] = 1;  
  6.     while(x) {  
  7.         if(x&1)   
  8.             ans = ans*a;  
  9.         a = a*a;  
  10.         x >>= 1;  
  11.     }  
  12.     return ans;  
  13. }  

用矩阵快速幂求斐波那契数列的第N项的代码:

[cpp]  view plain  copy
  1. #include<cstdio>  
  2. #include<algorithm>  
  3. #include<cstring>  
  4. #include<iostream>  
  5. using namespace std;  
  6.   
  7. const int M = 1e9+7;  
  8.   
  9. struct Matrix {  
  10.     long long a[2][2];  
  11.     Matrix() {  
  12.         memset(a, 0, sizeof(a));  
  13.     }  
  14.     Matrix operator * (const Matrix y) {  
  15.         Matrix ans;  
  16.         for(int i = 0; i <= 1; i++)  
  17.             for(int j = 0; j <= 1; j++)    
  18.                 for(int k = 0; k <= 1; k++)    
  19.                     ans.a[i][j] += a[i][k]*y.a[k][j];  //乘完再加到a[i][j]不会影响到最终结果
  20.         for(int i = 0; i <= 1; i++)  
  21.             for(int j = 0; j <= 1; j++)  
  22.                 ans.a[i][j] %= M;  
  23.         return ans;  
  24.     }  
  25.     void operator = (const Matrix b) {  
  26.         for(int i = 0; i <= 1; i++)  
  27.             for(int j = 0; j <= 1; j++)  
  28.                 a[i][j] = b.a[i][j];  
  29.     }  
  30. };  
  31.   
  32. int solve(long long x) {  
  33.     Matrix ans, trs;  
  34.     ans.a[0][0] = ans.a[1][1] = 1;  
  35.     trs.a[0][0] = trs.a[1][0] = trs.a[0][1] = 1;  
  36.     while(x) {  
  37.         if(x&1)   
  38.             ans = ans*trs;  
  39.         trs = trs*trs;  
  40.         x >>= 1;  
  41.     }  
  42.     return ans.a[0][0];  
  43. }  
  44.   
  45. int main() {  
  46.     int n;  
  47.     scanf("%d", &n);  
  48.     cout << solve(n-1) << endl;  
  49.     return 0;  
  50. }  

猜你喜欢

转载自blog.csdn.net/qq_27151549/article/details/80070102
今日推荐