Skip to content

Commit

Permalink
Moving sorting inside loop as it is faster and uses less memory
Browse files Browse the repository at this point in the history
  • Loading branch information
Krsto Proroković committed Dec 22, 2023
1 parent bb625c5 commit 223e150
Showing 1 changed file with 7 additions and 25 deletions.
32 changes: 7 additions & 25 deletions lib/scholar/neighbors/random_projection_forest.ex
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,13 @@ defmodule Scholar.Neighbors.RandomProjectionForest do
{hyperplanes, _key} =
Nx.Random.normal(key, type: type, shape: {num_trees, depth, dim})

proj = Nx.dot(hyperplanes, [2], tensor, [1])
sorted_indices = Nx.argsort(proj, axis: 2, stable: true, type: :u32)

{indices, medians, _} =
while {
indices = Nx.iota({num_trees, size}, axis: 1, type: :u32),
medians = Nx.broadcast(Nx.tensor(:nan, type: type), {num_trees, num_nodes}),
{
proj,
sorted_indices,
tensor,
hyperplanes,
level = Nx.u32(0),
pos = Nx.iota({size}, type: :u32),
cell_sizes = Nx.broadcast(Nx.u32(size), {size}),
Expand All @@ -155,13 +152,11 @@ defmodule Scholar.Neighbors.RandomProjectionForest do
}
},
level < depth do
level_proj = proj[[.., level]] |> Nx.take_along_axis(indices, axis: 1)

level_indices =
indices
|> inverse_permutation()
|> Nx.take_along_axis(sorted_indices[[.., level]], axis: 1)
level_proj =
Nx.dot(hyperplanes[[.., level]], [1], tensor, [1])
|> Nx.take_along_axis(indices, axis: 1)

level_indices = Nx.argsort(level_proj, axis: 1, type: :u32, stable: true)
orders = Nx.argsort(tags[level_indices], axis: 1, stable: true, type: :u32)
level_indices = Nx.take_along_axis(level_indices, orders, axis: 1)
indices = Nx.take_along_axis(indices, level_indices, axis: 1)
Expand Down Expand Up @@ -189,27 +184,14 @@ defmodule Scholar.Neighbors.RandomProjectionForest do
{
indices,
medians,
{proj, sorted_indices, level + 1, pos, cell_sizes, tags, nodes, 2 * width,
{tensor, hyperplanes, level + 1, pos, cell_sizes, tags, nodes, 2 * width,
2 * median_offset + 1}
}
end

{indices, hyperplanes, medians}
end

defnp inverse_permutation(indices) do
{num_trees, size} = Nx.shape(indices)
target = Nx.broadcast(Nx.u32(0), {num_trees, size})
trees = Nx.iota({num_trees, size, 1}, axis: 0)

indices =
Nx.concatenate([trees, Nx.new_axis(indices, 2)], axis: 2)
|> Nx.reshape({num_trees * size, 2})

updates = Nx.iota({num_trees, size}, axis: 1) |> Nx.reshape({num_trees * size})
Nx.indexed_add(target, indices, updates)
end

defnp update_medians(
pos,
left_sizes,
Expand Down

0 comments on commit 223e150

Please sign in to comment.