|
9 | 9 | from torch_pca import PCA
|
10 | 10 |
|
11 | 11 |
|
12 |
| -def test_basic() -> None: |
13 |
| - """Basic tests.""" |
| 12 | +def test_fullsvd() -> None: |
| 13 | + """Test basic full SVD.""" |
14 | 14 | input_1 = torch.load("tests/input_data.pt").to(torch.float32) + 2.0
|
15 |
| - torch_model = PCA(n_components=2).fit(input_1) |
| 15 | + torch_model = PCA( |
| 16 | + n_components=2, |
| 17 | + svd_solver="full", |
| 18 | + ).fit(input_1) |
16 | 19 | sklearn_model = PCA_sklearn(
|
17 | 20 | n_components=2,
|
18 | 21 | svd_solver="full",
|
@@ -76,3 +79,46 @@ def test_centering() -> None:
|
76 | 79 | # Unkown centering
|
77 | 80 | with pytest.raises(ValueError, match="Unknown centering.*"):
|
78 | 81 | model.transform(inputs_1, center="UNKNOWN")
|
| 82 | + |
| 83 | + |
| 84 | +def test_covariance_eigh() -> None: |
| 85 | + """Test SVD based on covariance matrix.""" |
| 86 | + input_1 = torch.load("tests/input_data.pt").to(torch.float32) + 2.0 |
| 87 | + torch_model = PCA(n_components=2, svd_solver="covariance_eigh").fit(input_1) |
| 88 | + sklearn_model = PCA_sklearn( |
| 89 | + n_components=2, |
| 90 | + svd_solver="covariance_eigh", |
| 91 | + whiten=False, |
| 92 | + ).fit(input_1) |
| 93 | + for attr_name in ["components_", "explained_variance_", "singular_values_"]: |
| 94 | + attr_torch = getattr(torch_model, attr_name) |
| 95 | + attr_sklearn = getattr(sklearn_model, attr_name) |
| 96 | + check.is_true( |
| 97 | + torch.allclose( |
| 98 | + attr_torch, |
| 99 | + torch.tensor(attr_sklearn, dtype=torch.float32), |
| 100 | + rtol=1e-5, |
| 101 | + atol=1e-5, |
| 102 | + ), |
| 103 | + f"Failed for attribute {attr_name}", |
| 104 | + ) |
| 105 | + |
| 106 | + |
| 107 | +def test_fail_svd() -> None: |
| 108 | + """Test unknown SVD solver.""" |
| 109 | + input_1 = torch.randn(200, 10) |
| 110 | + with pytest.raises(ValueError, match="Unknown SVD solver.*"): |
| 111 | + PCA(n_components=2, svd_solver="UNKNOWN").fit(input_1) |
| 112 | + |
| 113 | + |
| 114 | +def test_auto_svd() -> None: |
| 115 | + """Test auto SVD solver.""" |
| 116 | + input_1 = torch.randn(200, 10) |
| 117 | + model = PCA(n_components=2).fit(input_1) |
| 118 | + check.equal(model.svd_solver_, "covariance_eigh") |
| 119 | + input_2 = torch.randn(50, 8) |
| 120 | + model = PCA(n_components=2).fit(input_2) |
| 121 | + check.equal(model.svd_solver_, "full") |
| 122 | + input_3 = torch.randn(10, 1200) |
| 123 | + model = PCA(n_components=2).fit(input_3) |
| 124 | + check.equal(model.svd_solver_, "full") |
0 commit comments