分治FFT
目的
解决这样一类式子:
\[f[n] = \sum_{i = 0}^{n - 1}f[i]g[n - i]\]
算法
看上去跟普通卷积式子挺像的,但是由于计算\(f\)的每一项时都在利用它前面的项来产生贡献,所以不能一次FFT搞完。用FFT爆算复杂度\(O(n^2logn)\),比直接枚举复杂度还高……
考虑优化这个算法,如果我们要计算区间\([l, r]\)内的\(f\)值,如果可以快速算出区间\([l, mid]\)内的\(f\)值对区间\([mid + 1, r]\)内的\(f\)值产生了怎样的影响,就可以采取CDQ分治,不断递归下去算。
考虑\(x \in [mid + 1, r]\),\([l, mid]\)给它的贡献是:
\[h[x] = \sum_{i = l}^{mid}f[i]g[x - i]\]
为了方便,我们将范围扩充到\([1, x - 1]\)(假设此时\(f[mid + 1] ... f[r] = 0\)),因此有:
\[h[x] = \sum_{i = l}^{x - 1}f[i]g[x - i]\]
为了便于FFT计算,将枚举改成从0开始。(把表达式中的\(i\)改成\(i + l\),因为原来的\(i\)等于现在的\(i + l\))
\[h[x] = \sum_{i = 0}^{x - l - 1}f[i + l]g[x - l - i]\]
为了表示成卷积形式,我们令:
\[a[i] = f[i + l], b[i - 1] = g[i]\]
再在原式中用\(a[i], b[i]\)代替\(f[i], g[i]\).
\[h[x] = \sum_{i = 0}^{x - l - 1}a[i] b[x - l - 1 - i]\]
观察到后面刚好就是多项式乘法中某一项的系数,即
\[h[x] = (a * b)(x - l - 1)\]
在cdq分治的过程中用FFT/NTT计算即可。
代码
洛谷上的模板,因为要取模,所以用的NTT
#include<bits/stdc++.h> using namespace std; #define R register int #define p 998244353 #define AC 400100 #define LL long long #define ld double const int G = 3, Gi = 332748118; int n, lim, len; int f[AC], g[AC], a[AC], b[AC], rev[AC]; inline int read() { int x = 0;char c = getchar(); while(c > '9' || c < '0') c = getchar(); while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar(); return x; } inline void up(int &a, int b) {a += b; if(a < 0) a += p; if(a >= p) a -= p;} inline int qpow(int x, int have) { int rnt = 1; while(have) { if(have & 1) rnt = 1LL * rnt * x % p; x = 1LL * x * x % p, have >>= 1; } return rnt; } void init(int length)//这里的length已经是2个数组加起来的长度了 { lim = 1, len = 0; while(lim <= length) lim <<= 1, ++ len; for(R i = 0; i < lim; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1)), a[i] = b[i] = 0; } void NTT(int *A, int opt) { for(R i = 0; i < lim; i ++) if(i < rev[i]) swap(A[i], A[rev[i]]); for(R i = 1; i < lim; i <<= 1) { LL W = qpow((opt > 0) ? G : Gi, (p - 1) / (i << 1)); for(R r = i << 1, j = 0; j < lim; j += r) for(R k = 0, w = 1; k < i; k ++, w = 1LL * w * W % p) { int x = A[j + k], y = 1LL * w * A[j + k + i] % p; A[j + k] = (x + y) % p, A[j + k + i] = (x - y) % p; } } if(opt == -1) { int inv = qpow(lim, p - 2); for(R i = 0; i < lim; i ++) A[i] = 1LL * A[i] * inv % p; } } void pre() { n = read(), f[0] = 1; for(R i = 1; i < n; i ++) g[i] = read(); } void cal(int *A, int *B) { NTT(A, 1), NTT(B, 1); for(R i = 0; i <= lim; i ++) A[i] = 1LL * A[i] * B[i] % p; NTT(A, -1); } void cdq(int l, int r) { if(l == r) return ; int mid = (l + r) >> 1, length = r - l + 1; cdq(l, mid); init(length); for(R i = l; i <= mid; i ++) a[i - l] = f[i]; for(R i = 1; i < length; i ++) b[i - 1] = g[i];//这里要移动是为了凑x - l - 1 cal(a, b); for(R i = mid + 1; i <= r; i ++) up(f[i], a[i - l - 1]); cdq(mid + 1, r); } int main() { freopen("in.in", "r", stdin); pre(); cdq(0, n - 1); for(R i = 0; i < n; i ++) printf("%d ", ((f[i] % p) + p) % p); printf("\n"); fclose(stdin); return 0; }
来源:https://www.cnblogs.com/ww3113306/p/10359556.html