Skip to content

Commit

Permalink
Merge branch '2310-fix-raftbench-conda-recipe' into doc-bench-docker
Browse files Browse the repository at this point in the history
  • Loading branch information
dantegd committed Sep 22, 2023
2 parents 7ccdd21 + b24e4f7 commit 6ea8fda
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
arch: "amd64"
branch: ${{ inputs.branch }}
build_type: ${{ inputs.build_type || 'branch' }}
container_image: "rapidsai/ci:latest"
container_image: "rapidsai/ci-conda:latest"
date: ${{ inputs.date }}
node_type: "gpu-v100-latest-1"
run_script: "ci/build_docs.sh"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
build_type: pull-request
node_type: "gpu-v100-latest-1"
arch: "amd64"
container_image: "rapidsai/ci:latest"
container_image: "rapidsai/ci-conda:latest"
run_script: "ci/build_docs.sh"
wheel-build-pylibraft:
needs: checks
Expand Down
6 changes: 4 additions & 2 deletions conda/recipes/raft-ann-bench/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ requirements:
- h5py {{ h5py_version }}
- benchmark
- matplotlib
# rmm is needed to determine if package is gpu-enabled
- rmm ={{ minor_version }}
- python
- pandas
- pyyaml
# rmm is needed to determine if package is gpu-enabled
- rmm ={{ minor_version }}

run:
- python
Expand All @@ -104,6 +104,8 @@ requirements:
- python
- pandas
- pyyaml
# rmm is needed to determine if package is gpu-enabled
- rmm ={{ minor_version }}
about:
home: https://rapids.ai/
license: Apache-2.0
Expand Down
2 changes: 2 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resource/detail/device_memory_resource.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>

Expand All @@ -46,6 +47,7 @@ void build_knn_graph(raft::resources const& res,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::build");
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded,
"Currently only L2Expanded metric is supported");

Expand Down
8 changes: 7 additions & 1 deletion cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/resource/detail/device_memory_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <rmm/cuda_stream_view.hpp>
Expand Down Expand Up @@ -60,17 +62,21 @@ void search_main(raft::resources const& res,
raft::device_matrix_view<internal_IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<DistanceT, int64_t, row_major> distances)
{
resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::search");
RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n",
static_cast<size_t>(index.dataset().extent(0)),
static_cast<size_t>(index.dataset().extent(1)));
RAFT_LOG_DEBUG("# query size = %lu, dim = %lu\n",
static_cast<size_t>(queries.extent(0)),
static_cast<size_t>(queries.extent(1)));
RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match");
RAFT_EXPECTS(queries.extent(1) == index.dim(), "Queries and index dim must match");
const uint32_t topk = neighbors.extent(1);

if (params.max_queries == 0) { params.max_queries = queries.extent(0); }

common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"cagra::search(max_queries = %u, k = %u, dim = %zu)", params.max_queries, topk, index.dim());

std::unique_ptr<search_plan_impl<T, internal_IdxT, DistanceT>> plan =
factory<T, internal_IdxT, DistanceT>::create(
res, params, index.dim(), index.graph_degree(), topk);
Expand Down
5 changes: 5 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <raft/core/mdarray.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/serialize.hpp>
#include <raft/neighbors/cagra_types.hpp>

Expand Down Expand Up @@ -54,6 +55,8 @@ void serialize(raft::resources const& res,
const index<T, IdxT>& index_,
bool include_dataset)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::serialize");

RAFT_LOG_DEBUG(
"Saving CAGRA index, size %zu, dim %u", static_cast<size_t>(index_.size()), index_.dim());

Expand Down Expand Up @@ -113,6 +116,8 @@ void serialize(raft::resources const& res,
template <typename T, typename IdxT>
auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::deserialize");

char dtype_string[4];
is.read(dtype_string, 4);

Expand Down

0 comments on commit 6ea8fda

Please sign in to comment.