diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index fe15d518..e27473b7 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -29,9 +29,9 @@ defmodule Scholar.Neighbors.KDTree do import Nx.Defn - @derive {Nx.Container, keep: [:levels], containers: [:indexes]} - @enforce_keys [:levels, :indexes] - defstruct [:levels, :indexes] + @derive {Nx.Container, keep: [:levels], containers: [:indexes, :data]} + @enforce_keys [:levels, :indexes, :data] + defstruct [:levels, :indexes, :data] @doc """ Builds a KDTree without known min-max bounds. @@ -48,6 +48,7 @@ defmodule Scholar.Neighbors.KDTree do iex> Scholar.Neighbors.KDTree.unbanded(Nx.iota({5, 2}), compiler: EXLA.Defn) %Scholar.Neighbors.KDTree{ + data: Nx.iota({5, 2}), levels: 3, indexes: Nx.u32([3, 1, 4, 0, 2]) } @@ -69,7 +70,7 @@ defmodule Scholar.Neighbors.KDTree do Nx.argsort(tensor[[.., 0]], direction: :desc, type: :u32) end - %__MODULE__{levels: levels, indexes: indexes} + %__MODULE__{levels: levels, indexes: indexes, data: tensor} end defp recur([{_i, %Nx.Tensor{shape: {1}} = leaf} | rest], next, acc, tensor, level, levels, opts) do @@ -143,11 +144,16 @@ defmodule Scholar.Neighbors.KDTree do iex> Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10) %Scholar.Neighbors.KDTree{ + data: Nx.iota({5, 2}), levels: 3, indexes: Nx.u32([3, 1, 4, 0, 2]) } """ - defn banded(tensor, amplitude) do + deftransform banded(tensor, amplitude) do + %__MODULE__{levels: levels(tensor), indexes: banded_n(tensor, amplitude), data: tensor} + end + + defnp banded_n(tensor, amplitude) do levels = levels(tensor) {size, dims} = Nx.shape(tensor) band = amplitude + 1 @@ -162,8 +168,7 @@ defmodule Scholar.Neighbors.KDTree do end k = rem(level, dims) - indexes = Nx.argsort(tensor[[.., k]] + band * tags, type: :u32) - %__MODULE__{levels: levels, indexes: indexes} + Nx.argsort(tensor[[.., k]] + band * tags, type: :u32) end defnp update_tags(tags, indexes, level, levels, size) do diff --git a/test/scholar/neighbors/kd_tree_test.exs b/test/scholar/neighbors/kd_tree_test.exs index 66ee0f9f..90835e35 100644 --- a/test/scholar/neighbors/kd_tree_test.exs +++ b/test/scholar/neighbors/kd_tree_test.exs @@ -35,18 +35,24 @@ defmodule Scholar.Neighbors.KDTreeTest do end test "corner cases" do - assert Scholar.Neighbors.KDTree.unbanded(Nx.iota({1, 2}), compiler: EXLA.Defn) == - %Scholar.Neighbors.KDTree{levels: 1, indexes: Nx.u32([0])} + assert %Scholar.Neighbors.KDTree{levels: 1, indexes: indexes} = + Scholar.Neighbors.KDTree.unbanded(Nx.iota({1, 2}), compiler: EXLA.Defn) - assert Scholar.Neighbors.KDTree.unbanded(Nx.iota({2, 2}), compiler: EXLA.Defn) == - %Scholar.Neighbors.KDTree{levels: 2, indexes: Nx.u32([1, 0])} + assert indexes == Nx.u32([0]) + + assert %Scholar.Neighbors.KDTree{levels: 2, indexes: indexes} = + Scholar.Neighbors.KDTree.unbanded(Nx.iota({2, 2}), compiler: EXLA.Defn) + + assert indexes == Nx.u32([1, 0]) end end describe "banded" do test "iota" do - assert Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10) == - %Scholar.Neighbors.KDTree{levels: 3, indexes: Nx.u32([3, 1, 4, 0, 2])} + assert %Scholar.Neighbors.KDTree{levels: 3, indexes: indexes} = + Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10) + + assert indexes == Nx.u32([3, 1, 4, 0, 2]) end test "float" do