Skip to content

Commit ca605df

Browse files
authored
Merge pull request #8 from valentingol/gpudtype
🐛 Solve GPU and different dtypes issues
2 parents 9bff944 + 225be4b commit ca605df

File tree

4 files changed

+45
-6
lines changed

4 files changed

+45
-6
lines changed

src/torch_pca/pca_main.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ def fit(self, inputs: Tensor, *, determinist: bool = True) -> "PCA":
159159
PCA
160160
The PCA model fitted on the input data.
161161
"""
162+
# Auto-cast to float32 because float16 is not supported
163+
if inputs.dtype == torch.float16:
164+
inputs = inputs.to(torch.float32)
165+
162166
if self.svd_solver_ == "auto":
163167
self.svd_solver_ = choose_svd_solver(
164168
inputs=inputs,
@@ -184,14 +188,16 @@ def fit(self, inputs: Tensor, *, determinist: bool = True) -> "PCA":
184188
eigenvals[eigenvals < 0.0] = 0.0
185189
# Inverted indices
186190
idx = range(eigenvals.size(0) - 1, -1, -1)
187-
idx = torch.LongTensor(idx)
191+
idx = torch.LongTensor(idx).to(eigenvals.device)
188192
explained_variance = eigenvals.index_select(0, idx)
189193
total_var = torch.sum(explained_variance)
190194
# Compute equivalent variables to full SVD output
191195
vh_mat = eigenvecs.T.index_select(0, idx)
192196
coefs = torch.sqrt(explained_variance * (self.n_samples_ - 1))
193197
u_mat = None
194198
elif self.svd_solver_ == "randomized":
199+
if self.n_components_ is None:
200+
self.n_components_ = min(inputs.shape[-2:])
195201
if (
196202
not isinstance(self.n_components_, int)
197203
or int(self.n_components_) != self.n_components_
@@ -267,11 +273,22 @@ def transform(self, inputs: Tensor, center: str = "fit") -> Tensor:
267273
"""
268274
self._check_fitted("transform")
269275
assert self.components_ is not None # for mypy
270-
transformed = inputs @ self.components_.T
276+
assert self.mean_ is not None # for mypy
277+
components = (
278+
self.components_.to(torch.float16)
279+
if inputs.dtype == torch.float16
280+
else self.components_
281+
)
282+
mean = (
283+
self.mean_.to(torch.float16)
284+
if inputs.dtype == torch.float16
285+
else self.mean_
286+
)
287+
transformed = inputs @ components.T
271288
if center == "fit":
272-
transformed -= self.mean_ @ self.components_.T
289+
transformed -= mean @ components.T
273290
elif center == "input":
274-
transformed -= inputs.mean(dim=-2, keepdim=True) @ self.components_.T
291+
transformed -= inputs.mean(dim=-2, keepdim=True) @ components.T
275292
elif center != "none":
276293
raise ValueError(
277294
"Unknown centering, `center` argument should be "

src/torch_pca/random_svd.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def lu_normalizer(inputs: Tensor) -> Tuple[Tensor, Tensor]:
3434

3535
if random_state is not None:
3636
torch.manual_seed(random_state)
37-
proj_mat = torch.randn(inputs.shape[-1], size, device=inputs.device)
37+
proj_mat = torch.randn(
38+
inputs.shape[-1], size, device=inputs.device, dtype=inputs.dtype
39+
)
3840
if power_iteration_normalizer == "auto":
3941
power_iteration_normalizer = "none" if n_iter <= 2 else "QR"
4042
qr_normalizer = torch.linalg.qr

src/torch_pca/svd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def svd_flip(u_mat: Optional[Tensor], vh_mat: Tensor) -> Tuple[Tensor, Tensor]:
109109
Adjusted V^H matrix.
110110
"""
111111
max_abs_v_rows = torch.argmax(torch.abs(vh_mat), dim=1)
112-
shift = torch.arange(vh_mat.shape[0])
112+
shift = torch.arange(vh_mat.shape[0]).to(vh_mat.device)
113113
indices = max_abs_v_rows + shift * vh_mat.shape[1]
114114
flat_vh = torch.reshape(vh_mat, (-1,))
115115
signs = torch.sign(torch.take_along_dim(flat_vh, indices, dim=0))

tests/test_gpu.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Test PCA with GPU and different dtypes."""
2+
3+
# Copyright (c) 2024 Valentin Goldité. All Rights Reserved.
4+
import pytest_check as check
5+
import torch
6+
7+
from torch_pca import PCA
8+
9+
10+
def test_gpu() -> None:
11+
"""Test with GPU and different dtypes."""
12+
inputs = torch.load("tests/input_data.pt").to("cuda:0")
13+
for dtype in [torch.float32, torch.float16, torch.float64]:
14+
inputs = inputs.to(dtype)
15+
out1 = PCA(svd_solver="full").fit_transform(inputs)
16+
out2 = PCA(svd_solver="covariance_eigh").fit_transform(inputs)
17+
out3 = PCA(svd_solver="randomized", random_state=0).fit_transform(inputs)
18+
for out in [out1, out2, out3]:
19+
check.equal(str(out.device), "cuda:0")
20+
check.equal(out.dtype, dtype)

0 commit comments

Comments
 (0)