HDU 5956 The Elder (树形DP + 斜率优化)

题目链接

容易看出来是一个树形dp,并且有一个非常显然的状态转移方程:

dp[u]=min\left \{ dp[v]+(dis[u]-dis[v])^{2}+p \right \},其中v是树上从u到根节点路径上的点。

但是显然这样的时间复杂度在树退化成链的时候会达到O\left ( N^{2} \right ),需要想办法来进行优化。尝试进行变形:

如果状态v和w都可以转移到状态u,那么在这种情况下,从状态v转移会更优:

dp[v]+(dis[u]-dis[v])^{2}+p<dp[w]+(dis[u]-dis[w])^{2}+p

dp[v]+(dis[v])^{2}-2*dis[u]*dis[v]<dp[w]+(dis[w])^{2}-2*dis[u]*dis[w]

dp[v]+(dis[v])^{2}-dp[w]-(dis[w])^{2}<2*dis[u]*(dis[v]-dis[w])

\frac{dp[v]+(dis[v])^{2}-dp[w]-(dis[w])^{2}}{dis[v]-dis[w]}<2*dis[u]

\text{ let } f[x]= dp[x]+(dis[x])^{2}\text{, }\frac{f[v]-f[w]}{dis[v]-dis[w]}<2*dis[u]

我们发现上式变成了一个斜率的形式。考虑将(dis[i], f[i])的点绘制出来,如果出现了下面的情况:

,那么通过枚举各种情况,我们可以分析出来 j 处必不可能是较优的点。

也就是说,有可能作为最优解进行转移的状态,它们的点必然是在一个下凸壳上的。

每次在得到一个新的状态的时候,由于dis[]的单调性,它的位置必然是在这个半凸壳的右端处。由于dis[]的单调性,f[]也是满足单调递增,这样就可以用一个单调队列来维护半凸壳上的点。

对于每个新的状态u,具体的维护方法为:

1. 检查队头的两个元素q[l]和q[l+1],通过上面的斜率检查,如果q[l+1]比q[l]更优,那么就把q[l]出队。

2. 直接取队头的元素为目标状态,进行状态转移,计算出f[u]。

3. 将u插入队尾。插入之前需要检查三个状态q[r-1], q[r], u是否满足斜率单调递增,若不满足则将q[r]出队。

这样就将整个DP的时间复杂度优化到了O(N)

需要注意的是,由于每个节点可能有多个子节点,因此每次转移之后要将队尾恢复为原来的元素。

#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <queue>
#include <map>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const int maxn = 100050;
const ll INF = (1LL << 62) - 1;
const double eps = 1e-8;
const ll mod = 2147493647;
const double pi = acos(-1.0);

int t, n, a, b, l, r;
int no, head[maxn], q[maxn];
ll dp[maxn], dis[maxn], ans, m, x;
struct node
{
    int to, nxt;
    ll w;
}e[maxn << 1];
void add(int a, int b, ll x)
{
    e[no].to = b;
    e[no].nxt = head[a];
    e[no].w = x;
    head[a] = no++;
}

ll gety(int u, int v)
{
    return dp[u] + dis[u]*dis[u] - dp[v] - dis[v]*dis[v];
}
ll getx(int u, int v) {return dis[u] - dis[v];}
ll getdp(int u, int v) {return dp[v] + m + (dis[u] - dis[v])*(dis[u] - dis[v]);}

void pre(int u, int fa)
{
    for(int i = head[u];i != -1;i = e[i].nxt)
    {
        int v = e[i].to;
        if(v == fa) continue;
        dis[v] = dis[u] + e[i].w;
        pre(v, u);
    }
}

void dfs(int u, int fa, int l, int r)
{
    int pre = -1;
    while(l < r && gety(q[l+1], q[l]) <= 2*dis[u]*getx(q[l+1], q[l])) l++;
    dp[u] = min(dp[u], getdp(u, q[l]));
    while(l < r && getx(u, q[r])*gety(q[r], q[r-1]) >= getx(q[r], q[r-1])*gety(u, q[r])) r--;
    pre = q[++r], q[r] = u;
    ans = max(ans, dp[u]);
    for(int i = head[u];i != -1;i = e[i].nxt)
    {
        int v = e[i].to;
        if(v == fa) continue;
        dfs(v, u, l, r);
    }
    if(pre != -1) q[r] = pre;
}

void init()
{
    no = r = 0, l = 1;
    memset(dis, 0, sizeof(dis));
    memset(head, -1, sizeof(head));
    ans = dp[1] = q[0] = 0;
}

int main()
{
    scanf("%d", &t);
    while(t--)
    {
        scanf("%d%lld", &n, &m);
        init();
        for(int i = 1;i < n;i++)
        {
            scanf("%d%d%lld", &a, &b, &x);
            add(a, b, x), add(b, a, x);
        }
        pre(1, -1);
        for(int i = 1;i <= n;i++)
            dp[i] = dis[i]*dis[i];
        dfs(1, -1, 1, 0);
        printf("%lld\n", ans);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/NPU_SXY/article/details/82459596