BZOJ5340 [CTSC2018] 假面

2018 年 05 月 18 日发布.

Description

有一列数 $a_i$ ,有 $Q$ 个操作:

  1. 对于某个 $i$ ,以 $p$ 的概率使 $a_i$ 减一(若 $a_i=0$ 则忽略);
  2. 指定某些 $a_i$ ,从这些 $a_i$ 中的所有大于零的位置随机选一个,对于每个 $a_i$ 你需要求出选到它的概率。

你还要求出每个 $a_i$ 执行完所有操作后的期望。

令 $C$ 为 2 操作的数量:

$n\leqslant200,Q\leqslant2\times10^5,C\leqslant10^3;$ 初始时 $a_i\leqslant100$ 。

时限 6s 。

Solution

我们维护一个 $f_{i,j}$ 表示到当前操作时 $a_i=j$ 的概率。显然我可以对每个 1 操作 $O(a_i)$ 更新这个值。

最后求期望很简单。我们考虑如何做 2 操作。

显然 2 操作的答案只和这些 $a_i$ 是否等于 $0$ 有关(而和它们的值无关)。我们令 $p_i=f_{i,0},q_i=1-p_i$ 。再令 $f_j$ 表示不考虑 $i$ 时有 $j$ 个 $a$ 非零的概率,那么显然选到 $i$ 的概率是 $$ q_i\sum_j\frac{f_j}{j+1} $$ 那么我们考虑如何对于所有的 $i$ 快速地求出 $f_j$ 。

首先,如果我们去掉“不考虑 $i$ ”的限制,那么显然 DP 一遍就能得到答案,即令 $f_{k,j}$ 表示前 $k$ 个里有 $j$ 个非零的概率,则 $f_{k,j}=p_kf_{k-1,j}+q_kf_{k-1,j-1}$ 。

如果 $i$ 是最后一个数(假设一共有 $m$ 个数),那么答案就是 $q_i\sum_j\frac{f_{m-1,j}}{j+1}$ 。

但是我们可以发现 $f_{m,j}$ 是转移顺序无关的,且 $f_{m-1,j}=\cfrac{f_{m,j}-f_{m-1,j-1}q_k}{p_k}$ 。

所以我们可以直接算出 $f_{k,j}$ ,然后对于每个 $i$ ,把它看成第 $m$ 个数,利用上面的公式倒推出 $f_{k-1,j}$ 即可。

单次 2 操作复杂度是 $O(m^2)=O(n^2)$ 的。

总时间复杂度 $O(Cn^2+Q\max a_i)$ 。

Code

#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>

namespace _rqy{
const int N = 205;
const int M = 105;
const int mod = 998244353;

typedef long long LL;

LL pow_mod(LL a, LL p) {
  LL ans = 1;
  for (; p > 0; p >>= 1, (a *= a) %= mod)
    if (p & 1) (ans *= a) %= mod;
  return ans;
}

inline char getChar() {
  const int L = 10000000;
  static char s[L], *p = s, *end = s;
  if (p == end) {
    if (feof(stdin)) return EOF;
    end = s + fread(p = s, 1, L, stdin);
  }
  return *(p++);
}

int readInt() {
  int ans = 0, c;
  while (!isdigit(c = getChar()));
  do ans = ans * 10 + c - '0';
  while (isdigit(c = getChar()));
  return ans;
}

int n, m[N];
LL p[N][M], inv[N];

void main() {
  n = readInt();
  inv[1] = 1;
  for (int i = 2; i <= n; ++i)
    inv[i] = -(mod / i) * inv[mod % i] % mod;
  for (int i = 1; i <= n; ++i) {
    m[i] = readInt();
    for (int j = 0; j < m[i]; ++j)
      p[i][j] = 0;
    p[i][m[i]] = 1;
  }
  int q = readInt();
  while (q--) {
    if (readInt() == 0) {
      int i = readInt();
      LL u = readInt();
      LL v = readInt();
      LL x = u * pow_mod(v, mod - 2) % mod, y = 1 - x;
      LL tmp = p[i][0] * x % mod;
      for (int j = 0; j <= m[i]; ++j)
        p[i][j] = (x * p[i][j + 1] + y * p[i][j]) % mod;
      (p[i][0] += tmp) %= mod;
    } else {
      static int t[N];
      static LL f[N];
      int k = readInt(), s = 0;
      f[0] = 1;
      for (int i = 1; i < k; ++i) f[i] = 0;
      for (int i = 0; i < k; ++i) {
        int j = t[i] = readInt();
        LL y = p[j][0], x = 1 - y;
        if (y == 0) {
          ++s;
          continue;
        }
        for (int t = k; t >= 0; --t) {
          (f[t + 1] += f[t] * x) %= mod;
          (f[t] *= y) %= mod;
        }
      }
      for (int i = 0; i < k; ++i) {
        int j = t[i];
        LL y = p[j][0], x = 1 - y, iy = pow_mod(y, mod - 2), ans = 0;
        if (y == 0) {
          for (int t = 0; t < k; ++t)
            (ans += f[t] * inv[t + s]) %= mod;
        } else {
          for (int t = 0; t < k; ++t) {
            (f[t] *= iy) %= mod;
            (f[t + 1] -= f[t] * x) %= mod;
          }

          for (int t = 0; t < k; ++t)
            (ans += f[t] * inv[t + s + 1]) %= mod;

          for (int t = k - 1; t >= 0; --t) {
            (f[t + 1] += f[t] * x) %= mod;
            (f[t] *= y) %= mod;
          }
        }
        printf("%lld%c", (ans * x % mod + mod) % mod,
                i == k - 1 ? '\n' : ' ');
        }
      }
    }
    for (int i = 1; i <= n; ++i) {
      LL ans = 0;
      for (int j = 0; j <= m[i]; ++j)
        (ans += p[i][j] * j) %= mod;
      printf("%lld%c", (ans + mod) % mod, i == n ? '\n' : ' ');
    }
  }
};

int main() {
  //freopen("faceless.in", "r", stdin);
  //freopen("faceless.out", "w", stdout);
  _rqy::main();
  return 0;
}