-
Notifications
You must be signed in to change notification settings - Fork 595
/
Copy pathfft.py
executable file
·129 lines (117 loc) · 5.4 KB
/
fft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import blst
P1_INF = blst.P1(bytes.fromhex("400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"))
def _simple_ft(vals, modulus, roots_of_unity):
L = len(roots_of_unity)
o = []
for i in range(L):
last = P1_INF if type(vals[0]) == blst.P1 else 0
for j in range(L):
if type(vals[0]) == blst.P1:
#last = b.add(last, b.multiply(vals[j], roots_of_unity[(i*j)%L]))
last.add(vals[j].dup().mul(roots_of_unity[(i*j)%L]))
else:
last += vals[j] * roots_of_unity[(i*j)%L]
o.append(last if type(last) == blst.P1 else last % modulus)
return o
def _fft(vals, modulus, roots_of_unity):
if len(vals) <= 4 and type(vals[0]) != blst.P1:
#return vals
return _simple_ft(vals, modulus, roots_of_unity)
elif len(vals) == 1 and type(vals[0]) == blst.P1:
return vals
L = _fft(vals[::2], modulus, roots_of_unity[::2])
R = _fft(vals[1::2], modulus, roots_of_unity[::2])
o = [0 for i in vals]
for i, (x, y) in enumerate(zip(L, R)):
#y_times_root = b.multiply(y, roots_of_unity[i]) if type(y) == tuple else y*roots_of_unity[i]
y_times_root = y.dup().mult(roots_of_unity[i]) if type(y) == blst.P1 else y*roots_of_unity[i]
#o[i] = b.add(x, y_times_root) if type(x) == tuple else (x+y_times_root) % modulus
o[i] = x.dup().add(y_times_root) if type(y) == blst.P1 else (x+y_times_root) % modulus
#o[i+len(L)] = b.add(x, b.neg(y_times_root)) if type(x) == tuple else (x-y_times_root) % modulus
o[i+len(L)] = x.dup().add(y_times_root.dup().neg()) if type(x) == blst.P1 else (x-y_times_root) % modulus
return o
def expand_root_of_unity(root_of_unity, modulus):
# Build up roots of unity
rootz = [1, root_of_unity]
while rootz[-1] != 1:
rootz.append((rootz[-1] * root_of_unity) % modulus)
return rootz
def fft(vals, modulus, root_of_unity, inv=False):
rootz = expand_root_of_unity(root_of_unity, modulus)
# Fill in vals with zeroes if needed
if len(rootz) > len(vals) + 1:
vals = vals + [0] * (len(rootz) - len(vals) - 1)
if inv:
# Inverse FFT
invlen = pow(len(vals), modulus-2, modulus)
if type(vals[0]) == blst.P1:
return [x.dup().mult(invlen) for x in
_fft(vals, modulus, rootz[:0:-1])]
else:
return [(x*invlen) % modulus for x in
_fft(vals, modulus, rootz[:0:-1])]
else:
# Regular FFT
return _fft(vals, modulus, rootz[:-1])
#Note on call hierarchy: the project is not using the functions below.
# Evaluates f(x) for f in evaluation form
def inv_fft_at_point(vals, modulus, root_of_unity, x):
if len(vals) == 1:
return vals[0]
# 1/2 in the field
half = (modulus + 1)//2
# 1/w
inv_root = pow(root_of_unity, len(vals)-1, modulus)
# f(-x) in evaluation form
f_of_minus_x_vals = vals[len(vals)//2:] + vals[:len(vals)//2]
# e(x) = (f(x) + f(-x)) / 2 in evaluation form
evens = [(f+g) * half % modulus for f,g in zip(vals, f_of_minus_x_vals)]
# o(x) = (f(x) - f(-x)) / 2 in evaluation form
odds = [(f-g) * half % modulus for f,g in zip(vals, f_of_minus_x_vals)]
# e(x^2) + coordinate * x * o(x^2) in evaluation form
comb = [(o * x * inv_root**i + e) % modulus for i, (o, e) in enumerate(zip(odds, evens))]
return inv_fft_at_point(comb[:len(comb)//2], modulus, root_of_unity ** 2 % modulus, x**2 % modulus)
def shift_domain(vals, modulus, root_of_unity, factor):
if len(vals) == 1:
return vals
# 1/2 in the field
half = (modulus + 1)//2
# 1/w
inv_factor = pow(factor, modulus - 2, modulus)
half_length = len(vals)//2
# f(-x) in evaluation form
f_of_minus_x_vals = vals[half_length:] + vals[:half_length]
# e(x) = (f(x) + f(-x)) / 2 in evaluation form
evens = [(f+g) * half % modulus for f,g in zip(vals, f_of_minus_x_vals)]
print('e', evens)
# o(x) = (f(x) - f(-x)) / 2 in evaluation form
odds = [(f-g) * half % modulus for f,g in zip(vals, f_of_minus_x_vals)]
print('o', odds)
shifted_evens = shift_domain(evens[:half_length], modulus, root_of_unity ** 2 % modulus, factor ** 2 % modulus)
print('se', shifted_evens)
shifted_odds = shift_domain(odds[:half_length], modulus, root_of_unity ** 2 % modulus, factor ** 2 % modulus)
print('so', shifted_odds)
return (
[(e + inv_factor * o) % modulus for e, o in zip(shifted_evens, shifted_odds)] +
[(e - inv_factor * o) % modulus for e, o in zip(shifted_evens, shifted_odds)]
)
def shift_poly(poly, modulus, factor):
factor_power = 1
inv_factor = pow(factor, modulus - 2, modulus)
o = []
for p in poly:
o.append(p * factor_power % modulus)
factor_power = factor_power * inv_factor % modulus
return o
def mul_polys(a, b, modulus, root_of_unity):
rootz = [1, root_of_unity]
while rootz[-1] != 1:
rootz.append((rootz[-1] * root_of_unity) % modulus)
if len(rootz) > len(a) + 1:
a = a + [0] * (len(rootz) - len(a) - 1)
if len(rootz) > len(b) + 1:
b = b + [0] * (len(rootz) - len(b) - 1)
x1 = _fft(a, modulus, rootz[:-1])
x2 = _fft(b, modulus, rootz[:-1])
return _fft([(v1*v2)%modulus for v1,v2 in zip(x1,x2)],
modulus, rootz[:0:-1])