dtoj#3871. game

题目描述:

给定一棵 n 个点的树。

每次等概率选定一个联通块,将该联通内的所有点都捶一遍。再从选定的联通块中随机选取一个点,删掉该点及其连边。

反复操作,直至没有剩余点,求所有点被捶次数的期望 ×n! ,答案对 109+7 取模。

算法标签:点分治,fft

思路:

考虑一个点A会在另一个点B被选中时被锤到,当且仅当A和B路径上的点被锤到的顺序排列,B在最前,所以被锤到的概率是(B到A路径上的点数包括A,B)分之1。

用点分治和fft维护每一种路径长度路径数。

以下代码:

#include<bits/stdc++.h>
#define il inline
#define LL long long
#define db double
#define pi acos(-1)
#define _(d) while(d(isdigit(ch=getchar())))
using namespace std;
const int N=2e5+5,p=1e9+7;
bool vis[N];int rt,d[N],ans,res[N];
int v[N],po[N],t,l,size,sz[N],mx[N],a[N];
int n,head[N],ne[N<<1],to[N<<1],maxd,cnt,jc[N],ny[N];
struct cp{
    db r,i;
    cp(){};cp(db _r,db _i){r=_r;i=_i;}
    friend cp operator+(cp t1,cp t2){return cp(t1.r+t2.r,t1.i+t2.i);}
    friend cp operator-(cp t1,cp t2){return cp(t1.r-t2.r,t1.i-t2.i);}
    friend cp operator*(cp t1,cp t2){return cp(t1.r*t2.r-t1.i*t2.i,t1.i*t2.r+t2.i*t1.r);}
    il void clear(){r=i=0;}
}b[N];
il int read(){
   int x,f=1;char ch;
   _(!)ch=='-'?f=-1:f;x=ch^48;
   _()x=(x<<1)+(x<<3)+(ch^48);
   return f*x;
}
il void ins(int x,int y){
    ne[++cnt]=head[x];
    head[x]=cnt;to[cnt]=y;
}
il int mu(int x,int y){
    if(x+y>=p)return x+y-p;
    return x+y;
}
il int ksm(LL a,int y){
    LL b=1;
    while(y){
        if(y&1)b=b*a%p;
        a=a*a%p;y>>=1;
    }
    return b;
}
il void dft(cp *x,int o){
    for(int i=0;i<t;i++)if(i<v[i])swap(x[i],x[v[i]]);
    for(int i=1;i<t;i<<=1){
        cp wn=cp(cos(pi/i),o*sin(pi/i));
        for(int j=0;j<t;j+=i<<1){
            cp w=cp(1,0);
            for(int k=0;k<i;k++,w=w*wn){
                cp A=x[j+k],B=x[i+j+k]*w;
                x[j+k]=A+B;x[i+j+k]=A-B;
            }
        }
    }
    if(o<0)for(int i=0;i<t;i++)x[i].r/=t;
}
il void getrt(int x,int fa){
    sz[x]=1;mx[x]=0;
    for(int i=head[x];i;i=ne[i]){
        if(fa==to[i]||vis[to[i]])continue;
        getrt(to[i],x);sz[x]+=sz[to[i]];
        if(sz[to[i]]>mx[x])mx[x]=sz[to[i]];
    }
    if(size-sz[x]>mx[x])mx[x]=size-sz[x];
    if(mx[rt]>mx[x])rt=x;
}
il void dfs(int x,int fa){
    if(d[x]>maxd)maxd=d[x];a[d[x]]++;
    for(int i=head[x];i;i=ne[i]){
        if(fa==to[i]||vis[to[i]])continue;
        d[to[i]]=d[x]+1;dfs(to[i],x);
    }
}
il void cal(int x,int vv){
    t=1;l=0;
    while(t<=(maxd<<1))t<<=1,l++;
    for(int i=0;i<t;i++)v[i]=(v[i>>1]>>1)|((i&1)<<l-1);
    for(int i=0;i<=maxd;i++)b[i].r=a[i];
    dft(b,1);
    for(int i=0;i<t;i++)b[i]=b[i]*b[i];
    
    
    dft(b,-1);
    
    for(int i=0;i<=(maxd<<1);i++)res[i+1]=mu(res[i+1],((LL)(b[i].r+.5)*vv+p)%p);
    for(int i=0;i<=maxd;i++)a[i]=0;maxd=0;
    for(int i=0;i<t;i++)b[i].clear();
}
il void solve(int x){
    vis[x]=1;d[x]=0;
    dfs(x,0);cal(x,1);
    for(int i=head[x];i;i=ne[i]){
        if(vis[to[i]])continue;
        d[to[i]]=1;
        dfs(to[i],0);cal(to[i],-1);
    }
    for(int i=head[x];i;i=ne[i]){
        if(vis[to[i]])continue;
        size=sz[to[i]];rt=0;
        getrt(to[i],x);solve(rt);
    }
}
int main()
{
    n=read();mx[0]=n;
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        ins(x,y);ins(y,x);
    }
    jc[0]=1;for(int i=1;i<=n;i++)jc[i]=1ll*i*jc[i-1]%p;
    for(int i=1;i<=n;i++)ny[i]=ksm(i,p-2);
    t=1;l=0;while(t<=n)t<<=1,l++;
    size=n;getrt(1,0);solve(rt);
    for(int i=1;i<=n;i++)ans=mu(ans,1ll*res[i]*ny[i]%p);
    printf("%d\n",1ll*ans*jc[n]%p);
    return 0;
}
View Code

猜你喜欢

转载自www.cnblogs.com/Jessie-/p/10410116.html
今日推荐