【模板】多项式乘法(FFT)(NTT)

版权声明:欢迎转载(标记出处),写得差还请多指教 https://blog.csdn.net/quan_tum/article/details/82083691

给定一个n次多项式 F ( x ) ,和一个 m 次多项式 G ( x )

请求出 F ( x ) G ( x ) 的卷积。

FFT看了很久很久才看懂,总是看懂了后面的就忘记了前面的…累啊

#include<bits/stdc++.h>
#define il inline
#define N 10000005
using namespace std;
#define getchar()(p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1<<21],*p1=buf,*p2=buf;
il int read(){
    int x=0,f=1;char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
    for(;isdigit(c);c=getchar()) x=(x+(x<<2)<<1)+c-48;
    return x*f;
}
char sr[1<<21],z[20];int C=-1,Z;
il void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
il void print(int x){
    if(C>1<<20)Ot();if(x<0)sr[++C]=45,x=-x;
    while(z[++Z]=x%10+48,x/=10);
    while(sr[++C]=z[Z],--Z);sr[++C]=' ';
}
const double Pi=acos(-1.0);
struct A{
    double x,y;
    A (double xx=0,double yy=0){x=xx,y=yy;}
}a[N],b[N];
A operator + (A a,A b){return A(a.x+b.x,a.y+b.y);}
A operator - (A a,A b){return A(a.x-b.x,a.y-b.y);}
A operator * (A a,A b){return A(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
int n,m,l,r[N],len=1;
void FFT(A *a,int type){
    for(int i=0;i<len;++i) if(i<r[i]) swap(a[i],a[r[i]]);
    for(int mid=1;mid<len;mid<<=1){
        A Wn(cos(Pi/mid),type*sin(Pi/mid));
        for(int R=mid<<1,j=0;j<len;j+=R){
            A w(1,0);
            for(int k=0;k<mid;k++,w=w*Wn){
                A x=a[j+k],y=w*a[j+mid+k];
                a[j+k]=x+y;a[j+mid+k]=x-y;
            }
        }
    }
}
int main(){
    n=read(),m=read();
    for(int i=0;i<=n;++i) a[i].x=read();
    for(int i=0;i<=m;++i) b[i].x=read();
    while(len<=n+m) len<<=1,++l;
    for(int i=0;i<len;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    FFT(a,1);FFT(b,1);
    for(int i=0;i<=len;++i) a[i]=a[i]*b[i];
    FFT(a,-1);
    for(int i=0;i<=n+m;++i) print((int)(a[i].x/len+0.5));
    Ot();return 0;
}

好像看NTT就没那么累了…

#include<bits/stdc++.h>
#define il inline
#define ll long long
#define swap(x,y) x^=y,y^=x,x^=y
const int MAXN=3*1e6+10,mo=998244353,G=3,Gi=332748118;
#define getchar() (p1==p2 && (p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1<<21],*p1=buf,*p2=buf;
il int read(){
    int x=0,f=1;char c=getchar();
    for(;!isdigit(c);c=getchar()) if(c=='-') f=-1;
    for(;isdigit(c);c=getchar()) x=(x+(x<<2)<<1)+c-48;
    return x*f;
}
int n,m,len=1,l,r[MAXN];
ll a[MAXN],b[MAXN];
il ll qpow(ll a,ll k){
    ll base=1;
    while(k){
        if(k&1) base=(base*a)%mo;
        a=(a*a)%mo;k>>=1;
    }
    return base%mo;
}
il void NTT(ll *a,int type){
    for(int i=0;i<len;++i) if(i<r[i]) swap(a[i],a[r[i]]);
    for(int mid=1;mid<len;mid<<=1){  
        ll Wn=qpow(type==1?G:Gi,(mo-1)/(mid<<1));
        for(int j=0;j<len;j+=(mid<<1)){
            ll w=1;
            for(int k=0;k<mid;k++,w=(w*Wn)%mo){
                 int x=a[j+k],y=w*a[j+k+mid]%mo;
                 a[j+k]=(x+y)%mo,
                 a[j+k+mid]=(x-y+mo)%mo;
            }
        }
    }
}
int main(){
    n=read();m=read();
    for(int i=0;i<=n;++i) a[i]=(read()+mo)%mo;
    for(int i=0;i<=m;++i) b[i]=(read()+mo)%mo;
    while(len<=n+m) len<<=1,++l;
    for(int i=0;i<len;++i) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1)); 
    NTT(a,1);NTT(b,1);   
    for(int i=0;i<len;++i) a[i]=(a[i]*b[i])%mo;
    NTT(a,-1);
    ll inv=qpow(len,mo-2);
    for(int i=0;i<=n+m;++i) printf("%d ",(a[i]*inv)%mo);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/quan_tum/article/details/82083691