-
Notifications
You must be signed in to change notification settings - Fork 0
/
FFT.cpp
90 lines (77 loc) · 2.05 KB
/
FFT.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
//Si la base es un numero complejo en lugar de una clase propia es algo mas lento
#include <vector>
#include <iostream>
#include <complex>
#include <math.h>
#include <set>
#include <string.h>
using namespace std;
#define optimizar_io ios_base::sync_with_stdio(0);cin.tie(0);
//typedef complex<double> base;
struct base{
double r,i;
base(double r=0, double i=0):r(r), i(i){}
double real()const{return r;}
void operator/=(const int c){r/=c, i/=c;}
};
base operator*(const base &a, const base &b){ return base(a.r*b.r-a.i*b.i, a.r*b.i+a.i*b.r); }
base operator+(const base &a, const base &b){ return base(a.r+b.r, a.i+b.i); }
base operator-(const base &a, const base &b){ return base(a.r-b.r, a.i-b.i); }
const int MAXN = 10000000;
int rev[MAXN];
base wlen_pw[MAXN];
void fft (base a[], int n, bool invert) {
for (int i=0; i<n; ++i)
if (i < rev[i])
swap (a[i], a[rev[i]]);
for (int len=2; len<=n; len<<=1) {
double ang = 2*M_PI/len * (invert?-1:+1);
int len2 = len>>1;
base wlen (cos(ang), sin(ang));
wlen_pw[0] = base (1, 0);
for (int i=1; i<len2; ++i)
wlen_pw[i] = wlen_pw[i-1] * wlen;
for (int i=0; i<n; i+=len) {
base t,
*pu = a+i,
*pv = a+i+len2,
*pu_end = a+i+len2,
*pw = wlen_pw;
for (; pu!=pu_end; ++pu, ++pv, ++pw) {
t = *pv * *pw;
*pv = *pu - t;
*pu = *pu + t;
}
}
}
if (invert)
for (int i=0; i<n; ++i)
a[i] /= n;
}
void calc_rev (int n, int log_n) {
for (int i=0; i<n; ++i) {
rev[i] = 0;
for (int j=0; j<log_n; ++j)
if (i & (1<<j))
rev[i] |= 1<<(log_n-1-j);
}
}
inline static void multiply(const vector<int> &a, const vector<int> &b, vector<int> &res) {
vector<base> fa (a.begin(), a.end()), fb (b.begin(), b.end());
int n=1;
int logn = 0;
while(n < max(a.size(), b.size()))
n <<= 1, logn++;
n <<= 1, logn++;
calc_rev(n, logn);
fa.resize (n), fb.resize (n);
fft (&fa[0], n, false), fft (&fb[0], n, false);
for (int i = 0; i < n; i++){
fa[i] = fa[i] * fb[i];
}
fft (&fa[0], n, true);
res.resize(n);
for (int i = 0; i < n; i++){
res[i] = int (fa[i].real() + 0.5);
}
}