P3380 【模板】二逼平衡树(树套树) 线段树套平衡树

\(\color{#0066ff}{ 题目描述 }\)

您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:

  1. 查询k在区间内的排名
  2. 查询区间内排名为k的值
  3. 修改某一位值上的数值
  4. 查询k在区间内的前驱(前驱定义为严格小于x,且最大的数,若不存在输出-2147483647)
  5. 查询k在区间内的后继(后继定义为严格大于x,且最小的数,若不存在输出2147483647)

\(\color{#0066ff}{输入格式}\)

第一行两个数 n,m 表示长度为n的有序序列和m个操作

第二行有n个数,表示有序序列

下面有m行,opt表示操作标号

若opt=1 则为操作1,之后有三个数l,r,k 表示查询k在区间[l,r]的排名

若opt=2 则为操作2,之后有三个数l,r,k 表示查询区间[l,r]内排名为k的数

若opt=3 则为操作3,之后有两个数pos,k 表示将pos位置的数修改为k

若opt=4 则为操作4,之后有三个数l,r,k 表示查询区间[l,r]内k的前驱

若opt=5 则为操作5,之后有三个数l,r,k 表示查询区间[l,r]内k的后继

\(\color{#0066ff}{输出格式}\)

对于操作1,2,4,5各输出一行,表示查询结果

\(\color{#0066ff}{输入样例}\)

9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5

\(\color{#0066ff}{输出样例}\)

2
4
3
4
9

\(\color{#0066ff}{数据范围与提示}\)

时空限制:2s,128M

\(n,m \leq 5\cdot {10}^4\)保证有序序列所有值在任何时刻满足 \([0, {10} ^8]\)

\(\color{#0066ff}{ 题解 }\)

可以线段树套平衡树

对于操作1,线段树每个区间在平衡树上找比k小的数的个数,加起来再加1就是排名

对于操作2,可以二分答案,然后通过操作1来判断\(O(log^3n)\)

对于操作3,相当于删除再插入,注意线段树整个一条链都要改

对于操作4,5,线段树子区间答案取max和min即可

#include<bits/stdc++.h>
#define LL long long
LL in() {
    char ch; LL x = 0, f = 1;
    while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
    for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
    return x * f;
}
const int maxn = 5e4 + 10;
const int inf = 0x7fffffff;
struct Splay {
protected:
    struct node {
        node *ch[2], *fa;
        int val, siz;
        node(node *fa = NULL, int val = 0, int siz = 0): fa(fa), val(val), siz(siz) { ch[0] = ch[1] = NULL; }
        void upd() { siz = (ch[0]? ch[0]->siz : 0) + (ch[1]? ch[1]->siz : 0) + 1; }
        bool isr() { return this == fa->ch[1]; }
        int rk() { return ch[0]? ch[0]->siz + 1 : 1; }
    }*root;
    void rot(node *x) {
        node *y = x->fa, *z = y->fa;
        bool k = x->isr(); node *w = x->ch[!k];
        if(y != root) z->ch[y->isr()] = x;
        else root = x;
        x->ch[!k] = y, y->ch[k] = w;
        y->fa = x, x->fa = z;
        if(w) w->fa = y;
        y->upd(), x->upd();
    }
    void splay(node *o) {
        while(o != root) {
            if(o->fa != root) rot(o->isr() ^ o->fa->isr()? o : o->fa);
            rot(o);
        }
    }
    node *merge(node *x, node *y, node *fa) {
        if(x) x->fa = fa;
        if(y) y->fa = fa;
        if(!x || !y) return x? x : y;
        if(rand() & 1) return x->ch[1] = merge(x->ch[1], y, x), x->upd(), x;
        else return y->ch[0] = merge(x, y->ch[0], y), y->upd(), y;
    }
public:
    int rnk(int val) {
        node *o = root, *lst = root; int rank = 0;
        while(o) {
            lst = o;
            if(val > o->val) rank += o->rk(), o = o->ch[1];
            else o = o->ch[0];
        }
        return splay(lst), rank;
    }
    int kth(int k) {
        node *o = root;
        while(o->rk() != k) {
            if(k > o->rk()) k -= o->rk(), o = o->ch[1];
            else o = o->ch[0];
        }
        return splay(o), o->val;
    }
    int pre(int val) {
        node *o = root, *lst = root;
        while(o) {
            if(o->val < val) lst = o, o = o->ch[1];
            else o = o->ch[0];
        }
        return splay(lst), lst->val;
    }
    int nxt(int val) {
        node *o = root, *lst = root;
        while(o) {
            if(o->val > val) lst = o, o = o->ch[0];
            else o = o->ch[1];
        }
        return splay(lst), lst->val;
    }
    void ins(int val) {
        if(!root) return (void)(root = new node(NULL, val, 1));
        node *o = root, *fa = NULL;
        while(o) fa = o, o = o->ch[val > o->val];
        fa->ch[val > fa->val] = o = new node(fa, val, 1);
        splay(o);
    }
    void del(int val) {
        node *o = root;
        while(o->val != val) o = o->ch[val > o->val];
        if(!o) return;
        splay(o);
        root = merge(o->ch[0], o->ch[1], NULL);
        delete o;
    }
};
struct SGT {
private:
    struct node {
        int l, r;
        node *ch[2];
        Splay *s;
        node(int l = 0, int r = 0, Splay *s = NULL): l(l), r(r), s(s) { ch[0] = ch[1] = NULL; }
    }*root;
    void build(node *&o, int l, int r, int *a) {
        o = new node(l, r, new Splay());
        for(int i = l; i <= r; i++) o->s->ins(a[i]);
        if(l == r) return;
        int mid = (l + r) >> 1;
        build(o->ch[0], l, mid, a), build(o->ch[1], mid + 1, r, a);
    }
    int rnk(node *o, int l, int r, int val) {
        if(o->r < l || o->l > r) return 0;
        if(l <= o->l && o->r <= r) return o->s->rnk(val);
        return rnk(o->ch[0], l, r, val) + rnk(o->ch[1], l, r, val);
    }
    int pre(node *o, int l, int r, int val) {
        if(o->r < l || o->l > r) return inf;
        if(l <= o->l && o->r <= r) return o->s->pre(val);
        int ans = -inf;
        int L = pre(o->ch[0], l, r, val);
        int R = pre(o->ch[1], l, r, val);
        if(L < val) ans = std::max(ans, L);
        if(R < val) ans = std::max(ans, R);
        return ans;
    }
    int nxt(node *o, int l, int r, int val) {
        if(o->r < l || o->l > r) return -inf;
        if(l <= o->l && o->r <= r) return o->s->nxt(val);
        int ans = inf;
        int L = nxt(o->ch[0], l, r, val);
        int R = nxt(o->ch[1], l, r, val);
        if(L > val) ans = std::min(ans, L);
        if(R > val) ans = std::min(ans, R);
        return ans;
    }
    void change(node *o, int pos, int val, int old) {
        if(o->r < pos || o->l > pos) return;
        o->s->del(old);
        o->s->ins(val);
        if(o->l == o->r) return;
        change(o->ch[0], pos, val, old);
        change(o->ch[1], pos, val, old);
    }
public:
    void build(int *a, int l, int r) { build(root, l, r, a); }
    int rnk(int val, int l, int r) { return rnk(root, l, r, val) + 1; }
    int kth(int k, int L, int R) {
        int l = 0, r = 1e8, ans = 0;
        while(l <= r) {
            int mid = (l + r) >> 1;
            if(rnk(mid, L, R) <= k) ans = mid, l = mid + 1;
            else r = mid - 1;
        }
        return ans;
    }
    void change(int pos, int old, int now) { change(root, pos, now, old); }
    int pre(int val, int l, int r) { return pre(root, l, r, val); }
    int nxt(int val, int l, int r) { return nxt(root, l, r, val); }
}v;
int a[maxn];
int main() {
    int p, l, r, k, n = in(), m = in();
    for(int i = 1; i <= n; i++) a[i] = in();
    v.build(a, 1, n);
    while(m --> 0) {
        p = in();
        if(p == 1) l = in(), r = in(), k = in(), printf("%d\n", v.rnk(k, l, r));
        if(p == 2) l = in(), r = in(), k = in(), printf("%d\n", v.kth(k, l, r));
        if(p == 3) l = in(), k = in(), v.change(l, a[l], k), a[l] = k;
        if(p == 4) l = in(), r = in(), k = in(), printf("%d\n", v.pre(k, l, r));
        if(p == 5) l = in(), r = in(), k = in(), printf("%d\n", v.nxt(k, l, r));
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/olinr/p/10333100.html