题解: 考虑每个第 k 大贡献多少个集合 i∑i∗S∑[kth(S)=i]=i∑S∑[kth(S)≥i] 那么我们可以枚举每一个 i,然后算出第 k 大 ≥i 的个数 这个个数等价于 ≥i 的至少有 k 个的个数 那么我们可以令 dpu,j 表示到 u 有 j 个 ≥i 的个数 写成生成函数 Fu(x)=(1orx)∏(Fv(x)+1),求的是 Gu(x)=∑Gv(x)+Fu(x) 本来以为要 dsuontree 然后分治 fft 的,好像可以两个 log,突然发现模数不对 其实可以考虑求出 n+1 个点值,然后拉个朗日插值,其中拉个朗日插值的复杂度可以做到 O(n2) 暴力求 n+1 个点值的复杂度是 O(n3),但是我们可以把枚举 i 这一步省掉 用线段树维护每一个 i 的答案,那么初始化就是一段赋 1 一段赋 x 维护一个系数类 (a,b,c,d) 表示把 (f,g) 变成 (af+b,cf+d+g) 就可以表示所有转移 并支持结合律 考虑线段树合并把每个 dp 值都搞上去,当合并到一棵树的叶子结点也就是下面一个区间全部是 0 的时候,整个区间对当前的贡献可以通过这个系数类对区间打一个乘法标记,这样就支持合并了 复杂度 O(n2logn)
#include<bits/stdc++.h>#define cs const
using namespace std;
int read(){
int cnt = 0, f = 1; char ch = 0;
while(!isdigit(ch)){ ch = getchar(); if(ch =='-') f = -1;}
while(isdigit(ch)) cnt = cnt*10 + (ch-'0'), ch = getchar();return cnt * f;}
cs int N = 1680;
cs int Mod = 64123;
int add(int a, int b){return a + b >= Mod ? a + b - Mod : a + b;}
int mul(int a, int b){return 1ll * a * b % Mod;}
int dec(int a, int b){return a - b < 0 ? a - b + Mod : a - b;}
int ksm(int a, int b){ int ans = 1; for(;b;b>>=1,a=mul(a,a)) if(b&1) ans = mul(ans, a);return ans;}
void Add(int &a, int b){ a = add(a, b);}
void Mul(int &a, int b){ a = mul(a, b);}
void Dec(int &a, int b){ a = dec(a, b);}
int sgn(int a){return a & 1 ? Mod - 1 : 1;}
int n, k, W, d[N];
vector<int> G[N];#define pb push_back#define poly vector<int>
int rt[N];
struct Coef{
int a, b, c, d;
Coef(int _a=0, int _b=0, int _c=0, int _d=0){ a=_a; b=_b; c=_c; d=_d;}
Coef operator * (cs Coef &A){return Coef(mul(a,A.a),add(mul(b,A.a),A.b),add(mul(a,A.c),c),add(mul(b,A.c),add(d,A.d)));}
void operator *=(cs Coef &A){ *this = *this * A;}};
namespace SGT{
cs int N = ::N * 200;
int nd; int ls[N], rs[N]; Coef vl[N];
int newnode(){ int x = ++nd; ls[x]= rs[x]= 0; vl[x]= Coef(1,0,0,0);return x;}#define mid ((l+r)>>1)
void down(int x){
if(!ls[x]) ls[x]= newnode();
if(!rs[x]) rs[x]= newnode();
vl[ls[x]] *= vl[x];
vl[rs[x]] *= vl[x];
vl[x]= Coef(1,0,0,0);}
void modify(int &x, int l, int r, int L, int R, Coef coe){
if(!x) x = newnode(); if(L<=l && r<=R){ vl[x] *= coe;return;} down(x);
if(L<=mid) modify(ls[x], l, mid, L, R, coe);
if(R>mid) modify(rs[x], mid+1, r, L, R, coe);}
void merge(int &x, int y){
if(!x ||!y){ x |= y;return;}
if(!ls[x]&&!rs[x]) swap(x, y);
if(!ls[y]&&!rs[y]){
vl[x] *= Coef(vl[y].b,0,0,0);
vl[x] *= Coef(1,0,0,vl[y].d);return;} down(x); down(y);
merge(ls[x], ls[y]);
merge(rs[x], rs[y]);}
int query(int x, int l, int r){
if(l == r)return vl[x].d; down(x);return add(query(ls[x],l,mid), query(rs[x],mid+1,r));}}
void dfs(int u, int fa, int c){
SGT::modify(rt[u],1,W,1,W,Coef(0,1,0,0));
for(int v: G[u]) if(v ^ fa){
dfs(v, u, c);
SGT::merge(rt[u], rt[v]);}
SGT::modify(rt[u],1,W,1,d[u],Coef(c,0,0,0));
SGT::modify(rt[u],1,W,1,W,Coef(1,0,1,0));
SGT::modify(rt[u],1,W,1,W,Coef(1,1,0,0));}
int y[N], inv[N];
poly operator * (poly a, poly b){
int deg = a.size() + b.size() - 1; poly c(deg,0);
for(int i = 0; i < a.size(); i++)
for(int j = 0; j < b.size(); j++)
Add(c[i+j], mul(a[i],b[j]));return c;}
poly operator * (poly a, int coe){
for(int i = 0; i < a.size(); i++) Mul(a[i],coe);return a;}
poly operator + (poly a, poly b){
int deg = max(a.size(),b.size()); a.resize(deg); b.resize(deg);
for(int i = 0; i < deg; i++) Add(a[i],b[i]);return a;}
poly operator / (poly a, poly b){
int deg = a.size() - b.size() + 1; poly c(deg, 0);
for(int i =(int)a.size() - 1; i > 0; i--){
c[i-1]= a[i]; Dec(a[i-1], mul(b[0], a[i]));}return c;}
int work(int n){
inv[0]= inv[1]= 1;
for(int i = 2; i <= n; i++) inv[i]= mul(Mod-Mod/i, inv[Mod%i]);
for(int i = 2; i <= n; i++) Mul(inv[i],inv[i-1]);
poly f, g; f.pb(1); g.pb(0); g.pb(1);
for(int i = 1; i <= n; i++) Dec(g[0],1), f = f * g;
poly ans;
for(int i = 1; i <= n; i++){
int coe = mul(mul(sgn(n-i),y[i]),mul(inv[i-1],inv[n-i]));
poly res; res.pb(dec(0,i)); res.pb(1);
ans = ans + (f / res) * coe;}
int sm = 0;
for(int i = k; i < ans.size(); i++) Add(sm, ans[i]);return sm;}
int main(){
n = read(), k = read(), W = read();
for(int i = 1; i <= n; i++) d[i]= read();
for(int i = 1, x, y; i < n; i++)
x = read(), y = read(), G[x].pb(y), G[y].pb(x);
for(int c = 1; c <= n + 1; c++){
dfs(1,0,c);
y[c]= SGT::query(rt[1],1,W); SGT::nd = 0;
memset(rt, 0, sizeof(int)*(n+1));} cout << work(n+1);return 0;}