Skip to content

Commit

Permalink
Merge pull request #2 from srzeszut/knn_imputer
Browse files Browse the repository at this point in the history
mix format
  • Loading branch information
srzeszut authored Nov 28, 2024
2 parents 071fb27 + d5913eb commit ac7fc1a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
16 changes: 9 additions & 7 deletions lib/scholar/impute/knn_imputter.ex
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,12 @@ defmodule Scholar.Impute.KNNImputter do

x =
if missing_values != :nan,
do: Nx.select(Nx.equal(x, missing_values), :nan, x),
else: x
do: Nx.select(Nx.equal(x, missing_values), :nan, x),
else: x

statistics =
knn_impute(x, num_neighbors: opts[:num_neighbors], missing_values: missing_values)

statistics = knn_impute(x, num_neighbors: opts[:num_neighbors], missing_values: missing_values)
%__MODULE__{statistics: statistics, missing_values: missing_values}
end

Expand Down Expand Up @@ -139,7 +140,9 @@ defmodule Scholar.Impute.KNNImputter do
neighbor_avg =
calculate_knn(x, row, col, rows: rows, num_neighbors: opts[:num_neighbors])

values_to_impute = Nx.put_slice(values_to_impute, [row, col], Nx.reshape(neighbor_avg, {1, 1}))
values_to_impute =
Nx.put_slice(values_to_impute, [row, col], Nx.reshape(neighbor_avg, {1, 1}))

{{col + 1, mask, num_neighbors, cols, row, x}, values_to_impute}
else
{{col + 1, mask, num_neighbors, num_cols, row, x}, values_to_impute}
Expand All @@ -166,11 +169,10 @@ defmodule Scholar.Impute.KNNImputter do
{_, row_distances} =
while {{i = 0, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances},
i < rows do

potential_donor = x[i]

distance =
calculate_distance(row_with_value_to_fill, nan_col, potential_donor,nan_row)
calculate_distance(row_with_value_to_fill, nan_col, potential_donor, nan_row)

row_distances = Nx.indexed_put(row_distances, Nx.new_axis(i, 0), distance)
{{i + 1, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances}
Expand All @@ -183,7 +185,7 @@ defmodule Scholar.Impute.KNNImputter do
Nx.sum(values) / num_neighbors
end

defnp calculate_distance(row,nan_col,potential_donor,nan_row) do
defnp calculate_distance(row, nan_col, potential_donor, nan_row) do
case row do
^nan_row -> Nx.Constants.infinity(Nx.type(row))
_ -> nan_euclidean(row, nan_col, potential_donor)
Expand Down
2 changes: 1 addition & 1 deletion test/scholar/impute/knn_imputter_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ defmodule KNNImputterTest do
test "missing values different than :nan" do
x = generate_data()
x = Nx.select(Nx.is_nan(x), 19.0, x)
# x = Nx.select(Nx.equal(x,19), :nan, x)
# x = Nx.select(Nx.equal(x,19), :nan, x)
jit_fit = Nx.Defn.jit(&KNNImputter.fit/2)
jit_transform = Nx.Defn.jit(&KNNImputter.transform/2)

Expand Down

0 comments on commit ac7fc1a

Please sign in to comment.