From 9a578276e7f286b262a4d0c70386b784d6384f04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 5 Nov 2023 19:13:16 +0100 Subject: [PATCH] Add benchmarks --- benchmarks/kdtree.exs | 14 ++++++ lib/scholar/neighbors/kd_tree.ex | 63 +++++++++++++++++++++---- mix.exs | 3 +- mix.lock | 5 +- test/scholar/neighbors/kd_tree_test.exs | 21 ++++++--- 5 files changed, 90 insertions(+), 16 deletions(-) create mode 100644 benchmarks/kdtree.exs diff --git a/benchmarks/kdtree.exs b/benchmarks/kdtree.exs new file mode 100644 index 00000000..dd437c7f --- /dev/null +++ b/benchmarks/kdtree.exs @@ -0,0 +1,14 @@ +Nx.global_default_backend(EXLA.Backend) +Nx.Defn.global_default_options(compiler: EXLA) + +key = Nx.Random.key(System.os_time()) +{uniform, _new_key} = Nx.Random.uniform(key, shape: {1000, 3}) + +Benchee.run( + %{ + "unbanded" => fn -> Scholar.Neighbors.KDTree.unbanded(uniform) end, + "banded" => fn -> Scholar.Neighbors.KDTree.banded(uniform, 2) end + }, + time: 10, + memory_time: 2 +) diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index 01e2202d..fe15d518 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -12,12 +12,16 @@ defmodule Scholar.Neighbors.KDTree do * `banded/2` - the tensor has min and max values with an amplitude given by `max - min`. It is also guaranteed that the `amplitude * levels(tensor) + 1` does not overflow the tensor. See `amplitude/1` to verify if this holds. This implementation happens - fully within `defn`. + fully within `defn`. This version is orders of magnitude faster than the `unbanded/2` + one. * `unbanded/2` - there are no known bands (min and max values) to the tensor. This implementation is recursive and goes in and out of the `defn`, therefore it cannot be called inside `defn`. + Each level traverses over the last axis of tensor, the index for a level can be + computed as: `rem(level, Nx.axis_size(tensor, -1))`. + ## References * [GPU-friendly, Parallel, and (Almost-)In-Place Construction of Left-Balanced k-d Trees](https://arxiv.org/pdf/2211.00120.pdf). @@ -25,8 +29,6 @@ defmodule Scholar.Neighbors.KDTree do import Nx.Defn - # TODO: Benchmark - @derive {Nx.Container, keep: [:levels], containers: [:indexes]} @enforce_keys [:levels, :indexes] defstruct [:levels, :indexes] @@ -34,8 +36,9 @@ defmodule Scholar.Neighbors.KDTree do @doc """ Builds a KDTree without known min-max bounds. - If your tensor has a known bound (for example, -1 and 1), - consider using the `banded/2` version which is more efficient. + If your tensor has a known band (for example, -1 and 1), + consider using the `banded/2` version which is often orders of + magnitude more efficient. ## Options @@ -128,7 +131,21 @@ defmodule Scholar.Neighbors.KDTree do defp unbanded_level(i) when is_integer(i), do: 31 - clz32(i + 1) @doc """ - BANDED + Builds a KDTree with known min-max bounds entirely within `defn`. + + This requires the amplitude `|max - min|` of the tensor to be given. + For example, a tensor where all values are between 0 and 1 has amplitude + 1. Values between -1 and 1 has amplitude 2. If your tensor is normalized, + then you know the amplitude. Otherwise you can use `amplitude/1` to check + it. + + ## Examples + + iex> Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10) + %Scholar.Neighbors.KDTree{ + levels: 3, + indexes: Nx.u32([3, 1, 4, 0, 2]) + } """ defn banded(tensor, amplitude) do levels = levels(tensor) @@ -180,7 +197,7 @@ defmodule Scholar.Neighbors.KDTree do shifted - 1 + min(lowest_level, shifted) end - defn banded_segment_begin(i, levels, size) do + defnp banded_segment_begin(i, levels, size) do level = banded_level(i) top = (1 <<< level) - 1 diff = levels - level - 1 @@ -212,12 +229,14 @@ defmodule Scholar.Neighbors.KDTree do 39.0 iex> Scholar.Neighbors.KDTree.amplitude(Nx.iota({20, 2}, type: :u8)) -1 + iex> Scholar.Neighbors.KDTree.amplitude(Nx.negate(Nx.iota({10, 2}))) + 19 """ def amplitude(tensor) do max = tensor |> Nx.reduce_max() |> Nx.to_number() min = tensor |> Nx.reduce_min() |> Nx.to_number() - amplitude = max - min + amplitude = abs(max - min) limit = tensor.type |> Nx.Constants.max_finite() |> Nx.to_number() if max + (amplitude + 1) * (Nx.axis_size(tensor, 0) - 1) > limit do @@ -253,9 +272,34 @@ defmodule Scholar.Neighbors.KDTree do """ deftransform root, do: 0 + @doc """ + Returns the parent of child `i`. + + It is your responsibility to guarantee the result is positive. + + ## Examples + + iex> Scholar.Neighbors.KDTree.parent(1) + 0 + iex> Scholar.Neighbors.KDTree.parent(2) + 0 + + iex> Scholar.Neighbors.KDTree.parent(Nx.u32(3)) + #Nx.Tensor< + u32 + 1 + > + + """ + deftransform parent(i) when is_integer(i), do: div(i - 1, 2) + deftransform parent(%Nx.Tensor{} = t), do: Nx.quotient(Nx.subtract(t, 1), 2) + @doc """ Returns the index of the left child of i. + It is your responsibility to guarantee the result + is not greater than the leading axis of the tensor. + ## Examples iex> Scholar.Neighbors.KDTree.left_child(0) @@ -276,6 +320,9 @@ defmodule Scholar.Neighbors.KDTree do @doc """ Returns the index of the right child of i. + It is your responsibility to guarantee the result + is not greater than the leading axis of the tensor. + ## Examples iex> Scholar.Neighbors.KDTree.right_child(0) diff --git a/mix.exs b/mix.exs index 55d286fa..02acd3bb 100644 --- a/mix.exs +++ b/mix.exs @@ -34,7 +34,8 @@ defmodule Scholar.MixProject do {:nx, github: "elixir-nx/nx", sparse: "nx", override: true, branch: "v0.6"}, {:nimble_options, "~> 0.5.2 or ~> 1.0"}, {:exla, "~> 0.6", optional: true}, - {:polaris, "~> 0.1"} + {:polaris, "~> 0.1"}, + {:benchee, "~> 1.0", only: :dev} ] end diff --git a/mix.lock b/mix.lock index 5087bf8d..c17916d2 100644 --- a/mix.lock +++ b/mix.lock @@ -1,5 +1,7 @@ %{ + "benchee": {:hex, :benchee, "1.1.0", "f3a43817209a92a1fade36ef36b86e1052627fd8934a8b937ac9ab3a76c43062", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}, {:statistex, "~> 1.0", [hex: :statistex, repo: "hexpm", optional: false]}], "hexpm", "7da57d545003165a012b587077f6ba90b89210fd88074ce3c60ce239eb5e6d93"}, "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, + "deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"}, "earmark_parser": {:hex, :earmark_parser, "1.4.37", "2ad73550e27c8946648b06905a57e4d454e4d7229c2dafa72a0348c99d8be5f7", [:mix], [], "hexpm", "6b19783f2802f039806f375610faa22da130b8edc21209d0bff47918bb48360e"}, "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, "ex_doc": {:hex, :ex_doc, "0.30.6", "5f8b54854b240a2b55c9734c4b1d0dd7bdd41f71a095d42a70445c03cf05a281", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "bd48f2ddacf4e482c727f9293d9498e0881597eae6ddc3d9562bd7923375109f"}, @@ -9,8 +11,9 @@ "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, "nimble_options": {:hex, :nimble_options, "0.5.2", "42703307b924880f8c08d97719da7472673391905f528259915782bb346e0a1b", [:mix], [], "hexpm", "4da7f904b915fd71db549bcdc25f8d56f378ef7ae07dc1d372cbe72ba950dce0"}, "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, - "nx": {:git, "https://github.com/elixir-nx/nx.git", "e52d9097a52ae39c1ece1dcc2c12ad6456fc0fe2", [sparse: "nx", branch: "v0.6"]}, + "nx": {:git, "https://github.com/elixir-nx/nx.git", "ef464cfd0935eb4c2c1fa9a40f099b098a0b95bf", [sparse: "nx", branch: "v0.6"]}, "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, + "statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "xla": {:hex, :xla, "0.5.0", "fb8a02c02e5a4f4531fbf18a90c325e471037f983f0115d23f510e7dd9a6aa65", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "571ac797a4244b8ba8552ed0295a54397bd896708be51e4da6cbb784f6678061"}, } diff --git a/test/scholar/neighbors/kd_tree_test.exs b/test/scholar/neighbors/kd_tree_test.exs index a8678c66..66ee0f9f 100644 --- a/test/scholar/neighbors/kd_tree_test.exs +++ b/test/scholar/neighbors/kd_tree_test.exs @@ -25,6 +25,15 @@ defmodule Scholar.Neighbors.KDTreeTest do assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] end + test "float" do + assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = + Scholar.Neighbors.KDTree.unbanded(example() |> Nx.as_type(:f32), + compiler: EXLA.Defn + ) + + assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + end + test "corner cases" do assert Scholar.Neighbors.KDTree.unbanded(Nx.iota({1, 2}), compiler: EXLA.Defn) == %Scholar.Neighbors.KDTree{levels: 1, indexes: Nx.u32([0])} @@ -40,14 +49,14 @@ defmodule Scholar.Neighbors.KDTreeTest do %Scholar.Neighbors.KDTree{levels: 3, indexes: Nx.u32([3, 1, 4, 0, 2])} end - test "sample" do - input = Nx.u32([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + test "float" do + assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = + Scholar.Neighbors.KDTree.banded(example() |> Nx.as_type(:f32), 100) - assert Nx.Defn.jit_apply( - &Scholar.Neighbors.KDTree.banded_segment_begin(&1, 4, 10), - [input] - ) == Nx.u32([0, 1, 7, 3, 6, 8, 9, 7, 8, 9]) + assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + end + test "sample" do assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = Scholar.Neighbors.KDTree.banded(example(), 100)