|
4 | 4 | import numpy as np
|
5 | 5 | import scipy.optimize
|
6 | 6 | import warnings
|
7 |
| -import copy |
| 7 | +# import copy |
8 | 8 |
|
9 | 9 | from . import inference
|
10 | 10 | from .tracking import utils
|
@@ -211,34 +211,41 @@ def closure():
|
211 | 211 | optimizer.zero_grad()
|
212 | 212 | curEval = -log_likelihood_fn()
|
213 | 213 | print(f"in closure, ll={-curEval}")
|
214 |
| - curEval.backward(retain_graph=True) |
| 214 | + for i in range(len(x)): |
| 215 | + print(f"x[{i}]={x[i]}") |
| 216 | + # curEval.backward(retain_graph=True) |
| 217 | + curEval.backward() |
215 | 218 | return curEval
|
216 | 219 |
|
217 | 220 | termination_info = "success: reached maximum number of iterations"
|
218 | 221 | log_like = []
|
219 | 222 | elapsed_time = []
|
220 | 223 | start_time = time.time()
|
221 | 224 | for epoch in range(n_epochs):
|
222 |
| - prev_x = copy.deepcopy(x) |
223 |
| - try: |
224 |
| - curEval = optimizer.step(closure) |
225 |
| - except RuntimeError: |
226 |
| - # begin backtracking |
227 |
| - if vars_to_estimate["sigma_a"]: |
228 |
| - sigma_a = prev_x.pop(0) |
229 |
| - sigma_a.requires_grad = False |
230 |
| - if vars_to_estimate["sqrt_diag_R"]: |
231 |
| - sqrt_diag_R = prev_x.pop(0) |
232 |
| - sqrt_diag_R.requires_grad = False |
233 |
| - if vars_to_estimate["m0"]: |
234 |
| - m0 = prev_x.pop(0) |
235 |
| - m0.requires_grad = False |
236 |
| - if vars_to_estimate["sqrt_diag_V0"]: |
237 |
| - sqrt_diag_V0 = prev_x.pop(0) |
238 |
| - sqrt_diag_V0.requires_grad = False |
239 |
| - # end backtracking |
240 |
| - termination_info = "nan generated" |
241 |
| - break |
| 225 | + curEval = optimizer.step(closure) |
| 226 | +# optimizer.step(closure) |
| 227 | +# prev_x = copy.deepcopy(x) |
| 228 | +# try: |
| 229 | +# curEval = optimizer.step(closure) |
| 230 | +# except RuntimeError: |
| 231 | +# breakpoint() |
| 232 | +# # begin backtracking |
| 233 | +# if vars_to_estimate["sigma_a"]: |
| 234 | +# sigma_a = prev_x.pop(0) |
| 235 | +# sigma_a.requires_grad = False |
| 236 | +# if vars_to_estimate["sqrt_diag_R"]: |
| 237 | +# sqrt_diag_R = prev_x.pop(0) |
| 238 | +# sqrt_diag_R.requires_grad = False |
| 239 | +# if vars_to_estimate["m0"]: |
| 240 | +# m0 = prev_x.pop(0) |
| 241 | +# m0.requires_grad = False |
| 242 | +# if vars_to_estimate["sqrt_diag_V0"]: |
| 243 | +# sqrt_diag_V0 = prev_x.pop(0) |
| 244 | +# sqrt_diag_V0.requires_grad = False |
| 245 | +# # end backtracking |
| 246 | +# termination_info = "nan generated" |
| 247 | +# break |
| 248 | + curEval = -log_likelihood_fn() |
242 | 249 | log_like.append(-curEval.item())
|
243 | 250 | elapsed_time.append(time.time() - start_time)
|
244 | 251 | print("--------------------------------------------------------------------------------")
|
|
0 commit comments