BZOJ5250 [多省联考2018] 秘密袭击

2018 年 04 月 8 日发布.

Description

一棵 $n$ 个点的树,第 $i$ 个点有点权 $d_i$ 。给定一个数 $k$ ,求所有 [ 大小不小于 $k$ 的连通块中的第 $k$ 大的点权 ] 的和。

$k \leqslant n \leqslant 1666$ ,点权最大值 $W \leqslant 1666$ 。

Solution

如果我们对于每一个 $w=1\dots W$ 都求出它作为第 $k$ 大的方案数,那就做完了。

显然,第 $k$ 大的数是 $w$ 的方案数,相当于其大于等于 $w$ 的方案数,减去其大于等于 $w+1$ 的方案数。

令 $a_i$ 表示第 $k$ 大的数大于等于 $i$ 的方案数,则答案即为 $\sum_{i=1}^W i(a_i-a_{i+1})$ ,容易发现它就等于 $\sum_{i=1}^W a_i$ 。

先考虑暴力 dp 。考虑枚举 $i$ 之后如何求 $a_i$ 。

显然,第 $k$ 大的数大于等于 $i$ ,等价于大于等于 $i$ 的个数至少有 $k$ 个。

我们令 $f_{a,b}$ 表示在 $a$ 的子树中选出 $b$ 一个连通块,连通块必须包含 $a$ 且大于等于 $i$ 的数恰好有 $b$ 个的方案数。

转移时在树上跑一遍背包即可。求 $a_i$ 就可以直接把所有 $f_{a, b} (b \ge k)$ 加起来。卡一下常就能AC了

考虑如何加速 dp 。

令 $F_a(x)=\sum_b f_{a, b}x^b$ ,即 $f_a$ 的生成函数。

树上背包的卷积形式相当于

$$F_a(x)=\left(\prod_{b\in son_a}(1+F_b(x))\right)*\begin{cases}1 & d_a \ge i\\x&d_a\lt i\end{cases}$$

求答案时可以令 $G_a(x)$ 表示 $a$ 子树里所有 $F_a(x)$ 的和,然后求 $G_{root}(x)$ 的第 $k\dots n$ 项和。

如果我们只需要给定 $x_0$ ,求出 $F_a(x_0), G_a(x_0)$ ,那么可以直接 $O(n)$ DP 。

那么显然我们可以利用拉格朗日插值法,通过 $n+1$ 次 DP 插值出答案。然而这并没有优化复杂度甚至还会更慢

那么我们看一看目前的 DP 代码,它大概长这个样:

DP(a, i, x0)
  (f, g) = (1, 0)
  for b in son(a)
    (fb, gb) = DP(b, i, x0)
    (f, g) = (f * (1 + fb), g + gb)
  if d[a] >= i
    (f, g) = (f * x0, g)
  (f, g) = (f, g + f)
  return (f, g)

对于每一个 $x_0$ ,最后的答案好像没有什么关联。那么能不能通过 $i$ 这一维来优化呢?

假设我们只枚举 $x_0$ ,而每次 DP 时求出 $i$ 取每一个值的时候的答案。

所以说可以利用线段树维护 DP 值。

打标记时 xjb搞 一番推导之后我们得出标记是这么个形式: (f, g) -> (x + y*f, g + z + w*f)

完了。线段树合并为什么不讲?因为讲不明白,自己看代码好了。

Code

#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>

inline int readInt() {
  int ans = 0, c;
  while (!isdigit(c = getchar()));
  do ans = ans * 10 + c - '0';
  while (isdigit(c = getchar()));
  return ans;
}

const int N = 2000;
const int mod = 64123;
typedef long long LL;

int n, k, W, d[N];
int pre[N], nxt[N * 2], to[N * 2], cnt = 0;

inline void addEdge(int x, int y) {
  nxt[cnt] = pre[x];
  to[pre[x] = cnt++] = y;
  nxt[cnt] = pre[y];
  to[pre[y] = cnt++] = x;
}

const int M = N * N * 4;

int lc[M], rc[M], pool[M], cnt2;

struct Msg{
  // (a, b) -> (x + y * a, b + z + w * a)
  int x, y, z, w;
  Msg(int x = 0, int y = 1, int z = 0, int w = 0) : x(x), y(y), z(z), w(w) {}
  inline friend Msg operator+(const Msg &a, const Msg &b) {
    // (a, b)
    // -> (x1 + y1 * a, b + z1 + w1 * a)
    // -> (x2 + y2 * x1 + y2 * y1 * a, b + z1 + w1 * a + z2 + w2 * x1 + w2 * y1 * a)
    //  = ((x2 + y2 * x1) + (y2 * y1) * a, b + (z1 + z2 + w2 * x1) + (w1 + w2 * y1) * a)
    return Msg(
        ((LL)a.x * b.y + b.x) % mod,
        (LL)a.y * b.y % mod,
        (a.z + b.z + (LL)b.w * a.x) % mod,
        (a.w + (LL)b.w * a.y) % mod);
  }
  inline Msg& operator+=(const Msg &a) { return *this = *this + a; }
} msgv[M];

inline int newNode() {
  int t = pool[cnt2++];
  lc[t] = rc[t] = 0; msgv[t] = Msg();
  return t;
}

inline void delNode(int t) { pool[--cnt2] = t; }

void pushd(int o) {
  if (lc[o] == 0) lc[o] = newNode();
  if (rc[o] == 0) rc[o] = newNode();
  msgv[lc[o]] += msgv[o];
  msgv[rc[o]] += msgv[o];
  msgv[o] = Msg();
}

void modify(int o, int l, int r, int L, int R, const Msg &m) {
  if (r < L || R < l) return;
  if (L <= l && r <= R) {
    msgv[o] += m;
  } else {
    pushd(o);
    int mid = (l + r) >> 1;
    modify(lc[o], l, mid, L, R, m);
    modify(rc[o], mid + 1, r, L, R, m);
  }
}

void delTree(int o) {
  if (o != 0) {
    delTree(lc[o]);
    delTree(rc[o]);
    delNode(o);
  }
}

int sumv(int o, int l, int r) {
  if (l == r) {
    return msgv[o].z;
  } else {
    pushd(o);
    int mid = (l + r) >> 1;
    return (sumv(lc[o], l, mid) + sumv(rc[o], mid + 1, r)) % mod;
  }
}

void merge(int &x, int y) {
  if (y == 0) return;
  if (x == 0) { x = y; return; }
  if (lc[x] == 0 && rc[x] == 0) std::swap(x, y);
  if (lc[y] == 0 && rc[y] == 0) {
    msgv[x].z = (msgv[x].z + msgv[y].z) % mod;
    msgv[x].y = (LL)msgv[x].y * msgv[y].x % mod;
    msgv[x].x = (LL)msgv[x].x * msgv[y].x % mod;
  } else {
    pushd(x); pushd(y);
    merge(lc[x], lc[y]);
    merge(rc[x], rc[y]);
  }
  delNode(y);
}

int xx, ansv;

int dp(int x, int fa) {
  int ans = newNode();
  modify(ans, 1, W, 1, W, Msg(1, 1));
  for (int i = pre[x]; i >= 0; i = nxt[i])
    if (to[i] != fa) merge(ans, dp(to[i], x));
  modify(ans, 1, W, 1, d[x], Msg(0, xx));
  modify(ans, 1, W, 1, W, Msg(0, 1, 0, 1));
  modify(ans, 1, W, 1, W, Msg(1, 1));
  return ans;
}

int A[N][N], fac[N], ifac[N], B[N];

void solve(int l, int r) {
  if (l == r) {
    int t = (LL)ifac[l] * ifac[n - l] * ((n - l) % 2 == 1 ? -1 : 1) % mod;
    for (int i = 0; i <= n; ++i)
      A[l][i] = (LL)A[l][i] * t % mod;
    return;
  }
  int mid = (l + r) >> 1;
  memcpy(A[mid + 1], A[l], sizeof(A[l]));
  for (int i = l; i <= mid; ++i) {
    for (int j = n; j > 0; --j)
      A[mid + 1][j] = (A[mid + 1][j - 1] - (LL)i * A[mid + 1][j]) % mod;
    A[mid + 1][0] = -(LL)i * A[mid + 1][0] % mod;
  }
  for (int i = mid + 1; i <= r; ++i) {
    for (int j = n; j > 0; --j)
      A[l][j] = (A[l][j - 1] - (LL)i * A[l][j]) % mod;
    A[l][0] = -(LL)i * A[l][0] % mod;
  }
  solve(l, mid);
  solve(mid + 1, r);
}

int main() {
  n = readInt();
  ifac[0] = ifac[1] = 1;
  for (int i = 2; i <= n; ++i) ifac[i] = - (LL)(mod / i) * ifac[mod % i] % mod;
  for (int i = 2; i <= n; ++i) ifac[i] = (LL)ifac[i] * ifac[i - 1] % mod;
  A[0][0] = 1;
  solve(0, n);

  k = readInt();
  for (int i = 0; i <= n; ++i) {
    for (int j = k; j <= n; ++j) B[i] += A[i][j];
    B[i] %= mod;
  }

  W = readInt();
  for (int i = 1; i <= n; ++i) d[i] = readInt();
  memset(pre, -1, sizeof pre);
  for (int i = 1; i < n; ++i) addEdge(readInt(), readInt());

  for (int i = 0; i < M - 1; ++i) pool[i] = i + 1;

  int ans = 0;
  for (int i = 0; i <= n; ++i) {
    xx = i;
    int t = dp(1, 0);
    ans = (ans + (LL)sumv(t, 1, W) * B[i]) % mod;
    delTree(t);
  }
  printf("%d\n", (ans + mod) % mod);
  return 0;
}