Skip to content

Commit

Permalink
Properly cast mode to u8 in KDTree
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Apr 25, 2024
1 parent eb63b68 commit 70a85ff
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions lib/scholar/neighbors/kd_tree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,6 @@ defmodule Scholar.Neighbors.KDTree do
predict_n(tree, data, NimbleOptions.validate!(opts, @predict_schema))
end

defnp predict_n(tree, data, opts) do
query_points(data, tree, opts)
end

defnp sort_by_distances(distances, point_indices) do
indices = Nx.argsort(distances)
{Nx.take(distances, indices), Nx.take(point_indices, indices)}
Expand Down Expand Up @@ -315,7 +311,7 @@ defmodule Scholar.Neighbors.KDTree do
end
end

defnp query_points(point, tree, opts) do
defnp predict_n(tree, point, opts) do
k = opts[:k]
node = Nx.as_type(root(), :s64)

Expand Down Expand Up @@ -458,12 +454,12 @@ defmodule Scholar.Neighbors.KDTree do

# Should be not reachable
true ->
{node, i, visited, nearest_neighbors, distances, -1, down, up}
{node, i, visited, nearest_neighbors, distances, Nx.u8(-1), down, up}
end

# Should be not reachable
true ->
{node, i, visited, nearest_neighbors, distances, -1, down, up}
{node, i, visited, nearest_neighbors, distances, Nx.u8(-1), down, up}
end

{nearest_neighbors, {node, data, indices, point, distances, visited, i, mode, down, up}}
Expand Down

0 comments on commit 70a85ff

Please sign in to comment.