树链剖分模板(洛谷P3384 )

前几天学了学树链剖分的模板题
总的来说树链剖分还是比较容易理解的
主要思想就是:把树上的若干串,离散到区间上
怎么把一棵树分割成若干串呢?
这时候要引出几个定义:
重儿子:就是所有儿子中,子树中节点数量最多的儿子
重链:连续重儿子连接成的链
除了重儿子以外都是轻儿子

截止目前我们现在可以得出几个结论:
每个非叶子节点一定是重链的一部分(因为有儿子就肯定有重儿子,自己要么是重链的首端,要么是重链的中间节点)
叶子结点要么是重链的一部分,要么自己独自成为一个重链(就一个顶点的重链)

综上我们得出:一颗树可以被分割成若干个重链!

我看其他博客中写的什么轻儿子,轻链,根本没什么用,哪里有什么轻链可言?我们不妨把那些不在重链上节点,划分为只有一个节点的重链!
在剖分过程中,要计算如下7个值:
fa[u]:u在树中的父亲
dep[u]:u节点的深度
size[u]:u的子树节点数(子树大小)
son[u]:u的重儿子
top[u]:u所在重链的顶部节点
id[u]:节点u在区间中的新编号
rk[x]:新编号对应的节点编号!

#include<iostream>
#include<cstdio>
#define ll long long
using namespace std;
const int maxn = 1e5 + 10;

struct Node
{
    int l,r;
    //lazy 表示的是 区间中每个节点需要加的数值
    long long sum,lazy;
}node[400100];

int point[2*maxn] ,next1[2*maxn],first[2*maxn];

int k,n,m,root,mod,v[maxn],cnt,fa[maxn],dep[maxn],size[maxn],son[maxn]
,top[maxn],id[maxn],rk[maxn];

void addEdge(int,int);
void dfs1(int);
void dfs2(int,int);
void pushup(int);
void pushdown(int);
void upGrade(int,int,int,int);
void build(int,int,int);
ll query(int,int,int);

void addEdge(int s,int t)
{
    point[++k] = t;next1[k] = first[s];first[s] = k;
}
//更新 size[u] dep[u] fa[u] son[u]
void dfs1(int now)
{
    size[now] = 1;
    dep[now] = dep[fa[now]] + 1;
    int k = first[now];
    while(k != 0)
    {
        int new1 = point[k];
        if(new1 != fa[now])
        {
            fa[new1] = now;
            dfs1(new1);
            size[now] += size[new1];
            if(size[son[now]] < size[new1])//更新重儿子
                son[now] = new1;
        }
        k = next1[k];
    }
}
//更新 top[u] id[u] rk[u]
void dfs2(int now,int tp)
{
    top[now] = tp;
    rk[cnt] = now;
    id[now] = cnt++;

    if(son[now])
        dfs2(son[now],tp);
    for(int k = first[now];k != 0;k = next1[k])
        if(point[k] != fa[now] && point[k] != son[now])
            dfs2(point[k],point[k]);
}

inline void updates(int x,int y,int c)
{
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])
            swap(x,y);
        upGrade(id[top[x]],id[x],c,1);
        x=fa[top[x]];
    }
    if(id[x]>id[y])
        swap(x,y);
    upGrade(id[x],id[y],c,1);
}
inline ll sum(int x,int y)
{
    ll ans = 0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])
            swap(x,y);
        ans += query(id[top[x]],id[x],1);
        ans %= mod;
        x=fa[top[x]];
    }
    if(id[x]>id[y])
        swap(x,y);
   ans +=  query(id[x],id[y],1);
   return  ans%mod;
}
void pushup(int root)
{
    node[root].sum = (node[root*2].sum + node[root*2+1].sum)%mod;
}
void pushdown(int root)
{
    if(node[root].lazy == 0)
        return;

    (node[root*2].sum += (node[root].lazy)*(node[root*2].r - node[root*2].l + 1)) %= mod;
    (node[root*2+1].sum += (node[root].lazy)*(node[root*2+1].r - node[root*2+1].l + 1)) %= mod;

    (node[root*2].lazy += node[root].lazy) %= mod;
    (node[root*2+1].lazy += node[root].lazy) %= mod;
    node[root].lazy = 0;

}
void upGrade(int la,int rb,int chg,int root)
{
    if(la <= node[root].l && rb >= node[root].r)
    {
        (node[root].sum += (node[root].r - node[root].l+1)*chg)%mod;
        (node[root].lazy += chg)%mod;
        return;
    }
    pushdown(root);
    int mid = node[root].l + node[root].r >> 1;
    if(la <= mid )
    {
        upGrade(la,rb,chg,root*2);
    }
    if(rb >= mid+1)
    {
        upGrade(la,rb,chg,root*2+1);
    }
    pushup(root);
}
void build(int l,int r,int root)
{
    node[root].l = l;
    node[root].r = r;
    if(l == r)
    {
        node[root].sum = v[rk[l]];
        return;
    }
    int mid = l+r >> 1;
    build(l,mid,root*2);
    build(mid+1,r,root*2+1);

    pushup(root);
}
long long  query(int la,int rb,int root)
{
    if(la <= node[root].l && rb >= node[root].r)
    {
        return node[root].sum;
    }
    pushdown(root);
    int mid = node[root].l + node[root].r >> 1;
    long long ans = 0;
    if(la <= mid )
    {
        ans += query(la,rb,root*2);
        ans %= mod;
    }
    if(rb > mid )
    {
        ans += query(la,rb,root*2 + 1);
        ans %= mod;
    }
    return  ans%mod;
}

int main()
{
    scanf("%d%d%d%d",&n,&m,&root,&mod);
    for(int i = 1; i <= n; i++)
        scanf("%d",&v[i]);
    int x,y;
    for(int i = 1;i <= n-1; i++)
    {
        scanf("%d%d",&x,&y);
        addEdge(x,y);
        addEdge(y,x);
    }
    dfs1(root);
    cnt = 1;//每个节点的id
    dfs2(root,root);

    build(1,n,1);
  //  cout<<"&&"<<node[1].sum<<"&&";

    for(int i = 1;i <= m; i++)
    {
        int op,x,y,k;
        scanf("%d",&op);

        if(op == 1)
        {
            scanf("%d%d%d",&x,&y,&k);
            updates(x,y,k);
        }
        if(op == 2)
        {
            scanf("%d%d",&x,&y);
            printf("%lld\n",sum(x,y));
        }
        if(op == 3)
        {
            scanf("%d%d",&x,&y);
            upGrade(id[x],id[x] + size[x] - 1,y,1);
        }
        if(op == 4)
        {
            scanf("%d",&x);
            printf("%lld\n",query(id[x],id[x]+size[x]-1,1));
        }
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_43912833/article/details/98648666