Address
LOJ#2473 洛谷P4365
Solution
Part 1
- 题意即为求所有连通块的第
K 大值之和,对
64123 取模。
- 直接求不太好做,考虑对原问题进行一些转化:
ans=S∈T∑kth of S(1)
=i=1∑WiS∈T∑[kth of S=i](2)
=i=1∑WS∈T∑[kth of S≥i](3)
=i=1∑WS∈T∑[cntS[i]≥K](4)
-
T 表示树上的所有连通块集合。
- 第
(2) 步到第
(3) 步是一个常见的转化:对于布尔表达式
[kth of S≥i],每个布尔表达式
[kth of S=i] 恰好被算了
i 次。
- 第
(4) 步中的
cntS[i] 表示连通块
S 中权值
≥i 的结点个数。
- 据此我们可以设
f[u][i][j] 表示在以
u 为根的子树内包含点
u 且恰好有
j 个权值
≥i 的结点的连通块个数。
- 特别地,我们定义选取的连通块大小为空也是一种方案。
- 则最后的答案为:
u=1∑ni=1∑Wj=K∑nf[u][i][j]
- 转移显然为:
f[u][i][j]=v∈sonu∏f[v][i][jv](du≥i,jv≥0∑jv=j−1)
f[u][i][j]=v∈sonu∏f[v][i][jv](du<i,jv≥0∑jv=j)
- 初值
f[u][i][ [du≥i] ]=1,每次枚举子节点
v 的
jv 进行合并。
- 以
u 为根的子树转移结束后,因为连通块大小为空也算一种方案,我们要令
f[u][i][0] 加一。
- 时间复杂度
O(n2W),实现时可以把
i 这一维提到外面枚举,就不需要在状态中记录
i,
只要有足够优秀的常数就能通过所有数据 。
Part 2
- 考虑复杂度正确的做法,即如何优化 DP 的转移。
- 容易发现转移是一个卷积的形式,我们设
F[u][i] 表示
f[u][i][j] 的生成函数,即:
F[u][i]=j=0∑nf[u][i][j]xj
- 我们再另外设
G[u][i]=v∈subtreeu∑F[v][i],答案就变为
G[root][i] (
root 为我们选定的树根) 的后
n−K+1 项的系数之和。
- 转移:
F[u][i]=xv∈sonu∏F[v][i](du≥i)
F[u][i]=v∈sonu∏F[v][i](du<i)
G[u][i]=F[u][i]+v∈sonu∑G[v][i]
- 同样地,以
u 为根的子树转移结束后,我们要令
F[u][i] 的常数项加一。
- 考虑如果我们直接维护这些多项式,不仅复杂度没有保证、常数极大,模数也不是 NTT 模数,无法较为方便地进行卷积,显然是行不通的。
- 换个思路考虑,多项式点值的加法和乘法运算相当方便。我们把所有
G[root][i] 相加合并为一个多项式。现在我们只要求出这个多项式的
n+1 个点值,就可以通过拉格朗日插值公式
O(n2) 还原这个多项式的各项系数,从而得到答案。
- 我们在最外层枚举
x=1→n+1,那么点值关于
i 这一维的初值设定就相当于线段树上的区间修改,转移则可以看做多棵线段树的合并。
- 考虑直接在线段树上打标记维护
F,G 的点值
(f,g)。
- 具体地,我们定义一种对于
(f,g) 的变换
(a,b,c,d) 表示把
(f,g) 变为
(a×f+b,c×f+d+g),直接在线段树上维护这样的标记。
- 标记的合并如下:
(a1,b1,c1,d1)+(a2,b2,c2,d2)
=(a1×a2,a2×b1+b2,a1×c2+c1,b1×c2+d1+d2)
- 则对于每个结点
u 相当于一开始在
[1,du] 打上标记
(1,x,0,0),在
[du+1,W] 打上标记
(1,1,0,0),转移完后再打上标记
(1,1,1,0)。
- 注意动态开点的线段树在标记下传时也要新建结点,并且如果合并时其中一个结点没有左右子节点,我们不能直接标记下传(那样新建的结点数就不止与两棵线段树公共的结点数有关,时间复杂度不正确)。
- 因为
(f,g) 初值都为 0,我们可以直接把这个结点的
b 乘到另一个结点的
a,b 上,把这个结点的
d 加到另一个结点的
d 上,就不用再继续合并下去了。这其实也是我们需要维护
a,c 的原因。
- 最终叶子结点的
(f,g) 即为其标记中的
b 和
d。
- 时间复杂度
O(n2logW),
由于常数原因效率可能还不如暴力做法。
Code
#include <bits/stdc++.h>
template <class T>
inline void read(T &res)
{
char ch;
while (ch = getchar(), !isdigit(ch));
res = ch ^ 48;
while (ch = getchar(), isdigit(ch))
res = res * 10 + ch - 48;
}
using std::vector;
const int mod = 64123;
const int M = 1e6 + 5;
const int N = 2005;
int F[N], rt[N], stk[M], yy[N], d[N];
int top, T, n, K, W, fans;
vector<int> e[N];
inline void add(int &x, int y)
{
x += y;
x >= mod ? x -= mod : 0;
}
inline int quick_pow(int x, int k)
{
int res = 1;
while (k)
{
if (k & 1)
res = 1ll * res * x % mod;
x = 1ll * x * x % mod;
k >>= 1;
}
return res;
}
struct tag
{
int a, b, c, d;
tag() {}
tag(int A, int B, int C, int D):
a(A), b(B), c(C), d(D) {}
inline tag operator + (const tag &x) const
{
return tag(1ll * a * x.a % mod,
(1ll * x.a * b + x.b) % mod,
(1ll * a * x.c + c) % mod,
(1ll * b * x.c + d + x.d) % mod);
}
inline bool operator == (const tag &x) const
{
return a == x.a && b == x.b && c == x.c && d == x.d;
}
};
const tag One = tag(1, 0, 0, 0);
struct node
{
tag x;
int lc, rc;
#define lc(x) tr[x].lc
#define rc(x) tr[x].rc
inline void Clear()
{
x = One;
lc = rc = 0;
}
}tr[M];
inline int new_node()
{
return top ? stk[top--] : ++T;
}
inline void addTag(int &s, const tag &w)
{
if (!s)
tr[s = new_node()].Clear();
tr[s].x = tr[s].x + w;
}
inline void downDate(int &s)
{
if (!s)
tr[s = new_node()].Clear();
if (tr[s].x == One)
return ;
addTag(lc(s), tr[s].x);
addTag(rc(s), tr[s].x);
tr[s].x = One;
}
inline void Modify(int &s, int l, int r, int u, int v, const tag &w)
{
if (l == u && r == v)
return addTag(s, w);
downDate(s);
int mid = l + r >> 1;
if (v <= mid)
Modify(lc(s), l, mid, u, v, w);
else if (u > mid)
Modify(rc(s), mid + 1, r, u, v, w);
else
{
Modify(lc(s), l, mid, u, mid, w);
Modify(rc(s), mid + 1, r, mid + 1, v, w);
}
}
inline void Merge(int &s1, int s2)
{
if (!s1 || !s2)
return (void)(s1 += s2);
if (!lc(s1) && !rc(s1))
std::swap(s1, s2);
if (!lc(s2) && !rc(s2))
{
tr[s1].x.a = 1ll * tr[s1].x.a * tr[s2].x.b % mod;
tr[s1].x.b = 1ll * tr[s1].x.b * tr[s2].x.b % mod;
add(tr[s1].x.d, tr[s2].x.d);
stk[++top] = s2;
return ;
}
downDate(s1);
downDate(s2);
Merge(lc(s1), lc(s2));
Merge(rc(s1), rc(s2));
stk[++top] = s2;
}
inline void Print(int s, int l, int r, int v)
{
if (!s)
return ;
if (l == r)
return add(yy[v], tr[s].x.d);
downDate(s);
int mid = l + r >> 1;
Print(lc(s), l, mid, v);
Print(rc(s), mid + 1, r, v);
}
inline void Lagrange()
{
F[0] = 1;
for (int i = 1; i <= n + 1; ++i)
{
for (int j = i; j >= 1; --j)
F[j] = F[j - 1];
F[0] = 0;
for (int j = 1; j <= i; ++j)
F[j - 1] = (1ll * (mod - i) * F[j] + F[j - 1]) % mod;
}
for (int i = 1; i <= n + 1; ++i)
{
int p = yy[i], q = 1;
for (int j = 1; j <= n + 1; ++j)
if (i != j)
q = 1ll * q * (i - j + mod) % mod;
p = 1ll * p * quick_pow(q, mod - 2) % mod;
for (int j = n + 1; j >= 1; --j)
F[j - 1] = (F[j - 1] + 1ll * i * F[j]) % mod;
for (int j = 1; j <= n + 1; ++j)
F[j - 1] = F[j];
F[n + 1] = 0;
for (int j = K; j <= n; ++j)
fans = (fans + 1ll * p * F[j]) % mod;
for (int j = n + 1; j >= 1; --j)
F[j] = F[j - 1];
for (int j = 1; j <= n + 1; ++j)
F[j - 1] = (1ll * (mod - i) * F[j] + F[j - 1]) % mod;
}
}
inline void Dfs(int x, int Fa, int v)
{
Modify(rt[x], 1, W, 1, d[x], tag(1, v, 0, 0));
if (d[x] < W)
Modify(rt[x], 1, W, d[x] + 1, W, tag(1, 1, 0, 0));
for (int i = 0, im = e[x].size(); i < im; ++i)
{
int y = e[x][i];
if (y == Fa)
continue;
Dfs(y, x, v);
Merge(rt[x], rt[y]);
}
addTag(rt[x], tag(1, 1, 1, 0));
}
inline void solve(int v)
{
for (int i = 1; i <= n; ++i)
rt[i] = 0;
T = top = 0;
Dfs(1, 0, v);
Print(rt[1], 1, W, v);
}
int main()
{
read(n); read(K); read(W);
for (int i = 1; i <= n; ++i)
read(d[i]);
for (int i = 1, x, y; i < n; ++i)
{
read(x); read(y);
e[x].push_back(y);
e[y].push_back(x);
}
for (int i = 1; i <= n + 1; ++i)
solve(i);
Lagrange();
printf("%d\n", fans);
}