BZOJ4944 [NOI2017] 泳池

2018 年 03 月 19 日发布.

Description

有一个底边长 $n$ 个单位,高为 $+\infty$ 的矩形网格,其中每个格子都有可能是危险的或者安全的。

每个格子安全的概率为 $p$ ,危险的概率为 $1 - p$ 。任意两个格子是否安全都是独立的。

要选取一个子矩形,要求它与底边相邻且其中所有格子都是安全的。最大化选取子矩形的面积。求答案恰好为 $k$ 的概率。

$n \leqslant 10^9, k \leqslant 1000$ 。

Solution

我们把“恰好为 $k$ 的概率”转化为“小于 $k+1$ 的概率减去小于 $k$ 的概率”。

为了简化描述,下面把某一列从底边开始能够延伸出的最长的安全的格子长度称为其“安全高度”。

考虑计算“答案小于 $k$ 的概率”。

考虑 DP :令 $f_{n, m}$ 表示底边长为 $n$ ,且已知每一行安全高度都不小于 $m$ 时的概率。

$$f_{n, m} = \begin{cases} 0 & nm >= k\\ 1 & n = 0\\ p^nf_{n,m+1}+\sum_{i=0}^{n-1}p^if_{i,m+1}(1-p)f_{n-i-1,m} & \text{otherwise} \end{cases}$$

前两个容易理解,最后一个实际上枚举了左数第一个“安全高度”恰好为 $m$ 的位置。

最后求出 $f_{n, 0}$ 即可。

这样就能拿到 70pts 了。

考虑如何加速转移。

首先,由于 $ij < k$ ,所以满足 $j>0$ 的状态实际上只有 $O(k\log k)$ 个,可以暴力 DP 。

我们把 $f_0$ 拿出来,假设它叫做 $g$ ,并只考虑 $g_n$ 中 $n \geq k$ 时的递推式。

由 $f$ 的递推式,以及 $f_{i,1} = 0(i>=k)$ 可知,

$$g_n = \sum_{i=0}^{k-1}p^i(1-p)f_{i,1}g_{n-i-1}(n\geq k)$$

那么这是一个线性常系数递推方程,可以利用矩阵快速幂做到 $O(k^3\log n)$ 。这样能够拿到 90pts 。

接下来我们考虑优化矩阵快速幂。

假设递推系数为 $a_i$ (即, $g_n=\sum_{i=1}^ka_ig_{n-i}$ ),转移矩阵为 $A$ ,那么 $A^t * (g_0,g_1\dots,g_{k-1})^T = (g_t,g_{t+1},\dots,g_{t+k-1})^T$ 。

将两者结合起来我们有

$$\begin{aligned} &\left(\sum_{i=1}^ka_iA^{k-i}\right)(g_0,g_1\dots,g_{k-1})^T\\ &=\left(\sum_{i=1}^ka_ig_{k-i},\dots,\sum_{i=1}^ka_ig_{2k-1-i}\right)^T\\ &=(g_k,\dots,g_{2k-1})^T\\ &=A^k(g_0,g_1\dots,g_{k-1})^T \end{aligned}$$

注意到这个等式对所有的 $g$ 的初值成立,一番分析之后我们得到 $A^k = \sum_{i=1}^ka_iA^{k-i}$ 。

这就表明所有 $A$ 的幂都可以以 $A^0, A^1,\dots,A^{k-1}$ 线性表出,于是我们转移时只需要保存当前的 $A^t$ 用这 $k$ 个矩阵线性表出的系数即可。那么矩阵乘法就变成了 $O(k^2)$ ,总时间复杂度变成了 $O(k^2\log n)$ 。

100pts 。

Code

#include <algorithm>
#include <cstdio>
#include <cstring>

typedef long long LL;

const int mod = 998244353;
const int K = 1050;

void exgcd(int a, int b, int &x, int &y) {
  if (!b) { x = 1; y = 0; return; }
  exgcd(b, a % b, y, x);
  y -= (a / b) * x;
}

inline int inv(int x) {
  int v;
  exgcd(x, mod, v, x);
  return v;
}

LL _f[2][K], pq[K], g[K];
LL ans[K], tmp[K * 2], x[K];
LL dp[K];

void mul(LL *a, LL *b, int k) {
  memset(tmp, 0, sizeof tmp);
  for (int i = 0; i < k; ++i) if (a[i] != 0)
    for (int j = 0; j < k; ++j)
      (tmp[i + j] += a[i] * b[j]) %= mod;
  for (int i = k * 2; i >= k; --i)
    for (int j = 0; j < k; ++j)
      (tmp[i - j - 1] += tmp[i] * dp[j]) %= mod;
  for (int i = 0; i < k; ++i) a[i] = tmp[i];
}

LL solve(int n, int k, int q) {
  pq[0] = 1;
  for (int i = 1; i <= k; ++i)
    pq[i] = q * pq[i - 1] % mod;

  LL *f = _f[0], *f2 = _f[1];
  std::fill(f2, f2 + k + 1, 0);
  f2[0] = 1;

  for (int j = k; j; --j, std::swap(f, f2)) {
    std::fill(f, f + k + 1, 0);
    for (int i = 0; i * j < k; ++i) {
      f[i] = pq[i] * f2[i] % mod;
      for (int t = 0; t < i; ++t)
        f[i] = (pq[t] * (1 - q) % mod * f2[t] % mod
                * f[i - t - 1] % mod + f[i]) % mod;
    }
  }

  g[0] = 1;
  for (int i = 1; i < k; ++i) {
    g[i] = (1 - q) * g[i - 1] % mod;
    for (int j = 1; j < k && j <= i; ++j) {
      LL t = (LL)pq[j] * f2[j] % mod;
      if (j == i)
        (g[i] += t) %= mod;
      else
        (g[i] += t * (1 - q) % mod * g[i - j - 1] % mod) %= mod;
    }
  }

  dp[0] = 1 - q;
  for (int i = 1; i < k; ++i)
    dp[i] = (LL)pq[i] * f2[i] % mod * (1 - q) % mod;

  memset(ans, 0, sizeof ans);
  memset(x, 0, sizeof x);
  ans[0] = 1;
  x[1] = 1;
  if (k == 1) x[0] = 1 - q;
  for (; n; n >>= 1, mul(x, x, k))
    if (n & 1) mul(ans, x, k);

  LL ansv = 0;
  for (int i = 0; i < k; ++i)
    (ansv += ans[i] * g[i]) %= mod;

  return (ansv + mod) % mod;
}
int main() {
  int n, k, x, y;

  scanf("%d%d%d%d", &n, &k, &x, &y);
  x = (LL)x * inv(y) % mod;

  printf("%lld\n", (solve(n, k + 1, x) - solve(n, k, x) + mod) % mod);
  return 0;
}