BZOJ 1912(树的直径+LCA)

题面

传送门

分析

显然,如果不加边,每条边都要走2次,总答案为2(n-1)

考虑k=1的朴素情况:

加一条边(a,b),这条边和树上a->b的路径形成一个环,这个环上的边只需要走一遍,所以答案会减少dist(a,b)-1 (a->b的路径少走一边,但是又多加了一条边,最终答案为2*(n-1)-dist(a,b)+1

显然dist(a,b)应该最大,求树上的直径\(L_1\)即可

接着推广到k=2的情况

加第二条边时,又会形成一个环,这个环和第一条边形成的环可能有重叠,重叠部分要走两遍,答案又会增加

因此,我们把直径上的边权取反(1变为-1),在直径取反后的树上再次求直径,设直径为\(L_2\)

答案就是\(2(n-1)-(L_1-1)-(L_2-1)\)

注意两次BFS求直径的方法在负权图上求直径不成立,需要用树形DP

具体方法如下:

1.在原图上用两次BFS求出直径\(L_1\)端点a,b

2.通过LCA和树上差分将直径上的边标记,并取反边权

3.用树形DP求出新直径\(L_2\)

代码

#include<iostream>
#include<cstdio>
#include<queue>
#include<cstring>
#include<cmath>
#define maxlog 32
#define maxn 400005
#define INF 0x7fffffff
using namespace std;
int n,k;
struct edge {
    int from;
    int to;
    int next;
    int len;
} E[maxn*2];
int sz=1;
int head[maxn];
void add_edge(int u,int v,int w) {
    sz++;
    E[sz].from=u;
    E[sz].to=v;
    E[sz].len=w;
    E[sz].next=head[u];
    head[u]=sz;
}

namespace k_1 {
    struct node {
        int x;
        int t;
        node() {

        }
        node(int u,int dis) {
            x=u;
            t=dis;
        }
    };
    int vis[maxn];
    node bfs(int s) {
        node now,ans;
        queue<node>q;
        q.push(node(s,0));
        memset(vis,0,sizeof(vis));
        ans.t=-INF;
        while(!q.empty()) {
            now=q.front();
            q.pop();
            int x=now.x;
            vis[x]=1;
            if(now.t>ans.t) {
                ans.t=now.t;
                ans.x=now.x;
            }
            for(int i=head[x]; i; i=E[i].next) {
                int y=E[i].to;
                if(!vis[y]) {
                    q.push(node(y,now.t+1));
                }
            }
        }
        return ans;
    }

    void solve() {
        node p=bfs(1);
        node q=bfs(p.x);
        printf("%d\n",2*(n-1)-q.t+1);
    }
}

namespace k_2 {
    int log2n;
    struct node {
        int x;
        int t;
        node() {

        }
        node(int u,int dis) {
            x=u;
            t=dis;
        }
    };
    int vis[maxn];
    int mark[maxn];

    node bfs(int s,int tim) {
        node now,ans;
        queue<node>q;
        q.push(node(s,0));
        memset(vis,0,sizeof(vis));
        ans.t=-INF;
        ans.x=s;
        if(n==1) return node(1,0);
        while(!q.empty()) {
            now=q.front();
            q.pop();
            int x=now.x;
            vis[x]=1;
            if(tim==1) {
                if(now.t>=ans.t&&now.x!=s) {
                    ans.t=now.t;
                    ans.x=now.x;
                }
            }else{
                if(now.t>=ans.t) {
                    ans.t=now.t;
                    ans.x=now.x;
                }
            } 
            for(int i=head[x]; i; i=E[i].next) {
                int y=E[i].to;
                if(!vis[y]) {
                    q.push(node(y,now.t+E[i].len));
                }
            }
        }
        return ans;
    }

    int anc[maxn][maxlog];
    int deep[maxn];
    void dfs1(int x,int fa) {
        deep[x]=deep[fa]+1;
        anc[x][0]=fa;
        for(int i=1; i<=log2n; i++) {
            anc[x][i]=anc[anc[x][i-1]][i-1];
        }
        for(int i=head[x]; i; i=E[i].next) {
            int y=E[i].to;
            if(y!=fa) {
                dfs1(y,x);
            }
        }
    }

    void dfs2(int x,int fa) {
        for(int i=head[x]; i; i=E[i].next) {
            int y=E[i].to;
            if(y!=fa) {
                dfs2(y,x);
                mark[x]+=mark[y];
            }
        }
    }

    int lca(int x,int y) {
        if(deep[x]<deep[y]) swap(x,y);
        for(int i=log2n; i>=0; i--) {
            if(deep[anc[x][i]]>=deep[y]) x=anc[x][i];
        }
        if(x==y) return x;
        for(int i=log2n; i>=0; i--) {
            if(anc[x][i]!=anc[y][i]) {
                x=anc[x][i];
                y=anc[y][i];
            }
        }
        return anc[x][0];
    }
    
    int len=0;
    int d[maxn];
    void dp(int x,int fa){
        d[x]=0;
        for(int i=head[x];i;i=E[i].next){
            int y=E[i].to;
            if(y!=fa){
                dp(y,x);
                len=max(len,d[x]+d[y]+E[i].len);
                d[x]=max(d[x],d[y]+E[i].len);
            }
        }
    }
    
    void solve() {
        int u,v;
        int ans=2*(n-1);
        log2n=log2(n)+1;
        dfs1(1,0);
        node p=bfs(1,1);
        node q=bfs(p.x,2);
        ans=ans-q.t+1;

        memset(mark,0,sizeof(mark));
        mark[p.x]++;
        mark[q.x]++;
        mark[lca(p.x,q.x)]-=2;
        dfs2(1,0);

        memset(head,0,sizeof(head));
        memset(E,0,sizeof(E));
        sz=1;
        for(int i=2; i<=n; i++) {
            add_edge(i,anc[i][0],mark[i]==1?-1:1);
            add_edge(anc[i][0],i,mark[i]==1?-1:1);
        }
    
        memset(d,-0x3f,sizeof(d));
        dp(1,0);
        ans=ans-len+1;

        printf("%d\n",ans);
    }
}
int main() {
    int u,v;
    scanf("%d %d",&n,&k);
    sz=1;
    for(int i=1; i<n; i++) {
        scanf("%d %d",&u,&v);
        add_edge(u,v,1);
        add_edge(v,u,1);
    }
    if(k==1) k_1::solve();
    else k_2::solve();
}

猜你喜欢

转载自www.cnblogs.com/birchtree/p/10198021.html