Codeforces997D Cycles in product 【FFT】【树形DP】

题目大意:

给两个树,求环的个数。

题目分析:

出题人摆错题号系列。

通过画图很容易就能想到把新图拆在两个树上,在树上游走成环。

考虑DP状态F,G,T。F表示最终答案,T表示儿子不考虑父亲,G表示父亲不考虑儿子。T通过从下往上做NTT,G通过从上往下做NTT。F顺便做NTT。

最后做一下拼接就行。

代码:

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 
  4 const int maxn = 4020;
  5 const int mod = 998244353;
  6 const int gg = 3;
  7 
  8 int n[2],k;
  9 
 10 vector <int> g[2][maxn];
 11 
 12 int f[2][maxn][80],gi[2][maxn][80],T[2][maxn][80];
 13 
 14 int C[100][100];
 15 
 16 int fast_pow(int now,int pw){
 17     int ans = 1,bit = 1,dt = now;
 18     while(bit <= pw){
 19         if(bit & pw){ans = (1ll*ans*dt)%mod;}
 20         bit <<=1; dt = (1ll*dt*dt)%mod;
 21     }
 22     return ans;
 23 }
 24 
 25 void read(){
 26     scanf("%d%d%d",&n[0],&n[1],&k);
 27     for(int i=1;i<n[0];i++){
 28         int u,v; scanf("%d%d",&u,&v);
 29         g[0][u].push_back(v); g[0][v].push_back(u);
 30     }
 31     for(int i=1;i<n[1];i++){
 32         int u,v; scanf("%d%d",&u,&v);
 33         g[1][u].push_back(v); g[1][v].push_back(u);
 34     }
 35 }
 36 
 37 int ord[260];
 38 
 39 void NTT(int *d,int len,int dr){
 40     for(int i=0;i<len;i++) if(ord[i] < i) swap(d[i],d[ord[i]]);
 41     for(int i=1;i<len;i<<=1){
 42         int wn = fast_pow(gg,(mod-1)/(2*i));
 43         if(dr == -1) wn = fast_pow(wn,mod-2);
 44         for(int j=0;j<len;j+=(i<<1)){
 45             for(int k=0,w=1;k<i;k++,w = (1ll*w*wn)%mod){
 46                 int x = d[j+k],y = (1ll*w*d[j+k+i])%mod;
 47                 d[j+k] = (x+y)%mod;
 48                 d[j+k+i] = (x-y+mod)%mod;
 49             }
 50         }
 51     }
 52     if(dr == -1){
 53         int iv = fast_pow(len,mod-2);
 54         for(int i=0;i<len;i++) d[i] = (1ll*d[i]*iv)%mod;
 55     }
 56 }
 57 
 58 int A[260],B[260];
 59 int fi[260],A0[260];
 60 
 61 void INV(){
 62     int len = 1,bit = 0; while(len <= k) len<<=1,bit++;
 63     memset(A0,0,sizeof(A0));memset(fi,0,sizeof(fi));
 64     A0[0] = 1;
 65     for(int i=2,j=1;i<=len;i<<=1,j++){
 66         for(int k=0;k<i;k++) fi[k] = A[k];
 67         int rl = i*2,rb = j+1;
 68         for(int k=0;k<rl;k++) ord[k] = (ord[k>>1]>>1) + ((k&1)<<rb-1);
 69         NTT(A0,rl,1); NTT(fi,rl,1);
 70         for(int k=0;k<rl;k++){
 71             A0[k] = (2*A0[k]-(1ll*fi[k]*A0[k]%mod)*A0[k]%mod)%mod;
 72             if(A0[k] < 0) A0[k] += mod;
 73         }
 74         NTT(A0,rl,-1);
 75         for(int k=i;k<rl;k++) A0[k] = fi[k] = 0;
 76     }
 77     for(int i=0;i<=k;i++) A[i] = A0[i];
 78 }
 79 
 80 void dfs1(int kd,int now,int fa){
 81     for(auto it:g[kd][now]){
 82         if(it == fa) continue;
 83         dfs1(kd,it,now);
 84     }
 85     memset(A,0,sizeof(A));
 86     for(auto it:g[kd][now]){
 87         if(it == fa) continue;
 88         for(int i=0;i<=k-2;i+=2) A[i+2] = (A[i+2]+T[kd][it][i])%mod;
 89     }
 90     A[0] -= 1; if(A[0] < 0) A[0] += mod;
 91     for(int i=0;i<=k;i++){ A[i] *= -1; if(A[i] < 0) A[i] += mod;}
 92     INV();
 93     for(int i=0;i<=k;i++) T[kd][now][i] = A[i];
 94 }
 95 
 96 void dfs2(int kd,int now,int fa){
 97     memset(B,0,sizeof(B));
 98     for(auto it:g[kd][now]){
 99         if(it == fa) continue;
100         for(int i=0;i<=k-2;i+=2) B[i+2] = (B[i+2]+T[kd][it][i])%mod;
101     }
102     for(int i=0;i<=k-2;i+=2) B[i+2] = (B[i+2]+gi[kd][now][i])%mod;
103     for(auto it:g[kd][now]){
104         if(it == fa) continue;
105         for(int i=0;i<=k-2;i+=2) B[i+2] = (B[i+2]+mod-T[kd][it][i])%mod;
106         memset(A,0,sizeof(A));
107         for(int i=0;i<=k;i++) A[i] = (mod-B[i])%mod; A[0] = (1-A[0]+mod)%mod;
108         INV(); for(int i=0;i<=k;i++) gi[kd][it][i] = A[i];
109         for(int i=0;i<=k-2;i+=2) B[i+2] = (B[i+2]+T[kd][it][i])%mod;
110     }
111     memset(A,0,sizeof(A));
112     for(int i=0;i<=k;i++) A[i] = (mod-B[i])%mod; A[0] = (1-A[0]+mod)%mod;
113     INV(); for(int i=0;i<=k;i++) f[kd][now][i] = A[i];
114     for(auto it:g[kd][now]){
115         if(it == fa) continue;
116         dfs2(kd,it,now);
117     }
118 }
119 
120 void solve(int kd){
121     dfs1(kd,1,0);
122     dfs2(kd,1,0);
123 }
124 
125 void work(){
126     solve(0);
127     solve(1);
128     for(int i=1;i<=k;i++){
129         C[i][0] = C[i][i] = 1;
130         for(int j=1;j<i;j++) C[i][j] = (C[i-1][j] + C[i-1][j-1]) % mod;
131     }
132     int ans = 0;
133     for(int i=0;i<=k;i++){
134         int s1 = 0,s2 = 0;
135         for(int j=1;j<=n[0];j++) s1 += f[0][j][i],s1 %= mod;
136         for(int j=1;j<=n[1];j++) s2 += f[1][j][k-i],s2 %= mod;
137         int pp = (1ll*s1*s2)%mod;pp = (1ll*C[k][i]*pp)%mod;
138         ans += pp; ans %= mod;
139     }
140     printf("%d",ans);
141 }
142 
143 int main(){
144     read();
145     work();
146     return 0;
147 }

猜你喜欢

转载自www.cnblogs.com/Menhera/p/9277561.html
今日推荐