[PKUWC2018]Slay the Spire

题目

我竟然能做出九老师的组合计数题,尽管这题很呆

我们先考虑一个简单的问题,如果给定你要选出来的\(m\)张卡牌,如何做到攻击伤害最高

非常显然,因为保证了强化牌上的数字大于\(1\),所以我们优先选择那些强化牌,毕竟最小的一张只能翻两倍的强化牌都要比选择攻击牌好;选完强化牌之后剩下的攻击牌自然是越大越好

当然了,我们也不可能选出\(k\)张强化牌来,这样什么伤害都造不成,所以如果强化牌数量大于\(k\),我们就选择前\(k-1\)大的强化牌,配上最大的攻击牌

我们发现给定的强化牌和攻击牌都是无序的,我们先排序再说

之后我们就可以大力\(dp\)

\(dp_{i,j}\)表示前\(i\)张强化牌里选出\(j\)张所有强化牌乘积的和,\(f_{i,j}\)表示前\(i\)张强化牌里选择了\(j\)张且第\(i\)张一定被选择的乘积和

显然有转移

\[dp_{i,j}=dp_{i-1,j}+dp_{i-1,j-1}\times a_i\]

\[f_{i,j}=dp_{i-1,j-1}\times a_i\]

攻击牌这边我们设\(h_{i,j}\)表示前\(i\)张攻击牌里选择了\(j\)张的所有方案的和,\(g_{i,j}\)表示强迫选择第\(i\)

也有转移

\[h_{i,j}=h_{i-1,j}+h_{i-1,j-1}+C_{i-1}^{j-1}\times b_i\]

\[g_{i,j}=h_{i-1,j-1}+C_{i-1}^{j-1}\times b_i\]

现在我们考虑枚举选择了多少张强化牌

设选择了\(i\)张强化牌,自然也就需要选择\(m-i\)张攻击牌

如果\(i<k\),那么这些强化牌都是要直接用的,选择的\(m-i\)张攻击牌里能被用于打出的\(k\)张的也就只有\(k-i\)张,而剩下的\(m-i-(k-i)=m-k\)张牌我们在后面随便选一下

于是答案就是

\[\sum_{j=0}^nf_{j,i}\times \sum_{j=0}^ng_{j,k-i}C_{n-j}^{m-k}\]

如果\(i>=k\),那么我们也只能选择\(k-1\)张强化牌,剩下的多选出来的\(i-k+1\)张我们还是从后面的位置里选出来,实际上需要的攻击牌也就只有\(1\)张,需要额外多选的攻击牌也就只有\(m-i-1\)

答案就是

\[\sum_{j=0}^nf_{j,k-1} C_{n-j}^{i-k+1}\times \sum_{j=0}^ng_{j,1}C_{n-j}^{m-i-1}\]

代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
inline int read() {
    char c=getchar();int x=0;while(c<'0'||x>'9') c=getchar();
    while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int mod=998244353;
const int maxn=3005;
int c[maxn][maxn];
int dp[maxn][maxn],f[maxn][maxn];
int g[maxn][maxn],h[maxn][maxn];
int T,n,m,k;
int a[maxn],b[maxn];
inline int C(int n,int m) {if(m>n) return 0;return c[n][m];}
inline int cmp(int A,int B) {return A>B;}
int main() {
    T=read();int M=maxn-5;
    for(re int i=0;i<=M;i++) c[i][0]=c[i][i]=1;
    for(re int i=1;i<=M;i++)
        for(re int j=1;j<i;j++)
            c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
    while(T--) {
        n=read(),m=read(),k=read();
        for(re int i=1;i<=n;i++) a[i]=read();
        for(re int i=1;i<=n;i++) b[i]=read();
        std::sort(a+1,a+n+1,cmp);std::sort(b+1,b+n+1,cmp);
        dp[0][0]=1;
        for(re int i=1;i<=n;i++)
            for(re int j=0;j<=i;j++) {
                dp[i][j]=dp[i-1][j];
                if(j) dp[i][j]=(dp[i][j]+1ll*dp[i-1][j-1]*a[i]%mod)%mod;
            }
        f[0][0]=1;
        for(re int i=1;i<=n;i++) 
            for(re int j=1;j<=i;j++) 
                f[i][j]=1ll*dp[i-1][j-1]*a[i]%mod;
        h[0][0]=0;
        for(re int i=1;i<=n;i++)
            for(re int j=0;j<=i;j++) {
                h[i][j]=h[i-1][j];
                if(j) h[i][j]=(h[i][j]+1ll*C(i-1,j-1)*b[i]%mod+h[i-1][j-1])%mod;
            }
        for(re int i=1;i<=n;i++)
            for(re int j=1;j<=i;j++)
                g[i][j]=(1ll*C(i-1,j-1)*b[i]%mod+h[i-1][j-1])%mod;
        int ans=0;
        for(re int i=0;i<m;i++) {
            int now=0;
            for(re int j=0;j<=n;j++) {
                if(i<=k-1) now=(now+f[j][i])%mod;
                    else now=(now+1ll*f[j][k-1]*C(n-j,i-k+1)%mod)%mod;
            }
            int tot=0,res=0;
            if(i<=k-1) res=k-i;else res=1;
            for(re int j=0;j<=n;j++) 
                tot=(tot+1ll*g[j][res]*C(n-j,m-i-res)%mod)%mod;
            ans=(ans+1ll*now*tot%mod)%mod;
        }
        printf("%d\n",ans);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/asuldb/p/10758306.html
今日推荐