BZOJ5250 [多省联考2018] 秘密袭击

Description

一棵 \(n\) 个点的树,第 \(i\) 个点有点权 \(d_i\) 。给定一个数 \(k\) ,求所有 [ 大小不小于 \(k\) 的连通块中的第 \(k\) 大的点权 ] 的和。

\(k \leqslant n \leqslant 1666\) ,点权最大值 \(W \leqslant 1666\)

Solution

如果我们对于每一个 \(w=1\dots W\) 都求出它作为第 \(k\) 大的方案数,那就做完了。

显然,第 \(k\) 大的数是 \(w\) 的方案数,相当于其大于等于 \(w\) 的方案数,减去其大于等于 \(w+1\) 的方案数。

\(a_i\) 表示第 \(k\) 大的数大于等于 \(i\) 的方案数,则答案即为 \(\sum_{i=1}^W i(a_i-a_{i+1})\) ,容易发现它就等于 \(\sum_{i=1}^W a_i\)

先考虑暴力 dp 。考虑枚举 \(i\) 之后如何求 \(a_i\)

显然,第 \(k\) 大的数大于等于 \(i\) ,等价于大于等于 \(i\) 的个数至少有 \(k\) 个。

我们令 \(f_{a,b}\) 表示在 \(a\) 的子树中选出 \(b\) 一个连通块,连通块必须包含 \(a\) 且大于等于 \(i\) 的数恰好有 \(b\) 个的方案数。

转移时在树上跑一遍背包即可。求 \(a_i\) 就可以直接把所有 \(f_{a, b} (b \ge k)\) 加起来。卡一下常就能AC了

考虑如何加速 dp 。

\(F_a(x)=\sum_b f_{a, b}x^b\) ,即 \(f_a\) 的生成函数。

树上背包的卷积形式相当于

\[F_a(x)=\left(\prod_{b\in son_a}(1+F_b(x))\right)*\begin{cases}1 & d_a \ge i\\x&d_a\lt i\end{cases}\]

求答案时可以令 \(G_a(x)\) 表示 \(a\) 子树里所有 \(F_a(x)\) 的和,然后求 \(G_{root}(x)\) 的第 \(k\dots n\) 项和。

如果我们只需要给定 \(x_0\) ,求出 \(F_a(x_0), G_a(x_0)\) ,那么可以直接 \(O(n)\) DP 。

那么显然我们可以利用拉格朗日插值法,通过 \(n+1\) 次 DP 插值出答案。然而这并没有优化复杂度甚至还会更慢

那么我们看一看目前的 DP 代码,它大概长这个样:

1
2
3
4
5
6
7
8
9
DP(a, i, x0)
(f, g) = (1, 0)
for b in son(a)
(fb, gb) = DP(b, i, x0)
(f, g) = (f * (1 + fb), g + gb)
if d[a] >= i
(f, g) = (f * x0, g)
(f, g) = (f, g + f)
return (f, g)

对于每一个 \(x_0\) ,最后的答案好像没有什么关联。那么能不能通过 \(i\) 这一维来优化呢?

假设我们只枚举 \(x_0\) ,而每次 DP 时求出 \(i\) 取每一个值的时候的答案。

  • (f, g) = (1, 0) -> 整体赋值
  • (f, g) = (f * (1 + fb), g + gb) -> 对应项合并
  • if d[a] >= i / (f, g) = (f * x0, g) -> 第 \(1\) 到第 \(d_a\) 项打标记
  • (f, g) = (f, g + f) -> 整体打标记

所以说可以利用线段树维护 DP 值。

打标记时 xjb搞 一番推导之后我们得出标记是这么个形式: (f, g) -> (x + y*f, g + z + w*f)

完了。线段树合并为什么不讲?因为讲不明白,自己看代码好了。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>

inline int readInt() {
int ans = 0, c;
while (!isdigit(c = getchar()));
do ans = ans * 10 + c - '0';
while (isdigit(c = getchar()));
return ans;
}

const int N = 2000;
const int mod = 64123;
typedef long long LL;

int n, k, W, d[N];
int pre[N], nxt[N * 2], to[N * 2], cnt = 0;

inline void addEdge(int x, int y) {
nxt[cnt] = pre[x];
to[pre[x] = cnt++] = y;
nxt[cnt] = pre[y];
to[pre[y] = cnt++] = x;
}

const int M = N * N * 4;

int lc[M], rc[M], pool[M], cnt2;

struct Msg{
// (a, b) -> (x + y * a, b + z + w * a)
int x, y, z, w;
Msg(int x = 0, int y = 1, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {}
inline friend Msg operator+(const Msg &a, const Msg &b) {
// (a, b)
// -> (x1 + y1 * a, b + z1 + w1 * a)
// -> (x2 + y2 * x1 + y2 * y1 * a, b + z1 + w1 * a + z2 + w2 * x1 + w2 * y1 * a)
// = ((x2 + y2 * x1) + (y2 * y1) * a, b + (z1 + z2 + w2 * x1) + (w1 + w2 * y1) * a)
return Msg(
((LL)a.x * b.y + b.x) % mod,
(LL)a.y * b.y % mod,
(a.z + b.z + (LL)b.w * a.x) % mod,
(a.w + (LL)b.w * a.y) % mod);
}
inline Msg& operator+=(const Msg &a) { return *this = *this + a; }
} msgv[M];

inline int newNode() {
int t = pool[cnt2++];
lc[t] = rc[t] = 0; msgv[t] = Msg();
return t;
}

inline void delNode(int t) { pool[--cnt2] = t; }

void pushd(int o) {
if (lc[o] == 0) lc[o] = newNode();
if (rc[o] == 0) rc[o] = newNode();
msgv[lc[o]] += msgv[o];
msgv[rc[o]] += msgv[o];
msgv[o] = Msg();
}

void modify(int o, int l, int r, int L, int R, const Msg &m) {
if (r < L || R < l) return;
if (L <= l && r <= R) {
msgv[o] += m;
} else {
pushd(o);
int mid = (l + r) >> 1;
modify(lc[o], l, mid, L, R, m);
modify(rc[o], mid + 1, r, L, R, m);
}
}

void delTree(int o) {
if (o != 0) {
delTree(lc[o]);
delTree(rc[o]);
delNode(o);
}
}

int sumv(int o, int l, int r) {
if (l == r) {
return msgv[o].z;
} else {
pushd(o);
int mid = (l + r) >> 1;
return (sumv(lc[o], l, mid) + sumv(rc[o], mid + 1, r)) % mod;
}
}

void merge(int &x, int y) {
if (y == 0) return;
if (x == 0) { x = y; return; }
if (lc[x] == 0 && rc[x] == 0) std::swap(x, y);
if (lc[y] == 0 && rc[y] == 0) {
msgv[x].z = (msgv[x].z + msgv[y].z) % mod;
msgv[x].y = (LL)msgv[x].y * msgv[y].x % mod;
msgv[x].x = (LL)msgv[x].x * msgv[y].x % mod;
} else {
pushd(x); pushd(y);
merge(lc[x], lc[y]);
merge(rc[x], rc[y]);
}
delNode(y);
}

int xx, ansv;

int dp(int x, int fa) {
int ans = newNode();
modify(ans, 1, W, 1, W, Msg(1, 1));
for (int i = pre[x]; i >= 0; i = nxt[i])
if (to[i] != fa) merge(ans, dp(to[i], x));
modify(ans, 1, W, 1, d[x], Msg(0, xx));
modify(ans, 1, W, 1, W, Msg(0, 1, 0, 1));
modify(ans, 1, W, 1, W, Msg(1, 1));
return ans;
}

int A[N][N], fac[N], ifac[N], B[N];

void solve(int l, int r) {
if (l == r) {
int t = (LL)ifac[l] * ifac[n - l] * ((n - l) % 2 == 1 ? -1 : 1) % mod;
for (int i = 0; i <= n; ++i)
A[l][i] = (LL)A[l][i] * t % mod;
return;
}
int mid = (l + r) >> 1;
memcpy(A[mid + 1], A[l], sizeof(A[l]));
for (int i = l; i <= mid; ++i) {
for (int j = n; j > 0; --j)
A[mid + 1][j] = (A[mid + 1][j - 1] - (LL)i * A[mid + 1][j]) % mod;
A[mid + 1][0] = -(LL)i * A[mid + 1][0] % mod;
}
for (int i = mid + 1; i <= r; ++i) {
for (int j = n; j > 0; --j)
A[l][j] = (A[l][j - 1] - (LL)i * A[l][j]) % mod;
A[l][0] = -(LL)i * A[l][0] % mod;
}
solve(l, mid);
solve(mid + 1, r);
}

int main() {
n = readInt();
ifac[0] = ifac[1] = 1;
for (int i = 2; i <= n; ++i) ifac[i] = - (LL)(mod / i) * ifac[mod % i] % mod;
for (int i = 2; i <= n; ++i) ifac[i] = (LL)ifac[i] * ifac[i - 1] % mod;
A[0][0] = 1;
solve(0, n);

k = readInt();
for (int i = 0; i <= n; ++i) {
for (int j = k; j <= n; ++j) B[i] += A[i][j];
B[i] %= mod;
}

W = readInt();
for (int i = 1; i <= n; ++i) d[i] = readInt();
memset(pre, -1, sizeof pre);
for (int i = 1; i < n; ++i) addEdge(readInt(), readInt());

for (int i = 0; i < M - 1; ++i) pool[i] = i + 1;

int ans = 0;
for (int i = 0; i <= n; ++i) {
xx = i;
int t = dp(1, 0);
ans = (ans + (LL)sumv(t, 1, W) * B[i]) % mod;
delTree(t);
}
printf("%d\n", (ans + mod) % mod);
return 0;
}