codeforces E. The Child and Binary Tree(多项式求逆,多项式开根,dp)

题面:https://codeforces.com/problemset/problem/438/E

题解:设\[f(n)\]表示权值和为n的二叉数有多少个\[g(n)\]表示集合里有没有权值为n的数

则:当n=0时,\[f(n) = 1\]

当n!=0时,先枚举根节点的权值,然后枚举左右子树的个数

\[f(n) = \sum\limits_{i = 0}^n {g(i)\sum\limits_{j = 0}^{n - i} {f(j){\rm{\cdot}}f(n - i - j)} } \]

令\[F(x) = \sum\limits_{i = 0}^\infty  {f(i){\rm{\cdot}}{x^i}} \],\[G(x) = \sum\limits_{i = 0}^\infty  {g(i){\rm{\cdot}}{x^i}} \]

则\[F = G{\rm{*}}{F^2}{\rm{ + }}1\]解这个方程

有解:\[\frac{{1{\rm{ \pm }}\sqrt {1 - 4{\rm{G}}} }}{{2G}}\]

舍去正号的解,则答案为\[\frac{{1 - \sqrt {1 - 4{\rm{G}}} }}{{2G}}\]的系数,接下来就是多项式求逆与开根

#include<bits/stdc++.h>
#define ms(x) memset(x,0,sizeof(x))
#define sws ios::sync_with_stdio(false)
using namespace std;
typedef long long ll;
const int maxn=4e5+5;
const double pi=acos(-1.0);
const ll mod=998244353;///通常情况下的模数,
const ll g=3;///模数的原根998244353,1004535809,469762049
ll qpow(ll a,ll n,ll p){
    ll ans=1;
    while(n){
        if(n&1) ans=ans*a%p;
        n>>=1;
        a=a*a%p;
    }
    return ans;
}
int rev[maxn];
int inv2;
void ntt(int a[],int n,int len,int pd){
    rev[0]=0;
    for(int i=1;i<n;i++){
        rev[i]=(rev[i>>1]>>1 | ((i&1)<<(len-1)));
        if(i<rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int mid=1;mid<n;mid<<=1){
        ll wn=qpow(g,(mod-1)/(mid*2),mod);///原根代替单位根
        if(pd==-1) wn=qpow(wn,mod-2,mod);///逆变换则改成逆元
        for(int j=0;j<n;j+=2*mid){
            ll w=1;
            for(int k=0;k<mid;k++){
                ll x=a[j+k],y=w*a[j+k+mid]%mod;
                a[j+k]=(x+y)%mod;
                a[j+k+mid]=(x-y+mod)%mod;
                w=w*wn%mod;
            }
        }
    }
    if(pd==-1){
        ll inv=qpow(n,mod-2,mod);
        for(int i=0;i<n;i++){
            a[i]=a[i]*inv%mod;

        }
    }
}
int A[maxn],B[maxn];
void solve(int *a,int *b,int n){
    int len=0,up=1;
    while(up<n) up<<=1,len++;
    ntt(a,up,len,1);
    ntt(b,up,len,1);
    for(int i=0;i<up;i++) a[i]=1ll*a[i]*b[i]%mod*b[i]%mod;
    ntt(a,up,len,-1);
}
void Inv(int *a,int *b,ll n){
    if(n==1){
        b[0]=qpow(a[0],mod-2,mod);
        return;
    }
    Inv(a,b,n>>1);
    for(int i=0;i<n;i++) A[i]=a[i],B[i]=b[i];
    solve(A,B,n<<1);
    for(int i=0;i<n;i++) b[i]=(2ll*b[i]%mod-A[i]+mod)%mod;
    for(int i=0;i<=2*n;i++) A[i]=B[i]=0;
}
int x[maxn],y[maxn],C[maxn],D[maxn],in[maxn];
void mul(int *a,int *b,int n){
      int len=0,up=1;
    while(up<n) up<<=1,len++;
    ntt(a,up,len,1);
    ntt(b,up,len,1);
    for(int i=0;i<up;i++) a[i]=1ll*a[i]*b[i]%mod;
    ntt(a,up,len,-1);
}
void Sqrt(int *a,int *b,ll n){
    if(n==1){
        b[0]=a[0];
        return;
    }
    Sqrt(a,b,n>>1);
    for(int i=0;i<n;i++) C[i]=a[i];
    Inv(b,D,n);
    mul(D,C,n<<1);
    for(int i=0;i<n;i++)b[i]=1ll*(b[i]+D[i])%mod*inv2%mod;
    for(int i=0;i<=n*2;i++) C[i]=D[i]=0;
    
}
int main(){
    int n,m;
    inv2=qpow(2,mod-2,mod);
    sws;
    cin>>n>>m;
    int up=0;
    for(int i=0;i<n;i++) {
        int c;
        cin>>c;
        up=max(c,up);
        x[c]=(-4+mod)%mod;
    }
    up=m+1;
    x[0]=(1-x[0]+mod)%mod;
    int len=1;
    while(len<up) len<<=1;
    Sqrt(x,y,len);
    y[0]++;
    Inv(y,in,len);
    for(int i=1;i<=m;i++) cout<<2*in[i]%mod<<endl; 
}

猜你喜欢

转载自www.cnblogs.com/azznaz/p/11531053.html
今日推荐