6789. 2020.08.09【NOI2020】模拟T1 高三

题目

求长度为 n n n的,每个位置的取值为 [ 1 , k ] [1,k] [1,k]之间的整数的,连续上升子段长度不超过 m m m的序列的个数。

n ≤ 1 e 9 n\le 1e9 n1e9

m , k ≤ 5 e 4 m,k\le 5e4 m,k5e4


正解

神仙套路题……

首先可以想到一个DP:设 f n f_n fn表示长度为 n n n的序列的答案。方程: f n = ∑ f n − i ( k i ) f_n=\sum f_{n-i}\binom{k}{i} fn=fni(ik)

这个DP显然是错误的,因为可能会有两个连续的序列合在一起。

尝试去用一些奇技淫巧来把它容斥掉:状态转移的时候乘上一个系数 g i g_i gi

对于一个极长长度为 l e n len len的的连续序列,它的真实贡献为 [ 1 ≤ l e n ≤ m ] [1\le len \le m] [1lenm]。设 F ( x ) = ∑ [ 1 ≤ i ≤ m ] x i F(x)=\sum[1\le i\le m]x^i F(x)=[1im]xi

考虑这样一个序列会被计算多少次。于是有 ∑ i ≥ 1 G i ( x ) = F ( x ) \sum_{i\ge 1} G^i(x)=F(x) i1Gi(x)=F(x),解一下方程就可以得到 G ( x ) G(x) G(x)

g i = [ x i ] G ( x ) g_i=[x^i]G(x) gi=[xi]G(x),方程变成了 f n = ∑ f n − i g i ( k i ) f_n=\sum f_{n-i}g_i\binom{k}{i} fn=fnigi(ik)

然后它就对了…………

后面就是一个常系数线性递推的事情了。


代码

using namespace std;
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cassert>
#define N 524288
#define ll long long
#define mo 998244353
#define check(x) printf("%d\n",x);
ll qpow(ll x,ll y=mo-2){
    
    
	ll r=1;
	for (;y;y>>=1,x=x*x%mo)
		if (y&1)
			r=r*x%mo;
	return r;
}
int nN,re[N];
void setlen(int n){
    
    
	int bit=0;
	for (nN=1;nN<=n;nN<<=1,++bit);
	for (int i=1;i<nN;++i)
		re[i]=re[i>>1]>>1|(i&1)<<bit-1;
}
void clear(int A[],int n){
    
    memset(A,0,sizeof(int)*n);}
void copy(int A[],int a[],int n){
    
    clear(A,nN);for (int i=0;i<=n;++i) A[i]=a[i];}
void dft(int A[],int flag){
    
    
	for (int i=0;i<nN;++i)
		if (i<re[i])
			swap(A[i],A[re[i]]);
	static int wnk[N];
	for (int i=1;i<nN;i<<=1){
    
    
		ll wn=qpow(3,flag==1?(mo-1)/(2*i):mo-1-(mo-1)/(2*i));
		wnk[0]=1;
		for (int k=1;k<i;++k)
			wnk[k]=wnk[k-1]*wn%mo;
		for (int j=0;j<nN;j+=i<<1)
			for (int k=0;k<i;++k){
    
    
				ll x=A[j+k],y=(ll)A[j+k+i]*wnk[k];
				A[j+k]=(x+y)%mo;
				A[j+k+i]=(x-y)%mo;
			}
	}
	if (flag==-1)
		for (int i=0,invn=qpow(nN);i<nN;++i)
			A[i]=(ll)A[i]*invn%mo;
	for (int i=0;i<nN;++i)
		A[i]=(A[i]+mo)%mo;
}
void multi(int c[],int a[],int b[],int n,int an=-1,int bn=-1){
    
    
	if (an==-1) an=n-1;
	if (bn==-1) bn=n-1;
	static int A[N],B[N],C[N];
	setlen(an+bn);
	copy(A,a,an),dft(A,1);
	if (a==b)
		for (int i=0;i<nN;++i)
			C[i]=(ll)A[i]*A[i]%mo;
	else{
    
    
		copy(B,b,bn),dft(B,1);
		for (int i=0;i<nN;++i)
			C[i]=(ll)A[i]*B[i]%mo;
	}
	dft(C,-1);
	for (int i=0;i<=min(n-1,an+bn);++i)
		c[i]=C[i];
}
void getinv(int c[],int a[],int n){
    
    
	static int b[N],g[N];
	int nn=1;for (;nn<n;nn<<=1);
	clear(b,nn),clear(g,nn);
	b[0]=qpow(a[0]);
	for (int i=1;i<n;i<<=1){
    
    
		multi(g,b,b,i*2,i-1,i-1);
		multi(g,g,a,i*2,i*2-1,min(n,i*2-1));
		for (int j=0;j<i*2;++j)
			b[j]=(2ll*b[j]-g[j]+mo)%mo;
	}
	for (int i=0;i<n;++i)
		c[i]=b[i];
}
void getrev(int A[],int a[],int n){
    
    for (int i=0;i<=n;++i) A[i]=a[n-i];}
void getdiv(int c[],int a[],int b[],int n,int m){
    
    
	static int A[N],B[N],C[N];
	clear(B,n-m+1),clear(A,n-m+1);
	getrev(A,a,n),getrev(B,b,m);
	getinv(B,B,n-m+1);
	multi(C,A,B,n-m+1);
	getrev(c,C,n-m);
}
void getmod(int c[],int a[],int b[],int n,int m){
    
    
	static int D[N];
	getdiv(D,a,b,n,m);
	multi(D,D,b,n,n-m,m);
	for (int i=0;i<m;++i)
		c[i]=(a[i]-D[i]+mo)%mo;
}
int n,m,k;
int fac[N],ifac[N];
void initC(int n){
    
    
	fac[0]=1;
	for (int i=1;i<=n;++i)
		fac[i]=(ll)fac[i-1]*i%mo;
	ifac[n]=qpow(fac[n]);
	for (int i=n-1;i>=0;--i)
		ifac[i]=(ll)ifac[i+1]*(i+1)%mo;
}
ll C(int m,int n){
    
    return m<n?0:m==n?1:(ll)fac[m]*ifac[n]%mo*ifac[m-n]%mo;}
int f[N],g[N],a[N];
int q[N],mx;
void chang(int n){
    
    
	if (n==0){
    
    
		q[mx=0]=1;
		return;
	}
	if (n&1){
    
    
		chang(n-1);
		for (int i=mx;i>=0;--i)
			q[i+1]=q[i];
		q[0]=0;
		if (mx+1<k)
			mx++;
		else{
    
    
			getmod(q,q,g,mx+1,k);
			mx=k-1;
		}
	}
	else{
    
    
		chang(n>>1);
		multi(q,q,q,2*mx+1,mx,mx);
		if (2*mx<k)
			mx*=2;
		else{
    
    
			getdiv(f,q,g,mx*2,k);
			getmod(q,q,g,mx*2,k);
			mx=k-1;
		}
	}
}
int main(){
    
    
	//freopen("in.txt","r",stdin);
	freopen("senior.in","r",stdin);
	freopen("senior.out","w",stdout);
	scanf("%d%d%d",&n,&m,&k);
	initC(k);
	for (int i=1;i<=m;++i)
		f[i]=1;
	g[0]=1;
	for (int i=1;i<=m;++i)
		g[i]=f[i];
	getinv(g,g,k+1);
	multi(g,f,g,k+1);	
	
	for (int i=1;i<=k;++i)
		a[i]=(ll)g[i]*C(k,i)%mo;
	g[k]=1;
	for (int i=1;i<=k;++i)
		g[k-i]=(mo-a[i])%mo;
	chang(n-(-k+1));
	ll ans=q[k-1];
	printf("%lld\n",ans);
	/*
	dp[0]=1;
	for (int i=1;i<=n;++i){
		for (int j=1;j<=k && i-j>=0;++j)
			(dp[i]+=dp[i-j]*C(k,j)%mo*g[j])%=mo;
	}
	*/
	return 0;
}

猜你喜欢

转载自blog.csdn.net/A1847225889/article/details/107914804