LOJ2269. 「SDOI2017」切树游戏 [FWT,动态DP]

LOJ

思路

显然是要DP的。设\(dp_{u,i}\)表示\(u\)子树内一个包含\(u\)的连通块异或出\(i\)的方案数,发现转移可以用FWT优化,写成生成函数就是这样的:
\[ dp_{u}=x^{val_u}\prod (dp_v+1) \]
最后答案是所有DP值的和,于是获得了朴素的\(O(nmQ)\)的做法。(中间运算全部用点值表示)

显然是要用动态DP优化的,我们另外记一个\(S_u\)表示子树的DP值和自己的DP值的和,写成矩阵的形式,就是
\[ \left[\begin{matrix} dp_u\\S_u\\1 \end{matrix}\right] = \left[\begin{matrix} dp'_u&0&dp'_u\\dp'_u&1&S'_u\\0&0&1 \end{matrix}\right] \times \left[\begin{matrix} dp_v\\S_v\\1 \end{matrix}\right] \]
(转移的意义:\(dp_u=dp'_u+dp'_udp_v,S_u=S'_u-dp'_u+dp_u+S_v\)

当然,这个只能用来做一个儿子,在多个儿子的时候还是不能直接把矩阵乘起来的。

考虑链剖,把轻儿子的信息合并在一起存在矩阵里,用矩阵加速重链上的转移,就对了。

怎么修改呢?

对于要修改的这个点,看上面的转移方程,发现只有\(x^{val_u}\)有改变,于是除掉之前的改成新的就可以了。

对于上面的点,要改变的是后面的\(dp_v+1\),所以也是把原来的除掉换成新的。

除的时候可能会除0,所以数字要换成\(x\times mod^y\)的记录形式来做除法。这东西一旦有两个数相加就炸了,但你发现只有\(dp_u\)需要做除法,而它只和乘除有关,所以没有问题。

但你这样常数又炸了,所以你还需要发现矩阵\(\left[\begin{matrix} a&0&b\\c&1&d\\0&0&1 \end{matrix}\right]\)的乘法有封闭性,所以只要维护四个值,就快了。

一开始DP的时候有一个细节。转移的时候一定不能把乘法的括号拆开,不然你就没了。

输出答案的时候有一个细节。正常的转移矩阵应该是后面要乘一个向量才能得到正确解,但你发现转移矩阵右边一列就是那个向量,所以可以直接把所有矩阵乘在一起之后用右边一列的信息。注意此时左边\(a,c\)两个元素已经不知道是什么东西了。(初学动态DP的时候在这上面蒙了好久)

(这题就当是复习动态DP吧,毕竟和FWT没有太大关系……)

代码

由于我空间炸了,需要用short存东西……

#include<bits/stdc++.h>
clock_t t=clock();
namespace my_std{
    using namespace std;
    #define pii pair<int,int>
    #define fir first
    #define sec second
    #define MP make_pair
    #define rep(i,x,y) for (int i=(x);i<=(y);i++)
    #define drep(i,x,y) for (int i=(x);i>=(y);i--)
    #define go(x) for (int i=head[x];i;i=edge[i].nxt)
    #define templ template<typename T>
    #define sz 30303
    #define SS 130
    #define mod 10007
    typedef long long ll;
    typedef double db;
    mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
    templ inline T rnd(T l,T r) {return uniform_int_distribution<T>(l,r)(rng);}
    templ inline bool chkmax(T &x,T y){return x<y?x=y,1:0;}
    templ inline bool chkmin(T &x,T y){return x>y?x=y,1:0;}
    templ inline void read(T& t)
    {
        t=0;char f=0,ch=getchar();double d=0.1;
        while(ch>'9'||ch<'0') f|=(ch=='-'),ch=getchar();
        while(ch<='9'&&ch>='0') t=t*10+ch-48,ch=getchar();
        if(ch=='.'){ch=getchar();while(ch<='9'&&ch>='0') t+=d*(ch^48),d*=0.1,ch=getchar();}
        t=(f?-t:t);
    }
    template<typename T,typename... Args>inline void read(T& t,Args&... args){read(t); read(args...);}
    char __sr[1<<21],__z[20];int __C=-1,__zz=0;
    inline void Ot(){fwrite(__sr,1,__C+1,stdout),__C=-1;}
    inline void print(register int x)
    {
        if(__C>1<<20)Ot();if(x<0)__sr[++__C]='-',x=-x;
        while(__z[++__zz]=x%10+48,x/=10);
        while(__sr[++__C]=__z[__zz],--__zz);__sr[++__C]='\n';
    }
    void file()
    {
        #ifdef NTFOrz
        freopen("a.in","r",stdin);
        #endif
    }
    inline void chktime()
    {
        #ifndef ONLINE_JUDGE
        cout<<(clock()-t)/1000.0<<'\n';
        #endif
    }
    ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x%mod) if (y&1) ret=ret*x%mod;return ret;}
//  inline ll mul(ll a,ll b){ll d=(ll)(a*(double)b/mod+0.5);ll ret=a*b-d*mod;if (ret<0) ret+=mod;return ret;}
}
using namespace my_std;

int n,m,mm;

int inv[mod+5];
struct Int{short a,z;short v(){return z?0:a;}};
#define Int(x,y) ((Int){x,y})
Int operator + (Int a,Int b){return Int((a.v()+b.v())%mod,0);}
Int operator - (Int a,Int b){return Int((a.v()-b.v()+mod)%mod,0);};
Int operator * (Int a,Int b){ if (b.v()) a.a=1ll*a.a*b.a%mod; else a.z++; return a; }
Int operator / (Int a,Int b){ if (b.v()) a.a=1ll*a.a*inv[b.a]%mod; else a.z--; return a; }
struct Array
{
    Int a[SS];
    const Array operator + (const Array &x) const {Array ret;rep(i,0,m-1) ret.a[i]=a[i]+x.a[i];return ret;}
    const Array operator - (const Array &x) const {Array ret;rep(i,0,m-1) ret.a[i]=a[i]-x.a[i];return ret;}
    const Array operator * (const Array &x) const {Array ret;rep(i,0,m-1) ret.a[i]=a[i]*x.a[i];return ret;}
    const Array operator / (const Array &x) const {Array ret;rep(i,0,m-1) ret.a[i]=a[i]/x.a[i];return ret;}
}fwt[SS];
struct Matrix
{
    Array a,b,c,d;
    const Matrix operator * (const Matrix &x) const {return (Matrix){a*x.a,a*x.b+b,c*x.a+x.c,c*x.b+d+x.d};}
};
void FWT(Array &a,int type)
{
    Int p,q,I=Int(ksm(2,mod-2),0);
    rep(i,0,mm-1)
        for (int mid=1<<i,j=0;j<m;j+=mid<<1)
            rep(k,0,mid-1)
            {
                p=a.a[j+k],q=a.a[j+k+mid];
                if (type==1) a.a[j+k]=p+q,a.a[j+k+mid]=p-q;
                else a.a[j+k]=(p+q)*I,a.a[j+k+mid]=(p-q)*I;
            }
}

int val[sz];
struct hh{int t,nxt;}edge[sz<<1];
int head[sz],ecnt;
void make_edge(int f,int t)
{
    edge[++ecnt]=(hh){t,head[f]};
    head[f]=ecnt;
    edge[++ecnt]=(hh){f,head[t]};
    head[t]=ecnt;
}

int dfn[sz],pre[sz],size[sz],son[sz],top[sz],bot[sz],fa[sz],T;
#define v edge[i].t
void dfs1(int x,int f)
{
    size[x]=1,fa[x]=f;
    go(x) if (v!=f)
    {
        dfs1(v,x);
        size[x]+=size[v];
        if (size[v]>size[son[x]]) son[x]=v;
    }
}
void dfs2(int x,int fa,int tp)
{
    pre[dfn[bot[top[x]=tp]=x]=++T]=x;
    if (son[x]) dfs2(son[x],x,tp);
    go(x) if (v!=fa&&v!=son[x]) dfs2(v,x,v);
}
Array dp[sz],S[sz];
void dfs(int x,int fa)
{
    dp[x]=S[x]=fwt[val[x]];
    go(x) if (v!=fa)
    {
        dfs(v,x);
        S[x]=S[x]+dp[x]*dp[v]+S[v];
        dp[x]=dp[x]+dp[x]*dp[v];
    }
}
#undef v

Matrix tr[sz<<2],tmp[sz];
#define lson k<<1,l,mid
#define rson k<<1|1,mid+1,r
void build(int k,int l,int r)
{
    if (l==r)
    {
        int x=pre[l];Array f,s;f=s=fwt[val[x]];
        #define v edge[i].t
        go(x) if (v!=fa[x]&&v!=son[x]) s=s+f*dp[v]+S[v],f=f*(fwt[0]+dp[v]);
        #undef v
        tr[k]=tmp[l]=(Matrix){f,f,f,s};
        return;
    }
    int mid=(l+r)>>1;
    build(lson),build(rson);
    tr[k]=tr[k<<1]*tr[k<<1|1];
}
void modify(int k,int l,int r,int x)
{
    if (l==r) return (void)(tr[k]=tmp[l]);
    int mid=(l+r)>>1;
    if (x<=mid) modify(lson,x);
    else modify(rson,x);
    tr[k]=tr[k<<1]*tr[k<<1|1];
}
Matrix query(int k,int l,int r,int x,int y)
{
    if (x<=l&&r<=y) return tr[k];
    int mid=(l+r)>>1;
    if (y<=mid) return query(lson,x,y);
    if (x>mid) return query(rson,x,y);
    return query(lson,x,y)*query(rson,x,y);
}
#undef lson
#undef rson 

void modify(int x,int w)
{
    Array p=tmp[dfn[x]].a,s=tmp[dfn[x]].d-p;
    p=p/fwt[val[x]];p=p*fwt[w];val[x]=w;
    tmp[dfn[x]].a=tmp[dfn[x]].b=tmp[dfn[x]].c=p;tmp[dfn[x]].d=s+p;
    while (233)
    {
        Matrix a=query(1,1,n,dfn[top[x]],dfn[bot[top[x]]]);
        modify(1,1,n,dfn[x]);
        Matrix b=query(1,1,n,dfn[top[x]],dfn[bot[top[x]]]);
        x=fa[top[x]]; if (!x) return;
        Matrix &M=tmp[dfn[x]];
        Array f0=M.a,s0=M.d;s0=s0-f0;
        f0=f0/(fwt[0]+a.b);
        f0=f0*(fwt[0]+b.b);
        s0=s0-a.d+b.d;
        tmp[dfn[x]]=(Matrix){f0,f0,f0,s0+f0};
    }
}

int main()
{
    file();
    rep(i,1,mod-1) inv[i]=ksm(i,mod-2);
    read(n,m);mm=log2(m);
    rep(i,0,m-1) fwt[i].a[i]=Int(1,0),FWT(fwt[i],1);
    rep(i,1,n) read(val[i]);
    int x,y;
    rep(i,1,n-1) read(x,y),make_edge(x,y);
    dfs1(1,0),dfs2(1,0,1),dfs(1,0),build(1,1,n);
    int Q;read(Q);char s[15];
    while (Q--)
    {
        cin>>s;
        if (s[0]=='C') read(x,y),modify(x,y);
        else
        {
            read(x);
            Array ans=query(1,1,n,1,dfn[bot[1]]).d;
            FWT(ans,-1);
            printf("%d\n",ans.a[x].v());
        }
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/p-b-p-b/p/11403109.html