bzoj 5395[Ynoi2016]谁的梦 set+map

定义一个序列的权值为不同数字的个数,例如 [1,2,3,3] 权值为 3
现在有n个序列,我们在每个序列里面选一个连续非空子串,拼接起来,求所有选法得到的序列的权值之和
如果一个序列能通过多种方法被选择出来,那么计算多次
本题带修改操作,格式请参考输入格式
由于结果可能过大,请输出答案 mod 19260817 的结果
Input
第一行两个数 n,m,表示有n个序列,m次修改
然后n个数,第i个数是leni,表示第i个序列的长度
之后n行,每行leni个数,表示第i个序列
之后m行,每行三个数x,y,z表示将第x个序列的第y个元素改为z
1 <= n,m <= 100000,序列中的元素均为 32 位整型数,leni的和 <= 100000
共53组数据
Output
输出m + 1行,依次表示初始局面以及每次修改后的答案。

solution
这是一道好题!
考虑对颜色分开考虑贡献。补集转化为颜色不贡献的所有区间。
即每个序列颜色的相邻位置中选区间。再把每个序列的答案乘起来。
然而要用set维护每个颜色在每个序列的所有出现位置来支持修改。
并且还要开map维护第i个序列中颜色j对应的set标号,总之细节挺多!
map可以直接开二维的,听说也是一个log
有个地方忘了取模调了半小时QWQ

#include<bits/stdc++.h>
using namespace std;
#define maxn 200020
#define rep(i,l,r) for(register int i = l ; i <= r ; i++)
#define repd(i,r,l) for(register int i = r ; i >= l ; i--)
#define inf 1e8

typedef long long ll;
const ll mod = 19260817;
struct node{
    int i,c;
    node(){};
    node(int x,int y):i(x),c(y){};
    bool operator < (node a)const{
        if ( c == a.c ) return i < a.i;
        return c < a.c;
    }
};
struct node2{
    int x,y,z;
}que[maxn];
set <int> s[maxn];
set <int>::iterator it,it2,it3;
int a[maxn],b[maxn],len[maxn],c[maxn],n,m,tot,cnt,vis[maxn],num[maxn];
ll tmp,sum[maxn],ans,inv[maxn];
vector <int> vec[maxn],col[maxn];

map <node,int> mp;
map <node,ll> rec;

inline ll cal(int len){
    return ((ll)len * (len + 1) / 2) % mod;
}
inline ll power(ll x,ll y){
    ll res = 1;
    while ( y ){
        if ( y & 1 ) res = res * x % mod;
        x = x * x % mod;
        y >>= 1;
    }
    return res;
}
void pre(){
    sort(c + 1,c + tot + 1);
    rep(i,1,n) rep(j,0,len[i] - 1){
        vec[i][j] = lower_bound(c + 1,c + tot + 1,vec[i][j]) - c;
        col[vec[i][j]].push_back(i);
    }
//  rep(i,1,n){
//      rep(j,0,len[i] - 1) cout<<vec[i][j]<<" ";
//      cout<<endl;
//  }
    rep(i,1,m) que[i].z = lower_bound(c + 1,c + tot + 1,que[i].z) - c;
    rep(i,1,n){
        rep(j,0,len[i] - 1){
            int id = mp[node(i,vec[i][j])];
            if ( !id ) id = mp[node(i,vec[i][j])] = ++cnt;
            s[id].insert(j + 1);
        }
    }
    tmp = 1;
    rep(i,1,n) tmp = tmp * cal(len[i]) % mod;
    //inv[0] = 1;
    //rep(i,1,100000) inv[i] = inv[mod % i] * (mod - mod / i) % mod;
    rep(i,1,tot){
        sum[i] = 1;
        sort(col[i].begin(),col[i].end());
        col[i].erase(unique(col[i].begin(),col[i].end()),col[i].end());
        for(register int j = 0 ; j < col[i].size() ; j++){
            ll cur = 0; int id = mp[node(col[i][j],i)],last = 0;        
            for (it = s[id].begin() ; it != s[id].end() ; ++it){
            //  cout<<*it<<" ";
                cur = (cur + cal(*it - last - 1)) % mod;
                last = *it;
            }
            cur = (cur + cal(len[col[i][j]] - last)) % mod;
            //cout<<endl<<cur<<endl;
            if ( !cur ) num[i]++ , sum[i] = sum[i] * power(cal(len[col[i][j]]),mod - 2) % mod;
            else sum[i] = sum[i] * cur % mod * power(cal(len[col[i][j]]),mod - 2) % mod;
            rec[node(col[i][j],i)] = cur;
        }
        sum[i] = sum[i] * tmp % mod; 
        ans = (tmp - sum[i] * (num[i] ? 0 : 1) + ans) % mod;
    }
//  rep(i,1,tot) cout<<sum[i]<<" ";
    //cout<<endl;
    ans = (ans % mod + mod) % mod;
    printf("%lld\n",ans);
}
void solve(){
    rep(i,1,m){
        int x = que[i].x , y = que[i].y , z = que[i].z , c = vec[x][y - 1];
        if ( c == z ){ printf("%lld\n",ans); continue; }
        vec[x][y - 1] = z;
        node cur = node(x,c);
        int id = mp[cur]; ll res = rec[cur]; 
        //==================delete c==========================================
        ans -= tmp - sum[c] * (num[c] ? 0 : 1);
        //判断改行贡献是否为0,sum[i]记录非0行的贡献
        if ( !res ) num[c]--;
        else sum[c] = sum[c] * power(res,mod - 2) % mod;
        it = s[id].find(y); int pos = *it; it2 = it++;
        if ( it2 == s[id].begin() ){
            if ( it != s[id].end() ){
                int ppos = *(it);
                res -= cal(pos - 1) + cal(ppos - pos - 1);
                res += cal(ppos - 1);   
            }
            else{
                res -= cal(pos - 1) + cal(len[x] - pos);
                res += cal(len[x]);
            }
        }   
        else{
            if ( it != s[id].end() ){
                int ppos = *(it) ,pre = *(--it2);
                res -= cal(ppos - pos - 1) + cal(pos - pre - 1);
                res += cal(ppos - pre - 1); 
            }
            else{
                int pre = *(--it2);
                res -= cal(len[x] - pos) + cal(pos - pre - 1);
                res += cal(len[x] - pre);   
            }
        }
        res = (res % mod + mod) % mod;
        s[id].erase(y) , rec[cur] = res;
        sum[c] = sum[c] * res % mod;
        ans += tmp - sum[c] * (num[c] ? 0 : 1);
        //===============================================================
        //insert z
        cur = node(x,z);
        ans -= tmp - sum[z] * (num[z] ? 0 : 1);
        if ( !mp[cur] ){
            mp[cur] = ++cnt;
            sum[z] = sum[z] * power(cal(len[x]),mod - 2) % mod;
            res = (cal(y - 1) + cal(len[x] - y)) % mod;
            s[cnt].insert(y); 
        }
        else{
            res = rec[cur]; id = mp[cur];
                //判断改行贡献是否为0,sum[i]记录非0行的贡献
            if ( !res ) num[z]--;
            else sum[z] = sum[z] * power(res,mod - 2) % mod;
            it = s[id].lower_bound(y);
        //  for(it3 = s[id].begin() ; it3 != s[id].end() ; ++it3) cout<<*it3<<" ";
        //  cout<<endl<<s[id].size()<<" ";
            if ( !s[id].size() ){
                res = res - cal(len[x]) + cal(y - 1) + cal(len[x] - y);
            }
            else if ( it == s[id].end() ){
                pos = *(--it);
                res = res - cal(len[x] - pos) + cal(y - pos - 1) + cal(len[x] - y);
            }
            else if ( it == s[id].begin() ){
                pos = *(it);
                res = res - cal(pos - 1) + cal(y - 1) + cal(pos - y - 1);
            }
            else{   
                int pos2 = *it , pos = *(--it);
                res = res - cal(pos2 - pos - 1) + cal(pos2 - y - 1) + cal(y - pos - 1);
            }
            s[id].insert(y);
        }
        res = (res % mod + mod) % mod , rec[cur] = res;
        if ( !res ) num[z]++;
        else sum[z] = sum[z] * res % mod;
        ans += tmp - sum[z] * (num[z] ? 0 : 1);
        ans = (ans % mod + mod) % mod;
        printf("%lld\n",ans);
    //  cout<<sum[c]<<" "<<sum[z]<<endl;
    }
}
int main(){
    freopen("input.txt","r",stdin);
    scanf("%d %d",&n,&m);
    rep(i,1,n) scanf("%d",&len[i]);
    rep(i,1,n) rep(j,1,len[i]){ int x; scanf("%d",&x) , vec[i].push_back(x) , c[++tot] = x; }
    rep(i,1,m){
        int x,y,z;
        scanf("%d %d %d",&x,&y,&z);
        que[i] = (node2){x,y,z};
        c[++tot] = z;
    }
    pre();
    solve();
    return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_42484877/article/details/81292768