Skip to content

Commit

Permalink
Fix bug with floating point data
Browse files Browse the repository at this point in the history
  • Loading branch information
msluszniak committed Apr 24, 2024
1 parent 0619dc5 commit 62bc90b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
40 changes: 22 additions & 18 deletions lib/scholar/neighbors/kd_tree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ defmodule Scholar.Neighbors.KDTree do
"""

import Nx.Defn
import Scholar.Shared
alias Scholar.Metrics.Distance

@derive {Nx.Container, keep: [:levels], containers: [:indices, :data]}
Expand Down Expand Up @@ -327,15 +328,15 @@ defmodule Scholar.Neighbors.KDTree do
)

{size, dims} = Nx.shape(tree.data)
nearest_neighbors = Nx.broadcast(Nx.s64(0), {k})
distances = Nx.broadcast(Nx.Constants.infinity(), {k})
nearest_neighbors = Nx.broadcast(Nx.s64(-1), {k})
distances = Nx.broadcast(Nx.Constants.infinity(to_float_type(tree.data)), {k})
visited = Nx.broadcast(Nx.u8(0), {size})

indices = tree.indices |> Nx.as_type(:s64)
data = tree.data

down = 0
up = 1
down = Nx.u8(0)
up = Nx.u8(1)
mode = down
i = Nx.s64(0)

Expand All @@ -351,27 +352,28 @@ defmodule Scholar.Neighbors.KDTree do
])

{nearest_neighbors, _} =
while {nearest_neighbors, {node, data, indices, point, distances, visited, i, mode}},
while {nearest_neighbors,
{node, data, indices, point, distances, visited, i, mode, down, up}},
node != -1 and i >= 0 do
coord_indicator = rem(i, dims)

{node, i, visited, nearest_neighbors, distances, mode} =
{node, i, visited, nearest_neighbors, distances, mode, down, up} =
cond do
node >= size ->
{parent(node), i - 1, visited, nearest_neighbors, distances, up}
{parent(node), i - 1, visited, nearest_neighbors, distances, up, down, up}

mode == down and
point[[coord_indicator]] < data[[indices[node], coord_indicator]] ->
{left_child(node), i + 1, visited, nearest_neighbors, distances, down}
{left_child(node), i + 1, visited, nearest_neighbors, distances, down, down, up}

mode == down and
point[[coord_indicator]] >= data[[indices[node], coord_indicator]] ->
{right_child(node), i + 1, visited, nearest_neighbors, distances, down}
{right_child(node), i + 1, visited, nearest_neighbors, distances, down, down, up}

mode == up ->
cond do
visited[indices[node]] ->
{parent(node), i - 1, visited, nearest_neighbors, distances, up}
{parent(node), i - 1, visited, nearest_neighbors, distances, up, down, up}

(left_child(node) >= size and right_child(node) >= size) or
(left_child(node) < size and visited[indices[left_child(node)]] and
Expand All @@ -392,7 +394,7 @@ defmodule Scholar.Neighbors.KDTree do
opts
)

{parent(node), i - 1, visited, nearest_neighbors, distances, up}
{parent(node), i - 1, visited, nearest_neighbors, distances, up, down, up}

left_child(node) < size and visited[indices[left_child(node)]] and
right_child(node) < size and
Expand All @@ -418,9 +420,10 @@ defmodule Scholar.Neighbors.KDTree do
) <
distances
) do
{right_child(node), i + 1, visited, nearest_neighbors, distances, down}
{right_child(node), i + 1, visited, nearest_neighbors, distances, down, down,
up}
else
{parent(node), i - 1, visited, nearest_neighbors, distances, up}
{parent(node), i - 1, visited, nearest_neighbors, distances, up, down, up}
end

((right_child(node) < size and visited[indices[right_child(node)]]) or
Expand All @@ -447,22 +450,23 @@ defmodule Scholar.Neighbors.KDTree do
) <
distances
) do
{left_child(node), i + 1, visited, nearest_neighbors, distances, down}
{left_child(node), i + 1, visited, nearest_neighbors, distances, down, down,
up}
else
{parent(node), i - 1, visited, nearest_neighbors, distances, up}
{parent(node), i - 1, visited, nearest_neighbors, distances, up, down, up}
end

# Should be not reachable
true ->
{node, i, visited, nearest_neighbors, distances, mode}
{node, i, visited, nearest_neighbors, distances, -1, down, up}
end

# Should be not reachable
true ->
{node, i, visited, nearest_neighbors, distances, mode}
{node, i, visited, nearest_neighbors, distances, -1, down, up}
end

{nearest_neighbors, {node, data, indices, point, distances, visited, i, mode}}
{nearest_neighbors, {node, data, indices, point, distances, visited, i, mode, down, up}}
end

Nx.revectorize(nearest_neighbors, input_vectorized_axes, target_shape: {num_points, k})
Expand Down
7 changes: 7 additions & 0 deletions test/scholar/neighbors/kd_tree_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,12 @@ defmodule Scholar.Neighbors.KDTreeTest do
assert KDTree.predict(kdtree, x_pred(), k: 4) ==
Nx.tensor([[0, 6, 4, 2], [5, 2, 9, 0], [0, 9, 2, 5], [5, 2, 7, 4]])
end

test "float type data" do
kdtree = KDTree.fit(x() |> Nx.as_type(:f64))

assert KDTree.predict(kdtree, x_pred(), k: 4) ==
Nx.tensor([[0, 6, 4, 2], [5, 2, 9, 0], [0, 9, 2, 5], [5, 2, 7, 4]])
end
end
end

0 comments on commit 62bc90b

Please sign in to comment.