LOJ6391 [THUPC2018] Tommy神的树

Description

一棵 \(n\) 个结点的树。初始时 \(a\)\(b\) 是黑的,其他点是白的。

每次可以把某个黑点染成红的并把与它相邻的白点染成黑的。

问把结点染红的顺序有多少种。

\(1\leqslant a,b\leqslant n\leqslant234567\)

答案对 \(998244353\) 取模,时限 10s 。

Solution

首先我们来考虑只有一个起点的情况。我们以这个点为根,进行树形 DP 可得( \(s_i\)\(i\) 的子树大小)

\[\begin{aligned} f_i &= (s_i - 1)!\prod_{j\in son_i}\frac{f_j}{s_j!}\\ \frac{f_i}{s_i!} &= s_i\prod_{j\in son_i}\frac{f_j}{s_j!} \end{aligned}\]

可以发现整棵树的方案数就是 \(n!/(\prod_is_i)\)

考虑有两个起点 \(a\)\(b\) 。可以看成新建了一个点 \(s\) ,其与 \(a,b\) 相连;最开始只有 \(s\) 是黑色的。这是一棵基环树。

我们拿出 \(a\)\(b\) 路径上的所有点 \((s=v_{-1},)a=v_0,v_1,\dots,v_m=b(,v_{m+1}=s)\)

枚举这条路径上最后一个被染红的点 \(v_i\) 。它既可以看做从左边染过来的又可以看做从右边染过来的。

如果我们把它看做从左边染过来的,那么相当于我们把 \(v_i-v_{i+1}\) 这条边(如果存在的话)切断了。同理如果我们把它看做从右边染过来的,那么相当于把 \(v_{i-1}-v_i\) 这条边切断了。

也就是说,如果我们切断了 \(v_i-v_{i+1}\) 这条边,那么现在的方案对应了 \(v_i\) 最后一个染色的方案和 \(v_{i+1}\) 最后一个染色的方案。

那么枚举切断了哪一条边。这个时候基环树变成了一棵树,其答案是 \(n!/(\prod_is_i)\) 。把所有答案都加起来除以 2 ,就可以得到正确答案(因为每个方案都算了两遍)。

考虑如何计算每种方案的 \(\prod_is_i\) 。显然除了 \(a,b\) 之间路径上的点之外所有点的 \(s\) 都不会变化,所以只需要考虑 \(a,b\) 路径上的点,即 \(\prod_{i=0}^ms_{v_i}\)

\(b_i\) 表示 \(v_i\) 这个点除去 \(v_{i-1},v_{i+1}\) 之外的子树大小(就是挂在链上的子树大小),那么如果割去 \(v_j-v_{j+1}\) 的边,那么 \(v_i\) 的子树大小为 \(\sum_{k=j+1}^ib_k\quad(i>j)\) 或者 \(\sum_{k=i}^kb_k\quad(i<j)\) 。如果我们取 \(b\) 的前缀和 \(x\) ,那么这就对应了 \(|s_i-s_j|\) 。也就是说,我们需要对于每个 \(j\) 计算 \(\prod_{i\neq j}|x_i-x_j|=(-1)^{m-j}\prod_{i\neq j}(x_j-x_i)\)

\(\prod_{i\neq j}(x_j-x_i)=f(x_j)\) ,其中 \(f(x)=\sum_i\prod_{j\neq i}(x-x_j)=\frac{\mathrm{d}}{\mathrm{d}x}\left(\prod_i(x-x_i)\right)\) 可以分治 NTT 求得。于是套多项式多点求值模板即可,复杂度 \(O(n\log^2n)\)

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <vector>

typedef long long LL;
typedef std::vector<LL> VLL;
const int mod = 998244353;

LL pow_mod(LL a, LL b) {
LL ans = 1;
if ((b %= mod - 1) < 0) b += mod - 1;
for (a %= mod; b > 0; b >>= 1, (a *= a) %= mod)
if (b & 1) (ans *= a) %= mod;
return ans;
}

namespace Solve1{
const int g = 3;
const int N = 500000;

void NTT(LL *A, int len, int opt) {
for (int i = 1, j = 0; i < len; ++i) {
for (int k = len; ~j & k; j ^= (k >>= 1));
if (i < j) std::swap(A[i], A[j]);
}
for (int h = 2; h <= len; h <<= 1) {
LL wn = pow_mod(g, (mod - 1) / h * opt);
for (int j = 0; j < len; j += h) {
LL w = 1;
for (int i = j; i < j + (h >> 1); ++i) {
LL _t1 = A[i], _t2 = A[i + (h >> 1)] * w % mod;
A[i] = (_t1 + _t2) % mod;
A[i + (h >> 1)] = (_t1 - _t2) % mod;
(w *= wn) %= mod;
}
}
}
if (opt == -1)
for (int i = 0, v = -(mod - 1) / len; i < len; ++i)
(A[i] *= v) %= mod;
}

inline int readInt() {
int ans = 0, c, f = 1;
while (!isdigit(c = getchar()))
if (c == '-') f *= -1;
do ans = ans * 10 + c - '0';
while (isdigit(c = getchar()));
return ans * f;
}

void Conv(const VLL &A, const VLL &B, VLL &ans) {
static LL tA[N], tB[N];
int n = A.size(), m = B.size(), len = 1;
while (len < n + m) len <<= 1;
for (int i = 0; i < n; ++i) tA[i] = A[i];
for (int i = n; i < len; ++i) tA[i] = 0;
for (int i = 0; i < m; ++i) tB[i] = B[i];
for (int i = m; i < len; ++i) tB[i] = 0;
NTT(tA, len, 1); NTT(tB, len, 1);
for (int i = 0; i < len; ++i)
(tA[i] *= tB[i]) %= mod;
NTT(tA, len, -1);
ans.resize(n + m - 1);
for (int i = 0; i < n + m - 1; ++i) ans[i] = tA[i];
}

void PolyInv(const LL *A, int n, LL *B) {
if (n == 1) {
B[0] = pow_mod(A[0], mod - 2);
return;
}
static LL tA[N], tB[N];
int m = (n + 1) / 2, len = 1;
while (len < n * 2) len <<= 1;
PolyInv(A, m, B);
for (int i = 0; i < n; ++i) tA[i] = A[i];
for (int i = n; i < len; ++i) tA[i] = 0;
for (int i = 0; i < m; ++i) tB[i] = B[i];
for (int i = m; i < len; ++i) tB[i] = 0;
NTT(tA, len, 1); NTT(tB, len, 1);
for (int i = 0; i < len; ++i)
(tB[i] *= (2 - tA[i] * tB[i] % mod)) %= mod;
NTT(tB, len, -1);
for (int i = 0; i < n; ++i) B[i] = tB[i];
}

void PolyMod(const LL *A, int n, const LL *B, int m, LL *C) {
if (n < m) {
for (int i = 0; i < n; ++i) C[i] = A[i];
return;
}
static LL t1[N], t2[N];
int t = n - m + 1;
for (int i = 0; i < m; ++i) t2[i] = B[m - i - 1];
for (int i = m; i < t; ++i) t2[i] = 0;
PolyInv(t2, t, t1);
int len = 1;
while (len < 2 * t) len <<= 1;
for (int i = t; i < len; ++i) t1[i] = t2[i] = 0;
for (int i = 0; i < t; ++i) t2[i] = A[n - i - 1];
NTT(t1, len, 1); NTT(t2, len, 1);
for (int i = 0; i < len; ++i) (t1[i] *= t2[i]) %= mod;
NTT(t1, len, -1);
len = 1;
while (len < n) len <<= 1;
for (int i = 0; i < t - i - 1; ++i) std::swap(t1[i], t1[t - i - 1]);
for (int i = t; i < len; ++i) t1[i] = 0;
for (int i = 0; i < m; ++i) t2[i] = B[i];
for (int i = m; i < len; ++i) t2[i] = 0;
NTT(t1, len, 1); NTT(t2, len, 1);
for (int i = 0; i < len; ++i) (t1[i] *= t2[i]) %= mod;
NTT(t1, len, -1);
for (int i = 0; i < m - 1; ++i) C[i] = (A[i] - t1[i]) % mod;
}

void Mod(const VLL &A, const VLL &B, VLL &C) {
static LL tA[N], tB[N], tC[N];
int n = A.size(), m = B.size();
for (int i = 0; i < n; ++i) tA[i] = A[i];
for (int i = 0; i < m; ++i) tB[i] = B[i];
PolyMod(tA, n, tB, m, tC);
int s = std::min(m - 1, n);
C.resize(s);
for (int i = 0; i < s; ++i) C[i] = tC[i];
}

int n, cnt;
VLL A[N], B[N];
LL x[N], y[N];

void Solve1(int t, int l, int r) {
if (l == r) {
A[t].clear();
A[t].push_back(-x[l]);
A[t].push_back(1);
} else {
int mid = (l + r) >> 1, L = ++cnt, R = ++cnt;
Solve1(L, l, mid);
Solve1(R, mid + 1, r);
Conv(A[L], A[R], A[t]);
}
}

void Solve2(int t, int l, int r) {
if (l == r) {
y[l] = B[t][0];
} else {
int mid = (l + r) >> 1, L = ++cnt, R = ++cnt;
Mod(B[t], A[L], B[L]);
Mod(B[t], A[R], B[R]);
Solve2(L, l, mid);
Solve2(R, mid + 1, r);
}
}

void Solve() {
Solve1(cnt = 0, 0, n - 1);
B[0].resize(n);
for (int i = 0; i < n; ++i)
B[0][i] = A[0][i + 1] * (i + 1) % mod;
Solve2(cnt = 0, 0, n - 1);
}
};

const int N = 300050;

int n, a, b, pre[N], nxt[N * 2], to[N * 2], cnt;
int siz[N], num[N], m;
bool on[N];

inline void addEdge(int x, int y) {
nxt[cnt] = pre[x];
to[pre[x] = cnt++] = y;
}

bool dfs(int x, int fa) {
if (x == b) return on[num[m++] = x] = true;
for (LL i = pre[x]; i >= 0; i = nxt[i])
if (to[i] != fa && dfs(to[i], x))
return on[num[m++] = x] = true;
return false;
}

int dfs2(int x, int fa) {
siz[x] = 1;
for (LL i = pre[x]; i >= 0; i = nxt[i])
if (to[i] != fa && !on[to[i]])
siz[x] += dfs2(to[i], x);
return siz[x];
}

int main() {
scanf("%d%d%d", &n, &a, &b);
memset(pre, -1, sizeof pre);
for (int i = 1, x, y; i < n; ++i) {
scanf("%d%d", &x, &y);
addEdge(x, y);
addEdge(y, x);
}
dfs(a, 0);
Solve1::n = m + 1;
Solve1::x[0] = 0;
for (int i = 0; i < m; ++i)
Solve1::x[i + 1] = Solve1::x[i] + dfs2(num[i], 0);
Solve1::Solve();
LL ans = 0;
for (int i = 0; i <= m; ++i)
(ans += ((m - i) & 1 ? -1 : 1) * pow_mod(Solve1::y[i], mod - 2)) %= mod;
for (int i = 1; i <= n; ++i)
(ans *= i) %= mod;
for (int i = 1; i <= n; ++i) if (!on[i])
(ans *= pow_mod(siz[i], mod - 2)) %= mod;
printf("%lld\n", (ans * pow_mod(2, mod - 2) % mod + mod) % mod);
return 0;
}