BZOJ4734 [清华集训2016] 如何优雅地求和

2018 年 12 月 26 日发布.

Description

已知 $m$ 次多项式函数 $f$ 在 $0,1\dots,m$ 处的取值 $f(0),f(1),\dots,f(m)$,给定 $n,x$,求

$$\left(\sum_{k=0}^n\binom nkf(k)x^k(1-x)^{n-k}\right)\bmod998244353$$

$n\leqslant10^9,m\leqslant2\times10^4$。

Solution

考虑把 $f$ 表示成下降幂的和:$f(x)=\sum_{i=0}^m f_ix_{(i)}$,其中 $x_{(i)}=\prod_{k=0}^{i-1}(x-k)=\frac{x!}{(x-k)!}$。

那么

$$\begin{aligned} &\sum_{k=0}^n\binom nkf(k)x^k(1-x)^{n-k}\\ =&\sum_{i=0}^mf_i\sum_{k=i}^n\frac{n!}{k!(n-k)!}\cdot\frac{k!}{(k-i)!}x^k(1-x)^{n-k}\\ =&\sum_{i=0}^mf_in_{(i)}x^i\sum_{k=i}^n\binom{n-i}{k-i}x^{k-i}(1-x)^{n-k}\\ =&\sum_{i=0}^mf_in_{(i)}x^i \end{aligned}$$

剩下的问题就是如何求 $f_i$。

考虑到

$$\begin{aligned} &\sum_nf(n)\frac{x^n}{n!}\\ =&\sum_{i=0}^mf_i\sum_{n=i}^{\infty}\frac{x^n}{(n-i)!}\\ =&\left(\sum_{i=0}^mf_ix^i\right)e^x \end{aligned}$$

只需要求 $\sum_n f(n)\frac{x^n}{n!}$ 和 $e^{-x}$ 的卷积前 $m+1$ 项即可。

复杂度 $O(n\log n)$。

Code

#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>

typedef long long LL;

int readInt() {
  int ans = 0, c, f = 1;
  while (!isdigit(c = getchar()))
    if (c == '-') f *= -1;
  do ans = ans * 10 + c - '0';
  while (isdigit(c = getchar()));
  return ans * f;
}

const int N = 100000;
const int mod = 998244353;
const int g = 3;

int n, m, len, rev[N];
LL omega[N], iomega[N];

LL pow_mod(LL a, LL b) {
  LL ans = 1;
  for (a %= mod; b; b >>= 1, a = a * a % mod)
    if (b & 1) ans = ans * a % mod;
  return ans;
}

void InitNTT(int n) {
  int k = 0;
  len = 1;
  while (len < n) len <<= 1, ++k;
  for (int i = 1; i < len; ++i)
    rev[i] = ((i & 1) << (k - 1)) | (rev[i >> 1] >> 1);
  omega[0] = iomega[0] = 1;
  LL w = pow_mod(g, (mod - 1) / len);
  for (int i = 1; i < len; ++i)
    iomega[len - i] = omega[i] = omega[i - 1] * w % mod;
}

void NTT(LL *A, const LL *omega) {
  for (int i = 0; i < len; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
  for (int h = 2; h <= len; h <<= 1)
    for (int j = 0; j < len; j += h) {
      const LL *w = omega;
      for (int i = j; i < j + (h >> 1); ++i) {
        LL _t1 = A[i], _t2 = A[i + (h >> 1)] * *w % mod;
        A[i] = (_t1 + _t2) % mod;
        A[i + (h >> 1)] = (_t1 - _t2) % mod;
        w += len / h;
      }
    }
  if (omega == ::iomega)
    for (int i = 0, v = -(mod - 1) / len; i < len; ++i)
      A[i] = A[i] * v % mod;
}

LL inv[N], ifac[N];
LL A[N], G[N];

int main() {
  n = readInt();
  m = readInt() + 1;
  LL x = readInt();
  inv[1] = ifac[0] = ifac[1] = 1;
  for (int i = 2; i < m; ++i)
    ifac[i] = ifac[i - 1] * (inv[i] = -(mod / i) * inv[mod % i] % mod) % mod;
  for (int i = 0; i < m; ++i) {
    A[i] = readInt() * ifac[i] % mod;
    G[i] = i & 1 ? -ifac[i] : ifac[i];
  }
  InitNTT(m * 2);
  NTT(A, omega); NTT(G, omega);
  for (int i = 0; i < len; ++i) G[i] = G[i] * A[i] % mod;
  NTT(G, iomega);
  LL t = 1, ans = 0;
  for (int i = 0; i < m; ++i) {
    ans = (ans + G[i] * t) % mod;
    t = t * x % mod * (n - i) % mod;
  }
  printf("%lld\n", (ans + mod) % mod);
  return 0;
}