洛谷 P3384 [模板] 树链剖分

传送门

树链剖分

本质上,树链剖分是一种将树肢解成平摊开来,再使用线段树对其进行维护的神奇算法。

我们需要通过两次 \(dfs\) ,预处理一些我们需要的东西,这里是第一次:

  1. 树上每个节点的父亲,这个不必多说,方便后续找 \(LCA\) 时的上跳过程;

  2. 树上每个节点的深度,这个也不必多说,方便后续决定对哪个节点进行操作;

  3. 每个节点的子树大小,同时标记每个节点的重儿子

何为重儿子? 顾名思义,对于一个节点,他的所有儿子中子树大小最大的那个儿子就是重儿子。其他的儿子我们就称其为轻儿子。

对于一条树边,连接重儿子的边我们叫他重边,连接轻儿子的边我们叫他轻边(又称轻链),重边连成的链我们叫他重链。

这里是第二次:

  1. 每个节点的 \(dfs\) 序,将树肢解后节点就按这个顺序平摊,值得注意的是每次要先遍历该节点的重儿子,回溯之后再遍历其他出点;

  2. \(dfs\) 编号所对应的节点编号,方便逆向访问到这个节点;

  3. 对于每各节点,我们记录顺着该节点所在的链向上走所能到达的最上方的节点(链顶);

形象地说,重链是我们修建的高速公路,在其上我们可以直接快速达到一条链的最顶部,而轻链则不同,在其上只能一步一步向上跳跃(也可以说,轻链的链顶就是当前点的父亲,轻链是长度为1
的重链)。

实际上,所谓轻链重链并没有什么特殊意义,其本质是将一棵树剖成几条链的一个较为方便的策略。

我们容易知道,树上任意两个点的最短路径都可以通过上文提到的轻重链来达到。对于每个链顶深度较大的点,我们让他跳跃到他链顶的父亲的位置,并对他途径的链进行区间操作

这个时候,你就会发现将重链 \(dfs\) 序连续的妙处所在了:

每一个重链都处在一段连续的区间上,我们可以统一对其进行处理。

而对于轻链,由于轻链长度只有一,所以不在乎所处区间是否连续。

重复上述操作,直到两个点的链顶相同为止。通过这样的步骤,我们将树上两点间的最短路径拆分成了数个区间,然后转化为了区间操作。

那么对子树的操作如何实现呢?

不难发现,每一个节点的子树在区间上都连续,理由很简单,只有当前子树递归完毕之后才会访问另一颗子树。同时,我们也可以得到这个区间的左右端点,若设当前节点的 \(dfs\) 序为 \(x\)

\(left node : x , right node : x + size[x] - 1\)

然后,对这个区间进行区间操作即可。

怎么样,是不是觉得非常简单?

当然,在代码实现的过程中,依旧有一些小小的细节值得注意:

  1. 线段树,不用我说,写错了就拖出去枪毙十分钟(笔者至少被枪毙了半小时);

  2. 对于子树操作(等价于区间操作),参数是点的 \(dfs\) 序,而对于最短路操作,参数则是点原本的序号,务必要搞清楚 \(dfs\) 序与原本点的编号的异同;

  3. 先两次 \(dfs\) ,执行完毕后再建树(这问题太蠢我不忍直视);

  4. \(dfs1\) 中注意不要又跑回父亲节点了,\(dfs2\) 中注意到了叶子节点要及时终止函数。

一时间就想到这么多,希望能对大家有所帮助。

以下提供模板代码,为了锻炼读者的代码阅读能力(我懒),没有加任何注释,愿各位食用愉快(逃)

模板代码

#include<iostream>
#include<cctype>
#include<cstdio>
using namespace std;
typedef long long ll; 
const int maxn = 50005;
ll read(){
    ll re = 0,ch = getchar();
    while(!isdigit(ch)) ch = getchar();
    while(isdigit(ch)) re = (re<<1) + (re<<3) + ch - '0',ch = getchar();
    return re;
}
int n,m,r,p;
struct edge{
    int v,nxt;
}e[maxn<<1];
int h[maxn],cnt;
void addedge(int u,int v){
    e[++cnt].v = v;
    e[cnt].nxt = h[u];
    h[u] = cnt;
}
int fa[maxn],sz[maxn],son[maxn],dfn[maxn],rev[maxn],val[maxn],dis[maxn],top[maxn];
void dfs1(int u,int f){
    dis[u] = dis[f] + 1;
    fa[u] = f;
    sz[u] = 1;
    for(int i = h[u];i;i = e[i].nxt){
        if(e[i].v != f){
            dfs1(e[i].v,u);
            sz[u] += sz[e[i].v];
            if(sz[e[i].v] > sz[son[u]]) son[u] = e[i].v;
        }
    }
}
void dfs2(int u,int topf){
    dfn[u] = ++cnt;
    rev[cnt] = u;
    top[u] = topf;
    if(!son[u]) return;
    dfs2(son[u],topf);
    for(int i = h[u];i;i = e[i].nxt)
        if(!dfn[e[i].v]) dfs2(e[i].v,e[i].v);
}
struct node{
    int l,r;
    ll sum,add;
    #define l(x) t[x].l 
    #define r(x) t[x].r
    #define sum(x) t[x].sum
    #define add(x) t[x].add
    #define mid(x) (t[x].r + t[x].l >> 1)
}t[maxn<<2];
void pushdown(int x){
    if(add(x)){
        sum(x<<1) += add(x) * (mid(x) - l(x) + 1);
        sum(x<<1|1) += add(x) * (r(x) - mid(x));
        add(x<<1) += add(x);
        add(x<<1|1) += add(x);
        add(x) = 0;
    }
}
void pushup(int x){
    sum(x) = (sum(x<<1) % p + sum(x<<1|1) % p) % p;
}
void build(int x,int l,int r){
    l(x) = l;
    r(x) = r;
    if(l == r){
        sum(x) = val[rev[l]];
        return;
    }
    build(x<<1,l,mid(x));
    build(x<<1|1,mid(x) + 1,r);
    pushup(x);
}
void modify(int x,int l,int r,int v){
    if(l <= l(x) && r >= r(x)){
        sum(x) += (r(x) - l(x) + 1) * v;
        add(x) += v;
        sum(x) %= p;
        add(x) %= p;
        return;
    }
    pushdown(x);
    if(l <= mid(x)) modify(x<<1,l,r,v);
    if(r > mid(x)) modify(x<<1|1,l,r,v);
    pushup(x); 
}
ll quiry(int x,int l,int r){
    ll ans = 0;
    if(l <= l(x) && r >= r(x))
        return sum(x);  
    pushdown(x);
    if(l <= mid(x)) ans += quiry(x<<1,l,r);
    if(r > mid(x)) ans += quiry(x<<1|1,l,r);
    return ans % p;
}
void tadd(int x,int y,int v){
    while(top[x] != top[y]){
        if(dis[top[x]] < dis[top[y]]) swap(x,y);
        modify(1,dfn[top[x]],dfn[x],v);
        x = fa[top[x]];
    }
    if(dis[x] > dis[y]) swap(x,y);
    modify(1,dfn[x],dfn[y],v);
}
ll task(int x,int y){
    ll ans = 0;
    while(top[x] != top[y]){
        if(dis[top[x]] < dis[top[y]]) swap(x,y);
        ans += quiry(1,dfn[top[x]],dfn[x]);
        ans %= p;
        x = fa[top[x]];
    }
    if(dis[x] > dis[y]) swap(x,y);
    ans += quiry(1,dfn[x],dfn[y]);
    return ans % p;
}
int main(){
    n = read(),m = read(),r = read(),p = read();
    for(int i = 1;i <= n;i++) val[i] = read();
    for(int i = 1;i < n;i++){
        int u = read(),v = read();
        addedge(u,v);
        addedge(v,u);
    }
    cnt = 0; 
    dfs1(r,0);
    dfs2(r,0);
    build(1,1,n);
    for(int i = 1;i <= m;i++){
        int op = read(),x,y,z;
        if(op == 1){
            x = read(),y = read(),z = read();
            tadd(x,y,z);        
        }
        if(op == 2){
            x = read(),y = read(); 
            printf("%lld\n",task(x,y));
        }
        if(op == 3){
            x = read(),z = read();
            modify(1,dfn[x],dfn[x] + sz[x] - 1,z);
        }
        if(op == 4){
            x = read();
            printf("%lld\n",quiry(1,dfn[x],dfn[x] + sz[x] - 1));
        }
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/mysterious-garden/p/9859599.html