Description
一个s个点的图,目前被s−n条边连成了n个连通块,第i个连通块大小为ai。要求你再连n−1条边把它变成一棵树。对于每一棵树,若第i个连通块连出了di条边,其价值为
val(T)=(∏i=1ndim)(∑i=1ndim)
求所有生成树的价值和mod998244353。n⩽30000,m⩽30。
Note: BZOJ 上出了问题(据说评测出了问题,页面上的数据下载下来本地 50s ),可以去 LOJ 交。
Solution
首先,“连通块的大小”只是起到一个系数的作用,如果在第 i 个连通块上连了 di 条边,那么方案数要乘 aidi 。
所以说
ans=T∑(i=1∏naididim)(i=1∑ndim)=i=1∑nT∑(i=1∏naididim)dim=i=1∑nT∑aididi2m⎝⎜⎛j=i∏ajdjdjm⎠⎟⎞
推不下去了,因为“枚举所有树”太难。
我们考虑整个式子跟树的结构无关而只跟每个点的度数有关,于是可以想到 prufer 序列。
在 prufer 序列中,如果某个点的编号出现了 k 次,那么它的度数即为 k+1 。
于是枚举所有树可以变成枚举所有长为 n−2 ,元素为 1…n 的序列;令 ki 表示序列中 i 这个点出现的次数,则 di=ki+1 。显然如果我固定所有 k 之后其方案数为 ∏i(ki!)(n−2)! ,所以答案即为
ans=i=1∑nT∑aididi2m⎝⎜⎛j=i∏aididim⎠⎟⎞=i=1∑nk1+k2+⋯+kn=n−2∑∏i(ki!)(n−2)!aiki+1(ki+1)2m⎝⎜⎛j=i∏ajkj+1(kj+1)m⎠⎟⎞
可以发现这就是一个排列问题(枚举 i 之后令第 i 个点选 k 个位置的价值为 aik+1(k+1)2m ,其它第 j 个点选 k 的价值为 ajk+1(k+1)m ),于是可以想到指数型生成函数。
也即
F(x)Ai(x)Bi(x)=i=1∑nAi(x)j=i∏Bj(x)=k∑aik+1(k+1)2mk!xk=k∑aik+1(k+1)mk!xk
最终答案即为 F 第 n−2 项的系数乘上 (n−2)! (别忘了这是指数型生成函数)。
我们发现这样复杂度是 O(n2logn) ,无法忍受。
考虑如何简化 Ai(x) 。考虑到 m 相比于 n 很小,根据第二类 Stirling 数的性质 am=∑i=0mS(m,i)ai,我们有
Ai(x)T(x)=∫Ai(x)dxAi(x)=dxdT(x)=k∑aik+1(k+1)2mk!xk=k∑aik+1(k+1)2m(k+1)!xk+1=k∑aikk2mk!xk=j=0∑2mS(2m,j)k∑aikkjk!xk=j=0∑2mS(2m,j)k∑aik(k−j)!xk=j=0∑2mS(2m,j)aijxjeaix=j=0∑2m[S(2m,j)jaijxj−1eaix+S(2m,j)aij+1xjeaix]=eaixj=0∑2m[S(2m,j+1)(j+1)aij+1+S(2m,j)aij+1]xj=eaixj=0∑2mS(2m+1,j+1)aij+1xj
最后一步是由于第二类 Stirling 数的递推公式 S(i,j)=S(i−1,j)j+S(i−1,j−1) 。
同理有Bi(x)=eaix∑j=0mS(m+1,j+1)aij+1xj。
所以
F(x)=i=1∑nAi(x)j=i∏Bj(x)=i=1∑n(eaixk=0∑2mS(2m+1,k+1)aik+1xk)j=i∏(eajxk=0∑mS(m+1,k+1)ajk+1xk)=esxi=1∑n(k=0∑2mS(2m+1,k+1)aik+1xk)j=i∏(k=0∑mS(m+1,k+1)ajk+1xk)
于是所有要乘起来的式子都变成了次数不超过 2m 的多项式。
可以利用分治 NTT,分治过程求出 B(x)=∏iBi(x) 和 A(x)=∑iAi(x)∏j=iBj(x) 。
时间复杂度 O(nmlog2n) 。
分治时可以仅保留前 n−2 项。
Code
Note: 代码中 A(x),B(x) 和 Solution 中的 A(x),B(x) 反了过来。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
| #include <algorithm> #include <cstdio> #include <cstring> #include <vector>
const int N = 30050; const int M = 35; const int mod = 998244353; const int g = 3;
typedef long long LL; typedef std::vector<LL> VLL;
int n, m; LL a[N], fac[N], ifac[N], inv[N];
inline LL pow_mod(LL x, int b) { LL ans = 1; for ((b += mod - 1) %= (mod - 1); b; b >>= 1, (x *= x) %= mod) if ((b & 1) != 0) (ans *= x) %= mod; return ans; }
void NTT(LL *A, int len, int opt) { for (int i = 1, j = 0; i < len; ++i) { int k = len; do j ^= (k >>= 1); while ((j & k) == 0); if (i < j) std::swap(A[i], A[j]); } for (int h = 2; h <= len; h <<= 1) { LL wn = pow_mod(g, (mod - 1) / h * opt); for (int j = 0; j < len; j += h) { LL w = 1LL; for (int i = j; i < j + (h >> 1); ++i) { LL _tmp1 = A[i], _tmp2 = A[i + (h >> 1)] * w % mod; A[i] = (_tmp1 + _tmp2) % mod; A[i + (h >> 1)] = (_tmp1 - _tmp2) % mod; (w *= wn) %= mod; } } }
if (opt == -1) for (int i = 0; i < len; ++i) (A[i] *= -(mod - 1) / len) %= mod; }
LL S[2 * M][2 * M];
struct PVLL{ VLL A, B; PVLL() : A(0), B(0) {} };
inline void Copy(const VLL &x, LL *A, int len) { for (int i = 0; i < len; ++i) A[i] = (i < x.size() ? x[i] : 0); }
PVLL operator*(const PVLL &x, const PVLL &y) { static LL T1[N * M * 4], T2[N * M * 4], T3[N * M * 4]; PVLL ans; int len = 1, ansl = std::max(x.A.size() + y.B.size(), x.B.size() + y.A.size()); while (len < ansl - 1) len <<= 1;
Copy(x.A, T1, len); NTT(T1, len, 1); Copy(y.A, T2, len); NTT(T2, len, 1);
for (int i = 0; i < len; ++i) T3[i] = T1[i] * T2[i] % mod; NTT(T3, len, -1); ans.A.resize(std::min<int>(n, x.A.size() + y.A.size() - 1)); for (int i = 0; i < std::min<int>(n, x.A.size() + y.A.size() - 1); ++i) (ans.A[i] = T3[i]) %= mod;
Copy(x.B, T3, len); NTT(T3, len, 1); for (int i = 0; i < len; ++i) (T3[i] *= T2[i]) %= mod; NTT(T3, len, -1); ans.B.resize(std::min(n, ansl - 1)); for (int i = 0; i < std::min(n, ansl - 1); ++i) ans.B[i] = T3[i];
Copy(y.B, T3, len); NTT(T3, len, 1); for (int i = 0; i < len; ++i) (T3[i] *= T1[i]) %= mod; NTT(T3, len, -1); for (int i = 0; i < std::min(n, ansl - 1); ++i) (ans.B[i] += T3[i]) %= mod; return ans; }
PVLL Solve(int l, int r) { if (l == r - 1) { PVLL ans; ans.A.resize(m + 1); ans.B.resize(2 * m + 1); LL pa = a[l]; for (int i = 0; i <= 2 * m; ++i, (pa *= a[l]) %= mod) { if (i <= m) ans.A[i] = (LL)pa * S[m + 1][i + 1] % mod; ans.B[i] = (LL)pa * S[2 * m + 1][i + 1] % mod; } return ans; } int mid = (l + r + 1) >> 1; return Solve(l, mid) * Solve(mid, r); }
LL ps[N];
int main() { scanf("%d%d", &n, &m);
fac[0] = fac[1] = ifac[0] = ifac[1] = inv[1] = 1; for (int i = 2; i <= n; ++i) { fac[i] = fac[i - 1] * i % mod; inv[i] = -(mod / i) * inv[mod % i] % mod; ifac[i] = ifac[i - 1] * inv[i] % mod; }
S[0][0] = 1; for (int i = 1; i <= 2 * m + 1; ++i) for (int j = 1; j <= 2 * m + 1; ++j) S[i][j] = (S[i - 1][j - 1] + S[i - 1][j] * j) % mod;
int s = 0; for (int i = 0; i < n; ++i) { scanf("%lld", &a[i]); (s += a[i]) %= mod; } PVLL res = Solve(0, n); LL ans = 0; ps[0] = 1; for (int i = 1; i <= n - 2; ++i) ps[i] = ps[i - 1] * s % mod; for (int i = 0; i <= n - 2; ++i) (ans += res.B[i] * ps[n - 2 - i] % mod * ifac[n - 2 - i] % mod) %= mod; printf("%lld\n", (ans * fac[n - 2] % mod + mod) % mod); return 0; }
|