九省联考2018 林克卡特树

版权声明:随意转载,愿意的话提一句作者就好了 https://blog.csdn.net/stone41123/article/details/84061375

Link

Difficulty

算法难度7,思维难度7,代码难度5

Description

给定一棵 n n 个点的树,边带权值,要求你选出 k + 1 k+1 条链,使得权值和最大。

1 k < n 3 × 1 0 5 v 1 0 6 1\le k<n\le 3\times 10^5,|v|\le 10^6

Solution

前面的小部分分我就不说了,说一下和正解有极大联系的60分的树形dp吧。

首先我们考虑设计dp状态。

第一想法是 d p ( i , j ) dp(i,j) 代表在 i i 的子树中选了 j j 条链的最大价值,看起来非常美好。

但是仔细想想发现没法写状态转移方程,因为不知道到底能不能和儿子连边,也不知道连边会发生什么事。

这样我们就发现我们还要记录一下每个点的连边状态。

d p ( i , j , 0 / 1 / 2 ) dp(i,j,0/1/2) 代表在 i i 的子树中完整地选了 j j 条链的最大价值, 0 / 1 / 2 0/1/2 代表点 i i 的度数。

首先初始状态:

  1. d p ( i , 0 , 0 ) = 0 dp(i,0,0)=0 ,代表这个点可以不选。
  2. d p ( i , 0 , 1 ) = 0 dp(i,0,1)=0 ,代表这个点可以作为链最下面的点向上连。
  3. d p ( i , 1 , 2 ) = 0 dp(i,1,2)=0 ,代表这个点可以单独作为一条链,至于为什么要有这个状态,只需要想一下极端情况 k + 1 = n k+1=n 时,合法答案是什么样子的就可以了。
  4. 其他都为负无穷,也就是不合法

考虑转移状态,将儿子 u u 的状态合并到点 x x :( i : k 1 i:k\to 1 代表 i i k k 枚举到 1 1 ,下面不再描述)

  1. d p ( x , i , 0 ) = m a x ( d p ( x , i , 0 ) , d p ( x , i j , 0 ) + d p ( u , j , 0 ) ) dp(x,i,0)=max(dp(x,i,0),dp(x,i-j,0)+dp(u,j,0)) ,其中 i : k 1 , j : 1 k i:k\to 1,j:1\to k

    代表不选或者选 j j 条链。

  2. d p ( x , i , 1 ) = m a x ( d p ( x , i , 1 ) , d p ( x , i j , 1 ) + d p ( u , j , 0 ) , d p ( x , i j , 0 ) + d p ( u , j , 0 ) + v a l ( x , u ) ) dp(x,i,1)=max(dp(x,i,1),dp(x,i-j,1)+dp(u,j,0),dp(x,i-j,0)+dp(u,j,0)+val(x,u)) ,其中 i : k 1 , j : 1 k i:k\to 1,j:1\to k

    代表不选,选 j j 条链,或者选 j j 条链并且选这条边。

  3. d p ( x , i , 2 ) = m a x ( d p ( x , i , 2 ) , d p ( x , i j , 2 ) + d p ( u , j , 0 ) , d p ( x , i j , 1 ) + d p ( u , j 1 , 1 ) + v a l ( x , u ) ) dp(x,i,2)=max(dp(x,i,2),dp(x,i-j,2)+dp(u,j,0),dp(x,i-j,1)+dp(u,j-1,1)+val(x,u)) ,其中 i : k 1 , j : 1 k i:k\to 1,j:1\to k

    代表不选,选 j j 条链,或者选 j 1 j-1 条链并且选这条边增加一条链。

  4. d p ( x , i , 1 ) = m a x ( d p ( x , i , 1 ) , d p ( x , i , 0 ) + d p ( u , 0 , 0 ) + v a l ( x , u ) ) dp(x,i,1)=max(dp(x,i,1),dp(x,i,0)+dp(u,0,0)+val(x,u)) ,其中 i k 1 i:k\to 1

    这个看起来跟上面的第二个转移的重复了,事实上并没有,因为这个转移既合法,第二个转移又转移不到。

  5. d p ( x , 0 , 1 ) = m a x ( d p ( x , 0 , 1 ) , d p ( u , 0 , 1 ) + v a l ( x , u ) ) dp(x,0,1)=max(dp(x,0,1),dp(u,0,1)+val(x,u))

    代表从下面连上来,同样是第二个转移没有转移到的。

  6. d p ( x , i , 0 ) = m a x ( d p ( x , i , 0 ) , d p ( x , i 1 , 1 ) , d p ( x , i , 2 ) ) dp(x,i,0)=max(dp(x,i,0),dp(x,i-1,1),dp(x,i,2)) ,其中 i 1 k i:1\to k

    代表不选,在这里停止这条链并计入总数,或者把那两个度数去掉。

转移方程大概就是这些了,dp的顺序呀,细节呀,就看我的代码吧。

这样的话复杂度有些玄学(调循环边界的话),我不太会算,反正只能有 45 45 分,会TLE。

本来想把这个dp放到dfs序上说不定就可以到 O ( n k ) O(nk) 了,后来发现我不会QAQ

这个dp必须先写一下,因为凸优化的代码就是在dp的基础上改的。

拿到45分之后,我们来看这题正解吧。

凸优化

凸优化就是针对凸函数求极值的优化。

我们这里不直接探究它的定义及一般情况,我们直接来看这个题,通过这个题来理解凸优化。

首先,通过打表可以发现,答案的函数是上凸的,对于样例来说画出来是这样的:


虽然图像有点儿尖,但是它确实是上凸的。

怎么直接判断一个题的答案是否上凸呢?

我们可以感性判断,比如对于这个题,假如只能选一条链的话,一定是选最长的,选两条的话,增长的就没有第一条那么多了,因为最长的已经选过了,这样来看,增长只会越来越慢,所以它是凸函数。

现在我们知道它是凸函数了,应该怎么做呢?

我们二分一个权值 m i d mid ,代表选一条链需要付出的代价,然后我们去掉选多少条链那一维,还按照原来的dp做。

这样子相当于我们拿 y = m i d × x y=mid\times x 的直线去切答案函数,在这个基础上求极值。

但是我们发现这样求得极值之后,无法判断下一次 m i d mid 变小还是变大。

我们同样可以发现,切了之后的可以取得极值的点是一段连续的区间。

因此,在此基础上我们再记录取得极值的最小的 k k 是多少,也就是区间的左端点是多少。

假如题目中的 k k 等于左端点的话,直接输出答案。

假如题目中的 k k 一定不在这个区间内(左端点大于 k k ),则令 l = m i d + 1 l=mid+1 ,让选的代价变大,左端点减小。

假如题目中的 k k 有可能在这个区间内(左端点小于 k k ),则令 r = m i d r=mid ,让选的代价变小,左端点增大。

最后令 m i d = l mid=l ,再做一次得到最终答案,并且把那个选的代价加回来,就好了。

感性理解一下这个过程,感觉挺对的QAQ

然后这个做法就叫凸优化啦,是不是感觉也没什么难的?

时间复杂度 O ( n l o g V ) O(nlogV) ,还有树形dp常数挺大,所以跑得比较慢。

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#define LL long long
using namespace std;
inline int read(){
    int x=0,f=1;char ch=' ';
    while(ch<'0' || ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0' && ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
    return f==1?x:-x;
}
const int N=3e5+5,K=105;
const LL inf=1e18;
int n,k,tot;
int head[N],to[N<<1],Next[N<<1],val[N<<1];
struct data{
    LL x,y;
    data(){}
    data(LL _x,LL _y):x(_x),y(_y){}
    inline bool operator < (const data& b) const {
        if(x==b.x)return y>b.y;
        return x<b.x;
    }
    inline data operator + (const data& b) const {return data(x+b.x,y+b.y);}
    inline data operator + (LL b) const {return data(x+b,y);}
}dp[N][3];
inline void addedge(int x,int y,int l){
    to[++tot]=y;
    Next[tot]=head[x];
    head[x]=tot;
    val[tot]=l;
}
LL mid;
inline void dfs(int x,int fa){
    dp[x][0]=data(0,0);
    dp[x][1]=data(0,0);
    dp[x][2]=max(data(0,0),data(-mid,1));
    for(int i=head[x];i;i=Next[i]){
        int u=to[i];
        if(u==fa)continue;
        dfs(u,x);
        dp[x][2]=max(dp[x][2],max(dp[x][2]+dp[u][0],dp[x][1]+dp[u][1]+val[i]+data(-mid,1)));
        dp[x][1]=max(dp[x][1],max(dp[x][1]+dp[u][0],dp[x][0]+dp[u][1]+val[i]));
        dp[x][0]=max(dp[x][0],dp[x][0]+dp[u][0]);
    }
    dp[x][0]=max(dp[x][0],data(0,0));
    dp[x][0]=max(dp[x][0],max(dp[x][1]+data(-mid,1),dp[x][2]));
}
int main(){
    n=read();k=read()+1;
    for(int i=1;i<n;++i){
        int x=read(),y=read(),l=read();
        addedge(x,y,l);addedge(y,x,l);
    }
    LL l=-1e12,r=1e12;
    while(l<r){
        mid=(double)(l+r)/2-0.5;
        dfs(1,0);
        if(dp[1][0].y==k){
            printf("%lld\n",dp[1][0].x+k*mid);
            return 0;
        }
        else if(dp[1][0].y>k)l=mid+1;
        else r=mid;
    }
    mid=l;
    dfs(1,0);
    printf("%lld\n",dp[1][0].x+k*mid);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/stone41123/article/details/84061375