AH/HNOI2017 单旋 (splay)

AH/HNOI2017 单旋 (splay)

前言

为什么说这道题叫做\(\text{splay}​\)就不能用\(\text{splay}​\)做?

不就是一个加一减一序列嘛干嘛用线段树\(+\text{set}\)做多烦啊

本人是一个\(\text{STL}\)的憎恨者,所以想出了一个可以使用\(\text {splay}\)实现的方法。

不过必须提一句……卡带着常数毁灭\(\text H\)国不是成功了吗

题意

要模拟一个只通过把\(x\)节点转到\(x\)节点的父亲的\(\text {splay}\)操作,并支持删根

怎么解

有一样东西普通的\(\text{splay}\)操作无法完成,那就是邪恶的深度。

因为普通的\(\text{splay}\)会把深度转乱,所以深度无法维护。

但是有一个很重要的发现:单旋最小值不会改变树的形状

这时候,我们就可以把问题转换一下:

【敲黑板】\(\text{The next part is only for very smart students.}\)

我们可以把深度存储在一个序列之中,维护这个序列。

当我们执行\(\text{splay}\)操作的时候,最小值原来的右子树成为了他父亲的左子树,相应的最小值的深度不变,其它结点深度均\(+1\)

我们需要维护以下操作:

  1. 查询
  2. 区间加一
  3. 前驱后继
  4. 插入

那么,喜欢线段树的同志们请注意了,你们维护\(3\)\(4\)是很困难的。

\(\text{splay}\)闪亮登场。

上代码

本人代码风格极差,望谅解。(太匆忙了)。

喜欢线段树的同志们自己实现一下。

\(\text{splay}\)的代码:

//By Zhengjiarui, Copyright @2019, All rights preserved.
//Do not copy this code
//AHOI/HNOI2017 splay
//solution: use splay to stimulate the spaly tree
#include <bits/stdc++.h>
#define ls c[x][0]
#define rs c[x][1]
#define rep(i, l, r) for (int i = l; i <= r; i++)
using namespace std;

template <typename T>
inline void rd(T &x) {
    int t;
    char ch;
    for (t = 0; !isdigit(ch = getchar()); t = (ch == '-'));
    for (x = ch - '0'; isdigit(ch = getchar()); x = x * 10 + ch - '0');
    if (t)
        x = -x;
}

const int inf = 2000000000, N = 100100;
int n, cnt, rt, Q, fa[N], v[N], sz[N], c[N][2], s[N], dep[N], mn[N];

void push(int x) {
    v[ls] += v[x];
    dep[ls] += v[x];
    mn[ls] += v[x];
    v[rs] += v[x];
    dep[rs] += v[x];
    mn[rs] += v[x];
    v[x] = 0;
}

void upd(int x) {
    sz[x] = sz[ls] + sz[rs] + 1;
    mn[x] = dep[x];
    if (ls)
        mn[x] = min(mn[x], mn[ls]);
    if (rs)
        mn[x] = min(mn[x], mn[rs]);
}

void rot(int &rt, int x) {
    int y = fa[x], z = fa[y], w = (c[y][1] == x);
    if (y == rt)
        rt = x;
    else
        c[z][c[z][1] == y] = x;
    fa[x] = z;
    fa[y] = x;
    fa[c[x][w ^ 1]] = y;
    c[y][w] = c[x][w ^ 1];
    c[x][w ^ 1] = y;
    upd(y);
}

void splay(int &rt, int x) {
    while (x != rt) {
        int y = fa[x], z = fa[y];
        if (y != rt) {
            if ((c[z][1] == y) ^ (c[y][1] == x))
                rot(rt, x);
            else
                rot(rt, y);
        }
        rot(rt, x);
    }
    upd(x);
}

void ins(int &x, int S, int d, int lst) {
    if (!x) {
        x = ++cnt, s[cnt] = S;
        dep[cnt] = mn[cnt] = d;
        sz[cnt] = 1;
        fa[cnt] = lst;
        return;
    }
    ins(c[x][S > s[x]], S, d, x);
    upd(x);
}

int getpre(int x, int S) {
    if (!x)
        return 0;
    if (v[x])
        push(x);
    if (s[x] > S)
        return getpre(c[x][0], S);
    Q = getpre(c[x][1], S);
    if (Q)
        return Q;
    else
        return x;
}

int getnxt(int x, int S) {
    if (!x)
        return 0;
    if (v[x])
        push(x);
    if (s[x] < S)
        return getnxt(rs, S);
    Q = getnxt(ls, S);
    if (Q)
        return Q;
    else
        return x;
}

int find(int x, int k) {
    if (v[x])
        push(x);
    if (sz[ls] + 1 == k)
        return x;
    if (sz[ls] + 1 < k)
        return find(rs, k - sz[ls] - 1);
    return find(ls, k);
}

int getl(int x, int d) {
    if (!x)
        return 0;
    if (v[x])
        push(x);
    if (min(mn[ls], dep[x]) >= d)
        return getl(rs, d) + sz[ls] + 1;
    else
        return getl(ls, d);
}

int getr(int x, int d) {
    if (!x)
        return 0;
    if (v[x])
        push(x);
    if (min(mn[rs], dep[x]) >= d)
        return getr(ls, d) + sz[rs] + 1;
    else
        return getr(rs, d);
}

int split(int l, int r) {
    int t1 = find(rt, l - 1), t2 = find(rt, r + 1);
    splay(rt, t1);
    splay(c[rt][1], t2);
    return c[c[rt][1]][0];
}

void mdf(int l, int r, int ad) {
    int y = split(l, r);
    v[y] += ad;
    mn[y] += ad;
    dep[y] += ad;
}

void change(int x, int S) {
    if (v[x])
        push(x);
    if (s[x] == S)
        dep[x] = 1;
    else
        change(c[x][S > s[x]], S);
    upd(x);
}

int main() {
    freopen("splay.in","r",stdin);
    freopen("splay.out","w",stdout);
    rd(n);
    ins(rt, -inf, inf, 0);
    ins(rt, inf, inf, 0);
    mn[0] = inf;
    rep(i, 1, n) {
        int op, x;
        rd(op);
        if (op == 1) {
            rd(x);
            int t1 = getpre(rt, x), t2 = getnxt(rt, x);
            int D = max(t1 > 2 ? dep[t1] : 0, t2 > 2 ? dep[t2] : 0) + 1;
            ins(rt, x, D, 0);
            splay(rt, cnt);
            printf("%d\n", D);
        }
        if (!(op & 1)) {
            int x = find(rt, 2), y = min(getl(rt, dep[x]), sz[rt] - 1) - 1;
            printf("%d\n", dep[x]);
            mdf(2, sz[rt] - 1, 1);
            if (y > 1)
                mdf(2, y + 1, -1);
            change(rt, s[x]);
        }
        if ((op & 1) && (op > 1)) {
            int x = find(rt, sz[rt] - 1), y = min(getr(rt, dep[x]), sz[rt] - 1) - 1;
            printf("%d\n", dep[x]);
            mdf(2, sz[rt] - 1, 1);
            if (y > 1)
                mdf(sz[rt] - y, sz[rt] - 1, -1);
            change(rt, s[x]);
        }
        if (op >= 4) {
            if (op == 4)
                splay(rt, find(rt, 2));
            else
                splay(rt, find(rt, sz[rt] - 1));
            int l = (op == 5), r = l ^ 1, y = c[rt][l];
            c[y][r] = c[rt][r];
            fa[y] = 0;
            fa[c[rt][r]] = y;
            rt = y;
            v[rt] -= 1;
            upd(rt);
        }
    }
    return 0;
}

大佬博客里的线段树代码(不是我写的)

#include<bits/stdc++.h>
#define N 200010
using namespace std;
int m,tp,root;
int opt[N],v[N],q[N],ch[N][2],fa[N],dep[N*2];
set<int>st;
int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f*=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
void down(int rt)
{
    if(dep[rt])
    {
    dep[rt<<1]+=dep[rt];
    dep[rt<<1|1]+=dep[rt];
    dep[rt]=0;
    }
}
void modify(int rt,int l,int r,int pos,int val)
{
    if(l==r){dep[rt]=val;return ;}
    down(rt);
    int mid=(l+r)>>1;
    if(pos<=mid)modify(rt<<1,l,mid,pos,val);
    else modify(rt<<1|1,mid+1,r,pos,val);
}
int query(int rt,int l,int r,int pos)
{
    if(l==r)return dep[rt];
    down(rt);
    int mid=(l+r)>>1;
    if(pos<=mid)return query(rt<<1,l,mid,pos);
    else return query(rt<<1|1,mid+1,r,pos);
}
void update(int rt,int l,int r,int L,int R,int k)
{
    if(L<=l&&R>=r){dep[rt]+=k;return ;}
    down(rt);
    int mid=(l+r)>>1;
    if(L<=mid)update(rt<<1,l,mid,L,R,k);
    if(R>mid)update(rt<<1|1,mid+1,r,L,R,k);
}
int insert(int x)
{
    set<int>::iterator it=st.insert(x).first;//定义前向迭代器,下标从插入元素开始
    if(!root){root=x;modify(1,1,tp,x,1);return 1;}//空树,插入结点深度为1
    if(it!=st.begin())//如果插入元素有前驱
    {
    if(!ch[*--it][1])ch[fa[x]=*it][1]=x;//如果前驱没有右儿子
    it++;
    }
    if(!fa[x])ch[fa[x]=*++it][0]=x;//要么成为后继的左儿子
    int deep=query(1,1,tp,fa[x])+1;
    modify(1,1,tp,x,deep);
    return deep;
}
int findmax()
{
    int x=*st.rbegin(),res=query(1,1,tp,x);
    if(x==root)return 1;
    if(x-1>=fa[x]+1)update(1,1,tp,fa[x]+1,x-1,-1);//x右树的深度都不变,所以先减1
    update(1,1,tp,1,tp,1);
    ch[fa[x]][1]=ch[x][0];
    fa[ch[x][0]]=fa[x];
    ch[fa[root]=x][0]=root;
    root=x;
    modify(1,1,tp,x,1);
    return res;
}
int findmin()
{
    int x=*st.begin(),res=query(1,1,tp,x);//找到最小值并询问深度
    if(x==root)return 1;
    if(x+1<=fa[x]-1)//x有右子数
    update(1,1,tp,x+1,fa[x]-1,-1);//x右树的深度都不变,所以先减1
    update(1,1,tp,1,tp,1);
    ch[fa[x]][0]=ch[x][1];
    fa[ch[x][1]]=fa[x];
    ch[fa[root]=x][1]=root;
    root=x;
    modify(1,1,tp,x,1);
    return res;
}
void delmax()
{
    printf("%d\n",findmax());
    update(1,1,tp,1,tp,-1);
    st.erase(root);
    root=ch[root][0];
    fa[root]=0;
}
void delmin()
{
    printf("%d\n",findmin());
    update(1,1,tp,1,tp,-1);
    st.erase(root);
    root=ch[root][1];
    fa[root]=0;
}
int main()
{
    m=read();
    for(int i=1;i<=m;i++)
    {
    opt[i]=read();
    if(opt[i]==1)q[++tp]=v[i]=read();
    }
    sort(q+1,q+1+tp);
    for(int i=1;i<=m;i++)
    if(opt[i]==1)v[i]=lower_bound(q+1,q+1+tp,v[i])-q;//将插入的值离散化
    for(int i=1;i<=m;i++)
    {
    if(opt[i]==1){printf("%d\n",insert(v[i]));}
    else if(opt[i]==2)printf("%d\n",findmin());
    else if(opt[i]==3)printf("%d\n",findmax());
    else if(opt[i]==4)delmin();
    else if(opt[i]==5)delmax();
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/JerryZheng2005/p/10527250.html