快速傅里叶变换


0.用处#

用来将多项式乘法,也就是卷积,从 O(n2)\Omicron(n ^ 2) 优化到 O(nlogn)\Omicron(n \log n)

FFT 是 DFT(离散傅里叶变换)的优化,用来系数表示法转点值表示法,IFFT 是 IDFT 的优化,作用相反。

什么是点值表示法

因为一个多项式 f(x)=i=0naixif(x) = \sum _ {i = 0} ^ n a _ i x ^ i,如果我们选 nnxix _ i 带入,则可以列出一个关于 aia _ inn 元一次方程组,就可以解出每一个 aia _ i

因此,点集 {(x,f(x))}\{ (x,f(x)) \} 就可以唯一确定一个多项式。

而卷积用点值表示法可以这么写:

f(x)={(x,f(x))},g(x)={(x,g(x))}(fg)(x)={(x,f(x)g(x))}={(x,f(x)×g(x))}f(x) = \{ (x,f(x)) \},g(x) = \{ (x,g(x)) \} (f \ast g)(x) = \{ (x,f(x) \ast g(x)) \} = \{ (x,f(x) \times g(x)) \}

后一个式子就是朴素的复数乘法。

傅里叶变换有个性质:F(f(x)g(x))=F(f(x))F(g(x))\mathcal{F}(f(x) \ast g(x)) = \mathcal{F}(f(x)) \ast \mathcal{F}(g(x)),所以可以用它来优化。

1.前置知识#

复数#

定义 & 运算#

形如 z=a+biz = a + bi 的数为复数,a=za = \Re z,为实部,b=zb = \Im z,为虚部,ii 为虚数单位,i2=1i ^ 2 = -1。复数集为 C\Complex,加减法就是实部虚部对应相加/减。

乘法定义如下: (a+bi)(c+di)=(ac+bci)+(adidb)=(acbd)+(ad+bc)i(a + bi)(c + di) = (ac + bci) + (adi - db) = (ac - bd) + (ad + bc)i

实质就是模长相乘,辐角相加,因为 eiθ=cosθ+isinθ,a,bR,a+bi=keiθe ^ {i \theta} =\cos \theta + i \sin \theta,\forall a,b \in \R,a + bi = k e ^ {i \theta},所以 z1z2=k1eiθ1k2eiθ2=(k1k2)ei(θ1+θ2)z _ 1 z _ 2 = k _ 1e ^ {i \theta _ 1} k _ 2e ^ {i \theta _ 2} = (k _ 1 k _ 2)e ^ {i (\theta _ 1 + \theta _ 2)}

给出封装的代码模板:

CPP
1
2
3
4
5
6
7
8
struct Complex
{
    double re,im;
    Complex(double re,double im):re(re),im(im){};
    Complex operator+(const Complex &a) const {return Complex(re + a.re,im + a.im);}
    Complex operator-(const Complex &a) const {return Complex(re - a.re,im - a.im);}
    Complex operator*(const Complex &a) const {return Complex(re * a.re - im * a.im,re * a.im + im * a.re);}
};

单位根#

满足 ωn=1\omega ^ n = 1 的复数 ω\omega 称为 nn 次单位根。

ωnk\omega _ n ^ k 为第 kknn 次单位根。

则有性质:

  1. ωnk=ωanak\omega _ n ^ k = \omega _ {an} ^ {ak}
  2. ωnk=ωnk+n2\omega _ n ^ k = -\omega _ {n} ^ {k + \frac{n}{2}}
  3. ωn0=ωnn=1\omega _ n ^ 0 = \omega _ n ^ n = 1

2. 正文#

DFT#

其实就是点值表示法中,i,xi=ωni\forall i,x _ i = \omega _ n ^ i

FFT#

在 DFT 的基础上,用计算机的办法优化。

对于一个 n1n - 1 次多项式 f(x)=i=0n1aixif(x) = \sum _ {i = 0} ^ {n - 1} a _ i x ^ i,我们可以按指数的奇偶性来给每一项分类,即 f(x)=(a0+a2x2++an2xn2)+x(a1+a3x2++an1xn2)f(x) = (a _ 0 + a _ 2 x ^ 2 + \ldots + a _ {n - 2} x ^ {n - 2}) + x(a _ 1 + a _ 3 x ^ 2 + \ldots + a _ {n - 1} x ^ {n - 2})

为了方便,我们记 f0(x)=a0+a2x++an2xn21,f1(x)=a1+a3x++an1xn21f _ 0(x) = a _ 0 + a _ 2 x + \ldots + a _ {n - 2} x ^ {\frac{n}{2} - 1},f _ 1 (x) = a _ 1 + a _ 3 x + \ldots + a _ {n - 1} x ^ {\frac{n}{2} - 1}

显然有 f(x)=f0(x2)+xf1(x1)f(x) = f _ 0 (x ^ 2) + x f _ 1 (x ^ 1)

带入 x=ωnk,k<nx = \omega _ n ^ k,k \lt n,有 f(ωnk)=f0(ωnk)+ωnkf1(ωnk)=f0(ωn2k)ωnkf1(ωn2k)f(\omega _ n ^ k) = f _ 0(\omega _ n ^ k) + \omega _ n ^ k f _ 1(\omega _ n ^ k) = f _ 0 (\omega _ {\frac{n}{2}} ^ k) - \omega _ n ^ k f _ 1 (\omega _ {\frac{n}{2}} ^ k)

所以,可以递归求解,当 n=1n = 1 时结束。

IFFT#

就是上边的逆变换。具体操作就是带入的单位根变为 ωnk\omega _ n ^ {-k},即原单位根的倒数,然后再做一次 FFT,最后答案要除以 nn为什么是这样?

这个东西乍一看很烦,实则确实。建议使用全文背诵法记忆

模板代码
CPP
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <bits/stdc++.h>
#define int long long
using namespace std;

constexpr int MAXN = 4e6 + 24;
const double Pi = acos(-1.0);
struct Complex
{
    double re,im;
    Complex(double re = 0.0,double im = 0.0):re(re),im(im){};
    Complex operator+(const Complex &a) const {return Complex(re + a.re,im + a.im);}
    Complex operator-(const Complex &a) const {return Complex(re - a.re,im - a.im);}
    Complex operator*(const Complex &a) const {return Complex(re * a.re - im * a.im,re * a.im + im * a.re);}
};
void FFT(int limit,Complex *a,int type)
{
    if(limit == 1)
        return ;
    Complex a1[limit >> 1],a2[limit >> 1];
    for(int i = 0;i < limit;i += 2)
        a1[i >> 1] = a[i],a2[i >> 1] = a[i + 1];
    FFT(limit >> 1,a1,type);
    FFT(limit >> 1,a2,type);
    Complex omega = Complex(cos(2.0 * Pi / limit),type * sin(2.0 * Pi / limit)),Pow = Complex(1,0);
    for(int i = 0;i < (limit >> 1);i++,Pow = Pow * omega)
    {
        a[i] = a1[i] + Pow * a2[i];
        a[i + (limit >> 1)] = a1[i] - Pow * a2[i];
    }
}
int n,m;
Complex a[MAXN],b[MAXN];
signed main()
{
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin >> n >> m;
    for(int i = 0;i <= n;i++)
        cin >> a[i].re;
    for(int i = 0;i <= m;i++)
        cin >> b[i].re;
    int limit = 1;
    while(limit <= n + m)
        limit <<= 1;
    FFT(limit,a,1);
    FFT(limit,b,1);
    for(int i = 0;i <= limit;i++)
        a[i] = a[i] * b[i];
    FFT(limit,a,-1);
    for(int i = 0;i <= n + m;i++)
        cout << (int)(a[i].re / limit + 0.5) << " ";
}

注意,这个 π\pi 一般用 acos(-1) 来求最精确,但如果你要背的话一定要写 3.141592653589793.14159265358979,少了会 WA。