【树上点分治1】Tree POJ 1741

原题链接
题意:给定一棵树,要求找出树上两点之间距离不超过k的点对有几个。

非常经典的点分治模板题,我们需要找出每一对经过当前根的点对,因此我们可以在找子树的时候把所有的子节点都丢进一个数组里,然后进行排序,利用尺取法选出点对。当然,这样可能会有重复,因为我们要求的是经过根的点对,但可能会有不经过根的点对也被计算进去,因此,利用容斥原理,在遍历每棵子树的时候减去以当前子节点为根的点对数。
在这里插入图片描述
如果当前A为根节点,我们提前处理出A到所有子节点的距离

A : 0
B : A->B
C : A -> B -> C
D : A -> B -> D
F : A -> F
E : A -> E

若C、D同时被选中,但他们到根节点的路径是有重复的,因此我们可以再减去以B为根,并选中AB边的子树中满足条件的个数。

具体看代码实现

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>
#include <stack>
#include <cmath>
#include <bitset>
#include <map>
using namespace std;
//#define ACM_LOCAL
typedef long long ll;
typedef long double ld;
typedef pair<int, int> PII;
const int N = 1e4 + 5;
const int INF = 0x3f3f3f3f;
const int MOD = 1e9 + 7;
int n, m, cnt, h[N], rt, sz[N], mx[N], vis[N], sum, ans, d[N], dep[N], k;

struct edge{
    
    
    int to, next, vi;
}e[N<<1];

void add(int u, int v, int w) {
    
    
    e[cnt].to = v;
    e[cnt].next = h[u];
    e[cnt].vi = w;
    h[u] = cnt++;
}

void get_rt(int x, int fa) {
    
    
    sz[x] = 1, mx[x] = 0;
    for (int i = h[x]; ~i; i = e[i].next) {
    
    
        int y = e[i].to;
        if (vis[y] || y == fa) continue;
        get_rt(y, x);
        sz[x] += sz[y];
        mx[x] = max(mx[x], sz[y]);
    }
    mx[x] = max(mx[x], sum - sz[x]);
    if (mx[x] < mx[rt]) rt = x;
}

void get_d(int x, int fa) {
    
    
    d[++d[0]] = dep[x];
    for (int i = h[x]; ~i; i = e[i].next) {
    
    
        int y = e[i].to;
        if (vis[y] || y == fa) continue;
        dep[y] = dep[x] + e[i].vi;
        get_d(y, x);
    }
}

int cal(int x, int now) {
    
    
    d[0] = 0, dep[x] = now;
    get_d(x, -1);
    sort(d+1, d+d[0]+1);
    int ans = 0;
    int l = 1, r = d[0];
    while (l < r) {
    
    
        if (d[l] + d[r] > k) r--;
        else ans += r - l, l++;
    }
    return ans;
}

void work(int x) {
    
    
    ans += cal(x, 0);
    vis[x] = 1;
    for (int i = h[x]; ~i; i = e[i].next) {
    
    
        int y = e[i].to;
        if (vis[y]) continue;
        ans -= cal(y, e[i].vi);
        sum = sz[y], rt = 0;
        get_rt(y, -1);
        work(rt);
    }
}

void solve () {
    
    
    while (scanf("%d %d", &n, &k), n) {
    
    
        memset(h, -1, sizeof h);
        memset(vis, 0, sizeof vis);
        cnt = 0;
        for (int i = 1; i <= n-1; i++) {
    
    
            int x, y, z;
            scanf("%d %d %d", &x, &y, &z);
            add(x, y, z);
            add(y, x, z);
        }
        rt = 0, sum = n, mx[0] = INF, ans = 0;
        get_rt(1, -1);
        work(rt);
        printf("%d\n", ans);
    }

}

int main() {
    
    
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
#ifdef ACM_LOCAL
    freopen("input", "r", stdin);
    freopen("output", "w", stdout);
#endif
    solve();
}

猜你喜欢

转载自blog.csdn.net/kaka03200/article/details/109406422