BZOJ5250 [多省联考2018] 秘密袭击
Description
一棵 $n$ 个点的树,第 $i$ 个点有点权 $d_i$ 。给定一个数 $k$ ,求所有 [ 大小不小于 $k$ 的连通块中的第 $k$ 大的点权 ] 的和。
$k \leqslant n \leqslant 1666$ ,点权最大值 $W \leqslant 1666$ 。
Solution
如果我们对于每一个 $w=1\dots W$ 都求出它作为第 $k$ 大的方案数,那就做完了。
显然,第 $k$ 大的数是 $w$ 的方案数,相当于其大于等于 $w$ 的方案数,减去其大于等于 $w+1$ 的方案数。
令 $a_i$ 表示第 $k$ 大的数大于等于 $i$ 的方案数,则答案即为 $\sum_{i=1}^W i(a_i-a_{i+1})$ ,容易发现它就等于 $\sum_{i=1}^W a_i$ 。
先考虑暴力 dp 。考虑枚举 $i$ 之后如何求 $a_i$ 。
显然,第 $k$ 大的数大于等于 $i$ ,等价于大于等于 $i$ 的个数至少有 $k$ 个。
我们令 $f_{a,b}$ 表示在 $a$ 的子树中选出 $b$ 一个连通块,连通块必须包含 $a$ 且大于等于 $i$ 的数恰好有 $b$ 个的方案数。
转移时在树上跑一遍背包即可。求 $a_i$ 就可以直接把所有 $f_{a, b} (b \ge k)$ 加起来。卡一下常就能AC了
考虑如何加速 dp 。
令 $F_a(x)=\sum_b f_{a, b}x^b$ ,即 $f_a$ 的生成函数。
树上背包的卷积形式相当于
$$F_a(x)=\left(\prod_{b\in son_a}(1+F_b(x))\right)*\begin{cases}1 & d_a \ge i\\x&d_a\lt i\end{cases}$$
求答案时可以令 $G_a(x)$ 表示 $a$ 子树里所有 $F_a(x)$ 的和,然后求 $G_{root}(x)$ 的第 $k\dots n$ 项和。
如果我们只需要给定 $x_0$ ,求出 $F_a(x_0), G_a(x_0)$ ,那么可以直接 $O(n)$ DP 。
那么显然我们可以利用拉格朗日插值法,通过 $n+1$ 次 DP 插值出答案。然而这并没有优化复杂度甚至还会更慢
那么我们看一看目前的 DP 代码,它大概长这个样:
DP(a, i, x0)
(f, g) = (1, 0)
for b in son(a)
(fb, gb) = DP(b, i, x0)
(f, g) = (f * (1 + fb), g + gb)
if d[a] >= i
(f, g) = (f * x0, g)
(f, g) = (f, g + f)
return (f, g)
对于每一个 $x_0$ ,最后的答案好像没有什么关联。那么能不能通过 $i$ 这一维来优化呢?
假设我们只枚举 $x_0$ ,而每次 DP 时求出 $i$ 取每一个值的时候的答案。
(f, g) = (1, 0)
-> 整体赋值(f, g) = (f * (1 + fb), g + gb)
-> 对应项合并if d[a] >= i / (f, g) = (f * x0, g)
-> 第 $1$ 到第 $d_a$ 项打标记(f, g) = (f, g + f)
-> 整体打标记
所以说可以利用线段树维护 DP 值。
打标记时 xjb搞 一番推导之后我们得出标记是这么个形式: (f, g) -> (x + y*f, g + z + w*f)
。
完了。线段树合并为什么不讲?因为讲不明白,自己看代码好了。
Code
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
inline int readInt() {
int ans = 0, c;
while (!isdigit(c = getchar()));
do ans = ans * 10 + c - '0';
while (isdigit(c = getchar()));
return ans;
}
const int N = 2000;
const int mod = 64123;
typedef long long LL;
int n, k, W, d[N];
int pre[N], nxt[N * 2], to[N * 2], cnt = 0;
inline void addEdge(int x, int y) {
[cnt] = pre[x];
nxt[pre[x] = cnt++] = y;
to[cnt] = pre[y];
nxt[pre[y] = cnt++] = x;
to}
const int M = N * N * 4;
int lc[M], rc[M], pool[M], cnt2;
struct Msg{
// (a, b) -> (x + y * a, b + z + w * a)
int x, y, z, w;
(int x = 0, int y = 1, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {}
Msginline friend Msg operator+(const Msg &a, const Msg &b) {
// (a, b)
// -> (x1 + y1 * a, b + z1 + w1 * a)
// -> (x2 + y2 * x1 + y2 * y1 * a, b + z1 + w1 * a + z2 + w2 * x1 + w2 * y1 * a)
// = ((x2 + y2 * x1) + (y2 * y1) * a, b + (z1 + z2 + w2 * x1) + (w1 + w2 * y1) * a)
return Msg(
((LL)a.x * b.y + b.x) % mod,
(LL)a.y * b.y % mod,
(a.z + b.z + (LL)b.w * a.x) % mod,
(a.w + (LL)b.w * a.y) % mod);
}
inline Msg& operator+=(const Msg &a) { return *this = *this + a; }
} msgv[M];
inline int newNode() {
int t = pool[cnt2++];
[t] = rc[t] = 0; msgv[t] = Msg();
lcreturn t;
}
inline void delNode(int t) { pool[--cnt2] = t; }
void pushd(int o) {
if (lc[o] == 0) lc[o] = newNode();
if (rc[o] == 0) rc[o] = newNode();
[lc[o]] += msgv[o];
msgv[rc[o]] += msgv[o];
msgv[o] = Msg();
msgv}
void modify(int o, int l, int r, int L, int R, const Msg &m) {
if (r < L || R < l) return;
if (L <= l && r <= R) {
[o] += m;
msgv} else {
(o);
pushdint mid = (l + r) >> 1;
(lc[o], l, mid, L, R, m);
modify(rc[o], mid + 1, r, L, R, m);
modify}
}
void delTree(int o) {
if (o != 0) {
(lc[o]);
delTree(rc[o]);
delTree(o);
delNode}
}
int sumv(int o, int l, int r) {
if (l == r) {
return msgv[o].z;
} else {
(o);
pushdint mid = (l + r) >> 1;
return (sumv(lc[o], l, mid) + sumv(rc[o], mid + 1, r)) % mod;
}
}
void merge(int &x, int y) {
if (y == 0) return;
if (x == 0) { x = y; return; }
if (lc[x] == 0 && rc[x] == 0) std::swap(x, y);
if (lc[y] == 0 && rc[y] == 0) {
[x].z = (msgv[x].z + msgv[y].z) % mod;
msgv[x].y = (LL)msgv[x].y * msgv[y].x % mod;
msgv[x].x = (LL)msgv[x].x * msgv[y].x % mod;
msgv} else {
(x); pushd(y);
pushd(lc[x], lc[y]);
merge(rc[x], rc[y]);
merge}
(y);
delNode}
int xx, ansv;
int dp(int x, int fa) {
int ans = newNode();
(ans, 1, W, 1, W, Msg(1, 1));
modifyfor (int i = pre[x]; i >= 0; i = nxt[i])
if (to[i] != fa) merge(ans, dp(to[i], x));
(ans, 1, W, 1, d[x], Msg(0, xx));
modify(ans, 1, W, 1, W, Msg(0, 1, 0, 1));
modify(ans, 1, W, 1, W, Msg(1, 1));
modifyreturn ans;
}
int A[N][N], fac[N], ifac[N], B[N];
void solve(int l, int r) {
if (l == r) {
int t = (LL)ifac[l] * ifac[n - l] * ((n - l) % 2 == 1 ? -1 : 1) % mod;
for (int i = 0; i <= n; ++i)
[l][i] = (LL)A[l][i] * t % mod;
Areturn;
}
int mid = (l + r) >> 1;
(A[mid + 1], A[l], sizeof(A[l]));
memcpyfor (int i = l; i <= mid; ++i) {
for (int j = n; j > 0; --j)
[mid + 1][j] = (A[mid + 1][j - 1] - (LL)i * A[mid + 1][j]) % mod;
A[mid + 1][0] = -(LL)i * A[mid + 1][0] % mod;
A}
for (int i = mid + 1; i <= r; ++i) {
for (int j = n; j > 0; --j)
[l][j] = (A[l][j - 1] - (LL)i * A[l][j]) % mod;
A[l][0] = -(LL)i * A[l][0] % mod;
A}
(l, mid);
solve(mid + 1, r);
solve}
int main() {
= readInt();
n [0] = ifac[1] = 1;
ifacfor (int i = 2; i <= n; ++i) ifac[i] = - (LL)(mod / i) * ifac[mod % i] % mod;
for (int i = 2; i <= n; ++i) ifac[i] = (LL)ifac[i] * ifac[i - 1] % mod;
[0][0] = 1;
A(0, n);
solve
= readInt();
k for (int i = 0; i <= n; ++i) {
for (int j = k; j <= n; ++j) B[i] += A[i][j];
[i] %= mod;
B}
= readInt();
W for (int i = 1; i <= n; ++i) d[i] = readInt();
(pre, -1, sizeof pre);
memsetfor (int i = 1; i < n; ++i) addEdge(readInt(), readInt());
for (int i = 0; i < M - 1; ++i) pool[i] = i + 1;
int ans = 0;
for (int i = 0; i <= n; ++i) {
= i;
xx int t = dp(1, 0);
= (ans + (LL)sumv(t, 1, W) * B[i]) % mod;
ans (t);
delTree}
("%d\n", (ans + mod) % mod);
printfreturn 0;
}