Skip to content

FHQ Treap

约 166 个字 129 行代码 预计阅读时间 2 分钟

P3369 【模板】普通平衡树 - 洛谷
Treap,支持:
1. 向 M 中插入一个数 x。
2. 从 M 中删除一个数 x(若有多个相同的数,应只删除一个)。
3. 查询 M 中有多少个数比 x 小,并且将得到的答案加一。
4. 查询如果将 M 从小到大排列后,排名位于第 x 位的数。
5. 查询 M 中 x 的前驱(前驱定义为小于 x,且最大的数)。
6. 查询 M 中 x 的后继(后继定义为大于 x,且最小的数)。

模板

C++
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
#define INF 0x3f3f3f3f
#define LL long long
const int MAXN = 1e5 + 5;
const double EPS = 1e-9;
struct node {
    int l, r;
    int val;
    int key;
    int size;
} tr[MAXN];
int root, idx;
void newnode(int &x, int v)
{
    x = ++idx;
    tr[x].val = v;
    tr[x].key = rand();
    tr[x].size = 1;
}
void pushup(int p) { tr[p].size = tr[tr[p].l].size + tr[tr[p].r].size + 1; }
void split(int p, int v, int &x, int &y)
{
    if (!p) {
        x = 0, y = 0;
        return;
    }
    if (tr[p].val <= v) {
        x = p;
        split(tr[x].r, v, tr[x].r, y);
        pushup(x);
    } else {
        y = p;
        split(tr[y].l, v, x, tr[y].l);
        pushup(y);
    }
}
int merge(int x, int y)
{
    if (!x || !y) {
        return x + y;
    }
    if (tr[x].key < tr[y].key) {
        tr[x].r = merge(tr[x].r, y);
        pushup(x);
        return x;
    } else {
        tr[y].l = merge(x, tr[y].l);
        pushup(y);
        return y;
    }
}
void insert(int v)
{
    int x, y, z;
    split(root, v, x, y);
    newnode(z, v);
    root = merge(merge(x, z), y);
}
void del(int v)
{
    int x, y, z;
    split(root, v, x, z);
    split(x, v - 1, x, y);
    y = merge(tr[y].l, tr[y].r);
    root = merge(merge(x, y), z);
}
int getrank(int v)
{
    int x, y;
    split(root, v - 1, x, y);
    int ans = tr[x].size + 1;
    root = merge(x, y);
    return ans;
}
int getval(int root, int v)
{
    if (v == tr[tr[root].l].size + 1) {
        return tr[root].val;
    } else if (v <= tr[tr[root].l].size) {
        return getval(tr[root].l, v);
    } else {
        return getval(tr[root].r, v - tr[tr[root].l].size - 1);
    }
}
int getpre(int v)
{
    int x, y, s, ans;
    split(root, v - 1, x, y);
    s = tr[x].size;
    ans = getval(x, s);
    root = merge(x, y);
    return ans;
}
int getnxt(int v)
{
    int x, y, ans;
    split(root, v, x, y);
    ans = getval(y, 1);
    root = merge(x, y);
    return ans;
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    int n;
    cin >> n;
    while (n--) {
        int op, v;
        cin >> op >> v;

        if (op == 1) {
            insert(v);
        } else if (op == 2) {
            del(v);
        } else if (op == 3) {
            cout << getrank(v) << endl;
        } else if (op == 4) {
            cout << getval(root, v) << endl;
        } else if (op == 5) {
            cout << getpre(v) << endl;
        } else {
            cout << getnxt(v) << endl;
        }
    }
    return 0;
}