Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
Sample Output
8
给定一棵树,求边权值和<= k 的路径数;
点分治的基础题;
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstdlib>
#include<cstring>
#include<string>
#include<cmath>
#include<map>
#include<set>
#include<vector>
#include<queue>
#include<string>
#include<bitset>
#include<ctime>
#include<deque>
#include<stack>
#include<functional>
#include<sstream>
using namespace std;
#define maxn 200005
#define inf 0x3f3f3f3f
#define INF 0x7fffffff
typedef long long ll;
typedef unsigned long long ull;
#define ms(x) memset(x,0,sizeof(x))
const long long int mod = 1e9 + 7;
inline int read()
{
int x = 0, k = 1; char c = getchar();
while (c < '0' || c > '9') { if (c == '-')k = -1; c = getchar(); }
while (c >= '0' && c <= '9')x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
return x * k;
}
struct node {
int l, v;
node() {}
node(int v,int l):v(v),l(l){}
};
int n, k;
int root;
int s[maxn], f[maxn];
int d[maxn];
int vis[maxn];
vector<int>dep;
int sz;
int ans;
vector<node>G[maxn];
void getrot(int now, int fa) {
int u;
s[now] = 1; f[now] = 0;
for (int i = 0; i < G[now].size(); i++) {
if ((u = G[now][i].v) != fa && !vis[u]) {
getrot(u, now);
s[now] += s[u];
f[now] = max(f[now], s[u]);
}
}
f[now] = max(f[now], sz - s[now]);
if (f[now] < f[root])root = now;
}
void getdep(int now, int fa) {
int u;
dep.push_back(d[now]);
s[now] = 1;
for (int i = 0; i < G[now].size(); i++) {
if ((u = G[now][i].v) != fa && !vis[u]) {
d[u] = d[now] + G[now][i].l;
getdep(u,now);
s[now] += s[u];
}
}
}
int cal(int now, int init) {
int u;
d[now] = init;
dep.clear();
getdep(now, 0);
int res = 0;
sort(dep.begin(), dep.end());
for (int l = 0, r=dep.size()-1; l < r;) {
if (dep[l] + dep[r] <= k) {
res += r - l; l++;
}
else r--;
}
return res;
}
void sol(int now) {
//ans = 0;
ans += cal(now, 0);
vis[now] = 1;
int u;
for (int i = 0; i < G[now].size(); i++) {
if (!vis[u = G[now][i].v]) {
ans -= cal(u, G[now][i].l);
f[0] = sz = s[u];
getrot(u,root=0);
sol(root);
}
}
}
int main()
{
ios::sync_with_stdio(false);
while (cin >> n >> k && n&&k) {
for (int i = 0; i <= n; i++)G[i].clear();
ms(vis);
for (int i = 1; i < n; i++) {
int u, v, l;
u = read(); v = read(); l = read();
G[u].push_back(node(v, l));
G[v].push_back(node(u, l));
}
f[0] = sz = n;
ans = 0;
getrot(1, root = 0);
sol(root);
cout << ans << endl;
}
}