Skip to content

Commit eb14ddb

Browse files
committed
Minor change
1 parent e0133e4 commit eb14ddb

File tree

1 file changed

+29
-22
lines changed

1 file changed

+29
-22
lines changed

src/lds/learning.py

+29-22
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import scipy.optimize
66
import warnings
7-
import copy
7+
# import copy
88

99
from . import inference
1010
from .tracking import utils
@@ -211,34 +211,41 @@ def closure():
211211
optimizer.zero_grad()
212212
curEval = -log_likelihood_fn()
213213
print(f"in closure, ll={-curEval}")
214-
curEval.backward(retain_graph=True)
214+
for i in range(len(x)):
215+
print(f"x[{i}]={x[i]}")
216+
# curEval.backward(retain_graph=True)
217+
curEval.backward()
215218
return curEval
216219

217220
termination_info = "success: reached maximum number of iterations"
218221
log_like = []
219222
elapsed_time = []
220223
start_time = time.time()
221224
for epoch in range(n_epochs):
222-
prev_x = copy.deepcopy(x)
223-
try:
224-
curEval = optimizer.step(closure)
225-
except RuntimeError:
226-
# begin backtracking
227-
if vars_to_estimate["sigma_a"]:
228-
sigma_a = prev_x.pop(0)
229-
sigma_a.requires_grad = False
230-
if vars_to_estimate["sqrt_diag_R"]:
231-
sqrt_diag_R = prev_x.pop(0)
232-
sqrt_diag_R.requires_grad = False
233-
if vars_to_estimate["m0"]:
234-
m0 = prev_x.pop(0)
235-
m0.requires_grad = False
236-
if vars_to_estimate["sqrt_diag_V0"]:
237-
sqrt_diag_V0 = prev_x.pop(0)
238-
sqrt_diag_V0.requires_grad = False
239-
# end backtracking
240-
termination_info = "nan generated"
241-
break
225+
curEval = optimizer.step(closure)
226+
# optimizer.step(closure)
227+
# prev_x = copy.deepcopy(x)
228+
# try:
229+
# curEval = optimizer.step(closure)
230+
# except RuntimeError:
231+
# breakpoint()
232+
# # begin backtracking
233+
# if vars_to_estimate["sigma_a"]:
234+
# sigma_a = prev_x.pop(0)
235+
# sigma_a.requires_grad = False
236+
# if vars_to_estimate["sqrt_diag_R"]:
237+
# sqrt_diag_R = prev_x.pop(0)
238+
# sqrt_diag_R.requires_grad = False
239+
# if vars_to_estimate["m0"]:
240+
# m0 = prev_x.pop(0)
241+
# m0.requires_grad = False
242+
# if vars_to_estimate["sqrt_diag_V0"]:
243+
# sqrt_diag_V0 = prev_x.pop(0)
244+
# sqrt_diag_V0.requires_grad = False
245+
# # end backtracking
246+
# termination_info = "nan generated"
247+
# break
248+
curEval = -log_likelihood_fn()
242249
log_like.append(-curEval.item())
243250
elapsed_time.append(time.time() - start_time)
244251
print("--------------------------------------------------------------------------------")

0 commit comments

Comments
 (0)