[CF1083D]The Fair Nut’s getting crazy[单调栈+线段树]

题意

给定一个长度为 \(n\) 的序列 \(\{a_i\}\)。你需要从该序列中选出两个非空的子段,这两个子段满足: - 两个子段非包含关系。 - 两个子段存在交。 - 位于两个子段交中的元素在每个子段中只能出现一次。 求共有多少种不同的子段选择方案。输出总方案数对 \(10^9 + 7\) 取模后的结果。 需要注意的是,选择子段 \([a, b]\)\([c, d]\) 与选择子段 \([c, d]\)\([a, b]\) 被视为是相同的两种方案。 \(1 \leq n \leq 10^5, -10^9 \leq a_i \leq 10^9\)

分析

  • 考虑枚举一个区间 \([b,c]\) 作为交,记录 \(L_i,R_i\) 表示距离 \(i\) 最近的和 \(i\) 颜色相同的位置。
  • 有: \(a\in[\max\limits_{i=b}^c{L_i},b),d\in(c,\min\limits_{i=b}^c{R_i}]\)
  • 记录可以取到的左端点的最小值(满足区间中不存在两个相同的数) \(pos\)\(mi, mx\) 分别表示 \([j,i]\)\(R\) 的极小值和 \(L\) 的极大值。
  • 考虑从左到右枚举交区间的右端点 \(i\) ,用单调栈维护每个位置的 \(mi, mx\) 。容易得到以 \(i\) 为交区间的右端点的方案数为 \(\sum_{j=pos}^i(mi_j-i)(j-mx_j)​\),拆开然后用线段树分别维护。

  • 总时间复杂度为 \(O(nlogn)\)

代码

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
#define go(u) for(int i = head[u], v = e[i].to; i; i=e[i].lst, v=e[i].to)
#define rep(i, a, b) for(int i = a; i <= b; ++i)
#define pb push_back
#define re(x) memset(x, 0, sizeof x)
inline int gi() {
    int x = 0,f = 1;
    char ch = getchar();
    while(!isdigit(ch)) { if(ch == '-') f = -1; ch = getchar();}
    while(isdigit(ch)) { x = (x << 3) + (x << 1) + ch - 48; ch = getchar();}
    return x * f;
}
template <typename T> inline void Max(T &a, T b){if(a < b) a = b;}
template <typename T> inline void Min(T &a, T b){if(a > b) a = b;}
const int N = 1e5 + 7, mod = 1e9 + 7;
int n, vc;
LL ans;
int lst[N], L[N], R[N], V[N], a[N];
int st1[N], st2[N], tp1, tp2;
#define Ls o << 1
#define Rs (o << 1 | 1)
LL s1(int n) {
    return 1ll * n * (n + 1) / 2;
}
LL ami[N << 2], amx[N << 2];
struct data {
    LL mi, mx, smi, tm;
    data operator +(const data &rhs) const {
        return (data){ (mi + rhs.mi) % mod, (mx + rhs.mx) % mod, (smi + rhs.smi) % mod, (tm + rhs.tm) % mod};
    }
}t[N << 2];
void add(LL &a, LL b) {
    a += b;if(a >= mod) a -= mod;
}
void stmi(int l, int r, int o, int v) {
    add(ami[o], v);
    add(t[o].tm, 1ll * v * t[o].mx % mod);
    add(t[o].mi, 1ll * (r - l + 1) * v % mod);
    add(t[o].smi, (s1(r) - s1(l - 1)) % mod * v % mod);
}
void stmx(int l, int r, int o, int v) {
    add(amx[o], v);
    add(t[o].tm, 1ll * v * t[o].mi % mod);
    add(t[o].mx, 1ll * (r - l + 1) * v % mod);
}
void pushdown(int l, int r, int o) {
    int mid = l + r >> 1;
    if(ami[o]) {
        stmi(l, mid, Ls, ami[o]);
        stmi(mid + 1, r, Rs, ami[o]);
    }
    if(amx[o]) {
        stmx(l, mid, Ls, amx[o]);
        stmx(mid + 1, r, Rs, amx[o]);
    }
    ami[o] = amx[o] = 0;
}
void pushup(int o) {
    t[o] = t[Ls] + t[Rs];
}
void modify(int L, int R, int l, int r, int o, int v, int opt) {
    if(L <= l && r <= R) {
        if(!opt) stmi(l, r, o, v);
        else stmx(l, r, o, v);
        return;
    }
    pushdown(l, r, o);int mid = l + r >> 1;
    if(L <= mid) modify(L, R, l, mid, Ls, v, opt);
    if(R > mid)  modify(L, R, mid + 1, r, Rs, v, opt);
    pushup(o);
}
data query(int L, int R, int l, int r, int o) {
    if(L <= l && r <= R) return t[o];
    pushdown(l, r, o);int mid = l + r >> 1;
    if(R <= mid) return query(L, R, l, mid, Ls);
    if(L > mid)  return query(L, R, mid + 1, r, Rs);
    return query(L, R, l, mid, Ls) + query(L, R, mid + 1, r, Rs);
}
int main() {
    n = gi();
    rep(i, 1, n) a[i] = gi(), V[i] = a[i];
    sort(V + 1, V + 1 + n);
    vc = unique(V + 1, V + 1 + n) - V - 1;
    rep(i, 1, n) a[i] = lower_bound(V + 1, V + 1 + vc, a[i]) - V;
    rep(i, 1, n) {
        L[i] = lst[a[i]] + 1;
        lst[a[i]] = i;
    }
    rep(i, 1, vc) lst[i] = n + 1;
    for(int i = n; i; --i) {
        R[i] = lst[a[i]] - 1;
        lst[a[i]] = i;
    }
    for(int i = 1, gg = 1; i <= n; ++i) {
        for(; tp1 && L[i] >= L[st1[tp1]]; --tp1) {
            modify(st1[tp1 - 1] + 1, st1[tp1], 1, n, 1, mod - L[st1[tp1]], 1);
        }
        modify(st1[tp1] + 1, i, 1, n, 1, L[i], 1);
        st1[++tp1] = i;
        for(; tp2 && R[i] <= R[st2[tp2]]; --tp2) {
            modify(st2[tp2 - 1] + 1, st2[tp2], 1, n, 1, mod - R[st2[tp2]], 0);
        }
        modify(st2[tp2] + 1, i, 1, n, 1, R[i], 0);
        st2[++tp2] = i;
        
        Max(gg, L[i]);
        data res = query(gg, i, 1, n, 1);
        LL tmp = ((res.smi + i * res.mx % mod - res.tm - (s1(i) - s1(gg - 1)) % mod * i % mod) % mod + mod) % mod;
        add(ans, tmp);
    }
    printf("%lld\n", ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/yqgAKIOI/p/10212225.html