[Luogu 3401] 洛谷树

Description

有一棵树,要求支持

  1. 查询两点间简单路径的所有子链的异或和的和
  2. 修改某条边的权值

Solution

这种树上异或问题首先应该想到对于每个点存下一个前缀异或和表示这个点到根节点路径的异或和。那么两点之间路径的异或和就等于这两点的前缀和再异或起来。

于是操作一变成了:有k个点,每个点有权值,问\(\sum \limits_{i=1}^k\sum\limits_{j=i+1}^k val[i]\oplus val[j]\)

由于是异或运算,我们按位考虑。

对于二进制位 \(p\),假设这 \(k\) 个数中有 \(x\) 个的第 \(p\) 位为1,剩下的为 \(0\),那么对答案有贡献的实际上就只有 \(x\times (k-x)\) 个点对,也就是说只有这么多点对异或起来的值为 \(1\)。这启示我们对于每个二进制位,都找到多少位是0,多少位是1,把他们乘起来就好了。

Code

#include<cstdio>
#include<cctype>
#include<cstring>
#include<iostream>
#include<algorithm>
#define N 30005
using std::min;
using std::max;
using std::swap;
#define int long long
#define ls cur<<1,l,mid,ql,qr
#define rs cur<<1|1,mid+1,r,ql,qr

int head[N],dfn[N],top[N],d[N];
int n,m,cnt,tot,lazy[N<<2],cme[N];
int sze[N],son[N],dis[N],fs[N],fa[N];

struct Edge{
    int to,nxt,dis;
}edge[N<<1];

struct Node{
    int a[12][2];

    friend Node operator+(Node x,Node y){
        Node z;memset(z.a,0,sizeof z.a);
        for(int i=1;i<=10;i++){
            z.a[i][0]=x.a[i][0]+y.a[i][0];
            z.a[i][1]=x.a[i][1]+y.a[i][1];
        } return z;
    }
}sum[N<<2];

void add(int x,int y,int z){
    edge[++cnt].to=y;
    edge[cnt].nxt=head[x];
    edge[cnt].dis=z;
    head[x]=cnt;
}

inline int getint(){
    int X=0;int w=0;char ch=0;
    while(!isdigit(ch))w|=ch=='-',ch=getchar();
    while( isdigit(ch))X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
    return w?-X:X;
}

void dfs(int now){
    sze[now]=1;
    for(int i=head[now];i;i=edge[i].nxt){
        int to=edge[i].to;
        if(sze[to]) continue;
        d[to]=d[now]+1;
        dis[to]=dis[now]^edge[i].dis;cme[to]=edge[i].dis;
        dfs(to);sze[now]+=sze[to];fa[to]=now;
        if(sze[to]>sze[son[now]])
            son[now]=to;
    }
}

void dfs2(int now,int low){
    dfn[now]=++tot;fs[tot]=now;top[now]=low;
    if(son[now])
        dfs2(son[now],low);
    for(int i=head[now];i;i=edge[i].nxt){
        int to=edge[i].to;
        if(dfn[to]) continue;
        dfs2(to,to);
    }
}

void pushup(int cur){
    sum[cur]=sum[cur<<1]+sum[cur<<1|1];
}

void build(int cur,int l,int r){
    if(l==r){
        int now=dis[fs[l]];
        for(int i=1;i<=10;i++){
            if(now>>i-1&1)
                sum[cur].a[i][1]++;
            else sum[cur].a[i][0]++;
        } return;
    }
    int mid=l+r>>1;
    build(cur<<1,l,mid);build(cur<<1|1,mid+1,r);
    pushup(cur);
}

void pushdown(int cur){
    if(!lazy[cur]) return;
    for(int i=1;i<=10;i++){
        if(lazy[cur]>>i-1&1){
            swap(sum[cur<<1].a[i][0],sum[cur<<1].a[i][1]);
            swap(sum[cur<<1|1].a[i][0],sum[cur<<1|1].a[i][1]);
        }
    }
    lazy[cur<<1]^=lazy[cur];lazy[cur<<1|1]^=lazy[cur];lazy[cur]=0;
}

Node query(int cur,int l,int r,int ql,int qr){
    if(ql<=l and r<=qr)
        return sum[cur];
    int mid=l+r>>1;pushdown(cur);
    Node z;memset(z.a,0,sizeof z.a);
    if(ql<=mid)
        z=z+query(ls);
    if(mid<qr)
        z=z+query(rs);
    return z;
}

int ask(int x,int y){
    Node z;memset(z.a,0,sizeof z.a);
    while(top[x]!=top[y]){
        // printf("X=%lld,y=%lld\n",x,y);
        if(d[top[x]]<d[top[y]])
            swap(x,y);
        z=z+query(1,1,n,dfn[top[x]],dfn[x]);
        x=fa[top[x]];
    }
    if(d[x]<d[y]) swap(x,y);
    z=z+query(1,1,n,dfn[y],dfn[x]);
    int ans=0;
    for(int i=1;i<=10;i++)
        ans+=(1<<i-1)*z.a[i][0]*z.a[i][1];
    return ans;
}

void modify(int cur,int l,int r,int ql,int qr,int z){
    if(ql<=l and r<=qr){
        for(int i=1;i<=10;i++){
            if(z>>i-1&1)
                swap(sum[cur].a[i][0],sum[cur].a[i][1]);
        }
        lazy[cur]^=z;return;
    }
    pushdown(cur);int mid=l+r>>1;
    if(ql<=mid)
        modify(ls,z);
    if(mid<qr)
        modify(rs,z);
    pushup(cur);
}

signed main(){
    n=getint(),m=getint();
    for(int i=1;i<n;i++){
        int x=getint(),y=getint(),z=getint();
        add(x,y,z);add(y,x,z);
    }
    d[1]=1;dfs(1);dfs2(1,1);build(1,1,n);
    while(m--){
        if(getint()==1){
            int x=getint(),y=getint();
            printf("%lld\n",ask(x,y));
        } else{
            int x=getint(),y=getint(),z=getint();
            if(d[x]<d[y]) swap(x,y);
            modify(1,1,n,dfn[x],dfn[x]+sze[x]-1,cme[x]^z);
            cme[x]=z;
        }
    } return 0;
}

猜你喜欢

转载自www.cnblogs.com/YoungNeal/p/9569162.html