Hello my friend
题目背景:
分析:树型DP
听说这个是套路,我现在才知道是不是可以直接退个役什么的······显然对于黑点而言,我们需要的是期望次数,而对于白点而言,我们需要的是期望概率,先考虑黑点的贡献,对于以1为根的有根树,令f[i]表示从点i到结束的期望经过的黑点个数,deg[i]表示i的度数。那么:
显然最后可以获得k[1]和b[1],显然b[1]就是f[1]了,所以f[1]在过程中可以不用显示维护,只要维护k[i], b[i]即可。
继续考虑白点的贡献,显然就是到这个白点的概率,定义dp[i]表示到点i的概率,显然不可能不经过父亲节点,那么,令g[i]表示,从fa[i]到i的概率,令f[i]表示从i到fa[i]的概率:
显然,f可以比较简单的求出来,有了f之后g也就非常好处理了,所以原题只需要3遍dfs即可,复杂度O(n)。
Source:
/* created by scarlyw */ #include <cstdio> #include <string> #include <algorithm> #include <cstring> #include <iostream> #include <cmath> #include <cctype> #include <vector> #include <set> #include <queue> #include <ctime> #include <bitset> inline char read() { static const int IN_LEN = 1024 * 1024; static char buf[IN_LEN], *s, *t; if (s == t) { t = (s = buf) + fread(buf, 1, IN_LEN, stdin); if (s == t) return -1; } return *s++; } /* template<class T> inline void R(T &x) { static char c; static bool iosig; for (c = read(), iosig = false; !isdigit(c); c = read()) { if (c == -1) return ; if (c == '-') iosig = true; } for (x = 0; isdigit(c); c = read()) x = ((x << 2) + x << 1) + (c ^ '0'); if (iosig) x = -x; } //*/ const int OUT_LEN = 1024 * 1024; char obuf[OUT_LEN], *oh = obuf; inline void write_char(char c) { if (oh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), oh = obuf; *oh++ = c; } template<class T> inline void W(T x) { static int buf[30], cnt; if (x == 0) write_char('0'); else { if (x < 0) write_char('-'), x = -x; for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48; while (cnt) write_char(buf[cnt--]); } } inline void flush() { fwrite(obuf, 1, oh - obuf, stdout); } ///* template<class T> inline void R(T &x) { static char c; static bool iosig; for (c = getchar(), iosig = false; !isdigit(c); c = getchar()) if (c == '-') iosig = true; for (x = 0; isdigit(c); c = getchar()) x = ((x << 2) + x << 1) + (c ^ '0'); if (iosig) x = -x; } //*/ const int MAXN = 100000 + 10; const int mod = 998244353; std::vector<int> edge[MAXN]; int n, x, y, ans; int d[MAXN], f[MAXN], k[MAXN], b[MAXN], g[MAXN], dp[MAXN], c[MAXN], sum[MAXN]; char s[MAXN]; inline void add(int &x, int t) { x += t, (x >= mod) ? (x -= mod) : (0); } inline int mod_pow(int a, int b) { int ans = 1; for (; b; b >>= 1, a = (long long)a * a % mod) if (b & 1) ans = (long long)ans * a % mod; return ans; } inline void add_edge(int x, int y) { edge[x].push_back(y), edge[y].push_back(x), d[x]++, d[y]++; } inline void read_in() { scanf("%d%s", &n, s + 1); for (int i = 1; i <= n; ++i) c[i] = s[i] - '0'; for (int i = 1; i < n; ++i) R(x), R(y), add_edge(x, y); } inline void dfs1(int cur, int fa) { if (d[cur] == 1) { f[cur] = k[cur] = 0, b[cur] = c[cur]; return ; } int cur_k = 1, cur_b = c[cur], cur_f = 0, inv = mod_pow(d[cur], mod - 2); for (int p = 0; p < edge[cur].size(); ++p) { int v = edge[cur][p]; if (v != fa) { dfs1(v, cur), add(cur_k, mod - (long long)inv * k[v] % mod); add(cur_b, (long long)inv * b[v] % mod), add(cur_f, f[v]); } } f[cur] = mod_pow((d[cur] - cur_f + mod) % mod, mod - 2); int x = mod_pow(cur_k, mod - 2); k[cur] = (long long)inv * x % mod, b[cur] = (long long)cur_b * x % mod; } inline void dfs2(int cur, int fa) { if (fa != 0) { int temp = (((long long)d[fa] - sum[fa] + f[cur] - g[fa]) % mod + mod) % mod; g[cur] = mod_pow(temp, mod - 2); } for (int p = 0; p < edge[cur].size(); ++p) { int v = edge[cur][p]; if (v != fa) add(sum[cur], f[v]); } for (int p = 0; p < edge[cur].size(); ++p) { int v = edge[cur][p]; if (v != fa) dfs2(v, cur); } } inline void dfs3(int cur, int fa) { if (c[cur] == 0) add(ans, dp[cur]); for (int p = 0; p < edge[cur].size(); ++p) { int v = edge[cur][p]; if (v != fa) dp[v] = (long long)dp[cur] * g[v] % mod, dfs3(v, cur); } } int main() { freopen("sad.in", "r", stdin); freopen("sad.out", "w", stdout); read_in(); dfs1(1, 0), dfs2(1, 0), dp[1] = 1, dfs3(1, 0); std::cout << (add(ans, b[1]), ans) << '\n'; return 0; } /* 4 1011 1 2 1 3 3 4 */