From 830f186051f1d69428356a583f87b42fa0003ecd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 5 Nov 2023 09:39:15 +0100 Subject: [PATCH 01/12] Unbanded KDTree --- lib/scholar/neighbors/kd_tree.ex | 240 ++++++++++++++++++ .../neighbors/radius_nearest_neighbors.ex | 4 +- test/scholar/neighbors/kd_tree_test.exs | 26 ++ 3 files changed, 269 insertions(+), 1 deletion(-) create mode 100644 lib/scholar/neighbors/kd_tree.ex create mode 100644 test/scholar/neighbors/kd_tree_test.exs diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex new file mode 100644 index 00000000..4f05de90 --- /dev/null +++ b/lib/scholar/neighbors/kd_tree.ex @@ -0,0 +1,240 @@ +defmodule Scholar.Neighbors.KDTree do + @moduledoc """ + Implements a kd-tree, a space-partitioning data structure for organizing points + in a k-dimensional space. + + This is implemented as one-dimensional tensor with indices pointed to highest + dimension of the given tensor. Traversal starts by calling `root/0` and then + accessing the `left_child/1` and `right_child/1`. The tree is left-balanced. + + Two construction modes are available: + + * `banded/2` - the tensor has min and max values with an amplitude given by `max - min`. + It is also guaranteed that the `amplitude * levels(tensor) + 1` does not overflow + the tensor. See `amplitude/1` to verify if this holds. This implementation happens + fully within `defn`. + + * `unbanded/2` - there are no known bands (min and max values) to the tensor. + This implementation is recursive and goes in and out of the `defn`, therefore + it cannot be called inside `defn`. + + ## References + + * [GPU-friendly, Parallel, and (Almost-)In-Place Construction of Left-Balanced k-d Trees](https://arxiv.org/pdf/2211.00120.pdf). + """ + + import Bitwise + import Nx.Defn + + # TODO: Benchmark + # TODO: Add tagged/amplitude version + + @derive {Nx.Container, keep: [:levels], containers: [:indexes]} + @enforce_keys [:levels, :indexes] + defstruct [:levels, :indexes] + + @doc """ + Builds a KDTree without known min-max bounds. + + If your tensor has a known bound (for exmaple, -1 and 1), + consider using the `banded/2` version which is more efficient. + + ## Options + + * `:compiler` - the default compiler to use for internal defn operations + + ## Examples + + iex> Scholar.Neighbors.KDTree.unbanded(Nx.iota({5, 2}), compiler: EXLA.Defn) + %Scholar.Neighbors.KDTree{ + levels: 3, + indexes: Nx.u32([3, 1, 4, 0, 2]) + } + + """ + def unbanded(tensor, opts \\ []) do + levels = levels(tensor) + {size, _dims} = Nx.shape(tensor) + subtree_size = unbanded_subtree_size(1, levels, size) + {left, mid, right} = Nx.Defn.jit_apply(&root_slice(&1, subtree_size), [tensor], opts) + + acc = <> + acc = recur([{1, left}, {2, right}], [], acc, tensor, 1, levels, opts) + %__MODULE__{levels: levels, indexes: Nx.from_binary(acc, :u32)} + end + + defp recur([{_i, %Nx.Tensor{shape: {1}} = leaf} | rest], next, acc, tensor, level, levels, opts) do + [leaf] = Nx.to_flat_list(leaf) + acc = <> + recur(rest, next, acc, tensor, level, levels, opts) + end + + defp recur([{i, %Nx.Tensor{shape: {2}} = node} | rest], next, acc, tensor, level, levels, opts) do + acc = <> + next = [{left_child(i), Nx.slice(node, [0], [1])} | next] + recur(rest, next, acc, tensor, level, levels, opts) + end + + defp recur([{i, indexes} | rest], next, acc, tensor, level, levels, opts) do + %Nx.Tensor{shape: {size, dims}} = tensor + k = rem(level, dims) + subtree_size = unbanded_subtree_size(left_child(i), levels, size) + + {left, mid, right} = + Nx.Defn.jit_apply(&recur_slice(&1, &2, &3, subtree_size), [tensor, indexes, k], opts) + + next = [{right_child(i), right}, {left_child(i), left} | next] + acc = <> + recur(rest, next, acc, tensor, level, levels, opts) + end + + defp recur([], [], acc, _tensor, _level, _levels, _opts) do + acc + end + + defp recur([], next, acc, tensor, level, levels, opts) do + recur(Enum.reverse(next), [], acc, tensor, level + 1, levels, opts) + end + + defp root_slice(tensor, subtree_size) do + indexes = Nx.argsort(tensor[[.., 0]]) + + {Nx.slice(indexes, [0], [subtree_size]), indexes[subtree_size], + Nx.slice(indexes, [subtree_size + 1], [Nx.size(indexes) - subtree_size - 1])} + end + + defp recur_slice(tensor, indexes, k, subtree_size) do + sorted = Nx.argsort(Nx.take(tensor, indexes)[[.., k]]) + indexes = Nx.take(indexes, sorted) + + {Nx.slice(indexes, [0], [subtree_size]), indexes[subtree_size], + Nx.slice(indexes, [subtree_size + 1], [Nx.size(indexes) - subtree_size - 1])} + end + + defp unbanded_subtree_size(i, levels, size) do + diff = levels - unbanded_level(i) - 1 + inner = (1 <<< diff) - 1 + fllc_s = (i <<< diff) + inner + inner + min(max(0, size - fllc_s), 1 <<< diff) + end + + defp unbanded_level(i) when is_integer(i), do: 31 - clz32(i + 1) + + @doc """ + Returns the amplitude of a tensor for banding. + + If -1 is returned, it means the tensor cannot use the `banded` algorithm + to generate a KDTree and `unbanded/2` must be used instead. + + This cannot be invoked inside a `defn`. + + ## Examples + + iex> Scholar.Neighbors.KDTree.amplitude(Nx.iota({10, 2})) + 19 + iex> Scholar.Neighbors.KDTree.amplitude(Nx.iota({20, 2}, type: :f32)) + 39.0 + iex> Scholar.Neighbors.KDTree.amplitude(Nx.iota({20, 2}, type: :u8)) + -1 + + """ + def amplitude(tensor) do + max = tensor |> Nx.reduce_max() |> Nx.to_number() + min = tensor |> Nx.reduce_min() |> Nx.to_number() + amplitude = max - min + limit = tensor.type |> Nx.Constants.max_finite() |> Nx.to_number() + + if max + (amplitude + 1) * (Nx.axis_size(tensor, 0) - 1) > limit do + -1 + else + amplitude + end + end + + @doc """ + Returns the number of resulting levels in a KDTree for `tensor`. + + ## Examples + + iex> Scholar.Neighbors.KDTree.levels(Nx.iota({10, 3})) + 4 + """ + deftransform levels(%Nx.Tensor{} = tensor) do + case Nx.shape(tensor) do + {size, _dims} -> 32 - clz32(size) + _ -> raise ArgumentError, "KDTrees requires a tensor of rank 2" + end + end + + @doc """ + Returns the root index. + + ## Examples + + iex> Scholar.Neighbors.KDTree.root() + 0 + + """ + deftransform root, do: 0 + + @doc """ + Returns the index of the left child of i. + + ## Examples + + iex> Scholar.Neighbors.KDTree.left_child(0) + 1 + iex> Scholar.Neighbors.KDTree.left_child(1) + 3 + + iex> Scholar.Neighbors.KDTree.left_child(Nx.u32(3)) + #Nx.Tensor< + u32 + 7 + > + + """ + deftransform left_child(i) when is_integer(i), do: 2 * i + 1 + deftransform left_child(%Nx.Tensor{} = t), do: Nx.add(Nx.multiply(2, t), 1) + + @doc """ + Returns the index of the right child of i. + + ## Examples + + iex> Scholar.Neighbors.KDTree.right_child(0) + 2 + iex> Scholar.Neighbors.KDTree.right_child(1) + 4 + + iex> Scholar.Neighbors.KDTree.right_child(Nx.u32(3)) + #Nx.Tensor< + u32 + 8 + > + + """ + deftransform right_child(i) when is_integer(i), do: 2 * i + 2 + deftransform right_child(%Nx.Tensor{} = t), do: Nx.add(Nx.multiply(2, t), 2) + + @clz_lookup {32, 31, 30, 30, 29, 29, 29, 29, 28, 28, 28, 28, 28, 28, 28, 28} + + defp clz32(x) when is_integer(x) do + n = + if x >= 1 <<< 16 do + if x >= 1 <<< 24 do + if x >= 1 <<< 28, do: 28, else: 24 + else + if x >= 1 <<< 20, do: 20, else: 16 + end + else + if x >= 1 <<< 8 do + if x >= 1 <<< 12, do: 12, else: 8 + else + if x >= 1 <<< 4, do: 4, else: 0 + end + end + + elem(@clz_lookup, x >>> n) - n + end +end diff --git a/lib/scholar/neighbors/radius_nearest_neighbors.ex b/lib/scholar/neighbors/radius_nearest_neighbors.ex index 66c378b1..b1eca0e3 100644 --- a/lib/scholar/neighbors/radius_nearest_neighbors.ex +++ b/lib/scholar/neighbors/radius_nearest_neighbors.ex @@ -1,6 +1,8 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do @moduledoc """ - The Radius Nearest Neighbors. It implements both classification and regression. + The Radius Nearest Neighbors. + + It implements both classification and regression. """ import Nx.Defn import Scholar.Shared diff --git a/test/scholar/neighbors/kd_tree_test.exs b/test/scholar/neighbors/kd_tree_test.exs new file mode 100644 index 00000000..c71596ab --- /dev/null +++ b/test/scholar/neighbors/kd_tree_test.exs @@ -0,0 +1,26 @@ +defmodule Scholar.Neighbors.KDTreeTest do + use ExUnit.Case, async: true + doctest Scholar.Neighbors.KDTree + + defp example do + Nx.tensor([ + [10, 15], + [46, 63], + [68, 21], + [40, 33], + [25, 54], + [15, 43], + [44, 58], + [45, 40], + [62, 69], + [53, 67] + ]) + end + + test "unbanded" do + assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = + Scholar.Neighbors.KDTree.unbanded(example(), compiler: EXLA.Defn) + + assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + end +end From 386654c5878ee74693f3548565759e8f88a1fe46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 5 Nov 2023 09:47:02 +0100 Subject: [PATCH 02/12] Handle degenerate cases --- lib/scholar/neighbors/kd_tree.ex | 24 ++++++++++++++++++------ test/scholar/neighbors/kd_tree_test.exs | 18 ++++++++++++++---- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index 4f05de90..b86355f4 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -36,7 +36,7 @@ defmodule Scholar.Neighbors.KDTree do @doc """ Builds a KDTree without known min-max bounds. - If your tensor has a known bound (for exmaple, -1 and 1), + If your tensor has a known bound (for example, -1 and 1), consider using the `banded/2` version which is more efficient. ## Options @@ -55,12 +55,20 @@ defmodule Scholar.Neighbors.KDTree do def unbanded(tensor, opts \\ []) do levels = levels(tensor) {size, _dims} = Nx.shape(tensor) - subtree_size = unbanded_subtree_size(1, levels, size) - {left, mid, right} = Nx.Defn.jit_apply(&root_slice(&1, subtree_size), [tensor], opts) - acc = <> - acc = recur([{1, left}, {2, right}], [], acc, tensor, 1, levels, opts) - %__MODULE__{levels: levels, indexes: Nx.from_binary(acc, :u32)} + indexes = + if size > 2 do + subtree_size = unbanded_subtree_size(1, levels, size) + {left, mid, right} = Nx.Defn.jit_apply(&root_slice(&1, subtree_size), [tensor], opts) + + acc = <> + acc = recur([{1, left}, {2, right}], [], acc, tensor, 1, levels, opts) + Nx.from_binary(acc, :u32) + else + degenerate_slice(tensor) + end + + %__MODULE__{levels: levels, indexes: indexes} end defp recur([{_i, %Nx.Tensor{shape: {1}} = leaf} | rest], next, acc, tensor, level, levels, opts) do @@ -103,6 +111,10 @@ defmodule Scholar.Neighbors.KDTree do Nx.slice(indexes, [subtree_size + 1], [Nx.size(indexes) - subtree_size - 1])} end + defp degenerate_slice(tensor) do + Nx.argsort(tensor[[.., 0]], direction: :desc) |> Nx.as_type(:u32) + end + defp recur_slice(tensor, indexes, k, subtree_size) do sorted = Nx.argsort(Nx.take(tensor, indexes)[[.., k]]) indexes = Nx.take(indexes, sorted) diff --git a/test/scholar/neighbors/kd_tree_test.exs b/test/scholar/neighbors/kd_tree_test.exs index c71596ab..280fb16f 100644 --- a/test/scholar/neighbors/kd_tree_test.exs +++ b/test/scholar/neighbors/kd_tree_test.exs @@ -17,10 +17,20 @@ defmodule Scholar.Neighbors.KDTreeTest do ]) end - test "unbanded" do - assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = - Scholar.Neighbors.KDTree.unbanded(example(), compiler: EXLA.Defn) + describe "unbanded" do + test "sample" do + assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = + Scholar.Neighbors.KDTree.unbanded(example(), compiler: EXLA.Defn) - assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + end + + test "corner cases" do + assert Scholar.Neighbors.KDTree.unbanded(Nx.iota({1, 2}), compiler: EXLA.Defn) == + %Scholar.Neighbors.KDTree{levels: 1, indexes: Nx.u32([0])} + + assert Scholar.Neighbors.KDTree.unbanded(Nx.iota({2, 2}), compiler: EXLA.Defn) == + %Scholar.Neighbors.KDTree{levels: 2, indexes: Nx.u32([1, 0])} + end end end From 135117196200369062927c22012fe7d0bf246c22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 5 Nov 2023 14:12:24 +0100 Subject: [PATCH 03/12] banded wip --- lib/scholar/neighbors/kd_tree.ex | 112 +++++++++++++++++++++--- mix.exs | 2 +- mix.lock | 2 +- test/scholar/neighbors/kd_tree_test.exs | 14 +++ 4 files changed, 117 insertions(+), 13 deletions(-) diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index b86355f4..cf31c6d4 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -23,7 +23,6 @@ defmodule Scholar.Neighbors.KDTree do * [GPU-friendly, Parallel, and (Almost-)In-Place Construction of Left-Balanced k-d Trees](https://arxiv.org/pdf/2211.00120.pdf). """ - import Bitwise import Nx.Defn # TODO: Benchmark @@ -65,7 +64,7 @@ defmodule Scholar.Neighbors.KDTree do acc = recur([{1, left}, {2, right}], [], acc, tensor, 1, levels, opts) Nx.from_binary(acc, :u32) else - degenerate_slice(tensor) + Nx.argsort(tensor[[.., 0]], direction: :desc, type: :u32) end %__MODULE__{levels: levels, indexes: indexes} @@ -105,18 +104,14 @@ defmodule Scholar.Neighbors.KDTree do end defp root_slice(tensor, subtree_size) do - indexes = Nx.argsort(tensor[[.., 0]]) + indexes = Nx.argsort(tensor[[.., 0]], type: :u32) {Nx.slice(indexes, [0], [subtree_size]), indexes[subtree_size], Nx.slice(indexes, [subtree_size + 1], [Nx.size(indexes) - subtree_size - 1])} end - defp degenerate_slice(tensor) do - Nx.argsort(tensor[[.., 0]], direction: :desc) |> Nx.as_type(:u32) - end - defp recur_slice(tensor, indexes, k, subtree_size) do - sorted = Nx.argsort(Nx.take(tensor, indexes)[[.., k]]) + sorted = Nx.argsort(Nx.take(tensor, indexes)[[.., k]], type: :u32) indexes = Nx.take(indexes, sorted) {Nx.slice(indexes, [0], [subtree_size]), indexes[subtree_size], @@ -124,14 +119,107 @@ defmodule Scholar.Neighbors.KDTree do end defp unbanded_subtree_size(i, levels, size) do + import Bitwise diff = levels - unbanded_level(i) - 1 - inner = (1 <<< diff) - 1 - fllc_s = (i <<< diff) + inner - inner + min(max(0, size - fllc_s), 1 <<< diff) + shifted = 1 <<< diff + fllc_s = (i <<< diff) + shifted - 1 + shifted - 1 + min(max(0, size - fllc_s), shifted) end defp unbanded_level(i) when is_integer(i), do: 31 - clz32(i + 1) + @doc """ + BANDED + """ + defn banded(tensor, amplitude) do + levels = levels(tensor) + {size, dims} = Nx.shape(tensor) + band = amplitude + 1 + tags = Nx.broadcast(Nx.u32(0), {size}) + + {_level, tags, _tensor, _band} = + while {level = 0, tags, tensor, band}, level < levels - 1 do + k = rem(level, dims) + indexes = Nx.argsort(tensor[[.., k]] + band * tags, type: :u32) + tags = update_tags(tags, indexes, level, levels, size) + {level + 1, tags, tensor, band} + end + + %__MODULE__{levels: levels, indexes: tags} + end + + defnp update_tags(tags, indexes, level, levels, size) do + # 1 + # indexes = [0, 1, 2, 3, 4] + # tags = [0, 0, 0, 0, 0] + # out = [1, 1, 1, 0, 2] + # + # 2 + # indexes = [3, 0, 1, 2, 4] + # tags = [1, 1, 1, 0, 2] + # out = [3, 1, 4, 0, 2] + # + # out = [3, 1, 4, 0, 2] + pos = Nx.argsort(indexes) |> print_value(label: "POS") + + pivot = + (print_value(banded_segment_begin(tags, levels, size), label: "sb") + + print_value(banded_subtree_size(left_child(tags), levels, size), label: "ss")) + |> print_value(label: "PIVOT") + + Nx.select( + pos < (1 <<< level) - 1, + tags, + Nx.select( + pos < pivot, + left_child(tags), + Nx.select( + pos > pivot, + right_child(tags), + tags + ) + ) + ) + |> print_value(label: "TAGS") + end + + defnp banded_subtree_size(i, levels, size) do + diff = levels - banded_level(i) - 1 + shifted = 1 <<< diff + fllc_s = (i <<< diff) + shifted - 1 + shifted - 1 + min(max(0, size - fllc_s), shifted) + end + + # defnp banded_segment_begin(t, levels, size) do + # while t, j <- 0..(size - 1) do + # s = t[j] + # i = (1 <<< banded_level(s)) - 1 + + # {_, _, acc} = + # while {i, s, acc = i}, i + 1 <= s do + # {i + 1, s, acc + banded_subtree_size(i, levels, size)} + # end + + # Nx.put_slice(t, [j], Nx.stack(acc)) + # end + # end + + # defnp banded_segment_begin(i, levels, size) do + # level = banded_level(i) + # top = (1 <<< level) - 1 + # diff = levels - level - 1 + # shifted = 1 <<< diff + # left_siblings = i - top + + # left_siblings_inner = left_siblings * (shifted - 1) + # left_siblings_last_if_full = left_siblings * (1 <<< (diff - 1)) + + # top + left_siblings_inner + left_siblings_last_if_full - + # min(left_siblings_last_if_full, size - (1 <<< (levels - 1)) - 1) + # end + + defnp banded_level(i), do: 31 - Nx.count_leading_zeros(i + 1) + @doc """ Returns the amplitude of a tensor for banding. @@ -232,6 +320,8 @@ defmodule Scholar.Neighbors.KDTree do @clz_lookup {32, 31, 30, 30, 29, 29, 29, 29, 28, 28, 28, 28, 28, 28, 28, 28} defp clz32(x) when is_integer(x) do + import Bitwise + n = if x >= 1 <<< 16 do if x >= 1 <<< 24 do diff --git a/mix.exs b/mix.exs index c439392e..55d286fa 100644 --- a/mix.exs +++ b/mix.exs @@ -31,7 +31,7 @@ defmodule Scholar.MixProject do [ {:ex_doc, "~> 0.30", only: :docs}, # {:nx, "~> 0.6", override: true}, - {:nx, github: "elixir-nx/nx", sparse: "nx", override: true}, + {:nx, github: "elixir-nx/nx", sparse: "nx", override: true, branch: "v0.6"}, {:nimble_options, "~> 0.5.2 or ~> 1.0"}, {:exla, "~> 0.6", optional: true}, {:polaris, "~> 0.1"} diff --git a/mix.lock b/mix.lock index cb415b64..5087bf8d 100644 --- a/mix.lock +++ b/mix.lock @@ -9,7 +9,7 @@ "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, "nimble_options": {:hex, :nimble_options, "0.5.2", "42703307b924880f8c08d97719da7472673391905f528259915782bb346e0a1b", [:mix], [], "hexpm", "4da7f904b915fd71db549bcdc25f8d56f378ef7ae07dc1d372cbe72ba950dce0"}, "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, - "nx": {:git, "https://github.com/elixir-nx/nx.git", "a0b7e2e5cc7a62a55cd2e7bbc3e44ba2ac1c996b", [sparse: "nx"]}, + "nx": {:git, "https://github.com/elixir-nx/nx.git", "e52d9097a52ae39c1ece1dcc2c12ad6456fc0fe2", [sparse: "nx", branch: "v0.6"]}, "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "xla": {:hex, :xla, "0.5.0", "fb8a02c02e5a4f4531fbf18a90c325e471037f983f0115d23f510e7dd9a6aa65", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "571ac797a4244b8ba8552ed0295a54397bd896708be51e4da6cbb784f6678061"}, diff --git a/test/scholar/neighbors/kd_tree_test.exs b/test/scholar/neighbors/kd_tree_test.exs index 280fb16f..41deaba2 100644 --- a/test/scholar/neighbors/kd_tree_test.exs +++ b/test/scholar/neighbors/kd_tree_test.exs @@ -33,4 +33,18 @@ defmodule Scholar.Neighbors.KDTreeTest do %Scholar.Neighbors.KDTree{levels: 2, indexes: Nx.u32([1, 0])} end end + + describe "banded" do + test "iota" do + assert Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 100) == + %Scholar.Neighbors.KDTree{levels: 3, indexes: Nx.u32([3, 1, 4, 0, 2])} + end + + test "sample" do + assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = + Scholar.Neighbors.KDTree.unbanded(example(), compiler: EXLA.Defn) + + assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + end + end end From 24aca45d5cbfad206cc3e34099723cf6ed26fbf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 5 Nov 2023 17:16:02 +0100 Subject: [PATCH 04/12] works --- lib/scholar/neighbors/kd_tree.ex | 69 +++++++++++-------------- test/scholar/neighbors/kd_tree_test.exs | 11 +++- 2 files changed, 39 insertions(+), 41 deletions(-) diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index cf31c6d4..d53ee1ff 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -137,35 +137,25 @@ defmodule Scholar.Neighbors.KDTree do band = amplitude + 1 tags = Nx.broadcast(Nx.u32(0), {size}) - {_level, tags, _tensor, _band} = - while {level = 0, tags, tensor, band}, level < levels - 1 do + {level, tags, _tensor, _band} = + while {level = Nx.u32(0), tags, tensor, band}, level < levels - 1 do k = rem(level, dims) indexes = Nx.argsort(tensor[[.., k]] + band * tags, type: :u32) tags = update_tags(tags, indexes, level, levels, size) {level + 1, tags, tensor, band} end - %__MODULE__{levels: levels, indexes: tags} + k = rem(level, dims) + indexes = Nx.argsort(tensor[[.., k]] + band * tags, type: :u32) + %__MODULE__{levels: levels, indexes: indexes} end defnp update_tags(tags, indexes, level, levels, size) do - # 1 - # indexes = [0, 1, 2, 3, 4] - # tags = [0, 0, 0, 0, 0] - # out = [1, 1, 1, 0, 2] - # - # 2 - # indexes = [3, 0, 1, 2, 4] - # tags = [1, 1, 1, 0, 2] - # out = [3, 1, 4, 0, 2] - # - # out = [3, 1, 4, 0, 2] - pos = Nx.argsort(indexes) |> print_value(label: "POS") + pos = Nx.argsort(indexes, type: :u32) pivot = - (print_value(banded_segment_begin(tags, levels, size), label: "sb") + - print_value(banded_subtree_size(left_child(tags), levels, size), label: "ss")) - |> print_value(label: "PIVOT") + banded_segment_begin(tags, levels, size) + + banded_subtree_size(left_child(tags), levels, size) Nx.select( pos < (1 <<< level) - 1, @@ -180,45 +170,46 @@ defmodule Scholar.Neighbors.KDTree do ) ) ) - |> print_value(label: "TAGS") end defnp banded_subtree_size(i, levels, size) do diff = levels - banded_level(i) - 1 shifted = 1 <<< diff - fllc_s = (i <<< diff) + shifted - 1 - shifted - 1 + min(max(0, size - fllc_s), shifted) + first_lowest_level = (i <<< diff) + shifted - 1 + # Use select instead of max to deal with overflows + lowest_level = Nx.select(first_lowest_level > size, Nx.u32(0), size - first_lowest_level) + shifted - 1 + min(lowest_level, shifted) end - # defnp banded_segment_begin(t, levels, size) do - # while t, j <- 0..(size - 1) do - # s = t[j] - # i = (1 <<< banded_level(s)) - 1 + defn banded_segment_begin(t, levels, size) do + while t, j <- 0..(size - 1) do + s = t[j] + i = (1 <<< banded_level(s)) - 1 - # {_, _, acc} = - # while {i, s, acc = i}, i + 1 <= s do - # {i + 1, s, acc + banded_subtree_size(i, levels, size)} - # end + {_, _, acc} = + while {i, s, acc = i}, i + 1 <= s do + {i + 1, s, acc + banded_subtree_size(i, levels, size)} + end - # Nx.put_slice(t, [j], Nx.stack(acc)) - # end - # end + Nx.put_slice(t, [j], Nx.stack(acc)) + end + end - # defnp banded_segment_begin(i, levels, size) do + # defn banded_segment_begin(i, levels, size) do # level = banded_level(i) # top = (1 <<< level) - 1 # diff = levels - level - 1 # shifted = 1 <<< diff # left_siblings = i - top - # left_siblings_inner = left_siblings * (shifted - 1) - # left_siblings_last_if_full = left_siblings * (1 <<< (diff - 1)) - - # top + left_siblings_inner + left_siblings_last_if_full - - # min(left_siblings_last_if_full, size - (1 <<< (levels - 1)) - 1) + # top + left_siblings * (shifted - 1) + + # min(left_siblings * shifted, size - (1 <<< (levels - 1)) - 1) # end - defnp banded_level(i), do: 31 - Nx.count_leading_zeros(i + 1) + # Since this property relies on u32, let's check the tensor type. + deftransformp banded_level(%Nx.Tensor{type: {:u, 32}} = i) do + Nx.subtract(31, Nx.count_leading_zeros(Nx.add(i, 1))) + end @doc """ Returns the amplitude of a tensor for banding. diff --git a/test/scholar/neighbors/kd_tree_test.exs b/test/scholar/neighbors/kd_tree_test.exs index 41deaba2..a8678c66 100644 --- a/test/scholar/neighbors/kd_tree_test.exs +++ b/test/scholar/neighbors/kd_tree_test.exs @@ -36,13 +36,20 @@ defmodule Scholar.Neighbors.KDTreeTest do describe "banded" do test "iota" do - assert Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 100) == + assert Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10) == %Scholar.Neighbors.KDTree{levels: 3, indexes: Nx.u32([3, 1, 4, 0, 2])} end test "sample" do + input = Nx.u32([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + + assert Nx.Defn.jit_apply( + &Scholar.Neighbors.KDTree.banded_segment_begin(&1, 4, 10), + [input] + ) == Nx.u32([0, 1, 7, 3, 6, 8, 9, 7, 8, 9]) + assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = - Scholar.Neighbors.KDTree.unbanded(example(), compiler: EXLA.Defn) + Scholar.Neighbors.KDTree.banded(example(), 100) assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] end From fd4d1d943adbb119a0e9efe5e477874d8e9c21f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 5 Nov 2023 17:31:29 +0100 Subject: [PATCH 05/12] Banded works --- lib/scholar/neighbors/kd_tree.ex | 31 ++++++++----------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index d53ee1ff..01e2202d 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -26,7 +26,6 @@ defmodule Scholar.Neighbors.KDTree do import Nx.Defn # TODO: Benchmark - # TODO: Add tagged/amplitude version @derive {Nx.Container, keep: [:levels], containers: [:indexes]} @enforce_keys [:levels, :indexes] @@ -181,31 +180,17 @@ defmodule Scholar.Neighbors.KDTree do shifted - 1 + min(lowest_level, shifted) end - defn banded_segment_begin(t, levels, size) do - while t, j <- 0..(size - 1) do - s = t[j] - i = (1 <<< banded_level(s)) - 1 - - {_, _, acc} = - while {i, s, acc = i}, i + 1 <= s do - {i + 1, s, acc + banded_subtree_size(i, levels, size)} - end + defn banded_segment_begin(i, levels, size) do + level = banded_level(i) + top = (1 <<< level) - 1 + diff = levels - level - 1 + shifted = 1 <<< diff + left_siblings = i - top - Nx.put_slice(t, [j], Nx.stack(acc)) - end + top + left_siblings * (shifted - 1) + + min(left_siblings * shifted, size - (1 <<< (levels - 1)) + 1) end - # defn banded_segment_begin(i, levels, size) do - # level = banded_level(i) - # top = (1 <<< level) - 1 - # diff = levels - level - 1 - # shifted = 1 <<< diff - # left_siblings = i - top - - # top + left_siblings * (shifted - 1) + - # min(left_siblings * shifted, size - (1 <<< (levels - 1)) - 1) - # end - # Since this property relies on u32, let's check the tensor type. deftransformp banded_level(%Nx.Tensor{type: {:u, 32}} = i) do Nx.subtract(31, Nx.count_leading_zeros(Nx.add(i, 1))) From 9a578276e7f286b262a4d0c70386b784d6384f04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 5 Nov 2023 19:13:16 +0100 Subject: [PATCH 06/12] Add benchmarks --- benchmarks/kdtree.exs | 14 ++++++ lib/scholar/neighbors/kd_tree.ex | 63 +++++++++++++++++++++---- mix.exs | 3 +- mix.lock | 5 +- test/scholar/neighbors/kd_tree_test.exs | 21 ++++++--- 5 files changed, 90 insertions(+), 16 deletions(-) create mode 100644 benchmarks/kdtree.exs diff --git a/benchmarks/kdtree.exs b/benchmarks/kdtree.exs new file mode 100644 index 00000000..dd437c7f --- /dev/null +++ b/benchmarks/kdtree.exs @@ -0,0 +1,14 @@ +Nx.global_default_backend(EXLA.Backend) +Nx.Defn.global_default_options(compiler: EXLA) + +key = Nx.Random.key(System.os_time()) +{uniform, _new_key} = Nx.Random.uniform(key, shape: {1000, 3}) + +Benchee.run( + %{ + "unbanded" => fn -> Scholar.Neighbors.KDTree.unbanded(uniform) end, + "banded" => fn -> Scholar.Neighbors.KDTree.banded(uniform, 2) end + }, + time: 10, + memory_time: 2 +) diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index 01e2202d..fe15d518 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -12,12 +12,16 @@ defmodule Scholar.Neighbors.KDTree do * `banded/2` - the tensor has min and max values with an amplitude given by `max - min`. It is also guaranteed that the `amplitude * levels(tensor) + 1` does not overflow the tensor. See `amplitude/1` to verify if this holds. This implementation happens - fully within `defn`. + fully within `defn`. This version is orders of magnitude faster than the `unbanded/2` + one. * `unbanded/2` - there are no known bands (min and max values) to the tensor. This implementation is recursive and goes in and out of the `defn`, therefore it cannot be called inside `defn`. + Each level traverses over the last axis of tensor, the index for a level can be + computed as: `rem(level, Nx.axis_size(tensor, -1))`. + ## References * [GPU-friendly, Parallel, and (Almost-)In-Place Construction of Left-Balanced k-d Trees](https://arxiv.org/pdf/2211.00120.pdf). @@ -25,8 +29,6 @@ defmodule Scholar.Neighbors.KDTree do import Nx.Defn - # TODO: Benchmark - @derive {Nx.Container, keep: [:levels], containers: [:indexes]} @enforce_keys [:levels, :indexes] defstruct [:levels, :indexes] @@ -34,8 +36,9 @@ defmodule Scholar.Neighbors.KDTree do @doc """ Builds a KDTree without known min-max bounds. - If your tensor has a known bound (for example, -1 and 1), - consider using the `banded/2` version which is more efficient. + If your tensor has a known band (for example, -1 and 1), + consider using the `banded/2` version which is often orders of + magnitude more efficient. ## Options @@ -128,7 +131,21 @@ defmodule Scholar.Neighbors.KDTree do defp unbanded_level(i) when is_integer(i), do: 31 - clz32(i + 1) @doc """ - BANDED + Builds a KDTree with known min-max bounds entirely within `defn`. + + This requires the amplitude `|max - min|` of the tensor to be given. + For example, a tensor where all values are between 0 and 1 has amplitude + 1. Values between -1 and 1 has amplitude 2. If your tensor is normalized, + then you know the amplitude. Otherwise you can use `amplitude/1` to check + it. + + ## Examples + + iex> Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10) + %Scholar.Neighbors.KDTree{ + levels: 3, + indexes: Nx.u32([3, 1, 4, 0, 2]) + } """ defn banded(tensor, amplitude) do levels = levels(tensor) @@ -180,7 +197,7 @@ defmodule Scholar.Neighbors.KDTree do shifted - 1 + min(lowest_level, shifted) end - defn banded_segment_begin(i, levels, size) do + defnp banded_segment_begin(i, levels, size) do level = banded_level(i) top = (1 <<< level) - 1 diff = levels - level - 1 @@ -212,12 +229,14 @@ defmodule Scholar.Neighbors.KDTree do 39.0 iex> Scholar.Neighbors.KDTree.amplitude(Nx.iota({20, 2}, type: :u8)) -1 + iex> Scholar.Neighbors.KDTree.amplitude(Nx.negate(Nx.iota({10, 2}))) + 19 """ def amplitude(tensor) do max = tensor |> Nx.reduce_max() |> Nx.to_number() min = tensor |> Nx.reduce_min() |> Nx.to_number() - amplitude = max - min + amplitude = abs(max - min) limit = tensor.type |> Nx.Constants.max_finite() |> Nx.to_number() if max + (amplitude + 1) * (Nx.axis_size(tensor, 0) - 1) > limit do @@ -253,9 +272,34 @@ defmodule Scholar.Neighbors.KDTree do """ deftransform root, do: 0 + @doc """ + Returns the parent of child `i`. + + It is your responsibility to guarantee the result is positive. + + ## Examples + + iex> Scholar.Neighbors.KDTree.parent(1) + 0 + iex> Scholar.Neighbors.KDTree.parent(2) + 0 + + iex> Scholar.Neighbors.KDTree.parent(Nx.u32(3)) + #Nx.Tensor< + u32 + 1 + > + + """ + deftransform parent(i) when is_integer(i), do: div(i - 1, 2) + deftransform parent(%Nx.Tensor{} = t), do: Nx.quotient(Nx.subtract(t, 1), 2) + @doc """ Returns the index of the left child of i. + It is your responsibility to guarantee the result + is not greater than the leading axis of the tensor. + ## Examples iex> Scholar.Neighbors.KDTree.left_child(0) @@ -276,6 +320,9 @@ defmodule Scholar.Neighbors.KDTree do @doc """ Returns the index of the right child of i. + It is your responsibility to guarantee the result + is not greater than the leading axis of the tensor. + ## Examples iex> Scholar.Neighbors.KDTree.right_child(0) diff --git a/mix.exs b/mix.exs index 55d286fa..02acd3bb 100644 --- a/mix.exs +++ b/mix.exs @@ -34,7 +34,8 @@ defmodule Scholar.MixProject do {:nx, github: "elixir-nx/nx", sparse: "nx", override: true, branch: "v0.6"}, {:nimble_options, "~> 0.5.2 or ~> 1.0"}, {:exla, "~> 0.6", optional: true}, - {:polaris, "~> 0.1"} + {:polaris, "~> 0.1"}, + {:benchee, "~> 1.0", only: :dev} ] end diff --git a/mix.lock b/mix.lock index 5087bf8d..c17916d2 100644 --- a/mix.lock +++ b/mix.lock @@ -1,5 +1,7 @@ %{ + "benchee": {:hex, :benchee, "1.1.0", "f3a43817209a92a1fade36ef36b86e1052627fd8934a8b937ac9ab3a76c43062", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}, {:statistex, "~> 1.0", [hex: :statistex, repo: "hexpm", optional: false]}], "hexpm", "7da57d545003165a012b587077f6ba90b89210fd88074ce3c60ce239eb5e6d93"}, "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, + "deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"}, "earmark_parser": {:hex, :earmark_parser, "1.4.37", "2ad73550e27c8946648b06905a57e4d454e4d7229c2dafa72a0348c99d8be5f7", [:mix], [], "hexpm", "6b19783f2802f039806f375610faa22da130b8edc21209d0bff47918bb48360e"}, "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, "ex_doc": {:hex, :ex_doc, "0.30.6", "5f8b54854b240a2b55c9734c4b1d0dd7bdd41f71a095d42a70445c03cf05a281", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "bd48f2ddacf4e482c727f9293d9498e0881597eae6ddc3d9562bd7923375109f"}, @@ -9,8 +11,9 @@ "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, "nimble_options": {:hex, :nimble_options, "0.5.2", "42703307b924880f8c08d97719da7472673391905f528259915782bb346e0a1b", [:mix], [], "hexpm", "4da7f904b915fd71db549bcdc25f8d56f378ef7ae07dc1d372cbe72ba950dce0"}, "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, - "nx": {:git, "https://github.com/elixir-nx/nx.git", "e52d9097a52ae39c1ece1dcc2c12ad6456fc0fe2", [sparse: "nx", branch: "v0.6"]}, + "nx": {:git, "https://github.com/elixir-nx/nx.git", "ef464cfd0935eb4c2c1fa9a40f099b098a0b95bf", [sparse: "nx", branch: "v0.6"]}, "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, + "statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "xla": {:hex, :xla, "0.5.0", "fb8a02c02e5a4f4531fbf18a90c325e471037f983f0115d23f510e7dd9a6aa65", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "571ac797a4244b8ba8552ed0295a54397bd896708be51e4da6cbb784f6678061"}, } diff --git a/test/scholar/neighbors/kd_tree_test.exs b/test/scholar/neighbors/kd_tree_test.exs index a8678c66..66ee0f9f 100644 --- a/test/scholar/neighbors/kd_tree_test.exs +++ b/test/scholar/neighbors/kd_tree_test.exs @@ -25,6 +25,15 @@ defmodule Scholar.Neighbors.KDTreeTest do assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] end + test "float" do + assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = + Scholar.Neighbors.KDTree.unbanded(example() |> Nx.as_type(:f32), + compiler: EXLA.Defn + ) + + assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + end + test "corner cases" do assert Scholar.Neighbors.KDTree.unbanded(Nx.iota({1, 2}), compiler: EXLA.Defn) == %Scholar.Neighbors.KDTree{levels: 1, indexes: Nx.u32([0])} @@ -40,14 +49,14 @@ defmodule Scholar.Neighbors.KDTreeTest do %Scholar.Neighbors.KDTree{levels: 3, indexes: Nx.u32([3, 1, 4, 0, 2])} end - test "sample" do - input = Nx.u32([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + test "float" do + assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = + Scholar.Neighbors.KDTree.banded(example() |> Nx.as_type(:f32), 100) - assert Nx.Defn.jit_apply( - &Scholar.Neighbors.KDTree.banded_segment_begin(&1, 4, 10), - [input] - ) == Nx.u32([0, 1, 7, 3, 6, 8, 9, 7, 8, 9]) + assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + end + test "sample" do assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = Scholar.Neighbors.KDTree.banded(example(), 100) From a66037933867f1ef548ab93093e5d764ca6a080f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 5 Nov 2023 19:15:12 +0100 Subject: [PATCH 07/12] Rename --- benchmarks/{kdtree.exs => kd_tree.exs} | 1 + 1 file changed, 1 insertion(+) rename benchmarks/{kdtree.exs => kd_tree.exs} (92%) diff --git a/benchmarks/kdtree.exs b/benchmarks/kd_tree.exs similarity index 92% rename from benchmarks/kdtree.exs rename to benchmarks/kd_tree.exs index dd437c7f..1552ac2d 100644 --- a/benchmarks/kdtree.exs +++ b/benchmarks/kd_tree.exs @@ -1,3 +1,4 @@ +# mix run benchmarks/kd_tree.exs Nx.global_default_backend(EXLA.Backend) Nx.Defn.global_default_options(compiler: EXLA) From 45adb49a2cf6231ce1e8db622dfb102cb0325353 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 5 Nov 2023 19:17:45 +0100 Subject: [PATCH 08/12] More --- test/test_helper.exs | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_helper.exs b/test/test_helper.exs index e69de29b..5e96292b 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -0,0 +1 @@ +Application.ensure_all_started(:exla) From 8a4ee41509f569fc6e09f7a6cf8121bae9aed945 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 5 Nov 2023 22:58:48 +0100 Subject: [PATCH 09/12] Store data in the KDTree --- lib/scholar/neighbors/kd_tree.ex | 19 ++++++++++++------- test/scholar/neighbors/kd_tree_test.exs | 18 ++++++++++++------ 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index fe15d518..e27473b7 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -29,9 +29,9 @@ defmodule Scholar.Neighbors.KDTree do import Nx.Defn - @derive {Nx.Container, keep: [:levels], containers: [:indexes]} - @enforce_keys [:levels, :indexes] - defstruct [:levels, :indexes] + @derive {Nx.Container, keep: [:levels], containers: [:indexes, :data]} + @enforce_keys [:levels, :indexes, :data] + defstruct [:levels, :indexes, :data] @doc """ Builds a KDTree without known min-max bounds. @@ -48,6 +48,7 @@ defmodule Scholar.Neighbors.KDTree do iex> Scholar.Neighbors.KDTree.unbanded(Nx.iota({5, 2}), compiler: EXLA.Defn) %Scholar.Neighbors.KDTree{ + data: Nx.iota({5, 2}), levels: 3, indexes: Nx.u32([3, 1, 4, 0, 2]) } @@ -69,7 +70,7 @@ defmodule Scholar.Neighbors.KDTree do Nx.argsort(tensor[[.., 0]], direction: :desc, type: :u32) end - %__MODULE__{levels: levels, indexes: indexes} + %__MODULE__{levels: levels, indexes: indexes, data: tensor} end defp recur([{_i, %Nx.Tensor{shape: {1}} = leaf} | rest], next, acc, tensor, level, levels, opts) do @@ -143,11 +144,16 @@ defmodule Scholar.Neighbors.KDTree do iex> Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10) %Scholar.Neighbors.KDTree{ + data: Nx.iota({5, 2}), levels: 3, indexes: Nx.u32([3, 1, 4, 0, 2]) } """ - defn banded(tensor, amplitude) do + deftransform banded(tensor, amplitude) do + %__MODULE__{levels: levels(tensor), indexes: banded_n(tensor, amplitude), data: tensor} + end + + defnp banded_n(tensor, amplitude) do levels = levels(tensor) {size, dims} = Nx.shape(tensor) band = amplitude + 1 @@ -162,8 +168,7 @@ defmodule Scholar.Neighbors.KDTree do end k = rem(level, dims) - indexes = Nx.argsort(tensor[[.., k]] + band * tags, type: :u32) - %__MODULE__{levels: levels, indexes: indexes} + Nx.argsort(tensor[[.., k]] + band * tags, type: :u32) end defnp update_tags(tags, indexes, level, levels, size) do diff --git a/test/scholar/neighbors/kd_tree_test.exs b/test/scholar/neighbors/kd_tree_test.exs index 66ee0f9f..90835e35 100644 --- a/test/scholar/neighbors/kd_tree_test.exs +++ b/test/scholar/neighbors/kd_tree_test.exs @@ -35,18 +35,24 @@ defmodule Scholar.Neighbors.KDTreeTest do end test "corner cases" do - assert Scholar.Neighbors.KDTree.unbanded(Nx.iota({1, 2}), compiler: EXLA.Defn) == - %Scholar.Neighbors.KDTree{levels: 1, indexes: Nx.u32([0])} + assert %Scholar.Neighbors.KDTree{levels: 1, indexes: indexes} = + Scholar.Neighbors.KDTree.unbanded(Nx.iota({1, 2}), compiler: EXLA.Defn) - assert Scholar.Neighbors.KDTree.unbanded(Nx.iota({2, 2}), compiler: EXLA.Defn) == - %Scholar.Neighbors.KDTree{levels: 2, indexes: Nx.u32([1, 0])} + assert indexes == Nx.u32([0]) + + assert %Scholar.Neighbors.KDTree{levels: 2, indexes: indexes} = + Scholar.Neighbors.KDTree.unbanded(Nx.iota({2, 2}), compiler: EXLA.Defn) + + assert indexes == Nx.u32([1, 0]) end end describe "banded" do test "iota" do - assert Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10) == - %Scholar.Neighbors.KDTree{levels: 3, indexes: Nx.u32([3, 1, 4, 0, 2])} + assert %Scholar.Neighbors.KDTree{levels: 3, indexes: indexes} = + Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10) + + assert indexes == Nx.u32([3, 1, 4, 0, 2]) end test "float" do From ba86a276674ea4aa32b1e2e286c770275bf6f79b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Mon, 6 Nov 2023 09:13:59 +0100 Subject: [PATCH 10/12] band -> bound, clz -> log --- benchmarks/kd_tree.exs | 4 +- lib/scholar/neighbors/kd_tree.ex | 88 ++++++++++--------------- test/scholar/neighbors/kd_tree_test.exs | 18 ++--- 3 files changed, 45 insertions(+), 65 deletions(-) diff --git a/benchmarks/kd_tree.exs b/benchmarks/kd_tree.exs index 1552ac2d..74dc3a33 100644 --- a/benchmarks/kd_tree.exs +++ b/benchmarks/kd_tree.exs @@ -7,8 +7,8 @@ key = Nx.Random.key(System.os_time()) Benchee.run( %{ - "unbanded" => fn -> Scholar.Neighbors.KDTree.unbanded(uniform) end, - "banded" => fn -> Scholar.Neighbors.KDTree.banded(uniform, 2) end + "unbound" => fn -> Scholar.Neighbors.KDTree.unbound(uniform) end, + "bound" => fn -> Scholar.Neighbors.KDTree.bound(uniform, 2) end }, time: 10, memory_time: 2 diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index e27473b7..b0398132 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -9,13 +9,13 @@ defmodule Scholar.Neighbors.KDTree do Two construction modes are available: - * `banded/2` - the tensor has min and max values with an amplitude given by `max - min`. + * `bound/2` - the tensor has min and max values with an amplitude given by `max - min`. It is also guaranteed that the `amplitude * levels(tensor) + 1` does not overflow the tensor. See `amplitude/1` to verify if this holds. This implementation happens - fully within `defn`. This version is orders of magnitude faster than the `unbanded/2` + fully within `defn`. This version is orders of magnitude faster than the `unbound/2` one. - * `unbanded/2` - there are no known bands (min and max values) to the tensor. + * `unbound/2` - there are no known bounds (min and max values) to the tensor. This implementation is recursive and goes in and out of the `defn`, therefore it cannot be called inside `defn`. @@ -36,8 +36,8 @@ defmodule Scholar.Neighbors.KDTree do @doc """ Builds a KDTree without known min-max bounds. - If your tensor has a known band (for example, -1 and 1), - consider using the `banded/2` version which is often orders of + If your tensor has known bounds (for example, -1 and 1), + consider using the `bound/2` version which is often orders of magnitude more efficient. ## Options @@ -46,7 +46,7 @@ defmodule Scholar.Neighbors.KDTree do ## Examples - iex> Scholar.Neighbors.KDTree.unbanded(Nx.iota({5, 2}), compiler: EXLA.Defn) + iex> Scholar.Neighbors.KDTree.unbound(Nx.iota({5, 2}), compiler: EXLA.Defn) %Scholar.Neighbors.KDTree{ data: Nx.iota({5, 2}), levels: 3, @@ -54,13 +54,13 @@ defmodule Scholar.Neighbors.KDTree do } """ - def unbanded(tensor, opts \\ []) do + def unbound(tensor, opts \\ []) do levels = levels(tensor) {size, _dims} = Nx.shape(tensor) indexes = if size > 2 do - subtree_size = unbanded_subtree_size(1, levels, size) + subtree_size = unbound_subtree_size(1, levels, size) {left, mid, right} = Nx.Defn.jit_apply(&root_slice(&1, subtree_size), [tensor], opts) acc = <> @@ -88,7 +88,7 @@ defmodule Scholar.Neighbors.KDTree do defp recur([{i, indexes} | rest], next, acc, tensor, level, levels, opts) do %Nx.Tensor{shape: {size, dims}} = tensor k = rem(level, dims) - subtree_size = unbanded_subtree_size(left_child(i), levels, size) + subtree_size = unbound_subtree_size(left_child(i), levels, size) {left, mid, right} = Nx.Defn.jit_apply(&recur_slice(&1, &2, &3, subtree_size), [tensor, indexes, k], opts) @@ -121,39 +121,42 @@ defmodule Scholar.Neighbors.KDTree do Nx.slice(indexes, [subtree_size + 1], [Nx.size(indexes) - subtree_size - 1])} end - defp unbanded_subtree_size(i, levels, size) do + defp unbound_subtree_size(i, levels, size) do import Bitwise - diff = levels - unbanded_level(i) - 1 + diff = levels - unbound_level(i) - 1 shifted = 1 <<< diff fllc_s = (i <<< diff) + shifted - 1 shifted - 1 + min(max(0, size - fllc_s), shifted) end - defp unbanded_level(i) when is_integer(i), do: 31 - clz32(i + 1) + defp unbound_level(i) when is_integer(i), do: floor(:math.log2(i + 1)) @doc """ Builds a KDTree with known min-max bounds entirely within `defn`. - This requires the amplitude `|max - min|` of the tensor to be given. + This requires the amplitude `|max - min|` of the tensor to be given + such that `max + (amplitude + 1) * (size - 1)` does not overflow the + maximum tensor type. + For example, a tensor where all values are between 0 and 1 has amplitude - 1. Values between -1 and 1 has amplitude 2. If your tensor is normalized, - then you know the amplitude. Otherwise you can use `amplitude/1` to check - it. + 1. Values between -1 and 1 has amplitude 2. If your tensor is normalized + to floating points, then it is most likely bound (given their high + precision). You can use `amplitude/1` to check your assumptions. ## Examples - iex> Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10) + iex> Scholar.Neighbors.KDTree.bound(Nx.iota({5, 2}), 10) %Scholar.Neighbors.KDTree{ data: Nx.iota({5, 2}), levels: 3, indexes: Nx.u32([3, 1, 4, 0, 2]) } """ - deftransform banded(tensor, amplitude) do - %__MODULE__{levels: levels(tensor), indexes: banded_n(tensor, amplitude), data: tensor} + deftransform bound(tensor, amplitude) do + %__MODULE__{levels: levels(tensor), indexes: bound_n(tensor, amplitude), data: tensor} end - defnp banded_n(tensor, amplitude) do + defnp bound_n(tensor, amplitude) do levels = levels(tensor) {size, dims} = Nx.shape(tensor) band = amplitude + 1 @@ -175,8 +178,8 @@ defmodule Scholar.Neighbors.KDTree do pos = Nx.argsort(indexes, type: :u32) pivot = - banded_segment_begin(tags, levels, size) + - banded_subtree_size(left_child(tags), levels, size) + bound_segment_begin(tags, levels, size) + + bound_subtree_size(left_child(tags), levels, size) Nx.select( pos < (1 <<< level) - 1, @@ -193,8 +196,8 @@ defmodule Scholar.Neighbors.KDTree do ) end - defnp banded_subtree_size(i, levels, size) do - diff = levels - banded_level(i) - 1 + defnp bound_subtree_size(i, levels, size) do + diff = levels - bound_level(i) - 1 shifted = 1 <<< diff first_lowest_level = (i <<< diff) + shifted - 1 # Use select instead of max to deal with overflows @@ -202,8 +205,8 @@ defmodule Scholar.Neighbors.KDTree do shifted - 1 + min(lowest_level, shifted) end - defnp banded_segment_begin(i, levels, size) do - level = banded_level(i) + defnp bound_segment_begin(i, levels, size) do + level = bound_level(i) top = (1 <<< level) - 1 diff = levels - level - 1 shifted = 1 <<< diff @@ -214,15 +217,15 @@ defmodule Scholar.Neighbors.KDTree do end # Since this property relies on u32, let's check the tensor type. - deftransformp banded_level(%Nx.Tensor{type: {:u, 32}} = i) do + deftransformp bound_level(%Nx.Tensor{type: {:u, 32}} = i) do Nx.subtract(31, Nx.count_leading_zeros(Nx.add(i, 1))) end @doc """ - Returns the amplitude of a tensor for banding. + Returns the amplitude of a bounded tensor. - If -1 is returned, it means the tensor cannot use the `banded` algorithm - to generate a KDTree and `unbanded/2` must be used instead. + If -1 is returned, it means the tensor cannot use the `bound` algorithm + to generate a KDTree and `unbound/2` must be used instead. This cannot be invoked inside a `defn`. @@ -261,7 +264,7 @@ defmodule Scholar.Neighbors.KDTree do """ deftransform levels(%Nx.Tensor{} = tensor) do case Nx.shape(tensor) do - {size, _dims} -> 32 - clz32(size) + {size, _dims} -> ceil(:math.log2(size + 1)) _ -> raise ArgumentError, "KDTrees requires a tensor of rank 2" end end @@ -344,27 +347,4 @@ defmodule Scholar.Neighbors.KDTree do """ deftransform right_child(i) when is_integer(i), do: 2 * i + 2 deftransform right_child(%Nx.Tensor{} = t), do: Nx.add(Nx.multiply(2, t), 2) - - @clz_lookup {32, 31, 30, 30, 29, 29, 29, 29, 28, 28, 28, 28, 28, 28, 28, 28} - - defp clz32(x) when is_integer(x) do - import Bitwise - - n = - if x >= 1 <<< 16 do - if x >= 1 <<< 24 do - if x >= 1 <<< 28, do: 28, else: 24 - else - if x >= 1 <<< 20, do: 20, else: 16 - end - else - if x >= 1 <<< 8 do - if x >= 1 <<< 12, do: 12, else: 8 - else - if x >= 1 <<< 4, do: 4, else: 0 - end - end - - elem(@clz_lookup, x >>> n) - n - end end diff --git a/test/scholar/neighbors/kd_tree_test.exs b/test/scholar/neighbors/kd_tree_test.exs index 90835e35..c67c4106 100644 --- a/test/scholar/neighbors/kd_tree_test.exs +++ b/test/scholar/neighbors/kd_tree_test.exs @@ -17,17 +17,17 @@ defmodule Scholar.Neighbors.KDTreeTest do ]) end - describe "unbanded" do + describe "unbound" do test "sample" do assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = - Scholar.Neighbors.KDTree.unbanded(example(), compiler: EXLA.Defn) + Scholar.Neighbors.KDTree.unbound(example(), compiler: EXLA.Defn) assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] end test "float" do assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = - Scholar.Neighbors.KDTree.unbanded(example() |> Nx.as_type(:f32), + Scholar.Neighbors.KDTree.unbound(example() |> Nx.as_type(:f32), compiler: EXLA.Defn ) @@ -36,35 +36,35 @@ defmodule Scholar.Neighbors.KDTreeTest do test "corner cases" do assert %Scholar.Neighbors.KDTree{levels: 1, indexes: indexes} = - Scholar.Neighbors.KDTree.unbanded(Nx.iota({1, 2}), compiler: EXLA.Defn) + Scholar.Neighbors.KDTree.unbound(Nx.iota({1, 2}), compiler: EXLA.Defn) assert indexes == Nx.u32([0]) assert %Scholar.Neighbors.KDTree{levels: 2, indexes: indexes} = - Scholar.Neighbors.KDTree.unbanded(Nx.iota({2, 2}), compiler: EXLA.Defn) + Scholar.Neighbors.KDTree.unbound(Nx.iota({2, 2}), compiler: EXLA.Defn) assert indexes == Nx.u32([1, 0]) end end - describe "banded" do + describe "bound" do test "iota" do assert %Scholar.Neighbors.KDTree{levels: 3, indexes: indexes} = - Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10) + Scholar.Neighbors.KDTree.bound(Nx.iota({5, 2}), 10) assert indexes == Nx.u32([3, 1, 4, 0, 2]) end test "float" do assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = - Scholar.Neighbors.KDTree.banded(example() |> Nx.as_type(:f32), 100) + Scholar.Neighbors.KDTree.bound(example() |> Nx.as_type(:f32), 100) assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] end test "sample" do assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = - Scholar.Neighbors.KDTree.banded(example(), 100) + Scholar.Neighbors.KDTree.bound(example(), 100) assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] end From 81337bda83876ce9cf45e5c8bf6ab8d8b5761b50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Mon, 6 Nov 2023 09:14:29 +0100 Subject: [PATCH 11/12] PR feedback --- lib/scholar/neighbors/kd_tree.ex | 2 +- test/scholar/neighbors/kd_tree_test.exs | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index b0398132..85f6b1b4 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -46,7 +46,7 @@ defmodule Scholar.Neighbors.KDTree do ## Examples - iex> Scholar.Neighbors.KDTree.unbound(Nx.iota({5, 2}), compiler: EXLA.Defn) + iex> Scholar.Neighbors.KDTree.unbound(Nx.iota({5, 2}), compiler: EXLA) %Scholar.Neighbors.KDTree{ data: Nx.iota({5, 2}), levels: 3, diff --git a/test/scholar/neighbors/kd_tree_test.exs b/test/scholar/neighbors/kd_tree_test.exs index c67c4106..98cb43f9 100644 --- a/test/scholar/neighbors/kd_tree_test.exs +++ b/test/scholar/neighbors/kd_tree_test.exs @@ -20,7 +20,7 @@ defmodule Scholar.Neighbors.KDTreeTest do describe "unbound" do test "sample" do assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = - Scholar.Neighbors.KDTree.unbound(example(), compiler: EXLA.Defn) + Scholar.Neighbors.KDTree.unbound(example(), compiler: EXLA) assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] end @@ -28,7 +28,7 @@ defmodule Scholar.Neighbors.KDTreeTest do test "float" do assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = Scholar.Neighbors.KDTree.unbound(example() |> Nx.as_type(:f32), - compiler: EXLA.Defn + compiler: EXLA ) assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] @@ -36,12 +36,12 @@ defmodule Scholar.Neighbors.KDTreeTest do test "corner cases" do assert %Scholar.Neighbors.KDTree{levels: 1, indexes: indexes} = - Scholar.Neighbors.KDTree.unbound(Nx.iota({1, 2}), compiler: EXLA.Defn) + Scholar.Neighbors.KDTree.unbound(Nx.iota({1, 2}), compiler: EXLA) assert indexes == Nx.u32([0]) assert %Scholar.Neighbors.KDTree{levels: 2, indexes: indexes} = - Scholar.Neighbors.KDTree.unbound(Nx.iota({2, 2}), compiler: EXLA.Defn) + Scholar.Neighbors.KDTree.unbound(Nx.iota({2, 2}), compiler: EXLA) assert indexes == Nx.u32([1, 0]) end From bec091c1b9f4f3b6b1b62871dec4bde181655f60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Mon, 6 Nov 2023 09:41:07 +0100 Subject: [PATCH 12/12] bound -> bounded --- benchmarks/kd_tree.exs | 4 +- lib/scholar/neighbors/kd_tree.ex | 50 ++++++++++++------------- test/scholar/neighbors/kd_tree_test.exs | 18 ++++----- 3 files changed, 36 insertions(+), 36 deletions(-) diff --git a/benchmarks/kd_tree.exs b/benchmarks/kd_tree.exs index 74dc3a33..18dda797 100644 --- a/benchmarks/kd_tree.exs +++ b/benchmarks/kd_tree.exs @@ -7,8 +7,8 @@ key = Nx.Random.key(System.os_time()) Benchee.run( %{ - "unbound" => fn -> Scholar.Neighbors.KDTree.unbound(uniform) end, - "bound" => fn -> Scholar.Neighbors.KDTree.bound(uniform, 2) end + "unbounded" => fn -> Scholar.Neighbors.KDTree.unbounded(uniform) end, + "bounded" => fn -> Scholar.Neighbors.KDTree.bounded(uniform, 2) end }, time: 10, memory_time: 2 diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index 85f6b1b4..c80077d8 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -9,13 +9,13 @@ defmodule Scholar.Neighbors.KDTree do Two construction modes are available: - * `bound/2` - the tensor has min and max values with an amplitude given by `max - min`. + * `bounded/2` - the tensor has min and max values with an amplitude given by `max - min`. It is also guaranteed that the `amplitude * levels(tensor) + 1` does not overflow the tensor. See `amplitude/1` to verify if this holds. This implementation happens - fully within `defn`. This version is orders of magnitude faster than the `unbound/2` + fully within `defn`. This version is orders of magnitude faster than the `unbounded/2` one. - * `unbound/2` - there are no known bounds (min and max values) to the tensor. + * `unbounded/2` - there are no known bounds (min and max values) to the tensor. This implementation is recursive and goes in and out of the `defn`, therefore it cannot be called inside `defn`. @@ -37,7 +37,7 @@ defmodule Scholar.Neighbors.KDTree do Builds a KDTree without known min-max bounds. If your tensor has known bounds (for example, -1 and 1), - consider using the `bound/2` version which is often orders of + consider using the `bounded/2` version which is often orders of magnitude more efficient. ## Options @@ -46,7 +46,7 @@ defmodule Scholar.Neighbors.KDTree do ## Examples - iex> Scholar.Neighbors.KDTree.unbound(Nx.iota({5, 2}), compiler: EXLA) + iex> Scholar.Neighbors.KDTree.unbounded(Nx.iota({5, 2}), compiler: EXLA) %Scholar.Neighbors.KDTree{ data: Nx.iota({5, 2}), levels: 3, @@ -54,13 +54,13 @@ defmodule Scholar.Neighbors.KDTree do } """ - def unbound(tensor, opts \\ []) do + def unbounded(tensor, opts \\ []) do levels = levels(tensor) {size, _dims} = Nx.shape(tensor) indexes = if size > 2 do - subtree_size = unbound_subtree_size(1, levels, size) + subtree_size = unbounded_subtree_size(1, levels, size) {left, mid, right} = Nx.Defn.jit_apply(&root_slice(&1, subtree_size), [tensor], opts) acc = <> @@ -88,7 +88,7 @@ defmodule Scholar.Neighbors.KDTree do defp recur([{i, indexes} | rest], next, acc, tensor, level, levels, opts) do %Nx.Tensor{shape: {size, dims}} = tensor k = rem(level, dims) - subtree_size = unbound_subtree_size(left_child(i), levels, size) + subtree_size = unbounded_subtree_size(left_child(i), levels, size) {left, mid, right} = Nx.Defn.jit_apply(&recur_slice(&1, &2, &3, subtree_size), [tensor, indexes, k], opts) @@ -121,15 +121,15 @@ defmodule Scholar.Neighbors.KDTree do Nx.slice(indexes, [subtree_size + 1], [Nx.size(indexes) - subtree_size - 1])} end - defp unbound_subtree_size(i, levels, size) do + defp unbounded_subtree_size(i, levels, size) do import Bitwise - diff = levels - unbound_level(i) - 1 + diff = levels - unbounded_level(i) - 1 shifted = 1 <<< diff fllc_s = (i <<< diff) + shifted - 1 shifted - 1 + min(max(0, size - fllc_s), shifted) end - defp unbound_level(i) when is_integer(i), do: floor(:math.log2(i + 1)) + defp unbounded_level(i) when is_integer(i), do: floor(:math.log2(i + 1)) @doc """ Builds a KDTree with known min-max bounds entirely within `defn`. @@ -140,23 +140,23 @@ defmodule Scholar.Neighbors.KDTree do For example, a tensor where all values are between 0 and 1 has amplitude 1. Values between -1 and 1 has amplitude 2. If your tensor is normalized - to floating points, then it is most likely bound (given their high + to floating points, then it is most likely bounded (given their high precision). You can use `amplitude/1` to check your assumptions. ## Examples - iex> Scholar.Neighbors.KDTree.bound(Nx.iota({5, 2}), 10) + iex> Scholar.Neighbors.KDTree.bounded(Nx.iota({5, 2}), 10) %Scholar.Neighbors.KDTree{ data: Nx.iota({5, 2}), levels: 3, indexes: Nx.u32([3, 1, 4, 0, 2]) } """ - deftransform bound(tensor, amplitude) do - %__MODULE__{levels: levels(tensor), indexes: bound_n(tensor, amplitude), data: tensor} + deftransform bounded(tensor, amplitude) do + %__MODULE__{levels: levels(tensor), indexes: bounded_n(tensor, amplitude), data: tensor} end - defnp bound_n(tensor, amplitude) do + defnp bounded_n(tensor, amplitude) do levels = levels(tensor) {size, dims} = Nx.shape(tensor) band = amplitude + 1 @@ -178,8 +178,8 @@ defmodule Scholar.Neighbors.KDTree do pos = Nx.argsort(indexes, type: :u32) pivot = - bound_segment_begin(tags, levels, size) + - bound_subtree_size(left_child(tags), levels, size) + bounded_segment_begin(tags, levels, size) + + bounded_subtree_size(left_child(tags), levels, size) Nx.select( pos < (1 <<< level) - 1, @@ -196,8 +196,8 @@ defmodule Scholar.Neighbors.KDTree do ) end - defnp bound_subtree_size(i, levels, size) do - diff = levels - bound_level(i) - 1 + defnp bounded_subtree_size(i, levels, size) do + diff = levels - bounded_level(i) - 1 shifted = 1 <<< diff first_lowest_level = (i <<< diff) + shifted - 1 # Use select instead of max to deal with overflows @@ -205,8 +205,8 @@ defmodule Scholar.Neighbors.KDTree do shifted - 1 + min(lowest_level, shifted) end - defnp bound_segment_begin(i, levels, size) do - level = bound_level(i) + defnp bounded_segment_begin(i, levels, size) do + level = bounded_level(i) top = (1 <<< level) - 1 diff = levels - level - 1 shifted = 1 <<< diff @@ -217,15 +217,15 @@ defmodule Scholar.Neighbors.KDTree do end # Since this property relies on u32, let's check the tensor type. - deftransformp bound_level(%Nx.Tensor{type: {:u, 32}} = i) do + deftransformp bounded_level(%Nx.Tensor{type: {:u, 32}} = i) do Nx.subtract(31, Nx.count_leading_zeros(Nx.add(i, 1))) end @doc """ Returns the amplitude of a bounded tensor. - If -1 is returned, it means the tensor cannot use the `bound` algorithm - to generate a KDTree and `unbound/2` must be used instead. + If -1 is returned, it means the tensor cannot use the `bounded` algorithm + to generate a KDTree and `unbounded/2` must be used instead. This cannot be invoked inside a `defn`. diff --git a/test/scholar/neighbors/kd_tree_test.exs b/test/scholar/neighbors/kd_tree_test.exs index 98cb43f9..3df22fbe 100644 --- a/test/scholar/neighbors/kd_tree_test.exs +++ b/test/scholar/neighbors/kd_tree_test.exs @@ -17,17 +17,17 @@ defmodule Scholar.Neighbors.KDTreeTest do ]) end - describe "unbound" do + describe "unbounded" do test "sample" do assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = - Scholar.Neighbors.KDTree.unbound(example(), compiler: EXLA) + Scholar.Neighbors.KDTree.unbounded(example(), compiler: EXLA) assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] end test "float" do assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = - Scholar.Neighbors.KDTree.unbound(example() |> Nx.as_type(:f32), + Scholar.Neighbors.KDTree.unbounded(example() |> Nx.as_type(:f32), compiler: EXLA ) @@ -36,35 +36,35 @@ defmodule Scholar.Neighbors.KDTreeTest do test "corner cases" do assert %Scholar.Neighbors.KDTree{levels: 1, indexes: indexes} = - Scholar.Neighbors.KDTree.unbound(Nx.iota({1, 2}), compiler: EXLA) + Scholar.Neighbors.KDTree.unbounded(Nx.iota({1, 2}), compiler: EXLA) assert indexes == Nx.u32([0]) assert %Scholar.Neighbors.KDTree{levels: 2, indexes: indexes} = - Scholar.Neighbors.KDTree.unbound(Nx.iota({2, 2}), compiler: EXLA) + Scholar.Neighbors.KDTree.unbounded(Nx.iota({2, 2}), compiler: EXLA) assert indexes == Nx.u32([1, 0]) end end - describe "bound" do + describe "bounded" do test "iota" do assert %Scholar.Neighbors.KDTree{levels: 3, indexes: indexes} = - Scholar.Neighbors.KDTree.bound(Nx.iota({5, 2}), 10) + Scholar.Neighbors.KDTree.bounded(Nx.iota({5, 2}), 10) assert indexes == Nx.u32([3, 1, 4, 0, 2]) end test "float" do assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = - Scholar.Neighbors.KDTree.bound(example() |> Nx.as_type(:f32), 100) + Scholar.Neighbors.KDTree.bounded(example() |> Nx.as_type(:f32), 100) assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] end test "sample" do assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = - Scholar.Neighbors.KDTree.bound(example(), 100) + Scholar.Neighbors.KDTree.bounded(example(), 100) assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] end