Skip to content

Commit 4297069

Browse files
committed
merged origin branch
2 parents cba0089 + 34f8574 commit 4297069

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

lib/scholar/covariance/shrunk_covariance.ex

+3-4
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ defmodule Scholar.Covariance.ShrunkCovariance do
3636
3737
## Return Values
3838
39-
The function returns a struct with the following parameters:
39+
The function returns a struct with the following parameters:
40+
4041
* `:covariance` - Tensor of shape `{num_features, num_features}`. Estimated covariance matrix.
4142
* `:location` - Tensor of shape `{num_features,}`.
4243
Estimated location, i.e. the estimated mean.
@@ -75,8 +76,6 @@ defmodule Scholar.Covariance.ShrunkCovariance do
7576
f32[2]
7677
[0.18202415108680725, -0.09216632694005966]
7778
>
78-
79-
8079
"""
8180

8281
deftransform fit(x, opts \\ []) do
@@ -115,6 +114,6 @@ defmodule Scholar.Covariance.ShrunkCovariance do
115114
mask = Nx.iota(Nx.shape(shrunk_cov))
116115
selector = Nx.remainder(mask, num_features + 1) == 0
117116

118-
Nx.select(selector, shrunk_cov + shrinkage * mu, shrunk_cov)
117+
shrunk_cov + shrinkage * mu * selector
119118
end
120119
end

0 commit comments

Comments
 (0)