Skip to content

Commit e0a8da5

Browse files
committed
✨ Add automatic SVD solver decision
1 parent 9460ddb commit e0a8da5

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

src/torch_pca/svd.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,30 @@
11
"""Functions related to SVD."""
22

33
# Copyright (c) 2024 Valentin Goldité. All Rights Reserved.
4-
from typing import Tuple
4+
from typing import Optional, Tuple, Union
55

66
import torch
77
from torch import Tensor
88

9+
NComponentsType = Union[int, float, None, str]
910

10-
def svd_flip(u_mat: Tensor, vh_mat: Tensor) -> Tuple[Tensor, Tensor]:
11+
12+
def choose_svd_solver(inputs: Tensor, n_components: NComponentsType) -> str:
13+
"""Choose the SVD solver based on the input shape."""
14+
if inputs.shape[-1] <= 1_000 and inputs.shape[-2] >= 10 * inputs.shape[-1]:
15+
return "covariance_eigh"
16+
if max(inputs.shape[-2:]) <= 500 or n_components == "mle":
17+
return "full"
18+
# NOTE: The randomized solver is not implemented yet.
19+
# if (
20+
# isinstance(n_components, float)
21+
# and 1 <= n_components < 0.8 * min(inputs.shape)
22+
# ):
23+
# return "randomized"
24+
return "full"
25+
26+
27+
def svd_flip(u_mat: Optional[Tensor], vh_mat: Tensor) -> Tuple[Tensor, Tensor]:
1128
"""Sign correction to ensure deterministic output from SVD.
1229
1330
Adjusts the columns of u and the rows of v such that

0 commit comments

Comments
 (0)