Skip to content

Commit

Permalink
Store data in the KDTree
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Nov 5, 2023
1 parent 45adb49 commit 8a4ee41
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
19 changes: 12 additions & 7 deletions lib/scholar/neighbors/kd_tree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ defmodule Scholar.Neighbors.KDTree do

import Nx.Defn

@derive {Nx.Container, keep: [:levels], containers: [:indexes]}
@enforce_keys [:levels, :indexes]
defstruct [:levels, :indexes]
@derive {Nx.Container, keep: [:levels], containers: [:indexes, :data]}
@enforce_keys [:levels, :indexes, :data]
defstruct [:levels, :indexes, :data]

@doc """
Builds a KDTree without known min-max bounds.
Expand All @@ -48,6 +48,7 @@ defmodule Scholar.Neighbors.KDTree do
iex> Scholar.Neighbors.KDTree.unbanded(Nx.iota({5, 2}), compiler: EXLA.Defn)
%Scholar.Neighbors.KDTree{
data: Nx.iota({5, 2}),
levels: 3,
indexes: Nx.u32([3, 1, 4, 0, 2])
}
Expand All @@ -69,7 +70,7 @@ defmodule Scholar.Neighbors.KDTree do
Nx.argsort(tensor[[.., 0]], direction: :desc, type: :u32)
end

%__MODULE__{levels: levels, indexes: indexes}
%__MODULE__{levels: levels, indexes: indexes, data: tensor}
end

defp recur([{_i, %Nx.Tensor{shape: {1}} = leaf} | rest], next, acc, tensor, level, levels, opts) do
Expand Down Expand Up @@ -143,11 +144,16 @@ defmodule Scholar.Neighbors.KDTree do
iex> Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10)
%Scholar.Neighbors.KDTree{
data: Nx.iota({5, 2}),
levels: 3,
indexes: Nx.u32([3, 1, 4, 0, 2])
}
"""
defn banded(tensor, amplitude) do
deftransform banded(tensor, amplitude) do
%__MODULE__{levels: levels(tensor), indexes: banded_n(tensor, amplitude), data: tensor}
end

defnp banded_n(tensor, amplitude) do
levels = levels(tensor)
{size, dims} = Nx.shape(tensor)
band = amplitude + 1
Expand All @@ -162,8 +168,7 @@ defmodule Scholar.Neighbors.KDTree do
end

k = rem(level, dims)
indexes = Nx.argsort(tensor[[.., k]] + band * tags, type: :u32)
%__MODULE__{levels: levels, indexes: indexes}
Nx.argsort(tensor[[.., k]] + band * tags, type: :u32)
end

defnp update_tags(tags, indexes, level, levels, size) do
Expand Down
18 changes: 12 additions & 6 deletions test/scholar/neighbors/kd_tree_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,24 @@ defmodule Scholar.Neighbors.KDTreeTest do
end

test "corner cases" do
assert Scholar.Neighbors.KDTree.unbanded(Nx.iota({1, 2}), compiler: EXLA.Defn) ==
%Scholar.Neighbors.KDTree{levels: 1, indexes: Nx.u32([0])}
assert %Scholar.Neighbors.KDTree{levels: 1, indexes: indexes} =
Scholar.Neighbors.KDTree.unbanded(Nx.iota({1, 2}), compiler: EXLA.Defn)

assert Scholar.Neighbors.KDTree.unbanded(Nx.iota({2, 2}), compiler: EXLA.Defn) ==
%Scholar.Neighbors.KDTree{levels: 2, indexes: Nx.u32([1, 0])}
assert indexes == Nx.u32([0])

assert %Scholar.Neighbors.KDTree{levels: 2, indexes: indexes} =
Scholar.Neighbors.KDTree.unbanded(Nx.iota({2, 2}), compiler: EXLA.Defn)

assert indexes == Nx.u32([1, 0])
end
end

describe "banded" do
test "iota" do
assert Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10) ==
%Scholar.Neighbors.KDTree{levels: 3, indexes: Nx.u32([3, 1, 4, 0, 2])}
assert %Scholar.Neighbors.KDTree{levels: 3, indexes: indexes} =
Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10)

assert indexes == Nx.u32([3, 1, 4, 0, 2])
end

test "float" do
Expand Down

0 comments on commit 8a4ee41

Please sign in to comment.