Skip to content

Commit

Permalink
move get_batches inside BruteKNN
Browse files Browse the repository at this point in the history
  • Loading branch information
Krsto Proroković committed Apr 13, 2024
1 parent d685da4 commit 96e8ed2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
21 changes: 21 additions & 0 deletions lib/scholar/neighbors/brute_knn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,27 @@ defmodule Scholar.Neighbors.BruteKNN do
{neighbor_indices, neighbor_distances}
end

defn get_batches(tensor, opts) do
{size, dim} = Nx.shape(tensor)
batch_size = opts[:batch_size]
num_batches = div(size, batch_size)
leftover_size = rem(size, batch_size)

batches =
tensor
|> Nx.slice_along_axis(0, num_batches * batch_size, axis: 0)
|> Nx.reshape({num_batches, batch_size, dim})

leftover =
if leftover_size > 0 do
Nx.slice_along_axis(tensor, num_batches * batch_size, leftover_size, axis: 0)
else
nil
end

{batches, leftover}
end

defnp brute_force_search(data, query, opts) do
k = opts[:num_neighbors]
metric = opts[:metric]
Expand Down
21 changes: 0 additions & 21 deletions lib/scholar/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -89,25 +89,4 @@ defmodule Scholar.Shared do

valid_broadcast(to_parse - 1, n_dims, shape1, shape2)
end

defn get_batches(tensor, opts) do
{size, dim} = Nx.shape(tensor)
batch_size = opts[:batch_size]
num_batches = div(size, batch_size)
leftover_size = rem(size, batch_size)

batches =
tensor
|> Nx.slice_along_axis(0, num_batches * batch_size, axis: 0)
|> Nx.reshape({num_batches, batch_size, dim})

leftover =
if leftover_size > 0 do
Nx.slice_along_axis(tensor, num_batches * batch_size, leftover_size, axis: 0)
else
nil
end

{batches, leftover}
end
end

0 comments on commit 96e8ed2

Please sign in to comment.