Skip to content

Commit

Permalink
Add benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Nov 5, 2023
1 parent fd4d1d9 commit 9a57827
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 16 deletions.
14 changes: 14 additions & 0 deletions benchmarks/kdtree.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Nx.global_default_backend(EXLA.Backend)
Nx.Defn.global_default_options(compiler: EXLA)

key = Nx.Random.key(System.os_time())
{uniform, _new_key} = Nx.Random.uniform(key, shape: {1000, 3})

Benchee.run(
%{
"unbanded" => fn -> Scholar.Neighbors.KDTree.unbanded(uniform) end,
"banded" => fn -> Scholar.Neighbors.KDTree.banded(uniform, 2) end
},
time: 10,
memory_time: 2
)
63 changes: 55 additions & 8 deletions lib/scholar/neighbors/kd_tree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,33 @@ defmodule Scholar.Neighbors.KDTree do
* `banded/2` - the tensor has min and max values with an amplitude given by `max - min`.
It is also guaranteed that the `amplitude * levels(tensor) + 1` does not overflow
the tensor. See `amplitude/1` to verify if this holds. This implementation happens
fully within `defn`.
fully within `defn`. This version is orders of magnitude faster than the `unbanded/2`
one.
* `unbanded/2` - there are no known bands (min and max values) to the tensor.
This implementation is recursive and goes in and out of the `defn`, therefore
it cannot be called inside `defn`.
Each level traverses over the last axis of tensor, the index for a level can be
computed as: `rem(level, Nx.axis_size(tensor, -1))`.
## References
* [GPU-friendly, Parallel, and (Almost-)In-Place Construction of Left-Balanced k-d Trees](https://arxiv.org/pdf/2211.00120.pdf).
"""

import Nx.Defn

# TODO: Benchmark

@derive {Nx.Container, keep: [:levels], containers: [:indexes]}
@enforce_keys [:levels, :indexes]
defstruct [:levels, :indexes]

@doc """
Builds a KDTree without known min-max bounds.
If your tensor has a known bound (for example, -1 and 1),
consider using the `banded/2` version which is more efficient.
If your tensor has a known band (for example, -1 and 1),
consider using the `banded/2` version which is often orders of
magnitude more efficient.
## Options
Expand Down Expand Up @@ -128,7 +131,21 @@ defmodule Scholar.Neighbors.KDTree do
defp unbanded_level(i) when is_integer(i), do: 31 - clz32(i + 1)

@doc """
BANDED
Builds a KDTree with known min-max bounds entirely within `defn`.
This requires the amplitude `|max - min|` of the tensor to be given.
For example, a tensor where all values are between 0 and 1 has amplitude
1. Values between -1 and 1 has amplitude 2. If your tensor is normalized,
then you know the amplitude. Otherwise you can use `amplitude/1` to check
it.
## Examples
iex> Scholar.Neighbors.KDTree.banded(Nx.iota({5, 2}), 10)
%Scholar.Neighbors.KDTree{
levels: 3,
indexes: Nx.u32([3, 1, 4, 0, 2])
}
"""
defn banded(tensor, amplitude) do
levels = levels(tensor)
Expand Down Expand Up @@ -180,7 +197,7 @@ defmodule Scholar.Neighbors.KDTree do
shifted - 1 + min(lowest_level, shifted)
end

defn banded_segment_begin(i, levels, size) do
defnp banded_segment_begin(i, levels, size) do
level = banded_level(i)
top = (1 <<< level) - 1
diff = levels - level - 1
Expand Down Expand Up @@ -212,12 +229,14 @@ defmodule Scholar.Neighbors.KDTree do
39.0
iex> Scholar.Neighbors.KDTree.amplitude(Nx.iota({20, 2}, type: :u8))
-1
iex> Scholar.Neighbors.KDTree.amplitude(Nx.negate(Nx.iota({10, 2})))
19
"""
def amplitude(tensor) do
max = tensor |> Nx.reduce_max() |> Nx.to_number()
min = tensor |> Nx.reduce_min() |> Nx.to_number()
amplitude = max - min
amplitude = abs(max - min)
limit = tensor.type |> Nx.Constants.max_finite() |> Nx.to_number()

if max + (amplitude + 1) * (Nx.axis_size(tensor, 0) - 1) > limit do
Expand Down Expand Up @@ -253,9 +272,34 @@ defmodule Scholar.Neighbors.KDTree do
"""
deftransform root, do: 0

@doc """
Returns the parent of child `i`.
It is your responsibility to guarantee the result is positive.
## Examples
iex> Scholar.Neighbors.KDTree.parent(1)
0
iex> Scholar.Neighbors.KDTree.parent(2)
0
iex> Scholar.Neighbors.KDTree.parent(Nx.u32(3))
#Nx.Tensor<
u32
1
>
"""
deftransform parent(i) when is_integer(i), do: div(i - 1, 2)
deftransform parent(%Nx.Tensor{} = t), do: Nx.quotient(Nx.subtract(t, 1), 2)

@doc """
Returns the index of the left child of i.
It is your responsibility to guarantee the result
is not greater than the leading axis of the tensor.
## Examples
iex> Scholar.Neighbors.KDTree.left_child(0)
Expand All @@ -276,6 +320,9 @@ defmodule Scholar.Neighbors.KDTree do
@doc """
Returns the index of the right child of i.
It is your responsibility to guarantee the result
is not greater than the leading axis of the tensor.
## Examples
iex> Scholar.Neighbors.KDTree.right_child(0)
Expand Down
3 changes: 2 additions & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ defmodule Scholar.MixProject do
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true, branch: "v0.6"},
{:nimble_options, "~> 0.5.2 or ~> 1.0"},
{:exla, "~> 0.6", optional: true},
{:polaris, "~> 0.1"}
{:polaris, "~> 0.1"},
{:benchee, "~> 1.0", only: :dev}
]
end

Expand Down
5 changes: 4 additions & 1 deletion mix.lock
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
%{
"benchee": {:hex, :benchee, "1.1.0", "f3a43817209a92a1fade36ef36b86e1052627fd8934a8b937ac9ab3a76c43062", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}, {:statistex, "~> 1.0", [hex: :statistex, repo: "hexpm", optional: false]}], "hexpm", "7da57d545003165a012b587077f6ba90b89210fd88074ce3c60ce239eb5e6d93"},
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"},
"earmark_parser": {:hex, :earmark_parser, "1.4.37", "2ad73550e27c8946648b06905a57e4d454e4d7229c2dafa72a0348c99d8be5f7", [:mix], [], "hexpm", "6b19783f2802f039806f375610faa22da130b8edc21209d0bff47918bb48360e"},
"elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"},
"ex_doc": {:hex, :ex_doc, "0.30.6", "5f8b54854b240a2b55c9734c4b1d0dd7bdd41f71a095d42a70445c03cf05a281", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "bd48f2ddacf4e482c727f9293d9498e0881597eae6ddc3d9562bd7923375109f"},
Expand All @@ -9,8 +11,9 @@
"makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"},
"nimble_options": {:hex, :nimble_options, "0.5.2", "42703307b924880f8c08d97719da7472673391905f528259915782bb346e0a1b", [:mix], [], "hexpm", "4da7f904b915fd71db549bcdc25f8d56f378ef7ae07dc1d372cbe72ba950dce0"},
"nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"},
"nx": {:git, "https://github.com/elixir-nx/nx.git", "e52d9097a52ae39c1ece1dcc2c12ad6456fc0fe2", [sparse: "nx", branch: "v0.6"]},
"nx": {:git, "https://github.com/elixir-nx/nx.git", "ef464cfd0935eb4c2c1fa9a40f099b098a0b95bf", [sparse: "nx", branch: "v0.6"]},
"polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"},
"statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"xla": {:hex, :xla, "0.5.0", "fb8a02c02e5a4f4531fbf18a90c325e471037f983f0115d23f510e7dd9a6aa65", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "571ac797a4244b8ba8552ed0295a54397bd896708be51e4da6cbb784f6678061"},
}
21 changes: 15 additions & 6 deletions test/scholar/neighbors/kd_tree_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ defmodule Scholar.Neighbors.KDTreeTest do
assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4]
end

test "float" do

Check failure on line 28 in test/scholar/neighbors/kd_tree_test.exs

View workflow job for this annotation

GitHub Actions / main (1.14.5, 25.3)

test unbanded float (Scholar.Neighbors.KDTreeTest)
assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} =
Scholar.Neighbors.KDTree.unbanded(example() |> Nx.as_type(:f32),
compiler: EXLA.Defn
)

assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4]
end

test "corner cases" do
assert Scholar.Neighbors.KDTree.unbanded(Nx.iota({1, 2}), compiler: EXLA.Defn) ==
%Scholar.Neighbors.KDTree{levels: 1, indexes: Nx.u32([0])}
Expand All @@ -40,14 +49,14 @@ defmodule Scholar.Neighbors.KDTreeTest do
%Scholar.Neighbors.KDTree{levels: 3, indexes: Nx.u32([3, 1, 4, 0, 2])}
end

test "sample" do
input = Nx.u32([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
test "float" do
assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} =
Scholar.Neighbors.KDTree.banded(example() |> Nx.as_type(:f32), 100)

assert Nx.Defn.jit_apply(
&Scholar.Neighbors.KDTree.banded_segment_begin(&1, 4, 10),
[input]
) == Nx.u32([0, 1, 7, 3, 6, 8, 9, 7, 8, 9])
assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4]
end

test "sample" do
assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} =
Scholar.Neighbors.KDTree.banded(example(), 100)

Expand Down

0 comments on commit 9a57827

Please sign in to comment.