Skip to content

Commit abe5132

Browse files
committed
Removed import torch from header
1 parent 703649c commit abe5132

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

src/lds/inference.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11

22
import math
33
import numpy as np
4-
import torch
54

65

76
class OnlineKalmanFilter:
@@ -248,6 +247,7 @@ def filterLDS_SS_withMissingValues_torch(y, B, Q, m0, V0, Z, R):
248247
249248
"""
250249

250+
import torch
251251
if torch.any(torch.isnan(y[:, 0])) or torch.any(torch.isnan(y[:, -1])):
252252
raise ValueError("The first or last observation cannot contain nan")
253253

src/lds/learning.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import time
44
import numpy as np
55
import scipy.optimize
6-
import torch
76
import warnings
87
import copy
98

@@ -176,6 +175,7 @@ def torch_lbfgs_optimize_SS_tracking_diagV0(y, B, sigma_a0, Qe, Z,
176175
"R": True, "m0": True, "V0": True},
177176
disp=True):
178177

178+
import torch
179179
def log_likelihood_fn():
180180
V0 = torch.diag(sqrt_diag_V0**2)
181181
R = torch.diag(sqrt_diag_R**2)
@@ -289,6 +289,7 @@ def torch_adam_optimize_SS_tracking_diagV0(y, B, sigma_a0, Qe, Z,
289289
"V0": True},
290290
):
291291

292+
import torch
292293
def log_likelihood_fn():
293294
V0 = torch.diag(sqrt_diag_V0**2)
294295
R = torch.diag(sqrt_diag_R**2)

0 commit comments

Comments
 (0)