Skip to content

Commit 1d08d3e

Browse files
committed
✅ Cover covariance eigh and auto solver
1 parent e0a8da5 commit 1d08d3e

File tree

1 file changed

+49
-3
lines changed

1 file changed

+49
-3
lines changed

tests/test_fit_transform.py

+49-3
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@
99
from torch_pca import PCA
1010

1111

12-
def test_basic() -> None:
13-
"""Basic tests."""
12+
def test_fullsvd() -> None:
13+
"""Test basic full SVD."""
1414
input_1 = torch.load("tests/input_data.pt").to(torch.float32) + 2.0
15-
torch_model = PCA(n_components=2).fit(input_1)
15+
torch_model = PCA(
16+
n_components=2,
17+
svd_solver="full",
18+
).fit(input_1)
1619
sklearn_model = PCA_sklearn(
1720
n_components=2,
1821
svd_solver="full",
@@ -76,3 +79,46 @@ def test_centering() -> None:
7679
# Unkown centering
7780
with pytest.raises(ValueError, match="Unknown centering.*"):
7881
model.transform(inputs_1, center="UNKNOWN")
82+
83+
84+
def test_covariance_eigh() -> None:
85+
"""Test SVD based on covariance matrix."""
86+
input_1 = torch.load("tests/input_data.pt").to(torch.float32) + 2.0
87+
torch_model = PCA(n_components=2, svd_solver="covariance_eigh").fit(input_1)
88+
sklearn_model = PCA_sklearn(
89+
n_components=2,
90+
svd_solver="covariance_eigh",
91+
whiten=False,
92+
).fit(input_1)
93+
for attr_name in ["components_", "explained_variance_", "singular_values_"]:
94+
attr_torch = getattr(torch_model, attr_name)
95+
attr_sklearn = getattr(sklearn_model, attr_name)
96+
check.is_true(
97+
torch.allclose(
98+
attr_torch,
99+
torch.tensor(attr_sklearn, dtype=torch.float32),
100+
rtol=1e-5,
101+
atol=1e-5,
102+
),
103+
f"Failed for attribute {attr_name}",
104+
)
105+
106+
107+
def test_fail_svd() -> None:
108+
"""Test unknown SVD solver."""
109+
input_1 = torch.randn(200, 10)
110+
with pytest.raises(ValueError, match="Unknown SVD solver.*"):
111+
PCA(n_components=2, svd_solver="UNKNOWN").fit(input_1)
112+
113+
114+
def test_auto_svd() -> None:
115+
"""Test auto SVD solver."""
116+
input_1 = torch.randn(200, 10)
117+
model = PCA(n_components=2).fit(input_1)
118+
check.equal(model.svd_solver_, "covariance_eigh")
119+
input_2 = torch.randn(50, 8)
120+
model = PCA(n_components=2).fit(input_2)
121+
check.equal(model.svd_solver_, "full")
122+
input_3 = torch.randn(10, 1200)
123+
model = PCA(n_components=2).fit(input_3)
124+
check.equal(model.svd_solver_, "full")

0 commit comments

Comments
 (0)