diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index d19d7e7904..80e964df57 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -46,6 +47,7 @@ void build_knn_graph(raft::resources const& res, std::optional build_params = std::nullopt, std::optional search_params = std::nullopt) { + resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::build"); RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded, "Currently only L2Expanded metric is supported"); diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 8190817b5b..b484fa55f9 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -22,6 +22,8 @@ #include #include +#include +#include #include #include #include @@ -60,17 +62,21 @@ void search_main(raft::resources const& res, raft::device_matrix_view neighbors, raft::device_matrix_view distances) { + resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::search"); RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n", static_cast(index.dataset().extent(0)), static_cast(index.dataset().extent(1))); RAFT_LOG_DEBUG("# query size = %lu, dim = %lu\n", static_cast(queries.extent(0)), static_cast(queries.extent(1))); - RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match"); + RAFT_EXPECTS(queries.extent(1) == index.dim(), "Queries and index dim must match"); const uint32_t topk = neighbors.extent(1); if (params.max_queries == 0) { params.max_queries = queries.extent(0); } + common::nvtx::range fun_scope( + "cagra::search(max_queries = %u, k = %u, dim = %zu)", params.max_queries, topk, index.dim()); + std::unique_ptr> plan = factory::create( res, params, index.dim(), index.graph_degree(), topk); diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 2c9cbd2563..234911e15c 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -54,6 +55,8 @@ void serialize(raft::resources const& res, const index& index_, bool include_dataset) { + common::nvtx::range fun_scope("cagra::serialize"); + RAFT_LOG_DEBUG( "Saving CAGRA index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); @@ -113,6 +116,8 @@ void serialize(raft::resources const& res, template auto deserialize(raft::resources const& res, std::istream& is) -> index { + common::nvtx::range fun_scope("cagra::deserialize"); + char dtype_string[4]; is.read(dtype_string, 4);