题意:给一棵树,树上有一些关键节点,选m个点,使得关键节点到这些点中距离的最小值的最大值最小
最大值最小,果断二分答案
我们只需要判定是否存在m个点能够在mid范围内到达所有关键点
暴力:从每个点bfs一遍看看mid范围能是否能覆盖到所有的点,o(n^2logn)
发现可以贪心,一个关键点要么被它的子树内的点管理,要么被它子树外的点管理,于是我们记录个pair/struct
first表示以x为根的子树中目前还没有人管理的关键点距离x的最远的距离,second表示以x为根的子树中选择了的点距离x的最近的距离.
①if(first+second<=mid)以x为根的树是可以自己处理的
②if(first==mid)就意味着必须要选择x这个点了
因为再向上一个点距离就超过mid了,这时候强制选择x这个点,并更新first,second即可
③if(这个点是关键点&&second>mid)就要更新first了
注意出来的时候要特判1(树根)
#include <stdio.h> #include <cstdlib> #include <algorithm> #include <cstring> #include <time.h> #pragma warning(disable:4996) template<typename T> T min(T x, T y) { return x < y ? x : y; } template<typename T> T max(T x, T y) { return x > y ? x : y; } const int MAXN = 300005; const int B = 400; const int INF = 2000000005; struct node { int to; node *next; }; void addnode(node *&head, int to) { node *p = new node; p->to = to; p->next = head; head = p; } int N, M; node *edge[MAXN]; int deep[MAXN], fa[MAXN][25], key[MAXN], cnt; int rank[MAXN], st[MAXN * 2][25], len; int f[MAXN], list[MAXN], num; bool mark[MAXN]; bool cmp(const int u, const int v) { return deep[u] > deep[v]; } int anc(int x, int k) { for (int i = 0; i < 25; i++) if (k&(1 << i)) x = fa[x][i]; return x; } int LCA(int x, int y) { x = rank[x]; y = rank[y]; if (x > y) std::swap(x, y); int len = y - x + 1; int t = 0; while (1 << t <= len) t++; t--; y = y - (1 << t) + 1; return deep[st[x][t]] > deep[st[y][t]]? st[y][t] : st[x][t]; } int dis(int x, int y) { return deep[x] + deep[y] - 2 * deep[LCA(x, y)]; } int nearest(int v) { int x = f[v]; for (int i = 1; i <= num; i++) x = min(x, dis(v, list[i])); return x; } void dfs1(int v) { f[v] = mark[v]? 0: INF; for (node *p = edge[v]; p; p = p->next) if (p->to != fa[v][0]) { dfs1(p->to); f[v] = min(f[v], f[p->to] + 1); } } void dfs2(int v) { f[v] = min(f[v], f[fa[v][0]] + 1); for (node *p = edge[v]; p; p = p->next) if (p->to != fa[v][0]) dfs2(p->to); } void insert(int v) { list[++num] = v; if (num == B) { for (int i = 1; i <= num; i++) mark[list[i]] = true; dfs1(1); dfs2(1); num = 0; } } bool judge(int d) { int i, n = 0; num = 0; memset(mark, 0, sizeof(mark)); memset(f, 63, sizeof(f)); for (i = 1; i <= cnt; i++) { if (nearest(key[i]) > d) { n++; insert(anc(key[i], min(d, deep[key[i]]))); } } return n <= M; } void dfs(int v) { st[++len][0] = v; rank[v] = len; for (int i = 1; i < 25; i++) fa[v][i] = fa[fa[v][i - 1]][i - 1]; for (node *p = edge[v]; p; p = p->next) if (p->to != fa[v][0]) { fa[p->to][0] = v; deep[p->to] = deep[v] + 1; dfs(p->to); st[++len][0] = v; } } void init() { int i, j, u, v; scanf("%d %d", &N, &M); for (i = 1; i <= N; i++) { scanf("%d", &u); if (u) key[++cnt] = i; } for (i = 1; i < N; i++) { scanf("%d %d", &u, &v); addnode(edge[u], v); addnode(edge[v], u); } dfs(1); std::sort(key + 1, key + cnt + 1, cmp); deep[0] = INF; for (i = 1; i < 25; i++) { int r = min(1 << (i - 1), len); for (j = 1; j <= len; j++) { if (r < len) r++; st[j][i] = cmp(st[j][i - 1], st[r][i - 1]) ? st[r][i - 1] : st[j][i - 1]; } } } int main() { int l = -1, r = MAXN; init(); while (r - l > 1) { int mid = (l + r) / 2; if (judge(mid)) r = mid; else l = mid; } printf("%d\n", r); return 0; }