Skip to content

Commit

Permalink
Add NVTX ranges for cagra search/serialize functions (#1737)
Browse files Browse the repository at this point in the history
* Add NVTX ranges for cagra search/serialize functions
* Add warn_non_pool_workspace for cagra build/search

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1737
  • Loading branch information
benfred authored Sep 21, 2023
1 parent 8292ef1 commit 4f0a2d2
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
2 changes: 2 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resource/detail/device_memory_resource.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>

Expand All @@ -46,6 +47,7 @@ void build_knn_graph(raft::resources const& res,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::build");
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded,
"Currently only L2Expanded metric is supported");

Expand Down
8 changes: 7 additions & 1 deletion cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/resource/detail/device_memory_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <rmm/cuda_stream_view.hpp>
Expand Down Expand Up @@ -60,17 +62,21 @@ void search_main(raft::resources const& res,
raft::device_matrix_view<internal_IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<DistanceT, int64_t, row_major> distances)
{
resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::search");
RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n",
static_cast<size_t>(index.dataset().extent(0)),
static_cast<size_t>(index.dataset().extent(1)));
RAFT_LOG_DEBUG("# query size = %lu, dim = %lu\n",
static_cast<size_t>(queries.extent(0)),
static_cast<size_t>(queries.extent(1)));
RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match");
RAFT_EXPECTS(queries.extent(1) == index.dim(), "Queries and index dim must match");
const uint32_t topk = neighbors.extent(1);

if (params.max_queries == 0) { params.max_queries = queries.extent(0); }

common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"cagra::search(max_queries = %u, k = %u, dim = %zu)", params.max_queries, topk, index.dim());

std::unique_ptr<search_plan_impl<T, internal_IdxT, DistanceT>> plan =
factory<T, internal_IdxT, DistanceT>::create(
res, params, index.dim(), index.graph_degree(), topk);
Expand Down
5 changes: 5 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <raft/core/mdarray.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/serialize.hpp>
#include <raft/neighbors/cagra_types.hpp>

Expand Down Expand Up @@ -54,6 +55,8 @@ void serialize(raft::resources const& res,
const index<T, IdxT>& index_,
bool include_dataset)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::serialize");

RAFT_LOG_DEBUG(
"Saving CAGRA index, size %zu, dim %u", static_cast<size_t>(index_.size()), index_.dim());

Expand Down Expand Up @@ -113,6 +116,8 @@ void serialize(raft::resources const& res,
template <typename T, typename IdxT>
auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::deserialize");

char dtype_string[4];
is.read(dtype_string, 4);

Expand Down

0 comments on commit 4f0a2d2

Please sign in to comment.