_rqy's Blog

一只猫猫,想成为天才少女数学家!

0%

BZOJ5119 [清华集训2017] 生成树计数

Description

一个ss个点的图,目前被sns-n条边连成了nn个连通块,第ii个连通块大小为aia_i。要求你再连n1n-1条边把它变成一棵树。对于每一棵树,若第ii个连通块连出了did_i条边,其价值为

val(T)=(i=1ndim)(i=1ndim)val(T)=\left(\prod_{i=1}^nd_i^m\right)\left(\sum_{i=1}^nd_i^m\right)

求所有生成树的价值和mod998244353\bmod998244353n30000,m30n \leqslant 30000, m \leqslant 30

Note: BZOJ 上出了问题(据说评测出了问题,页面上的数据下载下来本地 50s ),可以去 LOJ 交。

Solution

首先,“连通块的大小”只是起到一个系数的作用,如果在第 ii 个连通块上连了 did_i 条边,那么方案数要乘 aidia_i^{d_i}

所以说

ans=T(i=1naididim)(i=1ndim)=i=1nT(i=1naididim)dim=i=1nTaididi2m(jiajdjdjm)\begin{aligned} ans&=\sum_T\left(\prod_{i=1}^na_i^{d_i}d_i^m\right)\left(\sum_{i=1}^nd_i^m\right)\\ &=\sum_{i=1}^n\sum_T\left(\prod_{i=1}^na_i^{d_i}d_i^m\right)d_i^m\\ &=\sum_{i=1}^n\sum_Ta_i^{d_i}d_i^{2m}\left(\prod_{j\neq i}a_j^{d_j}d_j^m\right) \end{aligned}

推不下去了,因为“枚举所有树”太难。

我们考虑整个式子跟树的结构无关而只跟每个点的度数有关,于是可以想到 prufer 序列。

在 prufer 序列中,如果某个点的编号出现了 kk 次,那么它的度数即为 k+1k+1

于是枚举所有树可以变成枚举所有长为 n2n-2 ,元素为 1n1\dots n 的序列;令 kik_i 表示序列中 ii 这个点出现的次数,则 di=ki+1d_i=k_i+1 。显然如果我固定所有 kk 之后其方案数为 (n2)!i(ki!)\cfrac{(n-2)!}{\prod_i(k_i!)} ,所以答案即为

ans=i=1nTaididi2m(jiaididim)=i=1nk1+k2++kn=n2(n2)!i(ki!)aiki+1(ki+1)2m(jiajkj+1(kj+1)m)\begin{aligned} ans&=\sum_{i=1}^n\sum_Ta_i^{d_i}d_i^{2m}\left(\prod_{j\neq i}a_i^{d_i}d_i^m\right)\\ &=\sum_{i=1}^n\sum_{k_1+k_2+\dots+k_n=n-2}\frac{(n-2)!}{\prod_i(k_i!)}a_i^{k_i+1}(k_i+1)^{2m}\left(\prod_{j\neq i}a_j^{k_j+1}(k_j+1)^m\right)\\ \end{aligned}

可以发现这就是一个排列问题(枚举 ii 之后令第 ii 个点选 kk 个位置的价值为 aik+1(k+1)2ma_i^{k+1}(k+1)^{2m} ,其它第 jj 个点选 kk 的价值为 ajk+1(k+1)ma_j^{k+1}(k+1)^m ),于是可以想到指数型生成函数。

也即

F(x)=i=1nAi(x)jiBj(x)Ai(x)=kaik+1(k+1)2mxkk!Bi(x)=kaik+1(k+1)mxkk!\begin{aligned} F(x)&=\sum_{i=1}^nA_i(x)\prod_{j\neq i}B_j(x)\\ A_i(x)&=\sum_ka_i^{k+1}(k+1)^{2m}\frac{x^k}{k!}\\ B_i(x)&=\sum_ka_i^{k+1}(k+1)^m\frac{x^k}{k!} \end{aligned}

最终答案即为 FFn2n-2 项的系数乘上 (n2)!(n-2)! (别忘了这是指数型生成函数)。

我们发现这样复杂度是 O(n2logn)O(n^2\log n) ,无法忍受。

考虑如何简化 Ai(x)A_i(x) 。考虑到 mm 相比于 nn 很小,根据第二类 Stirling 数的性质 am=i=0mS(m,i)aia^m=\sum_{i=0}^mS(m,i)a^{\underline i},我们有

Ai(x)=kaik+1(k+1)2mxkk!T(x)=Ai(x)dx=kaik+1(k+1)2mxk+1(k+1)!=kaikk2mxkk!=j=02mS(2m,j)kaikkjxkk!=j=02mS(2m,j)kaikxk(kj)!=j=02mS(2m,j)aijxjeaixAi(x)=dT(x)dx=j=02m[S(2m,j)jaijxj1eaix+S(2m,j)aij+1xjeaix]=eaixj=02m[S(2m,j+1)(j+1)aij+1+S(2m,j)aij+1]xj=eaixj=02mS(2m+1,j+1)aij+1xj\begin{aligned} A_i(x)&=\sum_ka_i^{k+1}(k+1)^{2m}\frac{x^k}{k!}\\ T(x)=\int A_i(x)\mathrm{d}x&=\sum_ka_i^{k+1}(k+1)^{2m}\frac{x^{k+1}}{(k+1)!}\\ &=\sum_ka_i^kk^{2m}\frac{x^k}{k!}\\ &=\sum_{j=0}^{2m}S(2m,j)\sum_ka_i^kk^{\underline j}\frac{x^k}{k!}\\ &=\sum_{j=0}^{2m}S(2m,j)\sum_ka_i^k\frac{x^k}{(k-j)!}\\ &=\sum_{j=0}^{2m}S(2m,j)a_i^jx^je^{a_ix}\\ A_i(x)=\frac{\mathrm{d}T(x)}{\mathrm{d}x}&=\sum_{j=0}^{2m}\left[S(2m,j)ja_i^jx^{j-1}e^{a_ix}+S(2m,j)a_i^{j+1}x^je^{a_ix}\right]\\ &=e^{a_ix}\sum_{j=0}^{2m}\left[S(2m,j+1)(j+1)a_i^{j+1}+S(2m,j)a_i^{j+1}\right]x^j\\ &=e^{a_ix}\sum_{j=0}^{2m}S(2m+1,j+1)a_i^{j+1}x^j\\ \end{aligned}

最后一步是由于第二类 Stirling 数的递推公式 S(i,j)=S(i1,j)j+S(i1,j1)S(i,j)=S(i-1,j)j+S(i-1,j-1)

同理有Bi(x)=eaixj=0mS(m+1,j+1)aij+1xjB_i(x)=e^{a_ix}\sum_{j=0}^{m}S(m+1,j+1)a_i^{j+1}x^j

所以

F(x)=i=1nAi(x)jiBj(x)=i=1n(eaixk=02mS(2m+1,k+1)aik+1xk)ji(eajxk=0mS(m+1,k+1)ajk+1xk)=esxi=1n(k=02mS(2m+1,k+1)aik+1xk)ji(k=0mS(m+1,k+1)ajk+1xk)\begin{aligned} F(x)&=\sum_{i=1}^nA_i(x)\prod_{j\neq i}B_j(x)\\ &=\sum_{i=1}^n\left(e^{a_ix}\sum_{k=0}^{2m}S(2m+1,k+1)a_i^{k+1}x^k\right)\prod_{j \neq i}\left(e^{a_jx}\sum_{k=0}^{m}S(m+1,k+1)a_j^{k+1}x^k\right)\\ &=e^{sx}\sum_{i=1}^n\left(\sum_{k=0}^{2m}S(2m+1,k+1)a_i^{k+1}x^k\right)\prod_{j \neq i}\left(\sum_{k=0}^{m}S(m+1,k+1)a_j^{k+1}x^k\right) \end{aligned}

于是所有要乘起来的式子都变成了次数不超过 2m2m 的多项式。

可以利用分治 NTT,分治过程求出 B(x)=iBi(x)B(x)=\prod_iB_i(x)A(x)=iAi(x)jiBj(x)A(x)=\sum_iA_i(x)\prod_{j\neq i}B_j(x)

时间复杂度 O(nmlog2n)O(nm\log^2 n)

分治时可以仅保留前 n2n-2 项。

Code

Note: 代码中 A(x),B(x)A(x), B(x)Solution 中的 A(x),B(x)A(x), B(x) 反了过来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <vector>

const int N = 30050;
const int M = 35;
const int mod = 998244353;
const int g = 3;

typedef long long LL;
typedef std::vector<LL> VLL;

int n, m;
LL a[N], fac[N], ifac[N], inv[N];

inline LL pow_mod(LL x, int b) {
LL ans = 1;
for ((b += mod - 1) %= (mod - 1); b; b >>= 1, (x *= x) %= mod)
if ((b & 1) != 0) (ans *= x) %= mod;
return ans;
}

void NTT(LL *A, int len, int opt) {
for (int i = 1, j = 0; i < len; ++i) {
int k = len;
do j ^= (k >>= 1); while ((j & k) == 0);
if (i < j) std::swap(A[i], A[j]);
}
for (int h = 2; h <= len; h <<= 1) {
LL wn = pow_mod(g, (mod - 1) / h * opt);
for (int j = 0; j < len; j += h) {
LL w = 1LL;
for (int i = j; i < j + (h >> 1); ++i) {
LL _tmp1 = A[i], _tmp2 = A[i + (h >> 1)] * w % mod;
A[i] = (_tmp1 + _tmp2) % mod;
A[i + (h >> 1)] = (_tmp1 - _tmp2) % mod;
(w *= wn) %= mod;
}
}
}

if (opt == -1)
for (int i = 0; i < len; ++i)
(A[i] *= -(mod - 1) / len) %= mod;
}

LL S[2 * M][2 * M];

struct PVLL{
VLL A, B;
PVLL() : A(0), B(0) {}
};

inline void Copy(const VLL &x, LL *A, int len) {
for (int i = 0; i < len; ++i) A[i] = (i < x.size() ? x[i] : 0);
}

PVLL operator*(const PVLL &x, const PVLL &y) {
static LL T1[N * M * 4], T2[N * M * 4], T3[N * M * 4];
PVLL ans;
int len = 1,
ansl = std::max(x.A.size() + y.B.size(), x.B.size() + y.A.size());
while (len < ansl - 1) len <<= 1;

Copy(x.A, T1, len);
NTT(T1, len, 1);
Copy(y.A, T2, len);
NTT(T2, len, 1);

for (int i = 0; i < len; ++i) T3[i] = T1[i] * T2[i] % mod;
NTT(T3, len, -1);
ans.A.resize(std::min<int>(n, x.A.size() + y.A.size() - 1));
for (int i = 0; i < std::min<int>(n, x.A.size() + y.A.size() - 1); ++i)
(ans.A[i] = T3[i]) %= mod;

Copy(x.B, T3, len);
NTT(T3, len, 1);
for (int i = 0; i < len; ++i) (T3[i] *= T2[i]) %= mod;
NTT(T3, len, -1);
ans.B.resize(std::min(n, ansl - 1));
for (int i = 0; i < std::min(n, ansl - 1); ++i)
ans.B[i] = T3[i];

Copy(y.B, T3, len);
NTT(T3, len, 1);
for (int i = 0; i < len; ++i) (T3[i] *= T1[i]) %= mod;
NTT(T3, len, -1);
for (int i = 0; i < std::min(n, ansl - 1); ++i)
(ans.B[i] += T3[i]) %= mod;
return ans;
}

PVLL Solve(int l, int r) {
if (l == r - 1) {
PVLL ans;
ans.A.resize(m + 1);
ans.B.resize(2 * m + 1);
LL pa = a[l];
for (int i = 0; i <= 2 * m; ++i, (pa *= a[l]) %= mod) {
if (i <= m) ans.A[i] = (LL)pa * S[m + 1][i + 1] % mod;
ans.B[i] = (LL)pa * S[2 * m + 1][i + 1] % mod;
}
return ans;
}
int mid = (l + r + 1) >> 1;
return Solve(l, mid) * Solve(mid, r);
}

LL ps[N];

int main() {
scanf("%d%d", &n, &m);

fac[0] = fac[1] = ifac[0] = ifac[1] = inv[1] = 1;
for (int i = 2; i <= n; ++i) {
fac[i] = fac[i - 1] * i % mod;
inv[i] = -(mod / i) * inv[mod % i] % mod;
ifac[i] = ifac[i - 1] * inv[i] % mod;
}

S[0][0] = 1;
for (int i = 1; i <= 2 * m + 1; ++i)
for (int j = 1; j <= 2 * m + 1; ++j)
S[i][j] = (S[i - 1][j - 1] + S[i - 1][j] * j) % mod;

int s = 0;
for (int i = 0; i < n; ++i) {
scanf("%lld", &a[i]);
(s += a[i]) %= mod;
}
PVLL res = Solve(0, n);
LL ans = 0;
ps[0] = 1;
for (int i = 1; i <= n - 2; ++i) ps[i] = ps[i - 1] * s % mod;
for (int i = 0; i <= n - 2; ++i)
(ans += res.B[i] * ps[n - 2 - i] % mod * ifac[n - 2 - i] % mod) %= mod;
printf("%lld\n", (ans * fac[n - 2] % mod + mod) % mod);
return 0;
}