Skip to content

Commit

Permalink
Merge branch 'brute_force_index' of github.com:benfred/raft into brut…
Browse files Browse the repository at this point in the history
…e_force_index
  • Loading branch information
benfred committed Sep 25, 2023
2 parents ea9b1df + e3041a5 commit 89d108c
Show file tree
Hide file tree
Showing 75 changed files with 7,028 additions and 2,047 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
arch: "amd64"
branch: ${{ inputs.branch }}
build_type: ${{ inputs.build_type || 'branch' }}
container_image: "rapidsai/ci:latest"
container_image: "rapidsai/ci-conda:latest"
date: ${{ inputs.date }}
node_type: "gpu-v100-latest-1"
run_script: "ci/build_docs.sh"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
build_type: pull-request
node_type: "gpu-v100-latest-1"
arch: "amd64"
container_image: "rapidsai/ci:latest"
container_image: "rapidsai/ci-conda:latest"
run_script: "ci/build_docs.sh"
wheel-build-pylibraft:
needs: checks
Expand Down
5 changes: 1 addition & 4 deletions ci/release/update-version.sh
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ sed_runner 's/'"branch-.*\/RAPIDS.cmake"'/'"branch-${NEXT_SHORT_TAG}\/RAPIDS.cma
sed_runner "s/__version__ = .*/__version__ = \"${NEXT_FULL_TAG}\"/g" python/pylibraft/pylibraft/__init__.py
sed_runner "s/__version__ = .*/__version__ = \"${NEXT_FULL_TAG}\"/g" python/raft-dask/raft_dask/__init__.py

# Python pyproject.toml updates
sed_runner "s/^version = .*/version = \"${NEXT_FULL_TAG}\"/g" python/pylibraft/pyproject.toml
sed_runner "s/^version = .*/version = \"${NEXT_FULL_TAG}\"/g" python/raft-dask/pyproject.toml

# Wheel testing script
sed_runner "s/branch-.*/branch-${NEXT_SHORT_TAG}/g" ci/test_wheel_raft_dask.sh

Expand All @@ -74,6 +70,7 @@ for FILE in python/*/pyproject.toml; do
for DEP in "${DEPENDENCIES[@]}"; do
sed_runner "/\"${DEP}==/ s/==.*\"/==${NEXT_SHORT_TAG_PEP440}.*\"/g" ${FILE}
done
sed_runner "s/^version = .*/version = \"${NEXT_FULL_TAG}\"/g" "${FILE}"
sed_runner "/\"ucx-py==/ s/==.*\"/==${NEXT_UCX_PY_SHORT_TAG_PEP440}.*\"/g" ${FILE}
done

Expand Down
6 changes: 4 additions & 2 deletions conda/recipes/raft-ann-bench/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ requirements:
- h5py {{ h5py_version }}
- benchmark
- matplotlib
# rmm is needed to determine if package is gpu-enabled
- rmm ={{ minor_version }}
- python
- pandas
- pyyaml
# rmm is needed to determine if package is gpu-enabled
- rmm ={{ minor_version }}

run:
- python
Expand All @@ -104,6 +104,8 @@ requirements:
- python
- pandas
- pyyaml
# rmm is needed to determine if package is gpu-enabled
- rmm ={{ minor_version }}
about:
home: https://rapids.ai/
license: Apache-2.0
Expand Down
75 changes: 68 additions & 7 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ namespace raft::neighbors::cagra {
* // use default index parameters
* cagra::index_params build_params;
* cagra::search_params search_params
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* // create knn graph
* cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params);
* auto optimized_gaph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 64);
* auto optimized_gaph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 64);
* cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view());
* // Construct an index from dataset and optimized knn_graph
* auto index = cagra::index<T, IdxT>(res, build_params.metric(), dataset,
* optimized_graph.view());
* optimized_graph.view());
* @endcode
*
* @tparam DataT data element type
Expand Down Expand Up @@ -106,7 +106,7 @@ void build_knn_graph(raft::resources const& res,
* @code{.cpp}
* using namespace raft::neighbors;
* cagra::index_params build_params;
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* // build KNN graph not using `cagra::build_knn_graph`
* // build(knn_graph, dataset, ...);
* // sort graph index
Expand All @@ -115,7 +115,7 @@ void build_knn_graph(raft::resources const& res,
* cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view());
* // Construct an index from dataset and optimized knn_graph
* auto index = cagra::index<T, IdxT>(res, build_params.metric(), dataset,
* optimized_graph.view());
* optimized_graph.view());
* @endcode
*
* @tparam DataT type of the data in the source dataset
Expand Down Expand Up @@ -316,9 +316,70 @@ void search(raft::resources const& res,
auto distances_internal = raft::make_device_matrix_view<float, int64_t, row_major>(
distances.data_handle(), distances.extent(0), distances.extent(1));

cagra::detail::search_main<T, internal_IdxT, IdxT>(
res, params, idx, queries_internal, neighbors_internal, distances_internal);
cagra::detail::search_main<T,
internal_IdxT,
decltype(raft::neighbors::filtering::none_cagra_sample_filter()),
IdxT>(res,
params,
idx,
queries_internal,
neighbors_internal,
distances_internal,
raft::neighbors::filtering::none_cagra_sample_filter());
}

/**
* @brief Search ANN using the constructed index with the given sample filter.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @tparam T data element type
* @tparam IdxT type of the indices
* @tparam CagraSampleFilterT Device filter function, with the signature
* `(uint32_t query ix, uint32_t sample_ix) -> bool`
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] idx cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter a device filter function that greenlights samples for a given query
*/
template <typename T, typename IdxT, typename CagraSampleFilterT>
void search_with_filtering(raft::resources const& res,
const search_params& params,
const index<T, IdxT>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
{
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
"Number of rows in output neighbors and distances matrices must equal the number of queries.");

RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must equal k");
RAFT_EXPECTS(queries.extent(1) == idx.dim(),
"Number of query dimensions should equal number of dimensions in the index.");

using internal_IdxT = typename std::make_unsigned<IdxT>::type;
auto queries_internal = raft::make_device_matrix_view<const T, int64_t, row_major>(
queries.data_handle(), queries.extent(0), queries.extent(1));
auto neighbors_internal = raft::make_device_matrix_view<internal_IdxT, int64_t, row_major>(
reinterpret_cast<internal_IdxT*>(neighbors.data_handle()),
neighbors.extent(0),
neighbors.extent(1));
auto distances_internal = raft::make_device_matrix_view<float, int64_t, row_major>(
distances.data_handle(), distances.extent(0), distances.extent(1));

cagra::detail::search_main<T, internal_IdxT, CagraSampleFilterT, IdxT>(
res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter);
}

/** @} */ // end group cagra

} // namespace raft::neighbors::cagra
Expand Down
19 changes: 15 additions & 4 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ struct index : ann::index {
~index() = default;

/** Construct an empty index. */
index(raft::resources const& res)
index(raft::resources const& res,
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded)
: ann::index(),
metric_(raft::distance::DistanceType::L2Expanded),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0))
{
Expand Down Expand Up @@ -296,7 +297,11 @@ struct index : ann::index {
raft::host_matrix_view<const IdxT, int64_t, row_major> knn_graph)
{
RAFT_LOG_DEBUG("Copying CAGRA knn graph from host to device");
graph_ = make_device_matrix<IdxT, int64_t>(res, knn_graph.extent(0), knn_graph.extent(1));
if ((graph_.extent(0) != knn_graph.extent(0)) || (graph_.extent(1) != knn_graph.extent(1))) {
// clear existing memory before allocating to prevent OOM errors on large graphs
if (graph_.size()) { graph_ = make_device_matrix<IdxT, int64_t>(res, 0, 0); }
graph_ = make_device_matrix<IdxT, int64_t>(res, knn_graph.extent(0), knn_graph.extent(1));
}
raft::copy(graph_.data_handle(),
knn_graph.data_handle(),
knn_graph.size(),
Expand All @@ -311,7 +316,13 @@ struct index : ann::index {
mdspan<const T, matrix_extent<int64_t>, row_major, data_accessor> dataset)
{
size_t padded_dim = round_up_safe<size_t>(dataset.extent(1) * sizeof(T), 16) / sizeof(T);
dataset_ = make_device_matrix<T, int64_t>(res, dataset.extent(0), padded_dim);

if ((dataset_.extent(0) != dataset.extent(0)) ||
(static_cast<size_t>(dataset_.extent(1)) != padded_dim)) {
// clear existing memory before allocating to prevent OOM errors on large datasets
if (dataset_.size()) { dataset_ = make_device_matrix<T, int64_t>(res, 0, 0); }
dataset_ = make_device_matrix<T, int64_t>(res, dataset.extent(0), padded_dim);
}
if (dataset_.extent(1) == dataset.extent(1)) {
raft::copy(dataset_.data_handle(),
dataset.data_handle(),
Expand Down
2 changes: 2 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resource/detail/device_memory_resource.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>

Expand All @@ -46,6 +47,7 @@ void build_knn_graph(raft::resources const& res,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::build");
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded,
"Currently only L2Expanded metric is supported");

Expand Down
68 changes: 62 additions & 6 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/neighbors/detail/ivf_pq_search.cuh>
#include <raft/neighbors/sample_filter_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>

#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/resource/detail/device_memory_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <rmm/cuda_stream_view.hpp>
Expand All @@ -32,6 +35,48 @@

namespace raft::neighbors::cagra::detail {

template <class CagraSampleFilterT>
struct CagraSampleFilterWithQueryIdOffset {
const uint32_t offset;
CagraSampleFilterT filter;

CagraSampleFilterWithQueryIdOffset(const uint32_t offset, const CagraSampleFilterT filter)
: offset(offset), filter(filter)
{
}

_RAFT_DEVICE auto operator()(const uint32_t query_id, const uint32_t sample_id)
{
return filter(query_id + offset, sample_id);
}
};

template <class CagraSampleFilterT>
struct CagraSampleFilterT_Selector {
using type = CagraSampleFilterWithQueryIdOffset<CagraSampleFilterT>;
};
template <>
struct CagraSampleFilterT_Selector<raft::neighbors::filtering::none_cagra_sample_filter> {
using type = raft::neighbors::filtering::none_cagra_sample_filter;
};

// A helper function to set a query id offset
template <class CagraSampleFilterT>
inline typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type set_offset(
CagraSampleFilterT filter, const uint32_t offset)
{
typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type new_filter(offset, filter);
return new_filter;
}
template <>
inline
typename CagraSampleFilterT_Selector<raft::neighbors::filtering::none_cagra_sample_filter>::type
set_offset<raft::neighbors::filtering::none_cagra_sample_filter>(
raft::neighbors::filtering::none_cagra_sample_filter filter, const uint32_t)
{
return filter;
}

/**
* @brief Search ANN using the constructed index.
*
Expand All @@ -52,27 +97,37 @@ namespace raft::neighbors::cagra::detail {
* k]
*/

template <typename T, typename internal_IdxT, typename IdxT = uint32_t, typename DistanceT = float>
template <typename T,
typename internal_IdxT,
typename CagraSampleFilterT,
typename IdxT = uint32_t,
typename DistanceT = float>
void search_main(raft::resources const& res,
search_params params,
const index<T, IdxT>& index,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<internal_IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<DistanceT, int64_t, row_major> distances)
raft::device_matrix_view<DistanceT, int64_t, row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
{
resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::search");
RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n",
static_cast<size_t>(index.dataset().extent(0)),
static_cast<size_t>(index.dataset().extent(1)));
RAFT_LOG_DEBUG("# query size = %lu, dim = %lu\n",
static_cast<size_t>(queries.extent(0)),
static_cast<size_t>(queries.extent(1)));
RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match");
RAFT_EXPECTS(queries.extent(1) == index.dim(), "Queries and index dim must match");
const uint32_t topk = neighbors.extent(1);

if (params.max_queries == 0) { params.max_queries = queries.extent(0); }

std::unique_ptr<search_plan_impl<T, internal_IdxT, DistanceT>> plan =
factory<T, internal_IdxT, DistanceT>::create(
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"cagra::search(max_queries = %u, k = %u, dim = %zu)", params.max_queries, topk, index.dim());

using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<T, internal_IdxT, DistanceT, CagraSampleFilterT_s>> plan =
factory<T, internal_IdxT, DistanceT, CagraSampleFilterT_s>::create(
res, params, index.dim(), index.graph_degree(), topk);

plan->check(neighbors.extent(1));
Expand Down Expand Up @@ -113,7 +168,8 @@ void search_main(raft::resources const& res,
n_queries,
_seed_ptr,
_num_executed_iterations,
topk);
topk,
set_offset(sample_filter, qid));
}

static_assert(std::is_same_v<DistanceT, float>,
Expand Down
24 changes: 18 additions & 6 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <raft/core/mdarray.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/serialize.hpp>
#include <raft/neighbors/cagra_types.hpp>

Expand Down Expand Up @@ -54,6 +55,8 @@ void serialize(raft::resources const& res,
const index<T, IdxT>& index_,
bool include_dataset)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::serialize");

RAFT_LOG_DEBUG(
"Saving CAGRA index, size %zu, dim %u", static_cast<size_t>(index_.size()), index_.dim());

Expand Down Expand Up @@ -113,6 +116,8 @@ void serialize(raft::resources const& res,
template <typename T, typename IdxT>
auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::deserialize");

char dtype_string[4];
is.read(dtype_string, 4);

Expand All @@ -125,15 +130,22 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>
auto graph_degree = deserialize_scalar<std::uint32_t>(res, is);
auto metric = deserialize_scalar<raft::distance::DistanceType>(res, is);

auto dataset = raft::make_host_matrix<T, int64_t>(n_rows, dim);
auto graph = raft::make_host_matrix<IdxT, int64_t>(n_rows, graph_degree);
auto graph = raft::make_host_matrix<IdxT, int64_t>(n_rows, graph_degree);
deserialize_mdspan(res, is, graph.view());

bool has_dataset = deserialize_scalar<bool>(res, is);
if (has_dataset) { deserialize_mdspan(res, is, dataset.view()); }

return index<T, IdxT>(
res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view()));
if (has_dataset) {
auto dataset = raft::make_host_matrix<T, int64_t>(n_rows, dim);
deserialize_mdspan(res, is, dataset.view());
return index<T, IdxT>(
res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view()));
} else {
// create a new index with no dataset - the user must supply via update_dataset themselves
// later (this avoids allocating GPU memory in the meantime)
index<T, IdxT> idx(res, metric);
idx.update_graph(res, raft::make_const_mdspan(graph.view()));
return idx;
}
}

template <typename T, typename IdxT>
Expand Down
Loading

0 comments on commit 89d108c

Please sign in to comment.