Skip to content

Commit

Permalink
Added docs
Browse files Browse the repository at this point in the history
  • Loading branch information
norm4nn committed Nov 8, 2024
1 parent f605f27 commit 4b50ca6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
12 changes: 9 additions & 3 deletions lib/scholar/covariance/shrunk_covariance.ex
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
defmodule Scholar.Covariance.ShrunkCovariance do
@moduledoc """
Covariance estimator with shrinkage.
"""
import Nx.Defn
Expand Down Expand Up @@ -28,6 +30,7 @@ defmodule Scholar.Covariance.ShrunkCovariance do

@opts_schema NimbleOptions.new!(opts_schema)
@doc """
Fit the shrunk covariance model to `x`.
## Options
Expand Down Expand Up @@ -84,15 +87,18 @@ defmodule Scholar.Covariance.ShrunkCovariance do

defnp fit_n(x, opts) do
shrinkage = opts[:shrinkage]

if shrinkage < 0 or shrinkage > 1 do
raise ArgumentError,
"""
expected :shrinkage option to be in [0, 1] range, \
got shrinkage: #{inspect(Nx.shape(x))}\
"""
end

{x, location} = Scholar.Covariance.Utils.center(x, opts[:assume_centered])
covariance =

covariance =
Scholar.Covariance.Utils.empirical_covariance(x)
|> shrunk_covariance(shrinkage)

Expand All @@ -110,7 +116,7 @@ defmodule Scholar.Covariance.ShrunkCovariance do

mask = Nx.iota(Nx.shape(shrunk_cov))
selector = Nx.remainder(mask, num_features + 1) == 0

Nx.select(selector, shrunk_cov + shrinkage * mu, shrunk_cov)
end
end
end
6 changes: 5 additions & 1 deletion test/scholar/covariance/shrunk_covariance_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ defmodule Scholar.Covariance.ShrunkCovarianceTest do
atol: 1.0e-3
)

assert_all_close(model.location, Nx.tensor([-1.015519142150879, -0.4495307505130768, 0.06475571542978287]), atol: 1.0e-3)
assert_all_close(
model.location,
Nx.tensor([-1.015519142150879, -0.4495307505130768, 0.06475571542978287]),
atol: 1.0e-3
)
end

test "fit test 2" do
Expand Down

0 comments on commit 4b50ca6

Please sign in to comment.