[洛谷P3384] [模板] 树链剖分

题目传送门

显然是一道模板题。

然而索引出现了错误,狂wa不止。

感谢神犇Dr_J指正。%%%orz。

建线段树的时候,第44行。

把sum[p]=bv[pos[l]]%mod;打成了sum[p]=bv[in[l]]%mod;

忘了要用反映射搞一下......

树链剖分,从每个节点的儿子中,找出子树最大的一个作为重儿子。

然后以此将树链分成轻链和重链。

之后dfs一遍求出树链剖分序。

树链剖分序不仅保证子树内节点的编号在序列上连续,还保证一条重链上的节点的编号连续。

用一个线段树维护一下。

更改/询问子树的时候就是区间修改/查询区间和。

更改/询问树链的时候就是一步一步跳重链,每次一个区间修改/查询区间和。

  1 #include<cstdio>
  2 #include<cstring>
  3 #include<algorithm>
  4 #define ll long long
  5 using namespace std;
  6 
  7 int n,m,root;
  8 ll mod;
  9 ll bv[100005];
 10 int hd[100005],to[200005],nx[200005],ec;
 11 int sz[100005],f[100005],d[100005];
 12 int son[100005],tp[100005];
 13 int in[100005],out[100005],pc,pos[100005];
 14 int lb[400005],rb[400005];
 15 ll sum[400005],lz[400005];
 16 
 17 void edge(int af,int at)
 18 {
 19     to[++ec]=at;
 20     nx[ec]=hd[af];
 21     hd[af]=ec;
 22 }
 23 
 24 void pushup(int p)
 25 {
 26     sum[p]=((sum[p<<1]+sum[p<<1|1])%mod+mod)%mod;
 27 }
 28 
 29 void pushdown(int p)
 30 {
 31     if(!lz[p])return;
 32     sum[p<<1]=((sum[p<<1]+(rb[p<<1]-lb[p<<1]+1)%mod*lz[p]%mod)%mod+mod)%mod;
 33     sum[p<<1|1]=((sum[p<<1|1]+(rb[p<<1|1]-lb[p<<1|1]+1)%mod*lz[p]%mod)%mod+mod)%mod;
 34     lz[p<<1]=((lz[p<<1]+lz[p])%mod+mod)%mod;
 35     lz[p<<1|1]=((lz[p<<1|1]+lz[p])%mod+mod)%mod;
 36     lz[p]=0;
 37 }
 38 
 39 void build(int p,int l,int r)
 40 {
 41     lb[p]=l,rb[p]=r;
 42     if(l==r)
 43     {
 44         sum[p]=bv[pos[l]]%mod;
 45         return;
 46     }
 47     int mid=(l+r)>>1;
 48     build(p<<1,l,mid);
 49     build(p<<1|1,mid+1,r);
 50     pushup(p);
 51 }
 52 
 53 void add(int p,int l,int r,ll v)
 54 {
 55     if(lb[p]>=l&&rb[p]<=r)
 56     {
 57         sum[p]=(sum[p]+(rb[p]-lb[p]+1)%mod*v%mod)%mod;
 58         lz[p]=(lz[p]+v)%mod;
 59         return;
 60     }
 61     pushdown(p);
 62     int mid=(lb[p]+rb[p])>>1;
 63     if(l<=mid)add(p<<1,l,r,v);
 64     if(r>mid)add(p<<1|1,l,r,v);
 65     pushup(p);
 66 }
 67 
 68 ll query(int p,int l,int r)
 69 {
 70     if(lb[p]>=l&&rb[p]<=r)return sum[p];
 71     pushdown(p);
 72     int mid=(lb[p]+rb[p])>>1;
 73     ll ret=0;
 74     if(l<=mid)ret=((ret+query(p<<1,l,r))%mod+mod)%mod;
 75     if(r>mid)ret=((ret+query(p<<1|1,l,r))%mod+mod)%mod;
 76     return ret;
 77 }
 78 
 79 void pre(int p,int fa)
 80 {
 81     d[p]=d[fa]+1,f[p]=fa,sz[p]=1;
 82     for(int i=hd[p];i;i=nx[i])
 83     {
 84         if(to[i]==fa)continue;
 85         pre(to[i],p);
 86         sz[p]+=sz[to[i]];
 87         if(sz[to[i]]>sz[son[p]])son[p]=to[i];
 88     }
 89 }
 90 
 91 void dfs(int p)
 92 {
 93     in[p]=++pc;
 94     pos[pc]=p;
 95     if(p==son[f[p]])tp[p]=tp[f[p]];
 96     else tp[p]=p;
 97     if(son[p])dfs(son[p]);
 98     for(int i=hd[p];i;i=nx[i])
 99         if(to[i]!=f[p]&&to[i]!=son[p])dfs(to[i]);
100     out[p]=pc;
101 }
102 
103 int main()
104 {
105     scanf("%d%d%d%lld",&n,&m,&root,&mod);
106     for(int i=1;i<=n;i++)scanf("%lld",&bv[i]);
107     for(int i=1;i<n;i++)
108     {
109         int ff,tt;
110         scanf("%d%d",&ff,&tt);
111         edge(ff,tt),edge(tt,ff);
112     }
113     pre(root,root);
114     dfs(root);
115     build(1,1,n);
116     for(int i=1;i<=m;i++)
117     {
118         int op;
119         scanf("%d",&op);
120         if(op==1)
121         {
122             int x,y;
123             ll z;
124             scanf("%d%d%lld",&x,&y,&z);
125             z%=mod;
126             while(tp[x]!=tp[y])
127             {
128                 if(d[tp[x]]<d[tp[y]])swap(x,y);
129                 add(1,in[tp[x]],in[x],z);
130                 x=f[tp[x]];
131             }
132             if(d[x]>d[y])swap(x,y);
133             add(1,in[x],in[y],z);
134         }
135         if(op==2)
136         {
137             int x,y;
138             scanf("%d%d",&x,&y);
139             ll ans=0;
140             while(tp[x]!=tp[y])
141             {
142                 if(d[tp[x]]<d[tp[y]])swap(x,y);
143                 ans=((ans+query(1,in[tp[x]],in[x]))%mod+mod)%mod;
144                 x=f[tp[x]];
145             }
146             if(d[x]>d[y])swap(x,y);
147             ans=((ans+query(1,in[x],in[y]))%mod+mod)%mod;
148             printf("%lld\n",ans);
149         }
150         if(op==3)
151         {
152             int x;
153             ll z;
154             scanf("%d%lld",&x,&z);
155             z%=mod;
156             add(1,in[x],out[x],z);
157         }
158         if(op==4)
159         {
160             int x;
161             scanf("%d",&x);
162             ll ans=query(1,in[x],out[x]);
163             printf("%lld\n",ans);
164         }
165     }
166     return 0;
167 }

猜你喜欢

转载自www.cnblogs.com/eternhope/p/9726209.html