BZOJ5119 [清华集训2017] 生成树计数

2018 年 03 月 17 日发布.

Description

一个$s$个点的图,目前被$s-n$条边连成了$n$个连通块,第$i$个连通块大小为$a_i$。要求你再连$n-1$条边把它变成一棵树。对于每一棵树,若第$i$个连通块连出了$d_i$条边,其价值为

$val(T)=\left(\prod_{i=1}^nd_i^m\right)\left(\sum_{i=1}^nd_i^m\right)$

求所有生成树的价值和$\bmod998244353$。$n \leqslant 30000, m \leqslant 30$。

Note: BZOJ 上出了问题(据说评测出了问题,页面上的数据下载下来本地 50s ),可以去 LOJ 交。

Solution

首先,“连通块的大小”只是起到一个系数的作用,如果在第 $i$ 个连通块上连了 $d_i$ 条边,那么方案数要乘 $a_i^{d_i}$ 。

所以说 $$ \begin{aligned} ans&=\sum_T\left(\prod_{i=1}^na_i^{d_i}d_i^m\right)\left(\sum_{i=1}^nd_i^m\right)\\ &=\sum_{i=1}^n\sum_T\left(\prod_{i=1}^na_i^{d_i}d_i^m\right)d_i^m\\ &=\sum_{i=1}^n\sum_Ta_i^{d_i}d_i^{2m}\left(\prod_{j\neq i}a_j^{d_j}d_j^m\right) \end{aligned} $$ 推不下去了,因为“枚举所有树”太难。

我们考虑整个式子跟树的结构无关而只跟每个点的度数有关,于是可以想到 prufer 序列。

在 prufer 序列中,如果某个点的编号出现了 $k$ 次,那么它的度数即为 $k+1$ 。

于是枚举所有树可以变成枚举所有长为 $n-2$ ,元素为 $1\dots n$ 的序列;令 $k_i$ 表示序列中 $i$ 这个点出现的次数,则 $d_i=k_i+1$ 。显然如果我固定所有 $k$ 之后其方案数为 $\cfrac{(n-2)!}{\prod_i(k_i!)}$ ,所以答案即为 $$ \begin{aligned} ans&=\sum_{i=1}^n\sum_Ta_i^{d_i}d_i^{2m}\left(\prod_{j\neq i}a_i^{d_i}d_i^m\right)\\ &=\sum_{i=1}^n\sum_{k_1+k_2+\dots+k_n=n-2}\frac{(n-2)!}{\prod_i(k_i!)}a_i^{k_i+1}(k_i+1)^{2m}\left(\prod_{j\neq i}a_j^{k_j+1}(k_j+1)^m\right)\\ \end{aligned} $$ 可以发现这就是一个排列问题(枚举 $i$ 之后令第 $i$ 个点选 $k$ 个位置的价值为 $a_i^{k+1}(k+1)^{2m}$ ,其它第 $j$ 个点选 $k$ 的价值为 $a_j^{k+1}(k+1)^m$ ),于是可以想到指数型生成函数。

也即 $$ \begin{aligned} F(x)&=\sum_{i=1}^nA_i(x)\prod_{j\neq i}B_j(x)\\ A_i(x)&=\sum_ka_i^{k+1}(k+1)^{2m}\frac{x^k}{k!}\\ B_i(x)&=\sum_ka_i^{k+1}(k+1)^m\frac{x^k}{k!} \end{aligned} $$ 最终答案即为 $F$ 第 $n-2$ 项的系数乘上 $(n-2)!$ (别忘了这是指数型生成函数)。

我们发现这样复杂度是 $O(n^2\log n)$ ,无法忍受。

考虑如何简化 $A_i(x)$ 。考虑到 $m$ 相比于 $n$ 很小,根据第二类 Stirling 数的性质 $a^m=\sum_{i=0}^mS(m,i)a^{\underline i}$,我们有 $$ \begin{aligned} A_i(x)&=\sum_ka_i^{k+1}(k+1)^{2m}\frac{x^k}{k!}\\ T(x)=\int A_i(x)\mathrm{d}x&=\sum_ka_i^{k+1}(k+1)^{2m}\frac{x^{k+1}}{(k+1)!}\\ &=\sum_ka_i^kk^{2m}\frac{x^k}{k!}\\ &=\sum_{j=0}^{2m}S(2m,j)\sum_ka_i^kk^{\underline j}\frac{x^k}{k!}\\ &=\sum_{j=0}^{2m}S(2m,j)\sum_ka_i^k\frac{x^k}{(k-j)!}\\ &=\sum_{j=0}^{2m}S(2m,j)a_i^jx^je^{a_ix}\\ A_i(x)=\frac{\mathrm{d}T(x)}{\mathrm{d}x}&=\sum_{j=0}^{2m}\left[S(2m,j)ja_i^jx^{j-1}e^{a_ix}+S(2m,j)a_i^{j+1}x^je^{a_ix}\right]\\ &=e^{a_ix}\sum_{j=0}^{2m}\left[S(2m,j+1)(j+1)a_i^{j+1}+S(2m,j)a_i^{j+1}\right]x^j\\ &=e^{a_ix}\sum_{j=0}^{2m}S(2m+1,j+1)a_i^{j+1}x^j\\ \end{aligned} $$ 最后一步是由于第二类 Stirling 数的递推公式 $S(i,j)=S(i-1,j)j+S(i-1,j-1)$ 。

同理有$B_i(x)=e^{a_ix}\sum_{j=0}^{m}S(m+1,j+1)a_i^{j+1}x^j$。

所以 $$ \begin{aligned} F(x)&=\sum_{i=1}^nA_i(x)\prod_{j\neq i}B_j(x)\\ &=\sum_{i=1}^n\left(e^{a_ix}\sum_{k=0}^{2m}S(2m+1,k+1)a_i^{k+1}x^k\right)\prod_{j \neq i}\left(e^{a_jx}\sum_{k=0}^{m}S(m+1,k+1)a_j^{k+1}x^k\right)\\ &=e^{sx}\sum_{i=1}^n\left(\sum_{k=0}^{2m}S(2m+1,k+1)a_i^{k+1}x^k\right)\prod_{j \neq i}\left(\sum_{k=0}^{m}S(m+1,k+1)a_j^{k+1}x^k\right) \end{aligned} $$ 于是所有要乘起来的式子都变成了次数不超过 $2m$ 的多项式。

可以利用分治 NTT,分治过程求出 $B(x)=\prod_iB_i(x)$ 和 $A(x)=\sum_iA_i(x)\prod_{j\neq i}B_j(x)$ 。

时间复杂度 $O(nm\log^2 n)$ 。

分治时可以仅保留前 $n-2$ 项。

Code

Note: 代码中 $A(x), B(x)$ 和 Solution 中的 $A(x), B(x)$ 反了过来。

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <vector>

const int N = 30050;
const int M = 35;
const int mod = 998244353;
const int g = 3;

typedef long long LL;
typedef std::vector<LL> VLL;

int n, m;
LL a[N], fac[N], ifac[N], inv[N];

inline LL pow_mod(LL x, int b) {
  LL ans = 1;
  for ((b += mod - 1) %= (mod - 1); b; b >>= 1, (x *= x) %= mod)
    if ((b & 1) != 0) (ans *= x) %= mod;
  return ans;
}

void NTT(LL *A, int len, int opt) {
  for (int i = 1, j = 0; i < len; ++i) {
    int k = len;
    do j ^= (k >>= 1); while ((j & k) == 0);
    if (i < j) std::swap(A[i], A[j]);
  }
  for (int h = 2; h <= len; h <<= 1) {
    LL wn = pow_mod(g, (mod - 1) / h * opt);
    for (int j = 0; j < len; j += h) {
      LL w = 1LL;
      for (int i = j; i < j + (h >> 1); ++i) {
        LL _tmp1 = A[i], _tmp2 = A[i + (h >> 1)] * w % mod;
        A[i] = (_tmp1 + _tmp2) % mod;
        A[i + (h >> 1)] = (_tmp1 - _tmp2) % mod;
        (w *= wn) %= mod;
      }
    }
  }

  if (opt == -1)
    for (int i = 0; i < len; ++i)
      (A[i] *= -(mod - 1) / len) %= mod;
}

LL S[2 * M][2 * M];

struct PVLL{
  VLL A, B;
  PVLL() : A(0), B(0) {}
};

inline void Copy(const VLL &x, LL *A, int len) {
  for (int i = 0; i < len; ++i) A[i] = (i < x.size() ? x[i] : 0);
}

PVLL operator*(const PVLL &x, const PVLL &y) {
  static LL T1[N * M * 4], T2[N * M * 4], T3[N * M * 4];
  PVLL ans;
  int len = 1,
      ansl = std::max(x.A.size() + y.B.size(), x.B.size() + y.A.size());
  while (len < ansl - 1) len <<= 1;

  Copy(x.A, T1, len);
  NTT(T1, len, 1);
  Copy(y.A, T2, len);
  NTT(T2, len, 1);

  for (int i = 0; i < len; ++i) T3[i] = T1[i] * T2[i] % mod;
  NTT(T3, len, -1);
  ans.A.resize(std::min<int>(n, x.A.size() + y.A.size() - 1));
  for (int i = 0; i < std::min<int>(n, x.A.size() + y.A.size() - 1); ++i)
    (ans.A[i] = T3[i]) %= mod;

  Copy(x.B, T3, len);
  NTT(T3, len, 1);
  for (int i = 0; i < len; ++i) (T3[i] *= T2[i]) %= mod;
  NTT(T3, len, -1);
  ans.B.resize(std::min(n, ansl - 1));
  for (int i = 0; i < std::min(n, ansl - 1); ++i)
    ans.B[i] = T3[i];

  Copy(y.B, T3, len);
  NTT(T3, len, 1);
  for (int i = 0; i < len; ++i) (T3[i] *= T1[i]) %= mod;
  NTT(T3, len, -1);
  for (int i = 0; i < std::min(n, ansl - 1); ++i)
    (ans.B[i] += T3[i]) %= mod;
  return ans;
}

PVLL Solve(int l, int r) {
  if (l == r - 1) {
    PVLL ans;
    ans.A.resize(m + 1);
    ans.B.resize(2 * m + 1);
    LL pa = a[l];
    for (int i = 0; i <= 2 * m; ++i, (pa *= a[l]) %= mod) {
      if (i <= m) ans.A[i] = (LL)pa * S[m + 1][i + 1] % mod;
      ans.B[i] = (LL)pa * S[2 * m + 1][i + 1] % mod;
    }
    return ans;
  }
  int mid = (l + r + 1) >> 1;
  return Solve(l, mid) * Solve(mid, r);
}

LL ps[N];

int main() {
  scanf("%d%d", &n, &m);

  fac[0] = fac[1] = ifac[0] = ifac[1] = inv[1] = 1;
  for (int i = 2; i <= n; ++i) {
    fac[i] = fac[i - 1] * i % mod;
    inv[i] = -(mod / i) * inv[mod % i] % mod;
    ifac[i] = ifac[i - 1] * inv[i] % mod;
  }

  S[0][0] = 1;
  for (int i = 1; i <= 2 * m + 1; ++i)
    for (int j = 1; j <= 2 * m + 1; ++j)
      S[i][j] = (S[i - 1][j - 1] + S[i - 1][j] * j) % mod;

  int s = 0;
  for (int i = 0; i < n; ++i) {
    scanf("%lld", &a[i]);
    (s += a[i]) %= mod;
  }
  PVLL res = Solve(0, n);
  LL ans = 0;
  ps[0] = 1;
  for (int i = 1; i <= n - 2; ++i) ps[i] = ps[i - 1] * s % mod;
  for (int i = 0; i <= n - 2; ++i)
    (ans += res.B[i] * ps[n - 2 - i] % mod * ifac[n - 2 - i] % mod) %= mod;
  printf("%lld\n", (ans * fac[n - 2] % mod + mod) % mod);
  return 0;
}