From 96e8ed2491a32bd9fa3477829c7287291c1d90a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Sat, 13 Apr 2024 22:12:00 +0200 Subject: [PATCH] move get_batches inside BruteKNN --- lib/scholar/neighbors/brute_knn.ex | 21 +++++++++++++++++++++ lib/scholar/shared.ex | 21 --------------------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/lib/scholar/neighbors/brute_knn.ex b/lib/scholar/neighbors/brute_knn.ex index 50283433..4fae09b9 100644 --- a/lib/scholar/neighbors/brute_knn.ex +++ b/lib/scholar/neighbors/brute_knn.ex @@ -212,6 +212,27 @@ defmodule Scholar.Neighbors.BruteKNN do {neighbor_indices, neighbor_distances} end + defn get_batches(tensor, opts) do + {size, dim} = Nx.shape(tensor) + batch_size = opts[:batch_size] + num_batches = div(size, batch_size) + leftover_size = rem(size, batch_size) + + batches = + tensor + |> Nx.slice_along_axis(0, num_batches * batch_size, axis: 0) + |> Nx.reshape({num_batches, batch_size, dim}) + + leftover = + if leftover_size > 0 do + Nx.slice_along_axis(tensor, num_batches * batch_size, leftover_size, axis: 0) + else + nil + end + + {batches, leftover} + end + defnp brute_force_search(data, query, opts) do k = opts[:num_neighbors] metric = opts[:metric] diff --git a/lib/scholar/shared.ex b/lib/scholar/shared.ex index 21f439ad..61330283 100644 --- a/lib/scholar/shared.ex +++ b/lib/scholar/shared.ex @@ -89,25 +89,4 @@ defmodule Scholar.Shared do valid_broadcast(to_parse - 1, n_dims, shape1, shape2) end - - defn get_batches(tensor, opts) do - {size, dim} = Nx.shape(tensor) - batch_size = opts[:batch_size] - num_batches = div(size, batch_size) - leftover_size = rem(size, batch_size) - - batches = - tensor - |> Nx.slice_along_axis(0, num_batches * batch_size, axis: 0) - |> Nx.reshape({num_batches, batch_size, dim}) - - leftover = - if leftover_size > 0 do - Nx.slice_along_axis(tensor, num_batches * batch_size, leftover_size, axis: 0) - else - nil - end - - {batches, leftover} - end end