From b203e3f0e25058e47ceae76b2a13baf85c9ef8b4 Mon Sep 17 00:00:00 2001 From: Micka Date: Fri, 15 Mar 2024 17:25:54 +0100 Subject: [PATCH 01/10] Add `compile-library` by default on pylibraft build (#2090) This will avoid confusion for users launching only `./build.sh pylibraft`. Authors: - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Peter Andreas Entschev (https://github.com/pentschev) - Corey J. Nolet (https://github.com/cjnolet) - Robert Maynard (https://github.com/robertmaynard) - Ray Douglass (https://github.com/raydouglass) URL: https://github.com/rapidsai/raft/pull/2090 --- build.sh | 4 ++-- conda/recipes/libraft/build_libraft_template.sh | 4 ++-- cpp/template/README.md | 2 +- cpp/template/build.sh | 2 +- cpp/template/cmake/thirdparty/get_raft.cmake | 9 +++++++-- 5 files changed, 13 insertions(+), 8 deletions(-) diff --git a/build.sh b/build.sh index 148d23c9c1..45c7d1380f 100755 --- a/build.sh +++ b/build.sh @@ -305,7 +305,7 @@ if hasArg --allgpuarch; then BUILD_ALL_GPU_ARCH=1 fi -if hasArg --compile-lib || (( ${NUMARGS} == 0 )); then +if hasArg --compile-lib || hasArg pylibraft || (( ${NUMARGS} == 0 )); then COMPILE_LIBRARY=ON CMAKE_TARGET="${CMAKE_TARGET};raft_lib" fi @@ -405,7 +405,7 @@ fi ################################################################################ # Configure for building all C++ targets -if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || hasArg bench-prims || hasArg bench-ann; then +if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || hasArg bench-prims || hasArg bench-ann || ((${COMPILE_LIBRARY} == ON )); then if (( ${BUILD_ALL_GPU_ARCH} == 0 )); then RAFT_CMAKE_CUDA_ARCHITECTURES="NATIVE" echo "Building for the architecture of the GPU in the system..." diff --git a/conda/recipes/libraft/build_libraft_template.sh b/conda/recipes/libraft/build_libraft_template.sh index bd7719af76..86c0fa11b6 100644 --- a/conda/recipes/libraft/build_libraft_template.sh +++ b/conda/recipes/libraft/build_libraft_template.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Just building template so we verify it uses libraft.so and fail if it doesn't build -./build.sh template +./build.sh template --no-nvtx diff --git a/cpp/template/README.md b/cpp/template/README.md index 348dff270a..05ec48964f 100644 --- a/cpp/template/README.md +++ b/cpp/template/README.md @@ -8,7 +8,7 @@ Once the minimum requirements are satisfied, this example template application c This directory (`RAFT_SOURCE/cpp/template`) can be copied directly in order to build a new application with RAFT. -RAFT can be integrated into an existing CMake project by copying the contents in the `configure rapids-cmake` and `configure raft` sections of the provided `CMakeLists.txt` into your project, along with `cmake/thirdparty/get_raft.cmake`. +RAFT can be integrated into an existing CMake project by copying the contents in the `configure rapids-cmake` and `configure raft` sections of the provided `CMakeLists.txt` into your project, along with `cmake/thirdparty/get_raft.cmake`. Make sure to link against the appropriate Cmake targets. Use `raft::raft`to add make the headers available and `raft::compiled` when utilizing the shared library. diff --git a/cpp/template/build.sh b/cpp/template/build.sh index 3ac00fc9af..49c17f7499 100755 --- a/cpp/template/build.sh +++ b/cpp/template/build.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # raft empty project template build script diff --git a/cpp/template/cmake/thirdparty/get_raft.cmake b/cpp/template/cmake/thirdparty/get_raft.cmake index 6128b5c43c..07b0897be0 100644 --- a/cpp/template/cmake/thirdparty/get_raft.cmake +++ b/cpp/template/cmake/thirdparty/get_raft.cmake @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -33,6 +33,12 @@ function(find_and_configure_raft) #----------------------------------------------------- # Invoke CPM find_package() #----------------------------------------------------- + # Since the RAFT_NVTX option is used by targets generated by + # find_package(RAFT_NVTX) and when building from source we want to + # make `RAFT_NVTX` a cache variable so we get consistent + # behavior + # + set(RAFT_NVTX ${PKG_ENABLE_NVTX} CACHE BOOL "Enable raft nvtx logging" FORCE) rapids_cpm_find(raft ${PKG_VERSION} GLOBAL_TARGETS raft::raft BUILD_EXPORT_SET raft-template-exports @@ -46,7 +52,6 @@ function(find_and_configure_raft) "BUILD_TESTS OFF" "BUILD_PRIMS_BENCH OFF" "BUILD_ANN_BENCH OFF" - "RAFT_NVTX ${ENABLE_NVTX}" "RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}" ) endfunction() From 36484f48fe5dd60939fb7e0610d9c5091c0780b2 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Fri, 15 Mar 2024 14:25:10 -0700 Subject: [PATCH 02/10] MAINT: Simplify NCCL worker rank identification (#2228) This PR is based on @seberg work in https://github.com/rapidsai/raft/pull/1928 . From the PR: This is a follow up on https://github.com/rapidsai/raft/pull/1926, since the rank sorting seemed a bit hard to understand. It does modify the logic in the sense that the host is now sorted by IP as a way to group based on it. But I don't really think that host sorting was ever a goal? If the goal is really about being deterministic, then this should be more (or at least clearer) deterministic about order of worker IPs. OTOH, if the NVML device order doesn't matter, we could just sort the workers directly. The original https://github.com/rapidsai/raft/pull/1587 mentions: NCCL>1.11 expects a process with rank r to be mapped to r % num_gpus_per_node which is something that neither approach seems to quite assure, if such a requirement exists, I would want to do one of: Ensure we can guarantee this, but this requires initializing workers that are not involved in the operation. At least raise an error, because if NCCL will end up raising the error it will be very confusing. Authors: - Vibhu Jawa (https://github.com/VibhuJawa) - Sebastian Berg (https://github.com/seberg) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2228 --- python/raft-dask/raft_dask/common/comms.py | 95 +++---------------- python/raft-dask/raft_dask/test/test_comms.py | 15 ++- 2 files changed, 27 insertions(+), 83 deletions(-) diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index 118293c093..b2f7d1fb74 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,8 +18,7 @@ import time import uuid import warnings -from collections import Counter, OrderedDict, defaultdict -from typing import Dict +from collections import OrderedDict from dask.distributed import default_client from dask_cuda.utils import nvml_device_index @@ -691,9 +690,11 @@ def _func_ucp_ports(client, workers): def _func_worker_ranks(client, workers): """ - For each worker connected to the client, - compute a global rank which is the sum - of the NVML device index and the worker rank offset. + For each worker connected to the client, compute a global rank which takes + into account the NVML device index and the worker IP + (group workers on same host and order by NVML device). + Note that the reason for sorting was nvbug 4149999 and is presumably + fixed afterNCCL 2.19.3. Parameters ---------- @@ -703,13 +704,13 @@ def _func_worker_ranks(client, workers): # TODO: Add Test this function # Running into build issues preventing testing nvml_device_index_d = client.run(_get_nvml_device_index, workers=workers) - worker_ips = [ - _get_worker_ip(worker_address) - for worker_address in nvml_device_index_d + # Sort workers first by IP and then by the nvml device index: + worker_info_list = [ + (_get_worker_ip(worker), nvml_device_index, worker) + for worker, nvml_device_index in nvml_device_index_d.items() ] - ranks = _map_nvml_device_id_to_contiguous_range(nvml_device_index_d) - worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips) - return _append_rank_offset(ranks, worker_ip_offset_dict) + worker_info_list.sort() + return {wi[2]: i for i, wi in enumerate(worker_info_list)} def _get_nvml_device_index(): @@ -730,73 +731,3 @@ def _get_worker_ip(worker_address): worker_address (str): Full address string of the worker """ return ":".join(worker_address.split(":")[0:2]) - - -def _map_nvml_device_id_to_contiguous_range(nvml_device_index_d: dict) -> dict: - """ - For each worker address in nvml_device_index_d, map the corresponding - worker rank in the range(0, num_workers_per_node) where rank is decided - by the NVML device index. Worker with the lowest NVML device index gets - rank 0, and worker with the highest NVML device index gets rank - num_workers_per_node-1. - - Parameters - ---------- - nvml_device_index_d : dict - Dictionary of worker addresses mapped to their nvml device index. - - Returns - ------- - dict - Updated dictionary with worker addresses mapped to their rank. - """ - - rank_per_ip: Dict[str, int] = defaultdict(int) - - # Sort by NVML index to ensure that the worker - # with the lowest NVML index gets rank 0. - for worker, _ in sorted(nvml_device_index_d.items(), key=lambda x: x[1]): - ip = _get_worker_ip(worker) - - nvml_device_index_d[worker] = rank_per_ip[ip] - rank_per_ip[ip] += 1 - - return nvml_device_index_d - - -def _get_rank_offset_across_nodes(worker_ips): - """ - Get a dictionary of worker IP addresses mapped to the cumulative count of - their occurrences in the worker_ips list. The cumulative count serves as - the rank offset. - - Parameters - ---------- - worker_ips (list): List of worker IP addresses. - """ - worker_count_dict = Counter(worker_ips) - worker_offset_dict = {} - current_offset = 0 - for worker_ip, worker_count in worker_count_dict.items(): - worker_offset_dict[worker_ip] = current_offset - current_offset += worker_count - return worker_offset_dict - - -def _append_rank_offset(rank_dict, worker_ip_offset_dict): - """ - For each worker address in the rank dictionary, add the - corresponding worker offset from the worker_ip_offset_dict - to the rank value. - - Parameters - ---------- - rank_dict (dict): Dictionary of worker addresses mapped to their ranks. - worker_ip_offset_dict (dict): Dictionary of worker IP addresses - mapped to their offsets. - """ - for worker_ip, worker_offset in worker_ip_offset_dict.items(): - for worker_address in rank_dict: - if worker_ip in worker_address: - rank_dict[worker_address] += worker_offset - return rank_dict diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 68c9fee556..b62d7185b2 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -354,3 +354,16 @@ def test_device_multicast_sendrecv(n_trials, client): wait(dfs, timeout=5) assert list(map(lambda x: x.result(), dfs)) + + +@pytest.mark.nccl +@pytest.mark.parametrize( + "subset", [slice(-1, None), slice(1), slice(None, None, -2)] +) +def test_comm_init_worker_subset(client, subset): + # Basic test that initializing a subset of workers is fine + cb = Comms(comms_p2p=True, verbose=True) + + workers = list(client.scheduler_info()["workers"].keys()) + workers = workers[subset] + cb.init(workers=workers) From 7335267e7991d5eacef3d445968108a95d0f800a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Malte=20F=C3=B6rster?= <97973773+mfoerste4@users.noreply.github.com> Date: Sat, 16 Mar 2024 05:25:48 +0100 Subject: [PATCH 03/10] Fix illegal acces mean/stdev, sum add Kahan Summation (#2223) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR addresses #2204 and #2205. * fixes illegal access / test coverage for mean row-wise kernel * fixes illegal access / test coverage for stdev row-wise kernel * modified sum kernels to utilize Kahan/Neumaier summation per thread, also increase load per thread to benefit from this FYI, @tfeher Authors: - Malte Förster (https://github.com/mfoerste4) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/2223 --- cpp/include/raft/stats/detail/mean.cuh | 4 +- cpp/include/raft/stats/detail/stddev.cuh | 4 +- cpp/include/raft/stats/detail/sum.cuh | 78 +++++++++++++++++---- cpp/test/stats/mean.cu | 66 +++++++++--------- cpp/test/stats/minmax.cu | 66 +++++++----------- cpp/test/stats/stddev.cu | 55 ++++++++++++--- cpp/test/stats/sum.cu | 89 +++++++++++++++++------- 7 files changed, 236 insertions(+), 126 deletions(-) diff --git a/cpp/include/raft/stats/detail/mean.cuh b/cpp/include/raft/stats/detail/mean.cuh index cf4dbc7aa3..6c330acb26 100644 --- a/cpp/include/raft/stats/detail/mean.cuh +++ b/cpp/include/raft/stats/detail/mean.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ RAFT_KERNEL meanKernelRowMajor(Type* mu, const Type* data, IdxType D, IdxType N) __syncthreads(); raft::myAtomicAdd(smu + thisColId, thread_data); __syncthreads(); - if (threadIdx.x < ColsPerBlk) raft::myAtomicAdd(mu + colId, smu[thisColId]); + if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(mu + colId, smu[thisColId]); } template diff --git a/cpp/include/raft/stats/detail/stddev.cuh b/cpp/include/raft/stats/detail/stddev.cuh index acee4a944e..bc2644a233 100644 --- a/cpp/include/raft/stats/detail/stddev.cuh +++ b/cpp/include/raft/stats/detail/stddev.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ RAFT_KERNEL stddevKernelRowMajor(Type* std, const Type* data, IdxType D, IdxType __syncthreads(); raft::myAtomicAdd(sstd + thisColId, thread_data); __syncthreads(); - if (threadIdx.x < ColsPerBlk) raft::myAtomicAdd(std + colId, sstd[thisColId]); + if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(std + colId, sstd[thisColId]); } template diff --git a/cpp/include/raft/stats/detail/sum.cuh b/cpp/include/raft/stats/detail/sum.cuh index bb45eb50f4..4f85536e6c 100644 --- a/cpp/include/raft/stats/detail/sum.cuh +++ b/cpp/include/raft/stats/detail/sum.cuh @@ -34,30 +34,72 @@ RAFT_KERNEL sumKernelRowMajor(Type* mu, const Type* data, IdxType D, IdxType N) IdxType thisRowId = threadIdx.x / ColsPerBlk; IdxType colId = thisColId + ((IdxType)blockIdx.y * ColsPerBlk); IdxType rowId = thisRowId + ((IdxType)blockIdx.x * RowsPerBlkPerIter); - Type thread_data = Type(0); + Type thread_sum = Type(0); const IdxType stride = RowsPerBlkPerIter * gridDim.x; - for (IdxType i = rowId; i < N; i += stride) - thread_data += (colId < D) ? data[i * D + colId] : Type(0); + for (IdxType i = rowId; i < N; i += stride) { + thread_sum += (colId < D) ? data[i * D + colId] : Type(0); + } __shared__ Type smu[ColsPerBlk]; if (threadIdx.x < ColsPerBlk) smu[threadIdx.x] = Type(0); __syncthreads(); - raft::myAtomicAdd(smu + thisColId, thread_data); + raft::myAtomicAdd(smu + thisColId, thread_sum); + __syncthreads(); + if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(mu + colId, smu[thisColId]); +} + +template +RAFT_KERNEL sumKahanKernelRowMajor(Type* mu, const Type* data, IdxType D, IdxType N) +{ + constexpr int RowsPerBlkPerIter = TPB / ColsPerBlk; + IdxType thisColId = threadIdx.x % ColsPerBlk; + IdxType thisRowId = threadIdx.x / ColsPerBlk; + IdxType colId = thisColId + ((IdxType)blockIdx.y * ColsPerBlk); + IdxType rowId = thisRowId + ((IdxType)blockIdx.x * RowsPerBlkPerIter); + Type thread_sum = Type(0); + Type thread_c = Type(0); + const IdxType stride = RowsPerBlkPerIter * gridDim.x; + for (IdxType i = rowId; i < N; i += stride) { + // KahanBabushkaNeumaierSum + const Type cur_value = (colId < D) ? data[i * D + colId] : Type(0); + const Type t = thread_sum + cur_value; + if (abs(thread_sum) >= abs(cur_value)) { + thread_c += (thread_sum - t) + cur_value; + } else { + thread_c += (cur_value - t) + thread_sum; + } + thread_sum = t; + } + thread_sum += thread_c; + __shared__ Type smu[ColsPerBlk]; + if (threadIdx.x < ColsPerBlk) smu[threadIdx.x] = Type(0); + __syncthreads(); + raft::myAtomicAdd(smu + thisColId, thread_sum); __syncthreads(); if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(mu + colId, smu[thisColId]); } template -RAFT_KERNEL sumKernelColMajor(Type* mu, const Type* data, IdxType D, IdxType N) +RAFT_KERNEL sumKahanKernelColMajor(Type* mu, const Type* data, IdxType D, IdxType N) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - Type thread_data = Type(0); + Type thread_sum = Type(0); + Type thread_c = Type(0); IdxType colStart = N * blockIdx.x; for (IdxType i = threadIdx.x; i < N; i += TPB) { - IdxType idx = colStart + i; - thread_data += data[idx]; + // KahanBabushkaNeumaierSum + IdxType idx = colStart + i; + const Type cur_value = data[idx]; + const Type t = thread_sum + cur_value; + if (abs(thread_sum) >= abs(cur_value)) { + thread_c += (thread_sum - t) + cur_value; + } else { + thread_c += (cur_value - t) + thread_sum; + } + thread_sum = t; } - Type acc = BlockReduce(temp_storage).Sum(thread_data); + thread_sum += thread_c; + Type acc = BlockReduce(temp_storage).Sum(thread_sum); if (threadIdx.x == 0) { mu[blockIdx.x] = acc; } } @@ -66,15 +108,21 @@ void sum(Type* output, const Type* input, IdxType D, IdxType N, bool rowMajor, c { static const int TPB = 256; if (rowMajor) { - static const int RowsPerThread = 4; - static const int ColsPerBlk = 32; - static const int RowsPerBlk = (TPB / ColsPerBlk) * RowsPerThread; - dim3 grid(raft::ceildiv(N, (IdxType)RowsPerBlk), raft::ceildiv(D, (IdxType)ColsPerBlk)); + static const int ColsPerBlk = 8; + static const int MinRowsPerThread = 16; + static const int MinRowsPerBlk = (TPB / ColsPerBlk) * MinRowsPerThread; + static const int MaxBlocksDimX = 8192; + + const IdxType grid_y = raft::ceildiv(D, (IdxType)ColsPerBlk); + const IdxType grid_x = + raft::min((IdxType)MaxBlocksDimX, raft::ceildiv(N, (IdxType)MinRowsPerBlk)); + + dim3 grid(grid_x, grid_y); RAFT_CUDA_TRY(cudaMemset(output, 0, sizeof(Type) * D)); - sumKernelRowMajor + sumKahanKernelRowMajor <<>>(output, input, D, N); } else { - sumKernelColMajor<<>>(output, input, D, N); + sumKahanKernelColMajor<<>>(output, input, D, N); } RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/test/stats/mean.cu b/cpp/test/stats/mean.cu index 67931e2eed..61b57ce739 100644 --- a/cpp/test/stats/mean.cu +++ b/cpp/test/stats/mean.cu @@ -95,39 +95,39 @@ class MeanTest : public ::testing::TestWithParam> { // Note: For 1024 samples, 256 experiments, a mean of 1.0 with stddev=1.0, the // measured mean (of a normal distribution) will fall outside of an epsilon of // 0.15 only 4/10000 times. (epsilon of 0.1 will fail 30/100 times) -const std::vector> inputsf = {{0.15f, 1.f, 1024, 32, true, false, 1234ULL}, - {0.15f, 1.f, 1024, 64, true, false, 1234ULL}, - {0.15f, 1.f, 1024, 128, true, false, 1234ULL}, - {0.15f, 1.f, 1024, 256, true, false, 1234ULL}, - {0.15f, -1.f, 1024, 32, false, false, 1234ULL}, - {0.15f, -1.f, 1024, 64, false, false, 1234ULL}, - {0.15f, -1.f, 1024, 128, false, false, 1234ULL}, - {0.15f, -1.f, 1024, 256, false, false, 1234ULL}, - {0.15f, 1.f, 1024, 32, true, true, 1234ULL}, - {0.15f, 1.f, 1024, 64, true, true, 1234ULL}, - {0.15f, 1.f, 1024, 128, true, true, 1234ULL}, - {0.15f, 1.f, 1024, 256, true, true, 1234ULL}, - {0.15f, -1.f, 1024, 32, false, true, 1234ULL}, - {0.15f, -1.f, 1024, 64, false, true, 1234ULL}, - {0.15f, -1.f, 1024, 128, false, true, 1234ULL}, - {0.15f, -1.f, 1024, 256, false, true, 1234ULL}}; - -const std::vector> inputsd = {{0.15, 1.0, 1024, 32, true, false, 1234ULL}, - {0.15, 1.0, 1024, 64, true, false, 1234ULL}, - {0.15, 1.0, 1024, 128, true, false, 1234ULL}, - {0.15, 1.0, 1024, 256, true, false, 1234ULL}, - {0.15, -1.0, 1024, 32, false, false, 1234ULL}, - {0.15, -1.0, 1024, 64, false, false, 1234ULL}, - {0.15, -1.0, 1024, 128, false, false, 1234ULL}, - {0.15, -1.0, 1024, 256, false, false, 1234ULL}, - {0.15, 1.0, 1024, 32, true, true, 1234ULL}, - {0.15, 1.0, 1024, 64, true, true, 1234ULL}, - {0.15, 1.0, 1024, 128, true, true, 1234ULL}, - {0.15, 1.0, 1024, 256, true, true, 1234ULL}, - {0.15, -1.0, 1024, 32, false, true, 1234ULL}, - {0.15, -1.0, 1024, 64, false, true, 1234ULL}, - {0.15, -1.0, 1024, 128, false, true, 1234ULL}, - {0.15, -1.0, 1024, 256, false, true, 1234ULL}}; +const std::vector> inputsf = { + {0.15f, 1.f, 1024, 32, true, false, 1234ULL}, {0.15f, 1.f, 1024, 64, true, false, 1234ULL}, + {0.15f, 1.f, 1024, 128, true, false, 1234ULL}, {0.15f, 1.f, 1024, 256, true, false, 1234ULL}, + {0.15f, -1.f, 1024, 32, false, false, 1234ULL}, {0.15f, -1.f, 1024, 64, false, false, 1234ULL}, + {0.15f, -1.f, 1024, 128, false, false, 1234ULL}, {0.15f, -1.f, 1024, 256, false, false, 1234ULL}, + {0.15f, 1.f, 1024, 32, true, true, 1234ULL}, {0.15f, 1.f, 1024, 64, true, true, 1234ULL}, + {0.15f, 1.f, 1024, 128, true, true, 1234ULL}, {0.15f, 1.f, 1024, 256, true, true, 1234ULL}, + {0.15f, -1.f, 1024, 32, false, true, 1234ULL}, {0.15f, -1.f, 1024, 64, false, true, 1234ULL}, + {0.15f, -1.f, 1024, 128, false, true, 1234ULL}, {0.15f, -1.f, 1024, 256, false, true, 1234ULL}, + {0.15f, -1.f, 1030, 1, false, false, 1234ULL}, {0.15f, -1.f, 1030, 60, true, false, 1234ULL}, + {2.0f, -1.f, 31, 120, false, false, 1234ULL}, {2.0f, -1.f, 1, 130, true, false, 1234ULL}, + {0.15f, -1.f, 1030, 1, false, true, 1234ULL}, {0.15f, -1.f, 1030, 60, true, true, 1234ULL}, + {2.0f, -1.f, 31, 120, false, true, 1234ULL}, {2.0f, -1.f, 1, 130, false, true, 1234ULL}, + {2.0f, -1.f, 1, 1, false, false, 1234ULL}, {2.0f, -1.f, 1, 1, false, true, 1234ULL}, + {2.0f, -1.f, 7, 23, false, false, 1234ULL}, {2.0f, -1.f, 7, 23, false, true, 1234ULL}, + {2.0f, -1.f, 17, 5, false, false, 1234ULL}, {2.0f, -1.f, 17, 5, false, true, 1234ULL}}; + +const std::vector> inputsd = { + {0.15, 1.0, 1024, 32, true, false, 1234ULL}, {0.15, 1.0, 1024, 64, true, false, 1234ULL}, + {0.15, 1.0, 1024, 128, true, false, 1234ULL}, {0.15, 1.0, 1024, 256, true, false, 1234ULL}, + {0.15, -1.0, 1024, 32, false, false, 1234ULL}, {0.15, -1.0, 1024, 64, false, false, 1234ULL}, + {0.15, -1.0, 1024, 128, false, false, 1234ULL}, {0.15, -1.0, 1024, 256, false, false, 1234ULL}, + {0.15, 1.0, 1024, 32, true, true, 1234ULL}, {0.15, 1.0, 1024, 64, true, true, 1234ULL}, + {0.15, 1.0, 1024, 128, true, true, 1234ULL}, {0.15, 1.0, 1024, 256, true, true, 1234ULL}, + {0.15, -1.0, 1024, 32, false, true, 1234ULL}, {0.15, -1.0, 1024, 64, false, true, 1234ULL}, + {0.15, -1.0, 1024, 128, false, true, 1234ULL}, {0.15, -1.0, 1024, 256, false, true, 1234ULL}, + {0.15, -1.0, 1030, 1, false, false, 1234ULL}, {0.15, -1.0, 1030, 60, true, false, 1234ULL}, + {2.0, -1.0, 31, 120, false, false, 1234ULL}, {2.0, -1.0, 1, 130, true, false, 1234ULL}, + {0.15, -1.0, 1030, 1, false, true, 1234ULL}, {0.15, -1.0, 1030, 60, true, true, 1234ULL}, + {2.0, -1.0, 31, 120, false, true, 1234ULL}, {2.0, -1.0, 1, 130, false, true, 1234ULL}, + {2.0, -1.0, 1, 1, false, false, 1234ULL}, {2.0, -1.0, 1, 1, false, true, 1234ULL}, + {2.0, -1.0, 7, 23, false, false, 1234ULL}, {2.0, -1.0, 7, 23, false, true, 1234ULL}, + {2.0, -1.0, 17, 5, false, false, 1234ULL}, {2.0, -1.0, 17, 5, false, true, 1234ULL}}; typedef MeanTest MeanTestF; TEST_P(MeanTestF, Result) diff --git a/cpp/test/stats/minmax.cu b/cpp/test/stats/minmax.cu index 7563cb12be..fd909ebb90 100644 --- a/cpp/test/stats/minmax.cu +++ b/cpp/test/stats/minmax.cu @@ -145,45 +145,33 @@ class MinMaxTest : public ::testing::TestWithParam> { rmm::device_uvector minmax_ref; }; -const std::vector> inputsf = {{0.00001f, 1024, 32, 1234ULL}, - {0.00001f, 1024, 64, 1234ULL}, - {0.00001f, 1024, 128, 1234ULL}, - {0.00001f, 1024, 256, 1234ULL}, - {0.00001f, 1024, 512, 1234ULL}, - {0.00001f, 1024, 1024, 1234ULL}, - {0.00001f, 4096, 32, 1234ULL}, - {0.00001f, 4096, 64, 1234ULL}, - {0.00001f, 4096, 128, 1234ULL}, - {0.00001f, 4096, 256, 1234ULL}, - {0.00001f, 4096, 512, 1234ULL}, - {0.00001f, 4096, 1024, 1234ULL}, - {0.00001f, 8192, 32, 1234ULL}, - {0.00001f, 8192, 64, 1234ULL}, - {0.00001f, 8192, 128, 1234ULL}, - {0.00001f, 8192, 256, 1234ULL}, - {0.00001f, 8192, 512, 1234ULL}, - {0.00001f, 8192, 1024, 1234ULL}, - {0.00001f, 1024, 8192, 1234ULL}}; - -const std::vector> inputsd = {{0.0000001, 1024, 32, 1234ULL}, - {0.0000001, 1024, 64, 1234ULL}, - {0.0000001, 1024, 128, 1234ULL}, - {0.0000001, 1024, 256, 1234ULL}, - {0.0000001, 1024, 512, 1234ULL}, - {0.0000001, 1024, 1024, 1234ULL}, - {0.0000001, 4096, 32, 1234ULL}, - {0.0000001, 4096, 64, 1234ULL}, - {0.0000001, 4096, 128, 1234ULL}, - {0.0000001, 4096, 256, 1234ULL}, - {0.0000001, 4096, 512, 1234ULL}, - {0.0000001, 4096, 1024, 1234ULL}, - {0.0000001, 8192, 32, 1234ULL}, - {0.0000001, 8192, 64, 1234ULL}, - {0.0000001, 8192, 128, 1234ULL}, - {0.0000001, 8192, 256, 1234ULL}, - {0.0000001, 8192, 512, 1234ULL}, - {0.0000001, 8192, 1024, 1234ULL}, - {0.0000001, 1024, 8192, 1234ULL}}; +const std::vector> inputsf = { + {0.00001f, 1024, 32, 1234ULL}, {0.00001f, 1024, 64, 1234ULL}, {0.00001f, 1024, 128, 1234ULL}, + {0.00001f, 1024, 256, 1234ULL}, {0.00001f, 1024, 512, 1234ULL}, {0.00001f, 1024, 1024, 1234ULL}, + {0.00001f, 4096, 32, 1234ULL}, {0.00001f, 4096, 64, 1234ULL}, {0.00001f, 4096, 128, 1234ULL}, + {0.00001f, 4096, 256, 1234ULL}, {0.00001f, 4096, 512, 1234ULL}, {0.00001f, 4096, 1024, 1234ULL}, + {0.00001f, 8192, 32, 1234ULL}, {0.00001f, 8192, 64, 1234ULL}, {0.00001f, 8192, 128, 1234ULL}, + {0.00001f, 8192, 256, 1234ULL}, {0.00001f, 8192, 512, 1234ULL}, {0.00001f, 8192, 1024, 1234ULL}, + {0.00001f, 1024, 8192, 1234ULL}, {0.00001f, 1023, 5, 1234ULL}, {0.00001f, 1025, 30, 1234ULL}, + {0.00001f, 2047, 65, 1234ULL}, {0.00001f, 2049, 22, 1234ULL}, {0.00001f, 31, 644, 1234ULL}, + {0.00001f, 33, 999, 1234ULL}, {0.00001f, 1, 1, 1234ULL}, {0.00001f, 7, 23, 1234ULL}, + {0.00001f, 17, 5, 1234ULL}}; + +const std::vector> inputsd = { + {0.0000001, 1024, 32, 1234ULL}, {0.0000001, 1024, 64, 1234ULL}, + {0.0000001, 1024, 128, 1234ULL}, {0.0000001, 1024, 256, 1234ULL}, + {0.0000001, 1024, 512, 1234ULL}, {0.0000001, 1024, 1024, 1234ULL}, + {0.0000001, 4096, 32, 1234ULL}, {0.0000001, 4096, 64, 1234ULL}, + {0.0000001, 4096, 128, 1234ULL}, {0.0000001, 4096, 256, 1234ULL}, + {0.0000001, 4096, 512, 1234ULL}, {0.0000001, 4096, 1024, 1234ULL}, + {0.0000001, 8192, 32, 1234ULL}, {0.0000001, 8192, 64, 1234ULL}, + {0.0000001, 8192, 128, 1234ULL}, {0.0000001, 8192, 256, 1234ULL}, + {0.0000001, 8192, 512, 1234ULL}, {0.0000001, 8192, 1024, 1234ULL}, + {0.0000001, 1024, 8192, 1234ULL}, {0.0000001, 1023, 5, 1234ULL}, + {0.0000001, 1025, 30, 1234ULL}, {0.0000001, 2047, 65, 1234ULL}, + {0.0000001, 2049, 22, 1234ULL}, {0.0000001, 31, 644, 1234ULL}, + {0.0000001, 33, 999, 1234ULL}, {0.0000001, 1, 1, 1234ULL}, + {0.0000001, 7, 23, 1234ULL}, {0.0000001, 17, 5, 1234ULL}}; typedef MinMaxTest MinMaxTestF; TEST_P(MinMaxTestF, Result) diff --git a/cpp/test/stats/stddev.cu b/cpp/test/stats/stddev.cu index cf57d3a923..641621c1c6 100644 --- a/cpp/test/stats/stddev.cu +++ b/cpp/test/stats/stddev.cu @@ -141,7 +141,19 @@ const std::vector> inputsf = { {0.1f, -1.f, 2.f, 1024, 32, false, true, 1234ULL}, {0.1f, -1.f, 2.f, 1024, 64, false, true, 1234ULL}, {0.1f, -1.f, 2.f, 1024, 128, false, true, 1234ULL}, - {0.1f, -1.f, 2.f, 1024, 256, false, true, 1234ULL}}; + {0.1f, -1.f, 2.f, 1024, 256, false, true, 1234ULL}, + {0.1f, -1.f, 2.f, 1099, 97, false, false, 1234ULL}, + {0.1f, -1.f, 2.f, 1022, 694, true, false, 1234ULL}, + {0.5f, -1.f, 2.f, 31, 1, true, true, 1234ULL}, + {1.f, -1.f, 2.f, 1, 257, false, true, 1234ULL}, + {0.5f, -1.f, 2.f, 31, 1, false, false, 1234ULL}, + {1.f, -1.f, 2.f, 1, 257, true, false, 1234ULL}, + {1.f, -1.f, 2.f, 1, 1, false, false, 1234ULL}, + {1.f, -1.f, 2.f, 7, 23, false, false, 1234ULL}, + {1.f, -1.f, 2.f, 17, 5, false, false, 1234ULL}, + {1.f, -1.f, 2.f, 1, 1, false, true, 1234ULL}, + {1.f, -1.f, 2.f, 7, 23, false, true, 1234ULL}, + {1.f, -1.f, 2.f, 17, 5, false, true, 1234ULL}}; const std::vector> inputsd = { {0.1, 1.0, 2.0, 1024, 32, true, false, 1234ULL}, @@ -159,13 +171,33 @@ const std::vector> inputsd = { {0.1, -1.0, 2.0, 1024, 32, false, true, 1234ULL}, {0.1, -1.0, 2.0, 1024, 64, false, true, 1234ULL}, {0.1, -1.0, 2.0, 1024, 128, false, true, 1234ULL}, - {0.1, -1.0, 2.0, 1024, 256, false, true, 1234ULL}}; + {0.1, -1.0, 2.0, 1024, 256, false, true, 1234ULL}, + {0.1, -1.0, 2.0, 1099, 97, false, false, 1234ULL}, + {0.1, -1.0, 2.0, 1022, 694, true, false, 1234ULL}, + {0.5, -1.0, 2.0, 31, 1, true, true, 1234ULL}, + {1.0, -1.0, 2.0, 1, 257, false, true, 1234ULL}, + {0.5, -1.0, 2.0, 31, 1, false, false, 1234ULL}, + {1.0, -1.0, 2.0, 1, 257, true, false, 1234ULL}, + {1.0, -1.0, 2.0, 1, 1, false, false, 1234ULL}, + {1.0, -1.0, 2.0, 7, 23, false, false, 1234ULL}, + {1.0, -1.0, 2.0, 17, 5, false, false, 1234ULL}, + {1.0, -1.0, 2.0, 1, 1, false, true, 1234ULL}, + {1.0, -1.0, 2.0, 7, 23, false, true, 1234ULL}, + {1.0, -1.0, 2.0, 17, 5, false, true, 1234ULL}}; typedef StdDevTest StdDevTestF; TEST_P(StdDevTestF, Result) { - ASSERT_TRUE(devArrMatch( - params.stddev, stddev_act.data(), params.cols, CompareApprox(params.tolerance), stream)); + if (params.rows == 1) { + ASSERT_TRUE(devArrMatch( + float(0), stddev_act.data(), params.cols, CompareApprox(params.tolerance), stream)); + } else { + ASSERT_TRUE(devArrMatch(params.stddev, + stddev_act.data(), + params.cols, + CompareApprox(params.tolerance), + stream)); + } ASSERT_TRUE(devArrMatch(stddev_act.data(), vars_act.data(), @@ -177,11 +209,16 @@ TEST_P(StdDevTestF, Result) typedef StdDevTest StdDevTestD; TEST_P(StdDevTestD, Result) { - ASSERT_TRUE(devArrMatch(params.stddev, - stddev_act.data(), - params.cols, - CompareApprox(params.tolerance), - stream)); + if (params.rows == 1) { + ASSERT_TRUE(devArrMatch( + double(0), stddev_act.data(), params.cols, CompareApprox(params.tolerance), stream)); + } else { + ASSERT_TRUE(devArrMatch(params.stddev, + stddev_act.data(), + params.cols, + CompareApprox(params.tolerance), + stream)); + } ASSERT_TRUE(devArrMatch(stddev_act.data(), vars_act.data(), diff --git a/cpp/test/stats/sum.cu b/cpp/test/stats/sum.cu index 5a549f8ba4..bf2aa44a2c 100644 --- a/cpp/test/stats/sum.cu +++ b/cpp/test/stats/sum.cu @@ -33,7 +33,8 @@ template struct SumInputs { T tolerance; int rows, cols; - unsigned long long int seed; + bool rowMajor; + T value = T(1); }; template @@ -56,20 +57,34 @@ class SumTest : public ::testing::TestWithParam> { } protected: - void SetUp() override + void runTest() { int len = rows * cols; - T data_h[len]; + std::vector data_h(len); for (int i = 0; i < len; i++) { - data_h[i] = T(1); + data_h[i] = T(params.value); } - raft::update_device(data.data(), data_h, len, stream); - sum(handle, - raft::make_device_matrix_view(data.data(), rows, cols), - raft::make_device_vector_view(sum_act.data(), cols)); + raft::update_device(data.data(), data_h.data(), len, stream); + + if (params.rowMajor) { + using layout = raft::row_major; + sum(handle, + raft::make_device_matrix_view(data.data(), rows, cols), + raft::make_device_vector_view(sum_act.data(), cols)); + } else { + using layout = raft::col_major; + sum(handle, + raft::make_device_matrix_view(data.data(), rows, cols), + raft::make_device_vector_view(sum_act.data(), cols)); + } resource::sync_stream(handle, stream); + + double expected = double(params.rows) * params.value; + + ASSERT_TRUE(raft::devArrMatch( + T(expected), sum_act.data(), params.cols, raft::CompareApprox(params.tolerance))); } protected: @@ -81,27 +96,49 @@ class SumTest : public ::testing::TestWithParam> { rmm::device_uvector data, sum_act; }; -const std::vector> inputsf = { - {0.05f, 4, 5, 1234ULL}, {0.05f, 1024, 32, 1234ULL}, {0.05f, 1024, 256, 1234ULL}}; - -const std::vector> inputsd = {{0.05, 1024, 32, 1234ULL}, - {0.05, 1024, 256, 1234ULL}}; +const std::vector> inputsf = {{0.0001f, 4, 5, true, 1}, + {0.0001f, 1024, 32, true, 1}, + {0.0001f, 1024, 256, true, 1}, + {0.0001f, 100000000, 1, true, 0.001}, + {0.0001f, 1, 30, true, 0.001}, + {0.0001f, 1, 1, true, 0.001}, + {0.0001f, 17, 5, true, 0.001}, + {0.0001f, 7, 23, true, 0.001}, + {0.0001f, 3, 97, true, 0.001}, + {0.0001f, 4, 5, false, 1}, + {0.0001f, 1024, 32, false, 1}, + {0.0001f, 1024, 256, false, 1}, + {0.0001f, 100000000, 1, false, 0.001}, + {0.0001f, 1, 30, false, 0.001}, + {0.0001f, 1, 1, false, 0.001}, + {0.0001f, 17, 5, false, 0.001}, + {0.0001f, 7, 23, false, 0.001}, + {0.0001f, 3, 97, false, 0.001}}; + +const std::vector> inputsd = {{0.000001, 1024, 32, true, 1}, + {0.000001, 1024, 256, true, 1}, + {0.000001, 1024, 256, true, 1}, + {0.000001, 100000000, 1, true, 0.001}, + {0.000001, 1, 30, true, 0.0001}, + {0.000001, 1, 1, true, 0.0001}, + {0.000001, 17, 5, true, 0.0001}, + {0.000001, 7, 23, true, 0.0001}, + {0.000001, 3, 97, true, 0.0001}, + {0.000001, 1024, 32, false, 1}, + {0.000001, 1024, 256, false, 1}, + {0.000001, 1024, 256, false, 1}, + {0.000001, 100000000, 1, false, 0.001}, + {0.000001, 1, 30, false, 0.0001}, + {0.000001, 1, 1, false, 0.0001}, + {0.000001, 17, 5, false, 0.0001}, + {0.000001, 7, 23, false, 0.0001}, + {0.000001, 3, 97, false, 0.0001}}; typedef SumTest SumTestF; -TEST_P(SumTestF, Result) -{ - ASSERT_TRUE(raft::devArrMatch( - float(params.rows), sum_act.data(), params.cols, raft::CompareApprox(params.tolerance))); -} - typedef SumTest SumTestD; -TEST_P(SumTestD, Result) -{ - ASSERT_TRUE(raft::devArrMatch(double(params.rows), - sum_act.data(), - params.cols, - raft::CompareApprox(params.tolerance))); -} + +TEST_P(SumTestF, Result) { runTest(); } +TEST_P(SumTestD, Result) { runTest(); } INSTANTIATE_TEST_CASE_P(SumTests, SumTestF, ::testing::ValuesIn(inputsf)); From 69fd9714afa4a40cf3d68ef6a2ed137545ef52a2 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 18 Mar 2024 12:09:24 -0400 Subject: [PATCH 04/10] Replace local copyright check with pre-commit-hooks verify-copyright (#2220) The local `copyright.py` script is bug-prone. Replace it with a more robust centralized script from `pre-commit-hooks`. Issue: https://github.com/rapidsai/build-planning/issues/30 Authors: - Kyle Edwards (https://github.com/KyleFromNVIDIA) Approvers: - Jake Awe (https://github.com/AyodeAwe) URL: https://github.com/rapidsai/raft/pull/2220 --- .pre-commit-config.yaml | 23 +++- ci/checks/copyright.py | 290 ---------------------------------------- 2 files changed, 17 insertions(+), 296 deletions(-) delete mode 100644 ci/checks/copyright.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0d6ab7ee54..6bca1d228e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -81,12 +81,6 @@ repos: verbose: true require_serial: true exclude: .*/thirdparty/.* - - id: copyright-check - name: copyright-check - entry: python ./ci/checks/copyright.py --git-modified-only --update-current-year - language: python - pass_filenames: false - additional_dependencies: [gitpython] - id: include-check name: include-check entry: python ./cpp/scripts/include_checker.py cpp/bench cpp/include cpp/test @@ -109,6 +103,23 @@ repos: rev: v4.5.0 hooks: - id: check-json + - repo: https://github.com/rapidsai/pre-commit-hooks + rev: v0.0.1 + hooks: + - id: verify-copyright + files: | + (?x) + [.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$| + CMakeLists[.]txt$| + CMakeLists_standalone[.]txt$| + meta[.]yaml$| + setup[.]cfg$ + exclude: | + (?x) + cpp/include/raft/neighbors/detail/faiss_select/| + cpp/include/raft/thirdparty/| + docs/source/sphinxext/github_link[.]py| + cpp/cmake/modules/FindAVX[.]cmake default_language_version: python: python3 diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py deleted file mode 100644 index 2af8b1b8ff..0000000000 --- a/ci/checks/copyright.py +++ /dev/null @@ -1,290 +0,0 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import datetime -import re -import argparse -import io -import os -import sys - -import git - -SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - -# Add the scripts dir for gitutils -sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, - "../../cpp/scripts"))) - -# Now import gitutils. Ignore flake8 error here since there is no other way to -# set up imports -import gitutils # noqa: E402 - -FilesToCheck = [ - re.compile(r"[.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$"), - re.compile(r"CMakeLists[.]txt$"), - re.compile(r"CMakeLists_standalone[.]txt$"), - re.compile(r"setup[.]cfg$"), - re.compile(r"meta[.]yaml$") -] -ExemptFiles = [ - re.compile("cpp/include/raft/neighbors/detail/faiss_select/"), - re.compile("cpp/include/raft/thirdparty/"), - re.compile("docs/source/sphinxext/github_link.py"), - re.compile("cpp/cmake/modules/FindAVX.cmake") -] - -# this will break starting at year 10000, which is probably OK :) -CheckSimple = re.compile( - r"Copyright *(?:\(c\))? *(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)") -CheckDouble = re.compile( - r"Copyright *(?:\(c\))? *(\d{4})-(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)" # noqa: E501 -) - - -def checkThisFile(f): - if isinstance(f, git.Diff): - if f.deleted_file or f.b_blob.size == 0: - return False - f = f.b_path - elif not os.path.exists(f) or os.stat(f).st_size == 0: - # This check covers things like symlinks which point to files that DNE - return False - for exempt in ExemptFiles: - if exempt.search(f): - return False - for checker in FilesToCheck: - if checker.search(f): - return True - return False - - -def modifiedFiles(): - """Get a set of all modified files, as Diff objects. - - The files returned have been modified in git since the merge base of HEAD - and the upstream of the target branch. We return the Diff objects so that - we can read only the staged changes. - """ - repo = git.Repo() - # Use the environment variable TARGET_BRANCH or RAPIDS_BASE_BRANCH (defined in CI) if possible - target_branch = os.environ.get("TARGET_BRANCH", os.environ.get("RAPIDS_BASE_BRANCH")) - if target_branch is None: - # Fall back to the closest branch if not on CI - target_branch = repo.git.describe( - all=True, tags=True, match="branch-*", abbrev=0 - ).lstrip("heads/") - - upstream_target_branch = None - if target_branch in repo.heads: - # Use the tracking branch of the local reference if it exists. This - # returns None if no tracking branch is set. - upstream_target_branch = repo.heads[target_branch].tracking_branch() - if upstream_target_branch is None: - # Fall back to the remote with the newest target_branch. This code - # path is used on CI because the only local branch reference is - # current-pr-branch, and thus target_branch is not in repo.heads. - # This also happens if no tracking branch is defined for the local - # target_branch. We use the remote with the latest commit if - # multiple remotes are defined. - candidate_branches = [ - remote.refs[target_branch] for remote in repo.remotes - if target_branch in remote.refs - ] - if len(candidate_branches) > 0: - upstream_target_branch = sorted( - candidate_branches, - key=lambda branch: branch.commit.committed_datetime, - )[-1] - else: - # If no remotes are defined, try to use the local version of the - # target_branch. If this fails, the repo configuration must be very - # strange and we can fix this script on a case-by-case basis. - upstream_target_branch = repo.heads[target_branch] - merge_base = repo.merge_base("HEAD", upstream_target_branch.commit)[0] - diff = merge_base.diff() - changed_files = {f for f in diff if f.b_path is not None} - return changed_files - - -def getCopyrightYears(line): - res = CheckSimple.search(line) - if res: - return int(res.group(1)), int(res.group(1)) - res = CheckDouble.search(line) - if res: - return int(res.group(1)), int(res.group(2)) - return None, None - - -def replaceCurrentYear(line, start, end): - # first turn a simple regex into double (if applicable). then update years - res = CheckSimple.sub(r"Copyright (c) \1-\1, NVIDIA CORPORATION", line) - res = CheckDouble.sub( - rf"Copyright (c) {start:04d}-{end:04d}, NVIDIA CORPORATION", - res, - ) - return res - - -def checkCopyright(f, update_current_year): - """Checks for copyright headers and their years.""" - errs = [] - thisYear = datetime.datetime.now().year - lineNum = 0 - crFound = False - yearMatched = False - - if isinstance(f, git.Diff): - path = f.b_path - lines = f.b_blob.data_stream.read().decode().splitlines(keepends=True) - else: - path = f - with open(f, encoding="utf-8") as fp: - lines = fp.readlines() - - for line in lines: - lineNum += 1 - start, end = getCopyrightYears(line) - if start is None: - continue - crFound = True - if start > end: - e = [ - path, - lineNum, - "First year after second year in the copyright " - "header (manual fix required)", - None, - ] - errs.append(e) - elif thisYear < start or thisYear > end: - e = [ - path, - lineNum, - "Current year not included in the copyright header", - None, - ] - if thisYear < start: - e[-1] = replaceCurrentYear(line, thisYear, end) - if thisYear > end: - e[-1] = replaceCurrentYear(line, start, thisYear) - errs.append(e) - else: - yearMatched = True - # copyright header itself not found - if not crFound: - e = [ - path, - 0, - "Copyright header missing or formatted incorrectly " - "(manual fix required)", - None, - ] - errs.append(e) - # even if the year matches a copyright header, make the check pass - if yearMatched: - errs = [] - - if update_current_year: - errs_update = [x for x in errs if x[-1] is not None] - if len(errs_update) > 0: - lines_changed = ", ".join(str(x[1]) for x in errs_update) - print(f"File: {path}. Changing line(s) {lines_changed}") - for _, lineNum, __, replacement in errs_update: - lines[lineNum - 1] = replacement - with open(path, "w", encoding="utf-8") as out_file: - out_file.writelines(lines) - - return errs - - -def getAllFilesUnderDir(root, pathFilter=None): - retList = [] - for dirpath, dirnames, filenames in os.walk(root): - for fn in filenames: - filePath = os.path.join(dirpath, fn) - if pathFilter(filePath): - retList.append(filePath) - return retList - - -def checkCopyright_main(): - """ - Checks for copyright headers in all the modified files. In case of local - repo, this script will just look for uncommitted files and in case of CI - it compares between branches "$PR_TARGET_BRANCH" and "current-pr-branch" - """ - retVal = 0 - - argparser = argparse.ArgumentParser( - "Checks for a consistent copyright header in git's modified files" - ) - argparser.add_argument( - "--update-current-year", - dest="update_current_year", - action="store_true", - required=False, - help="If set, " - "update the current year if a header is already " - "present and well formatted.", - ) - argparser.add_argument( - "--git-modified-only", - dest="git_modified_only", - action="store_true", - required=False, - help="If set, " - "only files seen as modified by git will be " - "processed.", - ) - - args, dirs = argparser.parse_known_args() - - if args.git_modified_only: - files = [f for f in modifiedFiles() if checkThisFile(f)] - else: - files = [] - for d in [os.path.abspath(d) for d in dirs]: - if not os.path.isdir(d): - raise ValueError(f"{d} is not a directory.") - files += getAllFilesUnderDir(d, pathFilter=checkThisFile) - - errors = [] - for f in files: - errors += checkCopyright(f, args.update_current_year) - - if len(errors) > 0: - if any(e[-1] is None for e in errors): - print("Copyright headers incomplete in some of the files!") - for e in errors: - print(" %s:%d Issue: %s" % (e[0], e[1], e[2])) - print("") - n_fixable = sum(1 for e in errors if e[-1] is not None) - path_parts = os.path.abspath(__file__).split(os.sep) - file_from_repo = os.sep.join(path_parts[path_parts.index("ci") :]) - if n_fixable > 0 and not args.update_current_year: - print( - f"You can run `python {file_from_repo} --git-modified-only " - "--update-current-year` and stage the results in git to " - f"fix {n_fixable} of these errors.\n" - ) - retVal = 1 - - return retVal - - -if __name__ == "__main__": - sys.exit(checkCopyright_main()) From 32f6f40c561a97bc0174f2ebf0894d3289a76d65 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Mon, 18 Mar 2024 17:14:11 +0100 Subject: [PATCH 05/10] Add CAGRA-Q build (compression) (#2213) Add a `cagra::compress` function that implements CAGRA-Q (VQ + PQ) compression of a given dataset. The result, `compressed_dataset`, is supposed to complement the CAGRA graph during `cagra::search` in place of a raw dataset. ### Current state: - The code runs and produces a meaningful output (tested internally by running the original prototype search with the generated compressed dataset); the recall levels are approximately the same as with the prototype implementation. - No test coverage yet (need to coordinate with the search PR https://github.com/rapidsai/raft/pull/2206) - Full `pq_bits` support ([4,5,6,7,8] - same as in IVF-PQ) - Any `pq_dim` values are accepted, but the dataset is not padded and thus `dim` must be a multiple of `pq_dim`. - The codebook math type is hardcoded to `half` to match the prototype implementation for now. This could be a runtime (build) parameter as well. - All common input data types should work (`uint8_t`, `int8_t`, `half`, and `float` compile), but I tested only `float`. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/2213 --- .../core/detail/mdspan_numpy_serializer.hpp | 3 +- cpp/include/raft/neighbors/cagra.cuh | 89 ++-- cpp/include/raft/neighbors/cagra_types.hpp | 90 ++-- cpp/include/raft/neighbors/dataset.hpp | 330 ++++++++++++++ .../neighbors/detail/cagra/cagra_build.cuh | 10 + .../detail/cagra/cagra_serialize.cuh | 35 +- .../neighbors/detail/dataset_serialize.hpp | 192 ++++++++ .../raft/neighbors/detail/vpq_dataset.cuh | 427 ++++++++++++++++++ cpp/include/raft/neighbors/vpq_dataset.cuh | 51 +++ 9 files changed, 1103 insertions(+), 124 deletions(-) create mode 100644 cpp/include/raft/neighbors/dataset.hpp create mode 100644 cpp/include/raft/neighbors/detail/dataset_serialize.hpp create mode 100644 cpp/include/raft/neighbors/detail/vpq_dataset.cuh create mode 100644 cpp/include/raft/neighbors/vpq_dataset.cuh diff --git a/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp b/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp index 176309c8ce..3fb7b3005b 100644 --- a/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp +++ b/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp @@ -126,7 +126,8 @@ inline dtype_t get_numpy_dtype() } #if defined(_RAFT_HAS_CUDA) -template , bool> = true> +template , half>, bool> = true> inline dtype_t get_numpy_dtype() { return {RAFT_NUMPY_HOST_ENDIAN_CHAR, 'e', sizeof(T)}; diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index b8258297e6..b7e362f704 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -279,62 +280,6 @@ index build(raft::resources const& res, return detail::build(res, params, dataset); } -/** - * @brief Search ANN using the constructed index. - * - * See the [cagra::build](#cagra::build) documentation for a usage example. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] res raft resources - * @param[in] params configure the search - * @param[in] idx cagra index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ -template -void search(raft::resources const& res, - const search_params& params, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must equal k"); - RAFT_EXPECTS(queries.extent(1) == idx.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - using internal_IdxT = typename std::make_unsigned::type; - auto queries_internal = raft::make_device_matrix_view( - queries.data_handle(), queries.extent(0), queries.extent(1)); - auto neighbors_internal = raft::make_device_matrix_view( - reinterpret_cast(neighbors.data_handle()), - neighbors.extent(0), - neighbors.extent(1)); - auto distances_internal = raft::make_device_matrix_view( - distances.data_handle(), distances.extent(0), distances.extent(1)); - - cagra::detail::search_main(res, - params, - idx, - queries_internal, - neighbors_internal, - distances_internal, - raft::neighbors::filtering::none_cagra_sample_filter()); -} - /** * @brief Search ANN using the constructed index with the given sample filter. * @@ -401,10 +346,40 @@ void search_with_filtering(raft::resources const& res, auto distances_internal = raft::make_device_matrix_view( distances.data_handle(), distances.extent(0), distances.extent(1)); - cagra::detail::search_main( + return cagra::detail::search_main( res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); } +/** + * @brief Search ANN using the constructed index. + * + * See the [cagra::build](#cagra::build) documentation for a usage example. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] params configure the search + * @param[in] idx cagra index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search(raft::resources const& res, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + using none_filter_type = raft::neighbors::filtering::none_cagra_sample_filter; + return cagra::search_with_filtering( + res, params, idx, queries, neighbors, distances, none_filter_type{}); +} + /** @} */ // end group cagra } // namespace raft::neighbors::cagra diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 0f574ae5bb..807f89fd65 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -17,6 +17,7 @@ #pragma once #include "ann_types.hpp" +#include "dataset.hpp" #include #include @@ -35,6 +36,7 @@ #include #include #include + namespace raft::neighbors::cagra { /** * @addtogroup cagra @@ -61,6 +63,12 @@ struct index_params : ann::index_params { graph_build_algo build_algo = graph_build_algo::IVF_PQ; /** Number of Iterations to run if building with NN_DESCENT */ size_t nn_descent_niter = 20; + /** + * Specify compression params if compression is desired. + * + * NOTE: this is experimental new API, consider it unsafe. + */ + std::optional compression = std::nullopt; }; enum class search_algo { @@ -145,25 +153,37 @@ struct index : ann::index { /** Total length of the index (number of vectors). */ [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { - return dataset_view_.extent(0) ? dataset_view_.extent(0) : graph_view_.extent(0); + auto data_rows = dataset_->n_rows(); + return data_rows > 0 ? data_rows : graph_view_.extent(0); } /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t - { - return dataset_view_.extent(1); - } + [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { return dataset_->dim(); } /** Graph degree */ [[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t { return graph_view_.extent(1); } - /** Dataset [size, dim] */ - [[nodiscard]] inline auto dataset() const noexcept + /** + * DEPRECATED: please use data() instead. + * If you need to query dataset dimensions, use the dim() and size() of the cagra index. + * The data_handle() is not always available: you need to do a dynamic_cast to the expected + * dataset type at runtime. + */ + [[nodiscard]] [[deprecated("Use data()")]] inline auto dataset() const noexcept -> device_matrix_view { - return dataset_view_; + auto p = dynamic_cast*>(dataset_.get()); + if (p != nullptr) { return p->view(); } + auto d = dataset_->dim(); + return make_device_strided_matrix_view(nullptr, 0, d, d); + } + + /** Dataset [size, dim] */ + [[nodiscard]] inline auto data() const noexcept -> const neighbors::dataset& + { + return *dataset_; } /** neighborhood graph [size, graph-degree] */ @@ -185,7 +205,7 @@ struct index : ann::index { raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded) : ann::index(), metric_(metric), - dataset_(make_device_matrix(res, 0, 0)), + dataset_(new neighbors::empty_dataset(0)), graph_(make_device_matrix(res, 0, 0)) { } @@ -251,12 +271,11 @@ struct index : ann::index { mdspan, row_major, graph_accessor> knn_graph) : ann::index(), metric_(metric), - dataset_(make_device_matrix(res, 0, 0)), + dataset_(make_aligned_dataset(res, dataset, 16)), graph_(make_device_matrix(res, 0, 0)) { RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), "Dataset and knn_graph must have equal number of rows"); - update_dataset(res, dataset); update_graph(res, knn_graph); resource::sync_stream(res); } @@ -271,21 +290,14 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - if (dataset.extent(1) * sizeof(T) % 16 != 0) { - RAFT_LOG_DEBUG("Creating a padded copy of CAGRA dataset in device memory"); - copy_padded(res, dataset); - } else { - dataset_view_ = make_device_strided_matrix_view( - dataset.data_handle(), dataset.extent(0), dataset.extent(1), dataset.extent(1)); - } + dataset_ = make_aligned_dataset(res, dataset, 16); } /** Set the dataset reference explicitly to a device matrix view with padding. */ - void update_dataset(raft::resources const&, + void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - RAFT_EXPECTS(dataset.stride(0) * sizeof(T) % 16 == 0, "Incorrect data padding."); - dataset_view_ = dataset; + dataset_ = make_aligned_dataset(res, dataset, 16); } /** @@ -296,8 +308,22 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::host_matrix_view dataset) { - RAFT_LOG_DEBUG("Copying CAGRA dataset from host to device"); - copy_padded(res, dataset); + dataset_ = make_aligned_dataset(res, dataset, 16); + } + + /** Replace the dataset with a new dataset. */ + template + auto update_dataset(raft::resources const& res, DatasetT&& dataset) + -> std::enable_if_t, DatasetT>> + { + dataset_ = std::make_unique(std::move(dataset)); + } + + template + auto update_dataset(raft::resources const& res, std::unique_ptr&& dataset) + -> std::enable_if_t, DatasetT>> + { + dataset_ = std::move(dataset); } /** @@ -334,26 +360,10 @@ struct index : ann::index { } private: - /** Create a device copy of the dataset, and pad it if necessary. */ - template - void copy_padded(raft::resources const& res, - mdspan, row_major, data_accessor> dataset) - { - detail::copy_with_padding(res, dataset_, dataset); - - dataset_view_ = make_device_strided_matrix_view( - dataset_.data_handle(), dataset_.extent(0), dataset.extent(1), dataset_.extent(1)); - RAFT_LOG_DEBUG("CAGRA dataset strided matrix view %zux%zu, stride %zu", - static_cast(dataset_view_.extent(0)), - static_cast(dataset_view_.extent(1)), - static_cast(dataset_view_.stride(0))); - } - raft::distance::DistanceType metric_; - raft::device_matrix dataset_; raft::device_matrix graph_; - raft::device_matrix_view dataset_view_; raft::device_matrix_view graph_view_; + std::unique_ptr> dataset_; }; /** @} */ diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp new file mode 100644 index 0000000000..e7a3ba97a4 --- /dev/null +++ b/cpp/include/raft/neighbors/dataset.hpp @@ -0,0 +1,330 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include // get_device_for_address +#include // rounding up + +#include +#include +#include + +#ifdef __cpp_lib_bitops +#include +#endif + +namespace raft::neighbors { + +/** Two-dimensional dataset; maybe owning, maybe compressed, maybe strided. */ +template +struct dataset { + using index_type = IdxT; + /** Size of the dataset. */ + [[nodiscard]] virtual auto n_rows() const noexcept -> index_type = 0; + /** Dimensionality of the dataset. */ + [[nodiscard]] virtual auto dim() const noexcept -> uint32_t = 0; + /** Whether the object owns the data. */ + [[nodiscard]] virtual auto is_owning() const noexcept -> bool = 0; + virtual ~dataset() noexcept = default; +}; + +template +struct empty_dataset : public dataset { + using index_type = IdxT; + uint32_t suggested_dim; + explicit empty_dataset(uint32_t dim) noexcept : suggested_dim(dim) {} + [[nodiscard]] auto n_rows() const noexcept -> index_type final { return 0; } + [[nodiscard]] auto dim() const noexcept -> uint32_t final { return suggested_dim; } + [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } +}; + +template +struct strided_dataset : public dataset { + using index_type = IdxT; + using value_type = DataT; + using view_type = device_matrix_view; + [[nodiscard]] auto n_rows() const noexcept -> index_type final { return view().extent(0); } + [[nodiscard]] auto dim() const noexcept -> uint32_t final + { + return static_cast(view().extent(1)); + } + /** Leading dimension of the dataset. */ + [[nodiscard]] constexpr auto stride() const noexcept -> uint32_t + { + auto v = view(); + return static_cast(v.stride(0) > 0 ? v.stride(0) : v.extent(1)); + } + /** Get the view of the data. */ + [[nodiscard]] virtual auto view() const noexcept -> view_type; +}; + +template +struct non_owning_dataset : public strided_dataset { + using index_type = IdxT; + using value_type = DataT; + using typename strided_dataset::view_type; + view_type data; + explicit non_owning_dataset(view_type v) noexcept : data(v) {} + [[nodiscard]] auto is_owning() const noexcept -> bool final { return false; } + [[nodiscard]] auto view() const noexcept -> view_type final { return data; }; +}; + +template +struct owning_dataset : public strided_dataset { + using index_type = IdxT; + using value_type = DataT; + using typename strided_dataset::view_type; + using storage_type = + mdarray, LayoutPolicy, ContainerPolicy>; + using mapping_type = typename view_type::mapping_type; + storage_type data; + mapping_type view_mapping; + owning_dataset(storage_type&& store, mapping_type view_mapping) noexcept + : data{std::move(store)}, view_mapping{view_mapping} + { + } + + [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } + [[nodiscard]] auto view() const noexcept -> view_type final + { + return view_type{data.data_handle(), view_mapping}; + }; +}; + +/** + * @brief Contstruct a strided matrix from any mdarray or mdspan. + * + * This function constructs a non-owning view if the input satisfied two conditions: + * + * 1) The data is accessible from the current device + * 2) The memory layout is the same as expected (row-major matrix with the required stride) + * + * Otherwise, this function constructs an owning device matrix and copies the data. + * When the data is copied, padding elements are filled with zeroes. + * + * @tparam SrcT the source mdarray or mdspan + * + * @param[in] res raft resources handle + * @param[in] src the source mdarray or mdspan + * @param[in] required_stride the leading dimension (in elements) + * @return maybe owning current-device-accessible strided matrix + */ +template +auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t required_stride) + -> std::unique_ptr> +{ + using extents_type = typename SrcT::extents_type; + using value_type = typename SrcT::value_type; + using index_type = typename SrcT::index_type; + using layout_type = typename SrcT::layout_type; + static_assert(extents_type::rank() == 2, "The input must be a matrix."); + static_assert(std::is_same_v || + std::is_same_v> || + std::is_same_v, + "The input must be row-major"); + RAFT_EXPECTS(src.extent(1) <= required_stride, + "The input row length must be not larger than the desired stride."); + cudaPointerAttributes ptr_attrs; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&ptr_attrs, src.data_handle())); + auto* device_ptr = reinterpret_cast(ptr_attrs.devicePointer); + const uint32_t src_stride = src.stride(0) > 0 ? src.stride(0) : src.extent(1); + const bool device_accessible = device_ptr != nullptr; + const bool row_major = src.stride(1) <= 1; + const bool stride_matches = required_stride == src_stride; + + if (device_accessible && row_major && stride_matches) { + // Everything matches: make a non-owning dataset + return std::make_unique>( + make_device_strided_matrix_view( + device_ptr, src.extent(0), src.extent(1), required_stride)); + } + // Something is wrong: have to make a copy and produce an owning dataset + auto out_layout = + make_strided_layout(src.extents(), std::array{required_stride, 1}); + auto out_array = make_device_matrix(res, src.extent(0), required_stride); + + using out_mdarray_type = decltype(out_array); + using out_layout_type = typename out_mdarray_type::layout_type; + using out_container_policy_type = typename out_mdarray_type::container_policy_type; + using out_owning_type = + owning_dataset; + + RAFT_CUDA_TRY(cudaMemsetAsync(out_array.data_handle(), + 0, + out_array.size() * sizeof(value_type), + resource::get_cuda_stream(res))); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(), + sizeof(value_type) * required_stride, + src.data_handle(), + sizeof(value_type) * src_stride, + sizeof(value_type) * src.extent(1), + src.extent(0), + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + + return std::make_unique(std::move(out_array), out_layout); +} + +/** + * @brief Contstruct a strided matrix from any mdarray or mdspan. + * + * A variant `make_strided_dataset` that allows specifying the byte alignment instead of the + * explicit stride length. + * + * @tparam SrcT the source mdarray or mdspan + * + * @param[in] res raft resources handle + * @param[in] src the source mdarray or mdspan + * @param[in] align_bytes the required byte alignment for the dataset rows. + * @return maybe owning current-device-accessible strided matrix + */ +template +auto make_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t align_bytes = 16) + -> std::unique_ptr> +{ + using value_type = typename SrcT::value_type; + constexpr size_t kSize = sizeof(value_type); + uint32_t required_stride = + raft::round_up_safe(src.extent(1) * kSize, std::lcm(align_bytes, kSize)) / kSize; + return make_strided_dataset(res, src, required_stride); +} + +/** Parameters for VPQ compression. */ +struct vpq_params { + /** + * The bit length of the vector element after compression by PQ. + * + * Possible values: [4, 5, 6, 7, 8]. + * + * Hint: the smaller the 'pq_bits', the smaller the index size and the better the search + * performance, but the lower the recall. + */ + uint32_t pq_bits = 8; + /** + * The dimensionality of the vector after compression by PQ. + * When zero, an optimal value is selected using a heuristic. + * + * TODO: at the moment `dim` must be a multiple `pq_dim`. + */ + uint32_t pq_dim = 0; + /** + * Vector Quantization (VQ) codebook size - number of "coarse cluster centers". + * When zero, an optimal value is selected using a heuristic. + */ + uint32_t vq_n_centers = 0; + /** The number of iterations searching for kmeans centers (both VQ & PQ phases). */ + uint32_t kmeans_n_iters = 25; + /** + * The fraction of data to use during iterative kmeans building (VQ phase). + * When zero, an optimal value is selected using a heuristic. + */ + double vq_kmeans_trainset_fraction = 0; + /** + * The fraction of data to use during iterative kmeans building (PQ phase). + * When zero, an optimal value is selected using a heuristic. + */ + double pq_kmeans_trainset_fraction = 0; +}; + +/** + * @brief VPQ compressed dataset. + * + * The dataset is compressed using two level quantization + * + * 1. Vector Quantization + * 2. Product Quantization of residuals + * + * @tparam MathT the type of elements in the codebooks + * @tparam IdxT type of the vector indices (represent dataset.extent(0)) + * + */ +template +struct vpq_dataset : public dataset { + /** Vector Quantization codebook - "coarse cluster centers". */ + device_matrix vq_code_book; + /** Product Quantization codebook - "fine cluster centers". */ + device_matrix pq_code_book; + /** Compressed dataset. */ + device_matrix data; + + vpq_dataset(device_matrix&& vq_code_book, + device_matrix&& pq_code_book, + device_matrix&& data) + : vq_code_book{std::move(vq_code_book)}, + pq_code_book{std::move(pq_code_book)}, + data{std::move(data)} + { + } + + [[nodiscard]] auto n_rows() const noexcept -> IdxT final { return data.extent(0); } + [[nodiscard]] auto dim() const noexcept -> uint32_t final { return vq_code_book.extent(1); } + [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } + + /** Row length of the encoded data in bytes. */ + [[nodiscard]] constexpr inline auto encoded_row_length() const noexcept -> uint32_t + { + return data.extent(1); + } + /** The number of "coarse cluster centers" */ + [[nodiscard]] constexpr inline auto vq_n_centers() const noexcept -> uint32_t + { + return vq_code_book.extent(0); + } + /** The bit length of an encoded vector element after compression by PQ. */ + [[nodiscard]] constexpr inline auto pq_bits() const noexcept -> uint32_t + { + /* + NOTE: pq_bits and the book size + + Normally, we'd store `pq_bits` as a part of the index. + However, we know there's an invariant `pq_n_centers = 1 << pq_bits`, i.e. the codebook size is + the same as the number of possible code values. Hence, we don't store the pq_bits and derive it + from the array dimensions instead. + */ + auto pq_width = pq_n_centers(); +#ifdef __cpp_lib_bitops + return std::countr_zero(pq_width); +#else + uint32_t pq_bits = 0; + while (pq_width > 1) { + pq_bits++; + pq_width >>= 1; + } + return pq_bits; +#endif + } + /** The dimensionality of an encoded vector after compression by PQ. */ + [[nodiscard]] constexpr inline auto pq_dim() const noexcept -> uint32_t + { + return raft::div_rounding_up_unsafe(dim(), pq_len()); + } + /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ + [[nodiscard]] constexpr inline auto pq_len() const noexcept -> uint32_t + { + return pq_code_book.extent(1); + } + /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ + [[nodiscard]] constexpr inline auto pq_n_centers() const noexcept -> uint32_t + { + return pq_code_book.extent(0); + } +}; + +} // namespace raft::neighbors diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 08cc2beaeb..d91e45257e 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -16,6 +16,7 @@ #pragma once #include "../../cagra_types.hpp" +#include "../../vpq_dataset.cuh" #include "graph_core.cuh" #include @@ -344,6 +345,15 @@ index build( RAFT_LOG_INFO("Graph optimized, creating index"); // Construct an index from dataset and optimized knn graph. if (construct_index_with_dataset) { + if (params.compression.has_value()) { + index idx(res, params.metric); + idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view())); + idx.update_dataset( + res, + // TODO: hardcoding codebook math to `half`, we can do runtime dispatching later + neighbors::vpq_build(res, *params.compression, dataset)); + return idx; + } return index(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view())); } else { // We just add the graph. User is expected to update dataset separately. This branch is used diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index d7bd27222b..600c8785e0 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -31,7 +32,7 @@ namespace raft::neighbors::cagra::detail { -constexpr int serialization_version = 3; +constexpr int serialization_version = 4; /** * Save the index to file. @@ -65,26 +66,14 @@ void serialize(raft::resources const& res, serialize_scalar(res, os, index_.metric()); serialize_mdspan(res, os, index_.graph()); - include_dataset &= (index_.dataset().extent(0) > 0); + include_dataset &= (index_.data().n_rows() > 0); serialize_scalar(res, os, include_dataset); if (include_dataset) { RAFT_LOG_INFO("Saving CAGRA index with dataset"); - auto dataset = index_.dataset(); - // Remove padding before saving the dataset - auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), - sizeof(T) * host_dataset.extent(1), - dataset.data_handle(), - sizeof(T) * dataset.stride(0), - sizeof(T) * host_dataset.extent(1), - dataset.extent(0), - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - resource::sync_stream(res); - serialize_mdspan(res, os, host_dataset.view()); + neighbors::detail::serialize(res, os, index_.data()); } else { - RAFT_LOG_INFO("Saving CAGRA index WITHOUT dataset"); + RAFT_LOG_DEBUG("Saving CAGRA index WITHOUT dataset"); } } @@ -256,19 +245,13 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index auto graph = raft::make_host_matrix(n_rows, graph_degree); deserialize_mdspan(res, is, graph.view()); + index idx(res, metric); + idx.update_graph(res, raft::make_const_mdspan(graph.view())); bool has_dataset = deserialize_scalar(res, is); if (has_dataset) { - auto dataset = raft::make_host_matrix(n_rows, dim); - deserialize_mdspan(res, is, dataset.view()); - return index( - res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view())); - } else { - // create a new index with no dataset - the user must supply via update_dataset themselves - // later (this avoids allocating GPU memory in the meantime) - index idx(res, metric); - idx.update_graph(res, raft::make_const_mdspan(graph.view())); - return idx; + idx.update_dataset(res, neighbors::detail::deserialize_dataset(res, is)); } + return idx; } template diff --git a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp new file mode 100644 index 0000000000..a6a6ae59a5 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../dataset.hpp" + +#include +#include +#include + +#include + +#include +#include + +namespace raft::neighbors::detail { + +using dataset_instance_tag = uint32_t; +constexpr dataset_instance_tag kSerializeEmptyDataset = 1; +constexpr dataset_instance_tag kSerializeStridedDataset = 2; +constexpr dataset_instance_tag kSerializeVPQDataset = 3; + +template +void serialize(const raft::resources& res, std::ostream& os, const empty_dataset& dataset) +{ + serialize_scalar(res, os, dataset.suggested_dim); +} + +template +void serialize(const raft::resources& res, + std::ostream& os, + const strided_dataset& dataset) +{ + auto n_rows = dataset.n_rows(); + auto dim = dataset.dim(); + auto stride = dataset.stride(); + serialize_scalar(res, os, n_rows); + serialize_scalar(res, os, dim); + serialize_scalar(res, os, stride); + // Remove padding before saving the dataset + auto src = dataset.view(); + auto dst = make_host_matrix(n_rows, dim); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(), + sizeof(DataT) * dim, + src.data_handle(), + sizeof(DataT) * stride, + sizeof(DataT) * dim, + n_rows, + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + resource::sync_stream(res); + serialize_mdspan(res, os, dst.view()); +} + +template +void serialize(const raft::resources& res, + std::ostream& os, + const vpq_dataset& dataset) +{ + serialize_scalar(res, os, dataset.n_rows()); + serialize_scalar(res, os, dataset.dim()); + serialize_scalar(res, os, dataset.vq_n_centers()); + serialize_scalar(res, os, dataset.pq_n_centers()); + serialize_scalar(res, os, dataset.pq_len()); + serialize_scalar(res, os, dataset.encoded_row_length()); + serialize_mdspan(res, os, make_const_mdspan(dataset.vq_code_book.view())); + serialize_mdspan(res, os, make_const_mdspan(dataset.pq_code_book.view())); + serialize_mdspan(res, os, make_const_mdspan(dataset.data.view())); +} + +template +void serialize(const raft::resources& res, std::ostream& os, const dataset& dataset) +{ + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeEmptyDataset); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_32F); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_16F); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_8I); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_8U); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeVPQDataset); + serialize_scalar(res, os, CUDA_R_32F); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeVPQDataset); + serialize_scalar(res, os, CUDA_R_16F); + return serialize(res, os, *x); + } + RAFT_FAIL("unsupported dataset type."); +} + +template +auto deserialize_empty(raft::resources const& res, std::istream& is) + -> std::unique_ptr> +{ + auto suggested_dim = deserialize_scalar(res, is); + return std::make_unique>(suggested_dim); +} + +template +auto deserialize_strided(raft::resources const& res, std::istream& is) + -> std::unique_ptr> +{ + auto n_rows = deserialize_scalar(res, is); + auto dim = deserialize_scalar(res, is); + auto stride = deserialize_scalar(res, is); + auto host_array = make_host_matrix(n_rows, dim); + deserialize_mdspan(res, is, host_array.view()); + return make_strided_dataset(res, host_array, stride); +} + +template +auto deserialize_vpq(raft::resources const& res, std::istream& is) + -> std::unique_ptr> +{ + auto n_rows = deserialize_scalar(res, is); + auto dim = deserialize_scalar(res, is); + auto vq_n_centers = deserialize_scalar(res, is); + auto pq_n_centers = deserialize_scalar(res, is); + auto pq_len = deserialize_scalar(res, is); + auto encoded_row_length = deserialize_scalar(res, is); + + auto vq_code_book = make_device_matrix(res, vq_n_centers, dim); + auto pq_code_book = make_device_matrix(res, pq_n_centers, pq_len); + auto data = make_device_matrix(res, n_rows, encoded_row_length); + + deserialize_mdspan(res, is, vq_code_book.view()); + deserialize_mdspan(res, is, pq_code_book.view()); + deserialize_mdspan(res, is, data.view()); + + return std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book), std::move(data)); +} + +template +auto deserialize_dataset(raft::resources const& res, std::istream& is) + -> std::unique_ptr> +{ + switch (deserialize_scalar(res, is)) { + case kSerializeEmptyDataset: return deserialize_empty(res, is); + case kSerializeStridedDataset: + switch (deserialize_scalar(res, is)) { + case CUDA_R_32F: return deserialize_strided(res, is); + case CUDA_R_16F: return deserialize_strided(res, is); + case CUDA_R_8I: return deserialize_strided(res, is); + case CUDA_R_8U: return deserialize_strided(res, is); + default: break; + } + case kSerializeVPQDataset: + switch (deserialize_scalar(res, is)) { + case CUDA_R_32F: return deserialize_vpq(res, is); + case CUDA_R_16F: return deserialize_vpq(res, is); + default: break; + } + default: break; + } + RAFT_FAIL("Failed to deserialize dataset: unsupported combination of instance tags."); +} + +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/vpq_dataset.cuh b/cpp/include/raft/neighbors/detail/vpq_dataset.cuh new file mode 100644 index 0000000000..f6cd2a1ceb --- /dev/null +++ b/cpp/include/raft/neighbors/detail/vpq_dataset.cuh @@ -0,0 +1,427 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../dataset.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include // pq_bits-bitfield +#include // utils::mapping etc +#include +#include + +// A temporary stub till https://github.com/rapidsai/raft/pull/2077 is re-merged +namespace raft::util { + +/** + * Subsample the dataset to create a training set. + * + * @tparam DatasetT a row-major mdspan or mdarray (device or host) + * + * @param res raft handle + * @param dataset input row-major mdspan or mdarray (device or host) + * @param n_samples the size of the output mdarray + * + * @return a newly allocated subset of the dataset. + */ +template +auto subsample(raft::resources const& res, + const DatasetT& dataset, + typename DatasetT::index_type n_samples) + -> raft::device_matrix +{ + using value_type = typename DatasetT::value_type; + using index_type = typename DatasetT::index_type; + static_assert(std::is_same_v, + "Only row-major layout is supported at the moment"); + RAFT_EXPECTS(n_samples <= dataset.extent(0), + "The number of samples must be smaller than the number of input rows in the current " + "implementation."); + size_t dim = dataset.extent(1); + size_t trainset_ratio = dataset.extent(0) / n_samples; + auto result = raft::make_device_matrix(res, n_samples, dataset.extent(1)); + + RAFT_CUDA_TRY(cudaMemcpy2DAsync(result.data_handle(), + sizeof(value_type) * dim, + dataset.data_handle(), + sizeof(value_type) * dim * trainset_ratio, + sizeof(value_type) * dim, + n_samples, + cudaMemcpyDefault, + raft::resource::get_cuda_stream(res))); + return result; +} + +} // namespace raft::util + +namespace raft::neighbors::detail { + +template +auto fill_missing_params_heuristics(const vpq_params& params, const DatasetT& dataset) -> vpq_params +{ + vpq_params r = params; + double n_rows = dataset.extent(0); + size_t dim = dataset.extent(1); + if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{4}); } + if (r.pq_bits == 0) { r.pq_bits = 8; } + if (r.vq_n_centers == 0) { r.vq_n_centers = raft::round_up_safe(std::sqrt(n_rows), 8); } + if (r.vq_kmeans_trainset_fraction == 0) { + double vq_trainset_size = 100.0 * r.vq_n_centers; + r.vq_kmeans_trainset_fraction = std::min(1.0, vq_trainset_size / n_rows); + } + if (r.pq_kmeans_trainset_fraction == 0) { + // NB: we'll have actually `pq_dim` times more samples than this + // (because the dataset is reinterpreted as `[n_rows * pq_dim, pq_len]`) + double pq_trainset_size = 1000.0 * (1u << r.pq_bits); + r.pq_kmeans_trainset_fraction = std::min(1.0, pq_trainset_size / n_rows); + } + return r; +} + +template +auto transform_data(const raft::resources& res, DatasetT dataset) + -> device_mdarray +{ + using index_type = typename DatasetT::index_type; + using extents_type = typename DatasetT::extents_type; + using layout_type = typename DatasetT::layout_type; + using out_mdarray_type = device_mdarray; + if constexpr (std::is_same_v>) { return dataset; } + + auto result = raft::make_device_mdarray(res, dataset.extents()); + + linalg::map(res, + result.view(), + spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(dataset.view())); + + return result; +} + +/** Fix the internal indexing type to avoid integer underflows/overflows */ +using ix_t = int64_t; + +template +auto train_vq(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) + -> device_matrix +{ + const ix_t n_rows = dataset.extent(0); + const ix_t vq_n_centers = params.vq_n_centers; + const ix_t dim = dataset.extent(1); + const ix_t n_rows_train = n_rows * params.vq_kmeans_trainset_fraction; + + // Subsample the dataset and transform into the required type if necessary + auto vq_trainset = raft::util::subsample(res, dataset, n_rows_train); + auto vq_centers = raft::make_device_matrix(res, vq_n_centers, dim); + + using kmeans_in_type = typename DatasetT::value_type; + raft::cluster::kmeans_balanced_params kmeans_params; + kmeans_params.n_iters = params.kmeans_n_iters; + kmeans_params.metric = raft::distance::DistanceType::L2Expanded; + auto vq_centers_view = + raft::make_device_matrix_view(vq_centers.data_handle(), vq_n_centers, dim); + auto vq_trainset_view = raft::make_device_matrix_view( + vq_trainset.data_handle(), n_rows_train, dim); + raft::cluster::kmeans_balanced::fit( + res, + kmeans_params, + vq_trainset_view, + vq_centers_view, + spatial::knn::detail::utils::mapping{}); + + return vq_centers; +} + +template +auto predict_vq(const raft::resources& res, const DatasetT& dataset, const VqCentersT& vq_centers) + -> device_vector +{ + using kmeans_data_type = typename DatasetT::value_type; + using kmeans_math_type = typename VqCentersT::value_type; + using index_type = typename DatasetT::index_type; + using label_type = LabelT; + + auto vq_labels = raft::make_device_vector(res, dataset.extent(0)); + + raft::cluster::kmeans_balanced_params kmeans_params; + kmeans_params.metric = raft::distance::DistanceType::L2Expanded; + + auto vq_centers_view = raft::make_device_matrix_view( + vq_centers.data_handle(), vq_centers.extent(0), vq_centers.extent(1)); + + auto vq_dataset_view = raft::make_device_matrix_view( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + + raft::cluster::kmeans_balanced:: + predict( + res, + kmeans_params, + vq_dataset_view, + vq_centers_view, + vq_labels.view(), + spatial::knn::detail::utils::mapping{}); + + return vq_labels; +} + +template +auto train_pq(const raft::resources& res, + const vpq_params& params, + const DatasetT& dataset, + const device_matrix_view& vq_centers) + -> device_matrix +{ + const ix_t n_rows = dataset.extent(0); + const ix_t dim = dataset.extent(1); + const ix_t pq_dim = params.pq_dim; + const ix_t pq_bits = params.pq_bits; + const ix_t pq_n_centers = ix_t{1} << pq_bits; + const ix_t pq_len = raft::div_rounding_up_safe(dim, pq_dim); + const ix_t n_rows_train = n_rows * params.pq_kmeans_trainset_fraction; + + // Subsample the dataset and transform into the required type if necessary + auto pq_trainset = transform_data(res, raft::util::subsample(res, dataset, n_rows_train)); + + // Subtract VQ centers + { + auto vq_labels = predict_vq(res, pq_trainset, vq_centers); + using index_type = typename DatasetT::index_type; + linalg::map_offset( + res, + pq_trainset.view(), + [labels = vq_labels.view(), centers = vq_centers, dim] __device__(index_type off, MathT x) { + index_type i = off / dim; + index_type j = off % dim; + return x - centers(labels(i), j); + }, + raft::make_const_mdspan(pq_trainset.view())); + } + + auto pq_centers = raft::make_device_matrix(res, pq_n_centers, pq_len); + + // Train PQ centers + { + raft::cluster::kmeans_balanced_params kmeans_params; + kmeans_params.n_iters = params.kmeans_n_iters; + kmeans_params.metric = raft::distance::DistanceType::L2Expanded; + + auto pq_centers_view = + raft::make_device_matrix_view(pq_centers.data_handle(), pq_n_centers, pq_len); + + auto pq_trainset_view = raft::make_device_matrix_view( + pq_trainset.data_handle(), n_rows_train * pq_dim, pq_len); + + raft::cluster::kmeans_balanced::fit( + res, kmeans_params, pq_trainset_view, pq_centers_view); + } + + return pq_centers; +} + +template +__device__ auto compute_code(device_matrix_view dataset, + device_matrix_view vq_centers, + device_matrix_view pq_centers, + IdxT i, + uint32_t j, + LabelT vq_label) -> uint8_t +{ + auto data_mapping = spatial::knn::detail::utils::mapping{}; + uint32_t lane_id = Pow2::mod(laneId()); + + const uint32_t pq_book_size = pq_centers.extent(0); + const uint32_t pq_len = pq_centers.extent(1); + float min_dist = std::numeric_limits::infinity(); + uint8_t code = 0; + // calculate the distance for each PQ cluster, find the minimum for each thread + for (uint32_t l = lane_id; l < pq_book_size; l += SubWarpSize) { + // NB: the L2 quantifiers on residuals are always trained on L2 metric. + float d = 0.0f; + for (uint32_t k = 0; k < pq_len; k++) { + auto jk = j * pq_len + k; + auto x = data_mapping(dataset(i, jk)) - vq_centers(vq_label, jk); + auto t = x - pq_centers(l, k); + d += t * t; + } + if (d < min_dist) { + min_dist = d; + code = uint8_t(l); + } + } + // reduce among threads +#pragma unroll + for (uint32_t stride = SubWarpSize >> 1; stride > 0; stride >>= 1) { + const auto other_dist = shfl_xor(min_dist, stride, SubWarpSize); + const auto other_code = shfl_xor(code, stride, SubWarpSize); + if (other_dist < min_dist) { + min_dist = other_dist; + code = other_code; + } + } + return code; +} + +template +__launch_bounds__(BlockSize) RAFT_KERNEL + process_and_fill_codes_kernel(device_matrix_view out_codes, + device_matrix_view dataset, + device_matrix_view vq_centers, + device_vector_view vq_labels, + device_matrix_view pq_centers) +{ + constexpr uint32_t kSubWarpSize = std::min(WarpSize, 1u << PqBits); + using subwarp_align = Pow2; + const IdxT row_ix = subwarp_align::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x}); + if (row_ix >= out_codes.extent(0)) { return; } + + const uint32_t pq_dim = raft::div_rounding_up_unsafe(vq_centers.extent(1), pq_centers.extent(1)); + + const uint32_t lane_id = Pow2::mod(threadIdx.x); + const LabelT vq_label = vq_labels(row_ix); + + // write label + auto* out_label_ptr = reinterpret_cast(&out_codes(row_ix, 0)); + if (lane_id == 0) { *out_label_ptr = vq_label; } + + auto* out_codes_ptr = reinterpret_cast(out_label_ptr + 1); + ivf_pq::detail::bitfield_view_t code_view{out_codes_ptr}; + for (uint32_t j = 0; j < pq_dim; j++) { + // find PQ label + uint8_t code = compute_code(dataset, vq_centers, pq_centers, row_ix, j, vq_label); + // TODO: this writes in global memory one byte per warp, which is very slow. + // It's better to keep the codes in the shared memory or registers and dump them at once. + if (lane_id == 0) { code_view[j] = code; } + } +} + +template +auto process_and_fill_codes(const raft::resources& res, + const vpq_params& params, + const DatasetT& dataset, + device_matrix_view vq_centers, + device_matrix_view pq_centers) + -> device_matrix +{ + using data_t = typename DatasetT::value_type; + using cdataset_t = vpq_dataset; + using label_t = uint32_t; + + const ix_t n_rows = dataset.extent(0); + const ix_t dim = dataset.extent(1); + const ix_t pq_dim = params.pq_dim; + const ix_t pq_bits = params.pq_bits; + const ix_t pq_n_centers = ix_t{1} << pq_bits; + // NB: codes must be aligned at least to sizeof(label_t) to be able to read labels. + const ix_t codes_rowlen = + sizeof(label_t) * (1 + raft::div_rounding_up_safe(pq_dim * pq_bits, 8 * sizeof(label_t))); + + auto codes = raft::make_device_matrix(res, n_rows, codes_rowlen); + + auto stream = raft::resource::get_cuda_stream(res); + + // TODO: with scaling workspace we could choose the batch size dynamically + constexpr ix_t kReasonableMaxBatchSize = 65536; + constexpr ix_t kBlockSize = 256; + const ix_t threads_per_vec = std::min(WarpSize, pq_n_centers); + dim3 threads(kBlockSize, 1, 1); + ix_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); + auto kernel = [](uint32_t pq_bits) { + switch (pq_bits) { + case 4: return process_and_fill_codes_kernel; + case 5: return process_and_fill_codes_kernel; + case 6: return process_and_fill_codes_kernel; + case 7: return process_and_fill_codes_kernel; + case 8: return process_and_fill_codes_kernel; + default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); + } + }(pq_bits); + for (const auto& batch : + spatial::knn::detail::utils::batch_load_iterator(dataset.data_handle(), + n_rows, + dim, + max_batch_size, + stream, + rmm::mr::get_current_device_resource())) { + auto batch_view = raft::make_device_matrix_view(batch.data(), ix_t(batch.size()), dim); + auto labels = predict_vq(res, batch_view, vq_centers); + dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize / threads_per_vec), 1, 1); + kernel<<>>( + make_device_matrix_view( + codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen), + batch_view, + vq_centers, + make_const_mdspan(labels.view()), + pq_centers); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + return codes; +} + +template +auto vpq_convert_math_type(const raft::resources& res, vpq_dataset&& src) + -> vpq_dataset +{ + auto vq_code_book = make_device_mdarray(res, src.vq_code_book.extents()); + auto pq_code_book = make_device_mdarray(res, src.pq_code_book.extents()); + + linalg::map(res, + vq_code_book.view(), + spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(src.vq_code_book.view())); + linalg::map(res, + pq_code_book.view(), + spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(src.pq_code_book.view())); + return vpq_dataset{ + std::move(vq_code_book), std::move(pq_code_book), std::move(src.data)}; +} + +template +auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) + -> vpq_dataset +{ + // Use a heuristic to impute missing parameters. + auto ps = fill_missing_params_heuristics(params, dataset); + + // Train codes + auto vq_code_book = train_vq(res, ps, dataset); + auto pq_code_book = + train_pq(res, ps, dataset, raft::make_const_mdspan(vq_code_book.view())); + + // Encode dataset + auto codes = process_and_fill_codes(res, + ps, + dataset, + raft::make_const_mdspan(vq_code_book.view()), + raft::make_const_mdspan(pq_code_book.view())); + + return vpq_dataset{ + std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; +} + +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/vpq_dataset.cuh b/cpp/include/raft/neighbors/vpq_dataset.cuh new file mode 100644 index 0000000000..73ee6c52ed --- /dev/null +++ b/cpp/include/raft/neighbors/vpq_dataset.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "dataset.hpp" +#include "detail/vpq_dataset.cuh" + +#include + +namespace raft::neighbors { + +/** + * @brief Compress a dataset for use in CAGRA-Q search in place of the original data. + * + * @tparam DatasetT a row-major mdspan or mdarray (device or host). + * @tparam MathT a type of the codebook elements and internal math ops. + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res + * @param[in] params VQ and PQ parameters for compressing the data + * @param[in] dataset a row-major mdspan or mdarray (device or host) [n_rows, dim]. + */ +template +auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) + -> vpq_dataset +{ + if constexpr (std::is_same_v) { + return detail::vpq_convert_math_type( + res, detail::vpq_build(res, params, dataset)); + } else { + return detail::vpq_build(res, params, dataset); + } +} + +} // namespace raft::neighbors From d14cac2a9f36d24559c8a0a61806cef91913048b Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Mon, 18 Mar 2024 18:58:46 +0100 Subject: [PATCH 06/10] Add explicit instantiations for IVF-PQ search kernels used in tests (#2212) Compilation of IVF-PQ search kernels can be time consuming. In `libraft.so` the compilation is done in parallel for kernels without filtering and with `int64_t` index type. We have test with `uint32_t` index type as well as tests for `bitset_filter` with both 32 and 64 bit index types. This PR adds explicit template instantiations for the test. This way we avoid repeated compilation of the kernels with filter and this also enables parallel compilation of the `compute_similarity` kernel for different template types. The kernels with these additional type parameters are not added to `libraft.so`, only linked together with the test executable. Note that this PR does not increase the number of compiled kernels, but it enables to compile them in parallel. Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/2212 --- cpp/bench/prims/CMakeLists.txt | 8 + .../knn/ivf_pq_filter_float_int64_t.cu | 5 +- cpp/include/raft/core/detail/nvtx.hpp | 1 + .../raft/neighbors/detail/ivf_pq_build.cuh | 46 +++-- .../ivf_pq_compute_similarity_template.cuh | 71 +++++++ ...pq_compute_similarity_filters_test-ext.cuh | 181 ++++++++++++++++++ .../neighbors/ivf_pq_search_test-ext.cuh | 88 +++++++++ .../ivf_pq_compute_similarity_00_generate.py | 107 ++++------- .../ivf_pq_compute_similarity_float_float.cu | 57 +----- ...compute_similarity_float_float_bitset32.cu | 28 +++ ...compute_similarity_float_float_bitset64.cu | 28 +++ ...q_compute_similarity_float_float_filt32.cu | 28 +++ ...f_pq_compute_similarity_float_fp8_false.cu | 57 +----- ...ute_similarity_float_fp8_false_bitset32.cu | 28 +++ ...ute_similarity_float_fp8_false_bitset64.cu | 28 +++ ...mpute_similarity_float_fp8_false_filt32.cu | 28 +++ ...vf_pq_compute_similarity_float_fp8_true.cu | 57 +----- ...pute_similarity_float_fp8_true_bitset32.cu | 28 +++ ...pute_similarity_float_fp8_true_bitset64.cu | 28 +++ ...ompute_similarity_float_fp8_true_filt32.cu | 28 +++ .../ivf_pq_compute_similarity_float_half.cu | 57 +----- ..._compute_similarity_float_half_bitset32.cu | 28 +++ ..._compute_similarity_float_half_bitset64.cu | 28 +++ ...pq_compute_similarity_float_half_filt32.cu | 28 +++ ...vf_pq_compute_similarity_half_fp8_false.cu | 57 +----- ...pute_similarity_half_fp8_false_bitset32.cu | 28 +++ ...pute_similarity_half_fp8_false_bitset64.cu | 28 +++ ...ompute_similarity_half_fp8_false_filt32.cu | 28 +++ ...ivf_pq_compute_similarity_half_fp8_true.cu | 57 +----- ...mpute_similarity_half_fp8_true_bitset32.cu | 28 +++ ...mpute_similarity_half_fp8_true_bitset64.cu | 28 +++ ...compute_similarity_half_fp8_true_filt32.cu | 28 +++ .../ivf_pq_compute_similarity_half_half.cu | 57 +----- ...q_compute_similarity_half_half_bitset32.cu | 28 +++ ...q_compute_similarity_half_half_bitset64.cu | 28 +++ ..._pq_compute_similarity_half_half_filt32.cu | 28 +++ .../ivf_pq_search_filtering_float_int64_t.cu | 43 +++++ cpp/test/CMakeLists.txt | 24 +++ .../ann_ivf_pq/ivf_pq_build_float_uint32_t.cu | 37 ++++ .../ann_ivf_pq/ivf_pq_build_test-ext.cuh | 38 ++++ .../ivf_pq_search_float_uint32_t.cu | 68 +++++++ .../ann_ivf_pq/test_filter_float_int64_t.cu | 6 +- .../ann_ivf_pq/test_filter_int8_t_int64_t.cu | 6 +- .../ann_ivf_pq/test_float_uint32_t.cu | 12 +- .../ann_ivf_pq/test_int8_t_int64_t.cu | 3 +- 45 files changed, 1245 insertions(+), 486 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity_template.cuh create mode 100644 cpp/internal/raft_internal/neighbors/ivf_pq_compute_similarity_filters_test-ext.cuh create mode 100644 cpp/internal/raft_internal/neighbors/ivf_pq_search_test-ext.cuh create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_filt32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_filt32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_filt32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_filt32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_filt32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_filt32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_filt32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu create mode 100644 cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_float_uint32_t.cu create mode 100644 cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_test-ext.cuh create mode 100644 cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 3a2431cd34..5577881ef7 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -156,6 +156,14 @@ if(BUILD_PRIMS_BENCH) bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu bench/prims/neighbors/knn/ivf_pq_int8_t_int64_t.cu bench/prims/neighbors/knn/ivf_pq_uint8_t_int64_t.cu + src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu bench/prims/neighbors/refine_float_int64_t.cu bench/prims/neighbors/refine_uint8_t_int64_t.cu bench/prims/main.cpp diff --git a/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu b/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu index 9534515cbb..1840eca99d 100644 --- a/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu +++ b/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,10 @@ * limitations under the License. */ -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter #include "../knn.cuh" +#include +#include namespace raft::bench::spatial { KNN_REGISTER(float, int64_t, ivf_pq_filter_knn, kInputsFilter, kNoCopyOnly, kScopeFull); diff --git a/cpp/include/raft/core/detail/nvtx.hpp b/cpp/include/raft/core/detail/nvtx.hpp index 8afd1f16c6..82db75de84 100644 --- a/cpp/include/raft/core/detail/nvtx.hpp +++ b/cpp/include/raft/core/detail/nvtx.hpp @@ -28,6 +28,7 @@ #include #include #include +#include namespace raft::common::nvtx::detail { diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index b7796d52fa..8e3f7dbaf3 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -61,6 +61,8 @@ namespace raft::neighbors::ivf_pq::detail { using namespace raft::spatial::knn::detail; // NOLINT +using internal_extents_t = int64_t; // The default mdspan extent type used internally. + template __launch_bounds__(BlockDim) RAFT_KERNEL copy_warped_kernel( T* out, uint32_t ld_out, const S* in, uint32_t ld_in, uint32_t n_cols, size_t n_rows) @@ -442,15 +444,16 @@ void train_per_subset(raft::resources const& handle, stream); // train PQ codebook for this subspace - auto sub_trainset_view = - raft::make_device_matrix_view(sub_trainset.data(), n_rows, index.pq_len()); - auto centers_tmp_view = raft::make_device_matrix_view( + auto sub_trainset_view = raft::make_device_matrix_view( + sub_trainset.data(), n_rows, index.pq_len()); + auto centers_tmp_view = raft::make_device_matrix_view( pq_centers_tmp.data() + index.pq_book_size() * index.pq_len() * j, index.pq_book_size(), index.pq_len()); - auto sub_labels_view = raft::make_device_vector_view(sub_labels.data(), n_rows); - auto cluster_sizes_view = - raft::make_device_vector_view(pq_cluster_sizes.data(), index.pq_book_size()); + auto sub_labels_view = + raft::make_device_vector_view(sub_labels.data(), n_rows); + auto cluster_sizes_view = raft::make_device_vector_view( + pq_cluster_sizes.data(), index.pq_book_size()); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.n_iters = kmeans_n_iters; kmeans_params.metric = raft::distance::DistanceType::L2Expanded; @@ -525,17 +528,17 @@ void train_per_cluster(raft::resources const& handle, size_t available_rows = size_t(cluster_size) * size_t(index.pq_dim()); auto pq_n_rows = uint32_t(std::min(big_enough, available_rows)); // train PQ codebook for this cluster - auto rot_vectors_view = raft::make_device_matrix_view( + auto rot_vectors_view = raft::make_device_matrix_view( rot_vectors.data(), pq_n_rows, index.pq_len()); - auto centers_tmp_view = raft::make_device_matrix_view( + auto centers_tmp_view = raft::make_device_matrix_view( pq_centers_tmp.data() + static_cast(index.pq_book_size()) * static_cast(index.pq_len()) * static_cast(l), index.pq_book_size(), index.pq_len()); auto pq_labels_view = - raft::make_device_vector_view(pq_labels.data(), pq_n_rows); - auto pq_cluster_sizes_view = - raft::make_device_vector_view(pq_cluster_sizes.data(), index.pq_book_size()); + raft::make_device_vector_view(pq_labels.data(), pq_n_rows); + auto pq_cluster_sizes_view = raft::make_device_vector_view( + pq_cluster_sizes.data(), index.pq_book_size()); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.n_iters = kmeans_n_iters; kmeans_params.metric = raft::distance::DistanceType::L2Expanded; @@ -1587,11 +1590,11 @@ void extend(raft::resources const& handle, cudaMemcpyDefault, stream)); for (const auto& batch : vec_batches) { - auto batch_data_view = - raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); - auto batch_labels_view = raft::make_device_vector_view( + auto batch_data_view = raft::make_device_matrix_view( + batch.data(), batch.size(), index->dim()); + auto batch_labels_view = raft::make_device_vector_view( new_data_labels.data() + batch.offset(), batch.size()); - auto centers_view = raft::make_device_matrix_view( + auto centers_view = raft::make_device_matrix_view( cluster_centers.data(), n_clusters, index->dim()); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.metric = index->metric(); @@ -1767,10 +1770,10 @@ auto build(raft::resources const& handle, auto cluster_centers = cluster_centers_buf.data(); // Train balanced hierarchical kmeans clustering - auto trainset_const_view = - raft::make_device_matrix_view(trainset.data(), n_rows_train, index.dim()); - auto centers_view = - raft::make_device_matrix_view(cluster_centers, index.n_lists(), index.dim()); + auto trainset_const_view = raft::make_device_matrix_view( + trainset.data(), n_rows_train, index.dim()); + auto centers_view = raft::make_device_matrix_view( + cluster_centers, index.n_lists(), index.dim()); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; kmeans_params.metric = index.metric(); @@ -1779,9 +1782,10 @@ auto build(raft::resources const& handle, // Trainset labels are needed for training PQ codebooks rmm::device_uvector labels(n_rows_train, stream, device_memory); - auto centers_const_view = raft::make_device_matrix_view( + auto centers_const_view = raft::make_device_matrix_view( cluster_centers, index.n_lists(), index.dim()); - auto labels_view = raft::make_device_vector_view(labels.data(), n_rows_train); + auto labels_view = + raft::make_device_vector_view(labels.data(), n_rows_train); raft::cluster::kmeans_balanced::predict(handle, kmeans_params, trainset_const_view, diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity_template.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity_template.cuh new file mode 100644 index 0000000000..83dd994bd6 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity_template.cuh @@ -0,0 +1,71 @@ + +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is to be used in source files generated by + * src/neighbors/detailivf_pq_compute_similarity_00_generate.py + */ + +#pragma once + +#include +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , diff --git a/cpp/internal/raft_internal/neighbors/ivf_pq_compute_similarity_filters_test-ext.cuh b/cpp/internal/raft_internal/neighbors/ivf_pq_compute_similarity_filters_test-ext.cuh new file mode 100644 index 0000000000..aa14ab19b8 --- /dev/null +++ b/cpp/internal/raft_internal/neighbors/ivf_pq_compute_similarity_filters_test-ext.cuh @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // RAFT_WEAK_FUNCTION +#include // raft::distance::DistanceType +#include +#include // raft::neighbors::ivf_pq::detail::fp_8bit +#include // none_ivf_sample_filter +#include // none_ivf_sample_filter + +#include // rmm::cuda_stream_view + +#include // __half + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + extern template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + extern template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + float, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); + +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + float, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + float, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/internal/raft_internal/neighbors/ivf_pq_search_test-ext.cuh b/cpp/internal/raft_internal/neighbors/ivf_pq_search_test-ext.cuh new file mode 100644 index 0000000000..7a65e2d2f8 --- /dev/null +++ b/cpp/internal/raft_internal/neighbors/ivf_pq_search_test-ext.cuh @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::device_matrix_view +#include // raft::resources +#include +#include // raft::neighbors::ivf_pq::index +#include +#include + +#include + +#include // int64_t + +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + extern template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + extern template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr); \ + \ + extern template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances) + +instantiate_raft_neighbors_ivf_pq_search(float, uint32_t); + +#undef instantiate_raft_neighbors_ivf_pq_search + +#define instantiate_raft_neighbors_ivf_pq_search_with_filtering(T, IdxT, FilterT) \ + extern template void raft::neighbors::ivf_pq::search_with_filtering( \ + raft::resources const& handle, \ + const search_params& params, \ + const index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + FilterT sample_filter) + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + float, uint32_t, raft::neighbors::filtering::bitset_filter); + +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + float, uint32_t, raft::neighbors::filtering::none_ivf_sample_filter); + +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + float, int64_t, raft::neighbors::filtering::bitset_filter); + +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + int8_t, int64_t, raft::neighbors::filtering::bitset_filter); + +#undef COMMA +#undef instantiate_raft_neighbors_ivf_pq_search_with_filtering diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py index 670ed57ed1..9825a48f81 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -header = """ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. +header = """/* + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,78 +30,56 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ - -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT, IvfSampleFilterT) \\ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \\ - const cudaDeviceProp& dev_props, \\ - bool manage_local_topk, \\ - int locality_hint, \\ - double preferred_shmem_carveout, \\ - uint32_t pq_bits, \\ - uint32_t pq_dim, \\ - uint32_t precomp_data_count, \\ - uint32_t n_queries, \\ - uint32_t n_probes, \\ - uint32_t topk) -> raft::neighbors::ivf_pq::detail::selected; \\ -\\ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \\ - raft::neighbors::ivf_pq::detail::selected s, \\ - rmm::cuda_stream_view stream, \\ - uint32_t dim, \\ - uint32_t n_probes, \\ - uint32_t pq_dim, \\ - uint32_t n_queries, \\ - uint32_t queries_offset, \\ - raft::distance::DistanceType metric, \\ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \\ - uint32_t topk, \\ - uint32_t max_samples, \\ - const float* cluster_centers, \\ - const float* pq_centers, \\ - const uint8_t* const* pq_dataset, \\ - const uint32_t* cluster_labels, \\ - const uint32_t* _chunk_indices, \\ - const float* queries, \\ - const uint32_t* index_list, \\ - float* query_kths, \\ - IvfSampleFilterT sample_filter, \\ - LutT* lut_scores, \\ - OutT* _out_scores, \\ - uint32_t* _out_indices); - - -#define COMMA , + +#include """ -trailer = """ -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select -""" +none_filter_int64 = "raft::neighbors::filtering::ivf_to_sample_filter" \ + "" +none_filter_int32 = "raft::neighbors::filtering::ivf_to_sample_filter" \ + "" +bitset_filter32 = "raft::neighbors::filtering::ivf_to_sample_filter" \ + ">" +bitset_filter64 = "raft::neighbors::filtering::ivf_to_sample_filter" \ + ">" types = dict( - half_fp8_false=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>"), - half_fp8_true=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>"), - half_half=("half", "half"), - float_half=("float", "half"), - float_float= ("float", "float"), - float_fp8_false=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>"), - float_fp8_true=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>"), + half_fp8_false=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", none_filter_int64), + half_fp8_true=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", none_filter_int64), + half_half=("half", "half", none_filter_int64), + float_half=("float", "half", none_filter_int64), + float_float= ("float", "float", none_filter_int64), + float_fp8_false=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", none_filter_int64), + float_fp8_true=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", none_filter_int64), + half_fp8_false_filt32=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", none_filter_int32), + half_fp8_true_filt32=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", none_filter_int32), + half_half_filt32=("half", "half", none_filter_int32), + float_half_filt32=("float", "half", none_filter_int32), + float_float_filt32= ("float", "float", none_filter_int32), + float_fp8_false_filt32=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", none_filter_int32), + float_fp8_true_filt32=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", none_filter_int32), + half_fp8_false_bitset32=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", bitset_filter32), + half_fp8_true_bitset32=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", bitset_filter32), + half_half_bitset32=("half", "half", bitset_filter32), + float_half_bitset32=("float", "half", bitset_filter32), + float_float_bitset32= ("float", "float", bitset_filter32), + float_fp8_false_bitset32=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", bitset_filter32), + float_fp8_true_bitset32=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", bitset_filter32), + half_fp8_false_bitset64=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", bitset_filter64), + half_fp8_true_bitset64=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", bitset_filter64), + half_half_bitset64=("half", "half", bitset_filter64), + float_half_bitset64=("float", "half", bitset_filter64), + float_float_bitset64= ("float", "float", bitset_filter64), + float_fp8_false_bitset64=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", bitset_filter64), + float_fp8_true_bitset64=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", bitset_filter64) ) -for path_key, (OutT, LutT) in types.items(): +for path_key, (OutT, LutT, FilterT) in types.items(): path = f"ivf_pq_compute_similarity_{path_key}.cu" with open(path, "w") as f: f.write(header) - f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::filtering::ivf_to_sample_filter);\n") - f.write(trailer) + f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, {FilterT});\n") print(f"src/neighbors/detail/{path}") diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu index 7e17d6822a..db51608ae1 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,65 +16,13 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , +#include instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, float, raft::neighbors::filtering::ivf_to_sample_filter< int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset32.cu new file mode 100644 index 0000000000..caaf40abdf --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + float, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu new file mode 100644 index 0000000000..7801c25e9f --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + float, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_filt32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_filt32.cu new file mode 100644 index 0000000000..45ae348849 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_filt32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + float, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu index c1b72dab33..2f5bcf8f92 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,65 +16,13 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , +#include instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, raft::neighbors::filtering::ivf_to_sample_filter< int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset32.cu new file mode 100644 index 0000000000..e7f2c44254 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu new file mode 100644 index 0000000000..01b6900bb8 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_filt32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_filt32.cu new file mode 100644 index 0000000000..9f8d453364 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_filt32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu index fdff0860fc..06d21bcd50 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,65 +16,13 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , +#include instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, raft::neighbors::filtering::ivf_to_sample_filter< int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset32.cu new file mode 100644 index 0000000000..8b733a23c1 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu new file mode 100644 index 0000000000..77e4f9a023 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_filt32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_filt32.cu new file mode 100644 index 0000000000..3e036e3df4 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_filt32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu index 7205544370..ff42f5e041 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,65 +16,13 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , +#include instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, half, raft::neighbors::filtering::ivf_to_sample_filter< int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset32.cu new file mode 100644 index 0000000000..40b6313865 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu new file mode 100644 index 0000000000..9cedabdb11 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_filt32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_filt32.cu new file mode 100644 index 0000000000..61422bbc36 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_filt32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu index 2ac6c3527b..d2064cfe97 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,65 +16,13 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , +#include instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, raft::neighbors::filtering::ivf_to_sample_filter< int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset32.cu new file mode 100644 index 0000000000..1127f39f71 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu new file mode 100644 index 0000000000..0330bf58d6 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_filt32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_filt32.cu new file mode 100644 index 0000000000..d20f7921d5 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_filt32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu index 70f3ffdb0c..9dc954406e 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,65 +16,13 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , +#include instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, raft::neighbors::filtering::ivf_to_sample_filter< int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset32.cu new file mode 100644 index 0000000000..9131fa25a8 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu new file mode 100644 index 0000000000..8b4521b31b --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_filt32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_filt32.cu new file mode 100644 index 0000000000..71b63cf4a0 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_filt32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu index 5cc1cb8038..f527d879be 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,65 +16,13 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , +#include instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, half, raft::neighbors::filtering::ivf_to_sample_filter< int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset32.cu new file mode 100644 index 0000000000..8e1962e2bb --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu new file mode 100644 index 0000000000..e9671703e7 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_filt32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_filt32.cu new file mode 100644 index 0000000000..b66a07d1a9 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_filt32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu new file mode 100644 index 0000000000..39af78f12e --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // raft::device_matrix_view +#include // raft::resources +#include +#include // raft::neighbors::ivf_pq::index +#include +#include + +#include + +#include // int64_t + +#define instantiate_raft_neighbors_ivf_pq_search_with_filtering(T, IdxT, FilterT) \ + template void raft::neighbors::ivf_pq::search_with_filtering( \ + raft::resources const& handle, \ + const search_params& params, \ + const index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + FilterT sample_filter) + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + float, int64_t, raft::neighbors::filtering::bitset_filter); + +#undef COMMA +#undef instantiate_raft_neighbors_ivf_pq_search_with_filtering diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index dd7eb839ab..65d6e738a2 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -401,6 +401,30 @@ if(BUILD_TESTS) test/neighbors/ann_ivf_flat/test_float_int64_t.cu test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu + test/neighbors/ann_ivf_pq/ivf_pq_build_float_uint32_t.cu + test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu + src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_float_filt32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_filt32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_filt32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_half_filt32.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_filt32.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_filt32.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_half_filt32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset32.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset32.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset32.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu test/neighbors/ann_ivf_pq/test_float_uint32_t.cu test/neighbors/ann_ivf_pq/test_float_int64_t.cu test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu diff --git a/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_float_uint32_t.cu b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_float_uint32_t.cu new file mode 100644 index 0000000000..5ba21c3c2f --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_float_uint32_t.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index +#include + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + template auto raft::neighbors::ivf_pq::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(float, uint32_t); + +#undef instantiate_raft_neighbors_ivf_pq_build diff --git a/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_test-ext.cuh b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_test-ext.cuh new file mode 100644 index 0000000000..cd5435ab2e --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_test-ext.cuh @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include // raft::neighbors::ivf_pq::index +#include + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + extern template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + extern template auto raft::neighbors::ivf_pq::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(float, uint32_t); + +#undef instantiate_raft_neighbors_ivf_pq_build diff --git a/cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu new file mode 100644 index 0000000000..942d0fcc44 --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index +#include + +#include + +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_neighbors_ivf_pq_search(float, uint32_t); + +#undef instantiate_raft_neighbors_ivf_pq_search + +#define instantiate_raft_neighbors_ivf_pq_search_with_filtering(T, IdxT, FilterT) \ + template void raft::neighbors::ivf_pq::search_with_filtering( \ + raft::resources const& handle, \ + const search_params& params, \ + const index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + FilterT sample_filter) + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + float, uint32_t, raft::neighbors::filtering::bitset_filter); + +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + int8_t, int64_t, raft::neighbors::filtering::bitset_filter); + +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + float, uint32_t, raft::neighbors::filtering::none_ivf_sample_filter); + +#undef COMMA +#undef instantiate_raft_neighbors_ivf_pq_search_with_filtering diff --git a/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu index 17f72fb08a..70d5d8761f 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,11 @@ * limitations under the License. */ -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter #include "../ann_ivf_pq.cuh" +#include +#include + namespace raft::neighbors::ivf_pq { using f32_f32_i64_filter = ivf_pq_filter_test; diff --git a/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu index 537dbb4979..ba96a8db0b 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,11 @@ * limitations under the License. */ -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter #include "../ann_ivf_pq.cuh" +#include +#include + namespace raft::neighbors::ivf_pq { using f32_i08_i64_filter = ivf_pq_filter_test; diff --git a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu index a6cfab1f19..b8ada2249a 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu @@ -14,15 +14,11 @@ * limitations under the License. */ -// XXX: the uint32_t instance is not compiled in libraft.so. So we allow -// instantiating the template here. -// -// TODO: consider removing this test or consider adding an instantiation to the -// library. - -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY - #include "../ann_ivf_pq.cuh" +#include "ivf_pq_build_test-ext.cuh" + +#include +#include namespace raft::neighbors::ivf_pq { diff --git a/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu index 014e96a2db..970bdd6a12 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include "../ann_ivf_pq.cuh" +#include namespace raft::neighbors::ivf_pq { using f32_i08_i64 = ivf_pq_test; From bd50c37f52eeff459ca032bd0ef5f7920c2bcc8d Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 19 Mar 2024 10:51:23 +0100 Subject: [PATCH 07/10] Fix ANN bench ground truth generation for k>1024 (#2180) Generating ANN bench ground truth is affected by bug #2171, when k>1024. This PR fixes the issue for the ground truth generation. Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2180 --- .../raft/neighbors/detail/knn_merge_parts.cuh | 3 +++ .../raft-ann-bench/generate_groundtruth/__main__.py | 13 ++++--------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh index 0395e86e43..33324714fd 100644 --- a/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh +++ b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -168,5 +169,7 @@ inline void knn_merge_parts(const value_t* inK, else if (k <= 1024) knn_merge_parts_impl( inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else + THROW("Unimplemented for k=%d, knn_merge_parts works for k<=1024", k); } } // namespace raft::neighbors::detail diff --git a/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py b/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py index f4d97edea5..a5ebb76635 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -62,17 +62,12 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"): X = cp.asarray(dataset[i : i + n_batch, :], cp.float32) - D, Ind = knn( - X, - queries, - k, - metric=metric, - handle=handle, - global_id_offset=i, # shift neighbor index by offset i - ) + D, Ind = knn(X, queries, k, metric=metric, handle=handle) handle.sync() D, Ind = cp.asarray(D), cp.asarray(Ind) + Ind += i # shift neighbor index by offset i + if distances is None: distances = D indices = Ind From e53aa0c18d77f263ed6aa058dff87d3a6dd43590 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Akif=20=C3=87=C3=96RD=C3=9CK?= Date: Tue, 19 Mar 2024 15:26:15 +0100 Subject: [PATCH 08/10] Fix bug in blockRankedReduce (#2226) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit There was a bug appearing for negative floating point numbers with a max reduce operation. The `std::numeric_limits::min()` is greater than the negative floating point values whereas we want it to be smaller than all representable values. This PR replaces the `min` with the `lowest`. Authors: - Akif ÇÖRDÜK (https://github.com/akifcorduk) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Bradley Dice (https://github.com/bdice) - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/2226 --- cpp/include/raft/util/reduction.cuh | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/util/reduction.cuh b/cpp/include/raft/util/reduction.cuh index 7e897e67f5..2c2b1aa228 100644 --- a/cpp/include/raft/util/reduction.cuh +++ b/cpp/include/raft/util/reduction.cuh @@ -157,11 +157,10 @@ DI std::pair blockRankedReduce(T val, val = values[lane]; idx = indices[lane]; } else { - // get the min if it is a max op, get the max if it is a min op - val = reduce_op(std::numeric_limits::min(), std::numeric_limits::max()) == - std::numeric_limits::min() - ? std::numeric_limits::max() - : std::numeric_limits::min(); + // get the lower_bound of the type if it is a max op, + // get the upper bound of the type if it is a min op + val = reduce_op(lower_bound(), upper_bound()) == lower_bound() ? upper_bound() + : lower_bound(); idx = -1; } __syncthreads(); From 413e34e786b1c6b50a9dc6d76ff57d1f8a31555d Mon Sep 17 00:00:00 2001 From: Mahesh Doijade <36705640+mdoijade@users.noreply.github.com> Date: Tue, 19 Mar 2024 22:33:54 +0530 Subject: [PATCH 09/10] Add fused cosine 1-NN cutlass based kernel (#2125) - Adds cosine 1-NN cutlass based kernel for SM 8.0 or higher using tensor cores. - based on 3x TF32 - unifies the fusedDistanceNN kernels for L2/cosine. - expose this API in pylibraft as `fused_distance_nn_arg_min` supporting cosine & L2 distance metrics. Authors: - Mahesh Doijade (https://github.com/mdoijade) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/2125 --- cpp/CMakeLists.txt | 4 +- cpp/bench/prims/CMakeLists.txt | 12 +- .../distance/detail/fused_distance_nn.cuh | 97 ++++ .../custom_epilogue_with_broadcast.h | 1 + .../detail/fused_distance_nn/cutlass_base.cuh | 18 +- .../epilogue_elementwise.cuh | 9 +- .../fused_distance_nn/fused_cosine_nn.cuh | 136 ++++++ .../detail/fused_distance_nn/fused_l2_nn.cuh | 146 ++++++ .../fused_distance_nn/helper_structs.cuh | 146 ++++++ .../fused_distance_nn/persistent_gemm.h | 4 +- .../predicated_tile_iterator_reduced_vec.h | 99 ++--- .../detail/fused_distance_nn/simt_kernel.cuh | 187 ++++++++ .../raft/distance/detail/masked_nn.cuh | 2 +- .../raft/distance/fused_distance_nn-ext.cuh | 84 ++++ .../raft/distance/fused_distance_nn-inl.cuh | 328 ++++++++++++++ .../raft/distance/fused_distance_nn.cuh | 24 + ...pers.cuh => fused_distance_nn_helpers.cuh} | 4 +- cpp/include/raft/distance/fused_l2_nn-ext.cuh | 8 +- cpp/include/raft/distance/fused_l2_nn-inl.cuh | 4 +- .../distance/fused_distance_nn.hpp | 62 +++ .../raft_runtime/distance/fused_l2_nn.hpp | 36 +- cpp/src/distance/fused_distance_nn.cu | 53 +++ .../distance/fused_distance_min_arg.cu | 56 +++ .../distance/fused_distance_min_arg.hpp | 144 ++++++ .../raft_runtime/distance/fused_l2_min_arg.cu | 87 +--- cpp/test/CMakeLists.txt | 1 + cpp/test/distance/fused_cosine_nn.cu | 420 ++++++++++++++++++ cpp/test/distance/fused_l2_nn.cu | 1 - .../pylibraft/distance/CMakeLists.txt | 4 +- .../pylibraft/pylibraft/distance/__init__.py | 9 +- .../pylibraft/distance/fused_distance_nn.pyx | 200 +++++++++ .../test/test_fused_distance_argmin.py | 69 +++ 32 files changed, 2281 insertions(+), 174 deletions(-) create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh create mode 100644 cpp/include/raft/distance/fused_distance_nn-ext.cuh create mode 100644 cpp/include/raft/distance/fused_distance_nn-inl.cuh create mode 100755 cpp/include/raft/distance/fused_distance_nn.cuh rename cpp/include/raft/distance/{fused_l2_nn_helpers.cuh => fused_distance_nn_helpers.cuh} (92%) create mode 100644 cpp/include/raft_runtime/distance/fused_distance_nn.hpp create mode 100644 cpp/src/distance/fused_distance_nn.cu create mode 100644 cpp/src/raft_runtime/distance/fused_distance_min_arg.cu create mode 100644 cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp create mode 100644 cpp/test/distance/fused_cosine_nn.cu create mode 100644 python/pylibraft/pylibraft/distance/fused_distance_nn.pyx create mode 100755 python/pylibraft/pylibraft/test/test_fused_distance_argmin.py diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 638ceb3b45..6107b9325a 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -256,7 +256,7 @@ endif() if(RAFT_NVTX) # This enables NVTX within the project with no option to disable it downstream. - target_link_libraries(raft INTERFACE CUDA::nvToolsExt) + target_link_libraries(raft INTERFACE CUDA::nvtx3) target_compile_definitions(raft INTERFACE NVTX_ENABLED) else() # Allow enable NVTX downstream if not set here. This creates a new option at build/install time, @@ -324,6 +324,7 @@ if(RAFT_COMPILE_LIBRARY) src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu src/distance/distance.cu src/distance/fused_l2_nn.cu + src/distance/fused_distance_nn.cu src/linalg/detail/coalesced_reduction.cu src/matrix/detail/select_k_double_int64_t.cu src/matrix/detail/select_k_double_uint32_t.cu @@ -422,6 +423,7 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/cluster/update_centroids.cuh src/raft_runtime/cluster/update_centroids_double.cu src/raft_runtime/cluster/update_centroids_float.cu + src/raft_runtime/distance/fused_distance_min_arg.cu src/raft_runtime/distance/fused_l2_min_arg.cu src/raft_runtime/distance/pairwise_distance.cu src/raft_runtime/matrix/select_k_float_int64_t.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 5577881ef7..903f4e4347 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -122,16 +122,8 @@ if(BUILD_PRIMS_BENCH) ) ConfigureBench( - NAME - MATRIX_BENCH - PATH - bench/prims/matrix/argmin.cu - bench/prims/matrix/gather.cu - bench/prims/matrix/select_k.cu - bench/prims/matrix/main.cpp - OPTIONAL - LIB - EXPLICIT_INSTANTIATE_ONLY + NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu + bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( diff --git a/cpp/include/raft/distance/detail/fused_distance_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn.cuh new file mode 100644 index 0000000000..4fbfdc8755 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn.cuh @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include +#include +#include +#include // PairwiseDistances +#include +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +#include // size_t +#include // std::numeric_limits + +namespace raft { +namespace distance { + +namespace detail { + +template +void fusedDistanceNNImpl(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedDistanceNN. + typedef Policy P; + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef KeyValuePair KVPair; + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + switch (metric) { + case raft::distance::DistanceType::CosineExpanded: + fusedCosineNN( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream); + break; + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2Expanded: + // initOutBuffer is take care by fusedDistanceNNImpl() so we set it false to fusedL2NNImpl. + fusedL2NNImpl( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream); + break; + default: assert("only cosine/l2 metric is supported with fusedDistanceNN\n"); break; + } +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h index 32e9214a0d..186715851b 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h @@ -611,6 +611,7 @@ class EpilogueWithBroadcastCustom : public EpilogueBase +#include + #include #include #include @@ -46,6 +48,14 @@ namespace raft { namespace distance { namespace detail { +template +RAFT_KERNEL initBinMutexKernel(cuda::binary_semaphore* mut, IdxT m) +{ + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + + if (tid < m) { mut[tid].release(); } +} + template ; constexpr int batch_count = 1; + rmm::device_uvector> bin_mutex(m, stream); + + int blks_ = (m / 256) + 1; + + initBinMutexKernel<<>>(bin_mutex.data(), m); + typename EpilogueOutputOp::Params epilog_op_param( - dist_op, cg_reduce_op, redOp, pairRedOp, mutexes); + dist_op, cg_reduce_op, redOp, pairRedOp, mutexes, bin_mutex.data()); // Number of pipelines you want to use constexpr int NumStages = 3; diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh index f9ab394585..e69b2486df 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh @@ -56,6 +56,8 @@ #pragma once +#include + #include #include #include @@ -121,6 +123,7 @@ class FusedDistanceNNEpilogueElementwise { KVPReduceOpT_ pair_redop_; ReduceOpT_ red_op_; int* mutexes_; + cuda::binary_semaphore* bin_mutex_; using CGReduceT = CGReduceOp_; // // Methods @@ -130,12 +133,14 @@ class FusedDistanceNNEpilogueElementwise { CGReduceOp cg_reduce_op, ReduceOpT_ red_op, KVPReduceOpT_ pair_redop, - int* mutexes) + int* mutexes, + cuda::binary_semaphore* bin_mutex) : cg_reduce_op(cg_reduce_op), dist_op_(dist_op), pair_redop_(pair_redop), red_op_(red_op), - mutexes_(mutexes) + mutexes_(mutexes), + bin_mutex_(bin_mutex) { } diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh new file mode 100644 index 0000000000..f29c8b4d4c --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include +#include // PairwiseDistances +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +#include // size_t +#include // std::numeric_limits + +namespace raft { +namespace distance { + +namespace detail { + +template +void fusedCosineNN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedL2NN. + typedef Policy P; + + dim3 blk(P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef KeyValuePair KVPair; + + namespace arch = raft::util::arch; + using AccT = DataT; + ops::cosine_distance_op distance_op{}; + + raft::identity_op fin_op{}; + + auto kernel = fusedDistanceNNkernel; + + // Get pointer to fp32 SIMT kernel to determine the runtime architecture of the + // current system. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using cosineOp = raft::distance::detail::ops::cosine_cutlass_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + kvp_cg_min_reduce_op_ cg_reduce_op; + cosineOp cosine_dist_op; + + IdxT lda, ldb, ldd; + lda = k, ldb = k, ldd = n; + + cutlassFusedDistanceNN(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + min, + workspace, + cg_reduce_op, + cosine_dist_op, + redOp, + pairRedOp, + stream); + } else { + // If device less than SM_80, use fp32 SIMT kernel. + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + RAFT_CUDA_TRY(cudaGetLastError()); + } +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh new file mode 100644 index 0000000000..65475e73c7 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include +#include // PairwiseDistances +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +#include // size_t +#include // std::numeric_limits + +namespace raft { +namespace distance { + +namespace detail { + +template +void fusedL2NNImpl(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedL2NN. + typedef Policy P; + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef KeyValuePair KVPair; + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + namespace arch = raft::util::arch; + using AccT = DataT; + ops::l2_exp_distance_op distance_op{sqrt}; + + raft::identity_op fin_op{}; + + auto kernel = fusedDistanceNNkernel; + + // Get pointer to fp32 SIMT kernel to determine the best compute architecture + // out of all for which the kernel was compiled for that matches closely + // to the current device. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using L2Op = raft::distance::detail::ops::l2_exp_cutlass_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + kvp_cg_min_reduce_op_ cg_reduce_op; + L2Op L2_dist_op(sqrt); + + IdxT lda, ldb, ldd; + lda = k, ldb = k, ldd = n; + + cutlassFusedDistanceNN(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + min, + workspace, + cg_reduce_op, + L2_dist_op, + redOp, + pairRedOp, + stream); + } else { + // If device less than SM_80, use fp32 SIMT kernel. + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + RAFT_CUDA_TRY(cudaGetLastError()); + } +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh new file mode 100644 index 0000000000..e056c5d397 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include // PairwiseDistances +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl +#include + +#include // size_t +#include // std::numeric_limits + +namespace raft { +namespace distance { + +namespace detail { + +template +struct KVPMinReduceImpl { + typedef raft::KeyValuePair KVP; + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + +}; // KVPMinReduce + +template +struct MinAndDistanceReduceOpImpl { + typedef typename raft::KeyValuePair KVP; + + DI void operator()(LabelT rid, KVP* out, const KVP& other) const + { + if (other.value < out->value) { + out->key = other.key; + out->value = other.value; + } + } + DI void operator()(LabelT rid, volatile KVP* out, const KVP& other) const + { + if (other.value < out->value) { + out->key = other.key; + out->value = other.value; + } + } + + DI void operator()(LabelT rid, DataT* out, const KVP& other) const + { + if (other.value < *out) { *out = other.value; } + } + + DI void operator()(LabelT rid, volatile DataT* out, const KVP& other) const + { + if (other.value < *out) { *out = other.value; } + } + + DI void operator()(LabelT rid, DataT* out, const DataT& other) const + { + if (other < *out) { *out = other; } + } + + DI void operator()(LabelT rid, volatile DataT* out, const DataT& other) const + { + if (other < *out) { *out = other; } + } + + DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } + DI void init(KVP* out, DataT maxVal) const + { + out->value = maxVal; + out->key = 0xfffffff0; + } + + DI void init_key(DataT& out, LabelT idx) const { return; } + DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } + + DI DataT get_value(KVP& out) const { return out.value; } + DI DataT get_value(DataT& out) const { return out; } +}; + +template +struct MinReduceOpImpl { + typedef typename raft::KeyValuePair KVP; + DI void operator()(LabelT rid, DataT* out, const KVP& other) + { + if (other.value < *out) { *out = other.value; } + } + + DI void init(DataT* out, DataT maxVal) { *out = maxVal; } +}; + +template +RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) +{ + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (tid < m) { redOp.init(min + tid, maxVal); } +} + +template +void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, cudaStream_t stream) +{ + auto blks = raft::ceildiv(m, 256); + initKernel<<>>(min, m, maxVal, redOp); +} + +// cg::reduce functor for FusedDistanceNN used in its cutlass version +// to output the min distance value & key(loc id). +// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h +// store_with_byte_offset() passed to cg::reduce() & select_reduce. +template +struct kvp_cg_min_reduce_op { + typedef typename raft::KeyValuePair KVP; + + __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; + + using AccTypeT = AccType; + using IndexT = Index; + // functor signature. + __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } + + __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } + + __host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } +}; + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index 4f05251705..f1a7c728e9 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -180,8 +180,7 @@ struct FusedDistanceNNPersistent { /// Default ctor CUTLASS_HOST_DEVICE Arguments() - : // problem_count(0), - threadblock_count(0), + : threadblock_count(0), ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), @@ -235,7 +234,6 @@ struct FusedDistanceNNPersistent { /// Parameters structure struct Params { - // typename ProblemVisitor::Params problem_visitor; temp_problem_visitor problem_visitor; int threadblock_count; diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index 936b9f2c89..d61018593f 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -322,10 +322,8 @@ class PredicatedTileIteratorReducedVec { /// Parameters structure containing reference and precomputed state. Params params_; - /// Byte-level pointer - uint8_t* byte_pointer_; /// Byte-level pointer first tile offset of this threadblock. - uint8_t* first_tile_byte_pointer_; + volatile uint8_t* first_tile_byte_pointer_; /// Array of boolean values to contain steady-state predicates Mask mask_; @@ -350,6 +348,8 @@ class PredicatedTileIteratorReducedVec { /// Scatter indices int const* indices_; + const int do_gmem_reduction_; + // // Static asserts about internal strides // @@ -360,7 +360,6 @@ class PredicatedTileIteratorReducedVec { protected: SharedStorage& shared_storage_; - const bool& do_gmem_reduction_; private: // @@ -374,10 +373,10 @@ class PredicatedTileIteratorReducedVec { CUTLASS_DEVICE PredicatedTileIteratorReducedVec(SharedStorage& shared_storage, Params const& params, - Element* pointer, + volatile Element* pointer, TensorCoord extent, int thread_idx, - const bool& do_gmem_reduction, + const bool do_gmem_reduction, TensorCoord threadblock_offset = TensorCoord(), int const* indices = nullptr) : params_(params), @@ -409,6 +408,7 @@ class PredicatedTileIteratorReducedVec { EpilogueOpParams const& user_params = params_.user_param; shared_storage_.initSmem(user_params); } + __syncthreads(); // Null pointer performs no accesses if (!pointer) { mask_.clear(); } @@ -416,66 +416,53 @@ class PredicatedTileIteratorReducedVec { if (ScatterD && !indices) { mask_.clear(); } // Initialize pointer - first_tile_byte_pointer_ = reinterpret_cast(pointer) + + first_tile_byte_pointer_ = reinterpret_cast(pointer) + LongIndex(block_offset.row()) * LongIndex(params_.stride); - if (ScatterD) { - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; - } - // Initialize internal state counter state_[0] = state_[1] = state_[2] = 0; } - /// Destructor - CUTLASS_DEVICE - ~PredicatedTileIteratorReducedVec() + CUTLASS_DEVICE void dumpToGmem() { + if (block_start_row_first_tile_ >= extent_row_) { return; } + if (do_gmem_reduction_) { EpilogueOpParams const& user_params = params_.user_param; - auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - Element* shared_elem_arr = shared_storage_.data(); const uint32_t mutex_id = (block_start_row_first_tile_ / total_rows); - bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); - // If this is not optimal grid size perform mutex based gmem reduce. - if (useGmemMutex) { - // single lock per block for multiple rows - if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { - // acquire mutex lock. - unsigned int ns = 8; - while (atomicCAS(user_params.mutexes_ + mutex_id, 0, 1) == 1) { - __nanosleep(ns); - if (ns < 256) { ns *= 2; } - } + const bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); + int row = threadIdx.x; + Element* shared_elem_arr = shared_storage_.data(); + Element row_local_min; + if (row < total_rows) { row_local_min = shared_elem_arr[row]; } + + // single lock per block for multiple rows + if (useGmemMutex && threadIdx.x == 0) { user_params.bin_mutex_[mutex_id].acquire(); } + __syncthreads(); + + if (row < total_rows) { + volatile Element* gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + + if ((block_start_row_first_tile_ + row) < extent_row_) { + user_params.red_op_(block_start_row_first_tile_ + row, (gmem_ptr + row), row_local_min); } } __syncthreads(); - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - if (block_start_row_first_tile_ + row < extent_row_) { - user_params.red_op_( - block_start_row_first_tile_ + row, &gmem_ptr[row], shared_elem_arr[row]); - } - } + __threadfence(); - if (useGmemMutex) { - __threadfence(); - __syncthreads(); - if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { - // release mutex lock. - atomicExch(user_params.mutexes_ + mutex_id, 0); - } + if (useGmemMutex && (threadIdx.x == 0)) { + // release mutex lock. + user_params.bin_mutex_[mutex_id].release(); } + shared_storage_.initSmem(user_params); + __syncthreads(); } } - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - byte_pointer_ += pointer_offset * sizeof_bits::value / 8; - } + /// Destructor + CUTLASS_DEVICE + ~PredicatedTileIteratorReducedVec() {} /// Performs reduction and Stores a reduced output to memory CUTLASS_DEVICE @@ -515,9 +502,6 @@ class PredicatedTileIteratorReducedVec { user_params.red_op_.init(&red_val, maxVal); if (row_guard) { - const int iter_row = (row_id % total_rows); - const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); - CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn * kElementsPerAccess; ++column) { @@ -536,6 +520,10 @@ class PredicatedTileIteratorReducedVec { user_params.red_op_(row_id, &red_val, this_val); } } + } + const int iter_row = (row_id % total_rows); + const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); + if (row_guard) { // select_reduce doesn't need to use `red_op_` as at the warp level we use cg_reduce_op, // this satisfies the requirement of mst/single linkage of checking colors buffer. select_reduce red_obj( @@ -544,6 +532,7 @@ class PredicatedTileIteratorReducedVec { } } } + __syncthreads(); } /// Stores a fragment to memory @@ -575,14 +564,11 @@ class PredicatedTileIteratorReducedVec { { ++state_[0]; - if (!ScatterD) { byte_pointer_ += params_.advance_row; } - thread_start_row_ += ThreadMap::Shape::kRow; if (state_[0] == ThreadMap::Count::kRow) { state_[0] = 0; ++state_[1]; - byte_pointer_ += params_.advance_group; thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; @@ -590,18 +576,13 @@ class PredicatedTileIteratorReducedVec { if (state_[1] == ThreadMap::Count::kGroup) { state_[1] = 0; ++state_[2]; - byte_pointer_ += params_.advance_cluster; thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; - if (state_[2] == ThreadMap::Count::kCluster) { - state_[2] = 0; - byte_pointer_ += params_.advance_tile; - } + if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; } } } - return *this; } diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh new file mode 100644 index 0000000000..7417fd5dac --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // ops::l2_exp_distance_op +#include // PairwiseDistances +#include // Policy + +#include // size_t +#include // std::numeric_limits + +namespace raft { +namespace distance { +namespace detail { + +// TODO: specialize this function for MinAndDistanceReduceOp +// with atomicCAS of 64 bit which will eliminate mutex and shfls +template +DI void updateReducedVal( + int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, IdxT m, IdxT gridStrideY) +{ + const auto lid = threadIdx.x % raft::WarpSize; + const auto accrowid = threadIdx.x / P::AccThCols; + + // Update each output row in order within a warp. This will resolve hang + // issues with pre-Volta architectures +#pragma unroll + for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { + if (lid == j * P::AccThCols) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rid = gridStrideY + accrowid + i * P::AccThRows; + if (rid < m) { + auto value = val[i]; + while (atomicCAS(mutex + rid, 0, 1) == 1) + ; + __threadfence(); + red_op(rid, min + rid, value); + __threadfence(); + atomicCAS(mutex + rid, 1, 0); + } + } + } + } +} + +template +__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedDistanceNNkernel(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + DataT maxVal, + int* mutex, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + OpT distance_op, + FinalLambda fin_op) +{ +// compile only if below non-ampere arch. +#if __CUDA_ARCH__ < 800 + extern __shared__ char smem[]; + + typedef KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {0, maxVal}; + } + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( + DataT acc[P::AccRowsPerTh][P::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + + // intra thread reduce + const auto acccolid = threadIdx.x % P::AccThCols; + const auto accrowid = threadIdx.x / P::AccThCols; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < n) { + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + } + }; + + auto rowEpilog_lambda = + [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); + + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + + // reduce +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = P::AccThCols / 2; j > 0; j >>= 1) { + // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, + // but the shfl op applies the modulo internally. + auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); + KVPair tmp = {tmpkey, tmpvalue}; + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + + updateReducedVal(mutex, min, val, red_op, m, gridStrideY); + + // reset the val array. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {0, maxVal}; + } + }; + + IdxT lda = k, ldb = k, ldd = n; + constexpr bool row_major = true; + constexpr bool write_out = false; + PairwiseDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + xn, + yn, + nullptr, // Output pointer + smem, + distance_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); +#endif +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 536e4937ab..3e3699766f 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/cpp/include/raft/distance/fused_distance_nn-ext.cuh b/cpp/include/raft/distance/fused_distance_nn-ext.cuh new file mode 100644 index 0000000000..263bbcea81 --- /dev/null +++ b/cpp/include/raft/distance/fused_distance_nn-ext.cuh @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // raft::resources +#include // include initialize and reduce operations +#include // RAFT_EXPLICIT + +#include // int64_t + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft { +namespace distance { + +template +void fusedDistanceNNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) RAFT_EXPLICIT; + +} // namespace distance +} // namespace raft + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_distance_fusedDistanceNNMinReduce(DataT, OutT, IdxT) \ + extern template void raft::distance::fusedDistanceNNMinReduce( \ + OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + bool isRowMajor, \ + raft::distance::DistanceType metric, \ + float metric_arg, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedDistanceNNMinReduce diff --git a/cpp/include/raft/distance/fused_distance_nn-inl.cuh b/cpp/include/raft/distance/fused_distance_nn-inl.cuh new file mode 100644 index 0000000000..ffe86a1c04 --- /dev/null +++ b/cpp/include/raft/distance/fused_distance_nn-inl.cuh @@ -0,0 +1,328 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __FUSED_DISTANCE_NN_H +#define __FUSED_DISTANCE_NN_H + +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +namespace raft { +namespace distance { + +/** + * \ingroup fused_l2_nn + * @{ + */ +/** + * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. + * + * The benefits of such a call are 2-fold: 1) eliminate the need for an + * intermediate buffer to store the output of gemm 2) reduce the memory read + * traffic on this intermediate buffer, otherwise needed during the reduction + * phase for 1-NN. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances or store only the min distances. Accordingly, one + * has to pass an appropriate `ReduceOpT` + * @tparam IdxT indexing arithmetic type + * @tparam ReduceOpT A struct to perform the final needed reduction operation + * and also to initialize the output array elements with the + * appropriate initial value needed for reduction. + * @tparam KVPReduceOpT A struct providing functions for key-value pair comparison. + * + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] redOp reduction operator in the epilogue + * @param[in] pairRedOp reduction operation on key value pairs + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] isRowMajor whether the input/output is row or column major. + * @param[in] metric Distance metric to be used (supports L2, cosine) + * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) + * @param[in] stream cuda stream + */ +template +void fusedDistanceNN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + ASSERT(isRowMajor, "fusedDistanceNN only supports row major inputs"); + // When k is smaller than 32, the Policy4x4 results in redundant calculations + // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead + // that uses tiles with a smaller value of k. + bool is_skinny = k < 32; + + size_t bytes = sizeof(DataT) * k; + auto px = reinterpret_cast(x); + auto py = reinterpret_cast(y); + if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { + if (is_skinny) { + detail::fusedDistanceNNImpl< + DataT, + OutT, + IdxT, + typename linalg::Policy4x4Skinny::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } else { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } + } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { + if (is_skinny) { + detail::fusedDistanceNNImpl< + DataT, + OutT, + IdxT, + typename linalg::Policy4x4Skinny::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } else { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } + } else { + if (is_skinny) { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } else { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } + } +} + +/** + * @brief Wrapper around fusedDistanceNN with minimum reduction operators. + * + * fusedDistanceNN cannot be compiled in the distance library due to the lambda + * operators, so this wrapper covers the most common case (minimum). + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances (e.g. raft::KeyValuePair) or store only the min + * distances. + * @tparam IdxT indexing arithmetic type + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] isRowMajor whether the input/output is row or column major. + * @param[in] metric Distance metric to be used (supports L2, cosine) + * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) + * @param[in] stream cuda stream + */ +template +void fusedDistanceNNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + MinAndDistanceReduceOp redOp; + KVPMinReduce pairRedOp; + + fusedDistanceNN(min, + x, + y, + xn, + yn, + m, + n, + k, + workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); +} + +/** @} */ + +} // namespace distance +} // namespace raft + +#endif diff --git a/cpp/include/raft/distance/fused_distance_nn.cuh b/cpp/include/raft/distance/fused_distance_nn.cuh new file mode 100755 index 0000000000..04c42e49a1 --- /dev/null +++ b/cpp/include/raft/distance/fused_distance_nn.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "fused_distance_nn-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "fused_distance_nn-ext.cuh" +#endif diff --git a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh b/cpp/include/raft/distance/fused_distance_nn_helpers.cuh similarity index 92% rename from cpp/include/raft/distance/fused_l2_nn_helpers.cuh rename to cpp/include/raft/distance/fused_distance_nn_helpers.cuh index 996f696ef6..3a570c681c 100644 --- a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh +++ b/cpp/include/raft/distance/fused_distance_nn_helpers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ #pragma once #include -#include +#include namespace raft::distance { diff --git a/cpp/include/raft/distance/fused_l2_nn-ext.cuh b/cpp/include/raft/distance/fused_l2_nn-ext.cuh index c7dc4c5fc6..d0ac83cd51 100644 --- a/cpp/include/raft/distance/fused_l2_nn-ext.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-ext.cuh @@ -16,10 +16,10 @@ #pragma once -#include // raft::KeyValuePair -#include // raft::resources -#include // include initialize and reduce operations -#include // RAFT_EXPLICIT +#include // raft::KeyValuePair +#include // raft::resources +#include // include initialize and reduce operations +#include // RAFT_EXPLICIT #include // int64_t diff --git a/cpp/include/raft/distance/fused_l2_nn-inl.cuh b/cpp/include/raft/distance/fused_l2_nn-inl.cuh index 6f4db15c72..bf9a49d813 100644 --- a/cpp/include/raft/distance/fused_l2_nn-inl.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-inl.cuh @@ -20,8 +20,8 @@ #pragma once #include -#include -#include +#include +#include #include #include diff --git a/cpp/include/raft_runtime/distance/fused_distance_nn.hpp b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp new file mode 100644 index 0000000000..7c309d6fc7 --- /dev/null +++ b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::distance { + +/** + * @defgroup fused_distance_nn_min_arg_runtime Fused Distance 1NN Runtime API + * @{ + */ + +/** + * @brief Wrapper around fusedDistanceNN with minimum reduction operators. + * + * fusedDistanceNN cannot be compiled in the distance library due to the lambda + * operators, so this wrapper covers the most common case (minimum). + * + * @param[in] handle raft handle + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] metric Distance metric to be used (supports L2, cosine) + * @param[in] isRowMajor whether the input/output is row or column major. + * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) + */ +void fused_distance_nn_min_arg(raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + +/** @} */ // end group fused_distance_nn_min_arg_runtime + +} // end namespace raft::runtime::distance diff --git a/cpp/include/raft_runtime/distance/fused_l2_nn.hpp b/cpp/include/raft_runtime/distance/fused_l2_nn.hpp index 6154e03f4c..e46b3c5271 100644 --- a/cpp/include/raft_runtime/distance/fused_l2_nn.hpp +++ b/cpp/include/raft_runtime/distance/fused_l2_nn.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,23 +42,25 @@ namespace raft::runtime::distance { * @param[in] k gemm k * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt */ -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const float* x, - const float* y, - int m, - int n, - int k, - bool sqrt); +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt); -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const double* x, - const double* y, - int m, - int n, - int k, - bool sqrt); +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const double* x, + const double* y, + int m, + int n, + int k, + bool sqrt); /** @} */ // end group fused_l2_nn_min_arg_runtime diff --git a/cpp/src/distance/fused_distance_nn.cu b/cpp/src/distance/fused_distance_nn.cu new file mode 100644 index 0000000000..dc722d929c --- /dev/null +++ b/cpp/src/distance/fused_distance_nn.cu @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // raft::KeyValuePair +#include + +#include // int64_t + +#define instantiate_raft_distance_fusedDistanceNNMinReduce(DataT, OutT, IdxT) \ + template void raft::distance::fusedDistanceNNMinReduce( \ + OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + bool isRowMajor, \ + raft::distance::DistanceType metric, \ + float metric_arg, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedDistanceNNMinReduce diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu new file mode 100644 index 0000000000..dfdff4e94b --- /dev/null +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fused_distance_min_arg.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft::runtime::distance { + +void fused_distance_nn_min_arg(raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) +{ + switch (metric) { + case raft::distance::DistanceType::CosineExpanded: + compute_fused_cosine_nn_min_arg(handle, min, x, y, m, n, k, sqrt); + break; + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: + compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); + break; + default: assert("only Cosine/L2 metric is supported with fusedDistanceNN\n"); break; + } +} + +} // end namespace raft::runtime::distance diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp b/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp new file mode 100644 index 0000000000..6452752a79 --- /dev/null +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft::runtime::distance { + +template +struct KeyValueIndexOp { + __host__ __device__ __forceinline__ IndexT + operator()(const raft::KeyValuePair& a) const + { + return a.key; + } +}; + +template +void compute_fused_l2_nn_min_arg(raft::resources const& handle, + idx_t* min, + const value_t* x, + const value_t* y, + idx_t m, + idx_t n, + idx_t k, + bool sqrt) +{ + rmm::device_uvector workspace(m, resource::get_cuda_stream(handle)); + auto kvp = raft::make_device_vector>(handle, m); + constexpr bool is_row_major = true; + + rmm::device_uvector x_norms(m, resource::get_cuda_stream(handle)); + rmm::device_uvector y_norms(n, resource::get_cuda_stream(handle)); + raft::linalg::rowNorm( + x_norms.data(), x, k, m, raft::linalg::L2Norm, is_row_major, resource::get_cuda_stream(handle)); + raft::linalg::rowNorm( + y_norms.data(), y, k, n, raft::linalg::L2Norm, is_row_major, resource::get_cuda_stream(handle)); + + raft::distance::fusedL2NNMinReduce(kvp.data_handle(), + x, + y, + x_norms.data(), + y_norms.data(), + m, + n, + k, + (void*)workspace.data(), + sqrt, + true, + resource::get_cuda_stream(handle)); + + KeyValueIndexOp conversion_op; + thrust::transform(resource::get_thrust_policy(handle), + kvp.data_handle(), + kvp.data_handle() + m, + min, + conversion_op); + resource::sync_stream(handle); +} + +template +void compute_fused_cosine_nn_min_arg(raft::resources const& handle, + idx_t* min, + const value_t* x, + const value_t* y, + idx_t m, + idx_t n, + idx_t k, + bool sqrt) +{ + rmm::device_uvector workspace(m, resource::get_cuda_stream(handle)); + auto kvp = raft::make_device_vector>(handle, m); + + rmm::device_uvector x_norms(m, resource::get_cuda_stream(handle)); + rmm::device_uvector y_norms(n, resource::get_cuda_stream(handle)); + constexpr bool is_row_major = true; + raft::linalg::rowNorm(x_norms.data(), + x, + k, + m, + raft::linalg::L2Norm, + is_row_major, + resource::get_cuda_stream(handle), + raft::sqrt_op{}); + raft::linalg::rowNorm(y_norms.data(), + y, + k, + n, + raft::linalg::L2Norm, + is_row_major, + resource::get_cuda_stream(handle), + raft::sqrt_op{}); + + raft::distance::fusedDistanceNNMinReduce(kvp.data_handle(), + x, + y, + x_norms.data(), + y_norms.data(), + m, + n, + k, + (void*)workspace.data(), + sqrt, + true, + is_row_major, + raft::distance::DistanceType::CosineExpanded, + 0.0f, + resource::get_cuda_stream(handle)); + + KeyValueIndexOp conversion_op; + thrust::transform(resource::get_thrust_policy(handle), + kvp.data_handle(), + kvp.data_handle() + m, + min, + conversion_op); + resource::sync_stream(handle); +} + +} // end namespace raft::runtime::distance diff --git a/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu index bf4b1a431a..870757dca1 100644 --- a/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu +++ b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu @@ -14,6 +14,8 @@ * limitations under the License. */ +#include "fused_distance_min_arg.hpp" + #include #include #include @@ -28,77 +30,28 @@ namespace raft::runtime::distance { -template -struct KeyValueIndexOp { - __host__ __device__ __forceinline__ IndexT - operator()(const raft::KeyValuePair& a) const - { - return a.key; - } -}; - -template -void compute_fused_l2_nn_min_arg(raft::resources const& handle, - idx_t* min, - const value_t* x, - const value_t* y, - idx_t m, - idx_t n, - idx_t k, - bool sqrt) -{ - rmm::device_uvector workspace(m, resource::get_cuda_stream(handle)); - auto kvp = raft::make_device_vector>(handle, m); - - rmm::device_uvector x_norms(m, resource::get_cuda_stream(handle)); - rmm::device_uvector y_norms(n, resource::get_cuda_stream(handle)); - raft::linalg::rowNorm( - x_norms.data(), x, k, m, raft::linalg::L2Norm, true, resource::get_cuda_stream(handle)); - raft::linalg::rowNorm( - y_norms.data(), y, k, n, raft::linalg::L2Norm, true, resource::get_cuda_stream(handle)); - - raft::distance::fusedL2NNMinReduce(kvp.data_handle(), - x, - y, - x_norms.data(), - y_norms.data(), - m, - n, - k, - (void*)workspace.data(), - sqrt, - true, - resource::get_cuda_stream(handle)); - - KeyValueIndexOp conversion_op; - thrust::transform(resource::get_thrust_policy(handle), - kvp.data_handle(), - kvp.data_handle() + m, - min, - conversion_op); - resource::sync_stream(handle); -} - -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const float* x, - const float* y, - int m, - int n, - int k, - bool sqrt) +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt) { compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); } -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const double* x, - const double* y, - int m, - int n, - int k, - bool sqrt) +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const double* x, + const double* y, + int m, + int n, + int k, + bool sqrt) { compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); } diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 65d6e738a2..bf44cf9c60 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -171,6 +171,7 @@ if(BUILD_TESTS) test/distance/masked_nn.cu test/distance/masked_nn_compress_to_bits.cu test/distance/fused_l2_nn.cu + test/distance/fused_cosine_nn.cu test/distance/gram.cu LIB EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/test/distance/fused_cosine_nn.cu b/cpp/test/distance/fused_cosine_nn.cu new file mode 100644 index 0000000000..d4d632e1dc --- /dev/null +++ b/cpp/test/distance/fused_cosine_nn.cu @@ -0,0 +1,420 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Search with filter instantiation + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace raft { +namespace distance { + +template +struct RaftKVPMinReduce { + typedef raft::KeyValuePair KVP; + + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + +}; // KVPMinReduce + +template +__global__ void naiveCosKernel(raft::KeyValuePair* min, + DataT* x, + DataT* y, + int m, + int n, + int k, + int* workspace, + DataT maxVal) +{ + int midx = threadIdx.y + blockIdx.y * blockDim.y; + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + DataT acc_a = DataT(0); + DataT acc_b = DataT(0); + DataT acc_ab = DataT(0); + // if (midx >= m || nidx >= n) { return; } + + for (int i = 0; i < k; ++i) { + int xidx = i + midx * k; + int yidx = i + nidx * k; + auto a = x[xidx]; + auto b = y[yidx]; + acc_a += a * a; + acc_b += b * b; + acc_ab += a * b; + } + + // Use 1.0 - (cosine similarity) to calc the distance + DataT acc = maxVal; + if (midx < m || nidx < n) { acc = (DataT)1.0 - acc_ab / (raft::sqrt(acc_a) * raft::sqrt(acc_b)); } + + ReduceOpT redOp; + typedef cub::WarpReduce> WarpReduce; + __shared__ typename WarpReduce::TempStorage temp[NWARPS]; + int warpId = threadIdx.x / raft::WarpSize; + raft::KeyValuePair tmp; + tmp.key = nidx; + tmp.value = midx >= m || nidx >= n ? maxVal : acc; + tmp = WarpReduce(temp[warpId]).Reduce(tmp, RaftKVPMinReduce()); + if (threadIdx.x % raft::WarpSize == 0 && midx < m) { + while (atomicCAS(workspace + midx, 0, 1) == 1) + ; + __threadfence(); + redOp(midx, min + midx, tmp); + __threadfence(); + atomicCAS(workspace + midx, 1, 0); + } +} + +template +void naive(raft::KeyValuePair* min, + DataT* x, + DataT* y, + int m, + int n, + int k, + int* workspace, + cudaStream_t stream) +{ + static const dim3 TPB(32, 16, 1); + dim3 nblks(raft::ceildiv(n, (int)TPB.x), raft::ceildiv(m, (int)TPB.y), 1); + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + auto blks = raft::ceildiv(m, 256); + MinAndDistanceReduceOp op; + detail::initKernel, int> + <<>>(min, m, std::numeric_limits::max(), op); + RAFT_CUDA_TRY(cudaGetLastError()); + naiveCosKernel, 16> + <<>>(min, x, y, m, n, k, workspace, std::numeric_limits::max()); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +template +struct Inputs { + DataT tolerance; + int m, n, k; + unsigned long long int seed; + + friend std::ostream& operator<<(std::ostream& os, const Inputs& p) + { + return os << "m: " << p.m + << ", " + "n: " + << p.n + << ", " + "k: " + << p.k + << ", " + "seed: " + << p.seed + << ", " + "tol: " + << p.tolerance; + } +}; + +template +class FusedCosineNNTest : public ::testing::TestWithParam> { + public: + FusedCosineNNTest() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + x(params.m * params.k, stream), + y(params.n * params.k, stream), + xn(params.m, stream), + yn(params.n, stream), + min(params.m, stream), + min_ref(params.m, stream), + workspace(params.m * sizeof(int), stream) + { + } + + protected: + void SetUp() override + { + raft::random::RngState r(params.seed); + int m = params.m; + int n = params.n; + int k = params.k; + uniform(handle, r, x.data(), m * k, DataT(-1.0), DataT(1.0)); + uniform(handle, r, y.data(), n * k, DataT(-1.0), DataT(1.0)); + generateGoldenResult(); + raft::linalg::rowNorm( + xn.data(), x.data(), k, m, raft::linalg::L2Norm, true, stream, raft::sqrt_op{}); + raft::linalg::rowNorm( + yn.data(), y.data(), k, n, raft::linalg::L2Norm, true, stream, raft::sqrt_op{}); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + Inputs params; + rmm::device_uvector x; + rmm::device_uvector y; + rmm::device_uvector xn; + rmm::device_uvector yn; + rmm::device_uvector> min; + rmm::device_uvector> min_ref; + rmm::device_uvector workspace; + + virtual void generateGoldenResult() + { + int m = params.m; + int n = params.n; + int k = params.k; + naive(min_ref.data(), x.data(), y.data(), m, n, k, (int*)workspace.data(), stream); + } + + void runTest(raft::KeyValuePair* out) + { + int m = params.m; + int n = params.n; + int k = params.k; + raft::distance::DistanceType metric = raft::distance::DistanceType::CosineExpanded; + constexpr bool init_out_buffer = true; + fusedDistanceNNMinReduce, int>(out, + x.data(), + y.data(), + xn.data(), + yn.data(), + m, + n, + k, + (void*)workspace.data(), + false, + init_out_buffer, + true, + metric, + 0.0f, + stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } +}; + +template +struct CompareApproxAbsKVP { + typedef typename raft::KeyValuePair KVP; + CompareApproxAbsKVP(T eps_) : eps(eps_) {} + bool operator()(const KVP& a, const KVP& b) const + { + T diff = std::abs(std::abs(a.value) - std::abs(b.value)); + T m = std::max(std::abs(a.value), std::abs(b.value)); + T ratio = m >= eps ? diff / m : diff; + return (ratio <= eps); + } + + private: + T eps; +}; + +template +struct CompareExactKVP { + typedef typename raft::KeyValuePair KVP; + bool operator()(const KVP& a, const KVP& b) const + { + if (a.value != b.value) return false; + return true; + } +}; + +template +::testing::AssertionResult devArrMatch(const raft::KeyValuePair* expected, + const raft::KeyValuePair* actual, + size_t size, + L eq_compare, + cudaStream_t stream = 0) +{ + typedef typename raft::KeyValuePair KVP; + std::shared_ptr exp_h(new KVP[size]); + std::shared_ptr act_h(new KVP[size]); + raft::update_host(exp_h.get(), expected, size, stream); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < size; ++i) { + auto exp = exp_h.get()[i]; + auto act = act_h.get()[i]; + if (!eq_compare(exp, act)) { + return ::testing::AssertionFailure() + << "actual=" << act.key << "," << act.value << " != expected=" << exp.key << "," + << exp.value << " @" << i; + } + } + return ::testing::AssertionSuccess(); +} + +const std::vector> inputsf = { + {0.001f, 32, 32, 32, 1234ULL}, + {0.001f, 32, 64, 32, 1234ULL}, + {0.001f, 64, 32, 32, 1234ULL}, + {0.001f, 64, 64, 32, 1234ULL}, + {0.001f, 128, 32, 32, 1234ULL}, + {0.001f, 128, 64, 32, 1234ULL}, + {0.001f, 128, 128, 64, 1234ULL}, + {0.001f, 64, 128, 128, 1234ULL}, + + {0.001f, 32, 32, 34, 1234ULL}, + {0.001f, 32, 64, 34, 1234ULL}, + {0.001f, 64, 32, 34, 1234ULL}, + {0.001f, 64, 64, 34, 1234ULL}, + {0.001f, 128, 32, 34, 1234ULL}, + {0.001f, 128, 64, 34, 1234ULL}, + {0.001f, 128, 128, 66, 1234ULL}, + {0.001f, 64, 128, 130, 1234ULL}, + + {0.001f, 32, 32, 33, 1234ULL}, + {0.001f, 32, 64, 33, 1234ULL}, + {0.001f, 64, 32, 33, 1234ULL}, + {0.001f, 64, 64, 33, 1234ULL}, + {0.001f, 128, 32, 33, 1234ULL}, + {0.001f, 128, 64, 33, 1234ULL}, + {0.001f, 128, 128, 65, 1234ULL}, + {0.001f, 64, 128, 129, 1234ULL}, + {0.006f, 1805, 134, 2, 1234ULL}, + {0.006f, 8192, 1024, 64, 1234ULL}, + {0.006f, 8192, 1025, 64, 1234ULL}, + + // Repeat with smaller values of k + {0.006f, 32, 32, 1, 1234ULL}, + {0.001f, 32, 64, 2, 1234ULL}, + {0.001f, 64, 32, 3, 1234ULL}, + {0.001f, 64, 64, 4, 1234ULL}, + {0.001f, 128, 32, 5, 1234ULL}, + {0.001f, 128, 64, 6, 1234ULL}, + {0.001f, 128, 128, 7, 1234ULL}, + {0.001f, 64, 128, 8, 1234ULL}, + + {0.001f, 32, 32, 9, 1234ULL}, + {0.001f, 32, 64, 10, 1234ULL}, + {0.001f, 64, 32, 11, 1234ULL}, + {0.001f, 64, 64, 12, 1234ULL}, + {0.001f, 128, 32, 13, 1234ULL}, + {0.001f, 128, 64, 14, 1234ULL}, + {0.001f, 128, 128, 15, 1234ULL}, + {0.001f, 64, 128, 16, 1234ULL}, + + {0.001f, 32, 32, 17, 1234ULL}, + {0.001f, 32, 64, 18, 1234ULL}, + {0.001f, 64, 32, 19, 1234ULL}, + {0.001f, 64, 64, 20, 1234ULL}, + {0.001f, 128, 32, 21, 1234ULL}, + {0.001f, 128, 64, 22, 1234ULL}, + {0.001f, 128, 128, 23, 1234ULL}, + {0.00001, 64, 128, 24, 1234ULL}, + {0.001f, 1805, 134, 25, 1234ULL}, + {0.006f, 8192, 1024, 25, 1234ULL}, + {0.006f, 8192, 1024, 66, 1234ULL}, +}; +typedef FusedCosineNNTest FusedCosineNNTestF; +TEST_P(FusedCosineNNTestF, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNTests, FusedCosineNNTestF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.00001, 32, 32, 32, 1234ULL}, {0.00001, 32, 64, 32, 1234ULL}, + {0.00001, 64, 32, 32, 1234ULL}, {0.00001, 64, 64, 32, 1234ULL}, + {0.00001, 128, 32, 32, 1234ULL}, {0.00001, 128, 64, 32, 1234ULL}, + {0.00001, 128, 128, 64, 1234ULL}, {0.00001, 64, 128, 128, 1234ULL}, + + {0.00001, 32, 32, 34, 1234ULL}, {0.00001, 32, 64, 34, 1234ULL}, + {0.00001, 64, 32, 34, 1234ULL}, {0.00001, 64, 64, 34, 1234ULL}, + {0.00001, 128, 32, 34, 1234ULL}, {0.00001, 128, 64, 34, 1234ULL}, + {0.00001, 128, 128, 66, 1234ULL}, {0.00001, 64, 128, 130, 1234ULL}, + + {0.00001, 32, 32, 33, 1234ULL}, {0.00001, 32, 64, 33, 1234ULL}, + {0.00001, 64, 32, 33, 1234ULL}, {0.00001, 64, 64, 33, 1234ULL}, + {0.00001, 128, 32, 33, 1234ULL}, {0.00001, 128, 64, 33, 1234ULL}, + {0.00001, 128, 128, 65, 1234ULL}, {0.00001, 64, 128, 129, 1234ULL}, + + {0.00001, 1805, 134, 2, 1234ULL}, {0.00001, 8192, 1024, 25, 1234ULL}, +}; +typedef FusedCosineNNTest FusedCosineNNTestD; +TEST_P(FusedCosineNNTestD, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNTests, FusedCosineNNTestD, ::testing::ValuesIn(inputsd)); + +/// This is to test output determinism of the prim +template +class FusedCosineNNDetTest : public FusedCosineNNTest { + public: + FusedCosineNNDetTest() : stream(resource::get_cuda_stream(handle)), min1(0, stream) {} + + void SetUp() override + { + FusedCosineNNTest::SetUp(); + int m = this->params.m; + min1.resize(m, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void TearDown() override { FusedCosineNNTest::TearDown(); } + + protected: + raft::resources handle; + cudaStream_t stream; + + rmm::device_uvector> min1; + + static const int NumRepeats = 3; + + void generateGoldenResult() override {} +}; + +typedef FusedCosineNNDetTest FusedCosineNNDetTestF; +TEST_P(FusedCosineNNDetTestF, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + cudaMemsetAsync(min1.data(), 0, sizeof(*min.data()) * params.m, stream); + } +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNDetTests, FusedCosineNNDetTestF, ::testing::ValuesIn(inputsf)); + +typedef FusedCosineNNDetTest FusedCosineNNDetTestD; +TEST_P(FusedCosineNNDetTestD, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + } +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNDetTests, FusedCosineNNDetTestD, ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index d21a525d88..6fd8f15808 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -18,7 +18,6 @@ #include #include -#include #include #include #include diff --git a/python/pylibraft/pylibraft/distance/CMakeLists.txt b/python/pylibraft/pylibraft/distance/CMakeLists.txt index 14f0cc441a..2530e07a98 100644 --- a/python/pylibraft/pylibraft/distance/CMakeLists.txt +++ b/python/pylibraft/pylibraft/distance/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -13,7 +13,7 @@ # ============================================================================= # Set the list of Cython files to build -set(cython_sources pairwise_distance.pyx fused_l2_nn.pyx) +set(cython_sources pairwise_distance.pyx fused_l2_nn.pyx fused_distance_nn.pyx) set(linked_libraries raft::raft raft::compiled) # Build all of the Cython targets diff --git a/python/pylibraft/pylibraft/distance/__init__.py b/python/pylibraft/pylibraft/distance/__init__.py index f059b5f3dd..d16ab30b2f 100644 --- a/python/pylibraft/pylibraft/distance/__init__.py +++ b/python/pylibraft/pylibraft/distance/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,12 @@ # limitations under the License. # +from .fused_distance_nn import fused_distance_nn_argmin from .fused_l2_nn import fused_l2_nn_argmin from .pairwise_distance import DISTANCE_TYPES, distance as pairwise_distance -__all__ = ["fused_l2_nn_argmin", "pairwise_distance"] +__all__ = [ + "fused_distance_nn_argmin", + "fused_l2_nn_argmin", + "pairwise_distance", +] diff --git a/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx new file mode 100644 index 0000000000..0e9fa4b366 --- /dev/null +++ b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx @@ -0,0 +1,200 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +from cython.operator cimport dereference as deref +from libc.stdint cimport uintptr_t +from libcpp cimport bool + +from .distance_type cimport DistanceType + +from pylibraft.common import ( + Handle, + auto_convert_output, + cai_wrapper, + device_ndarray, +) +from pylibraft.common.handle import auto_sync_handle + +from pylibraft.common.handle cimport device_resources + + +cdef extern from "raft_runtime/distance/fused_distance_nn.hpp" \ + namespace "raft::runtime::distance" nogil: + + void fused_distance_nn_min_arg( + const device_resources &handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt, + DistanceType metric, + bool isRowMajor, + float metric_arg) except + + + +from pylibraft.distance.pairwise_distance import DISTANCE_TYPES + +SUPPORTED_DISTANCES = ["euclidean", "l2", "cosine", "sqeuclidean"] + + +@auto_sync_handle +@auto_convert_output +def fused_distance_nn_argmin(X, Y, out=None, sqrt=True, metric="euclidean", + handle=None): + """ + Compute the 1-nearest neighbors between X and Y using the distance metrics + + Valid values for metric: + ["euclidean", "l2", "cosine", "sqeuclidean"] + + Parameters + ---------- + + X : CUDA array interface compliant matrix shape (m, k) + Y : CUDA array interface compliant matrix shape (n, k) + out : Writable CUDA array interface matrix shape (m, 1) + metric : string denoting the metric type (default="euclidean") + + {handle_docstring} + + Examples + -------- + To compute the 1-nearest neighbors argmin: + + >>> import cupy as cp + >>> from pylibraft.common import Handle + >>> from pylibraft.distance import fused_distance_nn_argmin + >>> n_samples = 5000 + >>> n_clusters = 5 + >>> n_features = 50 + >>> in1 = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> in2 = cp.random.random_sample((n_clusters, n_features), + ... dtype=cp.float32) + >>> # A single RAFT handle can optionally be reused across + >>> # pylibraft functions. + >>> handle = Handle() + + >>> output = fused_distance_nn_argmin(in1, in2, handle=handle) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + + The output can also be computed in-place on a preallocated + array: + + >>> import cupy as cp + >>> from pylibraft.common import Handle + >>> from pylibraft.distance import fused_distance_nn_argmin + >>> n_samples = 5000 + >>> n_clusters = 5 + >>> n_features = 50 + >>> in1 = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> in2 = cp.random.random_sample((n_clusters, n_features), + ... dtype=cp.float32) + >>> output = cp.empty((n_samples, 1), dtype=cp.int32) + >>> # A single RAFT handle can optionally be reused across + >>> # pylibraft functions. + >>> handle = Handle() + + >>> fused_distance_nn_argmin(in1, in2, out=output, handle=handle) + array(...) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + """ + + x_cai = cai_wrapper(X) + y_cai = cai_wrapper(Y) + + x_dt = x_cai.dtype + y_dt = y_cai.dtype + + m = x_cai.shape[0] + n = y_cai.shape[0] + + if out is None: + output = device_ndarray.empty((m,), dtype="int32") + else: + output = out + + output_cai = cai_wrapper(output) + + x_k = x_cai.shape[1] + y_k = y_cai.shape[1] + + if x_k != y_k: + raise ValueError("Inputs must have same number of columns. " + "a=%s, b=%s" % (x_k, y_k)) + + if metric not in SUPPORTED_DISTANCES: + raise ValueError("metric %s is not supported" % metric) + + cdef DistanceType distance_type = DISTANCE_TYPES[metric] + + x_ptr = x_cai.data + y_ptr = y_cai.data + + d_ptr = output_cai.data + + handle = handle if handle is not None else Handle() + cdef device_resources *h = handle.getHandle() + + d_dt = output_cai.dtype + + x_c_contiguous = x_cai.c_contiguous + y_c_contiguous = y_cai.c_contiguous + + if x_c_contiguous != y_c_contiguous: + raise ValueError("Inputs must have matching strides") + + if not x_c_contiguous: + raise ValueError("Inputs must be C contiguous") + + if x_dt != y_dt: + raise ValueError("Inputs must have the same dtypes") + if d_dt != np.int32: + raise ValueError("Output array must be int32") + # unused arg for now. + metric_arg = 0.0 + if x_dt == np.float32: + fused_distance_nn_min_arg(deref(h), + d_ptr, + x_ptr, + y_ptr, + m, + n, + x_k, + sqrt, + distance_type, + x_c_contiguous, + metric_arg) + else: + raise ValueError("dtype %s not supported" % x_dt) + + return output diff --git a/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py b/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py new file mode 100755 index 0000000000..6736128242 --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from scipy.spatial.distance import cdist + +from pylibraft.common import DeviceResources, device_ndarray +from pylibraft.distance import fused_distance_nn_argmin + + +@pytest.mark.parametrize("inplace", [True, False]) +@pytest.mark.parametrize("n_rows", [10, 100]) +@pytest.mark.parametrize("n_clusters", [50, 100]) +@pytest.mark.parametrize("n_cols", [128, 31]) +@pytest.mark.parametrize("dtype", [np.float32]) +@pytest.mark.parametrize( + "metric", + [ + "euclidean", + "cosine", + "sqeuclidean", + ], +) +def test_fused_distance_nn_minarg( + n_rows, n_cols, n_clusters, dtype, inplace, metric +): + input1 = np.random.random_sample((n_rows, n_cols)) + input1 = np.asarray(input1, order="C").astype(dtype) + + input2 = np.random.random_sample((n_clusters, n_cols)) + input2 = np.asarray(input2, order="C").astype(dtype) + + output = np.zeros((n_rows), dtype="int32") + expected = cdist(input1, input2, metric) + + expected = expected.argmin(axis=1) + + input1_device = device_ndarray(input1) + input2_device = device_ndarray(input2) + output_device = device_ndarray(output) if inplace else None + + is_sqrt = True if metric == "sqeuclidean" else False + handle = DeviceResources() + ret_output = fused_distance_nn_argmin( + input1_device, + input2_device, + output_device, + is_sqrt, + metric, + handle=handle, + ) + handle.sync() + output_device = ret_output if not inplace else output_device + actual = output_device.copy_to_host() + + assert np.allclose(expected, actual, rtol=1e-4) From 0b9692b25f78cd1b27631e354e3f8921a976645c Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 19 Mar 2024 18:07:21 +0100 Subject: [PATCH 10/10] random sampling of dataset rows with improved memory utilization (#2155) The random sampling of IVF methods was reverted (#2144) due to large memory utilization #2141. This PR improves the memory consumption of subsamling: it is O(n_train) where n_train is the size of the subsampled dataset. This PR adds the following new APIs: - random::excess_sampling (todo may just call as sample_without_replacement) - matrix::sample_rows - matrix::gather for host input matrix Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/2155 --- cpp/bench/prims/CMakeLists.txt | 2 +- cpp/bench/prims/matrix/gather.cu | 38 ++++- cpp/bench/prims/random/subsample.cu | 112 ++++++++++++++ cpp/include/raft/matrix/detail/gather.cuh | 87 +++++++++++ .../raft/matrix/detail/sample_rows.cuh | 57 +++++++ cpp/include/raft/matrix/sample_rows.cuh | 75 ++++++++++ cpp/include/raft/random/detail/rng_impl.cuh | 138 ++++++++++++++++- cpp/include/raft/random/rng.cuh | 26 ++++ cpp/test/CMakeLists.txt | 2 + cpp/test/matrix/sample_rows.cu | 140 ++++++++++++++++++ cpp/test/random/excess_sampling.cu | 114 ++++++++++++++ 11 files changed, 786 insertions(+), 5 deletions(-) create mode 100644 cpp/bench/prims/random/subsample.cu create mode 100644 cpp/include/raft/matrix/detail/sample_rows.cuh create mode 100644 cpp/include/raft/matrix/sample_rows.cuh create mode 100644 cpp/test/matrix/sample_rows.cu create mode 100644 cpp/test/random/excess_sampling.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 903f4e4347..95361e19ca 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -128,7 +128,7 @@ if(BUILD_PRIMS_BENCH) ConfigureBench( NAME RANDOM_BENCH PATH bench/prims/random/make_blobs.cu bench/prims/random/permute.cu - bench/prims/random/rng.cu bench/prims/main.cpp + bench/prims/random/rng.cu bench/prims/random/subsample.cu bench/prims/main.cpp ) ConfigureBench(NAME SPARSE_BENCH PATH bench/prims/sparse/convert_csr.cu bench/prims/main.cpp) diff --git a/cpp/bench/prims/matrix/gather.cu b/cpp/bench/prims/matrix/gather.cu index e6f26ba925..078f9e6198 100644 --- a/cpp/bench/prims/matrix/gather.cu +++ b/cpp/bench/prims/matrix/gather.cu @@ -16,34 +16,48 @@ #include +#include +#include #include #include #include #include #include +#include namespace raft::bench::matrix { template struct GatherParams { IdxT rows, cols, map_length; + bool host; }; template inline auto operator<<(std::ostream& os, const GatherParams& p) -> std::ostream& { - os << p.rows << "#" << p.cols << "#" << p.map_length; + os << p.rows << "#" << p.cols << "#" << p.map_length << (p.host ? "#host" : "#device"); return os; } template struct Gather : public fixture { Gather(const GatherParams& p) - : params(p), matrix(this->handle), map(this->handle), out(this->handle), stencil(this->handle) + : params(p), + old_mr(rmm::mr::get_current_device_resource()), + pool_mr(rmm::mr::get_current_device_resource(), 2 * (1ULL << 30)), + matrix(this->handle), + map(this->handle), + out(this->handle), + stencil(this->handle), + matrix_h(this->handle) { + rmm::mr::set_current_device_resource(&pool_mr); } + ~Gather() { rmm::mr::set_current_device_resource(old_mr); } + void allocate_data(const ::benchmark::State& state) override { matrix = raft::make_device_matrix(handle, params.rows, params.cols); @@ -59,6 +73,11 @@ struct Gather : public fixture { if constexpr (Conditional) { raft::random::uniform(handle, rng, stencil.data_handle(), params.map_length, T(-1), T(1)); } + + if (params.host) { + matrix_h = raft::make_host_matrix(handle, params.rows, params.cols); + raft::copy(matrix_h.data_handle(), matrix.data_handle(), matrix.size(), stream); + } resource::sync_stream(handle, stream); } @@ -77,14 +96,22 @@ struct Gather : public fixture { raft::matrix::gather_if( handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op); } else { - raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view()); + if (params.host) { + raft::matrix::detail::gather( + handle, make_const_mdspan(matrix_h.view()), map_const_view, out.view()); + } else { + raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view()); + } } }); } private: GatherParams params; + rmm::mr::device_memory_resource* old_mr; + rmm::mr::pool_memory_resource pool_mr; raft::device_matrix matrix, out; + raft::host_matrix matrix_h; raft::device_vector stencil; raft::device_vector map; }; // struct Gather @@ -100,4 +127,9 @@ RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); + +auto inputs_host = raft::util::itertools::product>( + {10000000}, {100}, {1000, 1000000, 10000000}, {true}); +RAFT_BENCH_REGISTER((Gather), "Host", inputs_host); + } // namespace raft::bench::matrix diff --git a/cpp/bench/prims/random/subsample.cu b/cpp/bench/prims/random/subsample.cu new file mode 100644 index 0000000000..4c8ca2bf31 --- /dev/null +++ b/cpp/bench/prims/random/subsample.cu @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace raft::bench::random { + +struct sample_inputs { + int n_samples; + int n_train; + int method; +}; // struct sample_inputs + +inline auto operator<<(std::ostream& os, const sample_inputs& p) -> std::ostream& +{ + os << p.n_samples << "#" << p.n_train << "#" << p.method; + return os; +} + +// Sample with replacement. We use this as a baseline. +template +auto bernoulli_subsample(raft::resources const& res, IdxT n_samples, IdxT n_subsamples, int seed) + -> raft::device_vector +{ + RAFT_EXPECTS(n_subsamples <= n_samples, "Cannot have more training samples than dataset vectors"); + + auto indices = raft::make_device_vector(res, n_subsamples); + raft::random::RngState state(123456ULL); + raft::random::uniformInt( + res, state, indices.data_handle(), n_subsamples, IdxT(0), IdxT(n_samples)); + return indices; +} + +template +struct sample : public fixture { + sample(const sample_inputs& p) + : params(p), + old_mr(rmm::mr::get_current_device_resource()), + pool_mr(rmm::mr::get_current_device_resource(), 2 * GiB), + in(make_device_vector(res, p.n_samples)), + out(make_device_vector(res, p.n_train)) + { + rmm::mr::set_current_device_resource(&pool_mr); + raft::random::RngState r(123456ULL); + } + + ~sample() { rmm::mr::set_current_device_resource(old_mr); } + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + raft::random::RngState r(123456ULL); + loop_on_state(state, [this, &r]() { + if (params.method == 1) { + this->out = + bernoulli_subsample(this->res, this->params.n_samples, this->params.n_train, 137); + } else if (params.method == 2) { + this->out = raft::random::excess_subsample( + this->res, r, this->params.n_samples, this->params.n_train); + } + }); + } + + private: + float GiB = 1073741824.0f; + raft::device_resources res; + rmm::mr::device_memory_resource* old_mr; + rmm::mr::pool_memory_resource pool_mr; + sample_inputs params; + raft::device_vector out, in; +}; // struct sample + +const std::vector input_vecs = {{100000000, 10000000, 1}, + {100000000, 50000000, 1}, + {100000000, 100000000, 1}, + {100000000, 10000000, 2}, + {100000000, 50000000, 2}, + {100000000, 100000000, 2}}; + +RAFT_BENCH_REGISTER(sample, "", input_vecs); + +} // namespace raft::bench::random diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index 651fec81c3..05cc9204bf 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -16,9 +16,19 @@ #pragma once +#include +#include +#include +#include +#include #include +#include +#include +#include #include +#include + #include namespace raft { @@ -336,6 +346,83 @@ void gather_if(const InputIteratorT in, gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream); } +/** + * Helper function to gather a set of vectors from a (host) dataset. + */ +template +void gather_buff(host_matrix_view dataset, + host_vector_view indices, + MatIdxT offset, + pinned_matrix_view buff) +{ + raft::common::nvtx::range fun_scope("gather_host_buff"); + IdxT batch_size = std::min(buff.extent(0), indices.extent(0) - offset); + +#pragma omp for + for (IdxT i = 0; i < batch_size; i++) { + IdxT in_idx = indices(offset + i); + for (IdxT k = 0; k < buff.extent(1); k++) { + buff(i, k) = dataset(in_idx, k); + } + } +} + +template +void gather(raft::resources const& res, + host_matrix_view dataset, + device_vector_view indices, + raft::device_matrix_view output) +{ + raft::common::nvtx::range fun_scope("gather"); + IdxT n_dim = output.extent(1); + IdxT n_train = output.extent(0); + auto indices_host = raft::make_host_vector(n_train); + raft::copy( + indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res)); + resource::sync_stream(res); + + const size_t buffer_size = 32768 * 1024; // bytes + const size_t max_batch_size = + std::min(round_up_safe(buffer_size / n_dim, 32), n_train); + RAFT_LOG_DEBUG("Gathering data with batch size %zu", max_batch_size); + + // Gather the vector on the host in tmp buffers. We use two buffers to overlap H2D sync + // and gathering the data. + auto out_tmp1 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + auto out_tmp2 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + + // Usually a limited number of threads provide sufficient bandwidth for gathering data. + int n_threads = std::min(omp_get_max_threads(), 32); + + // The gather_buff function has a parallel for loop. We start the the omp parallel + // region here, to avoid repeated overhead within the device_offset loop. +#pragma omp parallel num_threads(n_threads) + { + auto view1 = out_tmp1.view(); + auto view2 = out_tmp2.view(); + gather_buff(dataset, make_const_mdspan(indices_host.view()), (MatIdxT)0, view1); + for (MatIdxT device_offset = 0; device_offset < n_train; device_offset += max_batch_size) { + MatIdxT batch_size = std::min(max_batch_size, n_train - device_offset); + +#pragma omp master + raft::copy(output.data_handle() + device_offset * n_dim, + view1.data_handle(), + batch_size * n_dim, + resource::get_cuda_stream(res)); + // Start gathering the next batch on the host. + MatIdxT host_offset = device_offset + batch_size; + batch_size = std::min(max_batch_size, n_train - host_offset); + if (batch_size > 0) { + gather_buff(dataset, make_const_mdspan(indices_host.view()), host_offset, view2); + } +#pragma omp master + resource::sync_stream(res); +#pragma omp barrier + std::swap(view1, view2); + } + } +} + } // namespace detail } // namespace matrix } // namespace raft diff --git a/cpp/include/raft/matrix/detail/sample_rows.cuh b/cpp/include/raft/matrix/detail/sample_rows.cuh new file mode 100644 index 0000000000..e28ad648da --- /dev/null +++ b/cpp/include/raft/matrix/detail/sample_rows.cuh @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::matrix::detail { + +/** Select rows randomly from input and copy to output. */ +template +void sample_rows(raft::resources const& res, + random::RngState random_state, + const T* input, + IdxT n_rows_input, + raft::device_matrix_view output) +{ + IdxT n_dim = output.extent(1); + IdxT n_samples = output.extent(0); + + raft::device_vector train_indices = + raft::random::excess_subsample(res, random_state, n_rows_input, n_samples); + + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, input)); + T* ptr = reinterpret_cast(attr.devicePointer); + if (ptr != nullptr) { + raft::matrix::gather(res, + raft::make_device_matrix_view(ptr, n_rows_input, n_dim), + raft::make_const_mdspan(train_indices.view()), + output); + } else { + auto dataset = raft::make_host_matrix_view(input, n_rows_input, n_dim); + raft::matrix::detail::gather(res, dataset, make_const_mdspan(train_indices.view()), output); + } +} +} // namespace raft::matrix::detail diff --git a/cpp/include/raft/matrix/sample_rows.cuh b/cpp/include/raft/matrix/sample_rows.cuh new file mode 100644 index 0000000000..7925d344e4 --- /dev/null +++ b/cpp/include/raft/matrix/sample_rows.cuh @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft::matrix { + +/** @brief Select rows randomly from input and copy to output. + * + * The rows are selected randomly. The random sampling method does not guarantee completely unique + * selection of rows, but it is close to being unique. + * + * @param res RAFT resource handle + * @param random_state + * @param dataset input dataset + * @param output subsampled dataset + */ +template +void sample_rows(raft::resources const& res, + random::RngState random_state, + mdspan, row_major, accessor> dataset, + raft::device_matrix_view output) +{ + RAFT_EXPECTS(dataset.extent(1) == output.extent(1), + "dataset dims must match, but received %ld vs %ld", + static_cast(dataset.extent(1)), + static_cast(output.extent(1))); + detail::sample_rows(res, random_state, dataset.data_handle(), dataset.extent(0), output); +} + +/** @brief Select rows randomly from input and copy to output. + * + * The rows are selected randomly. The random sampling method does not guarantee completely unique + * selection of rows, but it is close to being unique. + * + * @param res RAFT resource handle + * @param random_state + * @param dataset input dataset + * @param n_samples number of rows in the returned matrix + * + * @return subsampled dataset + * */ +template +raft::device_matrix sample_rows( + raft::resources const& res, + random::RngState random_state, + mdspan, row_major, accessor> dataset, + IdxT n_samples) +{ + auto output = raft::make_device_matrix(res, n_samples, dataset.extent(1)); + sample_rows(res, random_state, dataset, output.view()); + return output; +} + +} // namespace raft::matrix diff --git a/cpp/include/raft/random/detail/rng_impl.cuh b/cpp/include/raft/random/detail/rng_impl.cuh index 57f4c8d33d..61a944e9b6 100644 --- a/cpp/include/raft/random/detail/rng_impl.cuh +++ b/cpp/include/raft/random/detail/rng_impl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,12 +17,20 @@ #pragma once #include +#include +#include +#include +#include #include #include #include #include #include +#include + +#include + namespace raft { namespace random { namespace detail { @@ -278,6 +286,7 @@ std::enable_if_t> discrete(RngState& rng_state, len); } +/** Note the memory space requirements are O(4*len) */ template void sampleWithoutReplacement(RngState& rng_state, DataT* out, @@ -328,6 +337,133 @@ void affine_transform_params(RngState const& rng_state, IdxT n, IdxT& a, IdxT& b b = mt_rng() % n; } +/** @brief Sample without replacement from range 0..N-1. + * + * Elements are sampled uniformly. + * The algorithm will allocate a workspace of size O(4*n_samples) internally. + * + * We use max N random numbers. Depending on how large n_samples is w.r.t to N, we + * either use rejection sampling, or sort the [0..N-1] values using random keys. + * + * @tparam IdxT type of indices that we sample + * @tparam MatIdxT extent type of the returned mdarray + * + * @param res RAFT resource handle + * @param state random number generator state + * @param N number of elements to sample from. We will sample values in range 0..N-1 + * @param n_samples number of samples to return + * + * @return device mdarray with the random samples + */ +template +auto excess_subsample(raft::resources const& res, RngState& state, IdxT N, IdxT n_samples) + -> raft::device_vector +{ + RAFT_EXPECTS(n_samples <= N, "Cannot have more training samples than dataset vectors"); + + // Number of samples we'll need to sample (with replacement), to expect 'k' + // unique samples from 'n' is given by the following equation: log(1 - k/n)/log(1 - 1/n) ref: + // https://stats.stackexchange.com/questions/296005/the-expected-number-of-unique-elements-drawn-with-replacement + IdxT n_excess_samples = + n_samples < N + ? std::ceil(raft::log(1 - double(n_samples) / double(N)) / (raft::log(1 - 1 / double(N)))) + : N; + + // There is a variance of n_excess_samples, we take 10% more elements. + n_excess_samples += std::max(0.1 * n_samples, 100); + + // n_excess_sampless will be larger than N around k = 0.64*N. When we reach N, then instead of + // doing rejection sampling, we simply shuffle the range [0..N-1] using N random numbers. + n_excess_samples = std::min(n_excess_samples, N); + auto rnd_idx = raft::make_device_vector(res, n_excess_samples); + + auto linear_idx = raft::make_device_vector(res, rnd_idx.size()); + raft::linalg::map_offset(res, linear_idx.view(), identity_op()); + + uniformInt(res, state, rnd_idx.data_handle(), rnd_idx.size(), IdxT(0), IdxT(N)); + + // Sort indices according to rnd keys + size_t workspace_size = 0; + auto stream = resource::get_cuda_stream(res); + cub::DeviceMergeSort::SortPairs(nullptr, + workspace_size, + rnd_idx.data_handle(), + linear_idx.data_handle(), + rnd_idx.size(), + raft::less_op{}, + stream); + auto workspace = raft::make_device_vector(res, workspace_size); + cub::DeviceMergeSort::SortPairs(workspace.data_handle(), + workspace_size, + rnd_idx.data_handle(), + linear_idx.data_handle(), + rnd_idx.size(), + raft::less_op{}, + stream); + + if (rnd_idx.size() == static_cast(N)) { + // We shuffled the linear_idx array by sorting it according to rnd_idx. + // We return the first n_samples elements. + if (n_samples == N) { return linear_idx; } + rnd_idx = raft::make_device_vector(res, n_samples); + raft::copy(rnd_idx.data_handle(), linear_idx.data_handle(), n_samples, stream); + return rnd_idx; + } + // Else we do a rejection sampling (or excess sampling): we generated more random indices than + // needed and reject the duplicates. + auto keys_out = raft::make_device_vector(res, rnd_idx.size()); + auto values_out = raft::make_device_vector(res, rnd_idx.size()); + rmm::device_scalar num_selected(stream); + size_t worksize2 = 0; + cub::DeviceSelect::UniqueByKey(nullptr, + worksize2, + rnd_idx.data_handle(), + linear_idx.data_handle(), + keys_out.data_handle(), + values_out.data_handle(), + num_selected.data(), + rnd_idx.size(), + stream); + + if (worksize2 > workspace.size()) { + workspace = raft::make_device_vector(res, worksize2); + workspace_size = workspace.size(); + } + + cub::DeviceSelect::UniqueByKey(workspace.data_handle(), + workspace_size, + rnd_idx.data_handle(), + linear_idx.data_handle(), + keys_out.data_handle(), + values_out.data_handle(), + num_selected.data(), + rnd_idx.size(), + stream); + + IdxT selected = num_selected.value(stream); + + if (selected < n_samples) { + RAFT_LOG_DEBUG("Subsampling returned with less unique indices (%zu) than requested (%zu)", + (size_t)selected, + (size_t)n_samples); + + // We continue to select n_samples elements, this will now contains a few duplicates. + } + + // After duplicates are removed, we need to shuffle back to random order + cub::DeviceMergeSort::SortPairs(workspace.data_handle(), + workspace_size, + values_out.data_handle(), + keys_out.data_handle(), + n_samples, + raft::less_op{}, + stream); + + values_out = raft::make_device_vector(res, n_samples); + raft::copy(values_out.data_handle(), keys_out.data_handle(), n_samples, stream); + return values_out; +} + }; // end namespace detail }; // end namespace random }; // end namespace raft diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 4e63669f98..7fd461980f 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -813,6 +813,32 @@ void sampleWithoutReplacement(raft::resources const& handle, rng_state, out, outIdx, in, wts, sampledLen, len, resource::get_cuda_stream(handle)); } +/** @brief Sample from range 0..N-1. + * + * Elements are sampled uniformly. The method aims to sample without replacement, + * but there is a small probability of a few having duplicate elements. + * + * The algorithm will allocate a workspace of size 4*n_samples*sizeof(IdxT) internally. + * + * We use max N random numbers. Depending on how large n_samples is w.r.t to N, we + * either use rejection sampling, or sort the [0..N-1] values using random keys. + * + * @tparam IdxT type of indices that we sample + * @tparam MatIdxT extent type of the returned mdarray + * + * @param res RAFT resource handle + * @param state random number generator state + * @param N number of elements to sample from. We will sample values in range 0..N-1. + * @param n_samples number of samples to return + * + * @return device mdarray with the random samples + */ +template +auto excess_subsample(raft::resources const& res, RngState& state, IdxT N, IdxT n_samples) +{ + return detail::excess_subsample(res, state, N, n_samples); +} + /** * @brief Generates the 'a' and 'b' parameters for a modulo affine * transformation equation: `(ax + b) % n` diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index bf44cf9c60..ecb871fccc 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -267,6 +267,7 @@ if(BUILD_TESTS) test/matrix/matrix.cu test/matrix/norm.cu test/matrix/reverse.cu + test/matrix/sample_rows.cu test/matrix/slice.cu test/matrix/triangular.cu test/sparse/spectral_matrix.cu @@ -294,6 +295,7 @@ if(BUILD_TESTS) test/random/rng_int.cu test/random/rmat_rectangular_generator.cu test/random/sample_without_replacement.cu + test/random/excess_sampling.cu ) ConfigureTest( diff --git a/cpp/test/matrix/sample_rows.cu b/cpp/test/matrix/sample_rows.cu new file mode 100644 index 0000000000..e332a918fe --- /dev/null +++ b/cpp/test/matrix/sample_rows.cu @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace raft { +namespace matrix { + +struct inputs { + int N; + int dim; + int n_samples; + bool host; +}; + +::std::ostream& operator<<(::std::ostream& os, const inputs p) +{ + os << p.N << "#" << p.dim << "#" << p.n_samples << (p.host ? "#host" : "#device"); + return os; +} + +template +class SampleRowsTest : public ::testing::TestWithParam { + public: + SampleRowsTest() + : params(::testing::TestWithParam::GetParam()), + stream(resource::get_cuda_stream(res)), + state{137ULL}, + in(make_device_matrix(res, params.N, params.dim)), + out(make_device_matrix(res, 0, 0)), + in_h(make_host_matrix(res, params.N, params.dim)), + out_h(make_host_matrix(res, params.n_samples, params.dim)) + { + raft::random::uniform(res, state, in.data_handle(), in.size(), T(-1.0), T(1.0)); + for (int64_t i = 0; i < params.N; i++) { + for (int64_t k = 0; k < params.dim; k++) + in_h(i, k) = i * 1000 + k; + } + raft::copy(in.data_handle(), in_h.data_handle(), in_h.size(), stream); + } + + void check() + { + if (params.host) { + out = raft::matrix::sample_rows( + res, state, make_const_mdspan(in_h.view()), (int64_t)params.n_samples); + } else { + out = raft::matrix::sample_rows( + res, state, make_const_mdspan(in.view()), (int64_t)params.n_samples); + } + + raft::copy(out_h.data_handle(), out.data_handle(), out.size(), stream); + resource::sync_stream(res, stream); + + ASSERT_TRUE(out.extent(0) == params.n_samples); + ASSERT_TRUE(out.extent(1) == params.dim); + + std::unordered_set occurrence; + + for (int64_t i = 0; i < params.n_samples; ++i) { + T val = out_h(i, 0) / 1000; + ASSERT_TRUE(0 <= val && val < params.N) + << "out-of-range index @i=" << i << " val=" << val << " params=" << params; + EXPECT_TRUE(occurrence.find(val) == occurrence.end()) + << "repeated index @i=" << i << " idx=" << val << " params=" << params; + occurrence.insert(val); + for (int64_t k = 0; k < params.dim; k++) { + ASSERT_TRUE(raft::match(out_h(i, k), val * 1000 + k, raft::CompareApprox(1e-6))); + } + } + } + + protected: + inputs params; + raft::resources res; + cudaStream_t stream; + random::RngState state; + device_matrix in, out; + host_matrix in_h, out_h; +}; + +inline std::vector generate_inputs() +{ + std::vector input1 = + raft::util::itertools::product({10}, {1, 17, 96}, {1, 6, 9, 10}, {false}); + + std::vector input2 = + raft::util::itertools::product({137}, {1, 17, 128}, {1, 10, 100, 137}, {false}); + input1.insert(input1.end(), input2.begin(), input2.end()); + + input2 = raft::util::itertools::product( + {100000}, {1, 42}, {1, 137, 1000, 10000, 50000, 62000, 100000}, {false}); + + input1.insert(input1.end(), input2.begin(), input2.end()); + + int n = input1.size(); + // Add same tests for host data + for (int i = 0; i < n; i++) { + inputs x = input1[i]; + x.host = true; + input1.push_back(x); + } + return input1; +} + +const std::vector inputs1 = generate_inputs(); + +using SampleRowsTestInt64 = SampleRowsTest; +TEST_P(SampleRowsTestInt64, SamplingTest) { check(); } +INSTANTIATE_TEST_SUITE_P(SampleRowsTests, SampleRowsTestInt64, ::testing::ValuesIn(inputs1)); + +} // namespace matrix +} // namespace raft diff --git a/cpp/test/random/excess_sampling.cu b/cpp/test/random/excess_sampling.cu new file mode 100644 index 0000000000..e86436fb7d --- /dev/null +++ b/cpp/test/random/excess_sampling.cu @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace raft { +namespace random { + +using namespace raft::random; + +struct inputs { + int64_t N; + int64_t n_samples; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const inputs p) +{ + os << p.N << "/" << p.n_samples; + return os; +} + +template +class ExcessSamplingTest : public ::testing::TestWithParam { + public: + ExcessSamplingTest() + : params(::testing::TestWithParam::GetParam()), + stream(resource::get_cuda_stream(res)), + state{137ULL} + { + } + + void check() + { + device_vector out = + raft::random::excess_subsample(res, state, params.N, params.n_samples); + ASSERT_TRUE(out.extent(0) == params.n_samples); + + auto h_out = make_host_vector(res, params.n_samples); + raft::copy(h_out.data_handle(), out.data_handle(), out.size(), stream); + resource::sync_stream(res, stream); + + std::unordered_set occurrence; + int64_t sum = 0; + for (int64_t i = 0; i < params.n_samples; ++i) { + T val = h_out(i); + sum += val; + ASSERT_TRUE(0 <= val && val < params.N) + << "out-of-range index @i=" << i << " val=" << val << " n_samples=" << params.n_samples; + ASSERT_TRUE(occurrence.find(val) == occurrence.end()) + << "repeated index @i=" << i << " idx=" << val; + occurrence.insert(val); + } + float avg = sum / (float)params.n_samples; + if (params.n_samples >= 100 && params.N / params.n_samples < 100) { + ASSERT_TRUE(raft::match(avg, (params.N - 1) / 2.0f, raft::CompareApprox(0.2))) + << "non-uniform sample"; + } + } + + protected: + inputs params; + raft::resources res; + cudaStream_t stream; + RngState state; +}; + +const std::vector input1 = {{1, 0}, + {1, 1}, + {10, 0}, + {10, 1}, + {10, 2}, + {10, 10}, + {137, 42}, + {200, 0}, + {200, 1}, + {200, 100}, + {200, 130}, + {200, 200}, + {10000, 893}, + {10000000000, 1023}}; + +using ExcessSamplingTestInt64 = ExcessSamplingTest; +TEST_P(ExcessSamplingTestInt64, SamplingTest) { check(); } +INSTANTIATE_TEST_SUITE_P(ExcessSamplingTests, ExcessSamplingTestInt64, ::testing::ValuesIn(input1)); + +} // namespace random +} // namespace raft