Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KNNImputer #303

Merged
merged 16 commits into from
Nov 28, 2024
Merged

Add KNNImputer #303

merged 16 commits into from
Nov 28, 2024

Conversation

srzeszut
Copy link
Contributor

I have added the KNNImputer and I am currently implementing tests to ensure that it behaves as expected across various scenarios, including edge cases.

Comment on lines 75 to 80
if opts[:missing_values] != :nan and
Nx.any(Nx.is_nan(x)) == Nx.tensor(1, type: :u8) do
raise ArgumentError,
":missing_values other than :nan possible only if there is no Nx.Constant.nan() in the array"
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check does not really work in Nx. If you call fit inside Nx.Defn.jit, then x is an expression, and we can't read its values to find out if there is a nan or not. The best we can do is to remove this check and document it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found this check in simple imputer
https://github.com/elixir-nx/scholar/blob/main/lib/scholar/impute/simple_imputer.ex
Are you sure it won't work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is also broken there. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have fixed it there: c024c5b


all_nan_rows_count = Nx.sum(all_nan_rows)

if num_neighbors > rows - 1 - Nx.to_number(all_nan_rows_count) do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, this code won't work because, when you have an expression, you can't get a number from it. Can we remove this check? What happens if we don't check for this condition?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can test this by calling fit after jitting it with Nx.Defn.fit.


# if potential neighbor has nan in nan_col, we don't want to calculate distance and the case if potential_neighbour is the row to impute
{potential_neighbor} =
if potential_neighbor[nan_col] == Nx.Constants.nan() do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this check is guaranteed to work, given two NaNs are not guaranteed to be equal. Using Nx.is_nan would be more appropriate.


x =
if opts[:missing_values] != :nan,
do: Nx.select(Nx.equal(x, opts[:missing_values]), Nx.Constants.nan(), x),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Nx.is_nan here NaN is not equal to itself

coordinates = coordinates - 1

# inputes zeros in nan_col to calculate distance with squared_euclidean
new_row = Nx.indexed_put(row, Nx.new_axis(nan_col, 0), Nx.tensor(0))
Copy link
Contributor

@msluszniak msluszniak Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, when you write in defn, you don't need to wrap this zero in Nx.tensor. I prefer to explicitly use Nx.<type> or Nx.tensor(x, type: type) to indicate the type of the tensor. Now, there are some cases where imputter has fixed type like :f32. I think that this might cause undesired upcasts when e.g. I have tensor of type :bf16. So I suggest to check if there are any unwanted casts / upcast.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it but I don't know how to change this line
row_distances = Nx.iota({rows}, type: {:f, 32})
because i don't know what the type calculated distance will be at this point


# if row has all nans we skip it
{weight, potential_neighbor} =
if present_coordinates == 0 do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in comment up, try to replace "bare" numbers with typed tensors

@@ -0,0 +1,256 @@
defmodule Scholar.Impute.KNNImputer do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be written with double t KNNImputter like formatter etc.

Copy link
Contributor

@msluszniak msluszniak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, I dropped some comments :))

@krstopro
Copy link
Member

Hi @srzeszut and thanks for the pull request. I’m traveling now and don’t have my laptop with me. Will be back this Sunday, so I will have a look probably next week.

@srzeszut
Copy link
Contributor Author

srzeszut commented Oct 27, 2024

Thanks for the review, I apply suggested changes and left some comments.


if num_neighbors > rows - 1 - Nx.to_number(all_nan_rows_count) do
raise ArgumentError,
"Number of neighbors rows must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error messages start in lowercase. :)

Suggested change
"Number of neighbors rows must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value)"
"number of neighbors rows must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value)"


all_nan_rows_count = Nx.sum(all_nan_rows)

if num_neighbors > rows - 1 - Nx.to_number(all_nan_rows_count) do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add some tests? In particular, please add a test where you call jit this function and then you call it: Nx.Defn.jit(...).(arg1, arg2). It should reveal some errors around here. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added tests and checked it. I removed those checks and added them in the description

Comment on lines 6 to 7
`n_neighbors` nearest neighbors found in the training set. Two samples are
close if the features that neither is missing are close.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`n_neighbors` nearest neighbors found in the training set. Two samples are
close if the features that neither is missing are close.
`n_neighbors` nearest neighbors found in the training set. Two samples are
close if the features that neither is missing are close.


Preconditions:
* `number_of_neighbors` is a positive integer.
* number of neighbors must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value) otherwise it is better to use simple imputter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please try to break this long line :)

Comment on lines 106 to 116
test "Wrong impute rank" do
x = Nx.tensor([1, 2, 2, 3])

assert_raise ArgumentError,
"Wrong input rank. Expected: 2, got: 1",
fn ->
KNNImputter.fit(x, missing_values: 1, number_of_neighbors: 2)
end
end

test "Invalid n_neighbors value" do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test names start in lowercase :)

Suggested change
test "Wrong impute rank" do
x = Nx.tensor([1, 2, 2, 3])
assert_raise ArgumentError,
"Wrong input rank. Expected: 2, got: 1",
fn ->
KNNImputter.fit(x, missing_values: 1, number_of_neighbors: 2)
end
end
test "Invalid n_neighbors value" do
test "invalid impute rank" do
x = Nx.tensor([1, 2, 2, 3])
assert_raise ArgumentError,
"Wrong input rank. Expected: 2, got: 1",
fn ->
KNNImputter.fit(x, missing_values: 1, number_of_neighbors: 2)
end
end
test "invalid n_neighbors value" do

Copy link
Contributor

@josevalim josevalim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dropped the last round of nitpicks and we are good to go!

Copy link
Member

@krstopro krstopro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First review. Some features we might wanna have:

  • Make k-NN algorithm configurable.
  • Make the metric configurable.

You can leave these for another pull request. Have a look at e.g. KNNClassifier how it is done over there.
I should have another look tonight.

The default value expects there are no NaNs in the input tensor.
"""
],
number_of_neighbors: [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest changing this to num_neighbors to be consistent with the rest of Scholar.

Copy link
Member

@krstopro krstopro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several minor comments for now. I have to go through the code at least once more as I don't exactly understand the logic here.


x =
if opts[:missing_values] != :nan,
do: Nx.select(Nx.equal(x, opts[:missing_values]), Nx.Constants.nan(), x),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to use == instead of Nx.equal/2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a deftransform, so Nx.equal is the proper function. == will be Elixir.Kernel.==

placeholder_value = Nx.Constants.nan() |> Nx.tensor()

statistics = knn_impute(x, placeholder_value, num_neighbors: num_neighbors)
missing_values = opts[:missing_values]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move this line above so that you don't access opts[:missing_values] multiple times.


{_, values_to_impute} =
while {{row = 0, mask, num_neighbors, num_rows, x}, values_to_impute},
Nx.less(row, num_rows) do
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use < instead of Nx.less/2 over here.

Nx.less(row, num_rows) do
{_, values_to_impute} =
while {{col = 0, mask, num_neighbors, num_cols, row, x}, values_to_impute},
Nx.less(col, num_cols) do
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

{_, values_to_impute} =
while {{col = 0, mask, num_neighbors, num_cols, row, x}, values_to_impute},
Nx.less(col, num_cols) do
if mask[row][col] > 0 do
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if mask[row][col] do should work here.

Comment on lines 38 to 39
* `number_of_neighbors` is a positive integer.
* number of neighbors must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value) otherwise it is better to use simple imputter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* `number_of_neighbors` is a positive integer.
* number of neighbors must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value) otherwise it is better to use simple imputter
* The number of neighbors must be less than the number of valid rows - 1.
A valid row is a row with more than 1 non-NaN values. Otherwise it is better to use a simpler imputer.

Preconditions:
* `number_of_neighbors` is a positive integer.
* number of neighbors must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value) otherwise it is better to use simple imputter
* when you set a value different than :nan in `missing_values` there should be no NaNs in the input tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* when you set a value different than :nan in `missing_values` there should be no NaNs in the input tensor
* When you set a value different than `:nan` in `missing_values` there should be no NaNs in the input tensor

* `:missing_values` - the same value as in `:missing_values`

* `:statistics` - The imputation fill value for each feature. Computing statistics can result in
[`Nx.Constant.nan/0`](https://hexdocs.pm/nx/Nx.Constants.html#nan/0) values.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[`Nx.Constant.nan/0`](https://hexdocs.pm/nx/Nx.Constants.html#nan/0) values.
[`Nx.Constants.nan/0`](https://hexdocs.pm/nx/Nx.Constants.html#nan/0) values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need the explicit linking in hexdoc?


The function returns a struct with the following parameters:

* `:missing_values` - the same value as in `:missing_values`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* `:missing_values` - the same value as in `:missing_values`
* `:missing_values` - the same value as in the `:missing_values` option


num_neighbors = opts[:number_of_neighbors]

placeholder_value = Nx.Constants.nan() |> Nx.tensor()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
placeholder_value = Nx.Constants.nan() |> Nx.tensor()
placeholder_value = Nx.Constants.nan()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably want to pass the input type here to avoid upcasts


opts_schema = [
missing_values: [
type: {:or, [:float, :integer, {:in, [:nan]}]},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
type: {:or, [:float, :integer, {:in, [:nan]}]},
type: {:or, [:float, :integer, {:in, [:nan]}]},

I believe this should allow :infinity and :neg_infinity too for completeness

Comment on lines 143 to 146
indices =
[Nx.stack(row), Nx.stack(col)]
|> Nx.concatenate()
|> Nx.stack()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
indices =
[Nx.stack(row), Nx.stack(col)]
|> Nx.concatenate()
|> Nx.stack()
indices = Nx.stack([row, col]) |> Nx.reshape({1, 2})

If I read the code correctly, row and col are scalars and this should yield the same result

|> Nx.concatenate()
|> Nx.stack()

values_to_impute = Nx.indexed_put(values_to_impute, indices, Nx.stack(neighbor_avg))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
values_to_impute = Nx.indexed_put(values_to_impute, indices, Nx.stack(neighbor_avg))
values_to_impute = Nx.put_slice(values_to_impute, [row, col], Nx.reshape(neighbor_avg, {1, 1}))

I think this is even simpler

Comment on lines 172 to 186
{_, row_distances} =
while {{i = 0, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances},
Nx.less(i, rows) do
potential_donor = x[i]

distance =
if i == nan_row do
Nx.Constants.infinity(Nx.type(row_with_value_to_fill))
else
nan_euclidian(row_with_value_to_fill, nan_col, potential_donor)
end

row_distances = Nx.indexed_put(row_distances, Nx.new_axis(i, 0), distance)
{{i + 1, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances}
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try this:

  potential_donors = Nx.vectorize(x, :rows)
  distances = nan_euclidean(row_with_value_to_fill, nan_col, potential_donors) |> Nx.devectorize()
  row_distances = Nx.indexed_put(distances, [i], Nx.Constants.infinity())

@srzeszut
Copy link
Contributor Author

Thanks for all the comments, I applied your suggested changes to the code.

@josevalim josevalim merged commit c11afad into elixir-nx:main Nov 28, 2024
2 checks passed
@josevalim
Copy link
Contributor

💚 💙 💜 💛 ❤️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants