Skip to content

Commit

Permalink
[FEA] expose python & C API for prefiltered brute force (#174)
Browse files Browse the repository at this point in the history
Authors:
  - rhdong (https://github.com/rhdong)
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #174
  • Loading branch information
rhdong authored Jul 29, 2024
1 parent 98c07f9 commit 2517826
Show file tree
Hide file tree
Showing 20 changed files with 865 additions and 111 deletions.
12 changes: 10 additions & 2 deletions cpp/include/cuvs/neighbors/brute_force.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <cuvs/core/c_api.h>
#include <cuvs/distance/distance.h>
#include <cuvs/neighbors/common.h>
#include <dlpack/dlpack.h>
#include <stdint.h>

Expand Down Expand Up @@ -135,9 +136,13 @@ cuvsError_t cuvsBruteForceBuild(cuvsResources_t res,
* DLManagedTensor dataset;
* DLManagedTensor queries;
* DLManagedTensor neighbors;
* DLManagedTensor bitmap;
*
* cuvsFilter prefilter{(uintptr_t)&bitmap, BITMAP};
*
* // Search the `index` built using `cuvsBruteForceBuild`
* cuvsError_t search_status = cuvsBruteForceSearch(res, index, &queries, &neighbors, &distances);
* cuvsError_t search_status = cuvsBruteForceSearch(res, index, &queries, &neighbors, &distances,
* prefilter);
*
* // de-allocate `res`
* cuvsError_t res_destroy_status = cuvsResourcesDestroy(res);
Expand All @@ -148,12 +153,15 @@ cuvsError_t cuvsBruteForceBuild(cuvsResources_t res,
* @param[in] queries DLManagedTensor* queries dataset to search
* @param[out] neighbors DLManagedTensor* output `k` neighbors for queries
* @param[out] distances DLManagedTensor* output `k` distances for queries
* @param[in] prefilter cuvsFilter input prefilter that can be used
to filter queries and neighbors based on the given bitmap.
*/
cuvsError_t cuvsBruteForceSearch(cuvsResources_t res,
cuvsBruteForceIndex_t index,
DLManagedTensor* queries,
DLManagedTensor* neighbors,
DLManagedTensor* distances);
DLManagedTensor* distances,
cuvsFilter prefilter);
/**
* @}
*/
Expand Down
9 changes: 5 additions & 4 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ struct index : cuvs::neighbors::index {
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
*
* @return the constructed bruteforce index
* @return the constructed brute-force index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
Expand Down Expand Up @@ -221,13 +221,14 @@ auto build(raft::resources const& handle,
* @endcode
*
* @param[in] handle
* @param[in] index bruteforce constructed index
* @param[in] index brute-force constructed index
* @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a
* given query
* @param[in] sample_filter An optional device bitmap filter function with a `row-major` layout and
* the shape of [n_queries, index->size()], which means the filter will use the first
* `index->size()` bits to indicate whether queries[0] should compute the distance with dataset.
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<float>& index,
Expand Down
61 changes: 61 additions & 0 deletions cpp/include/cuvs/neighbors/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cuvs/core/c_api.h>
#include <cuvs/distance/distance.h>
#include <dlpack/dlpack.h>
#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

/**
* @defgroup filters Filters APIs
* @brief APIs related to filter functionality.
* @{
*/

/**
* @brief Enum to denote filter type.
*/
enum cuvsFilterType {
/* No filter */
NO_FILTER,
/* Filter an index with a bitset */
BITSET,
/* Filter an index with a bitmap */
BITMAP
};

/**
* @brief Struct to hold address of cuvs::neighbor::prefilter and its type
*
*/
typedef struct {
uintptr_t addr;
enum cuvsFilterType type;
} cuvsFilter;

/**
* @}
*/

#ifdef __cplusplus
}
#endif
34 changes: 27 additions & 7 deletions cpp/src/neighbors/brute_force_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cuvs/core/interop.hpp>
#include <cuvs/neighbors/brute_force.h>
#include <cuvs/neighbors/brute_force.hpp>
#include <cuvs/neighbors/common.h>

namespace {

Expand All @@ -53,20 +54,38 @@ void _search(cuvsResources_t res,
cuvsBruteForceIndex index,
DLManagedTensor* queries_tensor,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
DLManagedTensor* distances_tensor,
cuvsFilter prefilter)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::brute_force::index<T>*>(index.addr);

using queries_mdspan_type = raft::device_matrix_view<T const, int64_t, raft::row_major>;
using neighbors_mdspan_type = raft::device_matrix_view<int64_t, int64_t, raft::row_major>;
using distances_mdspan_type = raft::device_matrix_view<float, int64_t, raft::row_major>;
auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor);
auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor);
auto distances_mds = cuvs::core::from_dlpack<distances_mdspan_type>(distances_tensor);
using prefilter_mds_type = raft::device_vector_view<const uint32_t, int64_t>;
using prefilter_opt_type = cuvs::core::bitmap_view<const uint32_t, int64_t>;

auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor);
auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor);
auto distances_mds = cuvs::core::from_dlpack<distances_mdspan_type>(distances_tensor);

std::optional<cuvs::core::bitmap_view<const uint32_t, int64_t>> filter_opt;

if (prefilter.type == NO_FILTER) {
filter_opt = std::nullopt;
} else {
auto prefilter_ptr = reinterpret_cast<DLManagedTensor*>(prefilter.addr);
auto prefilter_mds = cuvs::core::from_dlpack<prefilter_mds_type>(prefilter_ptr);
auto prefilter_view = prefilter_opt_type((const uint32_t*)prefilter_mds.data_handle(),
queries_mds.extent(0),
index_ptr->dataset().extent(0));

filter_opt = std::make_optional<prefilter_opt_type>(prefilter_view);
}

cuvs::neighbors::brute_force::search(
*res_ptr, *index_ptr, queries_mds, neighbors_mds, distances_mds, std::nullopt);
*res_ptr, *index_ptr, queries_mds, neighbors_mds, distances_mds, filter_opt);
}

} // namespace
Expand Down Expand Up @@ -120,7 +139,8 @@ extern "C" cuvsError_t cuvsBruteForceSearch(cuvsResources_t res,
cuvsBruteForceIndex_t index_c_ptr,
DLManagedTensor* queries_tensor,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
DLManagedTensor* distances_tensor,
cuvsFilter prefilter)
{
return cuvs::core::translate_exceptions([=] {
auto queries = queries_tensor->dl_tensor;
Expand All @@ -143,7 +163,7 @@ extern "C" cuvsError_t cuvsBruteForceSearch(cuvsResources_t res,
RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries");

if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) {
_search<float>(res, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<float>(res, index, queries_tensor, neighbors_tensor, distances_tensor, prefilter);
} else {
RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d",
queries.dtype.code,
Expand Down
29 changes: 12 additions & 17 deletions cpp/test/neighbors/ann_cagra_c.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ TEST(CagraC, BuildSearch)
// create cuvsResources_t
cuvsResources_t res;
cuvsResourcesCreate(&res);
cudaStream_t stream;
cuvsStreamGet(res, &stream);

// create dataset DLTensor
DLManagedTensor dataset_tensor;
Expand All @@ -65,12 +67,11 @@ TEST(CagraC, BuildSearch)
cuvsCagraBuild(res, build_params, &dataset_tensor, index);

// create queries DLTensor
float* queries_d;
cudaMalloc(&queries_d, sizeof(float) * 4 * 2);
cudaMemcpy(queries_d, queries, sizeof(float) * 4 * 2, cudaMemcpyDefault);
rmm::device_uvector<float> queries_d(4 * 2, stream);
raft::copy(queries_d.data(), (float*)queries, 4 * 2, stream);

DLManagedTensor queries_tensor;
queries_tensor.dl_tensor.data = queries_d;
queries_tensor.dl_tensor.data = queries_d.data();
queries_tensor.dl_tensor.device.device_type = kDLCUDA;
queries_tensor.dl_tensor.ndim = 2;
queries_tensor.dl_tensor.dtype.code = kDLFloat;
Expand All @@ -81,11 +82,10 @@ TEST(CagraC, BuildSearch)
queries_tensor.dl_tensor.strides = nullptr;

// create neighbors DLTensor
uint32_t* neighbors_d;
cudaMalloc(&neighbors_d, sizeof(uint32_t) * 4);
rmm::device_uvector<uint32_t> neighbors_d(4, stream);

DLManagedTensor neighbors_tensor;
neighbors_tensor.dl_tensor.data = neighbors_d;
neighbors_tensor.dl_tensor.data = neighbors_d.data();
neighbors_tensor.dl_tensor.device.device_type = kDLCUDA;
neighbors_tensor.dl_tensor.ndim = 2;
neighbors_tensor.dl_tensor.dtype.code = kDLUInt;
Expand All @@ -96,11 +96,10 @@ TEST(CagraC, BuildSearch)
neighbors_tensor.dl_tensor.strides = nullptr;

// create distances DLTensor
float* distances_d;
cudaMalloc(&distances_d, sizeof(float) * 4);
rmm::device_uvector<float> distances_d(4, stream);

DLManagedTensor distances_tensor;
distances_tensor.dl_tensor.data = distances_d;
distances_tensor.dl_tensor.data = distances_d.data();
distances_tensor.dl_tensor.device.device_type = kDLCUDA;
distances_tensor.dl_tensor.ndim = 2;
distances_tensor.dl_tensor.dtype.code = kDLFloat;
Expand All @@ -116,14 +115,10 @@ TEST(CagraC, BuildSearch)
cuvsCagraSearch(res, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor);

// verify output
ASSERT_TRUE(cuvs::devArrMatchHost(neighbors_exp, neighbors_d, 4, cuvs::Compare<uint32_t>()));
ASSERT_TRUE(
cuvs::devArrMatchHost(distances_exp, distances_d, 4, cuvs::CompareApprox<float>(0.001f)));

// delete device memory
cudaFree(queries_d);
cudaFree(neighbors_d);
cudaFree(distances_d);
cuvs::devArrMatchHost(neighbors_exp, neighbors_d.data(), 4, cuvs::Compare<uint32_t>()));
ASSERT_TRUE(cuvs::devArrMatchHost(
distances_exp, distances_d.data(), 4, cuvs::CompareApprox<float>(0.001f)));

// de-allocate index and res
cuvsCagraSearchParamsDestroy(search_params);
Expand Down
39 changes: 17 additions & 22 deletions cpp/test/neighbors/ann_ivf_flat_c.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,47 +101,42 @@ TEST(IvfFlatC, BuildSearch)
int64_t n_dim = 32;
uint32_t n_neighbors = 8;

raft::handle_t handle;
auto stream = raft::resource::get_cuda_stream(handle);

cuvsDistanceType metric = L2Expanded;
size_t n_probes = 20;
size_t n_lists = 1024;

float *index_data, *query_data, *distances_data;
int64_t* neighbors_data;
cudaMalloc(&index_data, sizeof(float) * n_rows * n_dim);
cudaMalloc(&query_data, sizeof(float) * n_queries * n_dim);
cudaMalloc(&neighbors_data, sizeof(int64_t) * n_queries * n_neighbors);
cudaMalloc(&distances_data, sizeof(float) * n_queries * n_neighbors);
rmm::device_uvector<float> index_data(n_rows * n_dim, stream);
rmm::device_uvector<float> query_data(n_queries * n_dim, stream);
rmm::device_uvector<int64_t> neighbors_data(n_queries * n_neighbors, stream);
rmm::device_uvector<float> distances_data(n_queries * n_neighbors, stream);

generate_random_data(index_data, n_rows * n_dim);
generate_random_data(query_data, n_queries * n_dim);
generate_random_data(index_data.data(), n_rows * n_dim);
generate_random_data(query_data.data(), n_queries * n_dim);

run_ivf_flat(n_rows,
n_queries,
n_dim,
n_neighbors,
index_data,
query_data,
distances_data,
neighbors_data,
index_data.data(),
query_data.data(),
distances_data.data(),
neighbors_data.data(),
metric,
n_probes,
n_lists);

recall_eval(query_data,
index_data,
neighbors_data,
distances_data,
recall_eval(query_data.data(),
index_data.data(),
neighbors_data.data(),
distances_data.data(),
n_queries,
n_rows,
n_dim,
n_neighbors,
metric,
n_probes,
n_lists);

// delete device memory
cudaFree(index_data);
cudaFree(query_data);
cudaFree(neighbors_data);
cudaFree(distances_data);
}
Loading

0 comments on commit 2517826

Please sign in to comment.