Skip to content

Commit

Permalink
Fix correct call to brute force in generate groundtruth of cuvs-bench (
Browse files Browse the repository at this point in the history
…#427)

Fixes issue with helper script for generating ground truthset in cuvs-bench, which was using the old RAFT NN API.

Authors:
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - Divye Gala (https://github.com/divyegala)
  - Ben Frederickson (https://github.com/benfred)

URL: #427
  • Loading branch information
dantegd authored Oct 28, 2024
1 parent e7f1085 commit 12b10e8
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pylibraft.common import DeviceResources
from rmm.allocators.cupy import rmm_cupy_allocator

from cuvs.neighbors.brute_force import knn
from cuvs.neighbors.brute_force import build, search

from .utils import memmap_bin_file, suffix_from_dtype, write_bin

Expand All @@ -49,7 +49,7 @@ def choose_random_queries(dataset, n_queries):


def calc_truth(dataset, queries, k, metric="sqeuclidean"):
handle = DeviceResources()
resources = DeviceResources()
n_samples = dataset.shape[0]
n = 500000 # batch size for processing neighbors
i = 0
Expand All @@ -63,8 +63,9 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"):

X = cp.asarray(dataset[i : i + n_batch, :], cp.float32)

D, Ind = knn(X, queries, k, metric=metric, handle=handle)
handle.sync()
index = build(X, metric=metric, resources=resources)
D, Ind = search(index, queries, k, resources=resources)
resources.sync()

D, Ind = cp.asarray(D), cp.asarray(Ind)
Ind += i # shift neighbor index by offset i
Expand Down

0 comments on commit 12b10e8

Please sign in to comment.