BZOJ4911 [SDOI2017] 切树游戏

Description

一棵 \(n\) 个结点的树,每个点有点权 \(a_i\) 。有 \(q\) 个操作:

  • Change x y, 表示把编号为 \(x\) 的点的点权改为 \(y\);
  • Query k, 表示询问点权异或和为 \(k\) 的连通块个数 \(\bmod 10007\)

\(n\leqslant30000,a_i\in[0, m),m\leqslant128,q\leqslant30000\)

\(m\)\(2\) 的幂。

Solution

先考虑没有修改时怎么做。

由于 \(m\leqslant128\) ,我们可以求出每一种答案的值。 DP 即可。

\(f_{i,j}\) 表示以 \(i\) 为根的子树中包含 \(i\) 的异或和为 \(j\) 的连通块个数,则显然有

\[F_i=x^{a_i}\prod_{k\in son_i}(F_k+1)\]

其中乘法为异或卷积,乘 \(x^{a_i}\) 表示下标对 \(a_i\) 取模(取 vfk 集合幂级数的含义)。这样就是 \(O(nm^2)\) 了。

考虑快速沃尔什-哈达玛变换,即

\[\mathcal F(f)_i=\sum_{j=0}^{m-1}(-1)^{count(i\&j)}f_j\]

其中 \(\&\) 为按位与, \(count\) 统计二进制数中 \(1\) 的个数。容易证明若 \(a=b*c\) (异或卷积)则

\[\mathcal F(a)=\mathcal F(b)\cdot\mathcal F(c)\]

其中 \(\cdot\) 表示点乘,即逐位相乘。且我们有 \(\mathcal F(\mathcal F(f))_i=m\times f_i\) ,于是逆变换很容易进行。

那么我们可以 DP 初值时就 FWT 一下,计算时全部直接点乘( \(+1\) 现在要在每一位上都 \(+1\) )。最后求出所有 \(\mathcal F(f_i)\) 的和再变换回来。复杂度是 \(O(nm+m\log m)\)

再考虑有修改时怎么做。

考虑树剖维护 DP 。即树剖之后,每个点维护一个 \(H_i\) 表示只考虑轻儿子和自己的时候的 DP 值(也就相当于把重儿子切下来)。令 \(w_i\) 表示 \(i\) 的重儿子,则

\[F_i=H_i(1+F_{w_i})=H_i+H_iH_{w_i}(1+F_{w_{w_i}})=H_i+H_iH_{w_i}+\dots+H_i\dots H_{w_{\ddots_i}}\]

由于一条重链上的点在线段树中是连续的,所以 \(F_i\) 就等于这个区间的所有前缀的积的和。

同理,一条重链上的所有 \(F\) 的和就是这个区间内的所有子区间的积的和。

那么考虑修改时,修改一个点会影响到 \(O(\log n)\) 条重链。直接在线段树上修改查询即可。

一个细节:去掉某个轻儿子对其父亲的影响时,可能会除零。所以说要维护每个点所有轻儿子中零的个数和非零部分的乘积。

时间复杂度 \(O(nm+qm\log n)\)

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
#include <algorithm>
#include <cstdio>
#include <cstring>

const int N = 30050;
const int M = 128;
const int mod = 10007;

int n, m;
int inv[mod];
int count[M];

struct Int{
int a, p; // a * mod^p
Int(int t = 0) : a(t == 0 ? 1 : t), p(t == 0) {}
Int(int a, int p) : a(a), p(p) {}
inline operator int() const { return p == 0 ? a : 0; }
Int operator+(int x) const {
if ((x %= mod) != 0)
return Int(p > 0 ? x : (a + x) % mod, 0);
else
return *this;
}
Int &operator*=(int x) {
if ((x %= mod) == 0) ++p;
else (a *= x) %= mod;
return *this;
}
Int &operator/=(int x) {
if ((x %= mod) == 0) --p;
else (a *= inv[(x + mod) % mod]) %= mod;
return *this;
}
};

Int H[N][M];
int F[N][M];

struct Msg{
int sum[M], suf[M], pre[M], mul[M];
Msg() { Init(); }
Msg(Int H[M]) {
memset(sum, 0, sizeof sum);
memset(suf, 0, sizeof suf);
memset(pre, 0, sizeof pre);
memset(mul, 0, sizeof mul);
for (int i = 0; i < m; ++i)
sum[i] = suf[i] = pre[i] = mul[i] = H[i];
}
void Init() {
memset(sum, 0, sizeof sum);
memset(suf, 0, sizeof suf);
memset(pre, 0, sizeof pre);
for (int i = 0; i < m; ++i)
mul[i] = 1;
}
friend Msg operator*(const Msg &l, const Msg &r) {
Msg ans;
for (int i = 0; i < m; ++i) {
ans.sum[i] = (l.sum[i] + r.sum[i] + l.suf[i] * r.pre[i]) % mod;
ans.pre[i] = (l.pre[i] + l.mul[i] * r.pre[i]) % mod;
ans.suf[i] = (r.suf[i] + r.mul[i] * l.suf[i]) % mod;
ans.mul[i] = l.mul[i] * r.mul[i] % mod;
}
return ans;
}
}msgv[N * 4], tmp;

int node[N];

void upd(int o, int l, int r) {
if (l == r)
msgv[o] = Msg(H[node[l]]);
else
msgv[o] = msgv[o << 1] * msgv[o << 1 | 1];
}

void Build(int o, int l, int r) {
if (l != r) {
int mid = (l + r) >> 1;
Build(o << 1, l, mid);
Build(o << 1 | 1, mid + 1, r);
}
upd(o, l, r);
}

void Query(int o, int l, int r, int L, int R) {
if (R < l || r < L) return;
if (L <= l && r <= R) return void(tmp = tmp * msgv[o]);
int mid = (l + r) >> 1;
Query(o << 1, l, mid, L, R);
Query(o << 1 | 1, mid + 1, r, L, R);
}

void Modify(int o, int l, int r, int x) {
if (r < x || l > x) return;
if (l != r) {
int mid = (l + r) >> 1;
Modify(o << 1, l, mid, x);
Modify(o << 1 | 1, mid + 1, r, x);
}
upd(o, l, r);
}

int siz[N], fa[N], pos[N], top[N], bottom[N], son[N];
int pre[N], to[N * 2], nxt[N * 2], val[N];

inline void addEdge(int x, int y) {
static int cnt = 0;
nxt[cnt] = pre[x];
to[pre[x] = cnt++] = y;
nxt[cnt] = pre[y];
to[pre[y] = cnt++] = x;
}

int dfs0(int x, int f) {
fa[x] = f;
son[x] = 0;
siz[x] = 1;
for (int i = pre[x]; i != -1; i = nxt[i])
if (to[i] != f) {
siz[x] += dfs0(to[i], x);
if (siz[to[i]] > siz[son[x]]) son[x] = to[i];
}
return siz[x];
}

int Ans[M];

int dfs1(int x, int tp) {
static int cnt = 0;
node[pos[x] = ++cnt] = x;
top[x] = tp;
if (son[x] != 0) bottom[x] = dfs1(son[x], tp);
else bottom[x] = x;

for (int i = 0; i < m; ++i) H[x][i] = 1;

for (int i = pre[x]; i != -1; i = nxt[i])
if (to[i] != fa[x] && to[i] != son[x]) {
dfs1(to[i], to[i]);
for (int j = 0; j < m; ++j)
H[x][j] *= 1 + F[to[i]][j];
}

for (int i = 0; i < m; ++i)
if (count[i & val[x]] & 1)
H[x][i] *= -1;

for (int i = 0; i < m; ++i) F[x][i] = H[x][i];
if (son[x])
for (int i = 0; i < m; ++i)
(F[x][i] *= 1 + F[son[x]][i]) %= mod;
for (int i = 0; i < m; ++i)
(Ans[i] += F[x][i]) %= mod;

return bottom[x];
}

void Cut(int x) {
if (x == 0) return;
Modify(1, 1, n, pos[x]);
int y = top[x];
Cut(fa[y]);
tmp.Init();
Query(1, 1, n, pos[y], pos[bottom[y]]);
for (int i = 0; i < m; ++i) (Ans[i] -= tmp.sum[i]) %= mod;
for (int i = 0; i < m; ++i) H[fa[y]][i] /= 1 + tmp.pre[i];
}

void Link(int x) {
if (x == 0) return;
Modify(1, 1, n, pos[x]);
int y = top[x];
tmp.Init();
Query(1, 1, n, pos[y], pos[bottom[y]]);
for (int i = 0; i < m; ++i) (Ans[i] += tmp.sum[i]) %= mod;
for (int i = 0; i < m; ++i) H[fa[y]][i] *= 1 + tmp.pre[i];
Link(fa[y]);
}

void Modify(int x, int v) {
Cut(x);

for (int i = 0; i < m; ++i)
if (count[i & val[x]] & 1)
H[x][i] *= -1;

val[x] = v;

for (int i = 0; i < m; ++i)
if (count[i & val[x]] & 1)
H[x][i] *= -1;

Link(x);
}

char s[100];

int main() {
scanf("%d%d", &n, &m);

for (int i = 1; i < m; ++i)
count[i] = count[i & (i - 1)] + 1;
inv[1] = 1;
for (int i = 2; i < mod; ++i)
inv[i] = -(mod / i) * inv[mod % i] % mod;

for (int i = 1; i <= n; ++i)
scanf("%d", &val[i]);
memset(pre, -1, sizeof pre);
for (int i = 1, x, y; i < n; ++i) {
scanf("%d%d", &x, &y);
addEdge(x, y);
}

dfs0(1, 0);
dfs1(1, 1);
Build(1, 1, n);

int q;
scanf("%d", &q);
while (q--) {
int x;
scanf("%s%d", s, &x);
if (*s == 'C') {
int y;
scanf("%d", &y);
Modify(x, y);
} else {
int ans = 0;
for (int i = 0; i < m; ++i)
(ans += (count[i & x] & 1 ? -1 : 1) * Ans[i]) %= mod;
printf("%d\n", (ans * inv[m] % mod + mod) % mod);
}
}

return 0;
}