原题链接
题意:给定一棵树,要求找出树上两点之间距离不超过k的点对有几个。
非常经典的点分治模板题,我们需要找出每一对经过当前根的点对,因此我们可以在找子树的时候把所有的子节点都丢进一个数组里,然后进行排序,利用尺取法选出点对。当然,这样可能会有重复,因为我们要求的是经过根的点对,但可能会有不经过根的点对也被计算进去,因此,利用容斥原理,在遍历每棵子树的时候减去以当前子节点为根的点对数。
如果当前A为根节点,我们提前处理出A到所有子节点的距离
A : 0
B : A->B
C : A -> B -> C
D : A -> B -> D
F : A -> F
E : A -> E
若C、D同时被选中,但他们到根节点的路径是有重复的,因此我们可以再减去以B为根,并选中AB边的子树中满足条件的个数。
具体看代码实现
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>
#include <stack>
#include <cmath>
#include <bitset>
#include <map>
using namespace std;
//#define ACM_LOCAL
typedef long long ll;
typedef long double ld;
typedef pair<int, int> PII;
const int N = 1e4 + 5;
const int INF = 0x3f3f3f3f;
const int MOD = 1e9 + 7;
int n, m, cnt, h[N], rt, sz[N], mx[N], vis[N], sum, ans, d[N], dep[N], k;
struct edge{
int to, next, vi;
}e[N<<1];
void add(int u, int v, int w) {
e[cnt].to = v;
e[cnt].next = h[u];
e[cnt].vi = w;
h[u] = cnt++;
}
void get_rt(int x, int fa) {
sz[x] = 1, mx[x] = 0;
for (int i = h[x]; ~i; i = e[i].next) {
int y = e[i].to;
if (vis[y] || y == fa) continue;
get_rt(y, x);
sz[x] += sz[y];
mx[x] = max(mx[x], sz[y]);
}
mx[x] = max(mx[x], sum - sz[x]);
if (mx[x] < mx[rt]) rt = x;
}
void get_d(int x, int fa) {
d[++d[0]] = dep[x];
for (int i = h[x]; ~i; i = e[i].next) {
int y = e[i].to;
if (vis[y] || y == fa) continue;
dep[y] = dep[x] + e[i].vi;
get_d(y, x);
}
}
int cal(int x, int now) {
d[0] = 0, dep[x] = now;
get_d(x, -1);
sort(d+1, d+d[0]+1);
int ans = 0;
int l = 1, r = d[0];
while (l < r) {
if (d[l] + d[r] > k) r--;
else ans += r - l, l++;
}
return ans;
}
void work(int x) {
ans += cal(x, 0);
vis[x] = 1;
for (int i = h[x]; ~i; i = e[i].next) {
int y = e[i].to;
if (vis[y]) continue;
ans -= cal(y, e[i].vi);
sum = sz[y], rt = 0;
get_rt(y, -1);
work(rt);
}
}
void solve () {
while (scanf("%d %d", &n, &k), n) {
memset(h, -1, sizeof h);
memset(vis, 0, sizeof vis);
cnt = 0;
for (int i = 1; i <= n-1; i++) {
int x, y, z;
scanf("%d %d %d", &x, &y, &z);
add(x, y, z);
add(y, x, z);
}
rt = 0, sum = n, mx[0] = INF, ans = 0;
get_rt(1, -1);
work(rt);
printf("%d\n", ans);
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
#ifdef ACM_LOCAL
freopen("input", "r", stdin);
freopen("output", "w", stdout);
#endif
solve();
}