版权声明:本文为博主原创文章,喜欢就点个赞吧 https://blog.csdn.net/Anxdada/article/details/81866464
题意:就是一排n个数字, 每个数字代表从该位置出发能弹到的后面第几个位置, 每次询问一个开始位置, 问弹出这个序列至少弹多少次.
思路: 也有分块做法, 这里讲讲LCT, 我们建立一个虚拟点n+1, 如果弹出了这个序列, 我们就向这个虚拟点连边, 否则就向它能弹到的点连边, 所以每次询问我们直接抽出询问点到n+1点的路径, 然后查询n+1的维护的size 即可, size维护的是该点的子树个数. 实际上-1就是ans.
AC Code
const int maxn = 2e5 + 5;
int a[maxn];
struct node {
int fa, son[2], sz, val, lazy;
void init() {
sz = 1; lazy = 0;
fa = son[0] = son[1] = 0;
}
}t[maxn];
bool nroot(int x) {
return t[t[x].fa].son[0] == x || t[t[x].fa].son[1] == x;
}
void pushup(int x) {
t[x].sz = 1;
int l = t[x].son[0], r = t[x].son[1];
if(l) t[x].sz += t[l].sz;
if(r) t[x].sz += t[r].sz;
}
void pushdown(int x) {
if(t[x].lazy) {
t[x].lazy = 0;
swap(t[x].son[0], t[x].son[1]);
t[t[x].son[0]].lazy ^= 1;
t[t[x].son[1]].lazy ^= 1;
}
}
void rot(int x) {
int fa = t[x].fa, gfa = t[t[x].fa].fa;
int k = (x == t[fa].son[1]);
t[fa].son[k] = t[x].son[k^1];
if(nroot(fa)) t[gfa].son[fa == t[gfa].son[1]] = x;
if(t[x].son[k^1]) t[t[x].son[k^1]].fa = fa;
t[x].son[k^1] = fa;
t[fa].fa = x; t[x].fa = gfa;
pushup(fa);
}
int stk[maxn], top;
void splay(int x) {
top = 0; stk[++top] = x;
for(int i = x ; nroot(i) ; i = t[i].fa) {
stk[++top] = t[i].fa;
}
while(top) pushdown(stk[top--]);
for(int fa ; nroot(x) ; rot(x)) {
if(nroot(fa = t[x].fa))
rot((t[x].son[0] == x) ^ (t[fa].son[0] == fa) ? fa : x);
}
pushup(x);
}
int access(int x) {
int i;
for(i = 0 ; x ; x = t[i = x].fa) {
splay(x); t[x].son[1] = i;
if (i) t[i].fa = x; pushup(x);
}
return i;
}
int lca(int x, int y) {
access(x); return access(y);
}
void makeroot(int x){
access(x); splay(x); t[x].lazy ^= 1;
}
int findroot(int x) {
access(x); splay(x);
while(t[x].son[0]) pushdown(x), x = t[x].son[0];
return x;
}
void split(int x, int y) {
makeroot(x); access(y); splay(y);
}
void link(int x, int y) {
makeroot(x);
if(findroot(y) != x) t[x].fa = y;
}
void cut(int x, int y) {
makeroot(x);
if(findroot(y) == x && t[x].fa == y && !t[x].son[1]){
t[x].fa = t[y].son[0] = 0;
pushup(y);
}
}
int query(int x, int y) {
split(x, y);
return t[y].sz;
}
void solve() {
int n; scanf("%d", &n); t[n+1].init();
for (int i = 1 ; i <= n ; i ++) {
scanf("%d", a+i);
t[i].init();
}
for (int i = 1 ; i <= n ; i ++) {
if (i + a[i] <= n) link(i, i+a[i]);
else link(i, n+1);
}
int q; scanf("%d", &q);
while(q--) {
int op, x, y;
scanf("%d%d", &op, &x); ++x;
if (op == 1) {
printf("%d\n", query(x, n+1)-1);
}
else {
scanf("%d", &y);
cut(x, x + a[x] <= n ? x + a[x] : n + 1);
link(x, x + y <= n ? x + y : n + 1);
a[x] = y;
}
}
}