diff --git a/lib/scholar/impute/knn_imputter.ex b/lib/scholar/impute/knn_imputter.ex index 3107eca2..6b5a602e 100644 --- a/lib/scholar/impute/knn_imputter.ex +++ b/lib/scholar/impute/knn_imputter.ex @@ -83,11 +83,12 @@ defmodule Scholar.Impute.KNNImputter do x = if missing_values != :nan, - do: Nx.select(Nx.equal(x, missing_values), :nan, x), - else: x + do: Nx.select(Nx.equal(x, missing_values), :nan, x), + else: x + statistics = + knn_impute(x, num_neighbors: opts[:num_neighbors], missing_values: missing_values) - statistics = knn_impute(x, num_neighbors: opts[:num_neighbors], missing_values: missing_values) %__MODULE__{statistics: statistics, missing_values: missing_values} end @@ -139,7 +140,9 @@ defmodule Scholar.Impute.KNNImputter do neighbor_avg = calculate_knn(x, row, col, rows: rows, num_neighbors: opts[:num_neighbors]) - values_to_impute = Nx.put_slice(values_to_impute, [row, col], Nx.reshape(neighbor_avg, {1, 1})) + values_to_impute = + Nx.put_slice(values_to_impute, [row, col], Nx.reshape(neighbor_avg, {1, 1})) + {{col + 1, mask, num_neighbors, cols, row, x}, values_to_impute} else {{col + 1, mask, num_neighbors, num_cols, row, x}, values_to_impute} @@ -166,11 +169,10 @@ defmodule Scholar.Impute.KNNImputter do {_, row_distances} = while {{i = 0, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances}, i < rows do - potential_donor = x[i] distance = - calculate_distance(row_with_value_to_fill, nan_col, potential_donor,nan_row) + calculate_distance(row_with_value_to_fill, nan_col, potential_donor, nan_row) row_distances = Nx.indexed_put(row_distances, Nx.new_axis(i, 0), distance) {{i + 1, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances} @@ -183,7 +185,7 @@ defmodule Scholar.Impute.KNNImputter do Nx.sum(values) / num_neighbors end - defnp calculate_distance(row,nan_col,potential_donor,nan_row) do + defnp calculate_distance(row, nan_col, potential_donor, nan_row) do case row do ^nan_row -> Nx.Constants.infinity(Nx.type(row)) _ -> nan_euclidean(row, nan_col, potential_donor) diff --git a/test/scholar/impute/knn_imputter_test.exs b/test/scholar/impute/knn_imputter_test.exs index c76becb4..8937e7c9 100644 --- a/test/scholar/impute/knn_imputter_test.exs +++ b/test/scholar/impute/knn_imputter_test.exs @@ -73,7 +73,7 @@ defmodule KNNImputterTest do test "missing values different than :nan" do x = generate_data() x = Nx.select(Nx.is_nan(x), 19.0, x) -# x = Nx.select(Nx.equal(x,19), :nan, x) + # x = Nx.select(Nx.equal(x,19), :nan, x) jit_fit = Nx.Defn.jit(&KNNImputter.fit/2) jit_transform = Nx.Defn.jit(&KNNImputter.transform/2)