题目背景
这天又是一场模拟赛。作为机房栋梁的可爱妹子,Wu_Mr
再次 AK 了。而她的好姬友,108oahnew
却直接暴毙。
“咋做嘛,咋做嘛。”108oahnew
向 Wu_Mr
询问这些题的解法。
“我告诉你,但是你晚上得和我搞姬,这一切值得吗?” Wu_Mr
阴阳怪气。
“这……可是……可是……”108oahnew
脸渐渐泛出红色,“倒也不是不行……”
“??!”看着眼前的 108oahnew
,Wu_Mr
有些吃惊,“那,我就给你讲吧。不过今天晚上不行,我要去补考上周的模拟赛,明天吧。不能忘了哦。”说完,她就把 108oahnew
拉到白板前面开始讲题。
在晚上去补考之前,Wu_Mr
怕 108oahnew
太寂寞,就给了她一道练习题。
题目描述
给定一棵 n n n 个节点的树,每个点有一个点权。有 q q q 次操作,每次操作为修改某个点 x x x 的点权为 v v v,称为修改操作(修改操作是永久的)。
你每次可以选择一个点权均大于 0 0 0 的连通块,使这个连通块的每个点权都减一,称为一次删减操作(删除操作对于每次修改操作独立)。

你需要在每次修改操作后,输出使得所有点权全部变为 0 0 0 最少需要的删减操作次数。
第一行两个数 n n n 和 q q q,表示节点数和修改操作数。
接下来一行 n n n 个正整数,第 i i i 个表示编号为 i i i 的点的点权。
接下来 n − 1 n-1 n−1 行,每行两个数 u , v u, v u,v 表示这两个点之间有一条边。
接下来 q q q 行,每行两个数 x , v x, v x,v,表示这次修改操作是把 x x x 点的点权改为 v v v。
对于每个询问,输出一行一个整数表示最少的删减操作次数。
样例输入 1
5 1
1 1 1 1 1
1 2
2 3
3 4
4 5
1 2
样例输出 1
2
样例输入 2
10 5
39 24 1 16 13 31 90 69 7 34
1 2
1 3
1 4
4 5
4 6
6 7
1 8
6 9
6 10
2 96
4 95
8 40
9 26
6 97
样例输出 2
203
244
215
215
155
样例解释 1
你可以先选择整颗树进行删减操作,再选择 1 1 1 号节点进行删减操作。
对于 10 % 10 \% 10% 的数据, n ≤ 10 , a i ≤ 3 , q = 1 n \leq 10,a_i \leq 3,q=1 n≤10,ai≤3,q=1。
对于 30 % 30 \% 30% 的数据, n , q ≤ 1000 n,q \leq 1000 n,q≤1000。
对于另外 20 % 20 \% 20% 的数据,保证输入的树是一条从 1 1 1 到 n n n 的链。
对于 100 % 100 \% 100% 的数据, n , q ≤ 1 0 5 , 1 ≤ a i ≤ 1 0 7 , 1 ≤ x ≤ n , 1 ≤ v ≤ 1 0 7 n,q \leq 10^5, 1 \leq a_i \leq 10^7, 1 \leq x \leq n, 1 \leq v \leq 10^7 n,q≤105,1≤ai≤107,1≤x≤n,1≤v≤107。 x x x 和 v v v 的含义如题面描述。
首先,第一眼想到的是以前做过的一道CCJ的城市
但如果按原本的思路,容易想到 O ( n n l o g n ) O(n\sqrt nlogn) O(nnlogn),即对于修改分块:
度数大于 n \sqrt n n 的点查询时再修改,度数小于 n \sqrt n n 的点直接修改。
代码:
#include<bits/stdc++.h>
#define N 100005
typedef long long ll;
using namespace std;
int read(){
int op=1,sum=0;
char ch=getchar();
while(ch<'0'||ch>'9') {
if(ch=='-') op=-1;ch=getchar();}
while(ch>='0'&&ch<='9') sum=(sum<<3)+(sum<<1)+ch-'0',ch=getchar();
return op*sum;
}
vector<ll> c[N];
int d[N];
int tot,head[N],ver[N<<1],nex[N<<1];
inline void add(int x,int y){
nex[++tot]=head[x];head[x]=tot;ver[tot]=y;
d[y]++;
}
struct node{
int id,c,t;
}ask[N];
int fa[N],sqr;
ll ans,a[N];
vector<ll> t1[N],t2[N];
int len[N];
inline int lowbit(int x){
return x&(-x);}
inline ll get1(int x,int num){
ll now=0;
for(;x;x-=lowbit(x))now+=t1[num][x];
return now;
}
inline void add1(int x,ll val,int num){
if(!x)return ;
for(;x<=len[num];x+=lowbit(x))t1[num][x]+=val;
}
inline ll get2(int x,int num){
ll now=0;
for(;x;x-=lowbit(x))now+=t2[num][x];
return now;
}
inline void add2(int x,ll val,int num){
for(;x<=len[num];x+=lowbit(x))t2[num][x]+=val;
}
vector<int> cl[N],to[N];
inline int pos(int col,int num){
return lower_bound(cl[num].begin(),cl[num].end(),col)-cl[num].begin();
}
struct Pre{
int a,id;
}Pr[N];
bool cmp(Pre x,Pre y){
return x.a>y.a;}
int b[N];
inline ll work(int x,ll c){
int mid=pos(c,x);
ll les=get2(mid-1,x),gra=(c)*(get1(len[x],x)-get1(mid-1,x));
for(int i=0;i<to[x].size();++i){
ll tc=a[to[x][i]];
if(tc<c)les+=tc;
else gra+=(c);
}
return les+gra;
}
inline void ch(int x,ll c){
if(d[x]>sqr){
a[x]=c;return ;}
for(int i=head[x];i;i=nex[i]){
int y=ver[i];
int to=pos(a[x],y);
add1(to,-1,y);add2(to,-a[x],y);
to=pos(c,y);
add1(to,1,y);add2(to,c,y);
}
a[x]=c;
}
int main(){
// freopen("sorry.in","r",stdin);
// freopen("sorry.out","w",stdout);
int n=read(),q=read();
sqr=sqrt(n);
for(int i=1;i<=n;++i){
a[i]=read();
ans=ans+a[i];
c[i].push_back(a[i]);
fa[i]=i;
Pr[i].a=a[i];Pr[i].id=i;
}
for(int i=1;i<n;++i){
int u=read(),v=read();
add(u,v);add(v,u);
}
for(int i=1;i<=q;++i){
ask[i].id=read(),ask[i].c=read(),ask[i].t=i;
c[ask[i].id].push_back(ask[i].c);
}
for(int i=1;i<=n;++i){
cl[i].push_back(0);
t1[i].push_back(0);
t2[i].push_back(0);
for(int j=head[i];j;j=nex[j]){
int y=ver[j];
if(d[y]<=sqr){
for(int k=0;k<c[y].size();++k){
cl[i].push_back(c[y][k]);
}
}else{
to[i].push_back(y);
}
}
sort(cl[i].begin(),cl[i].end());
cl[i].erase(unique(cl[i].begin(),cl[i].end()),cl[i].end());
len[i]=cl[i].size();
--len[i];
for(int j=1;j<=len[i];++j){
t1[i].push_back(0);t2[i].push_back(0);}
for(int j=head[i];j;j=nex[j]){
int y=ver[j];
if(d[y]<=sqr){
add1(pos(a[y],i),1,i);
add2(pos(a[y],i),a[y],i);
}
}
}
sort(Pr+1,Pr+1+n,cmp);
ll cnt=0;
Pr[n+1].a=0;
for(int i=1;i<=n;++i){
int x=Pr[i].id;b[x]=1;
for(int j=head[x];j;j=nex[j]){
int y=ver[j];
if(b[y])cnt++;
}
ans-=cnt*(Pr[i].a-Pr[i+1].a);
}
for(int i=1;i<=q;++i){
int x=ask[i].id;ll C=ask[i].c;
ans+=work(x,a[x]);
ans-=work(x,C);
ans+=C-a[x];
ch(x,C);
printf("%lld\n",ans);
}
return 0;
换一种思路,以为树上联通块很好算,所以直接考虑贪心。维护一块已经处理完的,每次把儿子丢进去,那么增加的就是 m a x ( a [ x ] − a [ f a [ x ] ] , 0 ) max(a[x]-a[fa[x]],0) max(a[x]−a[fa[x]],0)。
这样就很好修改了。
维护方法一种是在线线段树动态开点: O ( n l o g n ) O(nlogn) O(nlogn)
@ c c j ccj ccj
#include<cstdio>
#include<vector>
#include<algorithm>
#include<cstdlib>
using namespace std;
const int INF=1e7;
int n,q,a[100002],fa[100002],rt[100002],cnt;
long long ans;
vector<int>g[100002],t[100002];
typedef struct{
int ls,rs,siz;
long long sum;
}P;
P p[10000002];
void gengxin(int root,int begin,int end,int wz,int z){
if (begin==end)
{
p[root].siz+=z;p[root].sum=(long long)begin*p[root].siz;
return;
}
int mid=(begin+end)/2;
if (wz<=mid)
{
if (!p[root].ls)p[root].ls=++cnt;
gengxin(p[root].ls,begin,mid,wz,z);
}
else
{
if (!p[root].rs)p[root].rs=++cnt;
gengxin(p[root].rs,mid+1,end,wz,z);
}
p[root].sum=(p[p[root].ls].sum+p[p[root].rs].sum);
p[root].siz=(p[p[root].ls].siz+p[p[root].rs].siz);
}
long long cxsum(int root,int begin,int end,int begin2,int end2){
if (begin>end2 || end<begin2 || !root)return 0;
if (begin>=begin2 && end<=end2)return p[root].sum;
int mid=(begin+end)/2;
return cxsum(p[root].ls,begin,mid,begin2,end2)+cxsum(p[root].rs,mid+1,end,begin2,end2);
}
long long cxsiz(int root,int begin,int end,int begin2,int end2){
if (begin>end2 || end<begin2 || !root)return 0;
if (begin>=begin2 && end<=end2)return p[root].siz;
int mid=(begin+end)/2;
return cxsiz(p[root].ls,begin,mid,begin2,end2)+cxsiz(p[root].rs,mid+1,end,begin2,end2);
}
void dfs(int x,int y){
fa[x]=y;rt[x]=++cnt;
for (int i=0;i<g[x].size();i++)
if (g[x][i]!=y)
{
t[x].push_back(g[x][i]);dfs(g[x][i],x);
}
for (int i=0;i<t[x].size();i++)gengxin(rt[x],1,INF,a[t[x][i]],1);
ans+=cxsum(rt[x],1,INF,a[x],INF)-cxsiz(rt[x],1,INF,a[x],INF)*a[x];
}
int main()
{
scanf("%d%d",&n,&q);
for (int i=1;i<=n;i++)scanf("%d",&a[i]);
for (int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);g[v].push_back(u);
}
dfs(1,0);
while(q--)
{
int x,v;
scanf("%d%d",&x,&v);
ans-=cxsum(rt[x],1,INF,a[x],INF)-cxsiz(rt[x],1,INF,a[x],INF)*a[x];
if (fa[x])
{
gengxin(rt[fa[x]],1,INF,a[x],-1);
if (a[x]>a[fa[x]])ans-=a[x]-a[fa[x]];
}
a[x]=v;
ans+=cxsum(rt[x],1,INF,a[x],INF)-cxsiz(rt[x],1,INF,a[x],INF)*a[x];
if (fa[x])
{
gengxin(rt[fa[x]],1,INF,a[x],1);
if (a[x]>a[fa[x]])ans+=a[x]-a[fa[x]];
}
printf("%lld\n",ans+a[1]);
}
return 0;
}
另一种是离线树状数组,直接把所有可能取值全部扔进树状数组: O ( n l o g n ) O(nlogn) O(nlogn) 但常数小。
#include<bits/stdc++.h>
#define N 100005
typedef long long ll;
using namespace std;
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){
if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){
x=x*10+ch-'0';ch=getchar();}
return x*f;
}
ll a[N];
vector<ll> t1[N],t2[N],c[N];
int len[N];
inline int lowbit(int x){
return x&(-x);}
inline void add1(int x,ll val,int num){
for(;x<=len[num];x+=lowbit(x))t1[num][x]+=val;
}
inline ll ask1(int x,int num){
ll now=0;
for(;x;x-=lowbit(x))now+=t1[num][x];
return now;
}
inline void add2(int x,ll val,int num){
for(;x<=len[num];x+=lowbit(x))t2[num][x]+=val;
}
inline ll ask2(int x,int num){
ll now=0;
for(;x;x-=lowbit(x))now+=t2[num][x];
return now;
}
vector<int> to[N];
int fa[N];
void dfs(int x,int las){
for(int i=0;i<to[x].size();++i){
int y=to[x][i];
if(y==las)continue;
fa[y]=x;
dfs(y,x);
}
}
struct node{
int id,c;
}que[N];
ll ans;
inline ll work(int x,ll C){
ll now=max(C-a[fa[x]],0ll);
int pos=upper_bound(c[x].begin(),c[x].end(),C)-c[x].begin();
now+=ask1(len[x],x)-ask1(pos-1,x)-(ask2(len[x],x)-ask2(pos-1,x))*C;
return now;
}
int main(){
// freopen("data.in","r",stdin);
// freopen("sorry.out","w",stdout);
int n=read(),q=read();
for(int i=1;i<=n;++i)a[i]=read(),c[i].push_back(0);
c[0].push_back(0);
for(int i=1;i<n;++i){
int x=read(),y=read();
to[x].push_back(y);to[y].push_back(x);
}
dfs(1,0);
for(int i=1;i<=q;++i){
que[i].id=read(),que[i].c=read();
c[fa[que[i].id]].push_back(que[i].c);
}
for(int i=1;i<=n;++i){
c[fa[i]].push_back(a[i]);
ans+=max(a[i]-a[fa[i]],0ll);
}
for(int i=1;i<=n;++i){
sort(c[i].begin(),c[i].end());
c[i].erase(unique(c[i].begin(),c[i].end()),c[i].end());
len[i]=c[i].size()-1;
for(int j=0;j<=len[i];++j)t1[i].push_back(0),t2[i].push_back(0);
}
for(int i=1;i<=n;++i){
int pos=lower_bound(c[fa[i]].begin(),c[fa[i]].end(),a[i])-c[fa[i]].begin();
add1(pos,a[i],fa[i]);add2(pos,1,fa[i]);
}
for(int T=1;T<=q;++T){
int x=que[T].id,C=que[T].c;
ans-=work(x,a[x]);
ans+=work(x,C);
int pos=lower_bound(c[fa[x]].begin(),c[fa[x]].end(),a[x])-c[fa[x]].begin();
add1(pos,-a[x],fa[x]);add2(pos,-1,fa[x]);
pos=lower_bound(c[fa[x]].begin(),c[fa[x]].end(),C)-c[fa[x]].begin();
add1(pos,C,fa[x]);add2(pos,1,fa[x]);
a[x]=C;
printf("%lld\n",ans);
}
return 0;
}