[Nowcoder 2018ACM多校第十场I] Rikka with Zombies

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013578420/article/details/82150662

题目大意:
给你一棵节点数为n的无向树, 每条边上有一个栅栏, 等概率的出现 [ l i , r i ] 的高度, 有m个僵尸, 出生在 a i 点, 可以闯过低于 h i 高度的栅栏。 一个点是安全的, 当且仅当它不会被任一僵尸到达。 求树中至少有一个安全点的概率。T组数据。 ( T 5 , n , m 2000 , 1 a i n , 1 l i r i   10 9 , 1 h i 10 9 )

题目思路:
将概率转换为求方案数, 至少一个转换为求一个都没有
现在求整棵树不安全的方案数
考虑树形dp, 设f[i][j]表示子树i内所有点都是不安全的, 子树外皆有可能, 能达到点i的最大僵尸为j的方案数
初始值: i点的最大僵尸至少为出生在该点的最大僵尸
即 设点i出生的最大僵尸为k(没僵尸则为1), 则对于所有j>=k, f[i][j] = 1
考虑子树合并, f[u][a]与f[v][b] v是u的一个孩子
if a == b //说明a可肯定能跨过(u,v)这条边
f[u][a] += f[u][a] * f[v][b] * (a 能跨过(u,v))
if a < b && a 一定不在子树v内 && b 一定在子树v内
f[u][a] += f[u][a] * f[v][b] * (b 不能跨过(u,v))
if a > b && a 一定不在子树v内 && b 一定在子树v内
f[u][a] += f[u][a] * f[v][b] * (a 不能跨过(u, v))

后两种情况可以分别前缀和求即可。
时间复杂度O(nm)

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <bitset>

#define pi pair<int, int>
#define fi first
#define se second
#define mp make_pair
#define ll long long

const int N = (int)2020;
const int mo = 998244353;

using namespace std;

int gi(){
    char c = getchar(); int ret = 0;
    while (!isdigit(c)) c = getchar();
    while (isdigit(c)){
        ret = ret * 10 + c - '0';
        c = getchar();
    }
    return ret;
}

ll pw(ll x, int k){
    ll ret = 1;
    for (; k; k >>= 1, x = x * x % mo)
        if (k & 1) ret = ret * x % mo;
    return ret;
}

int n, m; pi A[N];
int cnt, lst[N], nxt[N * 2], to[N * 2], L[N * 2], R[N * 2];
bitset<N> in[N]; ll f[N][N], tmp[N], ans, all;
void add(int u, int v, int a, int b){
    nxt[++ cnt] = lst[u]; lst[u] = cnt; to[cnt] = v; L[cnt] = a; R[cnt] = b;
    nxt[++ cnt] = lst[v]; lst[v] = cnt; to[cnt] = u; L[cnt] = a; R[cnt] = b;
}

void dfs(int u, int pre){
    for (int i = 1; i <= m; i ++)
        if (A[i].se == u) in[u][i] = 1;
    for (int j = lst[u]; j; j = nxt[j]){
        int v = to[j];
        if (v == pre) continue;
        dfs(v, u);
        in[u] |= in[v];
    }
}

void dp(int u, int pre){
    int pos = 1; ll sum;
    for (int i = 1; i <= m; i ++)
        if (A[i].se == u) pos = i;
    for (int i = pos; i <= m; i ++) f[u][i] = 1;

    for (int j = lst[u]; j; j = nxt[j]){
        int v = to[j];
        if (v == pre) continue;
        dp(v, u);

        memcpy(tmp, f[u], sizeof(f[u]));
        memset(f[u], 0, sizeof(f[u]));

        sum = 0;
        for (int a = 1; a <= m; a ++){
            int k = max(min(A[a].fi - 1, R[j]) - L[j] + 1, 0);

            (f[u][a] += tmp[a] * f[v][a] % mo * k % mo) %= mo;
            if (in[v][a]) (sum += f[v][a]) %= mo;
            else (f[u][a] += tmp[a] * sum % mo * (R[j] - L[j] + 1 - k) % mo) %= mo;
        }

        sum = 0;
        for (int a = m; a >= 1; a --){
            int k = max(min(A[a].fi - 1, R[j]) - L[j] + 1, 0);

            if (in[v][a]) (sum += f[v][a] * (R[j] - L[j] + 1 - k) % mo) %= mo;
            else (f[u][a] += tmp[a] * sum % mo) %= mo;
        }

    }
}

int main()
{

    int T = gi();
    while (T --){
        cnt = ans = 0; all = 1;
        memset(lst, 0, sizeof(lst));
        memset(in, 0, sizeof(in));
        memset(f, 0, sizeof(f));

        n = gi(); m = gi();
        for (int i = 1; i < n; i ++){
            int u, v, a, b;
            u = gi(), v = gi(), a = gi(), b = gi();
            add(u, v, a, b); all = all * (b - a + 1) % mo;
        }
        for (int i = 1; i <= m; i ++)
            A[i].se = gi(), A[i].fi = gi();

        sort(A + 1, A + m + 1);
        dfs(1, 0);

        dp(1, 0);
        for (int i = 1; i <= m; i ++)
            ans = (ans + f[1][i]) % mo;
        ans = (all - ans) % mo;
        ans = ans * pw(all, mo - 2) % mo;
        if (ans < 0) ans += mo;
        printf("%lld\n", ans);
    }

    return 0;
}

猜你喜欢

转载自blog.csdn.net/u013578420/article/details/82150662
今日推荐