快速傅立叶变换(FFT)学习笔记

经典问题:记 \(f(x) = \sum_{i=0}^{n} a_i x^i\)\(g(x) = \sum_{i=0}^{m} b_i x^i\),其中\(1\leq n, m \leq 10^5\),求\(f \cdot g (x)\)

显然如果我们做朴素的多项式乘法,时间复杂度是 \(O(nm)\) 的。我们可以使用快速傅立叶变换(FFT)在 \(O(c \log c)\) 的时间内解决此问题,其中 \(c\) 是不小于 \(n+m\) 的最小的 \(2\) 的幂。

本文是笔者学习FFT的笔记和一些思考,不推荐作为教程食用。

在学习FFT之前,我们首先要理解一些概念。

多项式的系数表示和点值表示

对于多项式\(f(x) = \sum_{i=0}^{n} a_i x^i\),把系数写成向量的形式\(\vec{a} = [a_0, a_1, \cdots, a_n]\),则\(\vec{a}\)称为多项式的系数表示。我们输入一个\(x\),对\(x\)\(f(x)\)可以得到一个\(y\),表示这个多项式函数在\(x\)点处的值。

我们知道,\(n+1\)个点可以唯一确定一个\(n\)次多项式,那么我们可以使用\(n+1\)个序偶对\((x_i, y_i)\)来表示这个多项式,这称为多项式的点值表示。如何从点值表示得到任意\(x\)处的函数值呢?可以使用拉格朗日差值法,由于与FFT关系不大,故这里不做具体说明。

离散傅立叶变换(DFT)与反变换(IDFT)

对于一个系数表示\(\vec{a} = [a_0, a_1, \cdots, a_n]\),对其做离散傅立叶变换得到:

\[ {\rm DFT}(\vec{a}) = [f(w_{n}^{0}),f(w_{n}^{1}),f(w_{n}^{2})\cdots,f(w_{n}^{n-1})]  \]

设其表示的多项式为\(\hat{f}(x)\),我们可以对其做离散傅立叶变换的反变换重新的到\(f(x)\),即向量\(\vec{a}\)

\[ \vec{a} = {\rm IDFT} ({\rm DFT} (\vec{a})) = \frac{1}{n} [\hat{f}(w_{n}^{0}),\hat{f}(w_{n}^{-1}),\hat{f}(w_{n}^{-2})\cdots,\hat{f}(w_{n}^{-(n-1)})] \]

其中\(w_{n}^{k}\)\(x^n = 1\)在复数域中的根,称为\(n\)次单位根,\(k=1\)的时候又称为本原单位根。

多项式乘法的原理

考虑\(n\)次的多项式\(f(x)\)\(m\)次多项式相乘会产生一个\(n+m\)次的多项式,所以我们只需要获得两个多项式的\(n+m+1\)个点形成新的多项式的点值表示,这一点可以取\(n+m+1\)次单位根做\(\rm DFT\),然后对得到的向量做\(\rm IDFT\)就可以得到目标多项式的系数表示了。即:

\[\vec{c} =  {\rm IDFT} [{\rm DFT}(\vec{a}) \circ {\rm DFT}(\vec{b})]  = \vec{a} * \vec{b}\]

这里也验证了傅立叶变换时域卷积等于频域乘积的性质。

快速求解DFT和IDFT的方法

我们知道卷积的求解和差值的求解都是\(O(n^2)\)的,和朴素的求法相比并没有提升,我么考虑用分治的方法来优化DFT。 考虑\(w_n\)\(w_{2n}\),有这样的两条性质:

\[ \left\{ \begin{aligned} & (w_{2n}^{k})^2 = w_{n}^{k} \\ & w_{2n}^{n+k} = -w_{2n}^{k}  \end{aligned} \right . \]

这里可以在复平面上画图理解,当然《复变函数》中也有该性质的证明,这里不必过于纠结,会用即可。

\(f(x) = \sum_{i=0}^{n} a_i x^i\)\(n=2m\),我们考虑将其项的次数按照奇偶进行分类:

\[ \begin{aligned} f(x) = & \sum_{i=0}^{n} a_i x^i \\ = & \sum_{i=0}^{m} a_{2i} x^{2i} + \sum_{i=0}^{m} a_{2i+1} x^{2i+1} \\ = & \sum_{i=0}^{m} a_{2i} x^{2i} + x \sum_{i=0}^{m} a_{2i+1} x^{2i} \\ = & f_0(x^2) + x f_1(x^2)   \end{aligned} \]

假设\(0 \leq k < m\),那么对于\(w_{n}^{k}\)

\[\begin{aligned} f(w_{n}^{k}) = & f_0((w_{n}^{k})^2) + w_{n}^{k} f_1((w_{n}^{k})^2) \\ = & f_0(w_{m}^{k}) + w_{n}^{k} f_1(w_{m}^{k}) \end{aligned} \]

对于\(w_{n}^{m+k}\)

\[ \begin{aligned} f(w_{n}^{m+k}) = & f_0((w_{n}^{m+k})^2) + w_{n}^{m+k} f_1((w_{n}^{m+k})^2) \\ = & f_0(w_{m}^{k}) - w_{n}^{k} f_1(w_{m}^{k}) \end{aligned} \]

因此就可以用分治的方法快速计算DFT了。

位逆序置换与非递归FFT

为了加速FFT的过程,我们可以将FFT写成非递归的形式,考虑展开递归树。

假设我们有8个数\([(0,1,2,3,4,5,6,7)]\),经过一次递归之后变成\([(0,2,4,6), (1,3,5,7)]\),经过第二次递归之后变成\([(0,4),(2,6),(1,5),(3,7)]\),经过第三次递归之后变成\([(0), (4), (2), (6), (1), (5),(3),(7)]\),我们发现第\(i\)位是\(i\)的二进制数翻转对应的那个十进制数字,因此我们可以把每个数字放在它最后的位置上,然后逐层向上还原,写成非递归版的FFT。

一份常数比较大的模板

洛谷上\(10^6\)级别的多项式乘法有两个点是TLE的,LOJ上\(10^5\)级别的多项式乘法通过没什么问题。(待后续更新)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef complex <double> CP;
const int MAX_N = int(2.1E6) + 5;
const double PI = acos(-1);
namespace FFT {
int n, aSz, bSz;
CP a[MAX_N], b[MAX_N], omg[MAX_N], inv[MAX_N];
void init() {
for (int i = 0; i < n; ++i) {
omg[i] = CP(cos(2 * PI * i / n), sin(2 * PI * i / n));
inv[i] = conj(omg[i]);
}
}
void fft(CP *a, CP *omg) {
int lim = 0;
while ((1 << lim) < n) ++lim;
for (int i = 0; i < n; ++i) {
int t = 0;
for (int j = 0; j < lim; ++j) {
if((i >> j) & 1) t |= (1 << (lim - j - 1));
}
if (i < t) swap(a[i], a[t]);
}
for (int l = 2; l <= n; l <<= 1) {
int m = l / 2;
for (CP *p = a; p != a + n; p += l) {
for (int i = 0; i < m; ++i) {
CP t = omg[n / l * i] * p[i + m];
p[i + m] = p[i] - t;
p[i] += t;
}
}
}
}
void run() {
n = 1;
while (n < aSz + bSz) n <<= 1;
init();
// printf("n = %d\n", n);
fft(a, omg);
fft(b, omg);
for (int i = 0; i < n; ++i) a[i] *= b[i];
fft(a, inv);
int len = aSz + bSz - 1;
for (int i = 0; i < len; ++i) {
printf("%d%c", int(round(a[i].real() / n)), i == len - 1 ? '\n' : ' ');
}
}
};
int main() {
scanf("%d%d", &FFT::aSz, &FFT::bSz);
++FFT::aSz, ++FFT::bSz;
int u;
for (int i = 0; i < FFT::aSz; ++i) {
scanf("%d", &u);
FFT::a[i].real(u);
}
for (int i = 0; i < FFT::bSz; ++i) {
scanf("%d", &u);
FFT::b[i].real(u);
}
FFT::run();
return 0;
}