【C++】dfs序与树上差分

引入

dfs序就是用递归的方法遍历一棵树的顺序。
这是一个括号序列,有助于解决树上的很多问题。

例1 DFS 序 1

LibreOJ-144
LOJ vjudge

题目描述

这是一道模板题。

给一棵有根树,这棵树由编号为 1... N 1...N N N 个结点组成。根结点的编号为 R R 。每个结点都有一个权值,结点 i i 的权值为 v i v_i
接下来有 M M 组操作,操作分为两类:

  • 1 a x,表示将结点 a a 的权值增加 x x
  • 2 a,表示求结点 a a 的子树上所有结点的权值之和。

输入格式

第一行有三个整数 N , M N,M R R
第二行有 N N 个整数,第 i i 个整数表示 v i v_i
在接下来的 N 1 N-1 行中,每行两个整数,表示一条边。
在接下来的 M M 行中,每行一组操作。

输出格式

对于每组 2 a 操作,输出一个整数,表示「以结点 a a 为根的子树」上所有结点的权值之和。

样例输入 1

10 14 9
12 -6 -4 -3 12 8 9 6 6 2
8 2
2 10
8 6
2 7
7 1
6 3
10 9
2 4
10 5
1 4 -1
2 2
1 7 -1
2 10
1 10 5
2 1
1 7 -5
2 5
1 1 8
2 7
1 8 8
2 2
1 5 5
2 6

样例输出 1

21
34
12
12
23
31
4

数据范围与提示

1 N , M 1 0 6 , 1 R N , 1 0 6 v i , x 1 0 6 . 1⩽N,M⩽10^6, 1⩽R⩽N, −10^6⩽v_i,x⩽10^6.

解析

操作:

  • 1.给某个点增加x

  • 我们只要给对应的序号的数加上x就可以了。

  • 2.询问子树之和

  • 一颗子树的dfs序是连续的,对应了一个连续的 区间,所以我们查询区间和。

  • 维护一个树状数组即可。

代码

#pragma GCC optimize(3,"Ofast","inline")
#pragma G++ optimize(3,"Ofast","inline")

#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>

#define R                  register int
#define re(i,a,b)          for(R i=a; i<=b; i++)
#define ms(i,a)            memset(a,i,sizeof(a))
#define MAX(a,b)           (((a)>(b)) ? (a):(b))
#define MIN(a,b)           (((a)<(b)) ? (a):(b))

using namespace std;

typedef long long LL;

int const N=1000005;

int n,m,r,cnt,sum;
int a[N],tin[N],tout[N],h[N];
LL s[N];

struct Edge{
    int to,nt;
} e[N<<1];

inline void add(int a,int b) {
    e[++cnt].to=b,e[cnt].nt=h[a],h[a]=cnt;
    e[++cnt].to=a,e[cnt].nt=h[b],h[b]=cnt;
}

inline void Add(int x,LL v) {
    while(x<=n) s[x]+=v,x+=x&-x;
}

inline LL getsum(int x) {
    LL ret=0;
    while(x) ret+=s[x],x-=x&-x;
    return ret;
}

void dfs(int x,int fa) {
    tin[x]=++sum;
    for(int i=h[x]; i; i=e[i].nt) {
        int v=e[i].to;
        if(v==fa) continue;
        dfs(v,x);
    }
    tout[x]=sum;
}

int main() {
    scanf("%d%d%d",&n,&m,&r);
    for(int i=1; i<=n; i++) scanf("%d",&a[i]);
    for(int i=1; i<=n-1; i++) {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y);
    }
    dfs(r,r);
    for(int i=1; i<=n; i++) Add(tin[i],a[i]);
    while(m--) {
        int k,x,y;
        scanf("%d",&k);
        if(k==1) {
            scanf("%d%d",&x,&y);
            Add(tin[x],y);
        } else {
            scanf("%d",&x);
            int l=tin[x];
            int r=tout[x];
            printf("%lld\n",getsum(r)-getsum(l-1));
        }
    }
    return 0;
}

例2 DFS 序 2

LibreOJ-145
LOJ vjudge

题目描述

这是一道模板题。

给一棵有根树,这棵树由编号为 1... N 1...N N N 个结点组成。根结点的编号为 R R 。每个结点都有一个权值,结点 i i 的权值为 v i v_i
接下来有 M M 组操作,操作分为两类:

  • 1 a x,表示将结点 a a 的子树上所有结点的权值增加 x x
  • 2 a,表示求结点 a a 的子树上所有结点的权值之和。

输入格式

第一行有三个整数 N , M N,M R R
第二行有 N N 个整数,第 i i 个整数表示 v i v_i
在接下来的 N 1 N-1 行中,每行两个整数,表示一条边。
在接下来的 M M 行中,每行一组操作。

输出格式

对于每组 2 a 操作,输出一个整数,表示「以结点 a a 为根的子树」上所有结点的权值之和。

样例输入

10 14 9
12 -6 -4 -3 12 8 9 6 6 2
8 2
2 10
8 6
2 7
7 1
6 3
10 9
2 4
10 5
1 4 -1
2 2
1 7 -1
2 10
1 10 5
2 1
1 7 -5
2 5
1 1 8
2 7
1 8 8
2 2
1 5 5
2 6

样例输出

21
33
16
17
27
76
30

数据范围与提示

1 N , M 1 0 6 , 1 R N , 1 0 6 v i , x 1 0 6 . 1⩽N,M⩽10^6, 1⩽R⩽N, −10^6⩽v_i,x⩽10^6.

解析

操作:

  • 1.子树加
  • 子树对应了一个连续的区间,那么就是一个区间修改
  • 2.子树查询
  • 区间查询
  • 维护一个数组数组

代码

#pragma GCC optimize(3,"Ofast","inline")
#pragma G++ optimize(3,"Ofast","inline")

#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>

#define R                  register int
#define re(i,a,b)          for(R i=a; i<=b; i++)
#define ms(i,a)            memset(a,i,sizeof(a))
#define MAX(a,b)           (((a)>(b)) ? (a):(b))
#define MIN(a,b)           (((a)<(b)) ? (a):(b))

#define lowbit(x)          ((x) & (-x))

using namespace std;

typedef long long LL;

int const N=1000005;

int n,m,r,cnt,sum;
int a[N],tin[N],tout[N],h[N],id[N];
LL s[N],ss[N];

void read(int &x) {
    x=0;
    char c=0;
    int w=0;
    while (!isdigit(c)) w|=c=='-',c=getchar();
    while (isdigit(c)) x=x*10+(c^48),c=getchar();
    if(w) x = -x;
}

struct edge {
    int to, nt;
} e[N << 1];

void add(int a, int b) {
    e[++cnt].to = b;
    e[cnt].nt = h[a];
    h[a] = cnt;
    e[++cnt].to = a;
    e[cnt].nt = h[b];
    h[b] = cnt;
}

void dfs(int x, int fa) {
    tin[x] = ++sum;
    id[sum] = x;
    for (int i = h[x]; i; i = e[i].nt) {
        int v = e[i].to;
        if (v == fa)
            continue;
        dfs(v, x);
    }
    tout[x] = sum;
}

void Add(int x, int v) {
    for (int i = x; i <= n; i += lowbit(i)) {
        s[i] += v;
        ss[i] += (LL)v * (x - 1);
    }
}

LL getsum(int x) {
    LL ret = 0;
    for (int i = x; i; i -= lowbit(i)) {
        ret += (LL)x * s[i];
        ret -= ss[i];
    }
    return ret;
}

int main() {
    read(n);
    read(m);
    read(r);
    for (int i = 1; i <= n; i++) read(a[i]);
    for (int i = 1; i < n; i++) {
        int x, y;
        scanf("%d%d", &x, &y);
        add(x, y);
    }
    dfs(r, r);
    for (int i = 1; i <= n; i++) Add(i, a[id[i]] - a[id[i - 1]]);
    while (m--) {
        int k, x, y;
        scanf("%d", &k);
        if (k == 1) {
            scanf("%d%d", &x, &y);
            int l = tin[x];
            int r = tout[x];
            Add(l, y);
            Add(r + 1, -y);
        } else {
            scanf("%d", &x);
            int l = tin[x];
            int r = tout[x];
            printf("%lld\n", getsum(r) - getsum(l - 1));
        }
    }
    return 0;
}

例3 DFS 序 3,树上差分 1

LibreOJ-146
LOJ vjudge

题目描述

这是一道模板题。

不保证无快读的程序能过。请务必使用快读。

给一棵有根树,这棵树由编号为 1 N 1…N N N 个结点组成。根结点的编号为 R R 。每个结点都有一个权值,结点 i i 的权值为 v i v_i
接下来有 M 组操作,操作分为三类:

  • 1 a b x,表示将「结点 a a 到结点 b b 的简单路径」上所有结点的权值都增加 x x
  • 2 a,表示求结点 a a 的权值。
  • 3 a,表示求 a a 的子树上所有结点的权值之和。

输入格式

第一行有三个整数 N , M N,M R R
第二行有 N N 个整数,第 i i 个整数表示 v i v_i
在接下来的 N 1 N−1 行中,每行两个整数,表示一条边。
在接下来的 M M 行中,每行一组操作。

输出格式

对于每组 2 a 操作,输出一个整数,表示结点 a 的权值。

样例输入 1

10 15 3
4 8 -2 -4 -7 -7 -9 5 2 5
3 9
3 4
4 5
4 8
8 7
3 6
8 2
9 10
2 1
2 5
1 4 7 3
1 7 2 6
1 6 7 -7
2 1
1 10 10 -9
2 4
1 2 9 -8
2 6
1 10 5 -2
1 4 4 6
1 6 1 3
1 1 10 2
1 9 2 0
2 7

样例输出 1

-7
4
-8
-14
-7

样例输入 2

10 17 3
5 1 -7 -9 -5 3 -7 -5 3 3
1 8
8 7
7 6
8 3
6 10
7 2
6 9
1 4
6 5
2 9
1 10 4 -2
2 8
1 1 10 -2
3 5
1 10 6 -3
3 1
1 6 5 9
2 8
1 4 5 1
2 10
1 2 5 6
1 2 6 0
1 2 7 -5
1 4 9 6
1 10 1 0
3 2

样例输出 2

3
-7
-5
-10
-9
-4
2

数据范围与提示

40 40 % 的数据不含操作 3
对于所有数据, 1 N , M 1 0 6 , 1 R N , 1 0 6 v i , x 1 0 6 . 1⩽N,M⩽10^6, 1⩽R⩽N, −10^6⩽v_i,x⩽10^6.

解析

本题可以做树剖,但是树剖的时间复杂度是: O ( n l o g n l o g n ) O(nlogn*logn)

我们可以做树上差分

  • 对于操作1: a到b路径上每个数增加x,我们可 以给a和b打一个+x的标记,lca打一个-x的标记,
  • lca的父亲打一个-x的标记。这样就可以处理操 作2和操作3了。
  • 对于操作2: 查询点的值,就是查询这个子树 的和。
  • 对于操作3:我们考虑子树里面每个修改对答案 的贡献,假设我们要查询以u为根的子树,子树 里面有一个点v,我们考虑v对答案的贡献就是 v a l [ v ] ( d e p [ v ] d e p [ u ] + 1 ) val[v]*(dep[v]-dep[u]+1) ,拆开以后就是 v a l [ v ] d e p [ v ] v a l [ v ] ( d e p [ u ] 1 ) val[v]*dep[v]-val[v]*(dep[u]-1) ,分别维护两个树状数组即可。

代码

#pragma GCC optimize(3,"Ofast","inline")
#pragma G++ optimize(3,"Ofast","inline")

#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>

#define R                  register int
#define re(i,a,b)          for(R i=a; i<=b; i++)
#define ms(i,a)            memset(a,i,sizeof(a))
#define MAX(a,b)           (((a)>(b)) ? (a):(b))
#define MIN(a,b)           (((a)<(b)) ? (a):(b))

#define lowbit(x)          ((x) & (-x))

using namespace std;

typedef long long LL;

int const N=1000005;

inline void read(int &x){
    x=0; 
    char c=0; 
    int w=0;  
    while (!isdigit(c)) w|=c=='-',c=getchar();  
    while (isdigit(c)) x=x*10+(c^48),c=getchar();  
    if(w) x=-x;  
}

struct edge{
    int to,nt;  
} e[N<<1];  

int n,m,rt,cnt,sum;
int dep[N],tin[N],tout[N],h[N],a[N];
int f[N][20];  
LL s[N],ss[N],t[N];    

void add(int a,int b){
    e[++cnt].to=b,e[cnt].nt=h[a],h[a]=cnt; 
    e[++cnt].to=a,e[cnt].nt=h[b],h[b]=cnt; 
}

void dfs(int x,int fa,int d){
    dep[x]=d; 
    tin[x]=++sum; 
    f[x][0]=fa;   
    for(int i=h[x]; i; i=e[i].nt){
        int v=e[i].to;   
        if(v==fa) continue; 
        dfs(v,x,d+1);  
    }
    tout[x]=sum;  
}

void Add(int x,int v){
    for(int i=x;i<=n;i+=lowbit(i))
        t[i]+=v;  
}

int ancestor(int x,int y){
    return tin[x]<=tin[y] && tout[y]<=tout[x]; 
}

int lca(int x,int y){
    if(ancestor(x,y)) return x; 
    if(ancestor(y,x)) return y;  
    for(int i=19; i>=0; i--)   
        if(!ancestor(f[x][i],y))  
            x=f[x][i];  
    return f[x][0]; 
}

LL getsum(int x,LL s[]){
    LL ret=0;  
    for(int i=x; i; i-=lowbit(i))  
        ret+=s[i];  
    return ret;   
}

inline void Add2(int x,LL d,int v){  
    for(int i=x;i<=n;i+=lowbit(i)) {
        s[i]+=v;  
        ss[i]+=d*v; 
    }
}

int main(){
    read(n); 
    read(m); 
    read(rt); 
    for(int i=1; i<=n; i++) read(a[i]); 
    for(int i=1; i<n; i++) {
        int x,y; 
        read(x); 
        read(y); 
        add(x,y); 
    }
    dfs(rt,rt,1);   
    for(int j=1; j<20; j++) for(int i=1; i<=n; i++) f[i][j]=f[f[i][j-1]][j-1]; 
    for(int i=1;i<=n;i++) Add(tin[i],a[i]);  
    while(m--) {
        int k,l,r,x;  
        read(k);  
        if(k==1){
            read(l); 
            read(r); 
            read(x);  
            Add2(tin[l],dep[l],x);  
            Add2(tin[r],dep[r],x);  
            int t=lca(l,r); 
            Add2(tin[t],dep[t],-x);  
            if(t!=rt) Add2(tin[f[t][0]],dep[t]-1,-x);  
        } else if(k==2) {
            read(l);  
            printf("%lld\n",getsum(tout[l],s)-getsum(tin[l]-1,s)+a[l]);  
        } else {
            read(l);  
            LL t1=getsum(tout[l],t)-getsum(tin[l]-1,t);  
            LL t2=getsum(tout[l],ss)-getsum(tin[l]-1,ss);
            LL t3=getsum(tout[l],s)-getsum(tin[l]-1,s);  
            printf("%lld\n",t2-t3*(dep[l]-1)+t1);    
        }
    }
    return 0; 
}

例4 DFS序4

LibreOJ-147
LOJ vjudge

题目描述

这是一道模板题。

本题严重卡常,请务必使用 fread 快读,不保证无快读的程序能过(虽然标程没用快读)。另外,建议使用 Tarjan 或树剖求 LCA。

给一棵有根树,这棵树由编号为 1 N 1…N N N 个结点组成。根结点的编号为 R R 。每个结点都有一个权值,结点 i i 的权值为 v i v_i
接下来有 M M 组操作,操作分为三类:

  • 1 a x,表示将结点 a a 的权值增加 x x
  • 2 a x,表示将 a a 的子树上所有结点的权值增加 x x
  • 3 a b,表示求「结点 a a 到结点 b b 的简单路径」上所有结点的权值之和。

输入格式

第一行有三个整数 N , M N,M R R
第二行有 N N 个整数,第 i i 个整数表示 v i v_i
在接下来的 N 1 N−1 行中,每行两个整数,表示一条边。
在接下来的 M M 行中,每行一组操作。

输出格式

对于每组 3 a b 操作,输出一个整数,表示「结点 a a 到结点 b b 的简单路径」上所有结点的权值之和(含结点 a , b a, b )。

样例输入 1

10 13 5
-2 -7 0 2 -9 -2 -4 9 8 -1
9 8
9 4
9 2
4 10
10 7
10 6
2 1
8 3
7 5
3 8 6
1 7 -8
1 5 -9
1 5 -4
1 4 -2
1 2 -1
3 5 1
1 7 1
3 1 3
1 1 -3
3 10 2
1 1 -8
3 8 4

样例输出 1

16
-37
7
-1
17

样例输入 2

10 16 4
-13 -11 5 4 18 13 14 -8 -8 14
4 1
4 10
10 2
2 8
4 7
1 6
8 5
1 3
2 9
3 5 10
1 5 -5
2 9 -4
3 8 6
1 5 -8
2 8 -5
3 8 7
1 9 0
2 10 -3
3 7 6
2 9 -4
2 8 2
3 4 4
2 1 8
1 6 5
3 8 3

样例输出 2

13
-1
8
18
4
-5

数据范围与提示

40 40 % 的数据不含操作 2
1 N , M 1 0 6 , 1 R N , 1 0 6 v i , x 1 0 6 . 1⩽N,M⩽10^6, 1⩽R⩽N, −10^6⩽v_i,x⩽10^6.

代码

#pragma GCC optimize(3,"Ofast","inline")
#pragma G++ optimize(3,"Ofast","inline")

#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>

#define R                  register int
#define re(i,a,b)          for(R i=a; i<=b; i++)
#define ms(i,a)            memset(a,i,sizeof(a))
#define MAX(a,b)           (((a)>(b)) ? (a):(b))
#define MIN(a,b)           (((a)<(b)) ? (a):(b))

#define lowbit(x)          ((x) & (-x))

using namespace std;

typedef long long LL;

namespace IN {
    #include <cctype>
    #include <cstdio>
    #define bsiz 1000000

    int sta[30];
    char buf[bsiz], pbuf[bsiz], *p = pbuf, *s = buf, *t = buf;

    #define mgetc() (s == t && (t = (s = buf) + fread(buf, 1, bsiz, stdin), s == t) ? EOF : *s++)

    inline int read() {
        register char ch;
        register int res=0, p;
        while (!isdigit(ch = mgetc()) && (ch ^ '-'));
        p = ch == '-' ? ch = mgetc(), -1 : 1;
        while (isdigit(ch)) res = (res << 3) + (res << 1) + (ch ^ 48), ch = mgetc();
        return res*p;
    }

}

const int N=1e6+5;

struct edge{
    int to,nt;  
} e[N<<1];  

int cnt,sum,a[N],h[N],n,m,rt,tin[N],tout[N],f[N][20],dep[N];  
LL s[N],ss[N],val[N],d[N];  

void add(int a,int b){
    e[++cnt].to=b; e[cnt].nt=h[a]; h[a]=cnt;  
    e[++cnt].to=a; e[cnt].nt=h[b]; h[b]=cnt;  
} 

void dfs(int x,int fa,int d,LL tot) {
    tin[x]=++sum;  
    f[x][0]=fa;  
    val[x]=tot; 
    dep[x]=d;  
    for(int i=h[x]; i; i=e[i].nt) {
        int v=e[i].to;  
        if(v==fa) continue;  
        dfs(v,x,d+1,tot+a[v]);  
    }
    tout[x]=sum;  
}

int inline ancestor(int x,int y) {
    return tin[x]<=tin[y] && tout[y]<=tout[x];  
}

int lca(int x,int y) {
    if(ancestor(x,y)) return x; 
    if(ancestor(y,x)) return y;  
    for(int i=19; i>=0; i--)  
        if(!ancestor(f[x][i],y)) 
            x=f[x][i];  
    return f[x][0];  
}

void Add(int x,LL v,LL s[]){
    for(int i=x; i<=n; i+=lowbit(i))  
        s[i]+=v;  
}

LL getsum(int x,LL s[]){
    LL ret=0;  
    for(int i=x; i; i-=lowbit(i)) 
        ret+=s[i];  
    return ret;  
}

int main() {
    n=IN::read();  
    m=IN::read(); 
    rt=IN::read();  
    for(int i=1; i<=n; i++) 
        a[i]=IN::read();  
    for(int i=1; i<n; i++) {
        int x,y;  
        x=IN::read(); 
        y=IN::read();  
        add(x,y);  
    }
    dfs(rt,rt,1,a[rt]);     
    for(int j=1; j<20; j++)  
        for(int i=1; i<=n; i++)  
            f[i][j]=f[f[i][j-1]][j-1];  
    while(m--) {
        int k,a,b,x;  
        k=IN::read();  
        if(k==1) {
            a=IN::read();  
            x=IN::read();
            Add(tin[a],x,d);
            Add(tout[a]+1,-x,d);      
        } else if(k==2) {
            a=IN::read();  
            x=IN::read();
            Add(tin[a],(LL)x*dep[a],s);  
            Add(tout[a]+1,-(LL)x*dep[a],s);  
            Add(tin[a],x,ss);  
            Add(tout[a]+1,-x,ss);  
        } else {
            a=IN::read();  
            b=IN::read(); 
            int t=lca(a,b);  
            LL t1=val[a]+val[b]-val[t];
            if(t!=rt) t1-=val[f[t][0]];  
            LL t2=getsum(tin[a],d)+getsum(tin[b],d)-getsum(tin[t],d);  
            if(t!=rt) t2-=getsum(tin[f[t][0]],d);  
            LL t3=getsum(tin[a],s)+getsum(tin[b],s)-getsum(tin[t],s);  
            if(t!=rt) t3-=getsum(tin[f[t][0]],s ); 
            LL t4=getsum(tin[a],ss)*(dep[a]+1)+getsum(tin[b],ss)*(dep[b]+1)-getsum(tin[t],ss)*(dep[t]+1);  
            if(t!=rt) t4-=getsum(tin[f[t][0]],ss)*(dep[f[t][0]]+1); 
            printf("%lld\n",t1+t2-t3+t4);        
        }
    }
    return 0;  
}
发布了106 篇原创文章 · 获赞 156 · 访问量 4万+

猜你喜欢

转载自blog.csdn.net/Ljnoit/article/details/105041154
今日推荐