BZOJ4734 [清华集训2016] 如何优雅地求和

Description

已知 \(m\) 次多项式函数 \(f\)\(0,1\dots,m\) 处的取值 \(f(0),f(1),\dots,f(m)\),给定 \(n,x\),求

\[\left(\sum_{k=0}^n\binom nkf(k)x^k(1-x)^{n-k}\right)\bmod998244353\]

\(n\leqslant10^9,m\leqslant2\times10^4\)

Solution

考虑把 \(f\) 表示成下降幂的和:\(f(x)=\sum_{i=0}^m f_ix_{(i)}\),其中 \(x_{(i)}=\prod_{k=0}^{i-1}(x-k)=\frac{x!}{(x-k)!}\)

那么

\[\begin{aligned} &\sum_{k=0}^n\binom nkf(k)x^k(1-x)^{n-k}\\ =&\sum_{i=0}^mf_i\sum_{k=i}^n\frac{n!}{k!(n-k)!}\cdot\frac{k!}{(k-i)!}x^k(1-x)^{n-k}\\ =&\sum_{i=0}^mf_in_{(i)}x^i\sum_{k=i}^n\binom{n-i}{k-i}x^{k-i}(1-x)^{n-k}\\ =&\sum_{i=0}^mf_in_{(i)}x^i \end{aligned}\]

剩下的问题就是如何求 \(f_i\)

考虑到

\[\begin{aligned} &\sum_nf(n)\frac{x^n}{n!}\\ =&\sum_{i=0}^mf_i\sum_{n=i}^{\infty}\frac{x^n}{(n-i)!}\\ =&\left(\sum_{i=0}^mf_ix^i\right)e^x \end{aligned}\]

只需要求 \(\sum_n f(n)\frac{x^n}{n!}\)\(e^{-x}\) 的卷积前 \(m+1\) 项即可。

复杂度 \(O(n\log n)\)

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>

typedef long long LL;

int readInt() {
int ans = 0, c, f = 1;
while (!isdigit(c = getchar()))
if (c == '-') f *= -1;
do ans = ans * 10 + c - '0';
while (isdigit(c = getchar()));
return ans * f;
}

const int N = 100000;
const int mod = 998244353;
const int g = 3;

int n, m, len, rev[N];
LL omega[N], iomega[N];

LL pow_mod(LL a, LL b) {
LL ans = 1;
for (a %= mod; b; b >>= 1, a = a * a % mod)
if (b & 1) ans = ans * a % mod;
return ans;
}

void InitNTT(int n) {
int k = 0;
len = 1;
while (len < n) len <<= 1, ++k;
for (int i = 1; i < len; ++i)
rev[i] = ((i & 1) << (k - 1)) | (rev[i >> 1] >> 1);
omega[0] = iomega[0] = 1;
LL w = pow_mod(g, (mod - 1) / len);
for (int i = 1; i < len; ++i)
iomega[len - i] = omega[i] = omega[i - 1] * w % mod;
}

void NTT(LL *A, const LL *omega) {
for (int i = 0; i < len; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
for (int h = 2; h <= len; h <<= 1)
for (int j = 0; j < len; j += h) {
const LL *w = omega;
for (int i = j; i < j + (h >> 1); ++i) {
LL _t1 = A[i], _t2 = A[i + (h >> 1)] * *w % mod;
A[i] = (_t1 + _t2) % mod;
A[i + (h >> 1)] = (_t1 - _t2) % mod;
w += len / h;
}
}
if (omega == ::iomega)
for (int i = 0, v = -(mod - 1) / len; i < len; ++i)
A[i] = A[i] * v % mod;
}

LL inv[N], ifac[N];
LL A[N], G[N];

int main() {
n = readInt();
m = readInt() + 1;
LL x = readInt();
inv[1] = ifac[0] = ifac[1] = 1;
for (int i = 2; i < m; ++i)
ifac[i] = ifac[i - 1] * (inv[i] = -(mod / i) * inv[mod % i] % mod) % mod;
for (int i = 0; i < m; ++i) {
A[i] = readInt() * ifac[i] % mod;
G[i] = i & 1 ? -ifac[i] : ifac[i];
}
InitNTT(m * 2);
NTT(A, omega); NTT(G, omega);
for (int i = 0; i < len; ++i) G[i] = G[i] * A[i] % mod;
NTT(G, iomega);
LL t = 1, ans = 0;
for (int i = 0; i < m; ++i) {
ans = (ans + G[i] * t) % mod;
t = t * x % mod * (n - i) % mod;
}
printf("%lld\n", (ans + mod) % mod);
return 0;
}