hdu 5519 Kykneion asma 2015沈阳现场(状压dp + 容斥 + 组合数)

题意

求只用 \(\{0, 1, 2, 3, 4\}\) 五个数每个数最多使用 \(a_i\) 次且没有前导零所能组成的 \(n\) 位数的数量,答案模 \(1e9+7\)

思路

貌似可以母函数+FFT (但是这两个我都不会

也可以用状压dp + 容斥原理来求解。

这个题求合法数不好求,所以可以计算总数减非法数来求答案。

设计状态 \(dp(i, j, s)\) 表示:

最终有 \(j\) 个数字使用次数会超过对应上限,到第 \(i\) 位时,超出上限次数的状态为 \(s\)

我们可以得到一个显然的转移方程:

\(dp[i][j][s] = dp[i-1][j][s] \times (5-(j-num[s])) + \sum_{i-a[k]-1} dp[i-a[k]-1][j][s \wedge (1<<k)] \times \dbinom{i-1}{a[k]}\)

最后容斥一下就好。

对于前导零的处理方法:如果有零,则 a[0]--, n-- 再dp一遍,两次答案相减即可。

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
const int maxn = 7010;
const int maxm = (1<<5)+10;
const int all = maxm-11;
const int mod = 1e9+7;

inline void add(ll &x, ll y) { x += y; if(x >= mod) x -= mod; }

ll fac[maxn], inv[maxn], num[maxm];
ll ksm(ll a, ll n) {
    ll res = 1;
    while(n) {
        if(n & 1) res = res*a%mod;
        a = a*a%mod;
        n >>= 1;
    } return res;
}
void init() {
    fac[0] = 1; for (int i = 1; i < maxn; ++i) fac[i] = fac[i-1]*i%mod;
    inv[maxn-1] = ksm(fac[maxn-1], mod-2);
    for (int i = maxn-2; i >= 0; --i) inv[i] = inv[i+1]*(i+1)%mod;
    for (int mask = 0; mask < maxm; ++mask) num[mask] = num[mask>>1] + (mask&1);
}
ll C(int n, int m) { return fac[n]*inv[m]%mod*inv[n-m]%mod; }

ll dp[maxn][6][maxm], a[10];
int n, T, cas;

ll DP() {
    memset(dp, 0, sizeof(dp));
    for (int i = 1; i <= 5; ++i) dp[0][i][0] = 1;
    for (int j = 1; j <= 5; ++j) {
        for (int i = 1; i <= n; ++i) {
            for (int mask = 0; mask <= all; ++mask) {
                if(num[mask] > j) continue;
                dp[i][j][mask] = dp[i-1][j][mask]*(5-(j-num[mask]))%mod;
                for (int k = 0; k < 5; ++k) {
                    if((mask & (1<<k)) && i>a[k])
                        add(dp[i][j][mask], dp[i-a[k]-1][j][mask^(1<<k)]*C(i-1, a[k])%mod);
                }
            }
        }
    }
    ll res = ksm(5, n);
    for (int mask = 1; mask <= all; ++mask) {
        if(num[mask]&1) {
            res = res-dp[n][num[mask]][mask];
            if(res < 0) res += mod;
        } else {
            res = res + dp[n][num[mask]][mask];
            if(res >= mod) res -= mod;
        }
    }
    return res;
}

int main() {
    init();
    scanf("%d", &T);
    while(T--) {
        scanf("%d", &n);
        for (int i = 0; i < 5; ++i) scanf("%lld", a+i);
        ll ans = DP();
        if(a[0]) {
            --a[0], --n;
            ans -= DP();
            if(ans < 0) ans += mod;
        }
        printf("Case #%d: %lld\n", ++cas, ans);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/acerkoo/p/11605414.html