【ZJOI2017】仙人掌
参考博客:https://www.cnblogs.com/wfj2048/p/6636028.html
我们先求出\(dfs\)树(就是\(dfs\)一遍),然后问题就变成了树形\(DP\)。
我们先判断无解:就用定义来判断,如果一条边出现在多个环里面就无解。
然后我们将所有在环上的边拆了,因为这些边不可能再出现在一个新的环中。于是我们得到了一个森林。
我们设\(f_v\)表示以\(v\)为根的树得到仙人掌的方案数。\(ans=\prod_{v\ is\ root} f_v\)。
首先我们知道\(f_v=\prod f_{son_v}\)。也就是所有子树\(f\)之积。但是子树之间也可以有边相连。
我们观察子树之间的连边:一个子树最多只会连出去一条边,也就是只会两两配对连边。我们设\(g_v\)表示有\(v\)个子树两两之间连边的方案数,则:
\[ g_v=g_{v-1}+(v-1)*g_{v-2} \]
这个转移的意义是考虑第\(v\)个点连不连边,如果不连,方案数就是其他\(v-1\)个点连边的方案;如果连,就从之前的\(v-1\)个点中任选一个相连,剩下的再连边。
设\(|sn_v|\)表示\(v\)的儿子个数。则\(f_v=g_{|sn_v|}*\prod f_{son_v}\)。
我的理解是:假设\(v\)的两个儿子\(a,b\)的子树之间要连边,设\(e_a,e_b\)分别表示\(a\)到\(v\)和\(b\)到\(v\)的边,那么我们一定是将之前覆盖了\(e_a\)和\(e_b\)的两个环的起点(深度最深的那个点,如果没有,就是\(a,b\))连接在一起。所以答案是\(g_{|sn_v|}\)乘上所有子树中连边的方案数。
但是\(v\)的子树中还可以向\(v\)的父亲连边。我们就把\(v\)的父亲也看做一个子节点。所以:
\[ \begin{cases} f_v=g_{|sn_v|}*\prod f_{sn_v}\ (v\ is\ not\ root)\\ f_v=g_{|sn_v|+1}*\prod f_{sn_v}\ (v\ is\ root) \end{cases} \]
代码:
#include<bits/stdc++.h>
#define ll long long
#define N 1000005
using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
const ll mod=998244353;
int n,m;
int x[N],y[N];
struct graph {
int cnt,to[N<<2],nxt[N<<2];
int h[N<<1];
void add(int i,int j) {
to[++cnt]=j;
nxt[cnt]=h[i];
h[i]=cnt;
}
void Init() {
cnt=0;
for(int i=1;i<=n;i++) h[i]=0;
}
}s;
ll f[N],g[N];
int dfn[N],id;
int dep[N],fa[N];
int tim[N];
void dfs(int v) {
dfn[v]=++id;
for(int i=s.h[v];i;i=s.nxt[i]) {
int to=s.to[i];
if(dfn[to]) continue ;
fa[to]=v;
dep[to]=dep[v]+1;
dfs(to);
}
}
bool cmp(int a,int b) {return dep[a]<dep[b];}
bool vis[N];
void DP(int v,int flag) {
vis[v]=1;
f[v]=1;
int sn=0;
for(int i=s.h[v];i;i=s.nxt[i]) {
int to=s.to[i];
if(tim[to]>1) continue ;
if(to==fa[v]||fa[to]!=v) continue ;
sn++;
DP(to,0);
f[v]=f[v]*f[to]%mod;
}
if(flag) f[v]=f[v]*g[sn]%mod;
else f[v]=f[v]*g[sn+1]%mod;
}
int st[N];
void work() {
for(int i=1;i<=m;i++) {
int a=x[i],b=y[i];
if(dfn[a]<dfn[b]) swap(a,b);
while(a!=b) {
tim[a]++;
if(tim[a]>2) {
cout<<0<<"\n";
return ;
}
a=fa[a];
}
}
for(int i=1;i<=n;i++) st[i]=i;
sort(st+1,st+1+n,cmp);
ll ans=1;
for(int i=1;i<=n;i++) {
int now=st[i];
if(vis[now]) continue ;
DP(now,1);
ans=ans*f[now]%mod;
}
cout<<ans<<"\n";
}
void Init() {
s.Init();
for(int i=1;i<=n;i++) vis[i]=tim[i]=fa[i]=dfn[i]=dep[i]=0;
id=0;
}
int main() {
g[0]=g[1]=1;
for(int i=2;i<=500005;i++) g[i]=(g[i-1]+g[i-2]*(i-1))%mod;
int T=Get();
while(T--) {
n=Get(),m=Get();
Init();
for(int i=1;i<=m;i++) {
x[i]=Get(),y[i]=Get();
s.add(x[i],y[i]),s.add(y[i],x[i]);
}
dep[1]=1;
dfs(1);
work();
}
return 0;
}