-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathfwi_example_minConf.jl
64 lines (49 loc) · 1.84 KB
/
fwi_example_minConf.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# 2D FWI on Overthrust model with SPG using minConf library
# Author: Philipp Witte, pwitte@eoas.ubc.ca
# Date: December 2017
#
using Statistics, Random, LinearAlgebra, PythonPlot
using JUDI, SlimOptim, HDF5, SegyIO
# Load starting model
n,d,o,m0 = read(h5open("$(JUDI.JUDI_DATA)/overthrust_model.h5","r"), "n", "d", "o", "m0")
model0 = Model((n[1],n[2]), (d[1],d[2]), (o[1],o[2]), m0)
# Bound constraints
v0 = sqrt.(1f0 ./ m0)
vmin = ones(Float32,model0.n) .* 1.3f0
vmax = ones(Float32,model0.n) .* 6.5f0
vmin[:,1:21] .= v0[:,1:21] # keep water column fixed
vmax[:,1:21] .= v0[:,1:21]
# Slowness squared [s^2/km^2]
mmin = vec((1f0 ./ vmax).^2)
mmax = vec((1f0 ./ vmin).^2)
# Load data
block = segy_read("$(JUDI.JUDI_DATA)/overthrust_shot_records.segy")
d_obs = judiVector(block)
# Set up wavelet
src_geometry = Geometry(block; key="source")
wavelet = ricker_wavelet(src_geometry.t[1],src_geometry.dt[1],0.008f0) # 8 Hz wavelet
q = judiVector(src_geometry,wavelet)
############################### FWI ###########################################
# Optimization parameters
fevals = parse(Int, get(ENV, "NITER", "10"))
batchsize = 8
# Objective function for minConf library
count = 0
function objective_function(x)
model0.m .= reshape(x,model0.n);
# fwi function value and gradient
i = randperm(d_obs.nsrc)[1:batchsize]
fval, grad = fwi_objective(model0, q[i], d_obs[i])
grad = .125f0*grad/maximum(abs.(grad)) # scale for line search
global count; count+= 1
return fval, grad
end
# Bound projection
proj(x) = reshape(median([vec(mmin) vec(x) vec(mmax)]; dims=2),model0.n)
# FWI with SPG
options = spg_options(verbose=3, maxIter=fevals, memory=3)
sol = spg(objective_function, model0.m, proj, options)
# Plot result
imshow(reshape(sqrt.(1f0 ./ sol.x), model0.n)', extent=[0, 10, 3, 0])
xlabel("Lateral position [km]")
ylabel("Depth [km]")