diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile
new file mode 100644
index 0000000000..9d35e3f97f
--- /dev/null
+++ b/.devcontainer/Dockerfile
@@ -0,0 +1,30 @@
+# syntax=docker/dockerfile:1.5
+
+ARG BASE
+ARG PYTHON_PACKAGE_MANAGER=conda
+
+FROM ${BASE} as pip-base
+
+ENV DEFAULT_VIRTUAL_ENV=rapids
+
+FROM ${BASE} as conda-base
+
+ENV DEFAULT_CONDA_ENV=rapids
+
+FROM ${PYTHON_PACKAGE_MANAGER}-base
+
+ARG CUDA
+ENV CUDAARCHS="RAPIDS"
+ENV CUDA_VERSION="${CUDA_VERSION:-${CUDA}}"
+
+ARG PYTHON_PACKAGE_MANAGER
+ENV PYTHON_PACKAGE_MANAGER="${PYTHON_PACKAGE_MANAGER}"
+
+ENV PYTHONSAFEPATH="1"
+ENV PYTHONUNBUFFERED="1"
+ENV PYTHONDONTWRITEBYTECODE="1"
+
+ENV SCCACHE_REGION="us-east-2"
+ENV SCCACHE_BUCKET="rapids-sccache-devs"
+ENV VAULT_HOST="https://vault.ops.k8s.rapids.ai"
+ENV HISTFILE="/home/coder/.cache/._bash_history"
diff --git a/.devcontainer/README.md b/.devcontainer/README.md
new file mode 100644
index 0000000000..3c76b8963d
--- /dev/null
+++ b/.devcontainer/README.md
@@ -0,0 +1,64 @@
+# RAFT Development Containers
+
+This directory contains [devcontainer configurations](https://containers.dev/implementors/json_reference/) for using VSCode to [develop in a container](https://code.visualstudio.com/docs/devcontainers/containers) via the `Remote Containers` [extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) or [GitHub Codespaces](https://github.com/codespaces).
+
+This container is a turnkey development environment for building and testing the RAFT C++ and Python libraries.
+
+## Table of Contents
+
+* [Prerequisites](#prerequisites)
+* [Host bind mounts](#host-bind-mounts)
+* [Launch a Dev Container](#launch-a-dev-container)
+
+## Prerequisites
+
+* [VSCode](https://code.visualstudio.com/download)
+* [VSCode Remote Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers)
+
+## Host bind mounts
+
+By default, the following directories are bind-mounted into the devcontainer:
+
+* `${repo}:/home/coder/raft`
+* `${repo}/../.aws:/home/coder/.aws`
+* `${repo}/../.local:/home/coder/.local`
+* `${repo}/../.cache:/home/coder/.cache`
+* `${repo}/../.conda:/home/coder/.conda`
+* `${repo}/../.config:/home/coder/.config`
+
+This ensures caches, configurations, dependencies, and your commits are persisted on the host across container runs.
+
+## Launch a Dev Container
+
+To launch a devcontainer from VSCode, open the RAFT repo and select the "Reopen in Container" button in the bottom right:
+
+Alternatively, open the VSCode command palette (typically `cmd/ctrl + shift + P`) and run the "Rebuild and Reopen in Container" command.
+
+## Using the devcontainer
+
+On startup, the devcontainer creates or updates the conda/pip environment using `raft/dependencies.yaml`.
+
+The container includes convenience functions to clean, configure, and build the various RAFT components:
+
+```shell
+$ clean-raft-cpp # only cleans the C++ build dir
+$ clean-pylibraft-python # only cleans the Python build dir
+$ clean-raft # cleans both C++ and Python build dirs
+
+$ configure-raft-cpp # only configures raft C++ lib
+
+$ build-raft-cpp # only builds raft C++ lib
+$ build-pylibraft-python # only builds raft Python lib
+$ build-raft # builds both C++ and Python libs
+```
+
+* The C++ build script is a small wrapper around `cmake -S ~/raft/cpp -B ~/raft/cpp/build` and `cmake --build ~/raft/cpp/build`
+* The Python build script is a small wrapper around `pip install --editable ~/raft/cpp`
+
+Unlike `build.sh`, these convenience scripts *don't* install the libraries after building them. Instead, they automatically inject the correct arguments to build the C++ libraries from source and use their build dirs as package roots:
+
+```shell
+$ cmake -S ~/raft/cpp -B ~/raft/cpp/build
+$ CMAKE_ARGS="-Draft_ROOT=~/raft/cpp/build" \ # <-- this argument is automatic
+ pip install -e ~/raft/cpp
+```
diff --git a/.devcontainer/cuda11.8-conda/devcontainer.json b/.devcontainer/cuda11.8-conda/devcontainer.json
new file mode 100644
index 0000000000..8da9b5428a
--- /dev/null
+++ b/.devcontainer/cuda11.8-conda/devcontainer.json
@@ -0,0 +1,37 @@
+{
+ "build": {
+ "context": "${localWorkspaceFolder}/.devcontainer",
+ "dockerfile": "${localWorkspaceFolder}/.devcontainer/Dockerfile",
+ "args": {
+ "CUDA": "11.8",
+ "PYTHON_PACKAGE_MANAGER": "conda",
+ "BASE": "rapidsai/devcontainers:23.10-cpp-llvm16-cuda11.8-mambaforge-ubuntu22.04"
+ }
+ },
+ "hostRequirements": {"gpu": "optional"},
+ "features": {
+ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:23.10": {}
+ },
+ "overrideFeatureInstallOrder": [
+ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils"
+ ],
+ "initializeCommand": ["/bin/bash", "-c", "mkdir -m 0755 -p ${localWorkspaceFolder}/../.{aws,cache,config,conda/pkgs,conda/${localWorkspaceFolderBasename}-cuda11.8-envs}"],
+ "postAttachCommand": ["/bin/bash", "-c", "if [ ${CODESPACES:-false} = 'true' ]; then . devcontainer-utils-post-attach-command; . rapids-post-attach-command; fi"],
+ "workspaceFolder": "/home/coder",
+ "workspaceMount": "source=${localWorkspaceFolder},target=/home/coder/raft,type=bind,consistency=consistent",
+ "mounts": [
+ "source=${localWorkspaceFolder}/../.aws,target=/home/coder/.aws,type=bind,consistency=consistent",
+ "source=${localWorkspaceFolder}/../.cache,target=/home/coder/.cache,type=bind,consistency=consistent",
+ "source=${localWorkspaceFolder}/../.config,target=/home/coder/.config,type=bind,consistency=consistent",
+ "source=${localWorkspaceFolder}/../.conda/pkgs,target=/home/coder/.conda/pkgs,type=bind,consistency=consistent",
+ "source=${localWorkspaceFolder}/../.conda/${localWorkspaceFolderBasename}-cuda11.8-envs,target=/home/coder/.conda/envs,type=bind,consistency=consistent"
+ ],
+ "customizations": {
+ "vscode": {
+ "extensions": [
+ "ms-python.flake8",
+ "nvidia.nsight-vscode-edition"
+ ]
+ }
+ }
+}
diff --git a/.devcontainer/cuda11.8-pip/devcontainer.json b/.devcontainer/cuda11.8-pip/devcontainer.json
new file mode 100644
index 0000000000..0b3ec79e37
--- /dev/null
+++ b/.devcontainer/cuda11.8-pip/devcontainer.json
@@ -0,0 +1,38 @@
+{
+ "build": {
+ "context": "${localWorkspaceFolder}/.devcontainer",
+ "dockerfile": "${localWorkspaceFolder}/.devcontainer/Dockerfile",
+ "args": {
+ "CUDA": "11.8",
+ "PYTHON_PACKAGE_MANAGER": "pip",
+ "BASE": "rapidsai/devcontainers:23.10-cpp-llvm16-cuda11.8-ubuntu22.04"
+ }
+ },
+ "hostRequirements": {"gpu": "optional"},
+ "features": {
+ "ghcr.io/rapidsai/devcontainers/features/ucx:23.10": {"version": "1.14.1"},
+ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:23.10": {}
+ },
+ "overrideFeatureInstallOrder": [
+ "ghcr.io/rapidsai/devcontainers/features/ucx",
+ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils"
+ ],
+ "initializeCommand": ["/bin/bash", "-c", "mkdir -m 0755 -p ${localWorkspaceFolder}/../.{aws,cache,config/pip,local/share/${localWorkspaceFolderBasename}-cuda11.8-venvs}"],
+ "postAttachCommand": ["/bin/bash", "-c", "if [ ${CODESPACES:-false} = 'true' ]; then . devcontainer-utils-post-attach-command; . rapids-post-attach-command; fi"],
+ "workspaceFolder": "/home/coder",
+ "workspaceMount": "source=${localWorkspaceFolder},target=/home/coder/raft,type=bind,consistency=consistent",
+ "mounts": [
+ "source=${localWorkspaceFolder}/../.aws,target=/home/coder/.aws,type=bind,consistency=consistent",
+ "source=${localWorkspaceFolder}/../.cache,target=/home/coder/.cache,type=bind,consistency=consistent",
+ "source=${localWorkspaceFolder}/../.config,target=/home/coder/.config,type=bind,consistency=consistent",
+ "source=${localWorkspaceFolder}/../.local/share/${localWorkspaceFolderBasename}-cuda11.8-venvs,target=/home/coder/.local/share/venvs,type=bind,consistency=consistent"
+ ],
+ "customizations": {
+ "vscode": {
+ "extensions": [
+ "ms-python.flake8",
+ "nvidia.nsight-vscode-edition"
+ ]
+ }
+ }
+}
diff --git a/.devcontainer/cuda12.0-conda/devcontainer.json b/.devcontainer/cuda12.0-conda/devcontainer.json
new file mode 100644
index 0000000000..f5af166b46
--- /dev/null
+++ b/.devcontainer/cuda12.0-conda/devcontainer.json
@@ -0,0 +1,37 @@
+{
+ "build": {
+ "context": "${localWorkspaceFolder}/.devcontainer",
+ "dockerfile": "${localWorkspaceFolder}/.devcontainer/Dockerfile",
+ "args": {
+ "CUDA": "12.0",
+ "PYTHON_PACKAGE_MANAGER": "conda",
+ "BASE": "rapidsai/devcontainers:23.10-cpp-mambaforge-ubuntu22.04"
+ }
+ },
+ "hostRequirements": {"gpu": "optional"},
+ "features": {
+ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:23.10": {}
+ },
+ "overrideFeatureInstallOrder": [
+ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils"
+ ],
+ "initializeCommand": ["/bin/bash", "-c", "mkdir -m 0755 -p ${localWorkspaceFolder}/../.{aws,cache,config,conda/pkgs,conda/${localWorkspaceFolderBasename}-cuda12.0-envs}"],
+ "postAttachCommand": ["/bin/bash", "-c", "if [ ${CODESPACES:-false} = 'true' ]; then . devcontainer-utils-post-attach-command; . rapids-post-attach-command; fi"],
+ "workspaceFolder": "/home/coder",
+ "workspaceMount": "source=${localWorkspaceFolder},target=/home/coder/raft,type=bind,consistency=consistent",
+ "mounts": [
+ "source=${localWorkspaceFolder}/../.aws,target=/home/coder/.aws,type=bind,consistency=consistent",
+ "source=${localWorkspaceFolder}/../.cache,target=/home/coder/.cache,type=bind,consistency=consistent",
+ "source=${localWorkspaceFolder}/../.config,target=/home/coder/.config,type=bind,consistency=consistent",
+ "source=${localWorkspaceFolder}/../.conda/pkgs,target=/home/coder/.conda/pkgs,type=bind,consistency=consistent",
+ "source=${localWorkspaceFolder}/../.conda/${localWorkspaceFolderBasename}-cuda12.0-envs,target=/home/coder/.conda/envs,type=bind,consistency=consistent"
+ ],
+ "customizations": {
+ "vscode": {
+ "extensions": [
+ "ms-python.flake8",
+ "nvidia.nsight-vscode-edition"
+ ]
+ }
+ }
+}
diff --git a/.devcontainer/cuda12.0-pip/devcontainer.json b/.devcontainer/cuda12.0-pip/devcontainer.json
new file mode 100644
index 0000000000..9f28002d38
--- /dev/null
+++ b/.devcontainer/cuda12.0-pip/devcontainer.json
@@ -0,0 +1,38 @@
+{
+ "build": {
+ "context": "${localWorkspaceFolder}/.devcontainer",
+ "dockerfile": "${localWorkspaceFolder}/.devcontainer/Dockerfile",
+ "args": {
+ "CUDA": "12.0",
+ "PYTHON_PACKAGE_MANAGER": "pip",
+ "BASE": "rapidsai/devcontainers:23.10-cpp-llvm16-cuda12.0-ubuntu22.04"
+ }
+ },
+ "hostRequirements": {"gpu": "optional"},
+ "features": {
+ "ghcr.io/rapidsai/devcontainers/features/ucx:23.10": {"version": "1.14.1"},
+ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:23.10": {}
+ },
+ "overrideFeatureInstallOrder": [
+ "ghcr.io/rapidsai/devcontainers/features/ucx",
+ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils"
+ ],
+ "initializeCommand": ["/bin/bash", "-c", "mkdir -m 0755 -p ${localWorkspaceFolder}/../.{aws,cache,config/pip,local/share/${localWorkspaceFolderBasename}-cuda12.0-venvs}"],
+ "postAttachCommand": ["/bin/bash", "-c", "if [ ${CODESPACES:-false} = 'true' ]; then . devcontainer-utils-post-attach-command; . rapids-post-attach-command; fi"],
+ "workspaceFolder": "/home/coder",
+ "workspaceMount": "source=${localWorkspaceFolder},target=/home/coder/raft,type=bind,consistency=consistent",
+ "mounts": [
+ "source=${localWorkspaceFolder}/../.aws,target=/home/coder/.aws,type=bind,consistency=consistent",
+ "source=${localWorkspaceFolder}/../.cache,target=/home/coder/.cache,type=bind,consistency=consistent",
+ "source=${localWorkspaceFolder}/../.config,target=/home/coder/.config,type=bind,consistency=consistent",
+ "source=${localWorkspaceFolder}/../.local/share/${localWorkspaceFolderBasename}-cuda12.0-venvs,target=/home/coder/.local/share/venvs,type=bind,consistency=consistent"
+ ],
+ "customizations": {
+ "vscode": {
+ "extensions": [
+ "ms-python.flake8",
+ "nvidia.nsight-vscode-edition"
+ ]
+ }
+ }
+}
diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml
index 00004c4e4d..107823d5ee 100644
--- a/.github/workflows/build.yaml
+++ b/.github/workflows/build.yaml
@@ -62,7 +62,7 @@ jobs:
arch: "amd64"
branch: ${{ inputs.branch }}
build_type: ${{ inputs.build_type || 'branch' }}
- container_image: "rapidsai/ci:latest"
+ container_image: "rapidsai/ci-conda:latest"
date: ${{ inputs.date }}
node_type: "gpu-v100-latest-1"
run_script: "ci/build_docs.sh"
diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml
index 4437e0dc85..e539877851 100644
--- a/.github/workflows/pr.yaml
+++ b/.github/workflows/pr.yaml
@@ -22,6 +22,7 @@ jobs:
- wheel-tests-pylibraft
- wheel-build-raft-dask
- wheel-tests-raft-dask
+ - devcontainer
secrets: inherit
uses: rapidsai/shared-action-workflows/.github/workflows/pr-builder.yaml@branch-23.10
checks:
@@ -62,7 +63,7 @@ jobs:
build_type: pull-request
node_type: "gpu-v100-latest-1"
arch: "amd64"
- container_image: "rapidsai/ci:latest"
+ container_image: "rapidsai/ci-conda:latest"
run_script: "ci/build_docs.sh"
wheel-build-pylibraft:
needs: checks
@@ -92,3 +93,11 @@ jobs:
with:
build_type: pull-request
script: ci/test_wheel_raft_dask.sh
+ devcontainer:
+ secrets: inherit
+ uses: rapidsai/shared-action-workflows/.github/workflows/build-in-devcontainer.yaml@branch-23.10
+ with:
+ build_command: |
+ sccache -z;
+ build-all -DBUILD_PRIMS_BENCH=ON -DBUILD_ANN_BENCH=ON --verbose;
+ sccache -s;
diff --git a/.gitignore b/.gitignore
index 7939fc1622..11b7bc3eba 100644
--- a/.gitignore
+++ b/.gitignore
@@ -62,3 +62,7 @@ _xml
# sphinx
_html
_text
+
+# clang tooling
+compile_commands.json
+.clangd/
diff --git a/build.sh b/build.sh
index 071820ba93..6200e6a2fa 100755
--- a/build.sh
+++ b/build.sh
@@ -78,8 +78,8 @@ INSTALL_TARGET=install
BUILD_REPORT_METRICS=""
BUILD_REPORT_INCL_CACHE_STATS=OFF
-TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
-BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH"
+TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
+BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH"
CACHE_ARGS=""
NVTX=ON
diff --git a/ci/build_cpp.sh b/ci/build_cpp.sh
index d2d2d08b99..a41f81152d 100755
--- a/ci/build_cpp.sh
+++ b/ci/build_cpp.sh
@@ -1,5 +1,5 @@
#!/bin/bash
-# Copyright (c) 2022, NVIDIA CORPORATION.
+# Copyright (c) 2022-2023, NVIDIA CORPORATION.
set -euo pipefail
diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh
index 6a7e319f5d..a867a71f68 100755
--- a/ci/release/update-version.sh
+++ b/ci/release/update-version.sh
@@ -47,10 +47,6 @@ sed_runner 's/'"branch-.*\/RAPIDS.cmake"'/'"branch-${NEXT_SHORT_TAG}\/RAPIDS.cma
sed_runner "s/__version__ = .*/__version__ = \"${NEXT_FULL_TAG}\"/g" python/pylibraft/pylibraft/__init__.py
sed_runner "s/__version__ = .*/__version__ = \"${NEXT_FULL_TAG}\"/g" python/raft-dask/raft_dask/__init__.py
-# Python pyproject.toml updates
-sed_runner "s/^version = .*/version = \"${NEXT_FULL_TAG}\"/g" python/pylibraft/pyproject.toml
-sed_runner "s/^version = .*/version = \"${NEXT_FULL_TAG}\"/g" python/raft-dask/pyproject.toml
-
# Wheel testing script
sed_runner "s/branch-.*/branch-${NEXT_SHORT_TAG}/g" ci/test_wheel_raft_dask.sh
@@ -74,6 +70,7 @@ for FILE in python/*/pyproject.toml; do
for DEP in "${DEPENDENCIES[@]}"; do
sed_runner "/\"${DEP}==/ s/==.*\"/==${NEXT_SHORT_TAG_PEP440}.*\"/g" ${FILE}
done
+ sed_runner "s/^version = .*/version = \"${NEXT_FULL_TAG}\"/g" "${FILE}"
sed_runner "/\"ucx-py==/ s/==.*\"/==${NEXT_UCX_PY_SHORT_TAG_PEP440}.*\"/g" ${FILE}
done
@@ -94,3 +91,10 @@ sed_runner "/^PROJECT_NUMBER/ s|\".*\"|\"${NEXT_SHORT_TAG}\"|g" cpp/doxygen/Doxy
sed_runner "/^set(RAFT_VERSION/ s|\".*\"|\"${NEXT_SHORT_TAG}\"|g" docs/source/build.md
sed_runner "/GIT_TAG.*branch-/ s|branch-.*|branch-${NEXT_SHORT_TAG}|g" docs/source/build.md
sed_runner "/rapidsai\/raft/ s|branch-[0-9][0-9].[0-9][0-9]|branch-${NEXT_SHORT_TAG}|g" docs/source/developer_guide.md
+
+# .devcontainer files
+find .devcontainer/ -type f -name devcontainer.json -print0 | while IFS= read -r -d '' filename; do
+ sed_runner "s@rapidsai/devcontainers:[0-9.]*@rapidsai/devcontainers:${NEXT_SHORT_TAG}@g" "${filename}"
+ sed_runner "s@rapidsai/devcontainers/features/ucx:[0-9.]*@rapidsai/devcontainers/features/ucx:${NEXT_SHORT_TAG_PEP440}@" "${filename}"
+ sed_runner "s@rapidsai/devcontainers/features/rapids-build-utils:[0-9.]*@rapidsai/devcontainers/features/rapids-build-utils:${NEXT_SHORT_TAG_PEP440}@" "${filename}"
+done
diff --git a/ci/test_wheel_raft_dask.sh b/ci/test_wheel_raft_dask.sh
index 676d642de9..fd9668e968 100755
--- a/ci/test_wheel_raft_dask.sh
+++ b/ci/test_wheel_raft_dask.sh
@@ -12,7 +12,7 @@ RAPIDS_PY_WHEEL_NAME="pylibraft_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels
python -m pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl
# Always install latest dask for testing
-python -m pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.10
+python -m pip install git+https://github.com/dask/dask.git@2023.9.2 git+https://github.com/dask/distributed.git@2023.9.2 git+https://github.com/rapidsai/dask-cuda.git@branch-23.10
# echo to expand wildcard before adding `[extra]` requires for pip
python -m pip install $(echo ./dist/raft_dask*.whl)[test]
diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml
index 65b4232d83..739e1e9785 100644
--- a/conda/environments/all_cuda-118_arch-x86_64.yaml
+++ b/conda/environments/all_cuda-118_arch-x86_64.yaml
@@ -10,7 +10,7 @@ dependencies:
- breathe
- c-compiler
- clang-tools=16.0.6
-- clang=16.0.6
+- clang==16.0.6
- cmake>=3.26.4
- cuda-profiler-api=11.8.86
- cuda-python>=11.7.1,<12.0a0
@@ -19,10 +19,10 @@ dependencies:
- cupy>=12.0.0
- cxx-compiler
- cython>=3.0.0
-- dask-core>=2023.7.1
+- dask-core==2023.9.2
- dask-cuda==23.10.*
-- dask>=2023.7.1
-- distributed>=2023.7.1
+- dask==2023.9.2
+- distributed==2023.9.2
- doxygen>=1.8.20
- gcc_linux-64=11.*
- gmock>=1.13.0
@@ -43,6 +43,8 @@ dependencies:
- numba>=0.57
- numpy>=1.21
- numpydoc
+- nvcc_linux-64=11.8
+- pre-commit
- pydata-sphinx-theme
- pytest
- pytest-cov
diff --git a/conda/environments/all_cuda-120_arch-x86_64.yaml b/conda/environments/all_cuda-120_arch-x86_64.yaml
index 9db38ed1de..321c17bf4f 100644
--- a/conda/environments/all_cuda-120_arch-x86_64.yaml
+++ b/conda/environments/all_cuda-120_arch-x86_64.yaml
@@ -10,19 +10,20 @@ dependencies:
- breathe
- c-compiler
- clang-tools=16.0.6
-- clang=16.0.6
+- clang==16.0.6
- cmake>=3.26.4
- cuda-cudart-dev
+- cuda-nvcc
- cuda-profiler-api
- cuda-python>=12.0,<13.0a0
- cuda-version=12.0
- cupy>=12.0.0
- cxx-compiler
- cython>=3.0.0
-- dask-core>=2023.7.1
+- dask-core==2023.9.2
- dask-cuda==23.10.*
-- dask>=2023.7.1
-- distributed>=2023.7.1
+- dask==2023.9.2
+- distributed==2023.9.2
- doxygen>=1.8.20
- gcc_linux-64=11.*
- gmock>=1.13.0
@@ -39,6 +40,7 @@ dependencies:
- numba>=0.57
- numpy>=1.21
- numpydoc
+- pre-commit
- pydata-sphinx-theme
- pytest
- pytest-cov
diff --git a/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml
index 5a9ef5bd32..4f1df12dfa 100644
--- a/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml
+++ b/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml
@@ -10,7 +10,7 @@ dependencies:
- benchmark>=1.8.2
- c-compiler
- clang-tools=16.0.6
-- clang=16.0.6
+- clang==16.0.6
- cmake>=3.26.4
- cuda-profiler-api=11.8.86
- cuda-version=11.8
@@ -34,6 +34,7 @@ dependencies:
- nccl>=2.9.9
- ninja
- nlohmann_json>=3.11.2
+- nvcc_linux-64=11.8
- scikit-build>=0.13.1
- sysroot_linux-64==2.17
name: bench_ann_cuda-118_arch-x86_64
diff --git a/conda/recipes/raft-ann-bench/meta.yaml b/conda/recipes/raft-ann-bench/meta.yaml
index 91d0fdb729..a2ab0af643 100644
--- a/conda/recipes/raft-ann-bench/meta.yaml
+++ b/conda/recipes/raft-ann-bench/meta.yaml
@@ -78,11 +78,11 @@ requirements:
- h5py {{ h5py_version }}
- benchmark
- matplotlib
- # rmm is needed to determine if package is gpu-enabled
- - rmm ={{ minor_version }}
- python
- pandas
- pyyaml
+ # rmm is needed to determine if package is gpu-enabled
+ - rmm ={{ minor_version }}
run:
- python
@@ -104,6 +104,8 @@ requirements:
- python
- pandas
- pyyaml
+ # rmm is needed to determine if package is gpu-enabled
+ - rmm ={{ minor_version }}
about:
home: https://rapids.ai/
license: Apache-2.0
diff --git a/conda/recipes/raft-dask/meta.yaml b/conda/recipes/raft-dask/meta.yaml
index c9caa4dd9b..04dfef5063 100644
--- a/conda/recipes/raft-dask/meta.yaml
+++ b/conda/recipes/raft-dask/meta.yaml
@@ -60,10 +60,10 @@ requirements:
- cudatoolkit
{% endif %}
- {{ pin_compatible('cuda-version', max_pin='x', min_pin='x') }}
- - dask >=2023.7.1
- - dask-core >=2023.7.1
+ - dask ==2023.9.2
+ - dask-core ==2023.9.2
- dask-cuda ={{ minor_version }}
- - distributed >=2023.7.1
+ - distributed ==2023.9.2
- joblib >=0.11
- nccl >=2.9.9
- pylibraft {{ version }}
diff --git a/cpp/.clangd b/cpp/.clangd
new file mode 100644
index 0000000000..7c4fe036dd
--- /dev/null
+++ b/cpp/.clangd
@@ -0,0 +1,65 @@
+# https://clangd.llvm.org/config
+
+# Apply a config conditionally to all C files
+If:
+ PathMatch: .*\.(c|h)$
+
+---
+
+# Apply a config conditionally to all C++ files
+If:
+ PathMatch: .*\.(c|h)pp
+
+---
+
+# Apply a config conditionally to all CUDA files
+If:
+ PathMatch: .*\.cuh?
+CompileFlags:
+ Add:
+ - "-x"
+ - "cuda"
+ # No error on unknown CUDA versions
+ - "-Wno-unknown-cuda-version"
+ # Allow variadic CUDA functions
+ - "-Xclang=-fcuda-allow-variadic-functions"
+Diagnostics:
+ Suppress:
+ - "variadic_device_fn"
+ - "attributes_not_allowed"
+
+---
+
+# Tweak the clangd parse settings for all files
+CompileFlags:
+ Add:
+ # report all errors
+ - "-ferror-limit=0"
+ - "-fmacro-backtrace-limit=0"
+ - "-ftemplate-backtrace-limit=0"
+ # Skip the CUDA version check
+ - "--no-cuda-version-check"
+ Remove:
+ # remove gcc's -fcoroutines
+ - -fcoroutines
+ # remove nvc++ flags unknown to clang
+ - "-gpu=*"
+ - "-stdpar*"
+ # remove nvcc flags unknown to clang
+ - "-arch*"
+ - "-gencode*"
+ - "--generate-code*"
+ - "-ccbin*"
+ - "-t=*"
+ - "--threads*"
+ - "-Xptxas*"
+ - "-Xcudafe*"
+ - "-Xfatbin*"
+ - "-Xcompiler*"
+ - "--diag-suppress*"
+ - "--diag_suppress*"
+ - "--compiler-options*"
+ - "--expt-extended-lambda"
+ - "--expt-relaxed-constexpr"
+ - "-forward-unknown-to-host-compiler"
+ - "-Werror=cross-execution-space-call"
diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt
index d93b19f784..7d63751906 100644
--- a/cpp/CMakeLists.txt
+++ b/cpp/CMakeLists.txt
@@ -22,7 +22,8 @@ include(rapids-find)
option(BUILD_CPU_ONLY "Build CPU only components. Applies to RAFT ANN benchmarks currently" OFF)
-# workaround for rapids_cuda_init_architectures not working for arch detection with enable_language(CUDA)
+# workaround for rapids_cuda_init_architectures not working for arch detection with
+# enable_language(CUDA)
set(lang_list "CXX")
if(NOT BUILD_CPU_ONLY)
@@ -286,7 +287,8 @@ endif()
set_target_properties(raft_compiled PROPERTIES EXPORT_NAME compiled)
if(RAFT_COMPILE_LIBRARY)
- add_library(raft_objs OBJECT
+ add_library(
+ raft_objs OBJECT
src/core/logger.cpp
src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu
src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu
@@ -331,6 +333,7 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu
src/neighbors/brute_force_knn_int_float_int.cu
src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu
+ src/neighbors/brute_force_knn_index_float.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu
@@ -452,18 +455,21 @@ if(RAFT_COMPILE_LIBRARY)
src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu
src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu
src/util/memory_pool.cpp
- )
+ )
set_target_properties(
raft_objs
PROPERTIES CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON
CUDA_STANDARD 17
CUDA_STANDARD_REQUIRED ON
- POSITION_INDEPENDENT_CODE ON)
+ POSITION_INDEPENDENT_CODE ON
+ )
target_compile_definitions(raft_objs PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY")
- target_compile_options(raft_objs PRIVATE "$<$:${RAFT_CXX_FLAGS}>"
- "$<$:${RAFT_CUDA_FLAGS}>")
+ target_compile_options(
+ raft_objs PRIVATE "$<$:${RAFT_CXX_FLAGS}>"
+ "$<$:${RAFT_CUDA_FLAGS}>"
+ )
add_library(raft_lib SHARED $)
add_library(raft_lib_static STATIC $)
@@ -477,13 +483,15 @@ if(RAFT_COMPILE_LIBRARY)
)
foreach(target raft_lib raft_lib_static raft_objs)
- target_link_libraries(${target} PUBLIC
- raft::raft
- ${RAFT_CTK_MATH_DEPENDENCIES} # TODO: Once `raft::resources` is used everywhere, this
- # will just be cublas
- $)
+ target_link_libraries(
+ ${target}
+ PUBLIC raft::raft
+ ${RAFT_CTK_MATH_DEPENDENCIES} # TODO: Once `raft::resources` is used everywhere, this
+ # will just be cublas
+ $
+ )
- #So consumers know when using libraft.so/libraft.a
+ # So consumers know when using libraft.so/libraft.a
target_compile_definitions(${target} PUBLIC "RAFT_COMPILED")
# ensure CUDA symbols aren't relocated to the middle of the debug build binaries
target_link_options(${target} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld")
diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu
index 7ba381ab0a..a9ff6c2922 100644
--- a/cpp/bench/ann/src/raft/raft_benchmark.cu
+++ b/cpp/bench/ann/src/raft/raft_benchmark.cu
@@ -147,6 +147,13 @@ void parse_build_param(const nlohmann::json& conf,
if (conf.contains("intermediate_graph_degree")) {
param.intermediate_graph_degree = conf.at("intermediate_graph_degree");
}
+ if (conf.contains("graph_build_algo")) {
+ if (conf.at("graph_build_algo") == "IVF_PQ") {
+ param.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ;
+ } else if (conf.at("graph_build_algo") == "NN_DESCENT") {
+ param.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT;
+ }
+ }
}
template
diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt
index e8d4739384..ca4b0f099d 100644
--- a/cpp/bench/prims/CMakeLists.txt
+++ b/cpp/bench/prims/CMakeLists.txt
@@ -77,6 +77,7 @@ if(BUILD_PRIMS_BENCH)
NAME CLUSTER_BENCH PATH bench/prims/cluster/kmeans_balanced.cu bench/prims/cluster/kmeans.cu
bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY
)
+ ConfigureBench(NAME CORE_BENCH PATH bench/prims/core/bitset.cu bench/prims/main.cpp)
ConfigureBench(
NAME TUNE_DISTANCE PATH bench/prims/distance/tune_pairwise/kernel.cu
@@ -155,4 +156,5 @@ if(BUILD_PRIMS_BENCH)
LIB
EXPLICIT_INSTANTIATE_ONLY
)
+
endif()
diff --git a/cpp/bench/prims/core/bitset.cu b/cpp/bench/prims/core/bitset.cu
new file mode 100644
index 0000000000..5f44aa9af5
--- /dev/null
+++ b/cpp/bench/prims/core/bitset.cu
@@ -0,0 +1,74 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#include
+#include
+#include
+
+namespace raft::bench::core {
+
+struct bitset_inputs {
+ uint32_t bitset_len;
+ uint32_t mask_len;
+ uint32_t query_len;
+}; // struct bitset_inputs
+
+template
+struct bitset_bench : public fixture {
+ bitset_bench(const bitset_inputs& p)
+ : params(p),
+ mask{raft::make_device_vector(res, p.mask_len)},
+ queries{raft::make_device_vector(res, p.query_len)},
+ outputs{raft::make_device_vector(res, p.query_len)}
+ {
+ raft::random::RngState state{42};
+ raft::random::uniformInt(res, state, mask.view(), index_t{0}, index_t{p.bitset_len});
+ }
+
+ void run_benchmark(::benchmark::State& state) override
+ {
+ loop_on_state(state, [this]() {
+ auto my_bitset = raft::core::bitset(
+ this->res, raft::make_const_mdspan(mask.view()), params.bitset_len);
+ my_bitset.test(res, raft::make_const_mdspan(queries.view()), outputs.view());
+ });
+ }
+
+ private:
+ raft::resources res;
+ bitset_inputs params;
+ raft::device_vector mask, queries;
+ raft::device_vector outputs;
+}; // struct bitset
+
+const std::vector bitset_input_vecs{
+ {256 * 1024 * 1024, 64 * 1024 * 1024, 256 * 1024 * 1024}, // Standard Bench
+ {256 * 1024 * 1024, 64 * 1024 * 1024, 1024 * 1024 * 1024}, // Extra queries
+ {128 * 1024 * 1024, 1024 * 1024 * 1024, 256 * 1024 * 1024}, // Extra mask to test atomics impact
+};
+
+using Uint8_32 = bitset_bench;
+using Uint16_64 = bitset_bench;
+using Uint32_32 = bitset_bench;
+using Uint32_64 = bitset_bench;
+
+RAFT_BENCH_REGISTER(Uint8_32, "", bitset_input_vecs);
+RAFT_BENCH_REGISTER(Uint16_64, "", bitset_input_vecs);
+RAFT_BENCH_REGISTER(Uint32_32, "", bitset_input_vecs);
+RAFT_BENCH_REGISTER(Uint32_64, "", bitset_input_vecs);
+
+} // namespace raft::bench::core
diff --git a/cpp/bench/prims/neighbors/cagra_bench.cuh b/cpp/bench/prims/neighbors/cagra_bench.cuh
index bb405088bb..63f6c14686 100644
--- a/cpp/bench/prims/neighbors/cagra_bench.cuh
+++ b/cpp/bench/prims/neighbors/cagra_bench.cuh
@@ -18,8 +18,10 @@
#include
#include
+#include
#include
#include
+#include
#include
@@ -40,6 +42,8 @@ struct params {
int block_size;
int search_width;
int max_iterations;
+ /** Ratio of removed indices. */
+ double removed_ratio;
};
template
@@ -49,7 +53,8 @@ struct CagraBench : public fixture {
params_(ps),
queries_(make_device_matrix(handle, ps.n_queries, ps.n_dims)),
dataset_(make_device_matrix(handle, ps.n_samples, ps.n_dims)),
- knn_graph_(make_device_matrix(handle, ps.n_samples, ps.degree))
+ knn_graph_(make_device_matrix(handle, ps.n_samples, ps.degree)),
+ removed_indices_bitset_(handle, ps.n_samples)
{
// Generate random dataset and queriees
raft::random::RngState state{42};
@@ -74,6 +79,13 @@ struct CagraBench : public fixture {
auto metric = raft::distance::DistanceType::L2Expanded;
+ auto removed_indices =
+ raft::make_device_vector(handle, ps.removed_ratio * ps.n_samples);
+ thrust::sequence(
+ resource::get_thrust_policy(handle),
+ thrust::device_pointer_cast(removed_indices.data_handle()),
+ thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0)));
+ removed_indices_bitset_.set(handle, removed_indices.view());
index_.emplace(raft::neighbors::cagra::index(
handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view())));
}
@@ -95,10 +107,18 @@ struct CagraBench : public fixture {
distances.data_handle(), params_.n_queries, params_.k);
auto queries_v = make_const_mdspan(queries_.view());
- loop_on_state(state, [&]() {
- raft::neighbors::cagra::search(
- this->handle, search_params, *this->index_, queries_v, ind_v, dist_v);
- });
+ if (params_.removed_ratio > 0) {
+ auto filter = raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view());
+ loop_on_state(state, [&]() {
+ raft::neighbors::cagra::search_with_filtering(
+ this->handle, search_params, *this->index_, queries_v, ind_v, dist_v, filter);
+ });
+ } else {
+ loop_on_state(state, [&]() {
+ raft::neighbors::cagra::search(
+ this->handle, search_params, *this->index_, queries_v, ind_v, dist_v);
+ });
+ }
double data_size = params_.n_samples * params_.n_dims * sizeof(T);
double graph_size = params_.n_samples * params_.degree * sizeof(IdxT);
@@ -120,6 +140,7 @@ struct CagraBench : public fixture {
state.counters["block_size"] = params_.block_size;
state.counters["search_width"] = params_.search_width;
state.counters["iterations"] = iterations;
+ state.counters["removed_ratio"] = params_.removed_ratio;
}
private:
@@ -128,6 +149,7 @@ struct CagraBench : public fixture {
raft::device_matrix queries_;
raft::device_matrix dataset_;
raft::device_matrix knn_graph_;
+ raft::core::bitset removed_indices_bitset_;
};
inline const std::vector generate_inputs()
@@ -141,7 +163,8 @@ inline const std::vector generate_inputs()
{64}, // itopk_size
{0}, // block_size
{1}, // search_width
- {0} // max_iterations
+ {0}, // max_iterations
+ {0.0} // removed_ratio
);
auto inputs2 = raft::util::itertools::product({2000000ull, 10000000ull}, // n_samples
{128}, // dataset dim
@@ -151,7 +174,22 @@ inline const std::vector generate_inputs()
{64}, // itopk_size
{64, 128, 256, 512, 1024}, // block_size
{1}, // search_width
- {0} // max_iterations
+ {0}, // max_iterations
+ {0.0} // removed_ratio
+ );
+ inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());
+
+ inputs2 = raft::util::itertools::product(
+ {2000000ull, 10000000ull}, // n_samples
+ {128}, // dataset dim
+ {1, 10, 10000}, // n_queries
+ {255}, // k
+ {64}, // knn graph degree
+ {300}, // itopk_size
+ {256}, // block_size
+ {2}, // search_width
+ {0}, // max_iterations
+ {0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio
);
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());
return inputs;
diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh
new file mode 100644
index 0000000000..6747c5fab0
--- /dev/null
+++ b/cpp/include/raft/core/bitset.cuh
@@ -0,0 +1,308 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace raft::core {
+/**
+ * @defgroup bitset Bitset
+ * @{
+ */
+/**
+ * @brief View of a RAFT Bitset.
+ *
+ * This lightweight structure stores a pointer to a bitset in device memory with it's length.
+ * It provides a test() device function to check if a given index is set in the bitset.
+ *
+ * @tparam bitset_t Underlying type of the bitset array. Default is uint32_t.
+ * @tparam index_t Indexing type used. Default is uint32_t.
+ */
+template
+struct bitset_view {
+ index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8;
+
+ _RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, index_t bitset_len)
+ : bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len}
+ {
+ }
+ /**
+ * @brief Create a bitset view from a device vector view of the bitset.
+ *
+ * @param bitset_span Device vector view of the bitset
+ * @param bitset_len Number of bits in the bitset
+ */
+ _RAFT_HOST_DEVICE bitset_view(raft::device_vector_view bitset_span,
+ index_t bitset_len)
+ : bitset_ptr_{bitset_span.data_handle()}, bitset_len_{bitset_len}
+ {
+ }
+ /**
+ * @brief Device function to test if a given index is set in the bitset.
+ *
+ * @param sample_index Single index to test
+ * @return bool True if index has not been unset in the bitset
+ */
+ inline _RAFT_DEVICE auto test(const index_t sample_index) const -> bool
+ {
+ const bitset_t bit_element = bitset_ptr_[sample_index / bitset_element_size];
+ const index_t bit_index = sample_index % bitset_element_size;
+ const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0;
+ return is_bit_set;
+ }
+
+ /**
+ * @brief Get the device pointer to the bitset.
+ */
+ inline _RAFT_HOST_DEVICE auto data_handle() -> bitset_t* { return bitset_ptr_; }
+ inline _RAFT_HOST_DEVICE auto data_handle() const -> const bitset_t* { return bitset_ptr_; }
+ /**
+ * @brief Get the number of bits of the bitset representation.
+ */
+ inline _RAFT_HOST_DEVICE auto size() const -> index_t { return bitset_len_; }
+
+ /**
+ * @brief Get the number of elements used by the bitset representation.
+ */
+ inline _RAFT_HOST_DEVICE auto n_elements() const -> index_t
+ {
+ return raft::ceildiv(bitset_len_, bitset_element_size);
+ }
+
+ inline auto to_mdspan() -> raft::device_vector_view
+ {
+ return raft::make_device_vector_view(bitset_ptr_, n_elements());
+ }
+ inline auto to_mdspan() const -> raft::device_vector_view
+ {
+ return raft::make_device_vector_view(bitset_ptr_, n_elements());
+ }
+
+ private:
+ bitset_t* bitset_ptr_;
+ index_t bitset_len_;
+};
+
+/**
+ * @brief RAFT Bitset.
+ *
+ * This structure encapsulates a bitset in device memory. It provides a view() method to get a
+ * device-usable lightweight view of the bitset.
+ * Each index is represented by a single bit in the bitset. The total number of bytes used is
+ * ceil(bitset_len / 8).
+ * @tparam bitset_t Underlying type of the bitset array. Default is uint32_t.
+ * @tparam index_t Indexing type used. Default is uint32_t.
+ */
+template
+struct bitset {
+ index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8;
+
+ /**
+ * @brief Construct a new bitset object with a list of indices to unset.
+ *
+ * @param res RAFT resources
+ * @param mask_index List of indices to unset in the bitset
+ * @param bitset_len Length of the bitset
+ * @param default_value Default value to set the bits to. Default is true.
+ */
+ bitset(const raft::resources& res,
+ raft::device_vector_view mask_index,
+ index_t bitset_len,
+ bool default_value = true)
+ : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
+ raft::resource::get_cuda_stream(res)},
+ bitset_len_{bitset_len},
+ default_value_{default_value}
+ {
+ cudaMemsetAsync(bitset_.data(),
+ default_value ? 0xff : 0x00,
+ n_elements() * sizeof(bitset_t),
+ resource::get_cuda_stream(res));
+ set(res, mask_index, !default_value);
+ }
+
+ /**
+ * @brief Construct a new bitset object
+ *
+ * @param res RAFT resources
+ * @param bitset_len Length of the bitset
+ * @param default_value Default value to set the bits to. Default is true.
+ */
+ bitset(const raft::resources& res, index_t bitset_len, bool default_value = true)
+ : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
+ resource::get_cuda_stream(res)},
+ bitset_len_{bitset_len},
+ default_value_{default_value}
+ {
+ cudaMemsetAsync(bitset_.data(),
+ default_value ? 0xff : 0x00,
+ n_elements() * sizeof(bitset_t),
+ resource::get_cuda_stream(res));
+ }
+ // Disable copy constructor
+ bitset(const bitset&) = delete;
+ bitset(bitset&&) = default;
+ bitset& operator=(const bitset&) = delete;
+ bitset& operator=(bitset&&) = default;
+
+ /**
+ * @brief Create a device-usable view of the bitset.
+ *
+ * @return bitset_view
+ */
+ inline auto view() -> raft::core::bitset_view
+ {
+ return bitset_view(to_mdspan(), bitset_len_);
+ }
+ [[nodiscard]] inline auto view() const -> raft::core::bitset_view
+ {
+ return bitset_view(to_mdspan(), bitset_len_);
+ }
+
+ /**
+ * @brief Get the device pointer to the bitset.
+ */
+ inline auto data_handle() -> bitset_t* { return bitset_.data(); }
+ inline auto data_handle() const -> const bitset_t* { return bitset_.data(); }
+ /**
+ * @brief Get the number of bits of the bitset representation.
+ */
+ inline auto size() const -> index_t { return bitset_len_; }
+
+ /**
+ * @brief Get the number of elements used by the bitset representation.
+ */
+ inline auto n_elements() const -> index_t
+ {
+ return raft::ceildiv(bitset_len_, bitset_element_size);
+ }
+
+ /** @brief Get an mdspan view of the current bitset */
+ inline auto to_mdspan() -> raft::device_vector_view
+ {
+ return raft::make_device_vector_view(bitset_.data(), n_elements());
+ }
+ [[nodiscard]] inline auto to_mdspan() const -> raft::device_vector_view
+ {
+ return raft::make_device_vector_view(bitset_.data(), n_elements());
+ }
+
+ /** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to
+ * the default value. */
+ void resize(const raft::resources& res, index_t new_bitset_len)
+ {
+ auto old_size = raft::ceildiv(bitset_len_, bitset_element_size);
+ auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size);
+ bitset_.resize(new_size);
+ bitset_len_ = new_bitset_len;
+ if (old_size < new_size) {
+ // If the new size is larger, set the new bits to the default value
+ cudaMemsetAsync(bitset_.data() + old_size,
+ default_value_ ? 0xff : 0x00,
+ (new_size - old_size) * sizeof(bitset_t),
+ resource::get_cuda_stream(res));
+ }
+ }
+
+ /**
+ * @brief Test a list of indices in a bitset.
+ *
+ * @tparam output_t Output type of the test. Default is bool.
+ * @param res RAFT resources
+ * @param queries List of indices to test
+ * @param output List of outputs
+ */
+ template
+ void test(const raft::resources& res,
+ raft::device_vector_view queries,
+ raft::device_vector_view output) const
+ {
+ RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size");
+ auto bitset_view = view();
+ raft::linalg::map(
+ res,
+ output,
+ [bitset_view] __device__(index_t query) { return output_t(bitset_view.test(query)); },
+ queries);
+ }
+ /**
+ * @brief Set a list of indices in a bitset to set_value.
+ *
+ * @param res RAFT resources
+ * @param mask_index indices to remove from the bitset
+ * @param set_value Value to set the bits to (true or false)
+ */
+ void set(const raft::resources& res,
+ raft::device_vector_view mask_index,
+ bool set_value = false)
+ {
+ auto* bitset_ptr = this->data_handle();
+ thrust::for_each_n(resource::get_thrust_policy(res),
+ mask_index.data_handle(),
+ mask_index.extent(0),
+ [bitset_ptr, set_value] __device__(const index_t sample_index) {
+ const index_t bit_element = sample_index / bitset_element_size;
+ const index_t bit_index = sample_index % bitset_element_size;
+ const bitset_t bitmask = bitset_t{1} << bit_index;
+ if (set_value) {
+ atomicOr(bitset_ptr + bit_element, bitmask);
+ } else {
+ const bitset_t bitmask2 = ~bitmask;
+ atomicAnd(bitset_ptr + bit_element, bitmask2);
+ }
+ });
+ }
+ /**
+ * @brief Flip all the bits in a bitset.
+ *
+ * @param res RAFT resources
+ */
+ void flip(const raft::resources& res)
+ {
+ auto bitset_span = this->to_mdspan();
+ raft::linalg::map(
+ res,
+ bitset_span,
+ [] __device__(bitset_t element) { return bitset_t(~element); },
+ raft::make_const_mdspan(bitset_span));
+ }
+ /**
+ * @brief Reset the bits in a bitset.
+ *
+ * @param res RAFT resources
+ */
+ void reset(const raft::resources& res)
+ {
+ cudaMemsetAsync(bitset_.data(),
+ default_value_ ? 0xff : 0x00,
+ n_elements() * sizeof(bitset_t),
+ resource::get_cuda_stream(res));
+ }
+
+ private:
+ raft::device_uvector bitset_;
+ index_t bitset_len_;
+ bool default_value_;
+};
+
+/** @} */
+} // end namespace raft::core
diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh
index 862db75866..b8c00616da 100644
--- a/cpp/include/raft/neighbors/brute_force-ext.cuh
+++ b/cpp/include/raft/neighbors/brute_force-ext.cuh
@@ -22,7 +22,8 @@
#include // raft::identity_op
#include // raft::resources
#include // raft::distance::DistanceType
-#include // RAFT_EXPLICIT
+#include
+#include // RAFT_EXPLICIT
#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY
@@ -38,6 +39,19 @@ inline void knn_merge_parts(
size_t n_samples,
std::optional> translations = std::nullopt) RAFT_EXPLICIT;
+template
+index build(raft::resources const& res,
+ mdspan, row_major, Accessor> dataset,
+ raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
+ T metric_arg = 0.0) RAFT_EXPLICIT;
+
+template
+void search(raft::resources const& res,
+ const index& idx,
+ raft::device_matrix_view queries,
+ raft::device_matrix_view neighbors,
+ raft::device_matrix_view distances) RAFT_EXPLICIT;
+
template (
+ raft::resources const& res,
+ const raft::neighbors::brute_force::index& idx,
+ raft::device_matrix_view queries,
+ raft::device_matrix_view neighbors,
+ raft::device_matrix_view distances);
+
+extern template void search(
+ raft::resources const& res,
+ const raft::neighbors::brute_force::index& idx,
+ raft::device_matrix_view queries,
+ raft::device_matrix_view neighbors,
+ raft::device_matrix_view distances);
+
+extern template raft::neighbors::brute_force::index build(
+ raft::resources const& res,
+ raft::device_matrix_view dataset,
+ raft::distance::DistanceType metric,
+ float metric_arg);
+} // namespace raft::neighbors::brute_force
+
#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \
value_t, idx_t, idx_layout, query_layout) \
extern template void raft::neighbors::brute_force::fused_l2_knn( \
diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh
index bc9e09e5b0..88439a738b 100644
--- a/cpp/include/raft/neighbors/brute_force-inl.cuh
+++ b/cpp/include/raft/neighbors/brute_force-inl.cuh
@@ -19,6 +19,7 @@
#include
#include
#include
+#include
#include
#include
@@ -280,6 +281,101 @@ void fused_l2_knn(raft::resources const& handle,
metric);
}
-/** @} */ // end group brute_force_knn
+/**
+ * @brief Build the index from the dataset for efficient search.
+ *
+ * @tparam T data element type
+ *
+ * @param[in] res
+ * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim]
+ * @param[in] metric: distance metric to use. Euclidean (L2) is used by default
+ * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This
+ * is ignored if the metric_type is not Minkowski.
+ *
+ * @return the constructed brute force index
+ */
+template
+index build(raft::resources const& res,
+ mdspan, row_major, Accessor> dataset,
+ raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
+ T metric_arg = 0.0)
+{
+ // certain distance metrics can benefit by pre-calculating the norms for the index dataset
+ // which lets us avoid calculating these at query time
+ std::optional> norms;
+ if (metric == raft::distance::DistanceType::L2Expanded ||
+ metric == raft::distance::DistanceType::L2SqrtExpanded ||
+ metric == raft::distance::DistanceType::CosineExpanded) {
+ norms = make_device_vector(res, dataset.extent(0));
+ // cosine needs the l2norm, where as l2 distances needs the squared norm
+ if (metric == raft::distance::DistanceType::CosineExpanded) {
+ raft::linalg::norm(res,
+ dataset,
+ norms->view(),
+ raft::linalg::NormType::L2Norm,
+ raft::linalg::Apply::ALONG_ROWS,
+ raft::sqrt_op{});
+ } else {
+ raft::linalg::norm(res,
+ dataset,
+ norms->view(),
+ raft::linalg::NormType::L2Norm,
+ raft::linalg::Apply::ALONG_ROWS);
+ }
+ }
+
+ return index(res, dataset, std::move(norms), metric, metric_arg);
+}
+/**
+ * @brief Brute Force search using the constructed index.
+ *
+ * @tparam T data element type
+ * @tparam IdxT type of the indices
+ *
+ * @param[in] res raft resources
+ * @param[in] idx brute force index
+ * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
+ * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
+ * [n_queries, k]
+ * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
+ * k]
+ */
+template
+void search(raft::resources const& res,
+ const index& idx,
+ raft::device_matrix_view queries,
+ raft::device_matrix_view neighbors,
+ raft::device_matrix_view distances)
+{
+ RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs");
+ RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1),
+ "Number of columns in queries must match brute force index");
+
+ auto k = neighbors.extent(1);
+ auto d = idx.dataset().extent(1);
+
+ std::vector dataset = {const_cast(idx.dataset().data_handle())};
+ std::vector sizes = {idx.dataset().extent(0)};
+ std::vector norms;
+ if (idx.has_norms()) { norms.push_back(const_cast(idx.norms().data_handle())); }
+
+ detail::brute_force_knn_impl(res,
+ dataset,
+ sizes,
+ d,
+ const_cast(queries.data_handle()),
+ queries.extent(0),
+ neighbors.data_handle(),
+ distances.data_handle(),
+ k,
+ true,
+ true,
+ nullptr,
+ idx.metric(),
+ idx.metric_arg(),
+ raft::identity_op(),
+ norms.size() ? &norms : nullptr);
+}
+/** @} */ // end group brute_force_knn
} // namespace raft::neighbors::brute_force
diff --git a/cpp/include/raft/neighbors/brute_force_types.hpp b/cpp/include/raft/neighbors/brute_force_types.hpp
new file mode 100644
index 0000000000..19dd6b8350
--- /dev/null
+++ b/cpp/include/raft/neighbors/brute_force_types.hpp
@@ -0,0 +1,165 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "ann_types.hpp"
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+
+namespace raft::neighbors::brute_force {
+/**
+ * @addtogroup brute_force
+ * @{
+ */
+
+/**
+ * @brief Brute Force index.
+ *
+ * The index stores the dataset and norms for the dataset in device memory.
+ *
+ * @tparam T data element type
+ */
+template
+struct index : ann::index {
+ public:
+ /** Distance metric used for retrieval */
+ [[nodiscard]] constexpr inline raft::distance::DistanceType metric() const noexcept
+ {
+ return metric_;
+ }
+
+ /** Total length of the index (number of vectors). */
+ [[nodiscard]] constexpr inline int64_t size() const noexcept { return dataset_view_.extent(0); }
+
+ /** Dimensionality of the data. */
+ [[nodiscard]] constexpr inline uint32_t dim() const noexcept { return dataset_view_.extent(1); }
+
+ /** Dataset [size, dim] */
+ [[nodiscard]] inline auto dataset() const noexcept
+ -> device_matrix_view
+ {
+ return dataset_view_;
+ }
+
+ /** Dataset norms */
+ [[nodiscard]] inline auto norms() const -> device_vector_view
+ {
+ return norms_view_.value();
+ }
+
+ /** Whether ot not this index has dataset norms */
+ [[nodiscard]] inline bool has_norms() const noexcept { return norms_view_.has_value(); }
+
+ [[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; }
+
+ // Don't allow copying the index for performance reasons (try avoiding copying data)
+ index(const index&) = delete;
+ index(index&&) = default;
+ auto operator=(const index&) -> index& = delete;
+ auto operator=(index&&) -> index& = default;
+ ~index() = default;
+
+ /** Construct a brute force index from dataset
+ *
+ * Constructs a brute force index from a dataset. This lets us precompute norms for
+ * the dataset, providing a speed benefit over doing this at query time.
+
+ * If the dataset is already in GPU memory, then this class stores a non-owning reference to
+ * the dataset. If the dataset is in host memory, it will be copied to the device and the
+ * index will own the device memory.
+ */
+ template
+ index(raft::resources const& res,
+ mdspan, row_major, data_accessor> dataset,
+ std::optional>&& norms,
+ raft::distance::DistanceType metric,
+ T metric_arg = 0.0)
+ : ann::index(),
+ metric_(metric),
+ dataset_(make_device_matrix(res, 0, 0)),
+ norms_(std::move(norms)),
+ metric_arg_(metric_arg)
+ {
+ if (norms_) { norms_view_ = make_const_mdspan(norms_.value().view()); }
+ update_dataset(res, dataset);
+ resource::sync_stream(res);
+ }
+
+ /** Construct a brute force index from dataset
+ *
+ * This class stores a non-owning reference to the dataset and norms here.
+ * Having precomputed norms gives us a performance advantage at query time.
+ */
+ index(raft::resources const& res,
+ raft::device_matrix_view dataset_view,
+ std::optional> norms_view,
+ raft::distance::DistanceType metric,
+ T metric_arg = 0.0)
+ : ann::index(),
+ metric_(metric),
+ dataset_(make_device_matrix(res, 0, 0)),
+ dataset_view_(dataset_view),
+ norms_view_(norms_view),
+ metric_arg_(metric_arg)
+ {
+ }
+
+ private:
+ /**
+ * Replace the dataset with a new dataset.
+ */
+ void update_dataset(raft::resources const& res,
+ raft::device_matrix_view dataset)
+ {
+ dataset_view_ = dataset;
+ }
+
+ /**
+ * Replace the dataset with a new dataset.
+ *
+ * We create a copy of the dataset on the device. The index manages the lifetime of this copy.
+ */
+ void update_dataset(raft::resources const& res,
+ raft::host_matrix_view dataset)
+ {
+ dataset_ = make_device_matrix(dataset.extents(0), dataset.extents(1));
+ raft::copy(dataset_.data_handle(),
+ dataset.data_handle(),
+ dataset.size(),
+ resource::get_cuda_stream(res));
+ dataset_view_ = make_const_mdspan(dataset_.view());
+ }
+
+ raft::distance::DistanceType metric_;
+ raft::device_matrix dataset_;
+ std::optional> norms_;
+ std::optional> norms_view_;
+ raft::device_matrix_view dataset_view_;
+ T metric_arg_;
+};
+
+/** @} */
+
+} // namespace raft::neighbors::brute_force
diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh
index 903d0571dc..f9682a973f 100644
--- a/cpp/include/raft/neighbors/cagra.cuh
+++ b/cpp/include/raft/neighbors/cagra.cuh
@@ -35,12 +35,11 @@ namespace raft::neighbors::cagra {
*/
/**
- * @brief Build a kNN graph.
+ * @brief Build a kNN graph using IVF-PQ.
*
* The kNN graph is the first building block for CAGRA index.
- * This function uses the IVF-PQ method to build a kNN graph.
*
- * The output is a dense matrix that stores the neighbor indices for each pont in the dataset.
+ * The output is a dense matrix that stores the neighbor indices for each point in the dataset.
* Each point has the same number of neighbors.
*
* See [cagra::build](#cagra::build) for an alternative method.
@@ -52,16 +51,16 @@ namespace raft::neighbors::cagra {
* @code{.cpp}
* using namespace raft::neighbors;
* // use default index parameters
- * cagra::index_params build_params;
- * cagra::search_params search_params
+ * ivf_pq::index_params build_params;
+ * ivf_pq::search_params search_params
* auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128);
* // create knn graph
* cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params);
- * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64);
+ * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64);
* cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view());
* // Construct an index from dataset and optimized knn_graph
* auto index = cagra::index(res, build_params.metric(), dataset,
- * optimized_graph.view());
+ * optimized_graph.view());
* @endcode
*
* @tparam DataT data element type
@@ -70,7 +69,7 @@ namespace raft::neighbors::cagra {
* @param[in] res raft resources
* @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim]
* @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree]
- * @param[in] refine_rate refinement rate for ivf-pq search
+ * @param[in] refine_rate (optional) refinement rate for ivf-pq search
* @param[in] build_params (optional) ivf_pq index building parameters for knn graph
* @param[in] search_params (optional) ivf_pq search parameters
*/
@@ -95,6 +94,58 @@ void build_knn_graph(raft::resources const& res,
res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params);
}
+/**
+ * @brief Build a kNN graph using NN-descent.
+ *
+ * The kNN graph is the first building block for CAGRA index.
+ *
+ * The output is a dense matrix that stores the neighbor indices for each point in the dataset.
+ * Each point has the same number of neighbors.
+ *
+ * See [cagra::build](#cagra::build) for an alternative method.
+ *
+ * The following distance metrics are supported:
+ * - L2Expanded
+ *
+ * Usage example:
+ * @code{.cpp}
+ * using namespace raft::neighbors;
+ * using namespace raft::neighbors::experimental;
+ * // use default index parameters
+ * nn_descent::index_params build_params;
+ * build_params.graph_degree = 128;
+ * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128);
+ * // create knn graph
+ * cagra::build_knn_graph(res, dataset, knn_graph.view(), build_params);
+ * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64);
+ * cagra::optimize(res, dataset, nn_descent_index.graph.view(), optimized_graph.view());
+ * // Construct an index from dataset and optimized knn_graph
+ * auto index = cagra::index(res, build_params.metric(), dataset,
+ * optimized_graph.view());
+ * @endcode
+ *
+ * @tparam DataT data element type
+ * @tparam IdxT type of the dataset vector indices
+ * @tparam accessor host or device accessor_type for the dataset
+ * @param[in] res raft::resources is an object mangaging resources
+ * @param[in] dataset input raft::host/device_matrix_view that can be located in
+ * in host or device memory
+ * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree]
+ * @param[in] build_params an instance of experimental::nn_descent::index_params that are parameters
+ * to run the nn-descent algorithm
+ */
+template , memory_type::device>>
+void build_knn_graph(raft::resources const& res,
+ mdspan, row_major, accessor> dataset,
+ raft::host_matrix_view knn_graph,
+ experimental::nn_descent::index_params build_params)
+{
+ detail::build_knn_graph(res, dataset, knn_graph, build_params);
+}
+
/**
* @brief Sort a KNN graph index.
* Preprocessing step for `cagra::optimize`: If a KNN graph is not built using
@@ -106,7 +157,7 @@ void build_knn_graph(raft::resources const& res,
* @code{.cpp}
* using namespace raft::neighbors;
* cagra::index_params build_params;
- * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128);
+ * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128);
* // build KNN graph not using `cagra::build_knn_graph`
* // build(knn_graph, dataset, ...);
* // sort graph index
@@ -115,7 +166,7 @@ void build_knn_graph(raft::resources const& res,
* cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view());
* // Construct an index from dataset and optimized knn_graph
* auto index = cagra::index(res, build_params.metric(), dataset,
- * optimized_graph.view());
+ * optimized_graph.view());
* @endcode
*
* @tparam DataT type of the data in the source dataset
@@ -259,7 +310,16 @@ index build(raft::resources const& res,
std::optional> knn_graph(
raft::make_host_matrix(dataset.extent(0), intermediate_degree));
- build_knn_graph(res, dataset, knn_graph->view());
+ if (params.build_algo == graph_build_algo::IVF_PQ) {
+ build_knn_graph(res, dataset, knn_graph->view());
+
+ } else {
+ // Use nn-descent to build CAGRA knn graph
+ auto nn_descent_params = experimental::nn_descent::index_params();
+ nn_descent_params.graph_degree = intermediate_degree;
+ nn_descent_params.intermediate_graph_degree = 1.5 * intermediate_degree;
+ build_knn_graph(res, dataset, knn_graph->view(), nn_descent_params);
+ }
auto cagra_graph = raft::make_host_matrix(dataset.extent(0), graph_degree);
@@ -316,9 +376,88 @@ void search(raft::resources const& res,
auto distances_internal = raft::make_device_matrix_view(
distances.data_handle(), distances.extent(0), distances.extent(1));
- cagra::detail::search_main(
- res, params, idx, queries_internal, neighbors_internal, distances_internal);
+ cagra::detail::search_main(res,
+ params,
+ idx,
+ queries_internal,
+ neighbors_internal,
+ distances_internal,
+ raft::neighbors::filtering::none_cagra_sample_filter());
+}
+
+/**
+ * @brief Search ANN using the constructed index with the given sample filter.
+ *
+ * Usage example:
+ * @code{.cpp}
+ * using namespace raft::neighbors;
+ * // use default index parameters
+ * cagra::index_params index_params;
+ * // create and fill the index from a [N, D] dataset
+ * auto index = cagra::build(res, index_params, dataset);
+ * // use default search parameters
+ * cagra::search_params search_params;
+ * // create a bitset to filter the search
+ * auto removed_indices = raft::make_device_vector(res, n_removed_indices);
+ * raft::core::bitset removed_indices_bitset(
+ * res, removed_indices.view(), dataset.extent(0));
+ * // search K nearest neighbours according to a bitset
+ * auto neighbors = raft::make_device_matrix(res, n_queries, k);
+ * auto distances = raft::make_device_matrix(res, n_queries, k);
+ * cagra::search_with_filtering(res, search_params, index, queries, neighbors, distances,
+ * filtering::bitset_filter(removed_indices_bitset.view()));
+ * @endcode
+ *
+ * @tparam T data element type
+ * @tparam IdxT type of the indices
+ * @tparam CagraSampleFilterT Device filter function, with the signature
+ * `(uint32_t query ix, uint32_t sample_ix) -> bool`
+ *
+ * @param[in] res raft resources
+ * @param[in] params configure the search
+ * @param[in] idx cagra index
+ * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
+ * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
+ * [n_queries, k]
+ * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
+ * k]
+ * @param[in] sample_filter a device filter function that greenlights samples for a given query
+ */
+template
+void search_with_filtering(raft::resources const& res,
+ const search_params& params,
+ const index& idx,
+ raft::device_matrix_view queries,
+ raft::device_matrix_view neighbors,
+ raft::device_matrix_view distances,
+ CagraSampleFilterT sample_filter = CagraSampleFilterT())
+{
+ RAFT_EXPECTS(
+ queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
+ "Number of rows in output neighbors and distances matrices must equal the number of queries.");
+
+ RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
+ "Number of columns in output neighbors and distances matrices must equal k");
+ RAFT_EXPECTS(queries.extent(1) == idx.dim(),
+ "Number of query dimensions should equal number of dimensions in the index.");
+
+ using internal_IdxT = typename std::make_unsigned::type;
+ auto queries_internal = raft::make_device_matrix_view(
+ queries.data_handle(), queries.extent(0), queries.extent(1));
+ auto neighbors_internal = raft::make_device_matrix_view(
+ reinterpret_cast(neighbors.data_handle()),
+ neighbors.extent(0),
+ neighbors.extent(1));
+ auto distances_internal = raft::make_device_matrix_view(
+ distances.data_handle(), distances.extent(0), distances.extent(1));
+
+ cagra::detail::search_main(
+ res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter);
}
+
/** @} */ // end group cagra
} // namespace raft::neighbors::cagra
diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp
index 02e3f5338e..5061d6082d 100644
--- a/cpp/include/raft/neighbors/cagra_types.hpp
+++ b/cpp/include/raft/neighbors/cagra_types.hpp
@@ -40,11 +40,24 @@ namespace raft::neighbors::cagra {
* @{
*/
+/**
+ * @brief ANN algorithm used by CAGRA to build knn graph
+ *
+ */
+enum class graph_build_algo {
+ /* Use IVF-PQ to build all-neighbors knn graph */
+ IVF_PQ,
+ /* Experimental, use NN-Descent to build all-neighbors knn graph */
+ NN_DESCENT
+};
+
struct index_params : ann::index_params {
/** Degree of input graph for pruning. */
size_t intermediate_graph_degree = 128;
/** Degree of output graph. */
size_t graph_degree = 64;
+ /** ANN algorithm to build knn graph. */
+ graph_build_algo build_algo = graph_build_algo::IVF_PQ;
};
enum class search_algo {
@@ -165,9 +178,10 @@ struct index : ann::index {
~index() = default;
/** Construct an empty index. */
- index(raft::resources const& res)
+ index(raft::resources const& res,
+ raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded)
: ann::index(),
- metric_(raft::distance::DistanceType::L2Expanded),
+ metric_(metric),
dataset_(make_device_matrix(res, 0, 0)),
graph_(make_device_matrix(res, 0, 0))
{
@@ -296,7 +310,11 @@ struct index : ann::index {
raft::host_matrix_view knn_graph)
{
RAFT_LOG_DEBUG("Copying CAGRA knn graph from host to device");
- graph_ = make_device_matrix(res, knn_graph.extent(0), knn_graph.extent(1));
+ if ((graph_.extent(0) != knn_graph.extent(0)) || (graph_.extent(1) != knn_graph.extent(1))) {
+ // clear existing memory before allocating to prevent OOM errors on large graphs
+ if (graph_.size()) { graph_ = make_device_matrix(res, 0, 0); }
+ graph_ = make_device_matrix(res, knn_graph.extent(0), knn_graph.extent(1));
+ }
raft::copy(graph_.data_handle(),
knn_graph.data_handle(),
knn_graph.size(),
@@ -311,7 +329,13 @@ struct index : ann::index {
mdspan, row_major, data_accessor> dataset)
{
size_t padded_dim = round_up_safe(dataset.extent(1) * sizeof(T), 16) / sizeof(T);
- dataset_ = make_device_matrix(res, dataset.extent(0), padded_dim);
+
+ if ((dataset_.extent(0) != dataset.extent(0)) ||
+ (static_cast(dataset_.extent(1)) != padded_dim)) {
+ // clear existing memory before allocating to prevent OOM errors on large datasets
+ if (dataset_.size()) { dataset_ = make_device_matrix(res, 0, 0); }
+ dataset_ = make_device_matrix(res, dataset.extent(0), padded_dim);
+ }
if (dataset_.extent(1) == dataset.extent(1)) {
raft::copy(dataset_.data_handle(),
dataset.data_handle(),
@@ -351,6 +375,7 @@ struct index : ann::index {
// TODO: Remove deprecated experimental namespace in 23.12 release
namespace raft::neighbors::experimental::cagra {
+using raft::neighbors::cagra::graph_build_algo;
using raft::neighbors::cagra::hash_mode;
using raft::neighbors::cagra::index;
using raft::neighbors::cagra::index_params;
diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
index d19d7e7904..40024a3deb 100644
--- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
+++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
@@ -28,12 +28,14 @@
#include
#include
#include
+#include
#include
#include
#include
#include
#include
+#include
#include
namespace raft::neighbors::cagra::detail {
@@ -46,6 +48,7 @@ void build_knn_graph(raft::resources const& res,
std::optional build_params = std::nullopt,
std::optional search_params = std::nullopt)
{
+ resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::build");
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded,
"Currently only L2Expanded metric is supported");
@@ -238,4 +241,27 @@ void build_knn_graph(raft::resources const& res,
if (!first) RAFT_LOG_DEBUG("# Finished building kNN graph");
}
+template
+void build_knn_graph(raft::resources const& res,
+ mdspan, row_major, accessor> dataset,
+ raft::host_matrix_view knn_graph,
+ experimental::nn_descent::index_params build_params)
+{
+ auto nn_descent_idx = experimental::nn_descent::index(res, knn_graph);
+ experimental::nn_descent::build(res, build_params, dataset, nn_descent_idx);
+
+ using internal_IdxT = typename std::make_unsigned::type;
+ using g_accessor = typename decltype(nn_descent_idx.graph())::accessor_type;
+ using g_accessor_internal =
+ host_device_accessor, g_accessor::mem_type>;
+
+ auto knn_graph_internal =
+ mdspan, row_major, g_accessor_internal>(
+ reinterpret_cast(nn_descent_idx.graph().data_handle()),
+ nn_descent_idx.graph().extent(0),
+ nn_descent_idx.graph().extent(1));
+
+ graph::sort_knn_graph(res, dataset, knn_graph_internal);
+}
+
} // namespace raft::neighbors::cagra::detail
diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
index 8190817b5b..81e714dc4e 100644
--- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
+++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
@@ -18,10 +18,13 @@
#include
#include
+#include
#include
#include
#include
+#include
+#include
#include
#include
#include
@@ -32,6 +35,48 @@
namespace raft::neighbors::cagra::detail {
+template
+struct CagraSampleFilterWithQueryIdOffset {
+ const uint32_t offset;
+ CagraSampleFilterT filter;
+
+ CagraSampleFilterWithQueryIdOffset(const uint32_t offset, const CagraSampleFilterT filter)
+ : offset(offset), filter(filter)
+ {
+ }
+
+ _RAFT_DEVICE auto operator()(const uint32_t query_id, const uint32_t sample_id)
+ {
+ return filter(query_id + offset, sample_id);
+ }
+};
+
+template
+struct CagraSampleFilterT_Selector {
+ using type = CagraSampleFilterWithQueryIdOffset;
+};
+template <>
+struct CagraSampleFilterT_Selector {
+ using type = raft::neighbors::filtering::none_cagra_sample_filter;
+};
+
+// A helper function to set a query id offset
+template
+inline typename CagraSampleFilterT_Selector::type set_offset(
+ CagraSampleFilterT filter, const uint32_t offset)
+{
+ typename CagraSampleFilterT_Selector::type new_filter(offset, filter);
+ return new_filter;
+}
+template <>
+inline
+ typename CagraSampleFilterT_Selector::type
+ set_offset(
+ raft::neighbors::filtering::none_cagra_sample_filter filter, const uint32_t)
+{
+ return filter;
+}
+
/**
* @brief Search ANN using the constructed index.
*
@@ -52,27 +97,37 @@ namespace raft::neighbors::cagra::detail {
* k]
*/
-template
+template
void search_main(raft::resources const& res,
search_params params,
const index& index,
raft::device_matrix_view queries,
raft::device_matrix_view neighbors,
- raft::device_matrix_view distances)
+ raft::device_matrix_view distances,
+ CagraSampleFilterT sample_filter = CagraSampleFilterT())
{
+ resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::search");
RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n",
static_cast(index.dataset().extent(0)),
static_cast(index.dataset().extent(1)));
RAFT_LOG_DEBUG("# query size = %lu, dim = %lu\n",
static_cast(queries.extent(0)),
static_cast