【模板】树套树(线段树套Splay)

如题,这是一个模板。。。

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cctype>

#define max(x, y) (x > y ? x : y)
#define min(x, y) (x < y ? x : y)

inline void read(int & x)
{
    x = 0;
    int k = 1;
    char c = getchar();
    while (!isdigit(c))
        if (c == '-') c = getchar(), k = -1;
        else c = getchar();
    while (isdigit(c))
        x = (x << 1) + (x << 3) + (c ^ 48),
        c = getchar();
    x *= k;  
}

const int inf = 2147483647;
const int N = 4001000;
int n, m, tot, l, r, x, y, z, opt, mxa = -2147483647;
int cnt[N], faz[N], val[N], siz[N], a[N], son[N][2], root[N];

//=======================================================================
//Splay

inline int Getson(int u) { return son[faz[u]][1] == u; }

inline void Pushup(int u) { siz[u] = siz[son[u][0]] + siz[son[u][1]] + cnt[u]; }

inline int Getmin(int u) { while (son[u][0]) u = son[u][0]; return u; }

inline int Getmax(int u) { while (son[u][1]) u = son[u][1]; return u; }

inline int Getx(int rtid, int x)
{
    int u = root[rtid], las = 0;
    while (u && val[las = u] != x)
        if (x >= val[u]) u = son[u][1];
        else u = son[u][0];
    return u ? u : las; 
}

void Rotate(int u)
{
    int y = faz[u], z = faz[y], ch = Getson(u);
    int b = son[u][ch ^ 1], d = Getson(y);
    son[u][ch ^ 1] = y, son[y][ch] = b;
    faz[y] = u, faz[b] = y, faz[u] = z;
    if (z) son[z][d] = u;
    Pushup(y), Pushup(u);
}

void Splay(int rtid, int u, int tar)
{
    while (faz[u] != tar)
    {
        if (faz[faz[u]] != tar)
            if (Getson(u) == Getson(faz[u])) Rotate(faz[u]);
            else Rotate(u);
        Rotate(u);
    }
    if (!tar) root[rtid] = u;
}

inline void Insert(int rtid, int x)
{
    int u = root[rtid], las = 0;
    if (!root[rtid])
    {
        root[rtid] = u = ++tot;
        val[u] = x; siz[u] = cnt[u] = 1;
        faz[u] = son[u][0] = son[u][1] = 0;
        return;
    }
    while (true)
    {
        ++siz[u];
        if (x == val[u]) 
        {
            ++cnt[u];
            break;
        }
        las = u;
        if (x > val[u]) u = son[u][1];
        else u = son[u][0];
        if (!u)
        {
            u = ++tot, val[u] = x, faz[u] = las,
            son[las][x > val[las]] = u;
            son[u][0] = son[u][1] = 0, 
            siz[u] = cnt[u] = 1;
            break;    
        }
    }
    Splay(rtid, u, 0);
}

inline void Delete(int rtid, int x)
{
    int u = root[rtid];
    while (u)
    if (x == val[u]) 
    {
        Splay(rtid, u, 0);
        if (cnt[u] > 1) { --cnt[u], --siz[u]; return; }
        if (!son[u][0] || !son[u][1]) 
        {
            root[rtid] = son[u][0] | son[u][1];
            faz[root[rtid]] = 0;
            return;
        }
        int newrt = Getmin(son[u][1]);
        faz[son[u][1]] = 0,
        faz[son[u][0]] = newrt,
        son[newrt][0] = son[u][0];
        Splay(rtid, newrt, 0);
        return;
    }
    else if (x > val[u]) u = son[u][1];
    else u = son[u][0];
}

inline int Getkth(int rtid, int k)
{
    int u = root[rtid];
    if (siz[u] < k) return -inf;
    while (u)
        if (siz[son[u][0]] >= k) u = son[u][0];
        else if (siz[son[u][0]] + cnt[u] < k) k -= siz[son[u][0]] + cnt[u], u = son[u][1];
        else return val[u];
}

inline int Getrank(int rtid, int x)
{
//-------------------·½·¨Ò»------------------------
//Wrong Answer
/*    int u = Getx(rtid, x);
    Splay(rtid, u, 0);
    return siz[son[u][0]];
*/

//-------------------·½·¨¶þ------------------------ 
int sum = 0, u = root[rtid];
    while (u)
        if (x == val[u]) return sum + siz[son[u][0]];
        else if (x > val[u]) sum += siz[son[u][0]] + cnt[u], u = son[u][1];
        else u = son[u][0];
    return sum;
}

int Pre(int rtid, int x)
{
    int u = Getx(rtid, x);
       if (val[u] < x) return val[u];
    Splay(rtid, u, 0);
    return son[u][0] == 0 ? -inf : val[Getmax(son[u][0])];
}

int Suf(int rtid, int x)
{
    int u = Getx(rtid, x);
       if (val[u] > x) return val[u];
    Splay(rtid, u, 0);
    return son[u][1] == 0 ? inf : val[Getmin(son[u][1])];
}

//Splay End
//=======================================================================
//Segment tree Begin

#define Root 1, 1, n
#define Lson u << 1, l, mid
#define Rson u << 1 | 1, mid + 1, r
#define ls (u << 1)
#define rs (u << 1 | 1)
#define mid (l + r >> 1)
#define MID (L + R >> 1)

inline void Segins(int u, int l, int r, int p, int x)
{
    Insert(u, x);
    if (l == r) return;
    if (p <= mid) Segins(Lson, p, x);
    else Segins(Rson, p, x);
}

inline void Segmdf(int u, int l, int r, int p, int x)
{
    Delete(u, a[p]), Insert(u, x);
    if (l == r) { a[p] = x; return; }
    if (p <= mid) Segmdf(Lson, p, x);
    else Segmdf(Rson, p, x);
}

inline int Segrak(int u, int l, int r, int x, int y, int z)
{
    if (l == x && r == y) return Getrank(u, z);
    if (y <= mid) return Segrak(Lson, x, y, z);
    if (x > mid) return Segrak(Rson, x, y, z);
    return Segrak(Lson, x, mid, z) + Segrak(Rson, mid + 1, y, z);
}

inline int Segpre(int u, int l, int r, int x, int y, int z)
{
    if (l == x && r == y) return Pre(u, z);
    if (y <= mid) return Segpre(Lson, x, y, z);
    if (x > mid) return Segpre(Rson, x, y, z);
    return max(Segpre(Lson, x, mid, z), Segpre(Rson, mid + 1, y, z));
}

inline int Segsuf(int u, int l, int r, int x, int y, int z)
{
    if (l == x && r == y) return Suf(u, z);
    if (y <= mid) return Segsuf(Lson, x, y, z);
    if (x > mid) return Segsuf(Rson, x, y, z);
    return min(Segsuf(Lson, x, mid, z), Segsuf(Rson, mid + 1, y, z));
}

inline int Segkth(int l, int r, int k)
{
    int cur, L = 0, R = mxa + 1;
    while (L < R)
    {
        cur = Segrak(Root, l, r, MID);
        if (cur < k) L = MID + 1;
        else R = MID;
    }
    return L - 1;
}

signed main()
{
    read(n), read(m);
    for (int i = 1; i <= n; ++i)
    {
        read(a[i]);
        Segins(Root, i, a[i]);
        mxa = max(mxa, a[i]);
    }
    for (int i = 1; i <= m; ++i)
    {
        read(opt);
        if (opt == 3) read(x), read(y), Segmdf(Root, x, y);
        else
        {
            read(l), read(r), read(x);
            if (opt == 1) printf("%d\n", Segrak(Root, l, r, x) + 1);
            if (opt == 2) printf("%d\n", Segkth(l, r, x));
            if (opt == 4) printf("%d\n", Segpre(Root, l, r, x));
            if (opt == 5) printf("%d\n", Segsuf(Root, l, r, x));
        }
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/yanyiming10243247/p/10057812.html