diff --git a/lib/scholar/preprocessing.ex b/lib/scholar/preprocessing.ex index 67c38b83..77d2b84b 100644 --- a/lib/scholar/preprocessing.ex +++ b/lib/scholar/preprocessing.ex @@ -133,14 +133,8 @@ defmodule Scholar.Preprocessing do > """ deftransform standard_scale(tensor, opts \\ []) do - standard_scale_n(tensor, NimbleOptions.validate!(opts, @general_schema)) - end - - defnp standard_scale_n(tensor, opts) do - std = Nx.standard_deviation(tensor, axes: opts[:axes], keep_axes: true) - mean_reduced = Nx.mean(tensor, axes: opts[:axes], keep_axes: true) - mean_reduced = Nx.select(std == 0, 0.0, mean_reduced) - (tensor - mean_reduced) / Nx.select(std == 0, 1.0, std) + opts = NimbleOptions.validate!(opts, @general_schema) + Scholar.Preprocessing.StandardScaler.fit_transform(tensor, opts) end @doc """ diff --git a/lib/scholar/scaler/standard_scaler.ex b/lib/scholar/preprocessing/standard_scaler.ex similarity index 67% rename from lib/scholar/scaler/standard_scaler.ex rename to lib/scholar/preprocessing/standard_scaler.ex index 676a84e9..9a028bd3 100644 --- a/lib/scholar/scaler/standard_scaler.ex +++ b/lib/scholar/preprocessing/standard_scaler.ex @@ -1,4 +1,4 @@ -defmodule Scholar.Scaler.StandardScaler do +defmodule Scholar.Preprocessing.StandardScaler do import Nx.Defn defstruct [:deviation, :mean] @@ -17,20 +17,27 @@ defmodule Scholar.Scaler.StandardScaler do deftransform fit(tensor, opts \\ []) do NimbleOptions.validate!(opts, @opts_schema) + {std, mean} = fit_n(tensor, opts) + %__MODULE__{deviation: std, mean: mean} + end + + defnp fit_n(tensor, opts) do std = Nx.standard_deviation(tensor, axes: opts[:axes], keep_axes: true) mean_reduced = Nx.mean(tensor, axes: opts[:axes], keep_axes: true) mean_reduced = Nx.select(Nx.equal(std, 0), 0.0, mean_reduced) - %__MODULE__{deviation: std, mean: mean_reduced} + + {std, mean_reduced} end - deftransform transform(tensor, %__MODULE__{deviation: std, mean: mean}) do + deftransform transform(%__MODULE__{deviation: std, mean: mean}, tensor) do scale(tensor, std, mean) end - deftransform fit_transform(tensor, opts \\ []) do - scaler = __MODULE__.fit(tensor, opts) - __MODULE__.transform(tensor, scaler) + defn fit_transform(tensor, opts \\ []) do + tensor + |> fit(opts) + |> transform(tensor) end defnp scale(tensor, std, mean) do diff --git a/test/scholar/scaler/standard_scaler_test.exs b/test/scholar/preprocessing/standard_scaler_test.exs similarity index 94% rename from test/scholar/scaler/standard_scaler_test.exs rename to test/scholar/preprocessing/standard_scaler_test.exs index b788e36d..4470efcf 100644 --- a/test/scholar/scaler/standard_scaler_test.exs +++ b/test/scholar/preprocessing/standard_scaler_test.exs @@ -1,6 +1,6 @@ defmodule StandardScalerTest do use Scholar.Case, async: true - alias Scholar.Scaler.StandardScaler + alias Scholar.Preprocessing.StandardScaler describe "fit_transform/2" do test "applies standard scaling to data" do