Skip to content

Commit

Permalink
Output distances in kdtree (elixir-nx#264)
Browse files Browse the repository at this point in the history
* Output distances in kdtree

* Add checks on data in predict
  • Loading branch information
msluszniak authored May 13, 2024
1 parent e0e92d0 commit ebdae8f
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 62 deletions.
138 changes: 94 additions & 44 deletions lib/scholar/neighbors/kd_tree.ex
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defmodule Scholar.Neighbors.KDTree do
@moduledoc """
Implements a kd-tree, a space-partitioning data structure for organizing points
Implements a k-d tree, a space-partitioning data structure for organizing points
in a k-dimensional space.
It can be used to predict the K-Nearest Neighbors of a given input.
Expand All @@ -19,14 +19,13 @@ defmodule Scholar.Neighbors.KDTree do

import Nx.Defn
import Scholar.Shared
alias Scholar.Metrics.Distance

@derive {Nx.Container, keep: [:levels], containers: [:indices, :data]}
@enforce_keys [:levels, :indices, :data]
defstruct [:levels, :indices, :data]
@derive {Nx.Container, keep: [:levels, :num_neighbors, :metric], containers: [:indices, :data]}
@enforce_keys [:levels, :indices, :data, :num_neighbors, :metric]
defstruct [:levels, :indices, :data, :num_neighbors, :metric]

opts = [
k: [
num_neighbors: [
type: :pos_integer,
default: 3,
doc: "The number of neighbors to use by default for `k_neighbors` queries"
Expand All @@ -45,22 +44,48 @@ defmodule Scholar.Neighbors.KDTree do
]
]

@predict_schema NimbleOptions.new!(opts)
@opts_schema NimbleOptions.new!(opts)

@doc """
Builds a KDTree.
## Examples
iex> Scholar.Neighbors.KDTree.fit(Nx.iota({5, 2}))
%Scholar.Neighbors.KDTree{
data: Nx.iota({5, 2}),
levels: 3,
indices: Nx.u32([3, 1, 4, 0, 2])
}
iex> tree = Scholar.Neighbors.KDTree.fit(Nx.iota({5, 2}))
iex> tree.data
Nx.tensor(
[
[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]
]
)
iex> tree.levels
3
iex> tree.indices
Nx.u32([3, 1, 4, 0, 2])
"""
deftransform fit(tensor, _opts \\ []) do
%__MODULE__{levels: levels(tensor), indices: fit_n(tensor), data: tensor}
deftransform fit(tensor, opts \\ []) do
opts = NimbleOptions.validate!(opts, @opts_schema)

metric =
case opts[:metric] do
{:minkowski, p} ->
&Scholar.Metrics.Distance.minkowski(&1, &2, p: p)

:cosine ->
&Scholar.Metrics.Distance.pairwise_cosine/2
end

%__MODULE__{
levels: levels(tensor),
indices: fit_n(tensor),
data: tensor,
num_neighbors: opts[:num_neighbors],
metric: metric
}
end

defnp fit_n(tensor) do
Expand Down Expand Up @@ -247,8 +272,9 @@ defmodule Scholar.Neighbors.KDTree do
iex> x = Nx.iota({10, 2})
iex> x_predict = Nx.tensor([[2, 5], [1, 9], [6, 4]])
iex> kdtree = Scholar.Neighbors.KDTree.fit(x)
iex> Scholar.Neighbors.KDTree.predict(kdtree, x_predict, k: 3)
iex> kdtree = Scholar.Neighbors.KDTree.fit(x, num_neighbors: 3)
iex> {indices, distances} = Scholar.Neighbors.KDTree.predict(kdtree, x_predict)
iex> indices
#Nx.Tensor<
s64[3][3]
[
Expand All @@ -257,7 +283,21 @@ defmodule Scholar.Neighbors.KDTree do
[2, 3, 1]
]
>
iex> Scholar.Neighbors.KDTree.predict(kdtree, x_predict, k: 3, metric: {:minkowski, 1})
iex> distances
#Nx.Tensor<
f32[3][3]
[
[2.0, 2.0, 4.4721360206604],
[5.0, 5.385164737701416, 6.082762718200684],
[2.2360680103302, 3.0, 4.123105525970459]
]
>
iex> x = Nx.iota({10, 2})
iex> x_predict = Nx.tensor([[2, 5], [1, 9], [6, 4]])
iex> kdtree = Scholar.Neighbors.KDTree.fit(x, num_neighbors: 3, metric: {:minkowski, 1})
iex> {indices, distances} = Scholar.Neighbors.KDTree.predict(kdtree, x_predict)
iex> indices
#Nx.Tensor<
s64[3][3]
[
Expand All @@ -266,26 +306,36 @@ defmodule Scholar.Neighbors.KDTree do
[2, 3, 1]
]
>
iex> distances
#Nx.Tensor<
f32[3][3]
[
[2.0, 2.0, 6.0],
[7.0, 7.0, 7.0],
[3.0, 3.0, 5.0]
]
>
"""
deftransform predict(tree, data, opts \\ []) do
predict_n(tree, data, NimbleOptions.validate!(opts, @predict_schema))
deftransform predict(tree, data) do
if Nx.rank(data) != 2 do
raise ArgumentError, "Input data must be a 2D tensor"
end

if Nx.axis_size(data, -1) != Nx.axis_size(tree.data, -1) do
raise ArgumentError, "Input data must have the same number of features as the training data"
end

predict_n(tree, data)
end

defnp sort_by_distances(distances, point_indices) do
indices = Nx.argsort(distances)
{Nx.take(distances, indices), Nx.take(point_indices, indices)}
end

defnp compute_distance(x1, x2, opts) do
case opts[:metric] do
{:minkowski, 2} -> Distance.squared_euclidean(x1, x2)
{:minkowski, p} -> Distance.minkowski(x1, x2, p: p)
:cosine -> Distance.cosine(x1, x2)
end
end

defnp update_knn(nearest_neighbors, distances, data, indices, curr_node, point, k, opts) do
curr_dist = compute_distance(data[[indices[curr_node]]], point, opts)
metric = opts[:metric]
curr_dist = metric.(data[[indices[curr_node]]], point)

if curr_dist < distances[[-1]] do
nearest_neighbors =
Expand All @@ -311,8 +361,8 @@ defmodule Scholar.Neighbors.KDTree do
end
end

defnp predict_n(tree, point, opts) do
k = opts[:k]
defnp predict_n(tree, point) do
k = tree.num_neighbors
node = Nx.as_type(root(), :s64)

input_vectorized_axes = point.vectorized_axes
Expand All @@ -330,6 +380,7 @@ defmodule Scholar.Neighbors.KDTree do

indices = tree.indices |> Nx.as_type(:s64)
data = tree.data
metric = tree.metric

mode = down()
i = Nx.s64(0)
Expand All @@ -345,8 +396,8 @@ defmodule Scholar.Neighbors.KDTree do
point
])

{nearest_neighbors, _} =
while {nearest_neighbors, {node, data, indices, point, distances, visited, i, mode}},
{{nearest_neighbors, distances}, _} =
while {{nearest_neighbors, distances}, {node, data, indices, point, visited, i, mode}},
node != -1 and i >= 0 do
coord_indicator = rem(i, dims)

Expand Down Expand Up @@ -384,7 +435,7 @@ defmodule Scholar.Neighbors.KDTree do
indices,
point,
k,
opts
metric: metric
)

{parent(node), i - 1, visited, nearest_neighbors, distances, up()}
Expand All @@ -402,14 +453,13 @@ defmodule Scholar.Neighbors.KDTree do
indices,
point,
k,
opts
metric: metric
)

if Nx.any(
compute_distance(
metric.(
point[[coord_indicator]],
data[[indices[right_child(node)], coord_indicator]],
opts
data[[indices[right_child(node)], coord_indicator]]
) <
distances
) do
Expand All @@ -431,14 +481,13 @@ defmodule Scholar.Neighbors.KDTree do
indices,
point,
k,
opts
metric: metric
)

if Nx.any(
compute_distance(
metric.(
point[[coord_indicator]],
data[[indices[left_child(node)], coord_indicator]],
opts
data[[indices[left_child(node)], coord_indicator]]
) <
distances
) do
Expand All @@ -457,10 +506,11 @@ defmodule Scholar.Neighbors.KDTree do
{node, i, visited, nearest_neighbors, distances, none()}
end

{nearest_neighbors, {node, data, indices, point, distances, visited, i, mode}}
{{nearest_neighbors, distances}, {node, data, indices, point, visited, i, mode}}
end

Nx.revectorize(nearest_neighbors, input_vectorized_axes, target_shape: {num_points, k})
{Nx.revectorize(nearest_neighbors, input_vectorized_axes, target_shape: {num_points, k}),
Nx.revectorize(distances, input_vectorized_axes, target_shape: {num_points, k})}
end

defnp down(), do: Nx.u8(0)
Expand Down
82 changes: 64 additions & 18 deletions test/scholar/neighbors/kd_tree_test.exs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defmodule Scholar.Neighbors.KDTreeTest do
use ExUnit.Case, async: true
use Scholar.Case, async: true
alias Scholar.Neighbors.KDTree
doctest KDTree

Expand All @@ -20,18 +20,24 @@ defmodule Scholar.Neighbors.KDTreeTest do

describe "fit" do
test "iota" do
assert %KDTree{levels: 3, indices: indices} = KDTree.fit(Nx.iota({5, 2}))
assert indices == Nx.u32([3, 1, 4, 0, 2])
tree = KDTree.fit(Nx.iota({5, 2}))
assert tree.levels == 3
assert tree.indices == Nx.u32([3, 1, 4, 0, 2])
assert tree.num_neighbors == 3
end

test "float" do
assert %KDTree{levels: 4, indices: indices} = KDTree.fit(example() |> Nx.as_type(:f32))
assert Nx.to_flat_list(indices) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4]
tree = KDTree.fit(Nx.as_type(example(), :f32))
assert tree.levels == 4
assert Nx.to_flat_list(tree.indices) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4]
assert tree.num_neighbors == 3
end

test "sample" do
assert %KDTree{levels: 4, indices: indices} = KDTree.fit(example())
assert Nx.to_flat_list(indices) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4]
tree = KDTree.fit(example())
assert tree.levels == 4
assert Nx.to_flat_list(tree.indices) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4]
assert tree.num_neighbors == 3
end
end

Expand All @@ -57,30 +63,70 @@ defmodule Scholar.Neighbors.KDTreeTest do
describe "predict knn" do
test "all defaults" do
kdtree = KDTree.fit(x())
{indices, distances} = KDTree.predict(kdtree, x_pred())

assert KDTree.predict(kdtree, x_pred()) ==
Nx.tensor([[0, 6, 4], [5, 2, 9], [0, 9, 2], [5, 2, 7]])
assert indices == Nx.tensor([[0, 6, 4], [5, 2, 9], [0, 9, 2], [5, 2, 7]])

assert_all_close(
distances,
Nx.tensor([
[3.464101552963257, 4.582575798034668, 4.795831680297852],
[4.242640495300293, 4.690415859222412, 4.795831680297852],
[3.7416574954986572, 5.5677642822265625, 6.0],
[3.872983455657959, 3.872983455657959, 6.164413928985596]
])
)
end

test "metric set to {:minkowski, 1.5}" do
kdtree = KDTree.fit(x())
kdtree = KDTree.fit(x(), metric: {:minkowski, 1.5})
{indices, distances} = KDTree.predict(kdtree, x_pred())

assert indices == Nx.tensor([[0, 6, 2], [5, 2, 9], [0, 9, 2], [5, 2, 7]])

assert KDTree.predict(kdtree, x_pred(), metric: {:minkowski, 1.5}) ==
Nx.tensor([[0, 6, 2], [5, 2, 9], [0, 9, 2], [5, 2, 7]])
assert_all_close(
distances,
Nx.tensor([
[4.065119743347168, 5.191402435302734, 5.862917423248291],
[5.198591709136963, 5.591182708740234, 5.869683265686035],
[4.334622859954834, 6.35192346572876, 6.9637274742126465],
[4.649191856384277, 4.649191856384277, 7.664907932281494]
])
)
end

test "k set to 4" do
kdtree = KDTree.fit(x())
kdtree = KDTree.fit(x(), num_neighbors: 4)
{indices, distances} = KDTree.predict(kdtree, x_pred())

assert indices == Nx.tensor([[0, 6, 4, 2], [5, 2, 9, 0], [0, 9, 2, 5], [5, 2, 7, 4]])

assert KDTree.predict(kdtree, x_pred(), k: 4) ==
Nx.tensor([[0, 6, 4, 2], [5, 2, 9, 0], [0, 9, 2, 5], [5, 2, 7, 4]])
assert_all_close(
distances,
Nx.tensor([
[3.464101552963257, 4.582575798034668, 4.795831680297852, 5.099019527435303],
[4.242640495300293, 4.690415859222412, 4.795831680297852, 7.4833149909973145],
[3.7416574954986572, 5.5677642822265625, 6.0, 6.480740547180176],
[3.872983455657959, 3.872983455657959, 6.164413928985596, 6.78233003616333]
])
)
end

test "float type data" do
kdtree = KDTree.fit(x() |> Nx.as_type(:f64))
kdtree = KDTree.fit(x() |> Nx.as_type(:f64), num_neighbors: 4)
{indices, distances} = KDTree.predict(kdtree, x_pred())

assert indices == Nx.tensor([[0, 6, 4, 2], [5, 2, 9, 0], [0, 9, 2, 5], [5, 2, 7, 4]])

assert KDTree.predict(kdtree, x_pred(), k: 4) ==
Nx.tensor([[0, 6, 4, 2], [5, 2, 9, 0], [0, 9, 2, 5], [5, 2, 7, 4]])
assert_all_close(
distances,
Nx.tensor([
[3.464101552963257, 4.582575798034668, 4.795831680297852, 5.099019527435303],
[4.242640495300293, 4.690415859222412, 4.795831680297852, 7.4833149909973145],
[3.7416574954986572, 5.5677642822265625, 6.0, 6.480740547180176],
[3.872983455657959, 3.872983455657959, 6.164413928985596, 6.78233003616333]
])
)
end
end
end

0 comments on commit ebdae8f

Please sign in to comment.