[BZOJ5287][Hnoi2018]毒瘤(虚树 + 树形 DP)

Address

Solution - Step 1

  • 首先我们看到非树边最多 11 11
  • 很容易想到暴力枚举每条非树边的两个端点是否在独立集内
  • 然后树上 DP 求独立集个数
  • 复杂度 O ( 3 m n × n ) O(3^{m-n}\times n)
  • 然后我们发现对于一条非树边 ( u , v ) (u,v) ,如果强制 u u 不在独立集内,那么 v v 点没有再去强制的必要
  • 所以只需要枚举 2 2 种情况
  • (1) u u 在独立集内, v v 不在独立集内
  • (2) u u 不在独立集内
  • 复杂度 O ( 2 m n × n ) O(2^{m-n}\times n)
  • 常数优秀的程序能够水过

Solution - Step 2

  • 考虑把 DP 的复杂度优化掉
  • 发现这个 DP 与对整棵树进行 DP 的区别仅仅是强制了 O ( m n ) O(m-n) 个点选或不选
  • 考虑虚树
  • 把所有非树边的端点放在一起建立虚树
  • 注意:为了消除虚树根的子树之外点的影响,需要把整棵树的根节点加入关键点集合
  • 然后在虚树上 DP ,转移方程为
  • f [ u ] [ 0 ] = h [ u ] [ 0 ] × v s o n [ u ] ( f [ v ] [ 0 ] × g [ ( u , v ) ] [ 0 ] [ 0 ] + f [ v ] [ 1 ] × g [ ( u , v ) ] [ 0 ] [ 1 ] ] ) f[u][0]=h[u][0]\times\prod_{v\in son[u]}(f[v][0]\times g[(u,v)][0][0]+f[v][1]\times g[(u,v)][0][1]])
  • f [ u ] [ 1 ] = h [ u ] [ 1 ] × v s o n [ u ] ( f [ v ] [ 0 ] × g [ ( u , v ) ] [ 1 ] [ 0 ] + f [ v ] [ 1 ] × g [ ( u , v ) ] [ 1 ] [ 1 ] ) f[u][1]=h[u][1]\times\prod_{v\in son[u]}(f[v][0]\times g[(u,v)][1][0]+f[v][1]\times g[(u,v)][1][1])
  • f [ u ] [ 0 ] f[u][0] 虚树上 u u 的子树内不选 u u 的独立集个数
  • f [ u ] [ 1 ] f[u][1] 虚树上 u u 的子树内不选 u u 的独立集个数
  • s o n [ u ] son[u] 虚树上 u u 的子节点集合
  • g [ ( u , v ) ] [ x ] [ y ] g[(u,v)][x][y] 原树上 u u v v 的路径(虚树上 u u v v 的父亲节点), u u 的选择状态为 x x v v 的选择状态为 y y ,设原树上 w w u u 的子节点且是 v v 的祖先,就 w w 的子树内但不在 v v 的子树内的所有点与 u u v v 构成独立集的方案数,画个图长这样
    在这里插入图片描述
  • 即为强制 u u 的状态为 x x v v 的状态为 y y 时,黄色部分构成独立集的方案数
  • h [ u ] [ 0 / 1 ] h[u][0/1] 表示 u u 的子树内,除去所有 v v 的子树( v v 满足其子树内存在虚树点), u u 不选 / 选的方案数
    在这里插入图片描述
  • 如上图,紫色点为虚树点,特殊地, u u 为虚树点
  • g g 可以建出虚树之后大力预处理, h h 可以在树上 DP 求出
  • 这样我们就能 O ( m n ) O(m-n) 实现 DP 了
  • 复杂度 O ( n log n + 2 m n × ( m n ) ) O(n\log n+2^{m-n}\times(m-n))

Code

#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define Edge(u) for (int e = adj[u], v = go[e]; e; e = nxt[e], v = go[e])
#define Tree(u) for (int e = adj[u], v; e; e = nxt[e]) if ((v = go[e]) != fu)
#define Vir(u) for (int e = adj2[u], v = go2[e]; e; e = nxt2[e], v = go2[e])

inline int read()
{
	int res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	return bo ? ~res + 1 : res;
}

template <class T>
inline void Swap(T &a, T &b) {a ^= b; b ^= a; a ^= b;}

const int N = 1e5 + 5, M = 2e5 + 100, LogN = 20, ZZQ = 998244353;

int n, m, ecnt, nxt[M], adj[N], go[M], tot, U[N], V[N], FA[N],
virn, vir[N], T, dfn[N], fa[N][LogN], dep[N], stk[N], top, virfa[N],
f[N][2], val[N][2][2], col[N], ans, ecnt2, nxt2[N], adj2[N], go2[N],
fh[N][2], exc[N][2], tt, uv[N];
bool vis[N];

void add_edge(int u, int v)
{
	nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
	nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u;
}

void add_edge2(int u, int v)
{
	nxt2[++ecnt2] = adj2[u]; adj2[u] = ecnt2; go2[ecnt2] = v;
}

void dfs(int u, int fu)
{
	dep[u] = dep[fa[u][0] = fu] + 1;
	for (int i = 0; i <= 15; i++) fa[u][i + 1] = fa[fa[u][i]][i];
	dfn[u] = ++T;
	f[u][0] = f[u][1] = 1;
	for (int e = adj[u], v; e; e = nxt[e])
	{
		if ((v = go[e]) == fu) continue;
		dfs(v, u);
		f[u][0] = 1ll * f[u][0] * (f[v][0] + f[v][1]) % ZZQ;
		f[u][1] = 1ll * f[u][1] * f[v][0] % ZZQ;
	}
}

int get_0(int u, int _u)
{
	int res = 1;
	for (int e = adj[u], v = go[e]; e; e = nxt[e], v = go[e])
		if (v != fa[u][0] && v != _u)
			res = 1ll * res * (f[v][0] + f[v][1]) % ZZQ;
	return res;
}

int get_1(int u, int _u)
{
	int res = 1;
	for (int e = adj[u], v = go[e]; e; e = nxt[e], v = go[e])
		if (v != fa[u][0] && v != _u)
			res = 1ll * res * f[v][0] % ZZQ;
	return res;
}

int lca(int u, int v)
{
	if (dep[u] < dep[v]) Swap(u, v);
	for (int i = 16; i >= 0; i--)
	{
		if (dep[fa[u][i]] >= dep[v]) u = fa[u][i];
		if (u == v) return u;
	}
	for (int i = 16; i >= 0; i--)
		if (fa[u][i] != fa[v][i])
			u = fa[u][i], v = fa[v][i];
	return fa[u][0];
}

int orz(int u, int v)
{
	for (int i = 16; i >= 0; i--)
		if ((dep[v] - dep[u] - 1 >> i) & 1)
			v = fa[v][i];
	return v;
}

inline bool comp(int a, int b)
{
	return dfn[a] < dfn[b];
}

int cx(int x)
{
	if (FA[x] != x) FA[x] = cx(FA[x]);
	return FA[x];
}

bool zm(int x, int y)
{
	int ix = cx(x), iy = cx(y);
	if (ix != iy) return FA[iy] = ix, 1;
	return 0;
}

void build_virtree()
{
	std::sort(vir + 1, vir + virn + 1, comp);
	int tn = virn;
	for (int i = 1; i <= tn; i++)
	{
		int u = vir[i];
		if (!top) {virfa[stk[++top] = u] = 0; continue;}
		int w = lca(stk[top], u);
		while (dep[w] < dep[stk[top]])
		{
			if (dep[w] > dep[stk[top - 1]]) virfa[stk[top]] = w;
			top--;
		}
		if (dep[w] > dep[stk[top]])
			virfa[w] = stk[top], stk[++top] = w, vir[++virn] = w;
		virfa[stk[++top] = u] = w;
	}
	std::sort(vir + 1, vir + virn + 1, comp);
	for (int i = 2; i <= virn; i++)
	{
		add_edge2(virfa[vir[i]], vir[i]);
		for (int c = 0; c <= 1; c++)
		{
			int u = fa[vir[i]][0], fau = virfa[vir[i]],
				lst = vir[i], r0 = 0, r1 = 0;
			(c ? r1 : r0) = 1;
			while ("I AK IOI")
			{
				int n0, n1;
				if (u != fau) n0 = 1ll * (r0 + r1) * get_0(u, lst) % ZZQ,
					n1 = 1ll * r0 * get_1(u, lst) % ZZQ;
				else n0 = (r0 + r1) % ZZQ, n1 = r0;
				r0 = n0; r1 = n1;
				if (u == fau) break;
				lst = u; u = fa[u][0];
			}
			val[vir[i]][c][0] = r0; val[vir[i]][c][1] = r1;
		}
	}
	for (int i = 1; i <= virn; i++)
	{
		int u = vir[i]; tt = 0;
		for (int e = adj2[u], v = go2[e]; e; e = nxt2[e], v = go2[e])
			vis[uv[++tt] = orz(u, v)] = 1;
		exc[u][0] = exc[u][1] = 1;
		for (int e = adj[u], v = go[e]; e; e = nxt[e], v = go[e])
			if (v != fa[u][0] && !vis[v])
			{
				exc[u][0] = 1ll * exc[u][0] * (f[v][0] + f[v][1]) % ZZQ;
				exc[u][1] = 1ll * exc[u][1] * f[v][0] % ZZQ;
			}
		for (int i = 1; i <= tt; i++) vis[uv[i]] = 0;
	}
}

void calc(int u, int fu)
{
	fh[u][0] = fh[u][1] = 0;
	if (col[u] != 1) fh[u][0] = 1;
	if (col[u] != 0) fh[u][1] = 1;
	for (int e = adj2[u], v = go2[e]; e; e = nxt2[e], v = go2[e])
	{
		calc(v, u);
		fh[u][0] = 1ll * fh[u][0] * ((1ll * val[v][0][0] * fh[v][0]
			+ 1ll * val[v][1][0] * fh[v][1]) % ZZQ) % ZZQ;
		fh[u][1] = 1ll * fh[u][1] * ((1ll * val[v][0][1] * fh[v][0]
			+ 1ll * val[v][1][1] * fh[v][1]) % ZZQ) % ZZQ;
	}
	fh[u][0] = 1ll * fh[u][0] * exc[u][0] % ZZQ;
	fh[u][1] = 1ll * fh[u][1] * exc[u][1] % ZZQ;
	if (u == 1) ans = (ans + (fh[u][0] + fh[u][1]) % ZZQ) % ZZQ;
}

void DFS(int dep)
{
	if (dep == tot + 1) return calc(vir[1], 0);
	int x = col[U[dep]], y = col[V[dep]];
	if (x != 0 && y != 1)
	{
		col[U[dep]] = 1; col[V[dep]] = 0;
		DFS(dep + 1);
		col[U[dep]] = x; col[V[dep]] = y;
	}
	if (x != 1) col[U[dep]] = 0, DFS(dep + 1), col[U[dep]] = x;
}

int main()
{
	int x, y, tn = 0;
	n = read(); m = read();
	for (int i = 1; i <= n; i++) FA[i] = i;
	while (m--)
	{
		x = read(); y = read();
		if (zm(x, y)) add_edge(x, y);
		else U[++tot] = x, V[tot] = y,
			vir[++tn] = x, vir[++tn] = y;
	}
	vir[++tn] = 1;
	std::sort(vir + 1, vir + tn + 1);
	virn = std::unique(vir + 1, vir + tn + 1) - vir - 1;
	dfs(1, 0);
	memset(col, -1, sizeof(col));
	build_virtree();
	DFS(1);
	std::cout << ans << std::endl;
	return 0;
}

猜你喜欢

转载自blog.csdn.net/xyz32768/article/details/86252403