BZOJ4820 [SDOI2017] 硬币游戏

2018 年 12 月 26 日发布.

Description

给定 $n$ 个长为 $m$ 的由 HT 组成的串(表示硬币的正反面)。

不停地丢一枚均匀的硬币直到硬币序列的最后 $m$ 位是 $n$ 个字符串之一为止。

对每个串求最后以它结尾的概率。$n,m\leqslant300$。

Solution

经典做法是建出 AC 自动机之后直接在上面跑高消。显然 $O((nm)^3)$ 不可能过得去。

以下所有大写字母(除 P,H,T 之外)均表示集合;小写字母表示字符串。

令 $S_i$ 表示所有以第 $i$ 个字符串结尾的硬币序列的集合,$N$ 表示游戏还没有结束的硬币序列的集合。

比如两个字符串分别为 TTTHHT,那么(\_ 表示空串):

N={\_,H,T,HH,HT,TH,TT,HHH,HTH,HTT,THH,THT,TTH,HHHH,HTHH,HTHT,HTTH,THHH,THTH,THTT,TTHH,TTHT...}

S1={TTT,HTTT,THTTT,HTHTTT,TTHTTT,...}

S2={HHT,HHHT,THHT,HHHHT,HTHHT,THHHT,TTHHT,...}

再定义 $s_0S=\{s_0s\mid s\in S\}$,即把 $S$ 中所有的字符串开头拼接一个 $s_0$;$Ss_0$ 同理。

那么我们有如下几个“集合方程”(仍以 TTTHHT 为例):

$$\begin{aligned} \{\_\}+HN+TN&=N+S_1+S_2\\ NTTT&=S1+S1T+S1TT+S2TT\\ NHTT&=S2\\ \end{aligned}$$

$+$ 表示不交并。

第一个式子的左边和右边都是所有可能出现的状态。

第二个式子是说,如果在某个没有结束的状态后面添加一个 TTT,可能会以 TTT 结尾,可能会以 TTT 结尾后多出一个 T,等等。显然这些情况都是不相交的。

第三个式子同上,但是这时候不会出现多出来的情况,因为在任何未结束的状态上拼一个 HH 都不会结束。

对于任意的字符串,我们可以得知

$$\begin{aligned} \{\_\}+HN+TN&=N+\sum_{i=1}^nS_i\\ Ns_i&=\sum_{\substack{1\leqslant k\leqslant n\\1\leqslant j\leqslant m\\pre(s_i,j)=suf(s_k,j)}}S_ksuf(s_i,m-j) \end{aligned}$$

其中 $pre(s,i),suf(s,i)$ 分别表示 $s$ 的长为 $i$ 的前缀和后缀。

第二个式子是说,如果第 $i$ 个串的长为 $j$ 的前缀等于第 $k$ 个串的长为 $j$ 的后缀,那么就有可能在添加 $pre(s_i,j)$ 之后直接结束,多出来一个 $suf(s_i,m-j)$。

接下来考虑怎么计算答案。

考虑令 $P(s)$ 表示 $s$ 出现的概率(对于均匀的硬币,实际上就是 $2^{-length(s)}$)。

现在再令 $P(N)$ 表示 $\sum_{s\in S_i}P(s)$,$P(S_i)$ 同理。

显然最后以第 $i$ 个串结尾的概率就是 $P(S_i)$,即所有以第 $i$ 个串结尾的状态的概率之和,并且 $P(s_0S)=P(Ss_0)=P(S)P(s_0)$。

于是上面的方程相当于说

$$\begin{aligned} 1+P(N)&=P(N)+\sum_{i=1}^nP(S_i)\\ 2^{-m}P(N)&=\sum_{\substack{1\leqslant k\leqslant n\\1\leqslant j\leqslant m\\pre(s_i,j)=suf(s_k,j)}}2^{j-m}P(S_k) \end{aligned}$$

$$\begin{aligned} 1&=\sum_{i=1}^nP(S_i)\\ P(N)&=\sum_{\substack{1\leqslant k\leqslant n\\1\leqslant j\leqslant m\\pre(s_i,j)=suf(s_k,j)}}2^jP(S_k) \end{aligned}$$

建出 AC 自动机之后就可以知道 $s_i$ 的后缀等于哪些串的哪些前缀,建出方程组之后高消即可。

Code

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <vector>

const int N = 305;
const int NN = N * N;

int ch[NN][2], len[NN], cnt, fail[NN], que[NN], e[N];
std::vector<int> V[NN];

int n, m;
char s[N][N];
double A[N][N], p2[N];

inline double abs(double x) { return x > 0 ? x : -x; }

void Gauss() {
  for (int i = 0; i <= n; ++i) {
    int t = i;
    for (int j = i; j <= n; ++j)
      if (abs(A[j][i]) > abs(A[t][i])) t = j;
    if (i != t)
      for (int j = i; j <= n + 1; ++j)
        std::swap(A[i][j], A[t][j]);
    for (int j = n + 1; j >= i; --j)
      A[i][j] /= A[i][i];
    for (int j = 0; j <= n; ++j)
      if (i != j)
        for (int k = n + 1; k >= i; --k)
          A[j][k] -= A[j][i] * A[i][k];
  }
}

int main() {
  scanf("%d%d", &n, &m);
  for (int i = 0; i < n; ++i) {
    scanf("%s", s[i]);
    int o = 0;
    for (int j = 0; j < m; ++j) {
      int t = s[i][j] == 'T';
      if (!ch[o][t]) len[ch[o][t] = ++cnt] = j + 1;
      V[o = ch[o][t]].push_back(i);
    }
    e[i] = o;
  }
  int hd = 0, tl = 0;
  if (ch[0][0]) que[tl++] = ch[0][0];
  if (ch[0][1]) que[tl++] = ch[0][1];
  while (hd < tl) {
    int x = que[hd++];
    (ch[x][0] ? fail[que[tl++] = ch[x][0]] : ch[x][0]) = ch[fail[x]][0];
    (ch[x][1] ? fail[que[tl++] = ch[x][1]] : ch[x][1]) = ch[fail[x]][1];
  }
  p2[0] = 1.0;
  for (int i = 1; i <= m; ++i) p2[i] = p2[i - 1] * 2;
  for (int i = 0; i < n; ++i) A[i][n] = -1;
  for (int i = 0; i < n; ++i)
    for (int t = e[i]; t; t = fail[t])
      for (std::vector<int>::iterator it = V[t].begin(); it != V[t].end(); ++it)
        A[*it][i] += p2[len[t]];
  for (int i = 0; i < n; ++i) A[n][i] = 1.0;
  A[n][n + 1] = 1.0;
  Gauss();
  for (int i = 0; i < n; ++i) printf("%.10lf\n", A[i][n + 1]);
  return 0;
}