Skip to content

Commit 48b26cf

Browse files
authored
Merge pull request #650 from guaacoelho/cache_enhancement
Removes the creation of the solver from within the shots iteration loop.
2 parents 3beb8fc + bed03f7 commit 48b26cf

File tree

1 file changed

+38
-29
lines changed

1 file changed

+38
-29
lines changed

pylops/waveeqprocessing/twoway.py

+38-29
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,13 @@ def updatesrc(self, wav):
229229
time_range=self.geometry.time_axis,
230230
)
231231

232-
def _srcillumination_oneshot(self, isrc: int) -> Tuple[NDArray, NDArray]:
232+
def _srcillumination_oneshot(self, solver: AcousticWaveSolverType, isrc: int) -> Tuple[NDArray, NDArray]:
233233
"""Source wavefield and illumination for one shot
234234
235235
Parameters
236236
----------
237+
solver : :obj:`AcousticWaveSolver`
238+
Devito's solver object.
237239
isrc : :obj:`int`
238240
Index of source to model
239241
@@ -245,21 +247,10 @@ def _srcillumination_oneshot(self, isrc: int) -> Tuple[NDArray, NDArray]:
245247
Source illumination
246248
247249
"""
248-
# create geometry for single source
249-
geometry = AcquisitionGeometry(
250-
self.model,
251-
self.geometry.rec_positions,
252-
self.geometry.src_positions[isrc, :],
253-
self.geometry.t0,
254-
self.geometry.tn,
255-
f0=self.geometry.f0,
256-
src_type=self.geometry.src_type,
257-
)
258-
solver = AcousticWaveSolver(self.model, geometry, space_order=self.space_order)
259250

260251
# assign source location to source object with custom wavelet
261252
if hasattr(self, "wav"):
262-
self.wav.coordinates.data[0, :] = self.geometry.src_positions[isrc, :]
253+
self.wav.coordinates.data[0, :] = solver.geometry.src_positions[:]
263254

264255
# source wavefield
265256
u0 = solver.forward(
@@ -279,13 +270,27 @@ def srcillumination_allshots(self, savewav: bool = False) -> None:
279270
Save source wavefield (``True``) or not (``False``)
280271
281272
"""
273+
# create geometry for single source
274+
geometry = AcquisitionGeometry(
275+
self.model,
276+
self.geometry.rec_positions,
277+
self.geometry.src_positions[0, :],
278+
self.geometry.t0,
279+
self.geometry.tn,
280+
f0=self.geometry.f0,
281+
src_type=self.geometry.src_type,
282+
)
283+
284+
solver = AcousticWaveSolver(self.model, geometry, space_order=self.space_order)
285+
282286
nsrc = self.geometry.src_positions.shape[0]
283287
if savewav:
284288
self.src_wavefield = []
285289
self.src_illumination = np.zeros(self.model.shape)
286290

287291
for isrc in range(nsrc):
288-
src_wav, src_ill = self._srcillumination_oneshot(isrc)
292+
solver.geometry.src_positions = self.geometry.src_positions[isrc, :]
293+
src_wav, src_ill = self._srcillumination_oneshot(solver, isrc)
289294
if savewav:
290295
self.src_wavefield.append(src_wav)
291296
self.src_illumination += src_ill
@@ -359,11 +364,13 @@ def _born_allshots(self, dm: NDArray) -> NDArray:
359364
dtot = np.array(dtot).reshape(nsrc, d.shape[0], d.shape[1])
360365
return dtot
361366

362-
def _bornadj_oneshot(self, isrc, dobs):
367+
def _bornadj_oneshot(self, solver: AcousticWaveSolverType, isrc, dobs):
363368
"""Adjoint born modelling for one shot
364369
365370
Parameters
366371
----------
372+
solver : :obj:`AcousticWaveSolver`
373+
Devito's solver object.
367374
isrc : :obj:`float`
368375
Index of source to model
369376
dobs : :obj:`np.ndarray`
@@ -375,25 +382,13 @@ def _bornadj_oneshot(self, isrc, dobs):
375382
Model
376383
377384
"""
378-
# create geometry for single source
379-
geometry = AcquisitionGeometry(
380-
self.model,
381-
self.geometry.rec_positions,
382-
self.geometry.src_positions[isrc, :],
383-
self.geometry.t0,
384-
self.geometry.tn,
385-
f0=self.geometry.f0,
386-
src_type=self.geometry.src_type,
387-
)
388385
# create boundary data
389386
recs = self.geometry.rec.copy()
390387
recs.data[:] = dobs.T[:]
391388

392-
solver = AcousticWaveSolver(self.model, geometry, space_order=self.space_order)
393-
394389
# assign source location to source object with custom wavelet
395390
if hasattr(self, "wav"):
396-
self.wav.coordinates.data[0, :] = self.geometry.src_positions[isrc, :]
391+
self.wav.coordinates.data[0, :] = solver.geometry.src_positions[:]
397392

398393
# source wavefield
399394
if hasattr(self, "src_wavefield"):
@@ -423,11 +418,25 @@ def _bornadj_allshots(self, dobs: NDArray) -> NDArray:
423418
Model
424419
425420
"""
421+
# create geometry for single source
422+
geometry = AcquisitionGeometry(
423+
self.model,
424+
self.geometry.rec_positions,
425+
self.geometry.src_positions[0, :],
426+
self.geometry.t0,
427+
self.geometry.tn,
428+
f0=self.geometry.f0,
429+
src_type=self.geometry.src_type,
430+
)
431+
426432
nsrc = self.geometry.src_positions.shape[0]
427433
mtot = np.zeros(self.model.shape, dtype=np.float32)
428434

435+
solver = AcousticWaveSolver(self.model, geometry, space_order=self.space_order)
436+
429437
for isrc in range(nsrc):
430-
m = self._bornadj_oneshot(isrc, dobs[isrc])
438+
solver.geometry.src_positions = self.geometry.src_positions[isrc, :]
439+
m = self._bornadj_oneshot(solver, isrc, dobs[isrc])
431440
mtot += self._crop_model(m.data, self.model.nbl)
432441
return mtot
433442

0 commit comments

Comments
 (0)