BZOJ4820 [SDOI2017] 硬币游戏
Description
给定 $n$ 个长为 $m$ 的由 H
和 T
组成的串(表示硬币的正反面)。
不停地丢一枚均匀的硬币直到硬币序列的最后 $m$ 位是 $n$ 个字符串之一为止。
对每个串求最后以它结尾的概率。$n,m\leqslant300$。
Solution
经典做法是建出 AC 自动机之后直接在上面跑高消。显然 $O((nm)^3)$ 不可能过得去。
以下所有大写字母(除 P,H,T
之外)均表示集合;小写字母表示字符串。
令 $S_i$ 表示所有以第 $i$ 个字符串结尾的硬币序列的集合,$N$ 表示游戏还没有结束的硬币序列的集合。
比如两个字符串分别为 TTT
,HHT
,那么(\_
表示空串):
N={\_,H,T,HH,HT,TH,TT,HHH,HTH,HTT,THH,THT,TTH,HHHH,HTHH,HTHT,HTTH,THHH,THTH,THTT,TTHH,TTHT...}
S1={TTT,HTTT,THTTT,HTHTTT,TTHTTT,...}
S2={HHT,HHHT,THHT,HHHHT,HTHHT,THHHT,TTHHT,...}
再定义 $s_0S=\{s_0s\mid s\in S\}$,即把 $S$ 中所有的字符串开头拼接一个 $s_0$;$Ss_0$ 同理。
那么我们有如下几个“集合方程”(仍以 TTT
与 HHT
为例):
$$\begin{aligned} \{\_\}+HN+TN&=N+S_1+S_2\\ NTTT&=S1+S1T+S1TT+S2TT\\ NHTT&=S2\\ \end{aligned}$$
$+$ 表示不交并。
第一个式子的左边和右边都是所有可能出现的状态。
第二个式子是说,如果在某个没有结束的状态后面添加一个 TTT
,可能会以 TTT
结尾,可能会以 TTT
结尾后多出一个 T
,等等。显然这些情况都是不相交的。
第三个式子同上,但是这时候不会出现多出来的情况,因为在任何未结束的状态上拼一个 HH
都不会结束。
对于任意的字符串,我们可以得知
$$\begin{aligned} \{\_\}+HN+TN&=N+\sum_{i=1}^nS_i\\ Ns_i&=\sum_{\substack{1\leqslant k\leqslant n\\1\leqslant j\leqslant m\\pre(s_i,j)=suf(s_k,j)}}S_ksuf(s_i,m-j) \end{aligned}$$
其中 $pre(s,i),suf(s,i)$ 分别表示 $s$ 的长为 $i$ 的前缀和后缀。
第二个式子是说,如果第 $i$ 个串的长为 $j$ 的前缀等于第 $k$ 个串的长为 $j$ 的后缀,那么就有可能在添加 $pre(s_i,j)$ 之后直接结束,多出来一个 $suf(s_i,m-j)$。
接下来考虑怎么计算答案。
考虑令 $P(s)$ 表示 $s$ 出现的概率(对于均匀的硬币,实际上就是 $2^{-length(s)}$)。
现在再令 $P(N)$ 表示 $\sum_{s\in S_i}P(s)$,$P(S_i)$ 同理。
显然最后以第 $i$ 个串结尾的概率就是 $P(S_i)$,即所有以第 $i$ 个串结尾的状态的概率之和,并且 $P(s_0S)=P(Ss_0)=P(S)P(s_0)$。
于是上面的方程相当于说
$$\begin{aligned} 1+P(N)&=P(N)+\sum_{i=1}^nP(S_i)\\ 2^{-m}P(N)&=\sum_{\substack{1\leqslant k\leqslant n\\1\leqslant j\leqslant m\\pre(s_i,j)=suf(s_k,j)}}2^{j-m}P(S_k) \end{aligned}$$
即
$$\begin{aligned} 1&=\sum_{i=1}^nP(S_i)\\ P(N)&=\sum_{\substack{1\leqslant k\leqslant n\\1\leqslant j\leqslant m\\pre(s_i,j)=suf(s_k,j)}}2^jP(S_k) \end{aligned}$$
建出 AC 自动机之后就可以知道 $s_i$ 的后缀等于哪些串的哪些前缀,建出方程组之后高消即可。
Code
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <vector>
const int N = 305;
const int NN = N * N;
int ch[NN][2], len[NN], cnt, fail[NN], que[NN], e[N];
std::vector<int> V[NN];
int n, m;
char s[N][N];
double A[N][N], p2[N];
inline double abs(double x) { return x > 0 ? x : -x; }
void Gauss() {
for (int i = 0; i <= n; ++i) {
int t = i;
for (int j = i; j <= n; ++j)
if (abs(A[j][i]) > abs(A[t][i])) t = j;
if (i != t)
for (int j = i; j <= n + 1; ++j)
std::swap(A[i][j], A[t][j]);
for (int j = n + 1; j >= i; --j)
[i][j] /= A[i][i];
Afor (int j = 0; j <= n; ++j)
if (i != j)
for (int k = n + 1; k >= i; --k)
[j][k] -= A[j][i] * A[i][k];
A}
}
int main() {
("%d%d", &n, &m);
scanffor (int i = 0; i < n; ++i) {
("%s", s[i]);
scanfint o = 0;
for (int j = 0; j < m; ++j) {
int t = s[i][j] == 'T';
if (!ch[o][t]) len[ch[o][t] = ++cnt] = j + 1;
[o = ch[o][t]].push_back(i);
V}
[i] = o;
e}
int hd = 0, tl = 0;
if (ch[0][0]) que[tl++] = ch[0][0];
if (ch[0][1]) que[tl++] = ch[0][1];
while (hd < tl) {
int x = que[hd++];
(ch[x][0] ? fail[que[tl++] = ch[x][0]] : ch[x][0]) = ch[fail[x]][0];
(ch[x][1] ? fail[que[tl++] = ch[x][1]] : ch[x][1]) = ch[fail[x]][1];
}
[0] = 1.0;
p2for (int i = 1; i <= m; ++i) p2[i] = p2[i - 1] * 2;
for (int i = 0; i < n; ++i) A[i][n] = -1;
for (int i = 0; i < n; ++i)
for (int t = e[i]; t; t = fail[t])
for (std::vector<int>::iterator it = V[t].begin(); it != V[t].end(); ++it)
[*it][i] += p2[len[t]];
Afor (int i = 0; i < n; ++i) A[n][i] = 1.0;
[n][n + 1] = 1.0;
A();
Gaussfor (int i = 0; i < n; ++i) printf("%.10lf\n", A[i][n + 1]);
return 0;
}