FFT与多项式乘法

初探FFT——从多项式乘法开始

前置知识

复数

定义:z=a+biz=a+bi,其中 a,bRi=1a,b\in \mathbb{R} \,\,i=\sqrt{-1},也可以写成 z=reiθz=re^{i\theta},其中 rr 为它的模,θ\theta 为它的辐角

加法法则:(a+bi)+(c+di)=(a+c)+(b+d)i\left( a+bi \right) +\left( c+di \right) =\left( a+c \right) +\left( b+d \right) i

乘法法则:(a+bi)(c+di)=(acbd)+(ad+bc)i\left( a+bi \right) \left( c+di \right) =\left( ac-bd \right) +\left( ad+bc \right) i

除法法则:a+bic+di=ac+bdc2+d2+bcadc2+d2i\frac{a+bi}{c+di}=\frac{ac+bd}{c^2+d^2}+\frac{bc-ad}{c^2+d^2}i

欧拉定理:eiθ=cosθ+isinθe^{i\theta}=\cos \theta +i\sin \theta

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
struct Complex
{
double r, i;
Complex() { r = 0, i = 0; }
Complex(double real, double imag) : r(real), i(imag){};
};
Complex operator+(const Complex &a, const Complex &b)
{
return Complex(a.r + b.r, a.i + b.i);
}
Complex operator-(const Complex &a, const Complex &b)
{
return Complex(a.r - b.r, a.i - b.i);
}
Complex operator*(const Complex &a, const Complex &b)
{
return Complex(a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r);
}

单位根

单位根是方程 zn=1z^n=1 在复数范围内的 nn 个根

zn=rnenθi=1{r=1nθ=2kπ\begin{array}{l} z^n=r^ne^{n\theta i}=1\\ \Rightarrow \begin{cases} r=1\\ n\theta =2k\pi\\ \end{cases}\\ \end{array}

所以 ωn=1\omega ^n=1nn 个根为

ωnk=ei2kπn=cos2kπn+isin2kπn,(k=0,1,,n1)\begin{aligned} \omega _{n}^{k}&=e^{i\frac{2k\pi}{n}}\\ &=\cos \frac{2k\pi}{n}+i\sin \frac{2k\pi}{n}\\ \end{aligned},\left( k=0,1,\cdots ,n-1 \right)

ωn=ei2πn\omega _n=e^{i\frac{2\pi}{n}} 称为主 nn 次单位根

主单位根的三个引理:

消去引理:ωdndk=ωnk\omega _{dn}^{dk}=\omega _{n}^{k},例如 ωnn/2=ω2=1\omega _{n}^{n/2}=\omega _2=-1

proof.ωdndk=ei2dkπdn=ei2kπn=ωnkproof. \omega _{dn}^{dk}=e^{i\frac{2dk\pi}{dn}}=e^{i\frac{2k\pi}{n}}=\omega _{n}^{k}

折半引理:(ωnk+n/2)2=(ωnk)2=ωn/2k\left( \omega _{n}^{k+n/2} \right) ^2=\left( \omega _{n}^{k} \right) ^2=\omega _{n/2}^{k}

proof.ωnk+n/2=ωnkωnn/2=ωnk(ωnk)2=ωn2k=ωn/2k\begin{aligned} proof.&\omega _{n}^{k+n/2}=\omega _{n}^{k}\omega _{n}^{n/2}=-\omega _{n}^{k}\\ &\left( -\omega _{n}^{k} \right) ^2=\omega _{n}^{2k}=\omega _{n/2}^{k}\\ \end{aligned}

求和引理:i=0n1(ωnk)i=0,n>1\sum\nolimits_{i=0}^{n-1}{\left( \omega _{n}^{k} \right) ^i}=0,n>1

proof.i=0n1(ωnk)i=(ωnk)n1ωnk1=(ωnn)k1ωnk1=1k1ωnk1=0proof.\sum\nolimits_{i=0}^{n-1}{\left( \omega _{n}^{k} \right) ^i}=\frac{\left( \omega _{n}^{k} \right) ^n-1}{\omega _{n}^{k}-1}=\frac{\left( \omega _{n}^{n} \right) ^k-1}{\omega _{n}^{k}-1}=\frac{1^k-1}{\omega _{n}^{k}-1}=0

多项式

多项式有两种表示方法

  • 系数表示,将一个多项式表示成由其系数构成的向量的形式
    • 加法时间复杂度 O(n)O(n)
    • 乘法需要计算两多项式系数向量的卷积,复杂度 O(n2)O(n^2)
    • 求值使用秦九韶算法,复杂度 O(n)O(n)
  • 点值表示,对于 nn 次多项式,用至少 n+1n+1 个多项式上的点来表示,两多项式应该在相同位置取值
    • 加法,将纵坐标相加,时间复杂度 O(n)O(n)
    • 乘法,将纵坐标相乘,时间复杂度 O(n)O(n)
    • 插值,拉格朗日插值公式,时间复杂度 O(n2)O(n^2)

DFT 及 FFT 原理推导

离散傅里叶变换 DFT

A(x)=i=0n1aixiA\left( x \right) =\sum_{i=0}^{n-1}{a_ix^i}ωn0,ωn1,,ωnn1\omega _{n}^{0},\omega _{n}^{1},\cdots ,\omega _{n}^{n-1} 处的值为 y0,y1,,yn1y_0,y_1,\cdots ,y_{n-1}

离散傅里叶变换后的结果 yi=A(ωni)=j=0n1ωnijajy_i=A\left( \omega _{n}^{i} \right) =\sum_{j=0}^{n-1}{\omega _{n}^{ij}a^j},记为 y=DFTn(a)y=\mathrm{DFT}_n\left ( a \right)

使用秦九韶算法一次求值的时间复杂度是 O(n)O(n),求 n 个值的总复杂度是 O(n2)O(n^2)

快速傅里叶变换 FFT

通过单位根的性质去除冗余的计算量,将时间复杂度降为 O(nlogn)O\left ( n\log n \right)

将一个 n1n-1 次多项式的系数向量 A(x)a=[a0,a1,,an1]A\left ( x \right) \rightarrow \boldsymbol{a}=\left[ a_0, a_1,\cdots ,a_{n-1} \right] 分为偶数项和奇数项两个向量,记为 a[0]\boldsymbol{a}^{\left[ 0 \right]}a[1]\boldsymbol{a}^{\left[ 1 \right]}

FFT 的 n 应保证为 2 的幂

a[0]=[a0,a2,,an2]A[0](x)\boldsymbol{a}^{\left[ 0 \right]}=\left[ a_0,a_2,\cdots ,a_{n-2} \right] \rightarrow A^{\left[ 0 \right]}\left( x \right)

a[1]=[a1,a3,,an1]A[1](x)\boldsymbol{a}^{\left[ 1 \right]}=\left[ a_1, a_3,\cdots ,a_{n-1} \right] \rightarrow A^{\left[ 1 \right]}\left ( x \right)

xx 代入,写出它们的表达式,观察它们的联系

A(x)=a0+a1x++an1xn1A[0](x)=a0+a2x++an2xn/21A[1](x)=a1+a3x++an1xn/21\begin{array}{l} A\left( x \right) =a_0+a_1x+\cdots +a_{n-1}x^{n-1}\\ A^{\left[ 0 \right]}\left( x \right) =a_0+a_2x+\cdots +a_{n-2}x^{n/2-1}\\ A^{\left[ 1 \right]}\left( x \right) =a_1+a_3x+\cdots +a_{n-1}x^{n/2-1}\\ \end{array}

将后两个表达式的自变量换成 x2x^2

A(x)=a0+a1x++an1xn1A[0](x2)=a0+a2x2++an2xn2A[1](x2)=a1+a3x2++an1xn2\begin{array}{l} A\left( x \right) =a_0+a_1x+\cdots +a_{n-1}x^{n-1}\\ A^{\left[ 0 \right]}\left( x^2 \right) =a_0+a_2x^2+\cdots +a_{n-2}x^{n-2}\\ A^{\left[ 1 \right]}\left( x^2 \right) =a_1+a_3x^2+\cdots +a_{n-1}x^{n-2}\\ \end{array}

可以发现 A(x)=A[0](x2)+xA[1](x2)A\left ( x \right) =A^{\left[ 0 \right]}\left ( x^2 \right) +xA^{\left[ 1 \right]}\left ( x^2 \right)

下面代入两个具体的单位根

{A(ωnk)=A[0]((ωnk)2)+ωnkA[1]((ωnk)2)A(ωnk+n/2)=A[0]((ωnk+n/2)2)+ωnk+n/2A[1]((ωnk+n/2)2)\begin{cases} A\left( \omega _{n}^{k} \right) =A^{\left[ 0 \right]}\left( \left( \omega _{n}^{k} \right) ^2 \right) +\omega _{n}^{k}A^{\left[ 1 \right]}\left( \left( \omega _{n}^{k} \right) ^2 \right)\\ A\left( \omega _{n}^{k+n/2} \right) =A^{\left[ 0 \right]}\left( \left( \omega _{n}^{k+n/2} \right) ^2 \right) +\omega _{n}^{k+n/2}A^{\left[ 1 \right]}\left( \left( \omega _{n}^{k+n/2} \right) ^2 \right)\\ \end{cases}

根据之前提到的折半引理和消去引理,可以得到

{A(ωnk)=A[0](ωn/2k)+ωnkA[1](ωn/2k)A(ωnk+n/2)=A[0](ωn/2k)ωnkA[1](ωn/2k)\begin{cases} A\left( \omega _{n}^{k} \right) =A^{\left[ 0 \right]}\left( \omega _{n/2}^{k} \right) +\omega _{n}^{k}A^{\left[ 1 \right]}\left( \omega _{n/2}^{k} \right)\\ A\left( \omega _{n}^{k+n/2} \right) =A^{\left[ 0 \right]}\left( \omega _{n/2}^{k} \right) -\omega _{n}^{k}A^{\left[ 1 \right]}\left( \omega _{n/2}^{k} \right)\\ \end{cases}

A[0](x)A^{\left[ 0 \right]}\left ( x \right)A[1](x)A^{\left[ 1 \right]}\left ( x \right) 均为次数为 n/2n/2 的多项式,所以问题变成多项式 A[0](x)A^{\left[ 0 \right]}\left ( x \right)A[1](x)A^{\left[ 1 \right]}\left ( x \right) 在各个 n/2n/2 次单位根上的值,也就是从 DFTn(a)\mathrm{DFT}_n\left ( a \right) 转化为 DFTn/2(a[0])\mathrm{DFT}_{n/2}\left ( a^{\left[ 0 \right]} \right)DFTn/2(a[1])\mathrm{DFT}_{n/2}\left ( a^{\left[ 1 \right]} \right) 两个子问题,子问题可以通过同样的方式递归求解,再根据折半引理快速地合并结果。

T(n)=2T(n/2)+O(n)=O(nlogn)T\left ( n \right) =2T\left ( n/2 \right) +O\left ( n \right) =O\left ( n\log n \right)

C++代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include <iostream>
#include <cmath>
using namespace std;
const double PI = acos(-1);
struct Complex
{
double r, i;
Complex() { r = 0, i = 0; }
Complex(double real, double image) : r(real), i(image){};
};
Complex operator+(const Complex &a, const Complex &b)
{
return Complex(a.r + b.r, a.i + b.i);
}
Complex operator-(const Complex &a, const Complex &b)
{
return Complex(a.r - b.r, a.i - b.i);
}
Complex operator*(const Complex &a, const Complex &b)
{
return Complex(a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r);
}
void FFT(Complex *a, int len)
{
if (len == 1)
return;
Complex a0[len >> 1], a1[len >> 1];
for (int i = 0; i < len; i += 2)
a0[i >> 1] = a[i], a1[i >> 1] = a[i + 1];
FFT(a0, len >> 1);
FFT(a1, len >> 1);
// wn也叫旋转因子,因为乘wn相当于在复平面逆时针旋转2*PI/len
Complex wn(cos(2 * PI / len), sin(2 * PI / len));
Complex w(1, 0);
for (int k = 0; k < (len >> 1); k++)
{
a[k] = a0[k] + w * a1[k];
a[k + len >> 1] = a0[k] - w * a1[k];
w = w * wn;
}
}

高效实现 FFT

蝴蝶操作

代码中存在公共子表达式 w * a1[k],可以计算一次后存在变量 t 中,这样的操作叫做蝴蝶操作

迭代实现

每次递归都把数组分为奇数位和偶数位,下面模拟这个过程

1
2
3
4
a0, a1, a2, a3, a4, a5, a6, a7
a0, a2, a4, a6 a1, a3, a5, a7
a0, a4 a2, a6 a1, a5 a3, a7
a0 a4 a2 a6 a1 a5 a3 a7

如果能知道递归最终的顺序,并按照这个顺序重新排列,那么就可以直接合并,不需要递归的时间,也节约了空间。

1
2
3
4
5
6
7
1  a0    a4    a2    a6    a1    a5    a3    a7
w2 w2 w2 w2
2 a0, a4 a2, a6 a1, a5 a3, a7
w4 w4
3 a0, a2, a4, a6 a1, a3, a5, a7
w8
4 a0, a1, a2, a3, a4, a5, a6, a7

通过观察可以发现,每次合并需要的主 nn 次单位根的 nn 取决于合并后的数组有多少元素,有多少个元素又可以通过当前层数确定。

位逆序置换

最后思考一下如何得到最终的顺序,这个操作称为位逆序置换

分析一下下标的规律

1
2
0 1 2 3 4 5 6 7
0 4 2 6 1 5 3 7

二进制

1
2
000 001 010 011 100 101 110 111
000 100 010 110 001 101 011 111

可以发现,最终位置的二进制和当前位置的二进制恰好是逆序的

位逆序置换可以 O(n)O(n) 从小到大递推实现,设 len=2klen=2^k,其中 kk 表示二进制数的长度,设 R(x)R(x) 表示长度为 kk 的二进制数 xx 翻转后的数(高位补 0),要求的是 R(0),R(1),,R(n1)R\left ( 0 \right) ,R\left ( 1 \right) ,\cdots ,R\left ( n-1 \right)

首先 R(0)=0R(0)=0

然后从小到大求 R(x)R(x),在求 R(x)R(x) 时,R(n2)R\left ( \lfloor \frac{n}{2} \rfloor \right) 是已知的。因此把 xx 向右移一位(除以 2),然后翻转,再右移一位,就得到了 x 除了(二进制)个位之外其它位的翻转结果。

考虑个位的翻转结果:如果个位是 0,翻转之后最高位就是 0,如果个位是 1,翻转之后最高位就是 1,因此再加上 len2=2k1\frac{len}{2}=2^{k-1}。综上

R(x)=R(x/2)2+(xmod2)×len2R\left( x \right) =\lfloor \frac{R\left( \lfloor x/2 \rfloor \right)}{2} \rfloor +\left( x\mathrm{mod}2 \right) \times \frac{len}{2}

举例:k=5,len=(100000)2k=5,len=\left( 100000 \right) _2。翻转 (11001)2\left( 11001 \right) _2

  1. 已经知道 R((1100)2)=R((01100)2)=(00110)2R\left( \left( 1100 \right) _2 \right) =R\left( \left( 01100 \right) _2 \right) =\left( 00110 \right) _2,右移一位得到 (00011)2\left( 00011 \right) _2
  2. 个位是 1,要翻转到最高位,即 (00011)2+2k1=(00011)2+(10000)2=(10011)2\left( 00011 \right) _2+2^{k-1}=\left( 00011 \right) _2+\left( 10000 \right) _2=\left( 10011 \right) _2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 同样需要保证 len 是 2 的幂
// 记 rev[i] 为 i 翻转后的值
void change(Complex *a, int len)
{
for (int i = 0; i < len; ++i)
{
rev[i] = rev[i >> 1] >> 1;
if (i & 1) // 如果最后一位是 1,则翻转成 len/2
rev[i] |= len >> 1;
}
for (int i = 0; i < len; ++i)
if (i < rev[i]) // 保证每对数只翻转一次
swap(a[i], a[rev[i]]);
}

FFT 逆变换

快速傅里叶逆变换 IDFT

DFTyi=j=0n1ωnijajy=Vna\mathrm{DFT}\rightarrow y_i=\sum_{j=0}^{n-1}{\omega _{n}^{ij}a_j}\rightarrow \boldsymbol{y}=\boldsymbol{V}_n\boldsymbol{a},其中 (Vn)ij=ωnij\left( \boldsymbol{V}_n \right) _{ij}=\omega _{n}^{ij}

离散傅里叶变换给出 a 求 y,而离散傅里叶逆变换给出 y 求 a

所以 a=Vn1y\boldsymbol{a}=\boldsymbol{V}_{n}^{-1}\boldsymbol{y},这里直接给出 (Vn1)ij=ωnijn\left( \boldsymbol{V}_{n}^{-1} \right) _{ij}=\frac{\omega _{n}^{-ij}}{n},可以通过矩阵相乘来验证结论

(Vn1Vn)ij=k=0n1ωnkin×ωnkj=k=0n1ωnk(ji)n=In\begin{aligned} \left( \boldsymbol{V}_{n}^{-1}\boldsymbol{V}_n \right) _{ij}&=\sum_{k=0}^{n-1}{\frac{\omega _{n}^{-ki}}{n}\times \omega _{n}^{kj}}\\ &=\sum_{k=0}^{n-1}{\frac{\omega _{n}^{k\left( j-i \right)}}{n}}\\ &=\boldsymbol{I}_n\\ \end{aligned}

根据前面提到的求和引理,只有当 i==ji==j 时,值才为 0,所以只有对角线上是 1,其它全为 0,结果是单位矩阵。

将逆矩阵代入,再进行化简,得到

ai=j=0n1ωnijnyj=1nj=0n1ωnijyj\begin{aligned} a_i&=\sum_{j=0}^{n-1}{\frac{\omega _{n}^{-ij}}{n}y_j}\\ &=\frac{1}{n}\sum_{j=0}^{n-1}{\omega _{n}^{-ij}y_j}\\ \end{aligned}

可以发现,逆变换的公式和傅里叶变换的公式十分相似,所以可以稍微改动傅里叶变换的代码来计算逆变换

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
/*
* 做 FFT
* len 必须是 2^k 形式
* opt == 1 时是 DFT,opt == -1 时是 IDFT
*/
void FFT(Complex *a, int len, int opt)
{
// 位逆序置换
change(a, len);
// 模拟合并过程,从高度为1开始
for (int dep = 1; dep <= log2(len); dep++)
{
// 合并后长度为m
int m = 1 << dep;
// wn:当前单位复根的间隔:w^1_m
Complex wn(cos(2 * PI / m), sin(opt * 2 * PI / m));
// 合并,共 len / m 次。
for (int k = 0; k < len; k += m)
{
// 计算当前单位复根,一开始是 1 = w^0_n,之后是以 wn 为间隔递增
Complex w(1, 0);
for (int j = 0; j < m / 2; j++)
{
// 左侧部分和右侧是子问题的解
Complex t = w * a[k + j + m / 2];
Complex u = a[k + j];
// 这就是把两部分分治的结果加起来
a[k + j] = u + t;
a[k + j + m / 2] = u - t;
w = w * wn;
}
}
}
if (opt == -1)
for (int i = 0; i < len; i++)
a[i].r /= len;
}

FFT 求多项式卷积

求两个向量的卷积,可以用 O(nlogn)O(n\log n) 的时间求出两个向量的离散傅里叶变换,也就相当于将系数表示转化为点值表示,然后用 O(n)O(n) 的时间对两个结果逐元素相乘,也就是点值表示中的乘法,最后再用 O(nlogn)O(n\log n) 的时间求出乘法结果的傅里叶逆变换,也就是将点值表示转化为系数表示。这样就可以在 O(nlogn)O(n\log n) 的时间复杂度内算出卷积。

上述过程也就是下面这个公式

ab=IDFT2n(DFT2n(a)DFT2n(b))\boldsymbol{a}\otimes \boldsymbol{b}=\mathrm{IDFT}_{2n}\left( \mathrm{DFT}_{2n}\left( \boldsymbol{a} \right) \odot \mathrm{DFT}_{2n}\left( \boldsymbol{b} \right) \right)

代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <iostream>
#include <cmath>
using namespace std;
const int maxn = 200005;
const double PI = acos(-1);
struct Complex
{
double r, i;
Complex() { r = 0, i = 0; }
Complex(double real, double imag) : r(real), i(imag){};
} F[maxn], G[maxn];
Complex operator+(const Complex &a, const Complex &b)
{
return Complex(a.r + b.r, a.i + b.i);
}
Complex operator-(const Complex &a, const Complex &b)
{
return Complex(a.r - b.r, a.i - b.i);
}
Complex operator*(const Complex &a, const Complex &b)
{
return Complex(a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r);
}
int rev[maxn];
void change(Complex *a, int len)
{
for (int i = 0; i < len; ++i)
{
rev[i] = rev[i >> 1] >> 1;
if (i & 1)
rev[i] |= len >> 1;
}
for (int i = 0; i < len; ++i)
if (i < rev[i])
swap(a[i], a[rev[i]]);
}
void FFT(Complex *a, int len, int opt)
{
change(a, len);
for (int dep = 1; dep <= log2(len); dep++)
{
int m = 1 << dep;
Complex wn(cos(2 * PI / m), sin(opt * 2 * PI / m));
for (int k = 0; k < len; k += m)
{
Complex w(1, 0);
for (int j = 0; j < m / 2; j++)
{
Complex t = w * a[k + j + m / 2];
Complex u = a[k + j];
a[k + j] = u + t;
a[k + j + m / 2] = u - t;
w = w * wn;
}
}
}
if (opt == -1)
for (int i = 0; i < len; i++)
a[i].r /= len;
}
int main()
{
int n, m, len = 1;
cin >> n >> m;
for (int i = 0; i <= n; i++)
cin >> F[i].r;
for (int i = 0; i <= m; i++)
cin >> G[i].r;
while (len <= n + m)
len <<= 1;
FFT(F, len, 1);
FFT(G, len, 1);
for (int i = 0; i < len; i++)
F[i] = F[i] * G[i];
FFT(F, len, -1);
for (int i = 0; i <= n + m; i++)
printf("%d ", (int)(F[i].r + 0.5));
}

测试输入数据

1
2
3
2 2
1 2 3
4 5 6

测试输出数据

1
4 13 28 27 18

参考资料

  1. 〔manim | 算法 | 互动〕具体学习并实现快速傅里叶变换(FFT)| 多项式乘法 | 快速求卷积 | 学习区首发互动视频_哔哩哔哩_bilibili
  2. 快速傅里叶变换(FFT)——有史以来最巧妙的算法?_哔哩哔哩_bilibili
  3. 【官方双语】那么……什么是卷积?_哔哩哔哩_bilibili
  4. 这个算法改变了世界_哔哩哔哩_bilibili
  5. 【官方双语】形象展示傅里叶变换_哔哩哔哩_bilibili
  6. 【官方双语】微分方程概论-第五章:在3.14分钟内理解e^iπ_哔哩哔哩_bilibili
  7. 【官方双语】欧拉公式与初等群论_哔哩哔哩_bilibili