线段树区间开方

概述

  线段树是实用的数据结构,支持所有符合结合律的运算的区间操作。但开方不符合结合律,怎么用线段树维护呢?其实线段树本身无法支持,但还是有方法在有限时间内维护的,就是利用数在经历多次开方后会趋向于统一的性质优化运算。

数据结构

  区间开方线段树的思想如上所示,就是利用开方运算的性质,在最初几组数据里暴力计算,后期根据连续统一的序列进行简便运算。对于两种常见的题目,有两种做法。

仅区间开方

  我们可以发现任何数在经历过很少的几次开方后就会等于0,于是我们很容易想到将最初的几次操作暴力计算,并记录最大值。由于$10^18$在经历6次开方和向下取整之后就等于1了,所以暴力最多算$6N$次,复杂度为$\Theta(N)$。而在此之后,我们遇到最大值不超过1的区间就可以不进行运算了。($\sqrt{1}=1 , \sqrt{0}=0$)

  至于实现,比普通的线段树还要简单,不需要lazy_tag,只需要维护区间最大值和所需操作即可,在update函数里加一个特判就可以优化计算。

例题:Luogu P4145 上帝造题的七分钟2 / 花神游历各国 

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int INF=1e9+7,MAXN=1e5+10,MAXNODE=MAXN*4;
int N,M;
LL tmp[MAXN],sum[MAXNODE],maxv[MAXNODE];
inline void push_up(int x){
    sum[x]=sum[x<<1]+sum[x<<1|1];
    maxv[x]=max(maxv[x<<1],maxv[x<<1|1]);
}
void init(int x,int l,int r){
    if(l==r){
        sum[x]=maxv[x]=tmp[l];
        return;
    }
    int mid=(l+r)>>1;
    init(x<<1,l,mid);
    init(x<<1|1,mid+1,r);
    push_up(x);
}
LL query(int x,int l,int r,int ql,int qr){
    if(ql<=l&&r<=qr)
        return sum[x];
    int mid=(l+r)>>1;
    LL ret=0;
    if(ql<=mid)
        ret+=query(x<<1,l,mid,ql,qr);
    if(mid<qr)
        ret+=query(x<<1|1,mid+1,r,ql,qr);
    return ret;
}
void update(int x,int l,int r,int ql,int qr){
    if(l==r){
        maxv[x]=sqrt(maxv[x]);
        sum[x]=sqrt(sum[x]);
        return;
    }
    if(maxv[x]<=1)
        return;
    int mid=(l+r)>>1;
    if(ql<=mid&&maxv[x<<1]>1)
        update(x<<1,l,mid,ql,qr);
    if(mid<qr&&maxv[x<<1|1]>1)
        update(x<<1|1,mid+1,r,ql,qr);
    push_up(x);
}
int main(){
    scanf("%d",&N);
    for(int i=1;i<=N;i++)
        scanf("%lld",tmp+i);
    init(1,1,N);
    scanf("%d",&M);
    for(int i=1,k,l,r;i<=M;i++){
        scanf("%d%d%d",&k,&l,&r);
        if(l>r)
            swap(l,r);
        if(k)
            printf("%lld\n",query(1,1,N,l,r));
        else
            update(1,1,N,l,r);
    }
    return 0;
}

注意:支持其他操作的线段树不支持此问题,因为这里没有区间加,会出现许多的0,相等的区间就会非常少。

支持其它操作

  开方和向下取整结合时,会有和|、&运算类似的使数据趋同的作用。所以我们只需要暴力执行前几次操作,对于后面的,只需要特判一段区间是否相等,如果相等则将区间和等指标直接开方,毕竟趋同是不可逆的。

例题:基础数据结构练习题xiaowuga大佬的代码

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+7;
long long a[N];
struct node{
    int l,r;
    long long maxx,minn,sum;
    long long lazy;
    void up(long long val){
        maxx+=val;minn+=val;sum+=(r-l+1)*1ll*val;
        lazy+=val;
    }
}tree[8*N];
void push_up(int x){
    tree[x].maxx=max(tree[x<<1].maxx,tree[x<<1|1].maxx);
    tree[x].minn=min(tree[x<<1].minn,tree[x<<1|1].minn);
    tree[x].sum=tree[x<<1].sum+tree[x<<1|1].sum;
}
void push_down(int x){
    long long val=tree[x].lazy;
    if(val){
        tree[x<<1].up(val);
        tree[x<<1|1].up(val);
        tree[x].lazy=0;
    }
}
void build(int x,int l,int r){
    tree[x].l=l;  tree[x].r=r; 
    tree[x].lazy=tree[x].sum=0;
    if(l==r){
        tree[x].minn=tree[x].maxx=tree[x].sum=a[l];
        return;
    }
    int m=(l+r)/2;
    build(x<<1,l,m);
    build(x<<1|1,m+1,r);
    push_up(x);
}
void updata(int x,int l,int r,long long val){
    int L=tree[x].l,R=tree[x].r;
    if(l<=L&&R<=r){
        tree[x].up(val);return;
    }
    int m=(L+R)/2;
    push_down(x);
    if(l<=m) updata(x<<1,l,r,val);
    if(r>m) updata(x<<1|1,l,r,val);
    push_up(x);
}
void Sqrt(int x,int l,int r){
    push_down(x);
    int L=tree[x].l,R=tree[x].r;
    if(l<=L&&R<=r){
        if(tree[x].maxx==tree[x].minn){
            long long t=(long long)sqrt(tree[x].maxx);
            updata(x,L,R,t-tree[x].maxx);
            return;
        }
        else if(tree[x].minn+1==tree[x].maxx){
            long long t1=(long long)sqrt(tree[x].minn); 
            long long t2=(long long)sqrt(tree[x].maxx); 
            if(t1+1==t2){
                updata(x,L,R,t2-tree[x].maxx);
                return;
            }
        }
    }
    int m=(L+R)/2;
    if(l<=m) Sqrt(x<<1,l,r);
    if(r>m) Sqrt(x<<1|1,l,r);
    push_up(x);
}
long long query(int x,int l,int r){
    push_down(x);
    int L=tree[x].l,R=tree[x].r;
    if(l<=L&&R<=r){
        return tree[x].sum;
    }
    int m=(L+R)/2;
    long long ans=0;
    if(l<=m) ans+=query(x<<1,l,r);
    if(r>m) ans+=query(x<<1|1,l,r);
    push_up(x);
    return ans;
}
int main(){
    int n,m;
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
    build(1,1,n);
    while(m--){
        int op,l,r;
        scanf("%d%d%d",&op,&l,&r);
        if(op==1){
            long long val;
            scanf("%lld",&val);
            updata(1,l,r,val);
        }
        else if(op==2){
            Sqrt(1,l,r);
        }
        else{
            printf("%lld\n",query(1,l,r));
        }
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/guoshaoyang/p/11228109.html
今日推荐