Skip to content

Commit

Permalink
Vectorized knn query (elixir-nx#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
msluszniak authored Dec 11, 2023
1 parent 8826219 commit 88744e3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 19 deletions.
42 changes: 25 additions & 17 deletions lib/scholar/neighbors/kd_tree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -407,19 +407,7 @@ defmodule Scholar.Neighbors.KDTree do
end

defnp predict_n(tree, data, opts) do
k = opts[:k]
num_samples = Nx.axis_size(data, 0)
knn = Nx.broadcast(Nx.s64(0), {num_samples, k})

{knn, _} =
while {knn, {tree, data, i = Nx.s64(0)}}, i < num_samples do
curr_point = data[[i]]
k_neighbors = query_one_point(curr_point, tree, opts)
knn = Nx.put_slice(knn, [i, 0], Nx.new_axis(k_neighbors, 0))
{knn, {tree, data, i + 1}}
end

knn
query_points(data, tree, opts)
end

defnp sort_by_distances(distances, point_indices) do
Expand Down Expand Up @@ -462,9 +450,18 @@ defmodule Scholar.Neighbors.KDTree do
end
end

defnp query_one_point(point, tree, opts) do
defnp query_points(point, tree, opts) do
k = opts[:k]
node = Nx.as_type(root(), :s64)

input_vectorized_axes = point.vectorized_axes
num_points = Nx.axis_size(point, 0)

point =
Nx.revectorize(point, [collapsed_axes: :auto, x: Nx.axis_size(point, -2)],
target_shape: {Nx.axis_size(point, -1)}
)

{size, dims} = Nx.shape(tree.data)
nearest_neighbors = Nx.broadcast(Nx.s64(0), {k})
distances = Nx.broadcast(Nx.Constants.infinity(), {k})
Expand All @@ -476,10 +473,21 @@ defmodule Scholar.Neighbors.KDTree do
down = 0
up = 1
mode = down
i = Nx.s64(0)

[nearest_neighbors, node, distances, visited, i, mode, point] =
Nx.broadcast_vectors([
nearest_neighbors,
node,
distances,
visited,
i,
mode,
point
])

{nearest_neighbors, _} =
while {nearest_neighbors,
{node, data, indices, point, distances, visited, i = Nx.s64(0), mode}},
while {nearest_neighbors, {node, data, indices, point, distances, visited, i, mode}},
node != -1 and i >= 0 do
coord_indicator = rem(i, dims)

Expand Down Expand Up @@ -593,6 +601,6 @@ defmodule Scholar.Neighbors.KDTree do
{nearest_neighbors, {node, data, indices, point, distances, visited, i, mode}}
end

nearest_neighbors
Nx.revectorize(nearest_neighbors, input_vectorized_axes, target_shape: {num_points, k})
end
end
3 changes: 2 additions & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ defmodule Scholar.MixProject do
defp deps do
[
{:ex_doc, "~> 0.30", only: :docs},
{:nx, "~> 0.6.4 or ~> 0.7", override: true},
# {:nx, "~> 0.6.4 or ~> 0.7", override: true},
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
{:nimble_options, "~> 0.5.2 or ~> 1.0"},
{:exla, "~> 0.6.3 or ~> 0.7", optional: true},
{:polaris, "~> 0.1"},
Expand Down
2 changes: 1 addition & 1 deletion mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"},
"nimble_options": {:hex, :nimble_options, "0.5.2", "42703307b924880f8c08d97719da7472673391905f528259915782bb346e0a1b", [:mix], [], "hexpm", "4da7f904b915fd71db549bcdc25f8d56f378ef7ae07dc1d372cbe72ba950dce0"},
"nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"},
"nx": {:hex, :nx, "0.6.4", "948d9f42f81e63fc901d243ac0a985c8bb87358be62e27826cfd67f58bc640af", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "bb9c2e2e3545b5eb4739d69046a988daaa212d127dba7d97801c291616aff6d6"},
"nx": {:git, "https://github.com/elixir-nx/nx.git", "ad0cd2e9c71b379ccbe9317349f0ad0560f5de29", [sparse: "nx"]},
"polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"},
"statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
Expand Down

0 comments on commit 88744e3

Please sign in to comment.