Skip to content

Commit

Permalink
merged origin branch
Browse files Browse the repository at this point in the history
  • Loading branch information
norm4nn committed Nov 14, 2024
2 parents cba0089 + 34f8574 commit 4297069
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions lib/scholar/covariance/shrunk_covariance.ex
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ defmodule Scholar.Covariance.ShrunkCovariance do
## Return Values
The function returns a struct with the following parameters:
The function returns a struct with the following parameters:
* `:covariance` - Tensor of shape `{num_features, num_features}`. Estimated covariance matrix.
* `:location` - Tensor of shape `{num_features,}`.
Estimated location, i.e. the estimated mean.
Expand Down Expand Up @@ -75,8 +76,6 @@ defmodule Scholar.Covariance.ShrunkCovariance do
f32[2]
[0.18202415108680725, -0.09216632694005966]
>
"""

deftransform fit(x, opts \\ []) do
Expand Down Expand Up @@ -115,6 +114,6 @@ defmodule Scholar.Covariance.ShrunkCovariance do
mask = Nx.iota(Nx.shape(shrunk_cov))
selector = Nx.remainder(mask, num_features + 1) == 0

Nx.select(selector, shrunk_cov + shrinkage * mu, shrunk_cov)
shrunk_cov + shrinkage * mu * selector
end
end

0 comments on commit 4297069

Please sign in to comment.