LOJ6391 [THUPC2018] Tommy神的树
Description
一棵 $n$ 个结点的树。初始时 $a$ 和 $b$ 是黑的,其他点是白的。
每次可以把某个黑点染成红的并把与它相邻的白点染成黑的。
问把结点染红的顺序有多少种。
$1\leqslant a,b\leqslant n\leqslant234567$
答案对 $998244353$ 取模,时限 10s 。
Solution
首先我们来考虑只有一个起点的情况。我们以这个点为根,进行树形 DP 可得( $s_i$ 是 $i$ 的子树大小)
$$\begin{aligned} f_i &= (s_i - 1)!\prod_{j\in son_i}\frac{f_j}{s_j!}\\ \frac{f_i}{s_i!} &= s_i\prod_{j\in son_i}\frac{f_j}{s_j!} \end{aligned}$$
可以发现整棵树的方案数就是 $n!/(\prod_is_i)$ 。
考虑有两个起点 $a$ 和 $b$ 。可以看成新建了一个点 $s$ ,其与 $a,b$ 相连;最开始只有 $s$ 是黑色的。这是一棵基环树。
我们拿出 $a$ 到 $b$ 路径上的所有点 $(s=v_{-1},)a=v_0,v_1,\dots,v_m=b(,v_{m+1}=s)$ 。
枚举这条路径上最后一个被染红的点 $v_i$ 。它既可以看做从左边染过来的又可以看做从右边染过来的。
如果我们把它看做从左边染过来的,那么相当于我们把 $v_i-v_{i+1}$ 这条边(如果存在的话)切断了。同理如果我们把它看做从右边染过来的,那么相当于把 $v_{i-1}-v_i$ 这条边切断了。
也就是说,如果我们切断了 $v_i-v_{i+1}$ 这条边,那么现在的方案对应了 $v_i$ 最后一个染色的方案和 $v_{i+1}$ 最后一个染色的方案。
那么枚举切断了哪一条边。这个时候基环树变成了一棵树,其答案是 $n!/(\prod_is_i)$ 。把所有答案都加起来除以 2 ,就可以得到正确答案(因为每个方案都算了两遍)。
考虑如何计算每种方案的 $\prod_is_i$ 。显然除了 $a,b$ 之间路径上的点之外所有点的 $s$ 都不会变化,所以只需要考虑 $a,b$ 路径上的点,即 $\prod_{i=0}^ms_{v_i}$ 。
令 $b_i$ 表示 $v_i$ 这个点除去 $v_{i-1},v_{i+1}$ 之外的子树大小(就是挂在链上的子树大小),那么如果割去 $v_j-v_{j+1}$ 的边,那么 $v_i$ 的子树大小为 $\sum_{k=j+1}^ib_k\quad(i>j)$ 或者 $\sum_{k=i}^kb_k\quad(i<j)$ 。如果我们取 $b$ 的前缀和 $x$ ,那么这就对应了 $|s_i-s_j|$ 。也就是说,我们需要对于每个 $j$ 计算 $\prod_{i\neq j}|x_i-x_j|=(-1)^{m-j}\prod_{i\neq j}(x_j-x_i)$ 。
而 $\prod_{i\neq j}(x_j-x_i)=f(x_j)$ ,其中 $f(x)=\sum_i\prod_{j\neq i}(x-x_j)=\frac{\mathrm{d}}{\mathrm{d}x}\left(\prod_i(x-x_i)\right)$ 可以分治 NTT 求得。于是套多项式多点求值模板即可,复杂度 $O(n\log^2n)$ 。
Code
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <vector>
typedef long long LL;
typedef std::vector<LL> VLL;
const int mod = 998244353;
(LL a, LL b) {
LL pow_mod= 1;
LL ans if ((b %= mod - 1) < 0) b += mod - 1;
for (a %= mod; b > 0; b >>= 1, (a *= a) %= mod)
if (b & 1) (ans *= a) %= mod;
return ans;
}
namespace Solve1{
const int g = 3;
const int N = 500000;
void NTT(LL *A, int len, int opt) {
for (int i = 1, j = 0; i < len; ++i) {
for (int k = len; ~j & k; j ^= (k >>= 1));
if (i < j) std::swap(A[i], A[j]);
}
for (int h = 2; h <= len; h <<= 1) {
= pow_mod(g, (mod - 1) / h * opt);
LL wn for (int j = 0; j < len; j += h) {
= 1;
LL w for (int i = j; i < j + (h >> 1); ++i) {
= A[i], _t2 = A[i + (h >> 1)] * w % mod;
LL _t1 [i] = (_t1 + _t2) % mod;
A[i + (h >> 1)] = (_t1 - _t2) % mod;
A(w *= wn) %= mod;
}
}
}
if (opt == -1)
for (int i = 0, v = -(mod - 1) / len; i < len; ++i)
(A[i] *= v) %= mod;
}
inline int readInt() {
int ans = 0, c, f = 1;
while (!isdigit(c = getchar()))
if (c == '-') f *= -1;
do ans = ans * 10 + c - '0';
while (isdigit(c = getchar()));
return ans * f;
}
void Conv(const VLL &A, const VLL &B, VLL &ans) {
static LL tA[N], tB[N];
int n = A.size(), m = B.size(), len = 1;
while (len < n + m) len <<= 1;
for (int i = 0; i < n; ++i) tA[i] = A[i];
for (int i = n; i < len; ++i) tA[i] = 0;
for (int i = 0; i < m; ++i) tB[i] = B[i];
for (int i = m; i < len; ++i) tB[i] = 0;
(tA, len, 1); NTT(tB, len, 1);
NTTfor (int i = 0; i < len; ++i)
(tA[i] *= tB[i]) %= mod;
(tA, len, -1);
NTT.resize(n + m - 1);
ansfor (int i = 0; i < n + m - 1; ++i) ans[i] = tA[i];
}
void PolyInv(const LL *A, int n, LL *B) {
if (n == 1) {
[0] = pow_mod(A[0], mod - 2);
Breturn;
}
static LL tA[N], tB[N];
int m = (n + 1) / 2, len = 1;
while (len < n * 2) len <<= 1;
(A, m, B);
PolyInvfor (int i = 0; i < n; ++i) tA[i] = A[i];
for (int i = n; i < len; ++i) tA[i] = 0;
for (int i = 0; i < m; ++i) tB[i] = B[i];
for (int i = m; i < len; ++i) tB[i] = 0;
(tA, len, 1); NTT(tB, len, 1);
NTTfor (int i = 0; i < len; ++i)
(tB[i] *= (2 - tA[i] * tB[i] % mod)) %= mod;
(tB, len, -1);
NTTfor (int i = 0; i < n; ++i) B[i] = tB[i];
}
void PolyMod(const LL *A, int n, const LL *B, int m, LL *C) {
if (n < m) {
for (int i = 0; i < n; ++i) C[i] = A[i];
return;
}
static LL t1[N], t2[N];
int t = n - m + 1;
for (int i = 0; i < m; ++i) t2[i] = B[m - i - 1];
for (int i = m; i < t; ++i) t2[i] = 0;
(t2, t, t1);
PolyInvint len = 1;
while (len < 2 * t) len <<= 1;
for (int i = t; i < len; ++i) t1[i] = t2[i] = 0;
for (int i = 0; i < t; ++i) t2[i] = A[n - i - 1];
(t1, len, 1); NTT(t2, len, 1);
NTTfor (int i = 0; i < len; ++i) (t1[i] *= t2[i]) %= mod;
(t1, len, -1);
NTT= 1;
len while (len < n) len <<= 1;
for (int i = 0; i < t - i - 1; ++i) std::swap(t1[i], t1[t - i - 1]);
for (int i = t; i < len; ++i) t1[i] = 0;
for (int i = 0; i < m; ++i) t2[i] = B[i];
for (int i = m; i < len; ++i) t2[i] = 0;
(t1, len, 1); NTT(t2, len, 1);
NTTfor (int i = 0; i < len; ++i) (t1[i] *= t2[i]) %= mod;
(t1, len, -1);
NTTfor (int i = 0; i < m - 1; ++i) C[i] = (A[i] - t1[i]) % mod;
}
void Mod(const VLL &A, const VLL &B, VLL &C) {
static LL tA[N], tB[N], tC[N];
int n = A.size(), m = B.size();
for (int i = 0; i < n; ++i) tA[i] = A[i];
for (int i = 0; i < m; ++i) tB[i] = B[i];
(tA, n, tB, m, tC);
PolyModint s = std::min(m - 1, n);
.resize(s);
Cfor (int i = 0; i < s; ++i) C[i] = tC[i];
}
int n, cnt;
[N], B[N];
VLL A[N], y[N];
LL x
void Solve1(int t, int l, int r) {
if (l == r) {
[t].clear();
A[t].push_back(-x[l]);
A[t].push_back(1);
A} else {
int mid = (l + r) >> 1, L = ++cnt, R = ++cnt;
(L, l, mid);
Solve1(R, mid + 1, r);
Solve1(A[L], A[R], A[t]);
Conv}
}
void Solve2(int t, int l, int r) {
if (l == r) {
[l] = B[t][0];
y} else {
int mid = (l + r) >> 1, L = ++cnt, R = ++cnt;
(B[t], A[L], B[L]);
Mod(B[t], A[R], B[R]);
Mod(L, l, mid);
Solve2(R, mid + 1, r);
Solve2}
}
void Solve() {
(cnt = 0, 0, n - 1);
Solve1[0].resize(n);
Bfor (int i = 0; i < n; ++i)
[0][i] = A[0][i + 1] * (i + 1) % mod;
B(cnt = 0, 0, n - 1);
Solve2}
};
const int N = 300050;
int n, a, b, pre[N], nxt[N * 2], to[N * 2], cnt;
int siz[N], num[N], m;
bool on[N];
inline void addEdge(int x, int y) {
[cnt] = pre[x];
nxt[pre[x] = cnt++] = y;
to}
bool dfs(int x, int fa) {
if (x == b) return on[num[m++] = x] = true;
for (LL i = pre[x]; i >= 0; i = nxt[i])
if (to[i] != fa && dfs(to[i], x))
return on[num[m++] = x] = true;
return false;
}
int dfs2(int x, int fa) {
[x] = 1;
sizfor (LL i = pre[x]; i >= 0; i = nxt[i])
if (to[i] != fa && !on[to[i]])
[x] += dfs2(to[i], x);
sizreturn siz[x];
}
int main() {
("%d%d%d", &n, &a, &b);
scanf(pre, -1, sizeof pre);
memsetfor (int i = 1, x, y; i < n; ++i) {
("%d%d", &x, &y);
scanf(x, y);
addEdge(y, x);
addEdge}
(a, 0);
dfs::n = m + 1;
Solve1::x[0] = 0;
Solve1for (int i = 0; i < m; ++i)
::x[i + 1] = Solve1::x[i] + dfs2(num[i], 0);
Solve1::Solve();
Solve1= 0;
LL ans for (int i = 0; i <= m; ++i)
(ans += ((m - i) & 1 ? -1 : 1) * pow_mod(Solve1::y[i], mod - 2)) %= mod;
for (int i = 1; i <= n; ++i)
(ans *= i) %= mod;
for (int i = 1; i <= n; ++i) if (!on[i])
(ans *= pow_mod(siz[i], mod - 2)) %= mod;
("%lld\n", (ans * pow_mod(2, mod - 2) % mod + mod) % mod);
printfreturn 0;
}