【题解】P4178 Tree

前置芝士

P3806 【模板】点分治1 。不过数据真是水的可以,第一次我数组开小,过了;第二次我分治的时候没找中心,还是过了……所以也可以做P4149 Race

题意

和点分治模板很像:求树上距离小于等于 k k k的路径数量。(把模板的等于改成了小于等于,并且需要统计路径数量)

分析

由于题目变成了小于等于,那么我们就不能再用原来那套开桶的办法了。于是我们考虑把当前根的所有子树中的节点拉出来统计方案。

具体方法是把所有子树中的节点按照到根的距离排序,然后两端开始双指针 O ( n ) O(n) O(n)扫描。(不会双指针的自行 G o o g l e Google Google


于是你自信满满地打了一发,却发现爆 0 0 0了!!!

实际上是由于有一些情况没有考虑到。

由于普通点分治的时候是对根的每一棵子树分别进行答案统计,也就意味着统计的路径都是不同子树间的,但我们这次把所有点都拉出来排序了,也就不能保证这条路径横跨两棵子树。于是就会产生下图的情况:

V24JYV.png

途中蓝色路径虽然经过科根节点,但显然是不合法的路径,需要减去。

这里用到一点小小的容斥,由于这种不合法的路径一定在根节点的同一棵子树内,于是我们先计算出经过这个子树的根的路径数量(上图中子树的根是1),然后减去即可。而对于子树中出现的不合法情况,我们接着用同样的方法容斥即可。

注意点:容斥的时候并不是真正统计子节点答案的时候,计算之前需要把儿子的 d i s dis dis值设为当前根节点到它的边权,因为这样的不合法路径一定会到达根节点,在统计的时候必须把经过两次的那条多余边减去。

代码

#include <bits/stdc++.h>
#define MAX 100005
#define ll long long
#define INF 0x3f3f3f3f
using namespace std;

int n, k, cnt, rt, sum, tot;
int head[MAX], vet[MAX], Next[MAX], cost[MAX];
int dis[MAX], d[MAX], mx[MAX], sz[MAX], vis[MAX];
ll ans;

void add(int x, int y, int w){
    
    
    cnt++;
    Next[cnt] = head[x];
    head[x] = cnt;
    vet[cnt] = y;
    cost[cnt] = w;
}

void getrt(int x, int fa){
    
    		//找重心(点分治模板)
    sz[x] = 1, mx[x] = 0;
    for (int i = head[x]; i; i = Next[i]) {
    
    
        int v = vet[i];
        if(v == fa || vis[v]) continue;
        getrt(v, x);
        sz[x] += sz[v];
        mx[x] = max(mx[x], sz[v]);
    }
    mx[x] = max(mx[x], sum-sz[x]);
    if(mx[x] < mx[rt]) rt = x;
}

void getdis(int x, int fa){
    
    		//处理距离(模板)
    d[++tot] = dis[x];
    for (int i = head[x]; i; i = Next[i]) {
    
    
        int v = vet[i];
        if(v == fa || vis[v]) continue;
        dis[v] = dis[x]+cost[i];
        getdis(v, x);
    }
}

ll calc(int x, int w){
    
    		//计算贡献
    tot = 0;
    dis[x] = w;
    getdis(x, 0);		//处理出距离并排序
    sort(d+1, d+tot+1);
    int l = 1, r = tot;
    ll res = 0;
    while(l < r){
    
    		//双指针,从两头开始扫
        if(d[l]+d[r] <= k){
    
    
            res += r-l;
            l++;
        }
        else r--;
    }
    return res;
}

void solve(int x){
    
    
    vis[x] = 1;
    ans += calc(x, 0);
    for (int i = head[x]; i; i = Next[i]) {
    
    
        int v = vet[i];
        if(vis[v]) continue;
        ans -= calc(v, cost[i]);		//容斥,把儿子节点dis初值设为边权
        sum = sz[v];
        rt = 0, mx[rt] = INF;
        getrt(v, 0);
        solve(rt);
    }
}

int main()
{
    
    
    cin >> n;
    int x, y, w;
    for (int i = 1; i < n; ++i) {
    
    
        scanf("%d%d%d", &x, &y, &w);
        add(x, y, w);
        add(y, x, w);
    }
    cin >> k;
    mx[rt] = sum = n;
    getrt(1, 0);
    solve(rt);

    cout << ans << endl;

    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_30115697/article/details/91490399