Skip to content

Commit

Permalink
address feedback, refactor to reuse code
Browse files Browse the repository at this point in the history
  • Loading branch information
santiago-imelio committed Dec 13, 2023
1 parent 84643fe commit c19c200
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
10 changes: 2 additions & 8 deletions lib/scholar/preprocessing.ex
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,8 @@ defmodule Scholar.Preprocessing do
>
"""
deftransform standard_scale(tensor, opts \\ []) do
standard_scale_n(tensor, NimbleOptions.validate!(opts, @general_schema))
end

defnp standard_scale_n(tensor, opts) do
std = Nx.standard_deviation(tensor, axes: opts[:axes], keep_axes: true)
mean_reduced = Nx.mean(tensor, axes: opts[:axes], keep_axes: true)
mean_reduced = Nx.select(std == 0, 0.0, mean_reduced)
(tensor - mean_reduced) / Nx.select(std == 0, 1.0, std)
opts = NimbleOptions.validate!(opts, @general_schema)
Scholar.Preprocessing.StandardScaler.fit_transform(tensor, opts)
end

@doc """
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defmodule Scholar.Scaler.StandardScaler do
defmodule Scholar.Preprocessing.StandardScaler do
import Nx.Defn

defstruct [:deviation, :mean]
Expand All @@ -17,20 +17,27 @@ defmodule Scholar.Scaler.StandardScaler do

deftransform fit(tensor, opts \\ []) do
NimbleOptions.validate!(opts, @opts_schema)
{std, mean} = fit_n(tensor, opts)

%__MODULE__{deviation: std, mean: mean}
end

defnp fit_n(tensor, opts) do
std = Nx.standard_deviation(tensor, axes: opts[:axes], keep_axes: true)
mean_reduced = Nx.mean(tensor, axes: opts[:axes], keep_axes: true)
mean_reduced = Nx.select(Nx.equal(std, 0), 0.0, mean_reduced)
%__MODULE__{deviation: std, mean: mean_reduced}

{std, mean_reduced}
end

deftransform transform(tensor, %__MODULE__{deviation: std, mean: mean}) do
deftransform transform(%__MODULE__{deviation: std, mean: mean}, tensor) do
scale(tensor, std, mean)
end

deftransform fit_transform(tensor, opts \\ []) do
scaler = __MODULE__.fit(tensor, opts)
__MODULE__.transform(tensor, scaler)
defn fit_transform(tensor, opts \\ []) do
tensor
|> fit(opts)
|> transform(tensor)
end

defnp scale(tensor, std, mean) do
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defmodule StandardScalerTest do
use Scholar.Case, async: true
alias Scholar.Scaler.StandardScaler
alias Scholar.Preprocessing.StandardScaler

describe "fit_transform/2" do
test "applies standard scaling to data" do
Expand Down

0 comments on commit c19c200

Please sign in to comment.