codeJan和树 (dfs+倍增)

链接:https://www.nowcoder.com/acm/contest/81/D
来源:牛客网

题目描述
codeJan有一天脑洞大开,想到一个有趣的问题。给一个固定根为1号结点的树,定义一个子树的beauty是这个子树的根节点到所有这棵树上其他节点的距离和,叶子节点的beauty是0。定义一个子树的sub-beauty是这个子树的beauty值减去这个子树的某一个子树(不包括自身)的beauty值。显然一个子树的beauty值是唯一的,而sub-beauty值可以有很多个。codeJan想要知道所有子树的所有sub-beauty中不超过m的最大值。
输入描述:

第一行是一个T≤20代表测试组数。每组测试的第一行包含两个正整数是n,m(n≤105,m≤108),接下来n−1
行每行包含三个正整数a b d,分别表示a结点和b结点之间的距离是d,a,b∈[1,n],1≤d≤103。请注意每棵树的根节点都是1号结点,并且保证输入合法。

输出描述:

对于每组测试样例输出一个整数表示所有子树sub-beauty中不超过m的最大值。如果所有子树的sub-beauty都大于m,输出-1。

示例1
输入

3
3 4
1 2 1
1 3 2
3 4
1 2 3
1 3 2
4 6
1 2 2
2 3 5
3 4 2

输出

3
-1
6

思路:
不从每个祖先树考虑子树,而是从子树考虑祖先树,然后用倍增查找下就行
accode

#include<bits/stdc++.h>
#define LL long long
#define INF 0x3f3f3f3f
using namespace std;
const int maxn = 1e5+5;
int n,m;
LL dp[maxn];
int fa[maxn];
int bz[maxn][31];
int head[maxn];
int tot;
int dep[maxn];
LL mx[maxn];
struct node
{
    int v;
    int net;
    LL va;
}E[maxn*2];
void init()
{
    memset(head,-1,sizeof(head));
    tot = 0;
}
void build(int u,int v,LL va)
{
    E[tot].v = v;
    E[tot].va = va;
    E[tot].net = head[u];
    head[u] = tot++;
}
int dfs(int u,int deep,int pa)
{
   // cout<<"fwfw"<<endl;
    dep[u] = deep;
    int cnt = 0;
    for(int i = head[u];~i;i = E[i].net){
        int v = E[i].v;
        if(v==pa) continue;
        fa[v] = u;
        int tmp = dfs(v,deep+1,u);
        dp[u] += dp[v]+tmp*E[i].va;
        cnt+=tmp;
    }
    if(cnt==0){
        dp[u] = 0;
        return 1;
    }
    return cnt+1;
}
void BZ()
{
    for(int i = 1;i<=n;i++){
        bz[i][0] = fa[i];
    }
    for(int i = 1;i<31;i++){
        for(int j = 1;j<=n;j++){
            bz[j][i] = bz[bz[j][i-1]][i-1];
        }
    }
}
LL getans(int x)
{
    LL ret = 0;
    int xx = x;
    for(int i = 30;i>=0;i--){
        if(bz[xx][i] == 0) continue;
        if(dp[bz[xx][i]]-dp[x]<=m){
            xx = bz[xx][i];
        }
    }
    if(xx==x) return -1;
    return dp[xx]-dp[x];
}
int t;
int main()
{
    scanf("%d",&t);
    while(t--){
        init();
        memset(bz,0,sizeof(bz));
        memset(mx,-1,sizeof(mx));
        memset(dp,0,sizeof(dp));
        scanf("%d%d",&n,&m);
        fa[1] = 0;
        for(int i = 0;i<n-1;i++){
            int u,v;
            LL d;
            scanf("%d%d%lld",&u,&v,&d);
            build(u,v,d);
            build(v,u,d);
            //fa[v] = u;
        }
        dfs(1,0,-1);
        BZ();
        for(int i = 2;i<=n;i++){
            mx[i] = getans(i);
        }
        LL ans = -1;
        for(int i = 1;i<=n;i++){
            ans = max(ans,mx[i]);
        }
        printf("%lld\n",ans);
    }
}

猜你喜欢

转载自blog.csdn.net/w571523631/article/details/80365405