多项式乘法与快速傅里叶变换

多项式乘法与快速傅里叶变换

问题介绍

试想这样一个问题,求两个多项式

f ( x ) = i = 0 n 1 a i x i

g ( x ) = i = 0 m 1 b i x i

的乘积
f ( x ) g ( x ) = i = 0 n + m 2 j + k = i ( a j + b k ) x i

使用传统的方法至少需要 O ( n 2 ) 的复杂度,下面介绍快速傅里叶变换,将这个过程加速到 O ( n log n ) .

问题的快速解法

首先考虑如何用其他方式表示多项式 f ( x ) = i = 0 n 1 a i x i .

任取n个不同的数(可以是整数、实数,甚至是复数)

x 0 , x 1 , , x n 1

将其代入 f ( x ) 中,就得到一个线性方程组
{   f ( x 0 ) = y 0   f ( x 1 ) = y 1     f ( x n 1 ) = y n 1

只要 n 足够大,就能够唯一地确定一个多项式,换言之,上述方程组可以表示一个多项式,将这两种多项式的表示方法分别称为系数表示和点值表示.

利用快速傅里叶变换来求多项式乘积的总体思路是

  1. 选取合适的 n 个不同的数 x 0 , x 1 , , x n 1
  2. 将多项式 f ( x ) g ( x ) 转化为点值表示(称为离散傅里叶变换,简称 D F T )
  3. 计算 f ( x ) g ( x ) 的点值表示
  4. f ( x ) g ( x ) 转化为系数表示(称为逆离散傅里叶变换,简称 D F T 1 )

下面本文将分步讲解上述过程.

1. 选取合适的 n 个不同的数 x 0 , x 1 , , x n 1

我们选取复数域上 1 n n 个不同的值(或称 n n 次单位复根)作为 x 0 , x 1 , , x n 1 的值,即

x k = ω n k = e 2 k π i n , k = 0 , 2 , , n 1

至于指数形式的复数 e 2 k π i n ,用大家所熟知的欧拉公式即可求得其代数形式
e i θ = cos θ + i sin θ

经过简单计算可知
ω n k + m n = cos ( 2 k π n + 2 π m ) + i sin ( 2 k π n + 2 π m ) = cos 2 k π n + i sin 2 k π n = ω n k , m Z

n 为偶数时
( ω n k ) 2 = ( e 2 k π i n ) 2 = e 2 k π i n / 2 = ω n / 2 k = ω n / 2 k   m o d   n / 2

其中 a   m o d   b a 除以 b 的余数,上述两等式将在后文中使用.

我们为什么要费尽周折选取如此复杂的点呢?是为了使用快速傅里叶变换.

2. 将多项式 f ( x ) g ( x ) 转化为点值表示

考虑多项式 f ( x ) = i = 0 n 1 a i x i ,当 n = 2 m , m Z + 时(当不满足该条件时,向 f ( x ) 补充系数为0的高次项来扩大 n 使其满足该条件),将其化为两个多项式

f [ 0 ] ( x ) = a 0 + a 2 x + + a n 2 x n 2 2

f [ 1 ] ( x ) = a 1 + a 3 x + + a n 1 x n 2 2

则有
f ( x ) = f [ 0 ] ( x 2 ) + x f [ 1 ] ( x 2 )

进而
f ( ω n k ) = f [ 0 ] ( ω n / 2 k   m o d   n / 2 ) + ω n k f [ 1 ] ( ω n / 2 k   m o d   n / 2 )

也就是说,要求 f ( x ) n 个不同点处的值,只需要求 f [ 0 ] ( x ) f [ 1 ] ( x ) n 2 个不同点处的值,由于 n = 2 m , m Z + ,可对 f [ 0 ] ( x ) f [ 1 ] ( x ) 重复进行上述过程,最终经过 m 步后得到 n 个函数
f [ 0 ] ( x ) = a 0 , f [ 1 ] ( x ) = a 1 , , f [ n 1 ] ( x ) = a n 1

之后回推得到 f ( x ) 的点值表示,上述过程就是快速傅里叶变换的过程,复杂度为 O ( n m ) O ( n log n ) .

当然,还需要对 g ( x ) 进行同样的变换.

3. 计算 f ( x ) g ( x ) 的点值表示

点值表示的优点是可以快速地求出两个选取了相同点值的多项式的乘积,例如多项式

{   f ( x 0 ) = y 0   f ( x 1 ) = y 1     f ( x n 1 ) = y n 1

与多项式
{   g ( x 0 ) = y 0   g ( x 1 ) = y 1     g ( x n 1 ) = y n 1

的乘积
{   f ( x 0 ) g ( x 0 ) = y 0 z 0   f ( x 1 ) g ( x 1 ) = y 1 z 1     f ( x n 1 ) g ( x n 1 ) = y n 1 z n 1

只需要 O ( n ) 的复杂度即可求得.

4. 将 f ( x ) g ( x ) 转化为系数表示

下面以 f ( x ) 为例,讲解如何将多项式从点值表示转化为系数表示,此过程又称多项式的插值.

f ( x ) 的点值表示写成矩阵形式 Y = V n A

(   y 0   y 1   y 2   y 3     y n 1 ) = (   1 1 1 1 1   1 ω n 1 ω n 2 ω n 3 ω n n 1   1 ω n 2 ω n 4 ω n 6 ω n 2 ( n 1 )   1 ω n 3 ω n 6 ω n 9 ω n 3 ( n 1 )     1 ω n n 1 ω n 2 ( n 1 ) ω n 3 ( n 1 ) ω n ( n 1 ) ( n 1 ) ) (   a 0   a 1   a 2   a 3     a n 1 )

此处矩阵 V n 中的1可视为 ω n 0 .

现在我们已知的是 Y V n ,要求的是 A V n 是一范德蒙德矩阵,可求得其逆矩阵

V n 1 = 1 n (   1 1 1 1 1   1 ω n 1 ω n 2 ω n 3 ω n ( n 1 )   1 ω n 2 ω n 4 ω n 6 ω n 2 ( n 1 )     1 ω n ( n 1 ) ω n 2 ( n 1 ) ω n 3 ( n 1 ) ω n ( n 1 ) ( n 1 ) )

因此 A = V n 1 Y

(   a 0   a 1   a 2   a 3     a n 1 ) = 1 n (   1 1 1 1 1   1 ω n 1 ω n 2 ω n 3 ω n ( n 1 )   1 ω n 2 ω n 4 ω n 6 ω n 2 ( n 1 )   1 ω n 3 ω n 6 ω n 9 ω n 3 ( n 1 )     1 ω n ( n 1 ) ω n 2 ( n 1 ) ω n 3 ( n 1 ) ω n ( n 1 ) ( n 1 ) ) (   y 0   y 1   y 2   y 3     y n 1 )

也就是说,只需将 Y A 对换,将 ω n k 换成 ω n k ,再乘上系数 1 n ,进行类似步骤2的变换,即可进行逆快速傅里叶变换,算法复杂度同样为 O ( n log n ) .

按照上述方法将 f ( x ) g ( x ) 转化为系数表示,本题得解.

代码实现

下面给出计算整系数多项式乘积的C++代码

#include <bits/stdc++.h>
using namespace std;
const double pi=acos(-1.0);
struct cpx
{
    double x,y;
    cpx(double x=0.0,double y=0.0){this->x=x;this->y=y;}
    cpx operator + (const cpx &b)const{return cpx(x+b.x,y+b.y);}
    cpx operator - (const cpx &b)const{return cpx(x-b.x,y-b.y);}
    cpx operator * (const cpx &b)const{return cpx(x*b.x-y*b.y,b.x*y+x*b.y);}
};
inline void Rader(cpx F[],int len)
{
    int j=len>>1;
    for(int i=1;i<len-1;i++)
    {
        if(i<j)swap(F[i],F[j]);
        int k=len>>1;
        while(j>=k)
        {
            j-=k;
            k>>=1;
        }
        if(j<k)j+=k;
    }
}
inline cpx w(int n,int k)
{
    return cpx(cos(2*k*pi/n),sin(2*k*pi/n));
}
cpx temp[10005];
inline void FFT(cpx f[],int len,int flag)
{
    Rader(f,len);
    int wei=-1,tt=len;
    while(tt)
    {
        wei++;
        tt>>=1;
    }
    for(int it=1;it<=wei;it++)
    {
        for(int i=0;i<len;i++)
        {
            int x=-1>>it<<it;
            temp[i]=f[(i&x)+(i&~x>>1)]+w(1<<it,-flag*(i&~x))*f[((i>>it<<1|1)<<it-1)+(i&~x>>1)];
        }
        for(int i=0;i<len;i++)f[i]=temp[i];
    }
    if(flag==-1)for(int i=0;i<len;i++)f[i].x/=len;
}
inline void Convolution(cpx f[],int n,cpx g[],int m)
{
    int len=1;
    while(len<2*max(n,m))len<<=1;
    for(int i=n;i<len;i++)f[i]=cpx(0.0,0.0);
    FFT(f,len,1);
    for(int i=m;i<len;i++)g[i]=cpx(0.0,0.0);
    FFT(g,len,1);
    for(int i=0;i<len;i++)f[i]=f[i]*g[i];
    FFT(f,len,-1);
}
cpx f[1005],g[1005];
int n,m;
int main()
{
    while(~scanf("%d",&n))
    {
        for(int i=0;i<n;i++)
        {
            f[i]=cpx(0.0,0.0);
            scanf("%lf",&f[i].x);
        }
        scanf("%d",&m);
        for(int i=0;i<m;i++)
        {
            g[i]=cpx(0.0,0.0);
            scanf("%lf",&g[i].x);
        }
        Convolution(f,n,g,m);
        for(int i=0;i<=n+m;i++)printf(" %.f",f[i].x);
    }
    return 0;
}

例:

( 2 x 3 + x + 1 ) ( 6 x 2 + 2 x + 3 ) = 12 x 5 + 4 x 4 + 12 x 3 + 8 x 2 + 5 x + 3

Input:
4
1 1 0 2
3
3 2 6
Output:
3 5 8 12 4 12 0 0

猜你喜欢

转载自blog.csdn.net/qq_39515621/article/details/80585497