[联合集训6-18]不同班级 容斥+分治NTT

我们设 f ( x ) 是至少有 x 个人与本班人匹配的方案数,那么根据容斥就有 A n s = i = 0 m ( 1 ) i f ( i ) ( n i ) !
a i = b i = 1 的时候是经典的错排问题, f ( x ) = ( n x )
对于一般的情况,我们枚举班级 i k i 个人与本班人匹配,就有:

f ( x ) = i = 1 n k i = 0 min ( a i , b i ) ( a i k 1 ) ( b i k 1 ) k 1 ! [ Σ k = x ]

发现这就是一个类似背包的东西,直接分治NTT就能求出所有的 f ( x )

代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#define N 131100
#define ll long long 
#define up(x,y) x=(x+(y))%mod
using namespace std;
const int mod=998244353;
int n,a[N],b[N],c[N],w[20][N],r[N],ans;
ll fac[N],ifac[N],p[N],q[N];
ll ksm(ll a,ll b)
{
    ll r=1;
    for(b=(b+mod-1)%(mod-1);b;b>>=1,a=a*a%mod)
        if(b&1) r=r*a%mod;
    return r;   
}
ll C(ll a,ll b)
{
    return a<b?0:fac[a]*ifac[b]%mod*ifac[a-b]%mod;
}
void ntt(ll a[],int m,int dft)
{
    for(int i=0;i<m;i++)
        r[i]=(r[i>>1]>>1)|((i&1)*(m>>1));
    for(int i=0;i<m;i++)
        if(i<r[i]) swap(a[i],a[r[i]]);      
    for(int i=1;i<m;i<<=1)      
    {
        ll wn=ksm(3,(mod-1)/(i<<1)*dft);
        for(int j=0;j<m;j+=(i<<1))
        {
            ll wk=1;
            for(int k=j;k<j+i;k++)
            {
                ll x=a[k],y=a[k+i]*wk%mod;
                a[k]=(x+y)%mod;
                a[k+i]=(x-y+mod)%mod;
                wk=wk*wn%mod;
            }
        }
    }
    if(dft==-1)
        for(int i=0,inv=ksm(m,mod-2);i<m;i++)
            a[i]=a[i]*inv%mod;
}

void solve(int d,int l,int r,int L,int R)
{
    if(l==r)
    {
        for(int i=L+1;i<=R;i++)
            w[d][i]=C(a[l],i-L)*C(b[l],i-L)%mod*fac[i-L]%mod;
        return ;    
    }
    int mid=(l+r)>>1,m=1;
    while(m<=R-L) m<<=1;
    solve(d+1,l,mid,L,c[mid]);
    solve(d+1,mid+1,r,c[mid],R);
    for(int i=1;i<m;i++)
        p[i]=q[i]=0;
    p[0]=q[0]=1;
    for(int i=L+1;i<=c[mid];i++)
        p[i-L]=w[d+1][i];
    for(int i=c[mid]+1;i<=R;i++)
        q[i-c[mid]]=w[d+1][i];
    ntt(p,m,1);ntt(q,m,1);
    for(int i=0;i<m;i++)
        p[i]=p[i]*q[i]%mod;
    ntt(p,m,-1);
    for(int i=L+1;i<=R;i++)
        w[d][i]=p[i-L];                 
}
int main()
{
    scanf("%d",&n);
    int tot=0;
    for(int i=1;i<=n;i++)
    {
        scanf("%d%d",&a[i],&b[i]);
        c[i]=c[i-1]+min(a[i],b[i]);
        tot+=a[i];
    }
    fac[0]=1;
    for(int i=1;i<=100000;i++)
        fac[i]=fac[i-1]*i%mod;
    ifac[100000]=ksm(fac[100000],mod-2);
    for(int i=99999;i>=0;i--)
        ifac[i]=ifac[i+1]*(i+1)%mod;
    solve(0,1,n,0,c[n]);
    w[0][0]=1;
    for(int i=0;i<=c[n];i++)
        up(ans,w[0][i]*fac[tot-i]%mod*((i&1)?mod-1:1));
    printf("%lld",ans);                 
    return 0;
}

猜你喜欢

转载自blog.csdn.net/dofypxy/article/details/80752640