设一个二维数组
c,不能为0;
令
c[0][j]=c[1][j]
再取一个数组
A和数组
B,可以为0
令
a[i]=A[i]+c[0][i]
b[i]=B[i]+c[1][i]
如给定
N=4,M=4,k=2
你可以给
c数组赋值为
c[0][1]=c[0][2]=c[1][1]=c[1][2]=1;
那么你的
A数组可以赋值为
A[1]=2,A[2]=0
B数组可以赋值为
B[1]=0,B[3]=2;
这样构造出一种
ab数组便是
a={
3,
1}
b={
1,
3}
如果确定了
c数组,且
i=1∑kc[0][i]=r,那么
A,B数组的组合数**(注意A,B数组中元素都可为0**)便是
C(N−r+k−1,k−1)∗C(M−r+k−1,k−1),
C(n,m)表示
n个取
m个的组合数(因为
sum(c[0])+sum(A)=N且sum(c[1])+sum(B)=M才是合法的)
对这样的c数组计数(即上面的组合数乘积)相当于对于所有任意
i=1...k的都有
min(ai,bi)>=c[0][i]的ab数组都计了一次数
请往下看
考虑一个确定
a,b数列的
MIN数列,其中
MIN[i]=min(a[i],b[i])
显然这个
a,b数列产生的贡献是
i=1∏=kMIN[i](这个是定义的)
**接下来来讨论一个确定的
MIN数组
MINx,并且
a,b也是确定的 **
我们尝试把乘法拆成计数贡献。
那么显然对于特定的
MINx数组,其应该要被计数的次数便是
i=1∏kMINx[i].
那么我们把这个计数分到每一个特定的
c数组和
A,B组合中去便可以了。
想一下,什么样的
c数组跟
A,B数组组成的
a,b数组会形成这个
MINx数组(注意,a,b数组也是特定的了)呢。
显然是对于对于任意
i=1..k都有
c[0][i]<=MINX[i]
这样的
c数组跟
A,B的组合形成的
a,b数组的
MIN数组中一定会有跟
MINx数组一样的。有且仅有一个
这样的
c数组会有多少个呢?
显然会有
i=1∏kMINx[i]个,因为你只要
c[0]数组的每一位数都对应小于等于
MINX就行了。
也就是说问题完美的转成计数问题了,即是转换成所有
c数组产生的
A,B数组的种数之和。
容易知道对于和相同的
c数组对应的
A,B数组个数是一样的,相同和如
r的种数有
C(r−1,k−1)插板法
以
N=3,M=3,K=2为例
若
c[0]={
1,
1},则对应的
A,B数组组合有:
A1={
0,
1},
B1={
0,1},则
a1=1,2,
b1=1,2,
MIN1=1,2
A2=0,1,B2=1,0,则
a2=1,2,
b2=2,1,
MIN2=1,1
接下来有
a3=2,1,
b3=1,2,
MIN3=1,1
a4=2,1,
b4=2,1,
MIN=2,1
这样对这个
c计数的话就相当于对以上四种
a,b都计了一次数了
手动模拟一下
c的其他数组对照上面的讲解就会迎刃而解了。
还是看不懂就是我不会解释
代码:
#include<bits/stdc++.h>
#define ll long long
#define endl '\n'
#define rep(i,l,r) for(int i=l;i<=r;i++)
#define per(i,r,l) for(int i=r;i>=l;i--)
const int MX=1e6+7;
const int mod=998244353;
const double pi=3.1415926535897932384;
double isp=1e-13;
using namespace std;
ll qpow(ll a,ll b,ll MOD=mod){for(ll ans=1;;a=a*a%MOD,b>>=1){if(b&1)ans=ans*a%MOD;if(!b)return ans;}}
ll inv(ll a,ll MOD=mod){return qpow(a,MOD-2,MOD);}//要求MOD为质数
ll exgcd(ll a,ll b,ll &x,ll &y){if(b==0){x=1,y=0;return a;}ll ret=exgcd(b,a%b,y,x);y-=a/b*x;return ret;}
ll getInv(int a,int mod){ll x,y;ll d=exgcd(a,mod,x,y);return d==1?(x%mod+mod)%mod:-1;}//求a在mod下的逆元,不存在逆元返回-1,不要求MOD为质数
ll p[MX],np[MX],in[MX];
ll C(ll n,ll m)
{
return p[n]*np[m]%mod*np[n-m]%mod;
}
int main()
{
ios::sync_with_stdio(0),cin.tie(0);
p[0]=np[0]=in[1]=1;
for(int i=1;i<MX;i++)p[i]=p[i-1]*i%mod;
for(int i=2;i<MX;i++)in[i]=mod-(mod/i)*in[mod%i]%mod;
for(int i=1;i<MX;i++)np[i]=np[i-1]*in[i]%mod;
int t;
cin>>t;
while(t--)
{
int n,m,k;
cin>>n>>m>>k;
ll ans=0;
for(int i=k;i<=min(n,m);i++)
{
ans=(ans+1ll*C(i-1,k-1)*C(n-i+k-1,k-1)%mod*C(m-i+k-1,k-1))%mod;
}
cout<<ans<<endl;
}
return 0;
}