Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/fea-mdspan_copy' into fea-mdspan…
Browse files Browse the repository at this point in the history
…_copy
  • Loading branch information
wphicks committed Sep 22, 2023
2 parents 62ac60a + 7f407ed commit 5bddcc8
Show file tree
Hide file tree
Showing 10 changed files with 809 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
arch: "amd64"
branch: ${{ inputs.branch }}
build_type: ${{ inputs.build_type || 'branch' }}
container_image: "rapidsai/ci:latest"
container_image: "rapidsai/ci-conda:latest"
date: ${{ inputs.date }}
node_type: "gpu-v100-latest-1"
run_script: "ci/build_docs.sh"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
build_type: pull-request
node_type: "gpu-v100-latest-1"
arch: "amd64"
container_image: "rapidsai/ci:latest"
container_image: "rapidsai/ci-conda:latest"
run_script: "ci/build_docs.sh"
wheel-build-pylibraft:
needs: checks
Expand Down
5 changes: 1 addition & 4 deletions ci/release/update-version.sh
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ sed_runner 's/'"branch-.*\/RAPIDS.cmake"'/'"branch-${NEXT_SHORT_TAG}\/RAPIDS.cma
sed_runner "s/__version__ = .*/__version__ = \"${NEXT_FULL_TAG}\"/g" python/pylibraft/pylibraft/__init__.py
sed_runner "s/__version__ = .*/__version__ = \"${NEXT_FULL_TAG}\"/g" python/raft-dask/raft_dask/__init__.py

# Python pyproject.toml updates
sed_runner "s/^version = .*/version = \"${NEXT_FULL_TAG}\"/g" python/pylibraft/pyproject.toml
sed_runner "s/^version = .*/version = \"${NEXT_FULL_TAG}\"/g" python/raft-dask/pyproject.toml

# Wheel testing script
sed_runner "s/branch-.*/branch-${NEXT_SHORT_TAG}/g" ci/test_wheel_raft_dask.sh

Expand All @@ -74,6 +70,7 @@ for FILE in python/*/pyproject.toml; do
for DEP in "${DEPENDENCIES[@]}"; do
sed_runner "/\"${DEP}==/ s/==.*\"/==${NEXT_SHORT_TAG_PEP440}.*\"/g" ${FILE}
done
sed_runner "s/^version = .*/version = \"${NEXT_FULL_TAG}\"/g" "${FILE}"
sed_runner "/\"ucx-py==/ s/==.*\"/==${NEXT_UCX_PY_SHORT_TAG_PEP440}.*\"/g" ${FILE}
done

Expand Down
6 changes: 4 additions & 2 deletions conda/recipes/raft-ann-bench/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ requirements:
- h5py {{ h5py_version }}
- benchmark
- matplotlib
# rmm is needed to determine if package is gpu-enabled
- rmm ={{ minor_version }}
- python
- pandas
- pyyaml
# rmm is needed to determine if package is gpu-enabled
- rmm ={{ minor_version }}

run:
- python
Expand All @@ -104,6 +104,8 @@ requirements:
- python
- pandas
- pyyaml
# rmm is needed to determine if package is gpu-enabled
- rmm ={{ minor_version }}
about:
home: https://rapids.ai/
license: Apache-2.0
Expand Down
2 changes: 2 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resource/detail/device_memory_resource.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>

Expand All @@ -46,6 +47,7 @@ void build_knn_graph(raft::resources const& res,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::build");
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded,
"Currently only L2Expanded metric is supported");

Expand Down
8 changes: 7 additions & 1 deletion cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/resource/detail/device_memory_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <rmm/cuda_stream_view.hpp>
Expand Down Expand Up @@ -60,17 +62,21 @@ void search_main(raft::resources const& res,
raft::device_matrix_view<internal_IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<DistanceT, int64_t, row_major> distances)
{
resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::search");
RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n",
static_cast<size_t>(index.dataset().extent(0)),
static_cast<size_t>(index.dataset().extent(1)));
RAFT_LOG_DEBUG("# query size = %lu, dim = %lu\n",
static_cast<size_t>(queries.extent(0)),
static_cast<size_t>(queries.extent(1)));
RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match");
RAFT_EXPECTS(queries.extent(1) == index.dim(), "Queries and index dim must match");
const uint32_t topk = neighbors.extent(1);

if (params.max_queries == 0) { params.max_queries = queries.extent(0); }

common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"cagra::search(max_queries = %u, k = %u, dim = %zu)", params.max_queries, topk, index.dim());

std::unique_ptr<search_plan_impl<T, internal_IdxT, DistanceT>> plan =
factory<T, internal_IdxT, DistanceT>::create(
res, params, index.dim(), index.graph_degree(), topk);
Expand Down
5 changes: 5 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <raft/core/mdarray.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/serialize.hpp>
#include <raft/neighbors/cagra_types.hpp>

Expand Down Expand Up @@ -54,6 +55,8 @@ void serialize(raft::resources const& res,
const index<T, IdxT>& index_,
bool include_dataset)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::serialize");

RAFT_LOG_DEBUG(
"Saving CAGRA index, size %zu, dim %u", static_cast<size_t>(index_.size()), index_.dim());

Expand Down Expand Up @@ -113,6 +116,8 @@ void serialize(raft::resources const& res,
template <typename T, typename IdxT>
auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::deserialize");

char dtype_string[4];
is.read(dtype_string, 4);

Expand Down
Loading

0 comments on commit 5bddcc8

Please sign in to comment.