LOJ565 [LR10] mathematican 的二进制

2018 年 06 月 25 日发布.

Description

一个初始为 0 的二进制数,有 mm 次操作。

ii 次操作是将这个二进制数加上 2ai2^{a_i} 。这个操作以 pip_i 的概率执行。

如果某次操作执行了并且修改了二进制数的 kk 位,那么它会带来 kk 的代价。

问代价和的期望,答案对 998244353998244353 取模。

n=maxai105,m2×105n=\max a_i\leqslant10^5, m\leqslant2\times10^5

Solution

首先可以发现:

把一个二进制数 yy 增加 2x2^x 后,带来的代价(即修改的位数)是 2+count(y)count(y+2x)2+count(y)-count(y+2^x) ,其中 count(t)count(t) 表示 tt 的二进制表示中 1 的个数。

这个结论比较显然:如果修改了 tt 位,说明 t1t-1 个 1 变成了 0 并且一个 0 变成了 1 。

由此容易得出:如果进行了 kk 次操作把二进制数从 00 变成了 yy ,那么代价和即为 2kcount(y)2k-count(y)

由期望的可加性,只需要求出 2E(k)E(count(y))2E(k)-E(count(y)) 即可,而前者显然就是 2i=1npi2\sum_{i=1}^np_i

对于 E(count(y))E(count(y)) ,只需要求出其每一位是 1 的概率求和即可。

如果我们要计算 yy 的第 tt 位是 1 的概率,那么显然所有 ai>ta_i>t 的操作都可以忽略掉。

我们记 ft,jf_{t, j} 表示只考虑所有 aita_i\leqslant t 的操作时 y/2t=j\lfloor y/2^t\rfloor=j 的概率,则显然有 E(count(y))=tjft,2j+1E(count(y)) = \sum_t\sum_jf_{t,2j+1}

考虑如何求出 ff

如果所有 ai<ta_i<t 的操作执行之后 y/2t1=j\lfloor y/2^{t-1}\rfloor=j^\prime ,并且 ai=ta_i=t 的操作执行了 kk 个,那么显然有 y/2t=k+j/2\lfloor y/2^t\rfloor=k+\lfloor j^\prime/2\rfloor 。由此有

ft,j=j/2+k=jft1,jgt,kf_{t,j}=\sum_{\lfloor j^\prime/2\rfloor+k=j}f_{t-1,j^\prime}g_{t,k}

其中 gt,kg_{t,k}ai=ta_i=t 的操作执行了 kk 个的概率,可以分治 NTT 求出。 DP 转移亦可以 NTT 优化。

这样的复杂度是 O(i(milog2mi+pilogpi))O(\sum_i \left(m_i\log^2 m_i+p_i\log p_i\right)) ,其中 mim_i 表示 aj=ia_j=ijj 的个数,pi=jimj2ijp_i=\left\lfloor\sum_{j\leqslant i} \frac{m_j}{2^{i-j}}\right\rfloor 表示 fi,tf_{i,t} 中最大的 tt

O(i(milog2mi+pilogpi))=O(imilog2m+ipilogm)=O(mlog2m+ijimj2ijlogm)=O(mlog2m+jmj(i2i)logm)=O(mlog2m+mlogm)=O(mlog2m)\begin{aligned} &\quad O\left(\sum_i \left(m_i\log^2 m_i+p_i\log p_i\right)\right)\\ &=O\left(\sum_im_i\log^2m+\sum_ip_i\log m\right)\\ &=O\left(m\log^2m+\sum_i\sum_{j\leqslant i}\frac{m_j}{2^{i-j}}\log m\right)\\ &=O\left(m\log^2m+\sum_j m_j(\sum_i2^{-i})\log m\right)\\ &=O\left(m\log^2m+m\log m\right)\\ &=O\left(m\log^2m\right) \end{aligned}

对于 m=2×105m=2\times10^5 可过。

Code

#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <vector>

typedef long long LL;
typedef std::vector<LL> VLL;

const int N = 200050;
const int mod = 998244353;
const int g = 3;

int readInt() {
  int ans = 0, c, f = 1;
  while (!isdigit(c = getchar()))
    if (c == '-') f *= -1;
  do ans = ans * 10 + c - '0';
  while (isdigit(c = getchar()));
  return ans * f;
}

LL pow_mod(LL a, LL b) {
  if ((b %= mod - 1) < 0) b += mod - 1;
  LL ans = 1;
  for (a %= mod; b > 0; b >>= 1, a = a * a % mod)
    if (b & 1) ans = ans * a % mod;
  return ans;
}

void NTT(LL *A, int len, int opt) {
  for (int i = 1, j = 0; i < len; ++i) {
    for (int k = len; ~j & k; j ^= (k >>= 1));
    if (i < j) std::swap(A[i], A[j]);
  }
  for (int h = 2; h <= len; h <<= 1) {
    LL wn = pow_mod(g, (mod - 1) / h * opt);
    for (int j = 0; j < len; j += h) {
      LL w = 1;
      for (int i = j; i < j + (h >> 1); ++i) {
        LL _t1 = A[i], _t2 = A[i + (h >> 1)] * w % mod;
        A[i] = (_t1 + _t2) % mod;
        A[i + (h >> 1)] = (_t1 - _t2) % mod;
        w = w * wn % mod;
      }
    }
  }
  if (opt == -1)
    for (int i = 0, v = -(mod - 1) / len; i < len; ++i)
      A[i] = A[i] * v % mod;
}

void Conv(const VLL &A, const VLL &B, VLL &C) {
  static LL tA[N * 4], tB[N * 4];
  int n = A.size(), m = B.size(), t = n + m - 1;
  int len = 1;
  while (len < t) len <<= 1;
  for (int i = 0; i < n; ++i) tA[i] = A[i];
  for (int i = n; i < len; ++i) tA[i] = 0;
  for (int i = 0; i < m; ++i) tB[i] = B[i];
  for (int i = m; i < len; ++i) tB[i] = 0;
  NTT(tA, len, 1);
  NTT(tB, len, 1);
  for (int i = 0; i < len; ++i)
    tA[i] = tA[i] * tB[i] % mod;
  NTT(tA, len, -1);
  VLL tmp = VLL();
  std::swap(C, tmp);
  for (int i = 0; i < t; ++i)
    C.push_back(tA[i]);
}

int cnt;
VLL S[N * 2];

void _solve(const VLL &v, int l, int r, VLL &s) {
  if (l == r) {
    s.clear();
    s.push_back(1 - v[l]);
    s.push_back(v[l]);
  } else {
    int mid = (l + r) >> 1;
    int L = cnt++, R = cnt++;
    _solve(v, l, mid, S[L]);
    _solve(v, mid + 1, r, S[R]);
    Conv(S[L], S[R], s);
  }
}

inline void Solve(const VLL &v, VLL &ans) {
  if (v.size() == 0) {
    ans.clear();
    ans.push_back(1);
    return;
  }
  cnt = 0;
  _solve(v, 0, v.size() - 1, ans);
}

int n, m;
VLL V[N], F, G, A[N];

int main() {
  n = readInt(); m = readInt();
  LL S = 0;
  while (m--) {
    LL a, p;
    a = readInt();
    p = readInt();
    p = p * pow_mod(readInt(), mod - 2) % mod;
    S = (S + p) % mod;
    V[a].push_back(p);
  }
  F.push_back(1);
  S = S * 2 % mod;
  for (int i = 0; i <= n + 20; ++i) {
    int t = F.size();
    G.resize((t + 1) / 2);
    std::fill(G.begin(), G.end(), 0);
    for (int i = 0; i < t; ++i) G[i / 2] += F[i];
    Solve(V[i], F);
    Conv(F, G, F);
    for (int i = 1; i < F.size(); i += 2)
      S = (S - F[i]) % mod;
  }
  printf("%lld\n", (S + mod) % mod);
  return 0;
}