File tree 1 file changed +3
-4
lines changed
1 file changed +3
-4
lines changed Original file line number Diff line number Diff line change @@ -36,7 +36,8 @@ defmodule Scholar.Covariance.ShrunkCovariance do
36
36
37
37
## Return Values
38
38
39
- The function returns a struct with the following parameters:
39
+ The function returns a struct with the following parameters:
40
+
40
41
* `:covariance` - Tensor of shape `{num_features, num_features}`. Estimated covariance matrix.
41
42
* `:location` - Tensor of shape `{num_features,}`.
42
43
Estimated location, i.e. the estimated mean.
@@ -75,8 +76,6 @@ defmodule Scholar.Covariance.ShrunkCovariance do
75
76
f32[2]
76
77
[0.18202415108680725, -0.09216632694005966]
77
78
>
78
-
79
-
80
79
"""
81
80
82
81
deftransform fit ( x , opts \\ [ ] ) do
@@ -115,6 +114,6 @@ defmodule Scholar.Covariance.ShrunkCovariance do
115
114
mask = Nx . iota ( Nx . shape ( shrunk_cov ) )
116
115
selector = Nx . remainder ( mask , num_features + 1 ) == 0
117
116
118
- Nx . select ( selector , shrunk_cov + shrinkage * mu , shrunk_cov )
117
+ shrunk_cov + shrinkage * mu * selector
119
118
end
120
119
end
You can’t perform that action at this time.
0 commit comments