Skip to content

Commit

Permalink
Replace second order argsort with permutation inverse (elixir-nx#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Dec 18, 2023
1 parent 2a0ef38 commit 6c60968
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion lib/scholar/neighbors/kd_tree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ defmodule Scholar.Neighbors.KDTree do
end

defnp update_tags(tags, indices, level, levels, size) do
pos = Nx.argsort(indices, type: :u32)
pos = inverse_permutation(indices)

pivot =
bounded_segment_begin(tags, levels, size) +
Expand All @@ -103,6 +103,17 @@ defmodule Scholar.Neighbors.KDTree do
)
end

defnp inverse_permutation(indices) do
shape = Nx.shape(indices)
type = Nx.type(indices)

Nx.indexed_put(
Nx.broadcast(Nx.tensor(0, type: type), shape),
Nx.new_axis(indices, -1),
Nx.iota(shape, type: type)
)
end

defnp bounded_subtree_size(i, levels, size) do
diff = levels - bounded_level(i) - 1
shifted = 1 <<< diff
Expand Down

0 comments on commit 6c60968

Please sign in to comment.