Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ShrunkCovariance #309

Merged
merged 12 commits into from
Nov 19, 2024
44 changes: 5 additions & 39 deletions lib/scholar/covariance/ledoit_wolf.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ defmodule Scholar.Covariance.LedoitWolf do
defstruct [:covariance, :shrinkage, :location]

opts_schema = [
assume_centered: [
assume_centered?: [
default: false,
type: :boolean,
doc: """
Expand Down Expand Up @@ -93,7 +93,7 @@ defmodule Scholar.Covariance.LedoitWolf do

iex> key = Nx.Random.key(0)
iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0, 0.0]), Nx.tensor([[3.0, 2.0, 1.0], [1.0, 2.0, 3.0], [1.3, 1.0, 2.2]]), shape: {10}, type: :f32)
iex> cov = Scholar.Covariance.LedoitWolf.fit(x, assume_centered: true)
iex> cov = Scholar.Covariance.LedoitWolf.fit(x, assume_centered?: true)
iex> cov.covariance
#Nx.Tensor<
f32[3][3]
Expand All @@ -110,7 +110,7 @@ defmodule Scholar.Covariance.LedoitWolf do
end

defnp fit_n(x, opts) do
{x, location} = center(x, opts)
{x, location} = Scholar.Covariance.Utils.center(x, opts[:assume_centered?])

{covariance, shrinkage} =
ledoit_wolf(x)
Expand All @@ -122,23 +122,6 @@ defmodule Scholar.Covariance.LedoitWolf do
}
end

defnp center(x, opts) do
x =
case Nx.shape(x) do
{_} -> Nx.new_axis(x, 1)
_ -> x
end

location =
if opts[:assume_centered] do
0
else
Nx.mean(x, axes: [0])
end

{x - location, location}
end

defnp ledoit_wolf(x) do
case Nx.shape(x) do
{_n, 1} ->
Expand All @@ -149,23 +132,6 @@ defmodule Scholar.Covariance.LedoitWolf do
end
end

defnp empirical_covariance(x) do
n = Nx.axis_size(x, 0)

covariance = Nx.dot(x, [0], x, [0]) / n

case Nx.shape(covariance) do
{} -> Nx.reshape(covariance, {1, 1})
_ -> covariance
end
end

defnp trace(x) do
x
|> Nx.take_diagonal()
|> Nx.sum()
end

defnp ledoit_wolf_shrinkage(x) do
case Nx.shape(x) do
{_, 1} ->
Expand All @@ -182,9 +148,9 @@ defmodule Scholar.Covariance.LedoitWolf do

defnp ledoit_wolf_shrinkage_complex(x) do
{num_samples, num_features} = Nx.shape(x)
emp_cov = empirical_covariance(x)
emp_cov = Scholar.Covariance.Utils.empirical_covariance(x)

emp_cov_trace = trace(emp_cov)
emp_cov_trace = Scholar.Covariance.Utils.trace(emp_cov)
mu = Nx.sum(emp_cov_trace) / num_features

flatten_delta = Nx.flatten(emp_cov)
Expand Down
120 changes: 120 additions & 0 deletions lib/scholar/covariance/shrunk_covariance.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
defmodule Scholar.Covariance.ShrunkCovariance do
@moduledoc """
Covariance estimator with shrinkage.
"""
import Nx.Defn

@derive {Nx.Container, containers: [:covariance, :location]}
defstruct [:covariance, :location]

opts_schema = [
assume_centered?: [
default: false,
type: :boolean,
doc: """
If `true`, data will not be centered before computation.
Useful when working with data whose mean is almost, but not exactly
zero.
If `false`, data will be centered before computation.
"""
],
shrinkage: [
default: 0.1,
type: :float,
doc: "Coefficient in the convex combination used for the computation
of the shrunk estimate. Range is [0, 1]."
]
]

@opts_schema NimbleOptions.new!(opts_schema)
@doc """
Fit the shrunk covariance model to `x`.

## Options

#{NimbleOptions.docs(@opts_schema)}

## Return Values

The function returns a struct with the following parameters:
norm4nn marked this conversation as resolved.
Show resolved Hide resolved
* `:covariance` - Tensor of shape `{num_features, num_features}`. Estimated covariance matrix.
* `:location` - Tensor of shape `{num_features,}`.
Estimated location, i.e. the estimated mean.

## Examples

iex> key = Nx.Random.key(0)
iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0]), Nx.tensor([[0.8, 0.3], [0.2, 0.4]]), shape: {10}, type: :f32)
iex> model = Scholar.Covariance.ShrunkCovariance.fit(x)
iex> model.covariance
#Nx.Tensor<
f32[2][2]
[
[0.7721845507621765, 0.19141492247581482],
[0.19141492247581482, 0.33952537178993225]
]
>
iex> model.location
#Nx.Tensor<
f32[2]
[0.18202415108680725, -0.09216632694005966]
>
iex> key = Nx.Random.key(0)
iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0]), Nx.tensor([[0.8, 0.3], [0.2, 0.4]]), shape: {10}, type: :f32)
iex> model = Scholar.Covariance.ShrunkCovariance.fit(x, shrinkage: 0.4)
iex> model.covariance
#Nx.Tensor<
f32[2][2]
[
[0.7000747323036194, 0.1276099532842636],
[0.1276099532842636, 0.41163527965545654]
]
>
iex> model.location
#Nx.Tensor<
f32[2]
[0.18202415108680725, -0.09216632694005966]
>


norm4nn marked this conversation as resolved.
Show resolved Hide resolved
"""

deftransform fit(x, opts \\ []) do
fit_n(x, NimbleOptions.validate!(opts, @opts_schema))
end

defnp fit_n(x, opts) do
shrinkage = opts[:shrinkage]

if shrinkage < 0 or shrinkage > 1 do
raise ArgumentError,
"""
expected :shrinkage option to be in [0, 1] range, \
got shrinkage: #{inspect(Nx.shape(x))}\
"""
end

{x, location} = Scholar.Covariance.Utils.center(x, opts[:assume_centered?])

covariance =
Scholar.Covariance.Utils.empirical_covariance(x)
|> shrunk_covariance(shrinkage)

%__MODULE__{
covariance: covariance,
location: location
}
end

defnp shrunk_covariance(emp_cov, shrinkage) do
num_features = Nx.axis_size(emp_cov, 1)
shrunk_cov = (1.0 - shrinkage) * emp_cov
emp_cov_trace = Scholar.Covariance.Utils.trace(emp_cov)
mu = Nx.sum(emp_cov_trace) / num_features

mask = Nx.iota(Nx.shape(shrunk_cov))
selector = Nx.remainder(mask, num_features + 1) == 0

Nx.select(selector, shrunk_cov + shrinkage * mu, shrunk_cov)
norm4nn marked this conversation as resolved.
Show resolved Hide resolved
end
end
39 changes: 39 additions & 0 deletions lib/scholar/covariance/utils.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
defmodule Scholar.Covariance.Utils do
@moduledoc false
import Nx.Defn
require Nx

defn center(x, assume_centered \\ false) do
norm4nn marked this conversation as resolved.
Show resolved Hide resolved
x =
case Nx.shape(x) do
{_} -> Nx.new_axis(x, 1)
_ -> x
end

location =
if assume_centered do
0
else
Nx.mean(x, axes: [0])
end

{x - location, location}
end

defn empirical_covariance(x) do
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to use Nx.covariance/2 instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, look at the PR description

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I forgot to ask: @norm4nn did you try setting ddof: 0 when calling Nx.covariance/2?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, ddof: 0 is default. It's strange that the tests are failing, because Nx.covariance/2 does exactly what is implemented here. Could it be the case that the data in your test is not centered and you are using Nx.covariance/2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you are right! This is the case, I will fix this on Thursday.

n = Nx.axis_size(x, 0)

covariance = Nx.dot(x, [0], x, [0]) / n

case Nx.shape(covariance) do
{} -> Nx.reshape(covariance, {1, 1})
_ -> covariance
end
end

defn trace(x) do
x
|> Nx.take_diagonal()
|> Nx.sum()
end
end
4 changes: 2 additions & 2 deletions test/scholar/covariance/ledoit_wolf_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ defmodule Scholar.Covariance.LedoitWolfTest do
)
end

test "fit test - :assume_centered is true" do
test "fit test - :assume_centered? is true" do
key = key()

{x, _new_key} =
Expand All @@ -52,7 +52,7 @@ defmodule Scholar.Covariance.LedoitWolfTest do
type: :f32
)

model = LedoitWolf.fit(x, assume_centered: true)
model = LedoitWolf.fit(x, assume_centered?: true)

assert_all_close(
model.covariance,
Expand Down
Loading
Loading