[九省联考2018]秘密袭击coat

题目描述

Luogu
题目大意:给一棵\(n\)个点的树,求所有联通块中第\(K\)大的权值\(W_k\)之和。
数据范围:\(K\leq n\leq 1666\) , \(W_{max}\leq 1666\),答案对\(64123\)取模,时限\(7sec\)

题解

\(Ans = \sum_{S} Kth\ of\ S = \sum_{v = 1}^W v\sum_{S} [Kth\ of\ S\ = v]\)
\(Ans = \sum_{v = 1}^W \sum_{S} [Kth\ of\ S \ge v]\)
我们令\(cnt(S,v)\)表示\(S\)中权值大于等于\(v\)的节点个数。
\(Ans = \sum_{v=1}^W \sum_{S} [cnt(S,v)\ge K]\)
然后就可以设计一个\(dp\)了,设\(f_{u,v,j}\)表示\(u\)为根的联通块中,\(W\ge v\)联通块数。
转移显然:\(f_{u,v,j} = \prod_{son_i} f_{son_i,v,k_{son_i}}\),其中\(\sum_{son_i} k_{son_i} = j - [W_u\ge v]\)
根据上述可得:\(Ans = \sum_{u} \sum_{v=1}^W \sum_{j=K}^n f_{u,v,j}\)
卡一下\(j\)那一维,用树上\(lca\)那套分析一下,复杂度就是严格\(O(n^2W)\)的。
然后竟然就能成功AC原题数据了qaq......
不管了。
可以注意到\(j\)那一维是一个背包,所以自然就能想到生成函数。
\(F_{u,v} = \sum_{j=0}^n f_{u,v,j} x^j\),设\(G_{u,v} = \sum_{s\in Tree_u} F_{s,v} = \sum_{j=0}^n g_{u,v,j}x^j\),那么\(F\)的转移就是一个卷积了。
即初始化后,\(F_{u,v} = \prod_{son_i} F_{son_i,v}\)
\(Ans = \sum_{u} \sum_{v=1}^W \sum_{j=K}^n f_{u,v,j} = \sum_{v=1}^W \sum_{j=K}^n g_{1,v,j} = \sum_{j=K}^n \sum_{v=1}^W g_{1,v,j}\)
我们的目标即求\(\sum_{v=1}^W G_{1,v}\)的每一项系数。
每次转移都卷积显然是傻子。
熟悉\(FFT\)原理的童鞋都知道先用点值表示,最后再拉格朗日插值回去即可得到每一项的系数。
外部枚举\(x = 1,2...n+1\),下面我们来考虑如何计算\(F\)\(G\)的点值表示。
注意到由于转了点值表示,所以多项式乘法是对位乘法,就可以用线段树合并维护了。
线段树每个叶子节点\(v\)维护点值\(F_{u,v},G_{u,v}\),我们在每个点\(u\)要干这些事:

  • 初始化:把区间\([1,W_u]\)\(F\)加上\(x\),把区间\((W_u,W_{max}]\)\(F\)加上\(1\)
  • 合并:把\(F_{son_i}\)对位相乘,\(G_{son_i}\)对位相加。
  • 结束:把\(G_{u}\)加上\(F_u\),把\(F_u\)\(1\),便于下次转移。

维护一个标记\((a,b,c,d)\),表示\((F,G)\ \to\ F(aF + b , cF + d + G)\)
那么初始化对应标记\((1,x,0,0)\)\((1,1,0,0)\)。结束对应标记\((1,1,1,0)\)
合并的时候,线段树合并,设合并\(x\)\(y\)
若其中一个点(以\(y\)为例)没有儿子了,也就是说下面的节点的\((F,G)\)都是一样的了。
此时\(F_y = b\)\(G_y = d\),对应\((F_x,G_x)\to (F_xF_y,G_x+G_y)\),修改\(x\)的标记,然后\(return\)即可。
最后把根节点的线段树遍历一遍,标记都放下去后叶子节点\(v\)的标记中的\(d\)\(G_{1,v}\)
最后套一下拉格朗日公式:\(H(x) = \sum_{i=1}^{n+1} H(i) \prod_{j\neq i} \frac{x-j}{i-j}\)
你说每次算\(\prod(x - j)\)\(O(n^2)\)的?
多项式除法了解一下蟹蟹qwq......先别管\(j\neq i\)就行了。
复杂度\(O(n^2logW)\)被暴力吊着打,代码其实挺短的。

实现代码

#include<bits/stdc++.h>
#define IL inline
#define _ 2005
#define ll long long
#define ld long double
using namespace std ; 

IL ll gi(){
    ll data = 0 , m = 1; char ch = 0 ; 
    while((ch != '-') && (ch < '0' || ch > '9')) ch = getchar() ; 
    if(ch == '-'){m = 0 ; ch = getchar() ; }
    while(ch >= '0' && ch <= '9'){data = (data<<1) + (data<<3) + (ch^48) ; ch = getchar() ; }
    return (m) ? data : -data ; 
} 

#define mod 64123 

int n , K , W , oo , stk[_ * _] , ans[_] , f[_] , fz[_] , H[_] , Ans , inv[mod] , rt[_] , val[_] ;  

struct _Edge{
    int to , next ; 
}Edge[_ << 1] ; int head[_] , CNT ; 
IL void AddEdge(int u , int v) {
    Edge[++ CNT] = (_Edge){v , head[u]} ; head[u] = CNT ; return ; 
}

struct Target {
    int a , b , c , d ; 
    IL Target() {a = 1 ; b = 0 ; c = 0 ; d = 0 ; } 
    IL Target(int s1,int s2,int s3,int s4) {a = s1 ; b = s2 ; c = s3 ; d = s4 ; }
} ;
IL Target operator + (Target A , Target B) {
    Target C ; 
    C.a = 1ll * A.a * B.a % mod ; C.b = (1ll * B.a * A.b % mod + B.b) % mod ; 
    C.c = (1ll * A.a * B.c % mod + A.c) % mod ; 
    C.d = (1ll * A.b * B.c % mod + A.d + B.d) % mod ; 
    return C ; 
}
struct Node {
    int ls , rs ; Target tag ; 
    IL Node(){ls = rs = 0 ; tag = Target() ; return ; }
}t[_ * _] ; 

IL int NewNode() {
    if(stk[0]) {t[stk[stk[0]]] = Node() ; return stk[stk[0] --] ; }
    else {t[++oo] = Node() ; return oo ; }
}
void PushDown(int o) {
    if(!t[o].ls) t[o].ls = NewNode() ; if(!t[o].rs) t[o].rs = NewNode() ; 
    t[t[o].ls].tag = t[t[o].ls].tag + t[o].tag ; 
    t[t[o].rs].tag = t[t[o].rs].tag + t[o].tag ; 
    t[o].tag = Target() ; 
    return ;  
}
void Insert(int &o , int l , int r , int ql , int qr , Target E) {
    if(!o) o = NewNode() ; if(ql <= l && r <= qr) {t[o].tag = t[o].tag + E ; return ; }
    int mid = (l + r) >> 1 ; 
    PushDown(o) ; 
    if(ql <= mid) Insert(t[o].ls , l , mid , ql , qr , E) ; 
    if(qr  > mid) Insert(t[o].rs , mid + 1 , r , ql , qr , E) ; 
    return ; 
}
int Merge(int o , int os) {
    if(!o || !os) return o + os ; 
    if(!t[o].ls && !t[o].rs) swap(o , os) ; 
    if(!t[os].ls && !t[os].rs) {
        t[o].tag.a = 1ll * t[o].tag.a * t[os].tag.b % mod ; 
        t[o].tag.b = 1ll * t[o].tag.b * t[os].tag.b % mod ; 
        t[o].tag.d = (t[o].tag.d + t[os].tag.d) % mod ; 
        stk[++stk[0]] = os ; 
        return o ; 
    }
    PushDown(o) ; PushDown(os) ;
    t[o].ls = Merge(t[o].ls , t[os].ls) ; 
    t[o].rs = Merge(t[o].rs , t[os].rs) ; 
    stk[++stk[0]] = os ; 
    return o ; 
}

void Dfs(int u , int From , int x) {
    Insert(rt[u] , 1 , W , 1 , val[u] , Target(1 , x , 0 , 0)) ; 
    if(val[u] + 1 <= W) Insert(rt[u] , 1 , W , val[u] + 1 , W , Target(1 , 1 , 0 , 0)) ; 
    for(int e = head[u] ; e ; e = Edge[e].next) {
        int v = Edge[e].to ; if(v == From) continue ; 
        Dfs(v , u , x) ;
        rt[u] = Merge(rt[u] , rt[v]) ;  
    }
    t[rt[u]].tag = t[rt[u]].tag + Target(1 , 1 , 1 , 0) ; return ; 
}

void GetAns(int o , int l , int r , int x) {
    if(l == r) {H[x] = (H[x] + t[o].tag.d) % mod ; return ; }
    PushDown(o) ;
    int mid = (l + r) >> 1 ; 
    GetAns(t[o].ls , l , mid , x) ; GetAns(t[o].rs , mid + 1 , r , x) ; 
    return ;
}
IL void Solve(int x) {
    oo = 0 ; stk[0] = 0 ; for(int i = 1; i <= n; i ++) rt[i] = 0 ; 
    Dfs(1 , 0 , x) ;
    H[x] = 0 ; GetAns(rt[1] , 1 , W , x) ; 
}

IL void Lagrange() {
    inv[0] = 1 ; inv[1] = 1 ; for(int i = 2; i < mod; i ++) inv[i] = 1ll * inv[mod % i] * (mod - mod / i) % mod ; 
    fz[0] = 1 ; 
    for(int i = 1; i <= n + 1; i ++) 
        for(int j = n + 1; j >= 0; j --) 
            if(j) fz[j] = (fz[j - 1] + 1ll * fz[j] * (mod - i) % mod) % mod ;  else fz[j] = 1ll * fz[j] * (mod - i) % mod ;  
    for(int i = 1; i <= n + 1; i ++) {
        int coef = 1 ; 
        for(int j = 1; j <= n + 1; j ++) if(i != j) coef = 1ll * coef * (i + mod - j) % mod ; 
        for(int j = 0; j <= n + 1; j ++) f[j] = fz[j] ;
        for(int j = 0; j <= n + 1; j ++) {
            if(j) f[j] = (f[j] - f[j - 1] + mod) % mod ;
            f[j] = 1ll * inv[mod - i] * f[j] % mod ; 
        }
        coef = 1ll * inv[coef] * H[i] % mod ; 
        for(int j = 0; j <= n; j ++) ans[j] = (ans[j] + 1ll * coef * f[j] % mod) % mod ;  
    }
    return ; 
}

int main() {
    n = gi() ; K = gi() ; W = gi() ; 
    for(int i = 1; i <= n; i ++) val[i] = gi() ; 
    for(int i = 1,u,v; i < n; i ++) u = gi() , v = gi() , AddEdge(u , v) , AddEdge(v , u) ; 
    Solve(1) ; 
    for(int i = 1; i <= n + 1; i ++) Solve(i) ; 
    //for(int i = 1; i <= n + 1; i ++) cout << "H("<<i<<") = " << H[i] << endl ; 
    Lagrange() ; 
    Ans = 0 ; 
    for(int j = K; j <= n; j ++) Ans = (Ans + ans[j]) % mod ; 
    cout << Ans << endl ; 
    return 0 ; 
}

所以所谓的整体DP到底是啥啊,根本没看到什么虚树的影子啊?

猜你喜欢

转载自www.cnblogs.com/GuessYCB/p/10355766.html