From 12b10e88e8ea6e944e91dee8a0380c89999b3b21 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Mon, 28 Oct 2024 11:56:53 -0500 Subject: [PATCH] Fix correct call to brute force in generate groundtruth of cuvs-bench (#427) Fixes issue with helper script for generating ground truthset in cuvs-bench, which was using the old RAFT NN API. Authors: - Dante Gama Dessavre (https://github.com/dantegd) Approvers: - Divye Gala (https://github.com/divyegala) - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/cuvs/pull/427 --- .../cuvs_bench/generate_groundtruth/__main__.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py b/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py index 2b4213016..dbee6cd36 100644 --- a/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py +++ b/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py @@ -24,7 +24,7 @@ from pylibraft.common import DeviceResources from rmm.allocators.cupy import rmm_cupy_allocator -from cuvs.neighbors.brute_force import knn +from cuvs.neighbors.brute_force import build, search from .utils import memmap_bin_file, suffix_from_dtype, write_bin @@ -49,7 +49,7 @@ def choose_random_queries(dataset, n_queries): def calc_truth(dataset, queries, k, metric="sqeuclidean"): - handle = DeviceResources() + resources = DeviceResources() n_samples = dataset.shape[0] n = 500000 # batch size for processing neighbors i = 0 @@ -63,8 +63,9 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"): X = cp.asarray(dataset[i : i + n_batch, :], cp.float32) - D, Ind = knn(X, queries, k, metric=metric, handle=handle) - handle.sync() + index = build(X, metric=metric, resources=resources) + D, Ind = search(index, queries, k, resources=resources) + resources.sync() D, Ind = cp.asarray(D), cp.asarray(Ind) Ind += i # shift neighbor index by offset i