From d730ded916ec41799c1709ddb69d49100433df40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Mon, 6 Nov 2023 13:46:08 +0100 Subject: [PATCH] KDTrees (#206) --- benchmarks/kd_tree.exs | 15 + lib/scholar/neighbors/kd_tree.ex | 350 ++++++++++++++++++ .../neighbors/radius_nearest_neighbors.ex | 4 +- mix.exs | 5 +- mix.lock | 5 +- test/scholar/neighbors/kd_tree_test.exs | 72 ++++ test/test_helper.exs | 1 + 7 files changed, 448 insertions(+), 4 deletions(-) create mode 100644 benchmarks/kd_tree.exs create mode 100644 lib/scholar/neighbors/kd_tree.ex create mode 100644 test/scholar/neighbors/kd_tree_test.exs diff --git a/benchmarks/kd_tree.exs b/benchmarks/kd_tree.exs new file mode 100644 index 00000000..18dda797 --- /dev/null +++ b/benchmarks/kd_tree.exs @@ -0,0 +1,15 @@ +# mix run benchmarks/kd_tree.exs +Nx.global_default_backend(EXLA.Backend) +Nx.Defn.global_default_options(compiler: EXLA) + +key = Nx.Random.key(System.os_time()) +{uniform, _new_key} = Nx.Random.uniform(key, shape: {1000, 3}) + +Benchee.run( + %{ + "unbounded" => fn -> Scholar.Neighbors.KDTree.unbounded(uniform) end, + "bounded" => fn -> Scholar.Neighbors.KDTree.bounded(uniform, 2) end + }, + time: 10, + memory_time: 2 +) diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex new file mode 100644 index 00000000..c80077d8 --- /dev/null +++ b/lib/scholar/neighbors/kd_tree.ex @@ -0,0 +1,350 @@ +defmodule Scholar.Neighbors.KDTree do + @moduledoc """ + Implements a kd-tree, a space-partitioning data structure for organizing points + in a k-dimensional space. + + This is implemented as one-dimensional tensor with indices pointed to highest + dimension of the given tensor. Traversal starts by calling `root/0` and then + accessing the `left_child/1` and `right_child/1`. The tree is left-balanced. + + Two construction modes are available: + + * `bounded/2` - the tensor has min and max values with an amplitude given by `max - min`. + It is also guaranteed that the `amplitude * levels(tensor) + 1` does not overflow + the tensor. See `amplitude/1` to verify if this holds. This implementation happens + fully within `defn`. This version is orders of magnitude faster than the `unbounded/2` + one. + + * `unbounded/2` - there are no known bounds (min and max values) to the tensor. + This implementation is recursive and goes in and out of the `defn`, therefore + it cannot be called inside `defn`. + + Each level traverses over the last axis of tensor, the index for a level can be + computed as: `rem(level, Nx.axis_size(tensor, -1))`. + + ## References + + * [GPU-friendly, Parallel, and (Almost-)In-Place Construction of Left-Balanced k-d Trees](https://arxiv.org/pdf/2211.00120.pdf). + """ + + import Nx.Defn + + @derive {Nx.Container, keep: [:levels], containers: [:indexes, :data]} + @enforce_keys [:levels, :indexes, :data] + defstruct [:levels, :indexes, :data] + + @doc """ + Builds a KDTree without known min-max bounds. + + If your tensor has known bounds (for example, -1 and 1), + consider using the `bounded/2` version which is often orders of + magnitude more efficient. + + ## Options + + * `:compiler` - the default compiler to use for internal defn operations + + ## Examples + + iex> Scholar.Neighbors.KDTree.unbounded(Nx.iota({5, 2}), compiler: EXLA) + %Scholar.Neighbors.KDTree{ + data: Nx.iota({5, 2}), + levels: 3, + indexes: Nx.u32([3, 1, 4, 0, 2]) + } + + """ + def unbounded(tensor, opts \\ []) do + levels = levels(tensor) + {size, _dims} = Nx.shape(tensor) + + indexes = + if size > 2 do + subtree_size = unbounded_subtree_size(1, levels, size) + {left, mid, right} = Nx.Defn.jit_apply(&root_slice(&1, subtree_size), [tensor], opts) + + acc = <> + acc = recur([{1, left}, {2, right}], [], acc, tensor, 1, levels, opts) + Nx.from_binary(acc, :u32) + else + Nx.argsort(tensor[[.., 0]], direction: :desc, type: :u32) + end + + %__MODULE__{levels: levels, indexes: indexes, data: tensor} + end + + defp recur([{_i, %Nx.Tensor{shape: {1}} = leaf} | rest], next, acc, tensor, level, levels, opts) do + [leaf] = Nx.to_flat_list(leaf) + acc = <> + recur(rest, next, acc, tensor, level, levels, opts) + end + + defp recur([{i, %Nx.Tensor{shape: {2}} = node} | rest], next, acc, tensor, level, levels, opts) do + acc = <> + next = [{left_child(i), Nx.slice(node, [0], [1])} | next] + recur(rest, next, acc, tensor, level, levels, opts) + end + + defp recur([{i, indexes} | rest], next, acc, tensor, level, levels, opts) do + %Nx.Tensor{shape: {size, dims}} = tensor + k = rem(level, dims) + subtree_size = unbounded_subtree_size(left_child(i), levels, size) + + {left, mid, right} = + Nx.Defn.jit_apply(&recur_slice(&1, &2, &3, subtree_size), [tensor, indexes, k], opts) + + next = [{right_child(i), right}, {left_child(i), left} | next] + acc = <> + recur(rest, next, acc, tensor, level, levels, opts) + end + + defp recur([], [], acc, _tensor, _level, _levels, _opts) do + acc + end + + defp recur([], next, acc, tensor, level, levels, opts) do + recur(Enum.reverse(next), [], acc, tensor, level + 1, levels, opts) + end + + defp root_slice(tensor, subtree_size) do + indexes = Nx.argsort(tensor[[.., 0]], type: :u32) + + {Nx.slice(indexes, [0], [subtree_size]), indexes[subtree_size], + Nx.slice(indexes, [subtree_size + 1], [Nx.size(indexes) - subtree_size - 1])} + end + + defp recur_slice(tensor, indexes, k, subtree_size) do + sorted = Nx.argsort(Nx.take(tensor, indexes)[[.., k]], type: :u32) + indexes = Nx.take(indexes, sorted) + + {Nx.slice(indexes, [0], [subtree_size]), indexes[subtree_size], + Nx.slice(indexes, [subtree_size + 1], [Nx.size(indexes) - subtree_size - 1])} + end + + defp unbounded_subtree_size(i, levels, size) do + import Bitwise + diff = levels - unbounded_level(i) - 1 + shifted = 1 <<< diff + fllc_s = (i <<< diff) + shifted - 1 + shifted - 1 + min(max(0, size - fllc_s), shifted) + end + + defp unbounded_level(i) when is_integer(i), do: floor(:math.log2(i + 1)) + + @doc """ + Builds a KDTree with known min-max bounds entirely within `defn`. + + This requires the amplitude `|max - min|` of the tensor to be given + such that `max + (amplitude + 1) * (size - 1)` does not overflow the + maximum tensor type. + + For example, a tensor where all values are between 0 and 1 has amplitude + 1. Values between -1 and 1 has amplitude 2. If your tensor is normalized + to floating points, then it is most likely bounded (given their high + precision). You can use `amplitude/1` to check your assumptions. + + ## Examples + + iex> Scholar.Neighbors.KDTree.bounded(Nx.iota({5, 2}), 10) + %Scholar.Neighbors.KDTree{ + data: Nx.iota({5, 2}), + levels: 3, + indexes: Nx.u32([3, 1, 4, 0, 2]) + } + """ + deftransform bounded(tensor, amplitude) do + %__MODULE__{levels: levels(tensor), indexes: bounded_n(tensor, amplitude), data: tensor} + end + + defnp bounded_n(tensor, amplitude) do + levels = levels(tensor) + {size, dims} = Nx.shape(tensor) + band = amplitude + 1 + tags = Nx.broadcast(Nx.u32(0), {size}) + + {level, tags, _tensor, _band} = + while {level = Nx.u32(0), tags, tensor, band}, level < levels - 1 do + k = rem(level, dims) + indexes = Nx.argsort(tensor[[.., k]] + band * tags, type: :u32) + tags = update_tags(tags, indexes, level, levels, size) + {level + 1, tags, tensor, band} + end + + k = rem(level, dims) + Nx.argsort(tensor[[.., k]] + band * tags, type: :u32) + end + + defnp update_tags(tags, indexes, level, levels, size) do + pos = Nx.argsort(indexes, type: :u32) + + pivot = + bounded_segment_begin(tags, levels, size) + + bounded_subtree_size(left_child(tags), levels, size) + + Nx.select( + pos < (1 <<< level) - 1, + tags, + Nx.select( + pos < pivot, + left_child(tags), + Nx.select( + pos > pivot, + right_child(tags), + tags + ) + ) + ) + end + + defnp bounded_subtree_size(i, levels, size) do + diff = levels - bounded_level(i) - 1 + shifted = 1 <<< diff + first_lowest_level = (i <<< diff) + shifted - 1 + # Use select instead of max to deal with overflows + lowest_level = Nx.select(first_lowest_level > size, Nx.u32(0), size - first_lowest_level) + shifted - 1 + min(lowest_level, shifted) + end + + defnp bounded_segment_begin(i, levels, size) do + level = bounded_level(i) + top = (1 <<< level) - 1 + diff = levels - level - 1 + shifted = 1 <<< diff + left_siblings = i - top + + top + left_siblings * (shifted - 1) + + min(left_siblings * shifted, size - (1 <<< (levels - 1)) + 1) + end + + # Since this property relies on u32, let's check the tensor type. + deftransformp bounded_level(%Nx.Tensor{type: {:u, 32}} = i) do + Nx.subtract(31, Nx.count_leading_zeros(Nx.add(i, 1))) + end + + @doc """ + Returns the amplitude of a bounded tensor. + + If -1 is returned, it means the tensor cannot use the `bounded` algorithm + to generate a KDTree and `unbounded/2` must be used instead. + + This cannot be invoked inside a `defn`. + + ## Examples + + iex> Scholar.Neighbors.KDTree.amplitude(Nx.iota({10, 2})) + 19 + iex> Scholar.Neighbors.KDTree.amplitude(Nx.iota({20, 2}, type: :f32)) + 39.0 + iex> Scholar.Neighbors.KDTree.amplitude(Nx.iota({20, 2}, type: :u8)) + -1 + iex> Scholar.Neighbors.KDTree.amplitude(Nx.negate(Nx.iota({10, 2}))) + 19 + + """ + def amplitude(tensor) do + max = tensor |> Nx.reduce_max() |> Nx.to_number() + min = tensor |> Nx.reduce_min() |> Nx.to_number() + amplitude = abs(max - min) + limit = tensor.type |> Nx.Constants.max_finite() |> Nx.to_number() + + if max + (amplitude + 1) * (Nx.axis_size(tensor, 0) - 1) > limit do + -1 + else + amplitude + end + end + + @doc """ + Returns the number of resulting levels in a KDTree for `tensor`. + + ## Examples + + iex> Scholar.Neighbors.KDTree.levels(Nx.iota({10, 3})) + 4 + """ + deftransform levels(%Nx.Tensor{} = tensor) do + case Nx.shape(tensor) do + {size, _dims} -> ceil(:math.log2(size + 1)) + _ -> raise ArgumentError, "KDTrees requires a tensor of rank 2" + end + end + + @doc """ + Returns the root index. + + ## Examples + + iex> Scholar.Neighbors.KDTree.root() + 0 + + """ + deftransform root, do: 0 + + @doc """ + Returns the parent of child `i`. + + It is your responsibility to guarantee the result is positive. + + ## Examples + + iex> Scholar.Neighbors.KDTree.parent(1) + 0 + iex> Scholar.Neighbors.KDTree.parent(2) + 0 + + iex> Scholar.Neighbors.KDTree.parent(Nx.u32(3)) + #Nx.Tensor< + u32 + 1 + > + + """ + deftransform parent(i) when is_integer(i), do: div(i - 1, 2) + deftransform parent(%Nx.Tensor{} = t), do: Nx.quotient(Nx.subtract(t, 1), 2) + + @doc """ + Returns the index of the left child of i. + + It is your responsibility to guarantee the result + is not greater than the leading axis of the tensor. + + ## Examples + + iex> Scholar.Neighbors.KDTree.left_child(0) + 1 + iex> Scholar.Neighbors.KDTree.left_child(1) + 3 + + iex> Scholar.Neighbors.KDTree.left_child(Nx.u32(3)) + #Nx.Tensor< + u32 + 7 + > + + """ + deftransform left_child(i) when is_integer(i), do: 2 * i + 1 + deftransform left_child(%Nx.Tensor{} = t), do: Nx.add(Nx.multiply(2, t), 1) + + @doc """ + Returns the index of the right child of i. + + It is your responsibility to guarantee the result + is not greater than the leading axis of the tensor. + + ## Examples + + iex> Scholar.Neighbors.KDTree.right_child(0) + 2 + iex> Scholar.Neighbors.KDTree.right_child(1) + 4 + + iex> Scholar.Neighbors.KDTree.right_child(Nx.u32(3)) + #Nx.Tensor< + u32 + 8 + > + + """ + deftransform right_child(i) when is_integer(i), do: 2 * i + 2 + deftransform right_child(%Nx.Tensor{} = t), do: Nx.add(Nx.multiply(2, t), 2) +end diff --git a/lib/scholar/neighbors/radius_nearest_neighbors.ex b/lib/scholar/neighbors/radius_nearest_neighbors.ex index 66c378b1..b1eca0e3 100644 --- a/lib/scholar/neighbors/radius_nearest_neighbors.ex +++ b/lib/scholar/neighbors/radius_nearest_neighbors.ex @@ -1,6 +1,8 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do @moduledoc """ - The Radius Nearest Neighbors. It implements both classification and regression. + The Radius Nearest Neighbors. + + It implements both classification and regression. """ import Nx.Defn import Scholar.Shared diff --git a/mix.exs b/mix.exs index c439392e..02acd3bb 100644 --- a/mix.exs +++ b/mix.exs @@ -31,10 +31,11 @@ defmodule Scholar.MixProject do [ {:ex_doc, "~> 0.30", only: :docs}, # {:nx, "~> 0.6", override: true}, - {:nx, github: "elixir-nx/nx", sparse: "nx", override: true}, + {:nx, github: "elixir-nx/nx", sparse: "nx", override: true, branch: "v0.6"}, {:nimble_options, "~> 0.5.2 or ~> 1.0"}, {:exla, "~> 0.6", optional: true}, - {:polaris, "~> 0.1"} + {:polaris, "~> 0.1"}, + {:benchee, "~> 1.0", only: :dev} ] end diff --git a/mix.lock b/mix.lock index cb415b64..c17916d2 100644 --- a/mix.lock +++ b/mix.lock @@ -1,5 +1,7 @@ %{ + "benchee": {:hex, :benchee, "1.1.0", "f3a43817209a92a1fade36ef36b86e1052627fd8934a8b937ac9ab3a76c43062", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}, {:statistex, "~> 1.0", [hex: :statistex, repo: "hexpm", optional: false]}], "hexpm", "7da57d545003165a012b587077f6ba90b89210fd88074ce3c60ce239eb5e6d93"}, "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, + "deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"}, "earmark_parser": {:hex, :earmark_parser, "1.4.37", "2ad73550e27c8946648b06905a57e4d454e4d7229c2dafa72a0348c99d8be5f7", [:mix], [], "hexpm", "6b19783f2802f039806f375610faa22da130b8edc21209d0bff47918bb48360e"}, "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, "ex_doc": {:hex, :ex_doc, "0.30.6", "5f8b54854b240a2b55c9734c4b1d0dd7bdd41f71a095d42a70445c03cf05a281", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "bd48f2ddacf4e482c727f9293d9498e0881597eae6ddc3d9562bd7923375109f"}, @@ -9,8 +11,9 @@ "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, "nimble_options": {:hex, :nimble_options, "0.5.2", "42703307b924880f8c08d97719da7472673391905f528259915782bb346e0a1b", [:mix], [], "hexpm", "4da7f904b915fd71db549bcdc25f8d56f378ef7ae07dc1d372cbe72ba950dce0"}, "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, - "nx": {:git, "https://github.com/elixir-nx/nx.git", "a0b7e2e5cc7a62a55cd2e7bbc3e44ba2ac1c996b", [sparse: "nx"]}, + "nx": {:git, "https://github.com/elixir-nx/nx.git", "ef464cfd0935eb4c2c1fa9a40f099b098a0b95bf", [sparse: "nx", branch: "v0.6"]}, "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, + "statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "xla": {:hex, :xla, "0.5.0", "fb8a02c02e5a4f4531fbf18a90c325e471037f983f0115d23f510e7dd9a6aa65", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "571ac797a4244b8ba8552ed0295a54397bd896708be51e4da6cbb784f6678061"}, } diff --git a/test/scholar/neighbors/kd_tree_test.exs b/test/scholar/neighbors/kd_tree_test.exs new file mode 100644 index 00000000..3df22fbe --- /dev/null +++ b/test/scholar/neighbors/kd_tree_test.exs @@ -0,0 +1,72 @@ +defmodule Scholar.Neighbors.KDTreeTest do + use ExUnit.Case, async: true + doctest Scholar.Neighbors.KDTree + + defp example do + Nx.tensor([ + [10, 15], + [46, 63], + [68, 21], + [40, 33], + [25, 54], + [15, 43], + [44, 58], + [45, 40], + [62, 69], + [53, 67] + ]) + end + + describe "unbounded" do + test "sample" do + assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = + Scholar.Neighbors.KDTree.unbounded(example(), compiler: EXLA) + + assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + end + + test "float" do + assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = + Scholar.Neighbors.KDTree.unbounded(example() |> Nx.as_type(:f32), + compiler: EXLA + ) + + assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + end + + test "corner cases" do + assert %Scholar.Neighbors.KDTree{levels: 1, indexes: indexes} = + Scholar.Neighbors.KDTree.unbounded(Nx.iota({1, 2}), compiler: EXLA) + + assert indexes == Nx.u32([0]) + + assert %Scholar.Neighbors.KDTree{levels: 2, indexes: indexes} = + Scholar.Neighbors.KDTree.unbounded(Nx.iota({2, 2}), compiler: EXLA) + + assert indexes == Nx.u32([1, 0]) + end + end + + describe "bounded" do + test "iota" do + assert %Scholar.Neighbors.KDTree{levels: 3, indexes: indexes} = + Scholar.Neighbors.KDTree.bounded(Nx.iota({5, 2}), 10) + + assert indexes == Nx.u32([3, 1, 4, 0, 2]) + end + + test "float" do + assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = + Scholar.Neighbors.KDTree.bounded(example() |> Nx.as_type(:f32), 100) + + assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + end + + test "sample" do + assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = + Scholar.Neighbors.KDTree.bounded(example(), 100) + + assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + end + end +end diff --git a/test/test_helper.exs b/test/test_helper.exs index e69de29b..5e96292b 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -0,0 +1 @@ +Application.ensure_all_started(:exla)