bzoj4034

树链剖分裸题

唯一一个要注意的地方就是数据范围吧。计算时不写(long long)是会爆掉的

#include<cstdio>
#include<cctype>
#define maxn 100001
using namespace std;
int n,m,cnt,son[maxn],fa[maxn],siz[maxn],val[maxn],top[maxn],a[maxn],dep[maxn],id[maxn];
struct data{int l,r;long long sum,tag;}tr[maxn<<2];
int head[maxn],to[maxn<<1],nex[maxn<<1];
long long ans;

void read(int &x){
    char ch=getchar();x=0;int f=1;
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
    x*=f;
}
void addedge(int u,int v){
    to[++cnt]=v;nex[cnt]=head[u];head[u]=cnt;
}

void dfs1(int x,int f){
    dep[x]=dep[f]+1;fa[x]=f;siz[x]=1;
    int maxson=-1;
    for(int i=head[x];i;i=nex[i]){
        if(to[i]==f)continue;
        dfs1(to[i],x);
        siz[x]+=siz[to[i]];
        if(siz[to[i]]>maxson){maxson=siz[to[i]];son[x]=to[i];}
    }
}

void dfs2(int x,int topf){
    a[++cnt]=val[x];id[x]=cnt;top[x]=topf;
    if(!son[x])return;
    dfs2(son[x],topf);
    for(int i=head[x];i;i=nex[i]){
        if(to[i]==fa[x]||to[i]==son[x])continue;
        dfs2(to[i],to[i]);
    }
}

void buildtr(int now,int l,int r){
    tr[now].l=l;tr[now].r=r;
    if(l==r){tr[now].sum=a[l];return;}
    int mid=(l+r)>>1;
    buildtr(now<<1,l,mid);buildtr(now<<1|1,mid+1,r);
    tr[now].sum=tr[now<<1].sum+tr[now<<1|1].sum;
}

void pushdown(int now){
    if(tr[now].l==tr[now].r||tr[now].tag==0)return;
    tr[now<<1].tag+=tr[now].tag;tr[now<<1|1].tag+=tr[now].tag;
    tr[now<<1].sum+=(tr[now<<1].r-tr[now<<1].l+1)*tr[now].tag;
    tr[now<<1|1].sum+=(tr[now<<1|1].r-tr[now<<1|1].l+1)*tr[now].tag;
    tr[now].tag=0;
}
 void addtr(int now,int l,int r,int ad){
     if(tr[now].l>=l&&tr[now].r<=r){tr[now].tag+=ad;tr[now].sum+=(long long)(tr[now].r-tr[now].l+1)*ad;return;}
     pushdown(now);
     int mid=(tr[now].l+tr[now].r)>>1;
     if(mid>=l)addtr(now<<1,l,r,ad);
     if(mid<r)addtr(now<<1|1,l,r,ad);
     tr[now].sum=tr[now<<1].sum+tr[now<<1|1].sum;
 }
long long query(int now,int l,int r){
    if(tr[now].l>=l&&tr[now].r<=r)return tr[now].sum;
    pushdown(now);
    long long mid=(tr[now].l+tr[now].r)>>1;
    if(mid>=r)return query(now<<1,l,r);else
    if(mid<l)return query(now<<1|1,l,r);else
    return query(now<<1,l,mid)+query(now<<1|1,mid+1,r);
}

void queryway(int x,int y){
    while(top[x]!=top[y]){
        ans+=query(1,id[top[x]],id[x]);
        x=fa[top[x]];
    }
    ans+=query(1,id[y],id[x]);
}

int main(){
    read(n);read(m);
    for(int i=1;i<=n;i++)read(val[i]);
    for(int i=1;i<n;i++){
        int u,v;read(u);read(v);
        addedge(u,v);addedge(v,u);
    }
    cnt=0;
    dfs1(1,1);dfs2(1,1);
    buildtr(1,1,n);
    for(int i=1;i<=m;i++){
        int opt,x,y;
        read(opt);read(x);
        switch(opt){
            case 1:read(y);addtr(1,id[x],id[x],y);break;
            case 2:read(y);addtr(1,id[x],id[x]+siz[x]-1,y);break;
            case 3:ans=0;queryway(x,1);printf("%lld\n",ans);break;
        }
    }
}

猜你喜欢

转载自www.cnblogs.com/MikuKnight/p/9014196.html