From 4ec5bc543649c3006a57bf2741e92f7d0433d72b Mon Sep 17 00:00:00 2001 From: Szymon Date: Mon, 11 Nov 2024 12:14:10 +0100 Subject: [PATCH 1/4] Update lib/scholar/covariance/shrunk_covariance.ex Co-authored-by: Mateusz Sluszniak <56299341+msluszniak@users.noreply.github.com> --- lib/scholar/covariance/shrunk_covariance.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/scholar/covariance/shrunk_covariance.ex b/lib/scholar/covariance/shrunk_covariance.ex index 770c1ca7..6e4e4ab7 100644 --- a/lib/scholar/covariance/shrunk_covariance.ex +++ b/lib/scholar/covariance/shrunk_covariance.ex @@ -115,6 +115,6 @@ defmodule Scholar.Covariance.ShrunkCovariance do mask = Nx.iota(Nx.shape(shrunk_cov)) selector = Nx.remainder(mask, num_features + 1) == 0 - Nx.select(selector, shrunk_cov + shrinkage * mu, shrunk_cov) + shrunk_cov + shrinkage * mu * selector end end From 3f38654c2bdccd9ab724ede8f6b072fbed3cf92d Mon Sep 17 00:00:00 2001 From: Szymon Date: Mon, 11 Nov 2024 12:14:30 +0100 Subject: [PATCH 2/4] Update lib/scholar/covariance/shrunk_covariance.ex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- lib/scholar/covariance/shrunk_covariance.ex | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/scholar/covariance/shrunk_covariance.ex b/lib/scholar/covariance/shrunk_covariance.ex index 6e4e4ab7..8ee30476 100644 --- a/lib/scholar/covariance/shrunk_covariance.ex +++ b/lib/scholar/covariance/shrunk_covariance.ex @@ -36,7 +36,8 @@ defmodule Scholar.Covariance.ShrunkCovariance do ## Return Values - The function returns a struct with the following parameters: + The function returns a struct with the following parameters: + * `:covariance` - Tensor of shape `{num_features, num_features}`. Estimated covariance matrix. * `:location` - Tensor of shape `{num_features,}`. Estimated location, i.e. the estimated mean. From 5323af1eea8a13fe64e4997ee53ccf0ec0ceac1e Mon Sep 17 00:00:00 2001 From: Szymon Date: Tue, 12 Nov 2024 17:58:52 +0100 Subject: [PATCH 3/4] Update lib/scholar/covariance/shrunk_covariance.ex Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com> --- lib/scholar/covariance/shrunk_covariance.ex | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/scholar/covariance/shrunk_covariance.ex b/lib/scholar/covariance/shrunk_covariance.ex index 8ee30476..0302a350 100644 --- a/lib/scholar/covariance/shrunk_covariance.ex +++ b/lib/scholar/covariance/shrunk_covariance.ex @@ -76,8 +76,6 @@ defmodule Scholar.Covariance.ShrunkCovariance do f32[2] [0.18202415108680725, -0.09216632694005966] > - - """ deftransform fit(x, opts \\ []) do From 34f85744ce0dd8cdf7cb07a7a63a8ebc215c4dc1 Mon Sep 17 00:00:00 2001 From: Szymon Date: Tue, 12 Nov 2024 17:59:19 +0100 Subject: [PATCH 4/4] Update lib/scholar/covariance/utils.ex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Krsto Proroković --- lib/scholar/covariance/utils.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/scholar/covariance/utils.ex b/lib/scholar/covariance/utils.ex index e5b62be6..3997bd48 100644 --- a/lib/scholar/covariance/utils.ex +++ b/lib/scholar/covariance/utils.ex @@ -3,7 +3,7 @@ defmodule Scholar.Covariance.Utils do import Nx.Defn require Nx - defn center(x, assume_centered \\ false) do +defn center(x, assume_centered? \\ false) do x = case Nx.shape(x) do {_} -> Nx.new_axis(x, 1)