@@ -159,6 +159,10 @@ def fit(self, inputs: Tensor, *, determinist: bool = True) -> "PCA":
159
159
PCA
160
160
The PCA model fitted on the input data.
161
161
"""
162
+ # Auto-cast to float32 because float16 is not supported
163
+ if inputs .dtype == torch .float16 :
164
+ inputs = inputs .to (torch .float32 )
165
+
162
166
if self .svd_solver_ == "auto" :
163
167
self .svd_solver_ = choose_svd_solver (
164
168
inputs = inputs ,
@@ -184,14 +188,16 @@ def fit(self, inputs: Tensor, *, determinist: bool = True) -> "PCA":
184
188
eigenvals [eigenvals < 0.0 ] = 0.0
185
189
# Inverted indices
186
190
idx = range (eigenvals .size (0 ) - 1 , - 1 , - 1 )
187
- idx = torch .LongTensor (idx )
191
+ idx = torch .LongTensor (idx ). to ( eigenvals . device )
188
192
explained_variance = eigenvals .index_select (0 , idx )
189
193
total_var = torch .sum (explained_variance )
190
194
# Compute equivalent variables to full SVD output
191
195
vh_mat = eigenvecs .T .index_select (0 , idx )
192
196
coefs = torch .sqrt (explained_variance * (self .n_samples_ - 1 ))
193
197
u_mat = None
194
198
elif self .svd_solver_ == "randomized" :
199
+ if self .n_components_ is None :
200
+ self .n_components_ = min (inputs .shape [- 2 :])
195
201
if (
196
202
not isinstance (self .n_components_ , int )
197
203
or int (self .n_components_ ) != self .n_components_
@@ -267,11 +273,22 @@ def transform(self, inputs: Tensor, center: str = "fit") -> Tensor:
267
273
"""
268
274
self ._check_fitted ("transform" )
269
275
assert self .components_ is not None # for mypy
270
- transformed = inputs @ self .components_ .T
276
+ assert self .mean_ is not None # for mypy
277
+ components = (
278
+ self .components_ .to (torch .float16 )
279
+ if inputs .dtype == torch .float16
280
+ else self .components_
281
+ )
282
+ mean = (
283
+ self .mean_ .to (torch .float16 )
284
+ if inputs .dtype == torch .float16
285
+ else self .mean_
286
+ )
287
+ transformed = inputs @ components .T
271
288
if center == "fit" :
272
- transformed -= self . mean_ @ self . components_ .T
289
+ transformed -= mean @ components .T
273
290
elif center == "input" :
274
- transformed -= inputs .mean (dim = - 2 , keepdim = True ) @ self . components_ .T
291
+ transformed -= inputs .mean (dim = - 2 , keepdim = True ) @ components .T
275
292
elif center != "none" :
276
293
raise ValueError (
277
294
"Unknown centering, `center` argument should be "
0 commit comments