AIM Tech Round 3 (Div. 1) E. Student's Camp DP

版权声明:xgc原创文章,未经允许不得转载。 https://blog.csdn.net/xgc_woker/article/details/82946771

Description
给你一个n*m的方块,每天靠边都有p的概率被吹掉,问最后每一层都有方块且互相连通的概率,以逆元形式输出。


Sample Input
2 2
1 2
1


Sample Output
937500007


首先考虑设 h [ i ] [ l ] [ r ] h[i][l][r] 为第i行l~r的块都在,其他吹掉,前面合法的概率。
因为每一层是单独分开的吗,你可以预处理出一个pl,pr分别表示1 ~ l-1被吹掉,l保留的概率,r+1 ~ m被吹掉,r保留的概率,就可以算出一个区间没被吹的概率。
只要枚举一个范围转移即可,时间复杂度 O ( n m 4 ) O(nm^4)
考虑引入前缀和优化,设:
f [ i ] [ r ] = l = 1 r h [ i ] [ l ] [ r ] f[i][r]=\sum_{l=1}^rh[i][l][r]
S l [ i ] [ r ] = l = 1 r f [ i ] [ l ] Sl[i][r]=\sum_{l=1}^rf[i][l]
S r Sr 根据相当于 S l Sl 反过来。
那么每次转移的时候我们多维护一个这个,考虑用总数减去不合法的,可得DP方程:
h [ i ] [ l ] [ r ] = p l [ l ] p r [ r ] ( S l [ i 1 ] [ m ] S l [ i 1 ] [ l 1 ] S r [ i 1 ] [ r + 1 ] ) h[i][l][r]=pl[l]*pr[r]*(Sl[i-1][m]-Sl[i-1][l-1]-Sr[i-1][r+1])
这样就能做到 O ( n m 2 ) O(nm^2)
考虑其实可以直接维护 f f
f [ i ] [ r ] = l = 1 r h [ i ] [ l ] [ r ] f[i][r]=\sum_{l=1}^rh[i][l][r]
= l = 1 r p l [ l ] p r [ r ] ( S l [ i 1 ] [ m ] S l [ i 1 ] [ l 1 ] S r [ i 1 ] [ r + 1 ] ) =\sum_{l=1}^rpl[l]*pr[r]*(Sl[i-1][m]-Sl[i-1][l-1]-Sr[i-1][r+1])
= l = 1 r p l [ l ] p r [ r ] S l [ i 1 ] [ m ] p l [ l ] p r [ r ] S l [ i 1 ] [ l 1 ] p l [ l ] p r [ r ] S r [ i 1 ] [ r + 1 ] =\sum_{l=1}^rpl[l]*pr[r]*Sl[i-1][m]-pl[l]*pr[r]*Sl[i-1][l-1]-pl[l]*pr[r]*Sr[i-1][r+1]
于是你可以处理一个 l = 1 r p l [ l ] \sum_{l=1}^rpl[l] l = 1 r p l [ l ] S [ i 1 ] [ l 1 ] \sum_{l=1}^rpl[l]*S[i-1][l-1]
于是转移就可以做到 O ( n m ) O(nm)
由于你前后好像没什么关系,于是我只开了一维。


#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;
typedef long long LL;
const LL mod = 1e9 + 7;
int _min(int x, int y) {return x < y ? x : y;}
int _max(int x, int y) {return x > y ? x : y;}
int read() {
	int s = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
	while(ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
	return s * f;
}

LL jc[110000], inv[110000];
LL f[1600], suml[110000], sumr[110000], gl[1600];
LL pl[1600], pr[1600], S[1600];

LL pow_mod(LL a, LL k) {
	LL ans = 1;
	while(k) {
		if(k & 1) (ans *= a) %= mod;
		(a *= a) %= mod; k /= 2;
	} return ans;
}

LL C(int n, int m) {return jc[n] * inv[m] % mod * inv[n - m] % mod;}

int main() {
	int n = read(), m = read();
	int a = read(), b = read();
	int k = read();
	LL P = (LL)a * pow_mod(b, mod - 2) % mod;
	suml[0] = 1LL; for(int i = 1; i <= k; i++) suml[i] = suml[i - 1] * P % mod;
	P = (LL)(b - a) * pow_mod(b, mod - 2) % mod;
	sumr[0] = 1LL; for(int i = 1; i <= k; i++) sumr[i] = sumr[i - 1] * P % mod;
	jc[0] = inv[0] = 1; for(int i = 1; i <= k; i++) jc[i] = (LL)jc[i - 1] * i % mod;
	inv[k] = pow_mod(jc[k], mod - 2);
	for(int i = k - 1; i >= 1; i--) inv[i] = (LL)inv[i + 1] * (i + 1) % mod;
	for(int i = 1; i <= m; i++) {
		if(i > k + 1) pl[i] = 0;
		else pl[i] = (suml[i - 1] * sumr[k - i + 1] % mod) * C(k, i - 1) % mod;
		pr[m - i + 1] = pl[i];
	} S[0] = 0; for(int i = 1; i <= m; i++) S[i] = (S[i - 1] + pl[i]) % mod;
	for(int i = 1; i <= m; i++) f[i] = S[i] * pr[i] % mod;
	suml[0] = sumr[0] = sumr[m + 1] = 0;
	for(int i = 1; i <= m; i++) suml[i] = (suml[i - 1] + f[i]) % mod, sumr[m - i + 1] = suml[i];
	for(int i = 1; i <= m; i++) gl[i] = (gl[i - 1] + pl[i]) % mod;
	S[0] = 0; for(int i = 1; i <= m; i++) S[i] = (S[i - 1] + pl[i] * suml[i - 1] % mod) % mod;
	for(int i = 2; i <= n; i++) {
		for(int r = 1; r <= m; r++) {
			f[r] = (gl[r] * pr[r] % mod) * (suml[m] - sumr[r + 1]) % mod;
			(f[r] -= S[r] * pr[r] % mod) %= mod;
			(f[r] += mod) %= mod;
		} for(int r = 1; r <= m; r++) suml[r] = (suml[r - 1] + f[r]) % mod, sumr[m - r + 1] = suml[r];
		for(int r = 1; r <= m; r++) S[r] = (S[r - 1] + pl[r] * suml[r - 1] % mod) % mod;
	} printf("%lld\n", suml[m]);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/xgc_woker/article/details/82946771