Skip to content

Commit

Permalink
Fixes to simple imputter
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Oct 28, 2024
1 parent 6c02f4f commit c024c5b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 21 deletions.
13 changes: 4 additions & 9 deletions lib/scholar/impute/simple_imputer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ defmodule Scholar.Impute.SimpleImputer do
default: :nan,
doc: ~S"""
The placeholder for the missing values. All occurrences of `:missing_values` will be imputed.
The default value expects there are no NaNs in the input tensor.
"""
],
strategy: [
Expand Down Expand Up @@ -72,17 +74,10 @@ defmodule Scholar.Impute.SimpleImputer do
"""
deftransform fit(x, opts \\ []) do
opts = NimbleOptions.validate!(opts, @opts_schema)

input_rank = Nx.rank(x)

if input_rank != 2 do
raise ArgumentError, "Wrong input rank. Expected: 2, got: #{inspect(input_rank)}"
end

if opts[:missing_values] != :nan and
Nx.any(Nx.is_nan(x)) == Nx.tensor(1, type: :u8) do
raise ArgumentError,
":missing_values other than :nan possible only if there is no Nx.Constant.nan() in the array"
raise ArgumentError, "wrong input rank. Expected: 2, got: #{inspect(input_rank)}"
end

{type, _num_bits} = x_type = Nx.type(x)
Expand All @@ -98,7 +93,7 @@ defmodule Scholar.Impute.SimpleImputer do
{fill_value_type, _} = Nx.type(opts[:fill_value])

raise ArgumentError,
"Wrong type of `:fill_value` for the given data. Expected: :f or :bf, got: #{inspect(fill_value_type)}"
"wrong type of `:fill_value` for the given data. Expected: :f or :bf, got: #{inspect(fill_value_type)}"

true ->
x
Expand Down
14 changes: 2 additions & 12 deletions test/scholar/impute/simple_imputer_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -122,27 +122,17 @@ defmodule SimpleImputerTest do
x = Nx.tensor([1, 2, 2, 3])

assert_raise ArgumentError,
"Wrong input rank. Expected: 2, got: 1",
"wrong input rank. Expected: 2, got: 1",
fn ->
SimpleImputer.fit(x, missing_values: 1, strategy: :mode)
end
end

test "Collision of nan" do
x = generate_data()

assert_raise ArgumentError,
":missing_values other than :nan possible only if there is no Nx.Constant.nan() in the array",
fn ->
SimpleImputer.fit(x, missing_values: 1.0, strategy: :mode)
end
end

test "Wrong :fill_value type" do
x = Nx.tensor([[1.0, 2.0, 2.0, 3.0]])

assert_raise ArgumentError,
"Wrong type of `:fill_value` for the given data. Expected: :f or :bf, got: :s",
"wrong type of `:fill_value` for the given data. Expected: :f or :bf, got: :s",
fn ->
SimpleImputer.fit(x,
missing_values: 1.0,
Expand Down

0 comments on commit c024c5b

Please sign in to comment.