[Nowcoder 2018ACM多校第一场H] Longest Path

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013578420/article/details/81200654

题目大意:
给你一棵n个节点的树, 带有边权c[i], 定义路径{ e 1 , e 2 e k }的费用是 ( e 1 e 2 ) 2 + ( e 2 e 3 ) 2 + + ( e k 1 e k ) 2 。 求每个节点距自身最远点的距离。 ( n 10 5 , c i 10 5 , n 10 6 )

题目思路:
类似于求树的直径做树形dp, 先选定1为根, 用f[i]表示向下走的答案, g[i]表示向上走的答案。 由于费用的更新需要用到两条边, 故扩展一下用f_ch[u][v]表示从u往下走第一步到v的答案, v是u的孩子, 这样复杂度还是O(n)的解决f。
然后考虑向上走的情况g。 考虑已经求出了g[u], 现在要用u来求出他的所有孩子g[v], 对于一个点v来说,

g [ v ] = max ( g [ u ] + ( e ( u , f a [ u ] ) e ( u , v ) ) 2 , max x ! = v , x s o n [ u ] { f _ c h [ u ] [ x ] + ( e ( v , u ) e ( u , x ) ) 2 } )

对于第二个max是个经典的dp斜率优化的问题, 将e(u, v)排序后, 维护上凸包+单调队列, 正着做一遍反着做一遍即可。

PS: 关于dp斜率优化
考虑dp: f [ i ] = max { f [ j ] + ( e [ i ] e [ j ] ) 2 }
对与某个转移j, 将式子移项, 分离变量, 只和i有关的部分、 只和j有关的部分、 和i,j均有关的部分。

( f [ i ] e [ i ] 2 ) = ( f [ j ] + e [ j ] 2 ) ( 2 e [ i ] ) e [ j ]

f [ i ] e [ i ] 2 看作截距b, f [ j ] + e [ j ] 2 看作y, 2 e [ i ] 看作斜率k, e [ j ] 看作x。
上式可以看作线性函数b = y - kx。
每个j对应一个坐标(x,y), 一系列的j在图上就是一些点, 对于一个i就是一个询问, 每个i对应一个斜率k, 每个i求一个斜率为k的经过图中某个点的最大截距。
这里是取最大值故维护上凸包(取min则维护下凸包), 在本题中, 考虑将e[i]从小到打排序, 先正过来求一遍, 即每个i都会考虑一遍小于它的j。 维护一个单调队列, 对于询问i, 由于询问的斜率是递增的, 按上凸包顺时针方向看, 相邻点构成的斜率递减, 询问i的取最大值的点满足其向下一个点斜率小于询问i的斜率,向上一个点的斜率大于询问i的斜率, 又考虑到询问i的斜率是递增的来询问的, 凸包上的点也是按x坐标递增来加入的, 故应从单调队列的尾端扫描, 根据斜率的比较关系, 斜率越大的询问取最大值的点越靠前, 如果队尾的上一个点由于队尾, 说明对于一个更大的斜率也会优于队尾的, 故弹出队尾元素。 再将i对应的坐标点加入凸包中, 可以用向量叉积判断凸包走向来决定是否删除队尾元素 。 反向求一遍同理。

Code:

#include <map>
#include <set>
#include <map>
#include <bitset>
#include <cmath>
#include <queue>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>

#define ll long long
#define db double
#define pw(x) ((x) * (x))
#define fi first
#define se second
#define mp(x, y) make_pair(x, y)

using namespace std;

const int N = (int)1e5 + 10;

int n;
int cnt, lst[N], nxt[N * 2], to[N * 2]; ll c[N * 2], pre[N];
map<int, ll> f_ch[N];
map<int, ll> :: iterator it;
ll f[N], g[N];

void add(int u, int v, int w){
    nxt[++ cnt] = lst[u]; lst[u] = cnt; to[cnt] = v; c[cnt] = w;
    nxt[++ cnt] = lst[v]; lst[v] = cnt; to[cnt] = u; c[cnt] = w;
}

void dfs(int u, int fa){
    for (int j = lst[u]; j; j = nxt[j]){
        int v = to[j];
        if (v == fa) continue;
        pre[v] = c[j];
        dfs(v, u);

        ll &x = f_ch[u][v];
        x = 0;
        for (it = f_ch[v].begin(); it != f_ch[v].end(); it ++){
            x = max(x, it->se + pw(c[j] - pre[it->fi]));
            f[u] = max(f[u], x);
        }

    }
}

pair <ll, int > tmp[N]; int sz;
pair <ll, ll > que[N]; int head, tail;

pair <ll, ll> operator-(pair<ll, ll> a, pair<ll, ll>  b){
    return mp(a.fi-b.fi, a.se-b.se);
}
ll operator*(pair<ll, ll> a, pair<ll, ll>  b){
    return a.fi * b.se - a.se * b.fi;
}


ll cross(pair<ll, ll> a, pair<ll, ll> b, pair<ll, ll> c){
    return (a - b) * (b - c);
}

ll count(pair<ll, ll > x, ll e){
    return x.se-2*e*x.fi+pw(e);
}

void dfs2(int u, int fa){
    sz = 0;
    for (int j = lst[u]; j; j = nxt[j]){
        int v = to[j];
        if (v == fa) continue;
        tmp[++ sz] = mp(pre[v], v);
    }

    sort(tmp + 1, tmp + sz + 1);
    que[head = tail = 1] = mp(tmp[1].fi, pw(tmp[1].fi) + f_ch[u][tmp[1].se]);

    for (int i = 2; i <= sz; i ++){
        int v = tmp[i].se; ll e = tmp[i].fi;
        while (head < tail && count(que[tail], e) <= count(que[tail - 1], e)) tail --;
        g[v] = max(g[v], count(que[tail], e));

        pair<ll, ll> p = mp(e, f_ch[u][v] + pw(e));
        while (head < tail && cross(p, que[tail], que[tail - 1]) <= 0)
            tail --;
        que[++ tail] = p;
    }

    que[head = tail = 1] = mp(tmp[sz].fi, pw(tmp[sz].fi) + f_ch[u][tmp[sz].se]);
    for (int i = sz - 1; i >= 1; i --){
        int v = tmp[i].se; ll e = tmp[i].fi;
        while (head < tail && count(que[tail], e) <= count(que[tail - 1], e)) tail --;
        g[v] = max(g[v], count(que[tail], e));

        pair<ll, ll> p = mp(e, f_ch[u][v] + pw(e));
        while (head < tail && cross(p, que[tail], que[tail - 1]) >= 0)
            tail --;
        que[++ tail] = p;
    }

    if (fa){
        for (int i = 1; i <= sz; i ++){
            int v = tmp[i].se; ll e = tmp[i].fi;
            g[v] = max(g[v], g[u] + pw(pre[u] - e));
        }
    }

    for (int j = lst[u]; j; j = nxt[j]){
        int v = to[j];
        if (v == fa) continue;
        dfs2(v, u);
    }

}

int getint(){
    int ret = 0; char c = getchar();
    while (c > '9' || c < '0') c = getchar();
    while (c <= '9' && c >= '0'){
        ret = ret * 10 + c - '0';
        c = getchar();
    }
    return ret;
}

int main(){
    while (scanf("%d", &n) != EOF){
        for (int i = 2, u, v, w; i <= n; i ++){
            u = getint(), v = getint(), w = getint();

            add(u, v, w);
        }

        dfs(1, 0);

        dfs2(1, 0);

        for (int i = 1; i <= n; i ++){
            printf("%lld\n", max(f[i], g[i]));
            lst[i] = 0;
            f[i] = g[i] = 0;
            f_ch[i].clear();
        }
        cnt = 0;
    }

    return 0;
}

猜你喜欢

转载自blog.csdn.net/u013578420/article/details/81200654
今日推荐