Skip to content

Commit c0d2f08

Browse files
committed
Numba for spectrum computation
1 parent 8a48392 commit c0d2f08

File tree

2 files changed

+205
-45
lines changed

2 files changed

+205
-45
lines changed

low_freq_dev/random.ipynb

-25
Original file line numberDiff line numberDiff line change
@@ -97,31 +97,6 @@
9797
"plt.show()"
9898
]
9999
},
100-
{
101-
"cell_type": "code",
102-
"execution_count": null,
103-
"metadata": {},
104-
"outputs": [],
105-
"source": []
106-
},
107-
{
108-
"cell_type": "markdown",
109-
"metadata": {},
110-
"source": [
111-
"-------------------------------------\n",
112-
"\n",
113-
"## BELOW HERE, WE ARE FIXING FUCKING SHIT THE FUCK UP LET'S FUCKING GOOOOOOOOOo"
114-
]
115-
},
116-
{
117-
"cell_type": "code",
118-
"execution_count": 16,
119-
"metadata": {},
120-
"outputs": [],
121-
"source": [
122-
"import von_karman"
123-
]
124-
},
125100
{
126101
"cell_type": "code",
127102
"execution_count": null,

low_freq_dev/von_karman.py

+205-20
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import concurrent.futures
2+
import time
23

34
import matplotlib.pyplot as plt
5+
import numba
46
import numpy as np
57
import scipy
68

@@ -101,6 +103,8 @@ def generate(self, eta_ones=False):
101103

102104
return u1, u2
103105

106+
# ------------------------------------------------------------------------------------------------ #
107+
104108
def compute_spectrum(self, u1=None, u2=None):
105109
"""
106110
Compute the spectrum of generated velocity fields. If no fields are provided, checks
@@ -139,6 +143,160 @@ class attributes for self.u1 and self.u2
139143

140144
return k1_pos, F11, F22
141145

146+
@staticmethod
147+
@numba.njit(parallel=True)
148+
def _compute_spectrum_numba_helper(k1_flat, k1_pos, power_u1_flat, power_u2_flat, dy):
149+
"""
150+
Numba-accelerated helper function for spectrum computation
151+
"""
152+
F11 = np.zeros_like(k1_pos)
153+
F22 = np.zeros_like(k1_pos)
154+
155+
for i in numba.prange(len(k1_pos)):
156+
k1_val = k1_pos[i]
157+
indices = np.where(k1_flat == k1_val)[0]
158+
159+
if len(indices) > 0:
160+
# Calculate mean of power values at these indices
161+
F11[i] = np.mean(power_u1_flat[indices]) * dy
162+
F22[i] = np.mean(power_u2_flat[indices]) * dy
163+
164+
return F11, F22
165+
166+
def compute_spectrum_numba(self, u1=None, u2=None):
167+
"""
168+
Numba-accelerated version of compute_spectrum
169+
170+
Parameters
171+
----------
172+
u1: np.ndarray, optional
173+
x- or longitudinal component of velocity field
174+
u2: np.ndarray, optional
175+
y- or transversal component of velocity field
176+
"""
177+
if u1 is None and u2 is None:
178+
u1 = self.u1
179+
u2 = self.u2
180+
181+
# Compute FFTs
182+
u1_fft = np.fft.fft2(u1)
183+
u2_fft = np.fft.fft2(u2)
184+
185+
# Get positive wavenumbers
186+
k1_pos_mask = self.k1_fft > 0
187+
k1_pos = self.k1_fft[k1_pos_mask]
188+
189+
# Compute power spectra
190+
power_u1 = (np.abs(u1_fft) / (self.N1 * self.N2)) ** 2
191+
power_u2 = (np.abs(u2_fft) / (self.N1 * self.N2)) ** 2
192+
193+
# Flatten k1 for faster processing
194+
k1_flat = self.k1.flatten()
195+
power_u1_flat = power_u1.flatten()
196+
power_u2_flat = power_u2.flatten()
197+
198+
# Call the Numba-accelerated helper function
199+
F11, F22 = self._compute_spectrum_numba_helper(k1_flat, k1_pos, power_u1_flat, power_u2_flat, self.dy)
200+
201+
return k1_pos, F11, F22
202+
203+
def compute_spectrum_numpy_fast(self, u1=None, u2=None):
204+
"""
205+
Fast NumPy implementation of spectrum computation without Numba
206+
207+
Parameters
208+
----------
209+
u1: np.ndarray, optional
210+
x- or longitudinal component of velocity field
211+
u2: np.ndarray, optional
212+
y- or transversal component of velocity field
213+
"""
214+
if u1 is None and u2 is None:
215+
u1 = self.u1
216+
u2 = self.u2
217+
218+
# Compute FFTs
219+
u1_fft = np.fft.fft2(u1)
220+
u2_fft = np.fft.fft2(u2)
221+
222+
# Get positive wavenumbers
223+
k1_pos_mask = self.k1_fft > 0
224+
k1_pos = self.k1_fft[k1_pos_mask]
225+
226+
# Compute power spectra
227+
power_u1 = (np.abs(u1_fft) / (self.N1 * self.N2)) ** 2
228+
power_u2 = (np.abs(u2_fft) / (self.N1 * self.N2)) ** 2
229+
230+
# Create result arrays
231+
F11 = np.zeros_like(k1_pos)
232+
F22 = np.zeros_like(k1_pos)
233+
234+
# Use a vectorized approach with unique k1 values
235+
unique_k1 = np.unique(self.k1)
236+
unique_k1_pos = unique_k1[unique_k1 > 0]
237+
238+
# Ensure we're using exactly the same k1_pos values as the original method
239+
# This fixes the shape mismatch
240+
k1_pos_set = set(k1_pos)
241+
unique_k1_pos = np.array([k for k in unique_k1_pos if k in k1_pos_set])
242+
243+
# Create mapping from k1 values to indices in result arrays
244+
k1_to_idx = {k: i for i, k in enumerate(k1_pos)}
245+
246+
for k1_val in unique_k1_pos:
247+
if k1_val in k1_to_idx:
248+
idx = k1_to_idx[k1_val]
249+
mask = self.k1 == k1_val
250+
F11[idx] = np.mean(power_u1[mask]) * self.dy
251+
F22[idx] = np.mean(power_u2[mask]) * self.dy
252+
253+
return k1_pos, F11, F22
254+
255+
def test_spectrum_computation(self, num_tests=3):
256+
"""
257+
Test and compare different spectrum computation methods
258+
"""
259+
# Generate velocity fields if not already present
260+
if not hasattr(self, "u1") or not hasattr(self, "u2"):
261+
self.generate()
262+
263+
_ = self.compute_spectrum_numba()
264+
265+
print("Original method")
266+
start_time = time.time()
267+
for _ in range(num_tests):
268+
k1_pos, F11, F22 = self.compute_spectrum()
269+
orig_time = (time.time() - start_time) / num_tests
270+
271+
print("Numba method")
272+
start_time = time.time()
273+
for _ in range(num_tests):
274+
k1_pos_numba, F11_numba, F22_numba = self.compute_spectrum_numba()
275+
numba_time = (time.time() - start_time) / num_tests
276+
277+
print("Numpy fast method")
278+
start_time = time.time()
279+
for _ in range(num_tests):
280+
k1_pos_np_fast, F11_np_fast, F22_np_fast = self.compute_spectrum_numpy_fast()
281+
np_fast_time = (time.time() - start_time) / num_tests
282+
283+
# Verify results match
284+
np.testing.assert_allclose(k1_pos, k1_pos_numba, rtol=1e-7)
285+
np.testing.assert_allclose(F11, F11_numba, rtol=1e-7)
286+
np.testing.assert_allclose(F22, F22_numba, rtol=1e-7)
287+
288+
np.testing.assert_allclose(k1_pos, k1_pos_np_fast, rtol=1e-7)
289+
np.testing.assert_allclose(F11, F11_np_fast, rtol=1e-7)
290+
np.testing.assert_allclose(F22, F22_np_fast, rtol=1e-7)
291+
292+
print(f"Original method: {orig_time:.4f} seconds")
293+
print(f"Numba method: {numba_time:.4f} seconds (speedup: {orig_time/numba_time:.2f}x)")
294+
print(f"NumPy fast method: {np_fast_time:.4f} seconds (speedup: {orig_time/np_fast_time:.2f}x)")
295+
296+
return {"original": orig_time, "numba": numba_time, "numpy_fast": np_fast_time}
297+
298+
# ------------------------------------------------------------------------------------------------ #
299+
142300
def analytical_spectrum(self, k1_arr):
143301
"""
144302
Compute the analytical spectrum of the von Karman spectrum
@@ -369,6 +527,13 @@ def diagnostic_plot(u1, u2):
369527
# Spectrum plot
370528

371529

530+
def _compute_single_realization(config):
531+
gen = generator(config)
532+
u1, u2 = gen.generate()
533+
k1_sim, F11, F22 = gen.compute_spectrum(u1, u2)
534+
return k1_sim, F11, F22
535+
536+
372537
def plot_spectrum_comparison(config: dict, num_realizations: int = 10):
373538
"""
374539
Generate velocity fields and plot their spectra compared to analytical spectrum
@@ -388,24 +553,31 @@ def plot_spectrum_comparison(config: dict, num_realizations: int = 10):
388553
k1_custom = np.logspace(-3, 3, 1000) / config["L"]
389554
F11_analytical, F22_analytical = gen.analytical_spectrum(k1_custom)
390555

391-
k1_pos = None
392-
F11_avg = None
393-
F22_avg = None
556+
results = []
394557

395-
for _ in range(num_realizations):
396-
u1, u2 = gen.generate()
397-
k1_sim, F11, F22 = gen.compute_spectrum(u1, u2)
558+
with concurrent.futures.ProcessPoolExecutor() as executor:
559+
futures = [executor.submit(_compute_single_realization, config) for _ in range(num_realizations)]
398560

399-
if F11_avg is None:
400-
k1_pos = k1_sim
401-
F11_avg = F11
402-
F22_avg = F22
403-
else:
404-
F11_avg += F11
405-
F22_avg += F22
561+
for future in concurrent.futures.as_completed(futures):
562+
try:
563+
results.append(future.result())
564+
except Exception as e:
565+
print(f"Error: {e}")
406566

407-
F11_avg /= num_realizations
408-
F22_avg /= num_realizations
567+
if not results:
568+
raise RuntimeError("No results were collected")
569+
570+
k1_pos = results[0][0]
571+
F11_avg = np.zeros_like(k1_pos)
572+
F22_avg = np.zeros_like(k1_pos)
573+
574+
for _, F11, F22 in results:
575+
F11_avg += F11
576+
F22_avg += F22
577+
578+
lr = len(results)
579+
F11_avg /= lr
580+
F22_avg /= lr
409581

410582
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
411583

@@ -502,9 +674,22 @@ def plot_spectrum(config: dict, num_realizations=10):
502674
"N2": 9,
503675
}
504676

505-
plot_spectrum(config)
506-
plot_spectrum_comparison(config)
677+
FINE_CONFIG = {
678+
"L": 500,
679+
"epsilon": 0.01,
680+
"L1_factor": 2,
681+
"L2_factor": 2,
682+
"N1": 10,
683+
"N2": 10,
684+
}
685+
686+
gen = generator(FINE_CONFIG)
687+
gen.generate()
688+
gen.test_spectrum_computation()
507689

508-
gen = generator(config)
509-
u1, u2 = gen.generate()
510-
diagnostic_plot(u1, u2)
690+
# plot_spectrum(config)
691+
# plot_spectrum_comparison(FINE_CONFIG)
692+
693+
# gen = generator(config)
694+
# u1, u2 = gen.generate()
695+
# diagnostic_plot(u1, u2)

0 commit comments

Comments
 (0)