Address
Solution - Step 1
- 首先我们看到非树边最多
11 条
- 很容易想到暴力枚举每条非树边的两个端点是否在独立集内
- 然后树上 DP 求独立集个数
- 复杂度
O(3m−n×n)
- 然后我们发现对于一条非树边
(u,v) ,如果强制
u 不在独立集内,那么
v 点没有再去强制的必要
- 所以只需要枚举
2 种情况
- (1)
u 在独立集内,
v 不在独立集内
- (2)
u 不在独立集内
- 复杂度
O(2m−n×n)
常数优秀的程序能够水过
Solution - Step 2
- 考虑把 DP 的复杂度优化掉
- 发现这个 DP 与对整棵树进行 DP 的区别仅仅是强制了
O(m−n) 个点选或不选
- 考虑虚树
- 把所有非树边的端点放在一起建立虚树
- 注意:为了消除虚树根的子树之外点的影响,需要把整棵树的根节点加入关键点集合
- 然后在虚树上 DP ,转移方程为
-
f[u][0]=h[u][0]×v∈son[u]∏(f[v][0]×g[(u,v)][0][0]+f[v][1]×g[(u,v)][0][1]])
-
f[u][1]=h[u][1]×v∈son[u]∏(f[v][0]×g[(u,v)][1][0]+f[v][1]×g[(u,v)][1][1])
-
f[u][0] 为虚树上
u 的子树内不选
u 的独立集个数
-
f[u][1] 为虚树上
u 的子树内不选
u 的独立集个数
-
son[u] 为虚树上
u 的子节点集合
-
g[(u,v)][x][y] 为原树上
u 到
v 的路径(虚树上
u 是
v 的父亲节点),
u 的选择状态为
x ,
v 的选择状态为
y ,设原树上
w 是
u 的子节点且是
v 的祖先,就
w 的子树内但不在
v 的子树内的所有点与
u 和
v 构成独立集的方案数,画个图长这样
- 即为强制
u 的状态为
x ,
v 的状态为
y 时,黄色部分构成独立集的方案数
-
h[u][0/1] 表示
u 的子树内,除去所有
v 的子树(
v 满足其子树内存在虚树点),
u 不选 / 选的方案数
- 如上图,紫色点为虚树点,特殊地,
u 为虚树点
-
g 可以建出虚树之后大力预处理,
h 可以在树上 DP 求出
- 这样我们就能
O(m−n) 实现 DP 了
- 复杂度
O(nlogn+2m−n×(m−n))
Code
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define Edge(u) for (int e = adj[u], v = go[e]; e; e = nxt[e], v = go[e])
#define Tree(u) for (int e = adj[u], v; e; e = nxt[e]) if ((v = go[e]) != fu)
#define Vir(u) for (int e = adj2[u], v = go2[e]; e; e = nxt2[e], v = go2[e])
inline int read()
{
int res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
return bo ? ~res + 1 : res;
}
template <class T>
inline void Swap(T &a, T &b) {a ^= b; b ^= a; a ^= b;}
const int N = 1e5 + 5, M = 2e5 + 100, LogN = 20, ZZQ = 998244353;
int n, m, ecnt, nxt[M], adj[N], go[M], tot, U[N], V[N], FA[N],
virn, vir[N], T, dfn[N], fa[N][LogN], dep[N], stk[N], top, virfa[N],
f[N][2], val[N][2][2], col[N], ans, ecnt2, nxt2[N], adj2[N], go2[N],
fh[N][2], exc[N][2], tt, uv[N];
bool vis[N];
void add_edge(int u, int v)
{
nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u;
}
void add_edge2(int u, int v)
{
nxt2[++ecnt2] = adj2[u]; adj2[u] = ecnt2; go2[ecnt2] = v;
}
void dfs(int u, int fu)
{
dep[u] = dep[fa[u][0] = fu] + 1;
for (int i = 0; i <= 15; i++) fa[u][i + 1] = fa[fa[u][i]][i];
dfn[u] = ++T;
f[u][0] = f[u][1] = 1;
for (int e = adj[u], v; e; e = nxt[e])
{
if ((v = go[e]) == fu) continue;
dfs(v, u);
f[u][0] = 1ll * f[u][0] * (f[v][0] + f[v][1]) % ZZQ;
f[u][1] = 1ll * f[u][1] * f[v][0] % ZZQ;
}
}
int get_0(int u, int _u)
{
int res = 1;
for (int e = adj[u], v = go[e]; e; e = nxt[e], v = go[e])
if (v != fa[u][0] && v != _u)
res = 1ll * res * (f[v][0] + f[v][1]) % ZZQ;
return res;
}
int get_1(int u, int _u)
{
int res = 1;
for (int e = adj[u], v = go[e]; e; e = nxt[e], v = go[e])
if (v != fa[u][0] && v != _u)
res = 1ll * res * f[v][0] % ZZQ;
return res;
}
int lca(int u, int v)
{
if (dep[u] < dep[v]) Swap(u, v);
for (int i = 16; i >= 0; i--)
{
if (dep[fa[u][i]] >= dep[v]) u = fa[u][i];
if (u == v) return u;
}
for (int i = 16; i >= 0; i--)
if (fa[u][i] != fa[v][i])
u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
int orz(int u, int v)
{
for (int i = 16; i >= 0; i--)
if ((dep[v] - dep[u] - 1 >> i) & 1)
v = fa[v][i];
return v;
}
inline bool comp(int a, int b)
{
return dfn[a] < dfn[b];
}
int cx(int x)
{
if (FA[x] != x) FA[x] = cx(FA[x]);
return FA[x];
}
bool zm(int x, int y)
{
int ix = cx(x), iy = cx(y);
if (ix != iy) return FA[iy] = ix, 1;
return 0;
}
void build_virtree()
{
std::sort(vir + 1, vir + virn + 1, comp);
int tn = virn;
for (int i = 1; i <= tn; i++)
{
int u = vir[i];
if (!top) {virfa[stk[++top] = u] = 0; continue;}
int w = lca(stk[top], u);
while (dep[w] < dep[stk[top]])
{
if (dep[w] > dep[stk[top - 1]]) virfa[stk[top]] = w;
top--;
}
if (dep[w] > dep[stk[top]])
virfa[w] = stk[top], stk[++top] = w, vir[++virn] = w;
virfa[stk[++top] = u] = w;
}
std::sort(vir + 1, vir + virn + 1, comp);
for (int i = 2; i <= virn; i++)
{
add_edge2(virfa[vir[i]], vir[i]);
for (int c = 0; c <= 1; c++)
{
int u = fa[vir[i]][0], fau = virfa[vir[i]],
lst = vir[i], r0 = 0, r1 = 0;
(c ? r1 : r0) = 1;
while ("I AK IOI")
{
int n0, n1;
if (u != fau) n0 = 1ll * (r0 + r1) * get_0(u, lst) % ZZQ,
n1 = 1ll * r0 * get_1(u, lst) % ZZQ;
else n0 = (r0 + r1) % ZZQ, n1 = r0;
r0 = n0; r1 = n1;
if (u == fau) break;
lst = u; u = fa[u][0];
}
val[vir[i]][c][0] = r0; val[vir[i]][c][1] = r1;
}
}
for (int i = 1; i <= virn; i++)
{
int u = vir[i]; tt = 0;
for (int e = adj2[u], v = go2[e]; e; e = nxt2[e], v = go2[e])
vis[uv[++tt] = orz(u, v)] = 1;
exc[u][0] = exc[u][1] = 1;
for (int e = adj[u], v = go[e]; e; e = nxt[e], v = go[e])
if (v != fa[u][0] && !vis[v])
{
exc[u][0] = 1ll * exc[u][0] * (f[v][0] + f[v][1]) % ZZQ;
exc[u][1] = 1ll * exc[u][1] * f[v][0] % ZZQ;
}
for (int i = 1; i <= tt; i++) vis[uv[i]] = 0;
}
}
void calc(int u, int fu)
{
fh[u][0] = fh[u][1] = 0;
if (col[u] != 1) fh[u][0] = 1;
if (col[u] != 0) fh[u][1] = 1;
for (int e = adj2[u], v = go2[e]; e; e = nxt2[e], v = go2[e])
{
calc(v, u);
fh[u][0] = 1ll * fh[u][0] * ((1ll * val[v][0][0] * fh[v][0]
+ 1ll * val[v][1][0] * fh[v][1]) % ZZQ) % ZZQ;
fh[u][1] = 1ll * fh[u][1] * ((1ll * val[v][0][1] * fh[v][0]
+ 1ll * val[v][1][1] * fh[v][1]) % ZZQ) % ZZQ;
}
fh[u][0] = 1ll * fh[u][0] * exc[u][0] % ZZQ;
fh[u][1] = 1ll * fh[u][1] * exc[u][1] % ZZQ;
if (u == 1) ans = (ans + (fh[u][0] + fh[u][1]) % ZZQ) % ZZQ;
}
void DFS(int dep)
{
if (dep == tot + 1) return calc(vir[1], 0);
int x = col[U[dep]], y = col[V[dep]];
if (x != 0 && y != 1)
{
col[U[dep]] = 1; col[V[dep]] = 0;
DFS(dep + 1);
col[U[dep]] = x; col[V[dep]] = y;
}
if (x != 1) col[U[dep]] = 0, DFS(dep + 1), col[U[dep]] = x;
}
int main()
{
int x, y, tn = 0;
n = read(); m = read();
for (int i = 1; i <= n; i++) FA[i] = i;
while (m--)
{
x = read(); y = read();
if (zm(x, y)) add_edge(x, y);
else U[++tot] = x, V[tot] = y,
vir[++tn] = x, vir[++tn] = y;
}
vir[++tn] = 1;
std::sort(vir + 1, vir + tn + 1);
virn = std::unique(vir + 1, vir + tn + 1) - vir - 1;
dfs(1, 0);
memset(col, -1, sizeof(col));
build_virtree();
DFS(1);
std::cout << ans << std::endl;
return 0;
}