二逼平衡树——树套树(线段树套Splay平衡树)

题面

  Bzoj3196

解析

  线段树和Splay两棵树套在一起,常数直逼inf,但最终侥幸过了

  思路还是比较简单, 在原数组维护一个下标线段树,再在每一个线段树节点,维护一个对应区间的权值Splay。简单说一下操作:

 0.提取区间

  这个操作是1、2、4、5操作的基础,其实也比较容易实现,在线段树中跑一跑,如果询问区间包含了节点覆盖的区间,就在一个数组中存一下节点的编号,然后返回就行

 1.查询区间内k的排名

  提取区间,找到区间内所有的Splay, 分别比k小的数的个数,相加后加一即可

 2.查询区间第k大

  提取区间,找到区间内所有的Splay,因为有多棵Splay,这个操作显然不能像一般的Splay一样查询,只能二分答案,转化为操作一,再check排名就行

 3.单点修改

  先在线段树中向下找到包含pos的所有节点,将这些节点对应的Splay中原值删去, 加入新值即可

 4.求前驱

  同样找到区间内的所有Splay,分别进入找前驱,输出最大的前驱即可。但k可能不在Splay中,于是我们先要找到大于等于k的权值最小的节点,将它旋转到根,再找前驱,但由于可能没有前驱,就在每棵Splay插入-inf, 查询排名时记得减去即可

 5.求后继

  操作与前驱类似,找到小于等于k的权值最大的节点,旋转到根,找后继,记得每棵Splay插入inf就行

  大部分操作都是自己YY的,可能不是很优秀,但我这个又臭又长的代码并没有调试多久,也是不容易啊。

 代码(340行):

#include<bits/stdc++.h>
using namespace std;
const int maxn = 100005, inf = 2147483647;

template<class T> void read(T &re)
{
    re=0;
    T sign=1;
    char tmp;
    while((tmp=getchar())&&(tmp<'0'||tmp>'9')) if(tmp=='-') sign=-1;
    re=tmp-'0';
    while((tmp=getchar())&&(tmp>='0'&&tmp<='9')) re=(re<<3)+(re<<1)+(tmp-'0');
    re*=sign;
}

int n, m, root[maxn<<1], rt, a[maxn];
int tot, cnt, lson[maxn<<1], rson[maxn<<1], stak[maxn], top, s[maxn], snum;

struct Splay_tree{
    int fa, s[2], val, siz, num;
}tr[maxn * 20];

void update(int x)
{
    int ls = tr[x].s[0], rs = tr[x].s[1];
    tr[x].siz = tr[ls].siz + tr[rs].siz + tr[x].num;
}

void Rotate(int x)
{
    int y = tr[x].fa, z = tr[y].fa, k = (tr[y].s[1] == x), w = (tr[z].s[1] == y), son = tr[x].s[k^1];
    tr[y].s[k] = son;tr[son].fa = y;
    tr[x].s[k^1] = y;tr[y].fa = x;
    tr[z].s[w] = x;tr[x].fa = z;
    update(y);update(x);
}

void Splay(int x, int to, int id)
{
    int y, z;
    while(tr[x].fa != to)
    {
        y = tr[x].fa;z = tr[y].fa;
        if(z != to)
            Rotate((tr[y].s[0] == x) ^ (tr[z].s[0] == y)? x: y);
        Rotate(x);
    }
    if(!to)
        root[id] = x;
}

void Insert(int x, int v)
{
    int now = root[x], ff = 0;
    while(now)
    {
        ff = now;
        tr[now].siz ++;
        if(tr[now].val == v)    break;
        now = tr[now].s[v>tr[now].val];
    }
    if(now)
        tr[now].num ++;
    else
    {
        if(snum)
            now = s[snum--];
        else
            now = ++cnt;
        tr[now].val = v;
        tr[ff].s[v>tr[ff].val] = now;
        tr[now].fa = ff;
        tr[now].num = 1;
        tr[now].siz = 1;
        tr[now].s[0] = tr[now].s[1] = 0;
    }
    Splay(now, 0, x);
}

void build(int &x, int l, int r)
{
    x = ++tot;
    root[x] = ++cnt;
    tr[root[x]].val = -inf;
    tr[root[x]].siz = tr[root[x]].num = 1;
    tr[0].s[1] = root[x];
    Insert(x, inf);
    for(int i = l; i <= r; ++i)
        Insert(x, a[i]);
    if(l == r)    return ;
    int mid = (l + r)>>1;
    build(lson[x], l, mid);
    build(rson[x], mid + 1, r); 
}

void Extract(int x, int l, int r, int L, int R)
{
    if(l <= L && R <= r)
    {
        stak[++top] = x;
        return ;
    }
    int mid = (L + R)>>1;
    if(l <= mid)
        Extract(lson[x], l, r, L, mid);
    if(mid < r)
        Extract(rson[x], l, r, mid + 1, R);
}

int Queryrk(int id, int x)
{
    int now = root[id], ret = 0;
    while(now)
    {
        int ls = tr[now].s[0], rs = tr[now].s[1];
        if(x < tr[now].val)
        {
            if(ls)
                now = ls;
            else
                break;
        }    
        else if(x == tr[now].val)
        {
            ret += tr[ls].siz ;
            break;
        }
        else
        {
            ret += tr[ls].siz + tr[now].num;
            if(rs)
                now = rs;
            else
                break;
        }
    }
    tr[0].s[1] = root[id];
    Splay(now, 0, id);
    return ret;
}

int work1(int x)
{
    int ret = -top;
    while(top)
    {
        ret += Queryrk(stak[top], x);
        top --;
    }
    return ret + 1;
}

int check(int x)
{
    int ret = 0;
    for(int i = 1; i <= top; ++i)
        ret += Queryrk(stak[i], x);
    return ret + 1;
}

int work2(int x)
{
    x += top;
    int l = 0, r = 1e8, mid, ret = 0;
    while(l <= r)
    {
        mid = (l + r)>>1;
        if(check(mid) <= x)
            ret = mid, l = mid + 1;
        else
            r = mid - 1;
    }
    top = 0;
    return ret;
}

int Find(int now, int x)
{
    while(1)
    {
        int ls = tr[now].s[0], rs = tr[now].s[1];
        if(tr[now].val == x)    return now;
        if(x < tr[now].val)    now = ls;
        else    now = rs;
    }
}

int Querypre(int now)
{
    now = tr[now].s[0];
    while(tr[now].s[1])    now = tr[now].s[1];
    return now;
}

int Querynxt(int now)
{
    now = tr[now].s[1];
    while(tr[now].s[0])    now = tr[now].s[0];
    return now;
}

void Modify(int x, int pos, int l, int r, int v)
{
    tr[0].s[1] = root[x];
    int y = Find(root[x], a[pos]);
    Splay(y, 0, x);
    int pre = Querypre(root[x]);
    int nxt = Querynxt(root[x]);
    Splay(pre, 0, x);Splay(nxt, pre, x);
    if(tr[y].num > 1)
    {
        tr[y].num --;
        tr[y].siz --;
    }
    else
    {
        s[++snum] = y;
        tr[nxt].s[0] = 0;
    }
    update(nxt);update(pre);
    Insert(x, v);
    if(l == r)
    {
        a[pos] = v;
        return ;
    }
    int mid = (l + r)>>1;
    if(pos <= mid)
        Modify(lson[x], pos, l, mid, v);
    else
        Modify(rson[x], pos, mid + 1, r, v);
}

int Findmx(int now, int x)
{
    int ret = 0;
    while(now)
    {
        int ls = tr[now].s[0], rs = tr[now].s[1];
        if(tr[now].val == x)
            return now;
        if(tr[now].val > x)
        {
            ret = now;
            now = ls;
        }
        else
            now = rs;
    }
    return ret;
}

int Findmn(int now, int x)
{
    int ret = 0;
    while(now)
    {
        int ls = tr[now].s[0], rs = tr[now].s[1];
        if(tr[now].val == x)
            return now;
        if(tr[now].val < x)
        {
            ret = now;
            now = rs;
        }
        else
            now = ls;
    }
    return ret;
}

int work3(int x)
{
    int ret = -inf;
    while(top)
    {
        tr[0].s[1] = root[stak[top]];
        int now = Findmx(root[stak[top]], x);
        Splay(now, 0, stak[top]);
        int pre = Querypre(root[stak[top]]);
        ret = max(ret, tr[pre].val);
        top--;
    }
    return ret;
}

int work4(int x)
{
    int ret = inf;
    while(top)
    {
        tr[0].s[1] = root[stak[top]];
        int now = Findmn(root[stak[top]], x);
        Splay(now, 0, stak[top]);
        int nxt = Querynxt(root[stak[top]]);
        ret = min(ret, tr[nxt].val);
        top--;
    }
    return ret;
}

int main()
{
    read(n);read(m);
    for(int i = 1; i <= n; ++i)
        read(a[i]);
    build(rt, 1, n);
    for(int i = 1; i <= m; ++i)
    {
        int opt, l, r, k, pos;
        read(opt);
        if(opt == 1)
        {
            read(l);read(r);read(k);
            Extract(rt, l, r, 1, n);
            printf("%d\n", work1(k));
        }
        else if(opt == 2)
        {
            read(l);read(r);read(k);
            Extract(rt, l, r, 1, n);
            printf("%d\n", work2(k));
        }
        else if(opt == 3)
        {
            read(pos);read(k);
            Modify(rt, pos, 1, n, k);
        }
        else if(opt == 4)
        {
            read(l);read(r);read(k);
            Extract(rt, l, r, 1, n);
            printf("%d\n", work3(k));
        }
        else
        {
            read(l);read(r);read(k);
            Extract(rt, l, r, 1, n);
            printf("%d\n", work4(k));
        }
    }
    return 0;
}
View Code

猜你喜欢

转载自www.cnblogs.com/Joker-Yza/p/11243378.html