Tree
Time Limit: 1000MS | Memory Limit: 30000K | |
Total Submissions: 26911 | Accepted: 8953 |
Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4 1 2 3 1 3 1 1 4 2 3 5 1 0 0
Sample Output
8
Source
【思路】
给定一棵树,边上有权,要求的是树上距离不超过k的点对数。
我们假设一棵有根树,维护树中的点到根的距离d,树中的任何路径必然有两种,1、经过根,2、不经过根。经过根的那些路径,其两端点u、v必然满足d[u] + d[v] <= k,不经过根的路径,也可以设根为子树中某一点来递归处理,问题解决。
子树中的路径加和判断可以通过O(NlogN)排序然后O(N)算出,那么我们的首要任务就变成了尽可能减少排序的次数,树形结构决定了排序次数和树的层数一致,所以也就需要通过寻找树的重心来减少树的层数,使整个算法复杂度在O(NlogN*logN)级别。
【代码】
//************************************************************************ // File Name: POJ_1741.cpp // Author: Shili_Xu // E-Mail: [email protected] // Created Time: 2018年02月26日 星期一 21时49分18秒 //************************************************************************ #include <cstdio> #include <cstring> #include <algorithm> #include <vector> using namespace std; const int MAXN = 10005; struct edge { int to, len; edge(int _to, int _len) : to(_to), len(_len) {} }; int n, k, root, size; int sz[MAXN], mxson[MAXN], d[MAXN]; bool visited[MAXN]; vector<edge> g[MAXN]; vector<int> dist; void get_root(int u, int fa) { sz[u] = 1; mxson[u] = 0; for (int i = 0; i < g[u].size(); i++) { int v = g[u][i].to; if (v != fa && !visited[v]) { get_root(v, u); sz[u] += sz[v]; mxson[u] = max(mxson[u], sz[v]); } } mxson[u] = max(mxson[u], size - sz[u]); if (mxson[u] < mxson[root]) root = u; } void get_dist(int u, int fa) { dist.push_back(d[u]); for (int i = 0; i < g[u].size(); i++) { int v = g[u][i].to; if (v != fa && !visited[v]) { d[v] = d[u] + g[u][i].len; get_dist(v, u); } } } int cal(int u, int base) { dist.clear(); d[u] = base; get_dist(u, 0); sort(dist.begin(), dist.end()); int ans = 0, l = 0, r = dist.size() - 1; while (l < r) { if (dist[l] + dist[r] <= k) ans += (r - l), l++; else r--; } return ans; } int work(int u) { visited[u] = true; int ans = 0; ans += cal(u, 0); for (int i = 0; i < g[u].size(); i++) { int v = g[u][i].to; if (!visited[v]) { ans -= cal(v, g[u][i].len); root = 0; size = mxson[0] = sz[v]; get_root(v, 0); ans += work(root); } } return ans; } int main() { while (scanf("%d %d", &n, &k) == 2 && n != 0 && k != 0) { for (int i = 1; i <= n; i++) g[i].clear(); for (int i = 1; i <= n - 1; i++) { int a, b, c; scanf("%d %d %d", &a, &b, &c); g[a].push_back(edge(b, c)); g[b].push_back(edge(a, c)); } int ans = 0; root = 0; size = mxson[0] = n; memset(visited, false, sizeof(visited)); get_root(1, 0); ans = work(root); printf("%d\n", ans); } return 0; }