【洛谷】5205 【模板】多项式开根

总结啥的就放到多项式入门里了,好多细节需要注意~ 

code: 

#include <bits/stdc++.h>   
#define ll long long 
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std;    
const int mod=998244353,G=3,N=1000003;        
int A[N],B[N],f[N],g[N],inv2,C[N],D[N];     
inline int qpow(int x,int y) 
{
    int tmp=1; 
    for(;y;y>>=1,x=(ll)x*x%mod)   if(y&1)    tmp=1ll*tmp*x%mod;  
    return tmp; 
} 
inline int INV(int x) { return qpow(x,mod-2); }       
void NTT(int *a,int len,int flag) 
{
    int i,j,k,mid;  
    for(i=k=0;i<len;++i) 
    {
        if(i>k)     swap(a[i],a[k]);                                                           
        for(j=len>>1;(k^=j)<j;j>>=1);  
    }  
    for(mid=1;mid<len;mid<<=1) 
    {
        int wn=qpow(G,(mod-1)/(mid<<1));   
        if(flag==-1)    wn=INV(wn);  
        for(i=0;i<len;i+=mid<<1) 
        {
            int w=1;   
            for(j=0;j<mid;++j) 
            {
                int x=a[i+j], y=1ll*w*a[i+j+mid]%mod;   
                a[i+j]=1ll*(x+y)%mod,  a[i+j+mid]=1ll*(x-y+mod)%mod;   
                w=1ll*w*wn%mod;   
            }
        }
    }
    if(flag==-1)  
    {
        int rev=INV(len);   
        for(i=0;i<len;++i)    a[i]=1ll*a[i]*rev%mod;   
    }
} 
void getinv(int *a,int *b,int len) 
{   
    if(len==1) { b[0]=INV(a[0]);  return; }         
    getinv(a,b,len>>1);       
    int i,j; 
    for(i=0;i<(len<<1);++i)   C[i]=D[i]=0;                   
    for(i=0;i<len;++i)        C[i]=a[i], D[i]=b[i];    
    NTT(C,len<<1,1); 
    NTT(D,len<<1,1);   
    for(i=0;i<(len<<1);++i)   C[i]=1ll*C[i]*D[i]%mod*D[i]%mod;  
    NTT(C,len<<1,-1);       
    for(i=0;i<len;++i)        b[i]=((b[i]<<1)%mod-C[i]+mod)%mod;                                                                      
}         
void getsqrt(int *a,int *b,int len) 
{
    if(len==1)   { b[0]=1; return;  }               
    getsqrt(a,b,len>>1);                         
    int i,j;  
    for(i=0;i<(len<<1);++i)    A[i]=B[i]=0;   
    getinv(b,B,len);                                             
    for(i=0;i<len;++i)         A[i]=a[i];            
    NTT(A,len<<1,1);                               
    NTT(B,len<<1,1);   
    for(i=0;i<(len<<1);++i)    A[i]=1ll*A[i]*B[i]%mod;         
    NTT(A,len<<1,-1);          
    for(i=0;i<len;++i)         b[i]=1ll*(b[i]+A[i])%mod*inv2%mod;                                               
}     
int main() 
{
    // setIO("input");   
    int n,i,j,lim=1;  
    inv2=INV(2);             
    scanf("%d",&n);          
    for(i=0;i<n;++i)    scanf("%d",&f[i]);         
    while(lim<n)  lim<<=1;     
    getsqrt(f,g,lim);                                    
    for(i=0;i<n;++i)    printf("%d ",g[i]);   
    return 0;   
}   

  

猜你喜欢

转载自www.cnblogs.com/guangheli/p/11909987.html