Skip to content

Commit

Permalink
Move iota inside loop
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Dec 31, 2023
1 parent 64d1840 commit 0e99c60
Showing 1 changed file with 3 additions and 13 deletions.
16 changes: 3 additions & 13 deletions lib/scholar/neighbors/random_projection_forest.ex
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ defmodule Scholar.Neighbors.RandomProjectionForest do
pos = Nx.iota({size}, type: :u32),
cell_sizes = Nx.broadcast(Nx.u32(size), {size}),
tags = Nx.broadcast(Nx.u32(0), {size}),
nodes = Nx.iota({num_nodes}, type: :u32),
width = Nx.u32(1),
median_offset = Nx.u32(0)
}
Expand All @@ -175,7 +174,6 @@ defmodule Scholar.Neighbors.RandomProjectionForest do
left_sizes,
right_sizes,
level_proj,
nodes,
width,
median_offset,
medians
Expand All @@ -186,24 +184,15 @@ defmodule Scholar.Neighbors.RandomProjectionForest do
{
indices,
medians,
{tensor, hyperplanes, level + 1, pos, cell_sizes, tags, nodes, 2 * width,
{tensor, hyperplanes, level + 1, pos, cell_sizes, tags, 2 * width,
2 * median_offset + 1}
}
end

{indices, hyperplanes, medians}
end

defnp update_medians(
pos,
left_sizes,
right_sizes,
level_proj,
nodes,
width,
median_offset,
medians
) do
defnp update_medians(pos, left_sizes, right_sizes, level_proj, width, median_offset, medians) do
size = Nx.size(pos)
{num_trees, num_nodes} = Nx.shape(medians)

Expand All @@ -225,6 +214,7 @@ defmodule Scholar.Neighbors.RandomProjectionForest do

right_first = Nx.take_along_axis(level_proj, right_indices, axis: 1)

nodes = Nx.iota({num_nodes}, type: :u32)
medians_first = (left_first + right_first) / 2

median_mask = width <= nodes and nodes < width + median_offset
Expand Down

0 comments on commit 0e99c60

Please sign in to comment.