题面
解析
答案显然是$\frac{\sum_{i=1}^n\sum_{j=1}^m (a_i+b_j)^k}{n*m}$
因此只需要求出$\sum_{i=1}^n\sum_{j=1}^m (a_i+b_j)^k$即可
暴力展开:$$\begin{align*}\sum_{i=1}^n\sum_{j=1}^m (a_i+b_j)^k&=\sum_{i=1}^n\sum_{j=1}^m\sum_{p=0}^k\binom{k}{p}a_i^p*b_j^{k-p}\\ &=k!\sum_{p=0}^k\sum_{i=1}^n\frac{a_i^p}{p!}\sum_{j=1}^m\frac{b_j^{k-p}}{(k-p)!}\\&=k!\sum_{p=0}^k\frac{\sum_{i=1}^na_i^p}{p!}\frac{\sum_{j=1}^mb_j^{k-p}}{(k-p)!}\end{align*}$$
现在就是要求对于任一$1\leqslant p \leqslant k$,$\sum_{i=1}^na_i^p$(求$\sum_{j=1}^mb_j^{k-p}$是类似的)
这个比较常见,我在生成函数小结里有写,这里直接给出结论:$$\begin{align*}F(x)&=\sum_{j=0}^{\infty}\sum_{i=1}^na_i^jx^j\\&=n-x\ln'(\prod_{i=1}^n(1-a_ix))\end{align*}$$
$\prod_{i=1}^n(1-a_ix)$可以分治$NTT$
对$a$、$b$分别求出它们的$F(x)$,第$i$项系数除以$i!$后卷积起来。卷积后的第$i$项系数乘以$i!$再除以$n*m$就是答案。
$O(N\log^2N)$
代码:
#include<cstdio> #include<iostream> #include<algorithm> #include<cstring> #include<vector> #define ls (x << 1) #define rs ((x << 1) | 1) using namespace std; typedef long long ll; const int maxn = 200005, mod = 998244353, g = 3; inline int read() { int ret, f=1; char c; while((c=getchar())&&(c<'0'||c>'9'))if(c=='-')f=-1; ret=c-'0'; while((c=getchar())&&(c>='0'&&c<='9'))ret=(ret<<3)+(ret<<1)+c-'0'; return ret*f; } int add(int x, int y) { return x + y < mod? x + y: x + y - mod; } int rdc(int x, int y) { return x - y < 0? x - y + mod: x - y; } ll qpow(ll x, int y) { ll ret = 1; while(y) { if(y&1) ret = ret * x % mod; x = x * x % mod; y >>= 1; } return ret; } int n, m, a[maxn], b[maxn], lim, bit, rev[maxn<<1]; ll fac[maxn], fnv[maxn]; ll ginv, c[maxn<<1], d[maxn<<1], A[maxn<<1], B[maxn<<1], t[maxn<<1], iv[maxn<<1]; void init() { ginv = qpow(g, mod - 2); fac[0] = 1; for(int i = 1; i <= 100001; ++i) fac[i] = fac[i-1] * i % mod; fnv[100001] = qpow(fac[100001], mod - 2); for(int i = 100000; i >= 0; --i) fnv[i] = fnv[i+1] * (i + 1) % mod; } void NTT_init(int x) { lim = 1; bit = 0; while(lim <= x) { lim <<= 1; ++ bit; } for(int i = 1; i < lim; ++i) rev[i] = (rev[i>>1] >> 1) | ((i & 1) << (bit - 1)); } void NTT(ll *x, int y) { for(int i = 1; i < lim; ++i) if(i < rev[i]) swap(x[i], x[rev[i]]); ll wn, w, u, v; for(int i = 1; i < lim; i <<= 1) { wn = qpow((y == 1)? g: ginv, (mod - 1) / (i << 1)); for(int j = 0; j < lim; j += (i << 1)) { w = 1; for(int k = 0; k < i; ++k) { u = x[j+k]; v = x[j+k+i] * w % mod; x[j+k] = add(u, v); x[j+k+i] = rdc(u, v); w = w * wn % mod; } } } if(y == -1) { ll linv = qpow(lim, mod - 2); for(int i = 0; i < lim; ++i) x[i] = x[i] * linv % mod; } } void get_inv(ll *x, ll *y, int len) { if(len == 1) { x[0] = qpow(y[0], mod - 2); return ; } get_inv(x, y, (len + 1) >> 1); for(int i = 0; i < len; ++i) c[i] = y[i]; NTT_init(len << 1); NTT(x, 1); NTT(c, 1); for(int i = 0; i < lim; ++i) { x[i] = rdc(add(x[i], x[i]), (c[i] * x[i] % mod) * x[i] % mod); c[i] = 0; } NTT(x, -1); for(int i = len; i < lim; ++i) x[i] = 0; } void get_ln(ll *x, ll *y, int len) { for(int i = 0; i < len; ++i) x[i] = y[i+1] * (i + 1) % mod; get_inv(iv, y, len); NTT_init(len << 1); NTT(x, 1); NTT(iv, 1); for(int i = 0; i < lim; ++i) { x[i] = x[i] * iv[i] % mod; iv[i] = 0; } NTT(x, -1); for(int i = len - 1; i >= 1; --i) x[i] = x[i-1] * qpow(i, mod - 2) % mod; x[0] = 0; for(int i = len; i < lim; ++i) x[i] = 0; } vector<int> G[maxn<<1]; void solve(int x, int l, int r, int *y) { G[x].clear(); if(l == r) { G[x].push_back(1); G[x].push_back(rdc(0, y[l])); return ; } int mid = (l + r) >> 1; solve(ls, l, mid, y); solve(rs, mid + 1, r, y); for(int i = 0; i <= mid - l + 1; ++i) c[i] = G[ls][i]; for(int i = 0; i <= r - mid; ++i) d[i] = G[rs][i]; NTT_init(r - l + 1); NTT(c, 1); NTT(d, 1); for(int i = 0; i < lim; ++i) { c[i] = c[i] * d[i] % mod; d[i] = 0; } NTT(c, -1); for(int i = 0; i <= r - l + 1; ++i) { G[x].push_back(c[i]); c[i] = 0; } for(int i = r - l + 2; i < lim; ++i) c[i] = 0; } int main() { init(); n = read(); m = read(); for(int i = 1; i <= n; ++i) a[i] = read(); for(int i = 1; i <= m; ++i) b[i] = read(); int q = read(); solve(1, 1, n, a); for(int i = 0; i <= n; ++i) t[i] = G[1][i]; get_ln(A, t, max(q, n) + 1); for(int i = 0; i <= max(q, n); ++i) A[i] = A[i+1] * (i + 1) % mod; for(int i = max(q, n); i >= 1; --i) A[i] = rdc(0, A[i-1]); A[0] = n; for(int i = 0; i <= max(q, n); ++i) A[i] = A[i] * fnv[i] % mod; solve(1, 1, m, b); memset(t, 0, sizeof(t)); for(int i = 0; i <= m; ++i) t[i] = G[1][i]; get_ln(B, t, max(q, m) + 1); for(int i = 0; i <= max(q, m); ++i) B[i] = B[i+1] * (i + 1) % mod; for(int i = max(q, m); i >= 1; --i) B[i] = rdc(0, B[i-1]); B[0] = m; for(int i = 0; i <= max(q, m); ++i) B[i] = B[i] * fnv[i] % mod; NTT_init(max(q, n) + max(q, m)); NTT(A, 1); NTT(B, 1); for(int i = 0; i < lim; ++i) A[i] = A[i] * B[i] % mod; NTT(A, -1); ll mul = qpow(1LL * n * m % mod, mod - 2); for(int i = 1; i <= q; ++i) printf("%lld\n", (A[i] * fac[i] % mod) * mul % mod); return 0; }
来源:https://www.cnblogs.com/Joker-Yza/p/12640512.html