diff --git a/README.md b/README.md index 55b8fca8f6..7833a5cfa3 100755 --- a/README.md +++ b/README.md @@ -1,8 +1,12 @@ #
 RAFT: Reusable Accelerated Functions and Tools for Vector Search and More
+> [!IMPORTANT] +> The vector search and clustering algorithms in RAFT are being migrated to a new library dedicated to vector search called [cuVS](https://github.com/rapidsai/cuvs). We will continue to support the vector search algorithms in RAFT during this move, but will no longer update them after the RAPIDS 24.06 (June) release. We plan to complete the migration by RAPIDS 24.08 (August) release. + ![RAFT tech stack](img/raft-tech-stack-vss.png) + ## Contents
@@ -77,6 +81,8 @@ Projects that use the RAFT ANNS algorithms for accelerating vector search includ Please see the example [Jupyter notebook](https://github.com/rapidsai/raft/blob/HEAD/notebooks/VectorSearch_QuestionRetrieval.ipynb) to get started RAFT for vector search in Python. + + ### Information Retrieval RAFT contains a catalog of reusable primitives for composing algorithms that require fast neighborhood computations, such as diff --git a/build.sh b/build.sh index e5df0af826..45c7d1380f 100755 --- a/build.sh +++ b/build.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # raft build scripts @@ -305,7 +305,7 @@ if hasArg --allgpuarch; then BUILD_ALL_GPU_ARCH=1 fi -if hasArg --compile-lib || (( ${NUMARGS} == 0 )); then +if hasArg --compile-lib || hasArg pylibraft || (( ${NUMARGS} == 0 )); then COMPILE_LIBRARY=ON CMAKE_TARGET="${CMAKE_TARGET};raft_lib" fi @@ -405,7 +405,7 @@ fi ################################################################################ # Configure for building all C++ targets -if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || hasArg bench-prims || hasArg bench-ann; then +if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || hasArg bench-prims || hasArg bench-ann || ((${COMPILE_LIBRARY} == ON )); then if (( ${BUILD_ALL_GPU_ARCH} == 0 )); then RAFT_CMAKE_CUDA_ARCHITECTURES="NATIVE" echo "Building for the architecture of the GPU in the system..." @@ -512,6 +512,8 @@ fi if hasArg docs; then set -x + export RAPIDS_VERSION="$(sed -E -e 's/^([0-9]{2})\.([0-9]{2})\.([0-9]{2}).*$/\1.\2.\3/' "${REPODIR}/VERSION")" + export RAPIDS_VERSION_MAJOR_MINOR="$(sed -E -e 's/^([0-9]{2})\.([0-9]{2})\.([0-9]{2}).*$/\1.\2/' "${REPODIR}/VERSION")" cd ${DOXYGEN_BUILD_DIR} doxygen Doxyfile cd ${SPHINX_BUILD_DIR} diff --git a/ci/build_docs.sh b/ci/build_docs.sh index 4c07683642..3d72c815db 100755 --- a/ci/build_docs.sh +++ b/ci/build_docs.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. set -euo pipefail @@ -28,7 +28,9 @@ rapids-mamba-retry install \ pylibraft \ raft-dask -export RAPIDS_VERSION_NUMBER="24.04" +export RAPIDS_VERSION="$(rapids-version)" +export RAPIDS_VERSION_MAJOR_MINOR="$(rapids-version-major-minor)" +export RAPIDS_VERSION_NUMBER="$RAPIDS_VERSION_MAJOR_MINOR" export RAPIDS_DOCS_DIR="$(mktemp -d)" rapids-logger "Build CPP docs" diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index d268c16e0a..636f637d0c 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. ######################## # RAFT Version Updater # ######################## @@ -36,23 +36,11 @@ function sed_runner() { sed -i.bak ''"$1"'' $2 && rm -f ${2}.bak } -sed_runner "s/set(RAPIDS_VERSION .*)/set(RAPIDS_VERSION \"${NEXT_SHORT_TAG}\")/g" cpp/CMakeLists.txt sed_runner "s/set(RAPIDS_VERSION .*)/set(RAPIDS_VERSION \"${NEXT_SHORT_TAG}\")/g" cpp/template/cmake/thirdparty/fetch_rapids.cmake -sed_runner "s/set(RAFT_VERSION .*)/set(RAFT_VERSION \"${NEXT_FULL_TAG}\")/g" cpp/CMakeLists.txt -sed_runner 's/'"pylibraft_version .*)"'/'"pylibraft_version ${NEXT_FULL_TAG})"'/g' python/pylibraft/CMakeLists.txt -sed_runner 's/'"raft_dask_version .*)"'/'"raft_dask_version ${NEXT_FULL_TAG})"'/g' python/raft-dask/CMakeLists.txt -sed_runner 's/'"branch-.*\/RAPIDS.cmake"'/'"branch-${NEXT_SHORT_TAG}\/RAPIDS.cmake"'/g' fetch_rapids.cmake # Centralized version file update echo "${NEXT_FULL_TAG}" > VERSION -# Wheel testing script -sed_runner "s/branch-.*/branch-${NEXT_SHORT_TAG}/g" ci/test_wheel_raft_dask.sh - -# Docs update -sed_runner 's/version = .*/version = '"'${NEXT_SHORT_TAG}'"'/g' docs/source/conf.py -sed_runner 's/release = .*/release = '"'${NEXT_FULL_TAG}'"'/g' docs/source/conf.py - DEPENDENCIES=( dask-cuda pylibraft @@ -84,9 +72,6 @@ sed_runner "/^ucx_py_version:$/ {n;s/.*/ - \"${NEXT_UCX_PY_VERSION}\"/}" conda/ for FILE in .github/workflows/*.yaml; do sed_runner "/shared-workflows/ s/@.*/@branch-${NEXT_SHORT_TAG}/g" "${FILE}" done -sed_runner "s/RAPIDS_VERSION_NUMBER=\".*/RAPIDS_VERSION_NUMBER=\"${NEXT_SHORT_TAG}\"/g" ci/build_docs.sh - -sed_runner "/^PROJECT_NUMBER/ s|\".*\"|\"${NEXT_SHORT_TAG}\"|g" cpp/doxygen/Doxyfile sed_runner "/^set(RAFT_VERSION/ s|\".*\"|\"${NEXT_SHORT_TAG}\"|g" docs/source/build.md sed_runner "s|branch-[0-9][0-9].[0-9][0-9]|branch-${NEXT_SHORT_TAG}|g" docs/source/build.md diff --git a/conda/environments/all_cuda-118_arch-aarch64.yaml b/conda/environments/all_cuda-118_arch-aarch64.yaml index 40b031d677..e27532a489 100644 --- a/conda/environments/all_cuda-118_arch-aarch64.yaml +++ b/conda/environments/all_cuda-118_arch-aarch64.yaml @@ -39,7 +39,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23 +- numpy>=1.23,<2.0a0 - numpydoc - nvcc_linux-aarch64=11.8 - pre-commit @@ -57,5 +57,5 @@ dependencies: - sysroot_linux-aarch64==2.17 - ucx-proc=*=gpu - ucx-py==0.37.* -- ucx>=1.13.0 +- ucx>=1.15.0,<1.16.0 name: all_cuda-118_arch-aarch64 diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 5485d09a37..bf535c5c04 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -39,7 +39,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23 +- numpy>=1.23,<2.0a0 - numpydoc - nvcc_linux-64=11.8 - pre-commit @@ -57,5 +57,5 @@ dependencies: - sysroot_linux-64==2.17 - ucx-proc=*=gpu - ucx-py==0.37.* -- ucx>=1.13.0 +- ucx>=1.15.0,<1.16.0 name: all_cuda-118_arch-x86_64 diff --git a/conda/environments/all_cuda-122_arch-aarch64.yaml b/conda/environments/all_cuda-122_arch-aarch64.yaml index b688bf3952..8ea3843841 100644 --- a/conda/environments/all_cuda-122_arch-aarch64.yaml +++ b/conda/environments/all_cuda-122_arch-aarch64.yaml @@ -36,7 +36,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23 +- numpy>=1.23,<2.0a0 - numpydoc - pre-commit - pydata-sphinx-theme @@ -53,5 +53,5 @@ dependencies: - sysroot_linux-aarch64==2.17 - ucx-proc=*=gpu - ucx-py==0.37.* -- ucx>=1.13.0 +- ucx>=1.15.0,<1.16.0 name: all_cuda-122_arch-aarch64 diff --git a/conda/environments/all_cuda-122_arch-x86_64.yaml b/conda/environments/all_cuda-122_arch-x86_64.yaml index 013f852aee..a3f6f7e99f 100644 --- a/conda/environments/all_cuda-122_arch-x86_64.yaml +++ b/conda/environments/all_cuda-122_arch-x86_64.yaml @@ -36,7 +36,7 @@ dependencies: - nccl>=2.9.9 - ninja - numba>=0.57 -- numpy>=1.23 +- numpy>=1.23,<2.0a0 - numpydoc - pre-commit - pydata-sphinx-theme @@ -53,5 +53,5 @@ dependencies: - sysroot_linux-64==2.17 - ucx-proc=*=gpu - ucx-py==0.37.* -- ucx>=1.13.0 +- ucx>=1.15.0,<1.16.0 name: all_cuda-122_arch-x86_64 diff --git a/conda/recipes/libraft/build_libraft_template.sh b/conda/recipes/libraft/build_libraft_template.sh index bd7719af76..86c0fa11b6 100644 --- a/conda/recipes/libraft/build_libraft_template.sh +++ b/conda/recipes/libraft/build_libraft_template.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Just building template so we verify it uses libraft.so and fail if it doesn't build -./build.sh template +./build.sh template --no-nvtx diff --git a/conda/recipes/pylibraft/meta.yaml b/conda/recipes/pylibraft/meta.yaml index 5c2829d297..e524a68f9e 100644 --- a/conda/recipes/pylibraft/meta.yaml +++ b/conda/recipes/pylibraft/meta.yaml @@ -65,7 +65,7 @@ requirements: {% endif %} - libraft {{ version }} - libraft-headers {{ version }} - - numpy >=1.23 + - numpy >=1.23,<2.0a0 - python x.x - rmm ={{ minor_version }} diff --git a/conda/recipes/raft-dask/conda_build_config.yaml b/conda/recipes/raft-dask/conda_build_config.yaml index 483e53026a..d2bdcbb351 100644 --- a/conda/recipes/raft-dask/conda_build_config.yaml +++ b/conda/recipes/raft-dask/conda_build_config.yaml @@ -14,7 +14,7 @@ sysroot_version: - "2.17" ucx_version: - - ">=1.14.1,<1.16.0" + - ">=1.15.0,<1.16.0" ucx_py_version: - "0.37.*" diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 650bc1a059..638ceb3b45 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -10,11 +10,8 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. -set(RAPIDS_VERSION "24.04") -set(RAFT_VERSION "24.04.00") - cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) -include(../fetch_rapids.cmake) +include(../rapids_config.cmake) include(rapids-cmake) include(rapids-cpm) include(rapids-export) @@ -34,7 +31,7 @@ endif() project( RAFT - VERSION ${RAFT_VERSION} + VERSION "${RAPIDS_VERSION}" LANGUAGES ${lang_list} ) diff --git a/cpp/doxygen/Doxyfile b/cpp/doxygen/Doxyfile index 779472d880..67566ac1f9 100644 --- a/cpp/doxygen/Doxyfile +++ b/cpp/doxygen/Doxyfile @@ -38,7 +38,7 @@ PROJECT_NAME = "RAFT C++ API" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = "24.04" +PROJECT_NUMBER = "$(RAPIDS_VERSION_MAJOR_MINOR)" # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index 282097742c..1cc272f74e 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -91,18 +91,11 @@ std::enable_if_t::value> cutlassDistanceKernel(const Da typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op); - const DataT *a, *b; - - IdxT gemm_lda, gemm_ldb; - // Number of pipelines you want to use constexpr int NumStages = 3; // Alignment constexpr int Alignment = VecLen; - // default initialize problem size with row major inputs - auto problem_size = cutlass::gemm::GemmCoord(n, m, k); - using cutlassDistKernel = typename cutlass::gemm::kernel::PairwiseDistanceGemm::value> cutlassDistanceKernel(const Da using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter; - if constexpr (isRowMajor) { - a = y; - b = x; - gemm_lda = ldb; - gemm_ldb = lda; - } else { - problem_size = cutlass::gemm::GemmCoord(m, n, k); - a = x; - b = y; - gemm_lda = lda; - gemm_ldb = ldb; + constexpr uint32_t gridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + constexpr uint32_t max_batch_size = gridYZMax * cutlassDistKernel::ThreadblockShape::kN; + IdxT numNbatches = (n - 1 + max_batch_size) / max_batch_size; + + for (IdxT i = 0; i < numNbatches; i++) { + const DataT *a, *b; + IdxT gemm_lda, gemm_ldb; + size_t offsetN = i * max_batch_size; + + if constexpr (isRowMajor) { + gemm_lda = ldb; + gemm_ldb = lda; + a = y + offsetN * gemm_lda; + b = x; + } else { + gemm_lda = lda; + gemm_ldb = ldb; + a = x; + b = y + offsetN; + } + IdxT chunkN = (i + 1) * max_batch_size; + IdxT currentN = (chunkN < n) ? max_batch_size : (n - offsetN); + + // default initialize problem size with row major inputs + auto problem_size = isRowMajor ? cutlass::gemm::GemmCoord(currentN, m, k) + : cutlass::gemm::GemmCoord(m, currentN, k); + + typename cutlassDist::Arguments arguments{ + mode, + problem_size, + batch_count, + epilog_op_param, + a, + b, + xn, // C matrix eq vector param, which here is A norm + nullptr, // tensor_Z, + (DataT*)yn + offsetN, // this is broadcast vec, which is required to be non-const param + dOutput + offsetN, // Output distance matrix + (int64_t)0, // batch stride A + (int64_t)0, // batch stride B + (int64_t)0, // batch stride Norm A + (int64_t)0, + (int64_t)0, // batch stride Norm B + (int64_t)0, // batch stride Output + gemm_lda, // stride A + gemm_ldb, // stride B + 1, // stride A norm + 0, // this is no-op for Z + 0, // This must be zero + ldd // stride Output matrix + }; + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = cutlassDist::get_workspace_size(arguments); + // Allocate workspace memory + rmm::device_uvector workspace(workspace_size, stream); + // Instantiate CUTLASS kernel depending on templates + cutlassDist cutlassDist_op; + // Check the problem size is supported or not + RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream)); + + // Launch initialized CUTLASS kernel + RAFT_CUTLASS_TRY(cutlassDist_op(stream)); } - - typename cutlassDist::Arguments arguments{ - mode, problem_size, batch_count, epilog_op_param, a, b, - xn, // C matrix eq vector param, which here is A norm - nullptr, // tensor_Z, - (DataT*)yn, // this is broadcast vec, which is required to be non-const param - dOutput, // Output distance matrix - (int64_t)0, // batch stride A - (int64_t)0, // batch stride B - (int64_t)0, // batch stride Norm A - (int64_t)0, - (int64_t)0, // batch stride Norm B - (int64_t)0, // batch stride Output - gemm_lda, // stride A - gemm_ldb, // stride B - 1, // stride A norm - 0, // this is no-op for Z - 0, // This must be zero - ldd // stride Output matrix - }; - - // Using the arguments, query for extra workspace required for matrix multiplication computation - size_t workspace_size = cutlassDist::get_workspace_size(arguments); - // Allocate workspace memory - rmm::device_uvector workspace(workspace_size, stream); - // Instantiate CUTLASS kernel depending on templates - cutlassDist cutlassDist_op; - // Check the problem size is supported or not - RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments)); - - // Initialize CUTLASS kernel with arguments and workspace pointer - RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream)); - - // Launch initialized CUTLASS kernel - RAFT_CUTLASS_TRY(cutlassDist_op(stream)); } }; // namespace detail diff --git a/cpp/include/raft/stats/detail/mean.cuh b/cpp/include/raft/stats/detail/mean.cuh index cf4dbc7aa3..6c330acb26 100644 --- a/cpp/include/raft/stats/detail/mean.cuh +++ b/cpp/include/raft/stats/detail/mean.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ RAFT_KERNEL meanKernelRowMajor(Type* mu, const Type* data, IdxType D, IdxType N) __syncthreads(); raft::myAtomicAdd(smu + thisColId, thread_data); __syncthreads(); - if (threadIdx.x < ColsPerBlk) raft::myAtomicAdd(mu + colId, smu[thisColId]); + if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(mu + colId, smu[thisColId]); } template diff --git a/cpp/include/raft/stats/detail/stddev.cuh b/cpp/include/raft/stats/detail/stddev.cuh index acee4a944e..bc2644a233 100644 --- a/cpp/include/raft/stats/detail/stddev.cuh +++ b/cpp/include/raft/stats/detail/stddev.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ RAFT_KERNEL stddevKernelRowMajor(Type* std, const Type* data, IdxType D, IdxType __syncthreads(); raft::myAtomicAdd(sstd + thisColId, thread_data); __syncthreads(); - if (threadIdx.x < ColsPerBlk) raft::myAtomicAdd(std + colId, sstd[thisColId]); + if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(std + colId, sstd[thisColId]); } template diff --git a/cpp/include/raft/stats/detail/sum.cuh b/cpp/include/raft/stats/detail/sum.cuh index bb45eb50f4..4f85536e6c 100644 --- a/cpp/include/raft/stats/detail/sum.cuh +++ b/cpp/include/raft/stats/detail/sum.cuh @@ -34,30 +34,72 @@ RAFT_KERNEL sumKernelRowMajor(Type* mu, const Type* data, IdxType D, IdxType N) IdxType thisRowId = threadIdx.x / ColsPerBlk; IdxType colId = thisColId + ((IdxType)blockIdx.y * ColsPerBlk); IdxType rowId = thisRowId + ((IdxType)blockIdx.x * RowsPerBlkPerIter); - Type thread_data = Type(0); + Type thread_sum = Type(0); const IdxType stride = RowsPerBlkPerIter * gridDim.x; - for (IdxType i = rowId; i < N; i += stride) - thread_data += (colId < D) ? data[i * D + colId] : Type(0); + for (IdxType i = rowId; i < N; i += stride) { + thread_sum += (colId < D) ? data[i * D + colId] : Type(0); + } __shared__ Type smu[ColsPerBlk]; if (threadIdx.x < ColsPerBlk) smu[threadIdx.x] = Type(0); __syncthreads(); - raft::myAtomicAdd(smu + thisColId, thread_data); + raft::myAtomicAdd(smu + thisColId, thread_sum); + __syncthreads(); + if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(mu + colId, smu[thisColId]); +} + +template +RAFT_KERNEL sumKahanKernelRowMajor(Type* mu, const Type* data, IdxType D, IdxType N) +{ + constexpr int RowsPerBlkPerIter = TPB / ColsPerBlk; + IdxType thisColId = threadIdx.x % ColsPerBlk; + IdxType thisRowId = threadIdx.x / ColsPerBlk; + IdxType colId = thisColId + ((IdxType)blockIdx.y * ColsPerBlk); + IdxType rowId = thisRowId + ((IdxType)blockIdx.x * RowsPerBlkPerIter); + Type thread_sum = Type(0); + Type thread_c = Type(0); + const IdxType stride = RowsPerBlkPerIter * gridDim.x; + for (IdxType i = rowId; i < N; i += stride) { + // KahanBabushkaNeumaierSum + const Type cur_value = (colId < D) ? data[i * D + colId] : Type(0); + const Type t = thread_sum + cur_value; + if (abs(thread_sum) >= abs(cur_value)) { + thread_c += (thread_sum - t) + cur_value; + } else { + thread_c += (cur_value - t) + thread_sum; + } + thread_sum = t; + } + thread_sum += thread_c; + __shared__ Type smu[ColsPerBlk]; + if (threadIdx.x < ColsPerBlk) smu[threadIdx.x] = Type(0); + __syncthreads(); + raft::myAtomicAdd(smu + thisColId, thread_sum); __syncthreads(); if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(mu + colId, smu[thisColId]); } template -RAFT_KERNEL sumKernelColMajor(Type* mu, const Type* data, IdxType D, IdxType N) +RAFT_KERNEL sumKahanKernelColMajor(Type* mu, const Type* data, IdxType D, IdxType N) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - Type thread_data = Type(0); + Type thread_sum = Type(0); + Type thread_c = Type(0); IdxType colStart = N * blockIdx.x; for (IdxType i = threadIdx.x; i < N; i += TPB) { - IdxType idx = colStart + i; - thread_data += data[idx]; + // KahanBabushkaNeumaierSum + IdxType idx = colStart + i; + const Type cur_value = data[idx]; + const Type t = thread_sum + cur_value; + if (abs(thread_sum) >= abs(cur_value)) { + thread_c += (thread_sum - t) + cur_value; + } else { + thread_c += (cur_value - t) + thread_sum; + } + thread_sum = t; } - Type acc = BlockReduce(temp_storage).Sum(thread_data); + thread_sum += thread_c; + Type acc = BlockReduce(temp_storage).Sum(thread_sum); if (threadIdx.x == 0) { mu[blockIdx.x] = acc; } } @@ -66,15 +108,21 @@ void sum(Type* output, const Type* input, IdxType D, IdxType N, bool rowMajor, c { static const int TPB = 256; if (rowMajor) { - static const int RowsPerThread = 4; - static const int ColsPerBlk = 32; - static const int RowsPerBlk = (TPB / ColsPerBlk) * RowsPerThread; - dim3 grid(raft::ceildiv(N, (IdxType)RowsPerBlk), raft::ceildiv(D, (IdxType)ColsPerBlk)); + static const int ColsPerBlk = 8; + static const int MinRowsPerThread = 16; + static const int MinRowsPerBlk = (TPB / ColsPerBlk) * MinRowsPerThread; + static const int MaxBlocksDimX = 8192; + + const IdxType grid_y = raft::ceildiv(D, (IdxType)ColsPerBlk); + const IdxType grid_x = + raft::min((IdxType)MaxBlocksDimX, raft::ceildiv(N, (IdxType)MinRowsPerBlk)); + + dim3 grid(grid_x, grid_y); RAFT_CUDA_TRY(cudaMemset(output, 0, sizeof(Type) * D)); - sumKernelRowMajor + sumKahanKernelRowMajor <<>>(output, input, D, N); } else { - sumKernelColMajor<<>>(output, input, D, N); + sumKahanKernelColMajor<<>>(output, input, D, N); } RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/template/README.md b/cpp/template/README.md index 348dff270a..05ec48964f 100644 --- a/cpp/template/README.md +++ b/cpp/template/README.md @@ -8,7 +8,7 @@ Once the minimum requirements are satisfied, this example template application c This directory (`RAFT_SOURCE/cpp/template`) can be copied directly in order to build a new application with RAFT. -RAFT can be integrated into an existing CMake project by copying the contents in the `configure rapids-cmake` and `configure raft` sections of the provided `CMakeLists.txt` into your project, along with `cmake/thirdparty/get_raft.cmake`. +RAFT can be integrated into an existing CMake project by copying the contents in the `configure rapids-cmake` and `configure raft` sections of the provided `CMakeLists.txt` into your project, along with `cmake/thirdparty/get_raft.cmake`. Make sure to link against the appropriate Cmake targets. Use `raft::raft`to add make the headers available and `raft::compiled` when utilizing the shared library. diff --git a/cpp/template/build.sh b/cpp/template/build.sh index 3ac00fc9af..49c17f7499 100755 --- a/cpp/template/build.sh +++ b/cpp/template/build.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # raft empty project template build script diff --git a/cpp/template/cmake/thirdparty/get_raft.cmake b/cpp/template/cmake/thirdparty/get_raft.cmake index 6128b5c43c..07b0897be0 100644 --- a/cpp/template/cmake/thirdparty/get_raft.cmake +++ b/cpp/template/cmake/thirdparty/get_raft.cmake @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -33,6 +33,12 @@ function(find_and_configure_raft) #----------------------------------------------------- # Invoke CPM find_package() #----------------------------------------------------- + # Since the RAFT_NVTX option is used by targets generated by + # find_package(RAFT_NVTX) and when building from source we want to + # make `RAFT_NVTX` a cache variable so we get consistent + # behavior + # + set(RAFT_NVTX ${PKG_ENABLE_NVTX} CACHE BOOL "Enable raft nvtx logging" FORCE) rapids_cpm_find(raft ${PKG_VERSION} GLOBAL_TARGETS raft::raft BUILD_EXPORT_SET raft-template-exports @@ -46,7 +52,6 @@ function(find_and_configure_raft) "BUILD_TESTS OFF" "BUILD_PRIMS_BENCH OFF" "BUILD_ANN_BENCH OFF" - "RAFT_NVTX ${ENABLE_NVTX}" "RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}" ) endfunction() diff --git a/cpp/test/distance/dist_cos.cu b/cpp/test/distance/dist_cos.cu index caf55529ed..b792ec4039 100644 --- a/cpp/test/distance/dist_cos.cu +++ b/cpp/test/distance/dist_cos.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,10 +29,12 @@ class DistanceExpCosXequalY : public DistanceTestSameBuffer {}; const std::vector> inputsf = { + {0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL}, {0.001f, 1024, 1024, 32, true, 1234ULL}, {0.001f, 1024, 32, 1024, true, 1234ULL}, {0.001f, 32, 1024, 1024, true, 1234ULL}, {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL}, {0.001f, 1024, 1024, 32, false, 1234ULL}, {0.001f, 1024, 32, 1024, false, 1234ULL}, {0.001f, 32, 1024, 1024, false, 1234ULL}, diff --git a/cpp/test/distance/dist_l2_exp.cu b/cpp/test/distance/dist_l2_exp.cu index 7bdbb44362..0203d9ed9d 100644 --- a/cpp/test/distance/dist_l2_exp.cu +++ b/cpp/test/distance/dist_l2_exp.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,12 +29,14 @@ class DistanceEucExpTestXequalY : public DistanceTestSameBuffer {}; const std::vector> inputsf = { + {0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL}, {0.001f, 2048, 4096, 128, true, 1234ULL}, {0.001f, 1024, 1024, 32, true, 1234ULL}, {0.001f, 1024, 32, 1024, true, 1234ULL}, {0.001f, 32, 1024, 1024, true, 1234ULL}, {0.003f, 1024, 1024, 1024, true, 1234ULL}, {0.003f, 1021, 1021, 1021, true, 1234ULL}, + {0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL}, {0.001f, 1024, 1024, 32, false, 1234ULL}, {0.001f, 1024, 32, 1024, false, 1234ULL}, {0.001f, 32, 1024, 1024, false, 1234ULL}, diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 938cd219d0..2854a8f3df 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -339,7 +339,7 @@ void naiveDistance(DataType* dist, DataType metric_arg = 2.0f, cudaStream_t stream = 0) { - static const dim3 TPB(16, 32, 1); + static const dim3 TPB(4, 256, 1); dim3 nblks(raft::ceildiv(m, (int)TPB.x), raft::ceildiv(n, (int)TPB.y), 1); switch (type) { diff --git a/cpp/test/stats/mean.cu b/cpp/test/stats/mean.cu index 67931e2eed..61b57ce739 100644 --- a/cpp/test/stats/mean.cu +++ b/cpp/test/stats/mean.cu @@ -95,39 +95,39 @@ class MeanTest : public ::testing::TestWithParam> { // Note: For 1024 samples, 256 experiments, a mean of 1.0 with stddev=1.0, the // measured mean (of a normal distribution) will fall outside of an epsilon of // 0.15 only 4/10000 times. (epsilon of 0.1 will fail 30/100 times) -const std::vector> inputsf = {{0.15f, 1.f, 1024, 32, true, false, 1234ULL}, - {0.15f, 1.f, 1024, 64, true, false, 1234ULL}, - {0.15f, 1.f, 1024, 128, true, false, 1234ULL}, - {0.15f, 1.f, 1024, 256, true, false, 1234ULL}, - {0.15f, -1.f, 1024, 32, false, false, 1234ULL}, - {0.15f, -1.f, 1024, 64, false, false, 1234ULL}, - {0.15f, -1.f, 1024, 128, false, false, 1234ULL}, - {0.15f, -1.f, 1024, 256, false, false, 1234ULL}, - {0.15f, 1.f, 1024, 32, true, true, 1234ULL}, - {0.15f, 1.f, 1024, 64, true, true, 1234ULL}, - {0.15f, 1.f, 1024, 128, true, true, 1234ULL}, - {0.15f, 1.f, 1024, 256, true, true, 1234ULL}, - {0.15f, -1.f, 1024, 32, false, true, 1234ULL}, - {0.15f, -1.f, 1024, 64, false, true, 1234ULL}, - {0.15f, -1.f, 1024, 128, false, true, 1234ULL}, - {0.15f, -1.f, 1024, 256, false, true, 1234ULL}}; - -const std::vector> inputsd = {{0.15, 1.0, 1024, 32, true, false, 1234ULL}, - {0.15, 1.0, 1024, 64, true, false, 1234ULL}, - {0.15, 1.0, 1024, 128, true, false, 1234ULL}, - {0.15, 1.0, 1024, 256, true, false, 1234ULL}, - {0.15, -1.0, 1024, 32, false, false, 1234ULL}, - {0.15, -1.0, 1024, 64, false, false, 1234ULL}, - {0.15, -1.0, 1024, 128, false, false, 1234ULL}, - {0.15, -1.0, 1024, 256, false, false, 1234ULL}, - {0.15, 1.0, 1024, 32, true, true, 1234ULL}, - {0.15, 1.0, 1024, 64, true, true, 1234ULL}, - {0.15, 1.0, 1024, 128, true, true, 1234ULL}, - {0.15, 1.0, 1024, 256, true, true, 1234ULL}, - {0.15, -1.0, 1024, 32, false, true, 1234ULL}, - {0.15, -1.0, 1024, 64, false, true, 1234ULL}, - {0.15, -1.0, 1024, 128, false, true, 1234ULL}, - {0.15, -1.0, 1024, 256, false, true, 1234ULL}}; +const std::vector> inputsf = { + {0.15f, 1.f, 1024, 32, true, false, 1234ULL}, {0.15f, 1.f, 1024, 64, true, false, 1234ULL}, + {0.15f, 1.f, 1024, 128, true, false, 1234ULL}, {0.15f, 1.f, 1024, 256, true, false, 1234ULL}, + {0.15f, -1.f, 1024, 32, false, false, 1234ULL}, {0.15f, -1.f, 1024, 64, false, false, 1234ULL}, + {0.15f, -1.f, 1024, 128, false, false, 1234ULL}, {0.15f, -1.f, 1024, 256, false, false, 1234ULL}, + {0.15f, 1.f, 1024, 32, true, true, 1234ULL}, {0.15f, 1.f, 1024, 64, true, true, 1234ULL}, + {0.15f, 1.f, 1024, 128, true, true, 1234ULL}, {0.15f, 1.f, 1024, 256, true, true, 1234ULL}, + {0.15f, -1.f, 1024, 32, false, true, 1234ULL}, {0.15f, -1.f, 1024, 64, false, true, 1234ULL}, + {0.15f, -1.f, 1024, 128, false, true, 1234ULL}, {0.15f, -1.f, 1024, 256, false, true, 1234ULL}, + {0.15f, -1.f, 1030, 1, false, false, 1234ULL}, {0.15f, -1.f, 1030, 60, true, false, 1234ULL}, + {2.0f, -1.f, 31, 120, false, false, 1234ULL}, {2.0f, -1.f, 1, 130, true, false, 1234ULL}, + {0.15f, -1.f, 1030, 1, false, true, 1234ULL}, {0.15f, -1.f, 1030, 60, true, true, 1234ULL}, + {2.0f, -1.f, 31, 120, false, true, 1234ULL}, {2.0f, -1.f, 1, 130, false, true, 1234ULL}, + {2.0f, -1.f, 1, 1, false, false, 1234ULL}, {2.0f, -1.f, 1, 1, false, true, 1234ULL}, + {2.0f, -1.f, 7, 23, false, false, 1234ULL}, {2.0f, -1.f, 7, 23, false, true, 1234ULL}, + {2.0f, -1.f, 17, 5, false, false, 1234ULL}, {2.0f, -1.f, 17, 5, false, true, 1234ULL}}; + +const std::vector> inputsd = { + {0.15, 1.0, 1024, 32, true, false, 1234ULL}, {0.15, 1.0, 1024, 64, true, false, 1234ULL}, + {0.15, 1.0, 1024, 128, true, false, 1234ULL}, {0.15, 1.0, 1024, 256, true, false, 1234ULL}, + {0.15, -1.0, 1024, 32, false, false, 1234ULL}, {0.15, -1.0, 1024, 64, false, false, 1234ULL}, + {0.15, -1.0, 1024, 128, false, false, 1234ULL}, {0.15, -1.0, 1024, 256, false, false, 1234ULL}, + {0.15, 1.0, 1024, 32, true, true, 1234ULL}, {0.15, 1.0, 1024, 64, true, true, 1234ULL}, + {0.15, 1.0, 1024, 128, true, true, 1234ULL}, {0.15, 1.0, 1024, 256, true, true, 1234ULL}, + {0.15, -1.0, 1024, 32, false, true, 1234ULL}, {0.15, -1.0, 1024, 64, false, true, 1234ULL}, + {0.15, -1.0, 1024, 128, false, true, 1234ULL}, {0.15, -1.0, 1024, 256, false, true, 1234ULL}, + {0.15, -1.0, 1030, 1, false, false, 1234ULL}, {0.15, -1.0, 1030, 60, true, false, 1234ULL}, + {2.0, -1.0, 31, 120, false, false, 1234ULL}, {2.0, -1.0, 1, 130, true, false, 1234ULL}, + {0.15, -1.0, 1030, 1, false, true, 1234ULL}, {0.15, -1.0, 1030, 60, true, true, 1234ULL}, + {2.0, -1.0, 31, 120, false, true, 1234ULL}, {2.0, -1.0, 1, 130, false, true, 1234ULL}, + {2.0, -1.0, 1, 1, false, false, 1234ULL}, {2.0, -1.0, 1, 1, false, true, 1234ULL}, + {2.0, -1.0, 7, 23, false, false, 1234ULL}, {2.0, -1.0, 7, 23, false, true, 1234ULL}, + {2.0, -1.0, 17, 5, false, false, 1234ULL}, {2.0, -1.0, 17, 5, false, true, 1234ULL}}; typedef MeanTest MeanTestF; TEST_P(MeanTestF, Result) diff --git a/cpp/test/stats/minmax.cu b/cpp/test/stats/minmax.cu index 7563cb12be..fd909ebb90 100644 --- a/cpp/test/stats/minmax.cu +++ b/cpp/test/stats/minmax.cu @@ -145,45 +145,33 @@ class MinMaxTest : public ::testing::TestWithParam> { rmm::device_uvector minmax_ref; }; -const std::vector> inputsf = {{0.00001f, 1024, 32, 1234ULL}, - {0.00001f, 1024, 64, 1234ULL}, - {0.00001f, 1024, 128, 1234ULL}, - {0.00001f, 1024, 256, 1234ULL}, - {0.00001f, 1024, 512, 1234ULL}, - {0.00001f, 1024, 1024, 1234ULL}, - {0.00001f, 4096, 32, 1234ULL}, - {0.00001f, 4096, 64, 1234ULL}, - {0.00001f, 4096, 128, 1234ULL}, - {0.00001f, 4096, 256, 1234ULL}, - {0.00001f, 4096, 512, 1234ULL}, - {0.00001f, 4096, 1024, 1234ULL}, - {0.00001f, 8192, 32, 1234ULL}, - {0.00001f, 8192, 64, 1234ULL}, - {0.00001f, 8192, 128, 1234ULL}, - {0.00001f, 8192, 256, 1234ULL}, - {0.00001f, 8192, 512, 1234ULL}, - {0.00001f, 8192, 1024, 1234ULL}, - {0.00001f, 1024, 8192, 1234ULL}}; - -const std::vector> inputsd = {{0.0000001, 1024, 32, 1234ULL}, - {0.0000001, 1024, 64, 1234ULL}, - {0.0000001, 1024, 128, 1234ULL}, - {0.0000001, 1024, 256, 1234ULL}, - {0.0000001, 1024, 512, 1234ULL}, - {0.0000001, 1024, 1024, 1234ULL}, - {0.0000001, 4096, 32, 1234ULL}, - {0.0000001, 4096, 64, 1234ULL}, - {0.0000001, 4096, 128, 1234ULL}, - {0.0000001, 4096, 256, 1234ULL}, - {0.0000001, 4096, 512, 1234ULL}, - {0.0000001, 4096, 1024, 1234ULL}, - {0.0000001, 8192, 32, 1234ULL}, - {0.0000001, 8192, 64, 1234ULL}, - {0.0000001, 8192, 128, 1234ULL}, - {0.0000001, 8192, 256, 1234ULL}, - {0.0000001, 8192, 512, 1234ULL}, - {0.0000001, 8192, 1024, 1234ULL}, - {0.0000001, 1024, 8192, 1234ULL}}; +const std::vector> inputsf = { + {0.00001f, 1024, 32, 1234ULL}, {0.00001f, 1024, 64, 1234ULL}, {0.00001f, 1024, 128, 1234ULL}, + {0.00001f, 1024, 256, 1234ULL}, {0.00001f, 1024, 512, 1234ULL}, {0.00001f, 1024, 1024, 1234ULL}, + {0.00001f, 4096, 32, 1234ULL}, {0.00001f, 4096, 64, 1234ULL}, {0.00001f, 4096, 128, 1234ULL}, + {0.00001f, 4096, 256, 1234ULL}, {0.00001f, 4096, 512, 1234ULL}, {0.00001f, 4096, 1024, 1234ULL}, + {0.00001f, 8192, 32, 1234ULL}, {0.00001f, 8192, 64, 1234ULL}, {0.00001f, 8192, 128, 1234ULL}, + {0.00001f, 8192, 256, 1234ULL}, {0.00001f, 8192, 512, 1234ULL}, {0.00001f, 8192, 1024, 1234ULL}, + {0.00001f, 1024, 8192, 1234ULL}, {0.00001f, 1023, 5, 1234ULL}, {0.00001f, 1025, 30, 1234ULL}, + {0.00001f, 2047, 65, 1234ULL}, {0.00001f, 2049, 22, 1234ULL}, {0.00001f, 31, 644, 1234ULL}, + {0.00001f, 33, 999, 1234ULL}, {0.00001f, 1, 1, 1234ULL}, {0.00001f, 7, 23, 1234ULL}, + {0.00001f, 17, 5, 1234ULL}}; + +const std::vector> inputsd = { + {0.0000001, 1024, 32, 1234ULL}, {0.0000001, 1024, 64, 1234ULL}, + {0.0000001, 1024, 128, 1234ULL}, {0.0000001, 1024, 256, 1234ULL}, + {0.0000001, 1024, 512, 1234ULL}, {0.0000001, 1024, 1024, 1234ULL}, + {0.0000001, 4096, 32, 1234ULL}, {0.0000001, 4096, 64, 1234ULL}, + {0.0000001, 4096, 128, 1234ULL}, {0.0000001, 4096, 256, 1234ULL}, + {0.0000001, 4096, 512, 1234ULL}, {0.0000001, 4096, 1024, 1234ULL}, + {0.0000001, 8192, 32, 1234ULL}, {0.0000001, 8192, 64, 1234ULL}, + {0.0000001, 8192, 128, 1234ULL}, {0.0000001, 8192, 256, 1234ULL}, + {0.0000001, 8192, 512, 1234ULL}, {0.0000001, 8192, 1024, 1234ULL}, + {0.0000001, 1024, 8192, 1234ULL}, {0.0000001, 1023, 5, 1234ULL}, + {0.0000001, 1025, 30, 1234ULL}, {0.0000001, 2047, 65, 1234ULL}, + {0.0000001, 2049, 22, 1234ULL}, {0.0000001, 31, 644, 1234ULL}, + {0.0000001, 33, 999, 1234ULL}, {0.0000001, 1, 1, 1234ULL}, + {0.0000001, 7, 23, 1234ULL}, {0.0000001, 17, 5, 1234ULL}}; typedef MinMaxTest MinMaxTestF; TEST_P(MinMaxTestF, Result) diff --git a/cpp/test/stats/stddev.cu b/cpp/test/stats/stddev.cu index cf57d3a923..641621c1c6 100644 --- a/cpp/test/stats/stddev.cu +++ b/cpp/test/stats/stddev.cu @@ -141,7 +141,19 @@ const std::vector> inputsf = { {0.1f, -1.f, 2.f, 1024, 32, false, true, 1234ULL}, {0.1f, -1.f, 2.f, 1024, 64, false, true, 1234ULL}, {0.1f, -1.f, 2.f, 1024, 128, false, true, 1234ULL}, - {0.1f, -1.f, 2.f, 1024, 256, false, true, 1234ULL}}; + {0.1f, -1.f, 2.f, 1024, 256, false, true, 1234ULL}, + {0.1f, -1.f, 2.f, 1099, 97, false, false, 1234ULL}, + {0.1f, -1.f, 2.f, 1022, 694, true, false, 1234ULL}, + {0.5f, -1.f, 2.f, 31, 1, true, true, 1234ULL}, + {1.f, -1.f, 2.f, 1, 257, false, true, 1234ULL}, + {0.5f, -1.f, 2.f, 31, 1, false, false, 1234ULL}, + {1.f, -1.f, 2.f, 1, 257, true, false, 1234ULL}, + {1.f, -1.f, 2.f, 1, 1, false, false, 1234ULL}, + {1.f, -1.f, 2.f, 7, 23, false, false, 1234ULL}, + {1.f, -1.f, 2.f, 17, 5, false, false, 1234ULL}, + {1.f, -1.f, 2.f, 1, 1, false, true, 1234ULL}, + {1.f, -1.f, 2.f, 7, 23, false, true, 1234ULL}, + {1.f, -1.f, 2.f, 17, 5, false, true, 1234ULL}}; const std::vector> inputsd = { {0.1, 1.0, 2.0, 1024, 32, true, false, 1234ULL}, @@ -159,13 +171,33 @@ const std::vector> inputsd = { {0.1, -1.0, 2.0, 1024, 32, false, true, 1234ULL}, {0.1, -1.0, 2.0, 1024, 64, false, true, 1234ULL}, {0.1, -1.0, 2.0, 1024, 128, false, true, 1234ULL}, - {0.1, -1.0, 2.0, 1024, 256, false, true, 1234ULL}}; + {0.1, -1.0, 2.0, 1024, 256, false, true, 1234ULL}, + {0.1, -1.0, 2.0, 1099, 97, false, false, 1234ULL}, + {0.1, -1.0, 2.0, 1022, 694, true, false, 1234ULL}, + {0.5, -1.0, 2.0, 31, 1, true, true, 1234ULL}, + {1.0, -1.0, 2.0, 1, 257, false, true, 1234ULL}, + {0.5, -1.0, 2.0, 31, 1, false, false, 1234ULL}, + {1.0, -1.0, 2.0, 1, 257, true, false, 1234ULL}, + {1.0, -1.0, 2.0, 1, 1, false, false, 1234ULL}, + {1.0, -1.0, 2.0, 7, 23, false, false, 1234ULL}, + {1.0, -1.0, 2.0, 17, 5, false, false, 1234ULL}, + {1.0, -1.0, 2.0, 1, 1, false, true, 1234ULL}, + {1.0, -1.0, 2.0, 7, 23, false, true, 1234ULL}, + {1.0, -1.0, 2.0, 17, 5, false, true, 1234ULL}}; typedef StdDevTest StdDevTestF; TEST_P(StdDevTestF, Result) { - ASSERT_TRUE(devArrMatch( - params.stddev, stddev_act.data(), params.cols, CompareApprox(params.tolerance), stream)); + if (params.rows == 1) { + ASSERT_TRUE(devArrMatch( + float(0), stddev_act.data(), params.cols, CompareApprox(params.tolerance), stream)); + } else { + ASSERT_TRUE(devArrMatch(params.stddev, + stddev_act.data(), + params.cols, + CompareApprox(params.tolerance), + stream)); + } ASSERT_TRUE(devArrMatch(stddev_act.data(), vars_act.data(), @@ -177,11 +209,16 @@ TEST_P(StdDevTestF, Result) typedef StdDevTest StdDevTestD; TEST_P(StdDevTestD, Result) { - ASSERT_TRUE(devArrMatch(params.stddev, - stddev_act.data(), - params.cols, - CompareApprox(params.tolerance), - stream)); + if (params.rows == 1) { + ASSERT_TRUE(devArrMatch( + double(0), stddev_act.data(), params.cols, CompareApprox(params.tolerance), stream)); + } else { + ASSERT_TRUE(devArrMatch(params.stddev, + stddev_act.data(), + params.cols, + CompareApprox(params.tolerance), + stream)); + } ASSERT_TRUE(devArrMatch(stddev_act.data(), vars_act.data(), diff --git a/cpp/test/stats/sum.cu b/cpp/test/stats/sum.cu index 5a549f8ba4..bf2aa44a2c 100644 --- a/cpp/test/stats/sum.cu +++ b/cpp/test/stats/sum.cu @@ -33,7 +33,8 @@ template struct SumInputs { T tolerance; int rows, cols; - unsigned long long int seed; + bool rowMajor; + T value = T(1); }; template @@ -56,20 +57,34 @@ class SumTest : public ::testing::TestWithParam> { } protected: - void SetUp() override + void runTest() { int len = rows * cols; - T data_h[len]; + std::vector data_h(len); for (int i = 0; i < len; i++) { - data_h[i] = T(1); + data_h[i] = T(params.value); } - raft::update_device(data.data(), data_h, len, stream); - sum(handle, - raft::make_device_matrix_view(data.data(), rows, cols), - raft::make_device_vector_view(sum_act.data(), cols)); + raft::update_device(data.data(), data_h.data(), len, stream); + + if (params.rowMajor) { + using layout = raft::row_major; + sum(handle, + raft::make_device_matrix_view(data.data(), rows, cols), + raft::make_device_vector_view(sum_act.data(), cols)); + } else { + using layout = raft::col_major; + sum(handle, + raft::make_device_matrix_view(data.data(), rows, cols), + raft::make_device_vector_view(sum_act.data(), cols)); + } resource::sync_stream(handle, stream); + + double expected = double(params.rows) * params.value; + + ASSERT_TRUE(raft::devArrMatch( + T(expected), sum_act.data(), params.cols, raft::CompareApprox(params.tolerance))); } protected: @@ -81,27 +96,49 @@ class SumTest : public ::testing::TestWithParam> { rmm::device_uvector data, sum_act; }; -const std::vector> inputsf = { - {0.05f, 4, 5, 1234ULL}, {0.05f, 1024, 32, 1234ULL}, {0.05f, 1024, 256, 1234ULL}}; - -const std::vector> inputsd = {{0.05, 1024, 32, 1234ULL}, - {0.05, 1024, 256, 1234ULL}}; +const std::vector> inputsf = {{0.0001f, 4, 5, true, 1}, + {0.0001f, 1024, 32, true, 1}, + {0.0001f, 1024, 256, true, 1}, + {0.0001f, 100000000, 1, true, 0.001}, + {0.0001f, 1, 30, true, 0.001}, + {0.0001f, 1, 1, true, 0.001}, + {0.0001f, 17, 5, true, 0.001}, + {0.0001f, 7, 23, true, 0.001}, + {0.0001f, 3, 97, true, 0.001}, + {0.0001f, 4, 5, false, 1}, + {0.0001f, 1024, 32, false, 1}, + {0.0001f, 1024, 256, false, 1}, + {0.0001f, 100000000, 1, false, 0.001}, + {0.0001f, 1, 30, false, 0.001}, + {0.0001f, 1, 1, false, 0.001}, + {0.0001f, 17, 5, false, 0.001}, + {0.0001f, 7, 23, false, 0.001}, + {0.0001f, 3, 97, false, 0.001}}; + +const std::vector> inputsd = {{0.000001, 1024, 32, true, 1}, + {0.000001, 1024, 256, true, 1}, + {0.000001, 1024, 256, true, 1}, + {0.000001, 100000000, 1, true, 0.001}, + {0.000001, 1, 30, true, 0.0001}, + {0.000001, 1, 1, true, 0.0001}, + {0.000001, 17, 5, true, 0.0001}, + {0.000001, 7, 23, true, 0.0001}, + {0.000001, 3, 97, true, 0.0001}, + {0.000001, 1024, 32, false, 1}, + {0.000001, 1024, 256, false, 1}, + {0.000001, 1024, 256, false, 1}, + {0.000001, 100000000, 1, false, 0.001}, + {0.000001, 1, 30, false, 0.0001}, + {0.000001, 1, 1, false, 0.0001}, + {0.000001, 17, 5, false, 0.0001}, + {0.000001, 7, 23, false, 0.0001}, + {0.000001, 3, 97, false, 0.0001}}; typedef SumTest SumTestF; -TEST_P(SumTestF, Result) -{ - ASSERT_TRUE(raft::devArrMatch( - float(params.rows), sum_act.data(), params.cols, raft::CompareApprox(params.tolerance))); -} - typedef SumTest SumTestD; -TEST_P(SumTestD, Result) -{ - ASSERT_TRUE(raft::devArrMatch(double(params.rows), - sum_act.data(), - params.cols, - raft::CompareApprox(params.tolerance))); -} + +TEST_P(SumTestF, Result) { runTest(); } +TEST_P(SumTestD, Result) { runTest(); } INSTANTIATE_TEST_CASE_P(SumTests, SumTestF, ::testing::ValuesIn(inputsf)); diff --git a/dependencies.yaml b/dependencies.yaml index 72aa3427d1..836775a5a3 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -402,7 +402,7 @@ dependencies: common: - output_types: [conda, pyproject] packages: - - &numpy numpy>=1.23 + - &numpy numpy>=1.23,<2.0a0 - output_types: [conda] packages: - *rmm_conda @@ -443,7 +443,7 @@ dependencies: - ucx-py==0.37.* - output_types: conda packages: - - ucx>=1.13.0 + - ucx>=1.15.0,<1.16.0 - ucx-proc=*=gpu - &ucx_py_conda ucx-py==0.37.* - output_types: pyproject diff --git a/docs/source/conf.py b/docs/source/conf.py index 07dd4825fa..8b2040baa2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,7 +1,10 @@ -# Copyright (c) 2018-2023, NVIDIA CORPORATION. +# Copyright (c) 2018-2024, NVIDIA CORPORATION. import os import sys +from packaging.version import Version + +import pylibraft # If extensions (or modules to document with autodoc) are in another # directory, add these directories to sys.path here. If the directory @@ -36,6 +39,8 @@ "sphinx_copybutton" ] + + breathe_default_project = "RAFT" breathe_projects = { "RAFT": "../../cpp/doxygen/_xml/", @@ -62,14 +67,23 @@ copyright = "2023, NVIDIA Corporation" author = "NVIDIA Corporation" +rst_prolog = """ + +.. attention:: + + The vector search and clustering algorithms in RAFT are being migrated to a new library dedicated to vector search called `cuVS `_. We will continue to support the vector search algorithms in RAFT during this move, but will no longer update them after the RAPIDS 24.06 (June) release. We plan to complete the migration by RAPIDS 24.08 (August) release. + +""" + # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # +RAFT_VERSION = Version(pylibraft.__version__) # The short X.Y version. -version = '24.04' +version = f"{RAFT_VERSION.major:02}.{RAFT_VERSION.minor:02}" # The full version, including alpha/beta/rc tags. -release = '24.04.00' +release = f"{RAFT_VERSION.major:02}.{RAFT_VERSION.minor:02}.{RAFT_VERSION.micro:02}" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/source/index.rst b/docs/source/index.rst index ee89aed5a6..bee0e948ff 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -5,6 +5,7 @@ RAPIDS RAFT: Reusable Accelerated Functions and Tools for Vector Search and More :width: 800 :alt: RAFT Tech Stack + Useful Resources ################ diff --git a/fetch_rapids.cmake b/fetch_rapids.cmake deleted file mode 100644 index 1dca136c97..0000000000 --- a/fetch_rapids.cmake +++ /dev/null @@ -1,20 +0,0 @@ -# ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except -# in compliance with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License -# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing permissions and limitations under -# the License. -# ============================================================================= -if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) - file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-24.04/RAPIDS.cmake - ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake - ) -endif() - -include(${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) diff --git a/python/pylibraft/CMakeLists.txt b/python/pylibraft/CMakeLists.txt index c17243728e..7a2d77041d 100644 --- a/python/pylibraft/CMakeLists.txt +++ b/python/pylibraft/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -14,9 +14,7 @@ cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) -include(../../fetch_rapids.cmake) - -set(pylibraft_version 24.04.00) +include(../../rapids_config.cmake) # We always need CUDA for pylibraft because the raft dependency brings in a header-only cuco # dependency that enables CUDA unconditionally. @@ -25,7 +23,7 @@ rapids_cuda_init_architectures(pylibraft) project( pylibraft - VERSION ${pylibraft_version} + VERSION "${RAPIDS_VERSION}" LANGUAGES CXX CUDA ) @@ -35,7 +33,7 @@ option(FIND_RAFT_CPP "Search for existing RAFT C++ installations before defaulti # If the user requested it we attempt to find RAFT. if(FIND_RAFT_CPP) - find_package(raft ${pylibraft_version} REQUIRED COMPONENTS compiled) + find_package(raft "${RAPIDS_VERSION}" REQUIRED COMPONENTS compiled) if(NOT TARGET raft::raft_lib) message( FATAL_ERROR diff --git a/python/pylibraft/pyproject.toml b/python/pylibraft/pyproject.toml index 6468220330..d687f70cf5 100644 --- a/python/pylibraft/pyproject.toml +++ b/python/pylibraft/pyproject.toml @@ -36,7 +36,7 @@ license = { text = "Apache 2.0" } requires-python = ">=3.9" dependencies = [ "cuda-python>=11.7.1,<12.0a0", - "numpy>=1.23", + "numpy>=1.23,<2.0a0", "rmm==24.4.*", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ diff --git a/python/raft-dask/CMakeLists.txt b/python/raft-dask/CMakeLists.txt index ff441e343e..58e5ae8104 100644 --- a/python/raft-dask/CMakeLists.txt +++ b/python/raft-dask/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -14,15 +14,13 @@ cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) -set(raft_dask_version 24.04.00) - -include(../../fetch_rapids.cmake) +include(../../rapids_config.cmake) include(rapids-cuda) rapids_cuda_init_architectures(raft-dask-python) project( raft-dask-python - VERSION ${raft_dask_version} + VERSION "${RAPIDS_VERSION}" LANGUAGES CXX CUDA ) @@ -32,7 +30,7 @@ option(FIND_RAFT_CPP "Search for existing RAFT C++ installations before defaulti # If the user requested it we attempt to find RAFT. if(FIND_RAFT_CPP) - find_package(raft ${raft_dask_version} REQUIRED COMPONENTS distributed) + find_package(raft "${RAPIDS_VERSION}" REQUIRED COMPONENTS distributed) else() set(raft_FOUND OFF) endif() diff --git a/python/raft-dask/pyproject.toml b/python/raft-dask/pyproject.toml index b869290d5c..07e2463c5c 100644 --- a/python/raft-dask/pyproject.toml +++ b/python/raft-dask/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "dask-cuda==24.4.*", "joblib>=0.11", "numba>=0.57", - "numpy>=1.23", + "numpy>=1.23,<2.0a0", "pylibraft==24.4.*", "rapids-dask-dependency==24.4.*", "ucx-py==0.37.*", diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index 118293c093..b2f7d1fb74 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,8 +18,7 @@ import time import uuid import warnings -from collections import Counter, OrderedDict, defaultdict -from typing import Dict +from collections import OrderedDict from dask.distributed import default_client from dask_cuda.utils import nvml_device_index @@ -691,9 +690,11 @@ def _func_ucp_ports(client, workers): def _func_worker_ranks(client, workers): """ - For each worker connected to the client, - compute a global rank which is the sum - of the NVML device index and the worker rank offset. + For each worker connected to the client, compute a global rank which takes + into account the NVML device index and the worker IP + (group workers on same host and order by NVML device). + Note that the reason for sorting was nvbug 4149999 and is presumably + fixed afterNCCL 2.19.3. Parameters ---------- @@ -703,13 +704,13 @@ def _func_worker_ranks(client, workers): # TODO: Add Test this function # Running into build issues preventing testing nvml_device_index_d = client.run(_get_nvml_device_index, workers=workers) - worker_ips = [ - _get_worker_ip(worker_address) - for worker_address in nvml_device_index_d + # Sort workers first by IP and then by the nvml device index: + worker_info_list = [ + (_get_worker_ip(worker), nvml_device_index, worker) + for worker, nvml_device_index in nvml_device_index_d.items() ] - ranks = _map_nvml_device_id_to_contiguous_range(nvml_device_index_d) - worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips) - return _append_rank_offset(ranks, worker_ip_offset_dict) + worker_info_list.sort() + return {wi[2]: i for i, wi in enumerate(worker_info_list)} def _get_nvml_device_index(): @@ -730,73 +731,3 @@ def _get_worker_ip(worker_address): worker_address (str): Full address string of the worker """ return ":".join(worker_address.split(":")[0:2]) - - -def _map_nvml_device_id_to_contiguous_range(nvml_device_index_d: dict) -> dict: - """ - For each worker address in nvml_device_index_d, map the corresponding - worker rank in the range(0, num_workers_per_node) where rank is decided - by the NVML device index. Worker with the lowest NVML device index gets - rank 0, and worker with the highest NVML device index gets rank - num_workers_per_node-1. - - Parameters - ---------- - nvml_device_index_d : dict - Dictionary of worker addresses mapped to their nvml device index. - - Returns - ------- - dict - Updated dictionary with worker addresses mapped to their rank. - """ - - rank_per_ip: Dict[str, int] = defaultdict(int) - - # Sort by NVML index to ensure that the worker - # with the lowest NVML index gets rank 0. - for worker, _ in sorted(nvml_device_index_d.items(), key=lambda x: x[1]): - ip = _get_worker_ip(worker) - - nvml_device_index_d[worker] = rank_per_ip[ip] - rank_per_ip[ip] += 1 - - return nvml_device_index_d - - -def _get_rank_offset_across_nodes(worker_ips): - """ - Get a dictionary of worker IP addresses mapped to the cumulative count of - their occurrences in the worker_ips list. The cumulative count serves as - the rank offset. - - Parameters - ---------- - worker_ips (list): List of worker IP addresses. - """ - worker_count_dict = Counter(worker_ips) - worker_offset_dict = {} - current_offset = 0 - for worker_ip, worker_count in worker_count_dict.items(): - worker_offset_dict[worker_ip] = current_offset - current_offset += worker_count - return worker_offset_dict - - -def _append_rank_offset(rank_dict, worker_ip_offset_dict): - """ - For each worker address in the rank dictionary, add the - corresponding worker offset from the worker_ip_offset_dict - to the rank value. - - Parameters - ---------- - rank_dict (dict): Dictionary of worker addresses mapped to their ranks. - worker_ip_offset_dict (dict): Dictionary of worker IP addresses - mapped to their offsets. - """ - for worker_ip, worker_offset in worker_ip_offset_dict.items(): - for worker_address in rank_dict: - if worker_ip in worker_address: - rank_dict[worker_address] += worker_offset - return rank_dict diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 68c9fee556..b62d7185b2 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -354,3 +354,16 @@ def test_device_multicast_sendrecv(n_trials, client): wait(dfs, timeout=5) assert list(map(lambda x: x.result(), dfs)) + + +@pytest.mark.nccl +@pytest.mark.parametrize( + "subset", [slice(-1, None), slice(1), slice(None, None, -2)] +) +def test_comm_init_worker_subset(client, subset): + # Basic test that initializing a subset of workers is fine + cb = Comms(comms_p2p=True, verbose=True) + + workers = list(client.scheduler_info()["workers"].keys()) + workers = workers[subset] + cb.init(workers=workers) diff --git a/rapids_config.cmake b/rapids_config.cmake new file mode 100644 index 0000000000..c8077f7f4b --- /dev/null +++ b/rapids_config.cmake @@ -0,0 +1,34 @@ +# ============================================================================= +# Copyright (c) 2018-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= +file(READ "${CMAKE_CURRENT_LIST_DIR}/VERSION" _rapids_version) +if(_rapids_version MATCHES [[^([0-9][0-9])\.([0-9][0-9])\.([0-9][0-9])]]) + set(RAPIDS_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(RAPIDS_VERSION_MINOR "${CMAKE_MATCH_2}") + set(RAPIDS_VERSION_PATCH "${CMAKE_MATCH_3}") + set(RAPIDS_VERSION_MAJOR_MINOR "${RAPIDS_VERSION_MAJOR}.${RAPIDS_VERSION_MINOR}") + set(RAPIDS_VERSION "${RAPIDS_VERSION_MAJOR}.${RAPIDS_VERSION_MINOR}.${RAPIDS_VERSION_PATCH}") +else() + string(REPLACE "\n" "\n " _rapids_version_formatted " ${_rapids_version}") + message( + FATAL_ERROR + "Could not determine RAPIDS version. Contents of VERSION file:\n${_rapids_version_formatted}") +endif() + +if(NOT EXISTS "${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS-${RAPIDS_VERSION_MAJOR_MINOR}.cmake") + file( + DOWNLOAD + "https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION_MAJOR_MINOR}/RAPIDS.cmake" + "${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS-${RAPIDS_VERSION_MAJOR_MINOR}.cmake") +endif() +include("${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS-${RAPIDS_VERSION_MAJOR_MINOR}.cmake")