LOJ565 [LR10] mathematican 的二进制
Description
一个初始为 0 的二进制数,有 $m$ 次操作。
第 $i$ 次操作是将这个二进制数加上 $2^{a_i}$ 。这个操作以 $p_i$ 的概率执行。
如果某次操作执行了并且修改了二进制数的 $k$ 位,那么它会带来 $k$ 的代价。
问代价和的期望,答案对 $998244353$ 取模。
$n=\max a_i\leqslant10^5, m\leqslant2\times10^5$
Solution
首先可以发现:
把一个二进制数 $y$ 增加 $2^x$ 后,带来的代价(即修改的位数)是 $2+count(y)-count(y+2^x)$ ,其中 $count(t)$ 表示 $t$ 的二进制表示中 1 的个数。
这个结论比较显然:如果修改了 $t$ 位,说明 $t-1$ 个 1 变成了 0 并且一个 0 变成了 1 。
由此容易得出:如果进行了 $k$ 次操作把二进制数从 $0$ 变成了 $y$ ,那么代价和即为 $2k-count(y)$ 。
由期望的可加性,只需要求出 $2E(k)-E(count(y))$ 即可,而前者显然就是 $2\sum_{i=1}^np_i$ 。
对于 $E(count(y))$ ,只需要求出其每一位是 1 的概率求和即可。
如果我们要计算 $y$ 的第 $t$ 位是 1 的概率,那么显然所有 $a_i>t$ 的操作都可以忽略掉。
我们记 $f_{t, j}$ 表示只考虑所有 $a_i\leqslant t$ 的操作时 $\lfloor y/2^t\rfloor=j$ 的概率,则显然有 $E(count(y)) = \sum_t\sum_jf_{t,2j+1}$ 。
考虑如何求出 $f$ 。
如果所有 $a_i<t$ 的操作执行之后 $\lfloor y/2^{t-1}\rfloor=j^\prime$ ,并且 $a_i=t$ 的操作执行了 $k$ 个,那么显然有 $\lfloor y/2^t\rfloor=k+\lfloor j^\prime/2\rfloor$ 。由此有
$$f_{t,j}=\sum_{\lfloor j^\prime/2\rfloor+k=j}f_{t-1,j^\prime}g_{t,k}$$
其中 $g_{t,k}$ 为 $a_i=t$ 的操作执行了 $k$ 个的概率,可以分治 NTT 求出。 DP 转移亦可以 NTT 优化。
这样的复杂度是 $O(\sum_i \left(m_i\log^2 m_i+p_i\log p_i\right))$ ,其中 $m_i$ 表示 $a_j=i$ 的 $j$ 的个数,$p_i=\left\lfloor\sum_{j\leqslant i} \frac{m_j}{2^{i-j}}\right\rfloor$ 表示 $f_{i,t}$ 中最大的 $t$ 。
有
$$\begin{aligned} &\quad O\left(\sum_i \left(m_i\log^2 m_i+p_i\log p_i\right)\right)\\ &=O\left(\sum_im_i\log^2m+\sum_ip_i\log m\right)\\ &=O\left(m\log^2m+\sum_i\sum_{j\leqslant i}\frac{m_j}{2^{i-j}}\log m\right)\\ &=O\left(m\log^2m+\sum_j m_j(\sum_i2^{-i})\log m\right)\\ &=O\left(m\log^2m+m\log m\right)\\ &=O\left(m\log^2m\right) \end{aligned}$$
对于 $m=2\times10^5$ 可过。
Code
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <vector>
typedef long long LL;
typedef std::vector<LL> VLL;
const int N = 200050;
const int mod = 998244353;
const int g = 3;
int readInt() {
int ans = 0, c, f = 1;
while (!isdigit(c = getchar()))
if (c == '-') f *= -1;
do ans = ans * 10 + c - '0';
while (isdigit(c = getchar()));
return ans * f;
}
(LL a, LL b) {
LL pow_modif ((b %= mod - 1) < 0) b += mod - 1;
= 1;
LL ans for (a %= mod; b > 0; b >>= 1, a = a * a % mod)
if (b & 1) ans = ans * a % mod;
return ans;
}
void NTT(LL *A, int len, int opt) {
for (int i = 1, j = 0; i < len; ++i) {
for (int k = len; ~j & k; j ^= (k >>= 1));
if (i < j) std::swap(A[i], A[j]);
}
for (int h = 2; h <= len; h <<= 1) {
= pow_mod(g, (mod - 1) / h * opt);
LL wn for (int j = 0; j < len; j += h) {
= 1;
LL w for (int i = j; i < j + (h >> 1); ++i) {
= A[i], _t2 = A[i + (h >> 1)] * w % mod;
LL _t1 [i] = (_t1 + _t2) % mod;
A[i + (h >> 1)] = (_t1 - _t2) % mod;
A= w * wn % mod;
w }
}
}
if (opt == -1)
for (int i = 0, v = -(mod - 1) / len; i < len; ++i)
[i] = A[i] * v % mod;
A}
void Conv(const VLL &A, const VLL &B, VLL &C) {
static LL tA[N * 4], tB[N * 4];
int n = A.size(), m = B.size(), t = n + m - 1;
int len = 1;
while (len < t) len <<= 1;
for (int i = 0; i < n; ++i) tA[i] = A[i];
for (int i = n; i < len; ++i) tA[i] = 0;
for (int i = 0; i < m; ++i) tB[i] = B[i];
for (int i = m; i < len; ++i) tB[i] = 0;
(tA, len, 1);
NTT(tB, len, 1);
NTTfor (int i = 0; i < len; ++i)
[i] = tA[i] * tB[i] % mod;
tA(tA, len, -1);
NTT= VLL();
VLL tmp std::swap(C, tmp);
for (int i = 0; i < t; ++i)
.push_back(tA[i]);
C}
int cnt;
[N * 2];
VLL S
void _solve(const VLL &v, int l, int r, VLL &s) {
if (l == r) {
.clear();
s.push_back(1 - v[l]);
s.push_back(v[l]);
s} else {
int mid = (l + r) >> 1;
int L = cnt++, R = cnt++;
(v, l, mid, S[L]);
_solve(v, mid + 1, r, S[R]);
_solve(S[L], S[R], s);
Conv}
}
inline void Solve(const VLL &v, VLL &ans) {
if (v.size() == 0) {
.clear();
ans.push_back(1);
ansreturn;
}
= 0;
cnt (v, 0, v.size() - 1, ans);
_solve}
int n, m;
[N], F, G, A[N];
VLL V
int main() {
= readInt(); m = readInt();
n = 0;
LL S while (m--) {
, p;
LL a= readInt();
a = readInt();
p = p * pow_mod(readInt(), mod - 2) % mod;
p = (S + p) % mod;
S [a].push_back(p);
V}
.push_back(1);
F= S * 2 % mod;
S for (int i = 0; i <= n + 20; ++i) {
int t = F.size();
.resize((t + 1) / 2);
Gstd::fill(G.begin(), G.end(), 0);
for (int i = 0; i < t; ++i) G[i / 2] += F[i];
(V[i], F);
Solve(F, G, F);
Convfor (int i = 1; i < F.size(); i += 2)
= (S - F[i]) % mod;
S }
("%lld\n", (S + mod) % mod);
printfreturn 0;
}