这可能是我第五次学FFT了……菜哭qwq
先给出一些个人认为非常优秀的参考资料:
一小时学会快速傅里叶变换(Fast Fourier Transform) - 知乎
快速傅里叶变换(FFT)用于计算两个$n$次多项式相乘,能把复杂度从朴素的$O(n^2)$优化到$O(nlog_2n)$。一个常见的应用是计算大整数相乘。
本文中所有多项式默认$x$为变量,其他字母均为常数。所有角均为弧度制。
一、多项式的两种表示方法
我们平时常用的表示方法称为“系数表示法”,即
$$A(x)=\sum _{i=0}^n a_ix^i$$
上面那个式子也可以看作一个以$x$为自变量的$n$次函数。用$n+1$个点可以确定一个$n$次函数(自行脑补初中学习的二次函数)。所以,给定$n+1$组$x$和对应的$A(x)$,就可以求出原多项式。用$n+1$个点表示一个$n$次多项式的方式称为“点值表示法”。
在“点值表示法”中,两个多项式相乘是$O(n)$的。因为对于同一个$x$,把它代入$A$和$B$求值的结果之积就是把它带入多项式$A\times B$求值的结果(这是多项式乘法的意义)。所以把点值表示法下的两个多项式的$n+1$个点的值相乘即可求出两多项式之积的点值表示。
线性复杂度点值表示好哇好
但是,把系数表示法转换成点值表示法需要对$n+1$个点求值,而每次求值是$O(n)$的,所以复杂度是$O(n^2)$。把点值表示法转换成系数表示法据说也是$O(n^2)$的(然而我只会$O(n^3)$的高斯消元qwq)。所以暴力取点然后算还不如直接朴素算法相乘……
但是有一种神奇的算法,通过取一些具有特殊性质的点可以把复杂度降到$O(nlog_2n)$。
二、单位根
从现在开始,所有$n$都默认是$2$的非负整数次幂,多项式次数为$n-1$。应用时如果多项式次数不是$2$的非负整数次幂减$1$,可以加系数为$0$的项补齐。
先看一些预备知识:
复数$a+bi$可以看作平面直角坐标系上的点$(a,b)$。这个点到原点的距离称为模长,即$\sqrt{a^2+b^2}$;原点与$(a,b)$所连的直线与实轴正半轴的夹角称为辐角,即$sin^{-1}\frac{b}{a}$。复数相乘的法则:模长相乘,辐角相加。
把以原点为圆心,$1$为半径的圆(称为“单位圆”)$n$等分,$n$个点中辐角最小的等分点(不考虑$1$)称为$n$次单位根,记作$\omega_n$,则这$n$个等分点可以表示为$\omega_n^k(0\leq k < n)$
这里如果不理解,可以考虑周角是$2\pi$,$n$次单位根的辐角是$\frac{2\pi}{n}$。$w_n^k=w_n^{k-1}\times w_n^1$,复数相乘时模长均为$1$,相乘仍为$1$。辐角$\frac{2\pi (k-1)}{n}$加上单位根的辐角$\frac{2\pi}{n}$变成$\frac{2\pi k}{n}$。
单位根具有如下性质:
1.折半引理
$$w_{2n}^{2k}=w_n^k$$
模长都是$1$,辐角$\frac{2\pi \times 2k}{2n}=\frac{2\pi k}{n}$,故相等。
2.消去引理
$$w_n^{k+\frac{n}{2}}=-w_n^k$$
这个从几何意义上考虑,$w_n^{k+\frac{n}{2}}$的辐角刚好比$w_n^k$多了$\frac{2\pi \times \frac{n}{2}}{n}=\pi$,刚好是一个平角,所以它们关于原点中心对称。互为相反数的复数关于原点中心对称。
3.(不知道叫什么的性质)其中$k$是整数
$$w_n^{a+kn}=w_n^a$$
这个也很好理解:$w_n^n$的辐角是$2\pi$,也就是转了一整圈回到了实轴正半轴上,这个复数就是实数$1$。乘上一个$w_n^n$就相当于给辐角加了一个周角,不会改变位置。
三、离散傅里叶变换(DFT)
DFT把多项式从系数表示法转换到点值表示法。
我们大力尝试把$n$次单位根的$0$到$n-1$次幂分别代入$n-1$次多项式$A(x)$。首先先对$A(x)$进行奇偶分组,得到:
$$A_1(x)=\sum_{i=0}^{\frac{n-1}{2}}a_{2i}·x^i$$
$$A_2(x)=\sum_{i=0}^{\frac{n-1}{2}}a_{2i+1}·x^i$$
则有:
$$A(x)=A_1(x^2)+x·A_2(x^2)$$
把$w_n^k$代入,得:
$$A(w_n^k)=A_1(w_n^{2k})+w_n^k·A_2(w_n^{2k})$$
根据折半引理,有:
$$A(w_n^k)=A_1(w_{\frac{n}{2}}^k)+w_n^k·A_2(w_{\frac{n}{2}}^k)$$
此时有一个特殊情况。当$\frac{n}{2}\leq k < n$,记$a=k-\frac{n}{2}$,则根据消去引理和上面第三个性质,有:
$$w_n^a=-w_n^k$$
$$w_{\frac{n}{2}}^a=w_{\frac{n}{2}}^k$$
所以
$$A(w_n^k)=A_1(w_{\frac{n}{2}}^a)-w_n^a·A_2(w_{\frac{n}{2}}^a)$$
这样变换主要是为了防止右侧式子里出现$w_n$的不同次幂。
按照这个式子可以递归计算。共递归$O(log_2n)$层,每层需要$O(n)$枚举$k$,因此可以在$O(nlog_2n)$内把系数表示法变为点值表示法。
四、离散傅里叶反变换(IDFT)
设$w_n^k(0\leq k<n)$代入多项式$A(x)$后得到的点值为$b_k$,令多项式$B(x)$:
$$B(x)=\sum_{i=0}^{n-1}b_ix^i$$
一个结论:设$w_n^{-k}(0\leq k<n)$代入$B(x)$后得到的点值为$c_k$,则多项式$A(x)$的系数$a_k=\frac{c_k}{n}$。下面来证明这个结论。
$$ \begin{aligned} c_k&=\sum_{i=0}^{n-1}b_i·w_n^{-ik}\ &=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j·w_n^{ij}·w_n^{-ik}\ &=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}w_n^{i(j-k)} \end{aligned} $$
脑补一下$\sum_{i=0}^{n-1}w_n^{i(j-k)}$怎么求。可以看出这是一个公比为$w_n^{j-k}$的等比数列。
当$j=k$,$w_n^0=1$,所以上式的值是$n$。
否则,根据等比数列求和公式,上式等于$w_n^{j-k}·\frac{w_n^{n(j-k)}-1}{w_n^{j-k}-1}$。$w_n^{n(j-k)}$相当于转了整整$(j-k)$圈,所以值为$1$,这个等比数列的和为$0$。
由于当$j \neq k$时上述等比数列值为$0$,所以$c_k=a_kn$,即$a_k=\frac{c_k}{n}$
至此,已经可以写出递归的FFT代码了。(常数大的一批qwq
实测洛谷3803有$77$分,会TLE两个点。
下面放上部分代码。建议继续阅读之前先充分理解这种写法。
const int N = (1e6 + 10) * 4;
const double PI = 3.141592653589793238462643383279502884197169399375105820974944;
struct cpx
{
double a, b;
cpx(){}
cpx(const double x, const double y = 0)
: a(x), b(y){}
cpx operator + (const cpx &c) const
{
return (cpx){a + c.a, b + c.b};
}
cpx operator - (const cpx &c) const
{
return (cpx){a - c.a, b - c.b};
}
cpx operator * (const cpx &c) const
{
return (cpx){a * c.a - b * c.b, a * c.b + b * c.a};
}
};
int n, m;
cpx a[N], b[N], buf[N];
inline cpx omega(const int n, const int k)
{
return (cpx){cos(2 * PI * k / n), sin(2 * PI * k / n)};
}
void FFT(cpx *a, const int n, const bool inv)
{
if (n == 1)
return;
static cpx buf[N];
int mid = n >> 1;
for (int i = 0; i < mid; i++)
{
buf[i] = a[i << 1];
buf[i + mid] = a[i << 1 | 1];
}
memcpy(a, buf, sizeof(cpx[n]));
//now a[i] is coefficient
FFT(a, mid, inv), FFT(a + mid, mid, inv);
//now a[i] is point value
//a[i] is A1(w_n^i), a[i + mid] is A2(w_n^i)
for (int i = 0; i < mid; i++)
{//calculate point value of A(w_n^i) and A(w_n^{i+n/2})
cpx x = omega(n, i * (inv ? -1 : 1));
buf[i] = a[i] + x * a[i + mid];
buf[i + mid] = a[i] - x * a[i + mid];
}
memcpy(a, buf, sizeof(cpx[n]));
}
int work()
{
read(n), read(m);
for (int i = 0; i <= n; i++)
{
int tmp;
read(tmp);
a[i] = tmp;
}
for (int i = 0; i <= m; i++)
{
int tmp;
read(tmp);
b[i] = tmp;
}
for (m += n, n = 1; n <= m; n <<= 1);
FFT(a, n, false), FFT(b, n, false);
for (int i = 0; i < n; i++)
a[i] = a[i] * b[i];
FFT(a, n, true);
for (int i = 0; i <= m; i++)
write((int)((a[i].a / n) + 0.5)), putchar(' ');
return 0;
}
五、优化
递归太慢了,我们用迭代。
考虑奇偶分组的过程。每一次把奇数项分到前面,偶数项分到后面,如${a_0,a_1,a_2,a_3,a_4,a_5,a_6,a_7}$,按照这个过程分组,最终每组只剩一个数的时候是${a_0,a_4,a_2,a_6,a_1,a_5,a_3,a_7}$。经过仔mo细bai观da察lao,发现$1_{(10)}=001_{(2)}$,$4_{(10)}=100_{(2)}$,一个数最终变成的数的下标是它的下标的二进制表示颠倒过来(并不知道为什么)。我们可以递推算这个(其中lg2是$log_2n$):
rev[i] = rev[i >> 1] >> 1 | ((i & 1) << (lg2 - 1))
可以先生成原数组经过$log_2n$次奇偶分组的最终状态,然后一层一层向上合并即可。
另外,标准库中的三角函数很慢,可以打出$w_n^k$和$w_n^{-k}$的表(或者只打一个表,因为$w_n^{-k}=w_n^{n-k}$)。当前分治的区间长度为$l$时,查询$w_l^k$相当于查询$w_n^{\frac{nk}{l}}$(这里要小心$nk$爆int……血的教训)。
代码如下(洛谷1919)
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cctype>
#include <cmath>
#include <string>
using namespace std;
namespace zyt
{
template<typename T>
inline void read(T &x)
{
char c;
bool f = false;
x = 0;
do
c = getchar();
while (c != '-' && !isdigit(c));
if (c == '-')
f = true, c = getchar();
do
x = x * 10 + c - '0', c = getchar();
while (isdigit(c));
if (f)
x = -x;
}
inline void read(char &c)
{
do
c = getchar();
while (!isgraph(c));
}
template<typename T>
inline void write(T x)
{
static char buf[20];
char *pos = buf;
if (x < 0)
putchar('-'), x = -x;
do
*pos++ = x % 10 + '0';
while (x /= 10);
while (pos > buf)
putchar(*--pos);
}
const int N = (1 << 17) + 11;
const double PI = acos(-1.0L);
struct cpx
{
double a, b;
cpx(const double x = 0, const double y = 0)
:a(x), b(y) {}
cpx operator + (const cpx &c) const
{
return (cpx){a + c.a, b + c.b};
}
cpx operator - (const cpx &c) const
{
return (cpx){a - c.a, b - c.b};
}
cpx operator * (const cpx &c) const
{
return (cpx){a * c.a - b * c.b, a * c.b + b * c.a};
}
cpx conj() const
{
return (cpx){a, -b};
}
~cpx(){}
}omega[N], inv[N];
int rev[N];
void FFT(cpx *a, const int n, const cpx *w)
{
for (int i = 0; i < n; i++)
if (i < rev[i])
swap(a[i], a[rev[i]]);
for (int len = 1; len < n; len <<= 1)
for (int i = 0; i < n; i += (len << 1))
for (int k = 0; k < len; k++)
{
cpx tmp = a[i + k] - w[k * (n / (len << 1))] * a[i + len + k];
a[i + k] = a[i + k] + w[k * (n / (len << 1))] * a[i + len + k];
a[i + len + k] = tmp;
}
}
void init(const int lg2)
{
for (int i = 0; i < (1 << lg2); i++)
{
rev[i] = rev[i >> 1] >> 1 | (i & 1) << (lg2 - 1);
omega[i] = (cpx){cos(2 * PI * i / (1 << lg2)), sin(2 * PI * i / (1 << lg2))};
inv[i] = omega[i].conj();
}
}
int work()
{
int n;
static cpx a[N], b[N];
read(n);
for (int i = 0; i < n; i++)
{
char c;
read(c);
a[i] = c - '0';
}
for (int i = 0; i < n; i++)
{
char c;
read(c);
b[i] = c - '0';
}
for (int i = 0; (i << 1) < n; i++)
swap(a[i], a[n - i - 1]), swap(b[i], b[n - i - 1]);
int lg2 = 0, tmp = n << 1;
for (n = 1; n < tmp; ++lg2, n <<= 1);
init(lg2);
FFT(a, n, omega), FFT(b, n, omega);
for (int i = 0; i < n; i++)
a[i] = a[i] * b[i];
FFT(a, n, inv);
bool st = false;
static int ans[N];
for (int i = 0; i < n; i++, n += (ans[n]))
{
ans[i] += (int)(a[i].a / n + 0.5);
ans[i + 1] += ans[i] / 10;
ans[i] %= 10;
}
for (int i = n - 1; i >= 0; i--)
if (st || ans[i])
write(ans[i]), st = true;
return 0;
}
}
int main()
{
return zyt::work();
}
来源:oschina
链接:https://my.oschina.net/u/4299463/blog/3751761