Tree POJ - 1741 点分治

Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
Sample Output
8

给定一棵树,求边权值和<= k 的路径数;

点分治的基础题;

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstdlib>
#include<cstring>
#include<string>
#include<cmath>
#include<map>
#include<set>
#include<vector>
#include<queue>
#include<string>
#include<bitset>
#include<ctime>
#include<deque>
#include<stack>
#include<functional>
#include<sstream>
using namespace std;
#define maxn 200005
#define inf 0x3f3f3f3f
#define INF 0x7fffffff
typedef long long  ll;
typedef unsigned long long ull;
#define ms(x) memset(x,0,sizeof(x))
const long long int mod = 1e9 + 7;

inline int read()
{
    int x = 0, k = 1; char c = getchar();
    while (c < '0' || c > '9') { if (c == '-')k = -1; c = getchar(); }
    while (c >= '0' && c <= '9')x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
    return x * k;
}

struct node {
    int l, v;
    node() {}
    node(int v,int l):v(v),l(l){}
};

int n, k;
int root;
int s[maxn], f[maxn];
int d[maxn];
int vis[maxn];
vector<int>dep;
int sz;
int ans;
vector<node>G[maxn];

void getrot(int now, int fa) {
    int u;
    s[now] = 1; f[now] = 0;
    for (int i = 0; i < G[now].size(); i++) {
        if ((u = G[now][i].v) != fa && !vis[u]) {
            getrot(u, now);
            s[now] += s[u];
            f[now] = max(f[now], s[u]);
        }
    }
    f[now] = max(f[now], sz - s[now]);
    if (f[now] < f[root])root = now;
}

void getdep(int now, int fa) {
    int u;
    dep.push_back(d[now]);
    s[now] = 1;
    for (int i = 0; i < G[now].size(); i++) {
        if ((u = G[now][i].v) != fa && !vis[u]) {
            d[u] = d[now] + G[now][i].l;
            getdep(u,now);
            s[now] += s[u];
        }
    }
}

int cal(int now, int init) {
    int u;
    d[now] = init;
    dep.clear();
    getdep(now, 0);
    int res = 0;
    sort(dep.begin(), dep.end());
    for (int l = 0, r=dep.size()-1; l < r;) {
        if (dep[l] + dep[r] <= k) {
            res += r - l; l++;
        }
        else r--;
    }
    return res;
}


void sol(int now) {
    //ans = 0;
    ans += cal(now, 0);
    vis[now] = 1;
    int u;
    for (int i = 0; i < G[now].size(); i++) {
        if (!vis[u = G[now][i].v]) {
            ans -= cal(u, G[now][i].l);
            f[0] = sz = s[u];
            getrot(u,root=0);
            sol(root);
        }
    }

}



int main()
{
    ios::sync_with_stdio(false);
    while (cin >> n >> k && n&&k) {
        for (int i = 0; i <= n; i++)G[i].clear();
        ms(vis);
        for (int i = 1; i < n; i++) {
            int u, v, l;
            u = read(); v = read(); l = read();
            G[u].push_back(node(v, l));
            G[v].push_back(node(u, l));
        }
        f[0] = sz = n;
        ans = 0;
        getrot(1, root = 0);
        sol(root);
        cout << ans << endl;
    }
}

猜你喜欢

转载自blog.csdn.net/qq_40273481/article/details/81780856
今日推荐