diff --git a/lib/scholar/covariance/shrunk_covariance.ex b/lib/scholar/covariance/shrunk_covariance.ex index 49687560..595f28b2 100644 --- a/lib/scholar/covariance/shrunk_covariance.ex +++ b/lib/scholar/covariance/shrunk_covariance.ex @@ -36,7 +36,8 @@ defmodule Scholar.Covariance.ShrunkCovariance do ## Return Values - The function returns a struct with the following parameters: + The function returns a struct with the following parameters: + * `:covariance` - Tensor of shape `{num_features, num_features}`. Estimated covariance matrix. * `:location` - Tensor of shape `{num_features,}`. Estimated location, i.e. the estimated mean. @@ -75,8 +76,6 @@ defmodule Scholar.Covariance.ShrunkCovariance do f32[2] [0.18202415108680725, -0.09216632694005966] > - - """ deftransform fit(x, opts \\ []) do @@ -115,6 +114,6 @@ defmodule Scholar.Covariance.ShrunkCovariance do mask = Nx.iota(Nx.shape(shrunk_cov)) selector = Nx.remainder(mask, num_features + 1) == 0 - Nx.select(selector, shrunk_cov + shrinkage * mu, shrunk_cov) + shrunk_cov + shrinkage * mu * selector end end