【洛谷】P4705 玩游戏-生成函数

版权声明:欢迎转载(请附带原链接)ヾ(๑╹◡╹)ノ" https://blog.csdn.net/corsica6/article/details/84770919

传送门:luoguP4705


题解

t t 次价值的期望: 1 n m i = 1 n j = 1 m ( a i + b j ) t \dfrac{1}{nm}\sum\limits_{i=1}^n\sum\limits_{j=1}^m\sum(a_i+b_j)^t

二项式定理展开一下:
t ! n m k = 0 t 1 k ! i = 1 n a i k 1 ( t k ) ! j = 1 m b j t k \dfrac{t!}{nm}\sum\limits_{k=0}^t\dfrac {1}{k!}\sum\limits_{i=1}^na_i^k\dfrac{1}{(t-k)!}\sum\limits_{j=1}^mb_j^{t-k}

所以需要构造的生成函数 F ( x ) , G ( x ) F(x),G(x) 的第 i i 项系数分别为 j = 1 n a j i , j = 1 m b j i \sum\limits_{j=1}^na_j^i,\sum\limits_{j=1}^mb_j^i

单独考虑 a j a_j F F 的贡献: 1 + a i x + a i 2 x + . . . 1+a_ix+a_i^2x+... ,生成函数为 1 1 a j x \dfrac1{1-a_jx}
所以 F ( x ) = i = 1 n 1 1 a i x F(x)=\sum\limits_{i=1}^n\dfrac1{1-a_ix}

1 a i x 1-a_ix 在分母不好看,积分转成 l n ln ,得到 F ( x ) = i = 1 n ln ( 1 a i x ) F(x)=\sum\limits_{i=1}^n\ln'(1-a_ix)

还是不好看。。。再把求导转到整个函数外面:

F ( x ) = ( ln ( i = 1 n ( 1 a i x ) ) ) F'(x)=(\ln(\prod\limits_{i=1}^n(1-a_ix)))'

F = x F + n F=-xF'+n

G G 同理,然后套多项式模板即可。


代码

#include<bits/stdc++.h>
#define mem(f) memset(f,0,sizeof(f))
using namespace std;
typedef long long ll;
const int N=1e5+10,M=2e6+10,mod=998244353,gen=3;

int n,m,bs,ivg,a[M],b[M],f[M],g[M];
int rv[M],s[20][M],frac[N],nv[N];

char cp,OS[100];
inline void rd(int &x)
{
    cp=getchar();x=0;
    for(;!isdigit(cp);cp=getchar());
    for(;isdigit(cp);cp=getchar()) x=(x<<3)+(x<<1)+(cp^48);
}

inline void ot(int x)
{
	int re=0;OS[0]='\n';
	for(;(!re)||x;x/=10) OS[++re]='0'+x%10;
	for(;~re;--re) putchar(OS[re]);
}

inline int ad(int x,int y){x+=y;return x>=mod?x-mod:x;}
inline int dc(int x,int y){x-=y;return x<0?x+mod:x;}

inline int fp(int x,int y)
{
    int re=1;
    for(;y;y>>=1,x=(ll)x*x%mod)
      if(y&1) re=(ll)re*x%mod;
    return re;
}

inline void ntt(int *e,int pr,int n)
{
    int i,j,k,ix,iy,ori,pd,g=pr?gen:ivg;
    for(i=1;i<n;++i) if(i<rv[i]) swap(e[i],e[rv[i]]);
    for(i=1;i<n;i<<=1){
        ori=fp(g,(mod-1)/(i<<1));
        for(j=0;j<n;j+=(i<<1))
            for(pd=1,k=0;k<i;++k,pd=(ll)pd*ori%mod){
                ix=e[j+k];iy=(ll)pd*e[j+k+i]%mod;
                e[j+k]=ad(ix,iy);e[j+k+i]=dc(ix,iy);
            }
    }
    if(pr) return;
    g=fp(n,mod-2);for(i=0;i<n;++i) e[i]=(ll)e[i]*g%mod;
}

inline void init(int n,int &len)
{
	int i,L=0;n<<=1;
	for(len=1;len<n;len<<=1) L++;
	for(i=1;i<len;++i) rv[i]=(rv[i>>1]>>1)|((i&1)<<(L-1));
}

inline void mul(int *f,int *g,int len)
{
	ntt(f,1,len);ntt(g,1,len);
	for(int i=0;i<len;++i) f[i]=(ll)f[i]*g[i]%mod;
	ntt(f,0,len);
}

inline void cal(int dep,int l,int r)
{
    if(l>r) return;
    if(l==r) {s[dep][0]=1;s[dep][1]=(mod-a[l])%mod;return;}
    int mid=(l+r)>>1,i,len,lim=mid-l+1;cal(dep+1,l,mid);
    for(i=0;i<=lim;++i) s[dep][i]=s[dep+1][i];cal(dep+1,mid+1,r);
	init(r-l+2,len);
	fill(s[dep]+lim+1,s[dep]+len,0);
    fill(s[dep+1]+r-mid+1,s[dep+1]+len,0);
    mul(s[dep],s[dep+1],len);
}

inline void gtder(int n,int *f,int *g)
{for(int i=1;i<n;++i) f[i-1]=(ll)g[i]*i%mod;f[n-1]=0;}

void gtinv(int n,int *f,int *g)
{
    if(n==1) {f[0]=fp(g[0],mod-2);return;}
    gtinv((n+1)>>1,f,g);int i,j,len;static int cont[M];
    for(i=0;i<n;++i) cont[i]=g[i];init(n,len);
	fill(cont+n,cont+len,0);ntt(f,1,len);ntt(cont,1,len);
	for(i=0;i<len;++i) f[i]=(ll)f[i]*dc(2,(ll)cont[i]*f[i]%mod)%mod;
	ntt(f,0,len);fill(f+n,f+len,0);
}

inline void gtln(int n,int *f,int *g)
{
    int i,j,len=1,L=0;static int der[M],nv[M];
    for(;len<n+n;len<<=1) L++;fill(der,der+len,0);fill(nv,nv+len,0);
    gtder(n,der,g);gtinv(n,nv,g);
    for(i=1;i<len;++i) rv[i]=(rv[i>>1]>>1)|((i&1)<<(L-1));
    mul(der,nv,len);
	for(i=0;i<n;++i) f[i]=der[i];fill(f+n,f+len,0);
}

int main(){
    int i,j,len,rev;
    rd(n);rd(m);ivg=fp(gen,mod-2);rev=fp((ll)n*m%mod,mod-2);
    for(i=1;i<=n;++i) rd(f[i]);for(i=1;i<=m;++i) rd(g[i]);
    
    frac[0]=frac[1]=nv[0]=nv[1]=1;
    for(rd(bs),i=2;i<=bs;++i)
      frac[i]=(ll)frac[i-1]*i%mod,nv[i]=(ll)(mod-mod/i)*nv[mod%i]%mod;
    for(i=2;i<=bs;++i) nv[i]=(ll)nv[i-1]*nv[i]%mod;

    len=max(bs,max(n,m))+1;
    memcpy(a,f,(n+2)<<2);cal(0,1,n);f[0]=n;
    gtln(len,b+1,s[0]);for(i=bs;i;--i) f[i]=(ll)nv[i]*(mod-b[i])%mod;
    
	mem(s[0]);mem(b);
    memcpy(a,g,(m+2)<<2);cal(0,1,m);g[0]=m;
    gtln(len,b+1,s[0]);for(i=bs;i;--i) g[i]=(ll)nv[i]*(mod-b[i])%mod;
    
    init(bs+1,len);
	fill(f+bs+1,f+len,0);fill(g+bs+1,g+len,0);////记得清空,不然只有66pts 
	mul(f,g,len);
    for(i=1;i<=bs;++i) ot((ll)frac[i]*f[i]%mod*(ll)rev%mod);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/corsica6/article/details/84770919