前置芝士
P3806 【模板】点分治1 。不过数据真是水的可以,第一次我数组开小,过了;第二次我分治的时候没找中心,还是过了……所以也可以做P4149 Race 。
题意
和点分治模板很像:求树上距离小于等于 k k k的路径数量。(把模板的等于改成了小于等于,并且需要统计路径数量)
分析
由于题目变成了小于等于,那么我们就不能再用原来那套开桶的办法了。于是我们考虑把当前根的所有子树中的节点拉出来统计方案。
具体方法是把所有子树中的节点按照到根的距离排序,然后两端开始双指针 O ( n ) O(n) O(n)扫描。(不会双指针的自行 G o o g l e Google Google)
于是你自信满满地打了一发,却发现爆 0 0 0了!!!
实际上是由于有一些情况没有考虑到。
由于普通点分治的时候是对根的每一棵子树分别进行答案统计,也就意味着统计的路径都是不同子树间的,但我们这次把所有点都拉出来排序了,也就不能保证这条路径横跨两棵子树。于是就会产生下图的情况:
途中蓝色路径虽然经过科根节点,但显然是不合法的路径,需要减去。
这里用到一点小小的容斥,由于这种不合法的路径一定在根节点的同一棵子树内,于是我们先计算出经过这个子树的根的路径数量(上图中子树的根是1),然后减去即可。而对于子树中出现的不合法情况,我们接着用同样的方法容斥即可。
注意点:容斥的时候并不是真正统计子节点答案的时候,计算之前需要把儿子的 d i s dis dis值设为当前根节点到它的边权,因为这样的不合法路径一定会到达根节点,在统计的时候必须把经过两次的那条多余边减去。
代码
#include <bits/stdc++.h>
#define MAX 100005
#define ll long long
#define INF 0x3f3f3f3f
using namespace std;
int n, k, cnt, rt, sum, tot;
int head[MAX], vet[MAX], Next[MAX], cost[MAX];
int dis[MAX], d[MAX], mx[MAX], sz[MAX], vis[MAX];
ll ans;
void add(int x, int y, int w){
cnt++;
Next[cnt] = head[x];
head[x] = cnt;
vet[cnt] = y;
cost[cnt] = w;
}
void getrt(int x, int fa){
//找重心(点分治模板)
sz[x] = 1, mx[x] = 0;
for (int i = head[x]; i; i = Next[i]) {
int v = vet[i];
if(v == fa || vis[v]) continue;
getrt(v, x);
sz[x] += sz[v];
mx[x] = max(mx[x], sz[v]);
}
mx[x] = max(mx[x], sum-sz[x]);
if(mx[x] < mx[rt]) rt = x;
}
void getdis(int x, int fa){
//处理距离(模板)
d[++tot] = dis[x];
for (int i = head[x]; i; i = Next[i]) {
int v = vet[i];
if(v == fa || vis[v]) continue;
dis[v] = dis[x]+cost[i];
getdis(v, x);
}
}
ll calc(int x, int w){
//计算贡献
tot = 0;
dis[x] = w;
getdis(x, 0); //处理出距离并排序
sort(d+1, d+tot+1);
int l = 1, r = tot;
ll res = 0;
while(l < r){
//双指针,从两头开始扫
if(d[l]+d[r] <= k){
res += r-l;
l++;
}
else r--;
}
return res;
}
void solve(int x){
vis[x] = 1;
ans += calc(x, 0);
for (int i = head[x]; i; i = Next[i]) {
int v = vet[i];
if(vis[v]) continue;
ans -= calc(v, cost[i]); //容斥,把儿子节点dis初值设为边权
sum = sz[v];
rt = 0, mx[rt] = INF;
getrt(v, 0);
solve(rt);
}
}
int main()
{
cin >> n;
int x, y, w;
for (int i = 1; i < n; ++i) {
scanf("%d%d%d", &x, &y, &w);
add(x, y, w);
add(y, x, w);
}
cin >> k;
mx[rt] = sum = n;
getrt(1, 0);
solve(rt);
cout << ans << endl;
return 0;
}