【题解】洛谷P3373 线段树2 && codevs P2216 行星序列(线段树)

既有加法又有乘法的线段树模型,第一感觉是开两个数组来记录标记,然后写两个pushdown函数,对不同情况进行操作,判断儿子节点有没有加法或乘法的标记啦。。。总之很复杂。经过各种情况的考虑和优化过后,我们拿到了…………70分。

那我们就需要想想更好的做法了。两个pushdown函数真的必要吗?答案是否定的。我们定义对儿子的sum计算为父亲的sum*乘法标记+加法标记*区间长度。然后如果对加法或乘法标记改变其值,将乘法标记乘自身,加法标记乘乘法标记再加上自己,原因同单独写加法或乘法。对加标记也进行类似的操作。注意取模,注意在建树时将乘法标记初始化为1,注意tmp为1和为2时代表的含义,传7个参数进去。。。这道题就可以解决了。sum和两个标记数组记得开大一点 ,maxn*8就足够了。

#include<cstdio>
#include<iostream>
#define ll long long
using namespace std;
const int maxn=100010;
int n,m,mod; 
int in[maxn],sum[maxn*8];
int plu[maxn*8],mul[maxn*8];
void build(int now,int l,int r)
{
	mul[now]=1;
	if(l==r)
	{
		sum[now]=in[l];
		return ;
	}
	int mid=(l+r)/2;
	build(now*2,l,mid);
	build(now*2+1,mid+1,r);
	sum[now]=(sum[now*2]+sum[now*2+1])%mod;
}
void update(int now)
{
	sum[now]=(sum[now*2]+sum[now*2+1])%mod;
}
void pushdown(int now,int l,int r)
{
	if(l==r) return;
	int mid=(l+r)/2;
	sum[now*2]=((ll)sum[now*2]*mul[now]%mod+(ll)(mid-l+1)*plu[now]%mod)%mod;
	mul[now*2]=(ll)mul[now*2]*mul[now]%mod;
	plu[now*2]=(ll)plu[now*2]*mul[now]%mod;
	plu[now*2]=(plu[now*2]+plu[now])%mod;
	
	sum[now*2+1]=((ll)sum[now*2+1]*mul[now]%mod+(ll)(r-mid)*plu[now]%mod)%mod;
	mul[now*2+1]=(ll)mul[now*2+1]*mul[now]%mod;
	plu[now*2+1]=(ll)plu[now*2+1]*mul[now]%mod;
	plu[now*2+1]=(plu[now*2+1]+plu[now])%mod;
	
	mul[now]=1;
	plu[now]=0;
}
void addnum(int now,int l,int r,int x,int y,ll k1,ll k2)
{
	int mid=(l+r)/2;
	if(x<=l&&y>=r)
	{
		sum[now]=((ll)sum[now]*k2%mod+(ll)(r-l+1)*k1%mod)%mod;
		mul[now]=(ll)mul[now]*k2%mod;
		plu[now]=(ll)plu[now]*k2%mod;
		plu[now]=(plu[now]+k1)%mod;
		return;
	} 
	pushdown(now,l,r);
	if(x<=mid) addnum(now*2,l,mid,x,y,k1,k2);
	if(mid+1<=y) addnum(now*2+1,mid+1,r,x,y,k1,k2);
	update(now);
}
ll getsum(int now,int l,int r,int x,int y)
{
	pushdown(now,l,r);
	int mid=(l+r)/2;
	if(x<=l&&y>=r)
	{
		return sum[now];
	}
	ll ans=0;
	if(x<=mid) ans=(ans+getsum(now*2,l,mid,x,y))%mod;
	if(mid+1<=y) ans=(ans+getsum(now*2+1,mid+1,r,x,y))%mod;
	update(now);
	return ans;
}
int main()
{
	scanf("%d%d",&n,&mod);
	for(int i=1;i<=n;i++)
	{
		scanf("%d",&in[i]);
	}
	build(1,1,n);
	scanf("%d",&m);
	for(int i=1;i<=m;i++)
	{
		int tmp,x,y,z;
		scanf("%d",&tmp);
		if(tmp==1)
		{
			scanf("%d%d%d",&x,&y,&z);
			addnum(1,1,n,x,y,0,z);
		}
		if(tmp==2)
		{
			scanf("%d%d%d",&x,&y,&z);
			addnum(1,1,n,x,y,z,1);
		}
		if(tmp==3)
		{
			scanf("%d%d",&x,&y);
			printf("%lld\n",getsum(1,1,n,x,y)%mod);
		}
	}
	return 0;
} 

猜你喜欢

转载自blog.csdn.net/Rem_Inory/article/details/81274470