7-5 1E. 树与路径(巧妙的树上差分)

在一棵有根树 T 上,任何两点间的最短路径都能够分为两个阶段:

从起点出发,沿着向根的方向走若干条边。

向着终点,沿着离开根的方向走若干条边。

定义一条路径的权值为向上走的边数乘上向下走的边数。特殊地,当起点等于终点的时候,两阶段的边数都是 0;当起点是终点的祖先的时候,第一阶段的边数是 0;当终点是起点的祖先的时候,第二阶段的边数是 0------这三种情况下,路径的权值都是 0。

现在给出一棵 n 个节点的无根树 T 和 m 条路径 (a​i​​ ,bi​​ )。对于每一个 r∈[1,n],你需要计算当 r 是根节点的时候,所有路径的权值和是多少。

输入格式:
第一行输入两个整数 n,m(1≤n,m≤3×10​5​​ )。
接下来 n−1 行每行输入两个整数 u​i​​ ,v​i​​ (1≤u​i​​ ,v​i​​ ≤n),表示树上的一条边。
接下来 m 行每行输入两个整数 a​i​​ ,bi​​ (1≤a​i​ ,b​i​​ ≤n),表示一条路径。
输出格式
输出 n 行每行一个整数,第 i 行表示以 i 为根时,所有路径的权值和。

思路:树上差分一个等差序列,可以化成常数和已知数的形式,真棒

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<cmath>
#include<queue>
#include<cstdlib>
#include<map>
#include<set>
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
#define x first
#define y second
#define int ll
using namespace std;
const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const int mod=1e9+7;
const int maxn=300100;
int head[maxn],ver[maxn<<1],nt[maxn<<1];
int d[maxn],f[maxn][22],lca[maxn];
int a[maxn],b[maxn],l[maxn];
int rk1[maxn],rk2[maxn],rk3[maxn];
int ans1[maxn];
int n,m,tot,x,y,t;
void add(int x,int y)
{
    ver[++tot]=y,nt[tot]=head[x],head[x]=tot;
}

void dfs1(int x,int fa)
{
    for(int i=head[x];i;i=nt[i])
    {
        int y=ver[i];
        if(y==fa) continue;
        d[y]=d[x]+1;
        f[y][0]=x;
        for(int j=1;j<=t;j++)
            f[y][j]=f[f[y][j-1]][j-1];
        dfs1(y,x);
    }
}

int lc(int x,int y)
{
    if(d[x]>d[y]) swap(x,y);
    for(int i=t;i>=0;i--)
        if(d[f[y][i]]>=d[x]) y=f[y][i];
    if(x==y) return x;
    for(int i=t;i>=0;i--)
        if(f[y][i]!=f[x][i]) y=f[y][i],x=f[x][i];
    return f[x][0];
}

void dfs2(int x,int fa)
{
    for(int i=head[x];i;i=nt[i])
    {
        int y=ver[i];
        if(y==fa) continue;
        dfs2(y,x);
        rk1[x]+=rk1[y];
        rk2[x]+=rk2[y];
        rk3[x]+=rk3[y];
    }
}

void dfs3(int x,int fa)
{
    for(int i=head[x];i;i=nt[i])
    {
        int y=ver[i];
        if(y==fa) continue;
        ans1[y]=ans1[x]-(rk1[y]+rk2[y]*d[y]+rk3[y]);
        dfs3(y,x);
        //cout<<"y:  "<<y<<"  ans1:  "<<ans1[y]<<endl;
    }
}

signed main(void)
{
    scanf("%lld%lld",&n,&m);
    t=log(n)/log(2)+1;

    for(int i=1;i<n;i++)
    {
        scanf("%lld%lld",&x,&y);
        add(x,y),add(y,x);
    }
    d[0]=-1;
    dfs1(1,0);

    int ans=0;
    for(int i=1;i<=m;i++)
    {
        scanf("%lld%lld",&a[i],&b[i]);
        lca[i]=lc(a[i],b[i]);
        l[i]=d[a[i]]+d[b[i]]-2*d[lca[i]];
        ans+=(d[a[i]]-d[lca[i]])*(d[b[i]]-d[lca[i]]);
        rk1[a[i]]+=l[i]-2*d[a[i]];
        rk1[b[i]]+=l[i]-2*d[b[i]];
        rk1[lca[i]]-=l[i]-2*d[a[i]]+l[i]-2*d[b[i]];
        rk2[a[i]]+=2;
        rk2[b[i]]+=2;
        rk2[lca[i]]-=4;
        rk3[a[i]]-=1;
        rk3[b[i]]-=1;
        rk3[lca[i]]+=2;
    }
    dfs2(1,0);
    ans1[1]=ans;
    dfs3(1,0);


    //cout<<ans<<endl;
    //cout<<d[2]<<endl;
    //for(int i=1;i<=n;i++)
    //    printf("rk1:%lld  rk2:%lld  rk3:%lld\n",rk1[i],rk2[i],rk3[i]);
    for(int i=1;i<=n;i++)
        printf("%lld\n",ans1[i]);

    return 0;
}




















发布了36 篇原创文章 · 获赞 11 · 访问量 647

猜你喜欢

转载自blog.csdn.net/weixin_43822647/article/details/103951476