From b203e3f0e25058e47ceae76b2a13baf85c9ef8b4 Mon Sep 17 00:00:00 2001 From: Micka Date: Fri, 15 Mar 2024 17:25:54 +0100 Subject: [PATCH 01/18] Add `compile-library` by default on pylibraft build (#2090) This will avoid confusion for users launching only `./build.sh pylibraft`. Authors: - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Peter Andreas Entschev (https://github.com/pentschev) - Corey J. Nolet (https://github.com/cjnolet) - Robert Maynard (https://github.com/robertmaynard) - Ray Douglass (https://github.com/raydouglass) URL: https://github.com/rapidsai/raft/pull/2090 --- build.sh | 4 ++-- conda/recipes/libraft/build_libraft_template.sh | 4 ++-- cpp/template/README.md | 2 +- cpp/template/build.sh | 2 +- cpp/template/cmake/thirdparty/get_raft.cmake | 9 +++++++-- 5 files changed, 13 insertions(+), 8 deletions(-) diff --git a/build.sh b/build.sh index 148d23c9c1..45c7d1380f 100755 --- a/build.sh +++ b/build.sh @@ -305,7 +305,7 @@ if hasArg --allgpuarch; then BUILD_ALL_GPU_ARCH=1 fi -if hasArg --compile-lib || (( ${NUMARGS} == 0 )); then +if hasArg --compile-lib || hasArg pylibraft || (( ${NUMARGS} == 0 )); then COMPILE_LIBRARY=ON CMAKE_TARGET="${CMAKE_TARGET};raft_lib" fi @@ -405,7 +405,7 @@ fi ################################################################################ # Configure for building all C++ targets -if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || hasArg bench-prims || hasArg bench-ann; then +if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || hasArg bench-prims || hasArg bench-ann || ((${COMPILE_LIBRARY} == ON )); then if (( ${BUILD_ALL_GPU_ARCH} == 0 )); then RAFT_CMAKE_CUDA_ARCHITECTURES="NATIVE" echo "Building for the architecture of the GPU in the system..." diff --git a/conda/recipes/libraft/build_libraft_template.sh b/conda/recipes/libraft/build_libraft_template.sh index bd7719af76..86c0fa11b6 100644 --- a/conda/recipes/libraft/build_libraft_template.sh +++ b/conda/recipes/libraft/build_libraft_template.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Just building template so we verify it uses libraft.so and fail if it doesn't build -./build.sh template +./build.sh template --no-nvtx diff --git a/cpp/template/README.md b/cpp/template/README.md index 348dff270a..05ec48964f 100644 --- a/cpp/template/README.md +++ b/cpp/template/README.md @@ -8,7 +8,7 @@ Once the minimum requirements are satisfied, this example template application c This directory (`RAFT_SOURCE/cpp/template`) can be copied directly in order to build a new application with RAFT. -RAFT can be integrated into an existing CMake project by copying the contents in the `configure rapids-cmake` and `configure raft` sections of the provided `CMakeLists.txt` into your project, along with `cmake/thirdparty/get_raft.cmake`. +RAFT can be integrated into an existing CMake project by copying the contents in the `configure rapids-cmake` and `configure raft` sections of the provided `CMakeLists.txt` into your project, along with `cmake/thirdparty/get_raft.cmake`. Make sure to link against the appropriate Cmake targets. Use `raft::raft`to add make the headers available and `raft::compiled` when utilizing the shared library. diff --git a/cpp/template/build.sh b/cpp/template/build.sh index 3ac00fc9af..49c17f7499 100755 --- a/cpp/template/build.sh +++ b/cpp/template/build.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # raft empty project template build script diff --git a/cpp/template/cmake/thirdparty/get_raft.cmake b/cpp/template/cmake/thirdparty/get_raft.cmake index 6128b5c43c..07b0897be0 100644 --- a/cpp/template/cmake/thirdparty/get_raft.cmake +++ b/cpp/template/cmake/thirdparty/get_raft.cmake @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -33,6 +33,12 @@ function(find_and_configure_raft) #----------------------------------------------------- # Invoke CPM find_package() #----------------------------------------------------- + # Since the RAFT_NVTX option is used by targets generated by + # find_package(RAFT_NVTX) and when building from source we want to + # make `RAFT_NVTX` a cache variable so we get consistent + # behavior + # + set(RAFT_NVTX ${PKG_ENABLE_NVTX} CACHE BOOL "Enable raft nvtx logging" FORCE) rapids_cpm_find(raft ${PKG_VERSION} GLOBAL_TARGETS raft::raft BUILD_EXPORT_SET raft-template-exports @@ -46,7 +52,6 @@ function(find_and_configure_raft) "BUILD_TESTS OFF" "BUILD_PRIMS_BENCH OFF" "BUILD_ANN_BENCH OFF" - "RAFT_NVTX ${ENABLE_NVTX}" "RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}" ) endfunction() From 36484f48fe5dd60939fb7e0610d9c5091c0780b2 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Fri, 15 Mar 2024 14:25:10 -0700 Subject: [PATCH 02/18] MAINT: Simplify NCCL worker rank identification (#2228) This PR is based on @seberg work in https://github.com/rapidsai/raft/pull/1928 . From the PR: This is a follow up on https://github.com/rapidsai/raft/pull/1926, since the rank sorting seemed a bit hard to understand. It does modify the logic in the sense that the host is now sorted by IP as a way to group based on it. But I don't really think that host sorting was ever a goal? If the goal is really about being deterministic, then this should be more (or at least clearer) deterministic about order of worker IPs. OTOH, if the NVML device order doesn't matter, we could just sort the workers directly. The original https://github.com/rapidsai/raft/pull/1587 mentions: NCCL>1.11 expects a process with rank r to be mapped to r % num_gpus_per_node which is something that neither approach seems to quite assure, if such a requirement exists, I would want to do one of: Ensure we can guarantee this, but this requires initializing workers that are not involved in the operation. At least raise an error, because if NCCL will end up raising the error it will be very confusing. Authors: - Vibhu Jawa (https://github.com/VibhuJawa) - Sebastian Berg (https://github.com/seberg) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2228 --- python/raft-dask/raft_dask/common/comms.py | 95 +++---------------- python/raft-dask/raft_dask/test/test_comms.py | 15 ++- 2 files changed, 27 insertions(+), 83 deletions(-) diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index 118293c093..b2f7d1fb74 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,8 +18,7 @@ import time import uuid import warnings -from collections import Counter, OrderedDict, defaultdict -from typing import Dict +from collections import OrderedDict from dask.distributed import default_client from dask_cuda.utils import nvml_device_index @@ -691,9 +690,11 @@ def _func_ucp_ports(client, workers): def _func_worker_ranks(client, workers): """ - For each worker connected to the client, - compute a global rank which is the sum - of the NVML device index and the worker rank offset. + For each worker connected to the client, compute a global rank which takes + into account the NVML device index and the worker IP + (group workers on same host and order by NVML device). + Note that the reason for sorting was nvbug 4149999 and is presumably + fixed afterNCCL 2.19.3. Parameters ---------- @@ -703,13 +704,13 @@ def _func_worker_ranks(client, workers): # TODO: Add Test this function # Running into build issues preventing testing nvml_device_index_d = client.run(_get_nvml_device_index, workers=workers) - worker_ips = [ - _get_worker_ip(worker_address) - for worker_address in nvml_device_index_d + # Sort workers first by IP and then by the nvml device index: + worker_info_list = [ + (_get_worker_ip(worker), nvml_device_index, worker) + for worker, nvml_device_index in nvml_device_index_d.items() ] - ranks = _map_nvml_device_id_to_contiguous_range(nvml_device_index_d) - worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips) - return _append_rank_offset(ranks, worker_ip_offset_dict) + worker_info_list.sort() + return {wi[2]: i for i, wi in enumerate(worker_info_list)} def _get_nvml_device_index(): @@ -730,73 +731,3 @@ def _get_worker_ip(worker_address): worker_address (str): Full address string of the worker """ return ":".join(worker_address.split(":")[0:2]) - - -def _map_nvml_device_id_to_contiguous_range(nvml_device_index_d: dict) -> dict: - """ - For each worker address in nvml_device_index_d, map the corresponding - worker rank in the range(0, num_workers_per_node) where rank is decided - by the NVML device index. Worker with the lowest NVML device index gets - rank 0, and worker with the highest NVML device index gets rank - num_workers_per_node-1. - - Parameters - ---------- - nvml_device_index_d : dict - Dictionary of worker addresses mapped to their nvml device index. - - Returns - ------- - dict - Updated dictionary with worker addresses mapped to their rank. - """ - - rank_per_ip: Dict[str, int] = defaultdict(int) - - # Sort by NVML index to ensure that the worker - # with the lowest NVML index gets rank 0. - for worker, _ in sorted(nvml_device_index_d.items(), key=lambda x: x[1]): - ip = _get_worker_ip(worker) - - nvml_device_index_d[worker] = rank_per_ip[ip] - rank_per_ip[ip] += 1 - - return nvml_device_index_d - - -def _get_rank_offset_across_nodes(worker_ips): - """ - Get a dictionary of worker IP addresses mapped to the cumulative count of - their occurrences in the worker_ips list. The cumulative count serves as - the rank offset. - - Parameters - ---------- - worker_ips (list): List of worker IP addresses. - """ - worker_count_dict = Counter(worker_ips) - worker_offset_dict = {} - current_offset = 0 - for worker_ip, worker_count in worker_count_dict.items(): - worker_offset_dict[worker_ip] = current_offset - current_offset += worker_count - return worker_offset_dict - - -def _append_rank_offset(rank_dict, worker_ip_offset_dict): - """ - For each worker address in the rank dictionary, add the - corresponding worker offset from the worker_ip_offset_dict - to the rank value. - - Parameters - ---------- - rank_dict (dict): Dictionary of worker addresses mapped to their ranks. - worker_ip_offset_dict (dict): Dictionary of worker IP addresses - mapped to their offsets. - """ - for worker_ip, worker_offset in worker_ip_offset_dict.items(): - for worker_address in rank_dict: - if worker_ip in worker_address: - rank_dict[worker_address] += worker_offset - return rank_dict diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 68c9fee556..b62d7185b2 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -354,3 +354,16 @@ def test_device_multicast_sendrecv(n_trials, client): wait(dfs, timeout=5) assert list(map(lambda x: x.result(), dfs)) + + +@pytest.mark.nccl +@pytest.mark.parametrize( + "subset", [slice(-1, None), slice(1), slice(None, None, -2)] +) +def test_comm_init_worker_subset(client, subset): + # Basic test that initializing a subset of workers is fine + cb = Comms(comms_p2p=True, verbose=True) + + workers = list(client.scheduler_info()["workers"].keys()) + workers = workers[subset] + cb.init(workers=workers) From 7335267e7991d5eacef3d445968108a95d0f800a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Malte=20F=C3=B6rster?= <97973773+mfoerste4@users.noreply.github.com> Date: Sat, 16 Mar 2024 05:25:48 +0100 Subject: [PATCH 03/18] Fix illegal acces mean/stdev, sum add Kahan Summation (#2223) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR addresses #2204 and #2205. * fixes illegal access / test coverage for mean row-wise kernel * fixes illegal access / test coverage for stdev row-wise kernel * modified sum kernels to utilize Kahan/Neumaier summation per thread, also increase load per thread to benefit from this FYI, @tfeher Authors: - Malte Förster (https://github.com/mfoerste4) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/2223 --- cpp/include/raft/stats/detail/mean.cuh | 4 +- cpp/include/raft/stats/detail/stddev.cuh | 4 +- cpp/include/raft/stats/detail/sum.cuh | 78 +++++++++++++++++---- cpp/test/stats/mean.cu | 66 +++++++++--------- cpp/test/stats/minmax.cu | 66 +++++++----------- cpp/test/stats/stddev.cu | 55 ++++++++++++--- cpp/test/stats/sum.cu | 89 +++++++++++++++++------- 7 files changed, 236 insertions(+), 126 deletions(-) diff --git a/cpp/include/raft/stats/detail/mean.cuh b/cpp/include/raft/stats/detail/mean.cuh index cf4dbc7aa3..6c330acb26 100644 --- a/cpp/include/raft/stats/detail/mean.cuh +++ b/cpp/include/raft/stats/detail/mean.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ RAFT_KERNEL meanKernelRowMajor(Type* mu, const Type* data, IdxType D, IdxType N) __syncthreads(); raft::myAtomicAdd(smu + thisColId, thread_data); __syncthreads(); - if (threadIdx.x < ColsPerBlk) raft::myAtomicAdd(mu + colId, smu[thisColId]); + if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(mu + colId, smu[thisColId]); } template diff --git a/cpp/include/raft/stats/detail/stddev.cuh b/cpp/include/raft/stats/detail/stddev.cuh index acee4a944e..bc2644a233 100644 --- a/cpp/include/raft/stats/detail/stddev.cuh +++ b/cpp/include/raft/stats/detail/stddev.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ RAFT_KERNEL stddevKernelRowMajor(Type* std, const Type* data, IdxType D, IdxType __syncthreads(); raft::myAtomicAdd(sstd + thisColId, thread_data); __syncthreads(); - if (threadIdx.x < ColsPerBlk) raft::myAtomicAdd(std + colId, sstd[thisColId]); + if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(std + colId, sstd[thisColId]); } template diff --git a/cpp/include/raft/stats/detail/sum.cuh b/cpp/include/raft/stats/detail/sum.cuh index bb45eb50f4..4f85536e6c 100644 --- a/cpp/include/raft/stats/detail/sum.cuh +++ b/cpp/include/raft/stats/detail/sum.cuh @@ -34,30 +34,72 @@ RAFT_KERNEL sumKernelRowMajor(Type* mu, const Type* data, IdxType D, IdxType N) IdxType thisRowId = threadIdx.x / ColsPerBlk; IdxType colId = thisColId + ((IdxType)blockIdx.y * ColsPerBlk); IdxType rowId = thisRowId + ((IdxType)blockIdx.x * RowsPerBlkPerIter); - Type thread_data = Type(0); + Type thread_sum = Type(0); const IdxType stride = RowsPerBlkPerIter * gridDim.x; - for (IdxType i = rowId; i < N; i += stride) - thread_data += (colId < D) ? data[i * D + colId] : Type(0); + for (IdxType i = rowId; i < N; i += stride) { + thread_sum += (colId < D) ? data[i * D + colId] : Type(0); + } __shared__ Type smu[ColsPerBlk]; if (threadIdx.x < ColsPerBlk) smu[threadIdx.x] = Type(0); __syncthreads(); - raft::myAtomicAdd(smu + thisColId, thread_data); + raft::myAtomicAdd(smu + thisColId, thread_sum); + __syncthreads(); + if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(mu + colId, smu[thisColId]); +} + +template +RAFT_KERNEL sumKahanKernelRowMajor(Type* mu, const Type* data, IdxType D, IdxType N) +{ + constexpr int RowsPerBlkPerIter = TPB / ColsPerBlk; + IdxType thisColId = threadIdx.x % ColsPerBlk; + IdxType thisRowId = threadIdx.x / ColsPerBlk; + IdxType colId = thisColId + ((IdxType)blockIdx.y * ColsPerBlk); + IdxType rowId = thisRowId + ((IdxType)blockIdx.x * RowsPerBlkPerIter); + Type thread_sum = Type(0); + Type thread_c = Type(0); + const IdxType stride = RowsPerBlkPerIter * gridDim.x; + for (IdxType i = rowId; i < N; i += stride) { + // KahanBabushkaNeumaierSum + const Type cur_value = (colId < D) ? data[i * D + colId] : Type(0); + const Type t = thread_sum + cur_value; + if (abs(thread_sum) >= abs(cur_value)) { + thread_c += (thread_sum - t) + cur_value; + } else { + thread_c += (cur_value - t) + thread_sum; + } + thread_sum = t; + } + thread_sum += thread_c; + __shared__ Type smu[ColsPerBlk]; + if (threadIdx.x < ColsPerBlk) smu[threadIdx.x] = Type(0); + __syncthreads(); + raft::myAtomicAdd(smu + thisColId, thread_sum); __syncthreads(); if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(mu + colId, smu[thisColId]); } template -RAFT_KERNEL sumKernelColMajor(Type* mu, const Type* data, IdxType D, IdxType N) +RAFT_KERNEL sumKahanKernelColMajor(Type* mu, const Type* data, IdxType D, IdxType N) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - Type thread_data = Type(0); + Type thread_sum = Type(0); + Type thread_c = Type(0); IdxType colStart = N * blockIdx.x; for (IdxType i = threadIdx.x; i < N; i += TPB) { - IdxType idx = colStart + i; - thread_data += data[idx]; + // KahanBabushkaNeumaierSum + IdxType idx = colStart + i; + const Type cur_value = data[idx]; + const Type t = thread_sum + cur_value; + if (abs(thread_sum) >= abs(cur_value)) { + thread_c += (thread_sum - t) + cur_value; + } else { + thread_c += (cur_value - t) + thread_sum; + } + thread_sum = t; } - Type acc = BlockReduce(temp_storage).Sum(thread_data); + thread_sum += thread_c; + Type acc = BlockReduce(temp_storage).Sum(thread_sum); if (threadIdx.x == 0) { mu[blockIdx.x] = acc; } } @@ -66,15 +108,21 @@ void sum(Type* output, const Type* input, IdxType D, IdxType N, bool rowMajor, c { static const int TPB = 256; if (rowMajor) { - static const int RowsPerThread = 4; - static const int ColsPerBlk = 32; - static const int RowsPerBlk = (TPB / ColsPerBlk) * RowsPerThread; - dim3 grid(raft::ceildiv(N, (IdxType)RowsPerBlk), raft::ceildiv(D, (IdxType)ColsPerBlk)); + static const int ColsPerBlk = 8; + static const int MinRowsPerThread = 16; + static const int MinRowsPerBlk = (TPB / ColsPerBlk) * MinRowsPerThread; + static const int MaxBlocksDimX = 8192; + + const IdxType grid_y = raft::ceildiv(D, (IdxType)ColsPerBlk); + const IdxType grid_x = + raft::min((IdxType)MaxBlocksDimX, raft::ceildiv(N, (IdxType)MinRowsPerBlk)); + + dim3 grid(grid_x, grid_y); RAFT_CUDA_TRY(cudaMemset(output, 0, sizeof(Type) * D)); - sumKernelRowMajor + sumKahanKernelRowMajor <<>>(output, input, D, N); } else { - sumKernelColMajor<<>>(output, input, D, N); + sumKahanKernelColMajor<<>>(output, input, D, N); } RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/test/stats/mean.cu b/cpp/test/stats/mean.cu index 67931e2eed..61b57ce739 100644 --- a/cpp/test/stats/mean.cu +++ b/cpp/test/stats/mean.cu @@ -95,39 +95,39 @@ class MeanTest : public ::testing::TestWithParam> { // Note: For 1024 samples, 256 experiments, a mean of 1.0 with stddev=1.0, the // measured mean (of a normal distribution) will fall outside of an epsilon of // 0.15 only 4/10000 times. (epsilon of 0.1 will fail 30/100 times) -const std::vector> inputsf = {{0.15f, 1.f, 1024, 32, true, false, 1234ULL}, - {0.15f, 1.f, 1024, 64, true, false, 1234ULL}, - {0.15f, 1.f, 1024, 128, true, false, 1234ULL}, - {0.15f, 1.f, 1024, 256, true, false, 1234ULL}, - {0.15f, -1.f, 1024, 32, false, false, 1234ULL}, - {0.15f, -1.f, 1024, 64, false, false, 1234ULL}, - {0.15f, -1.f, 1024, 128, false, false, 1234ULL}, - {0.15f, -1.f, 1024, 256, false, false, 1234ULL}, - {0.15f, 1.f, 1024, 32, true, true, 1234ULL}, - {0.15f, 1.f, 1024, 64, true, true, 1234ULL}, - {0.15f, 1.f, 1024, 128, true, true, 1234ULL}, - {0.15f, 1.f, 1024, 256, true, true, 1234ULL}, - {0.15f, -1.f, 1024, 32, false, true, 1234ULL}, - {0.15f, -1.f, 1024, 64, false, true, 1234ULL}, - {0.15f, -1.f, 1024, 128, false, true, 1234ULL}, - {0.15f, -1.f, 1024, 256, false, true, 1234ULL}}; - -const std::vector> inputsd = {{0.15, 1.0, 1024, 32, true, false, 1234ULL}, - {0.15, 1.0, 1024, 64, true, false, 1234ULL}, - {0.15, 1.0, 1024, 128, true, false, 1234ULL}, - {0.15, 1.0, 1024, 256, true, false, 1234ULL}, - {0.15, -1.0, 1024, 32, false, false, 1234ULL}, - {0.15, -1.0, 1024, 64, false, false, 1234ULL}, - {0.15, -1.0, 1024, 128, false, false, 1234ULL}, - {0.15, -1.0, 1024, 256, false, false, 1234ULL}, - {0.15, 1.0, 1024, 32, true, true, 1234ULL}, - {0.15, 1.0, 1024, 64, true, true, 1234ULL}, - {0.15, 1.0, 1024, 128, true, true, 1234ULL}, - {0.15, 1.0, 1024, 256, true, true, 1234ULL}, - {0.15, -1.0, 1024, 32, false, true, 1234ULL}, - {0.15, -1.0, 1024, 64, false, true, 1234ULL}, - {0.15, -1.0, 1024, 128, false, true, 1234ULL}, - {0.15, -1.0, 1024, 256, false, true, 1234ULL}}; +const std::vector> inputsf = { + {0.15f, 1.f, 1024, 32, true, false, 1234ULL}, {0.15f, 1.f, 1024, 64, true, false, 1234ULL}, + {0.15f, 1.f, 1024, 128, true, false, 1234ULL}, {0.15f, 1.f, 1024, 256, true, false, 1234ULL}, + {0.15f, -1.f, 1024, 32, false, false, 1234ULL}, {0.15f, -1.f, 1024, 64, false, false, 1234ULL}, + {0.15f, -1.f, 1024, 128, false, false, 1234ULL}, {0.15f, -1.f, 1024, 256, false, false, 1234ULL}, + {0.15f, 1.f, 1024, 32, true, true, 1234ULL}, {0.15f, 1.f, 1024, 64, true, true, 1234ULL}, + {0.15f, 1.f, 1024, 128, true, true, 1234ULL}, {0.15f, 1.f, 1024, 256, true, true, 1234ULL}, + {0.15f, -1.f, 1024, 32, false, true, 1234ULL}, {0.15f, -1.f, 1024, 64, false, true, 1234ULL}, + {0.15f, -1.f, 1024, 128, false, true, 1234ULL}, {0.15f, -1.f, 1024, 256, false, true, 1234ULL}, + {0.15f, -1.f, 1030, 1, false, false, 1234ULL}, {0.15f, -1.f, 1030, 60, true, false, 1234ULL}, + {2.0f, -1.f, 31, 120, false, false, 1234ULL}, {2.0f, -1.f, 1, 130, true, false, 1234ULL}, + {0.15f, -1.f, 1030, 1, false, true, 1234ULL}, {0.15f, -1.f, 1030, 60, true, true, 1234ULL}, + {2.0f, -1.f, 31, 120, false, true, 1234ULL}, {2.0f, -1.f, 1, 130, false, true, 1234ULL}, + {2.0f, -1.f, 1, 1, false, false, 1234ULL}, {2.0f, -1.f, 1, 1, false, true, 1234ULL}, + {2.0f, -1.f, 7, 23, false, false, 1234ULL}, {2.0f, -1.f, 7, 23, false, true, 1234ULL}, + {2.0f, -1.f, 17, 5, false, false, 1234ULL}, {2.0f, -1.f, 17, 5, false, true, 1234ULL}}; + +const std::vector> inputsd = { + {0.15, 1.0, 1024, 32, true, false, 1234ULL}, {0.15, 1.0, 1024, 64, true, false, 1234ULL}, + {0.15, 1.0, 1024, 128, true, false, 1234ULL}, {0.15, 1.0, 1024, 256, true, false, 1234ULL}, + {0.15, -1.0, 1024, 32, false, false, 1234ULL}, {0.15, -1.0, 1024, 64, false, false, 1234ULL}, + {0.15, -1.0, 1024, 128, false, false, 1234ULL}, {0.15, -1.0, 1024, 256, false, false, 1234ULL}, + {0.15, 1.0, 1024, 32, true, true, 1234ULL}, {0.15, 1.0, 1024, 64, true, true, 1234ULL}, + {0.15, 1.0, 1024, 128, true, true, 1234ULL}, {0.15, 1.0, 1024, 256, true, true, 1234ULL}, + {0.15, -1.0, 1024, 32, false, true, 1234ULL}, {0.15, -1.0, 1024, 64, false, true, 1234ULL}, + {0.15, -1.0, 1024, 128, false, true, 1234ULL}, {0.15, -1.0, 1024, 256, false, true, 1234ULL}, + {0.15, -1.0, 1030, 1, false, false, 1234ULL}, {0.15, -1.0, 1030, 60, true, false, 1234ULL}, + {2.0, -1.0, 31, 120, false, false, 1234ULL}, {2.0, -1.0, 1, 130, true, false, 1234ULL}, + {0.15, -1.0, 1030, 1, false, true, 1234ULL}, {0.15, -1.0, 1030, 60, true, true, 1234ULL}, + {2.0, -1.0, 31, 120, false, true, 1234ULL}, {2.0, -1.0, 1, 130, false, true, 1234ULL}, + {2.0, -1.0, 1, 1, false, false, 1234ULL}, {2.0, -1.0, 1, 1, false, true, 1234ULL}, + {2.0, -1.0, 7, 23, false, false, 1234ULL}, {2.0, -1.0, 7, 23, false, true, 1234ULL}, + {2.0, -1.0, 17, 5, false, false, 1234ULL}, {2.0, -1.0, 17, 5, false, true, 1234ULL}}; typedef MeanTest MeanTestF; TEST_P(MeanTestF, Result) diff --git a/cpp/test/stats/minmax.cu b/cpp/test/stats/minmax.cu index 7563cb12be..fd909ebb90 100644 --- a/cpp/test/stats/minmax.cu +++ b/cpp/test/stats/minmax.cu @@ -145,45 +145,33 @@ class MinMaxTest : public ::testing::TestWithParam> { rmm::device_uvector minmax_ref; }; -const std::vector> inputsf = {{0.00001f, 1024, 32, 1234ULL}, - {0.00001f, 1024, 64, 1234ULL}, - {0.00001f, 1024, 128, 1234ULL}, - {0.00001f, 1024, 256, 1234ULL}, - {0.00001f, 1024, 512, 1234ULL}, - {0.00001f, 1024, 1024, 1234ULL}, - {0.00001f, 4096, 32, 1234ULL}, - {0.00001f, 4096, 64, 1234ULL}, - {0.00001f, 4096, 128, 1234ULL}, - {0.00001f, 4096, 256, 1234ULL}, - {0.00001f, 4096, 512, 1234ULL}, - {0.00001f, 4096, 1024, 1234ULL}, - {0.00001f, 8192, 32, 1234ULL}, - {0.00001f, 8192, 64, 1234ULL}, - {0.00001f, 8192, 128, 1234ULL}, - {0.00001f, 8192, 256, 1234ULL}, - {0.00001f, 8192, 512, 1234ULL}, - {0.00001f, 8192, 1024, 1234ULL}, - {0.00001f, 1024, 8192, 1234ULL}}; - -const std::vector> inputsd = {{0.0000001, 1024, 32, 1234ULL}, - {0.0000001, 1024, 64, 1234ULL}, - {0.0000001, 1024, 128, 1234ULL}, - {0.0000001, 1024, 256, 1234ULL}, - {0.0000001, 1024, 512, 1234ULL}, - {0.0000001, 1024, 1024, 1234ULL}, - {0.0000001, 4096, 32, 1234ULL}, - {0.0000001, 4096, 64, 1234ULL}, - {0.0000001, 4096, 128, 1234ULL}, - {0.0000001, 4096, 256, 1234ULL}, - {0.0000001, 4096, 512, 1234ULL}, - {0.0000001, 4096, 1024, 1234ULL}, - {0.0000001, 8192, 32, 1234ULL}, - {0.0000001, 8192, 64, 1234ULL}, - {0.0000001, 8192, 128, 1234ULL}, - {0.0000001, 8192, 256, 1234ULL}, - {0.0000001, 8192, 512, 1234ULL}, - {0.0000001, 8192, 1024, 1234ULL}, - {0.0000001, 1024, 8192, 1234ULL}}; +const std::vector> inputsf = { + {0.00001f, 1024, 32, 1234ULL}, {0.00001f, 1024, 64, 1234ULL}, {0.00001f, 1024, 128, 1234ULL}, + {0.00001f, 1024, 256, 1234ULL}, {0.00001f, 1024, 512, 1234ULL}, {0.00001f, 1024, 1024, 1234ULL}, + {0.00001f, 4096, 32, 1234ULL}, {0.00001f, 4096, 64, 1234ULL}, {0.00001f, 4096, 128, 1234ULL}, + {0.00001f, 4096, 256, 1234ULL}, {0.00001f, 4096, 512, 1234ULL}, {0.00001f, 4096, 1024, 1234ULL}, + {0.00001f, 8192, 32, 1234ULL}, {0.00001f, 8192, 64, 1234ULL}, {0.00001f, 8192, 128, 1234ULL}, + {0.00001f, 8192, 256, 1234ULL}, {0.00001f, 8192, 512, 1234ULL}, {0.00001f, 8192, 1024, 1234ULL}, + {0.00001f, 1024, 8192, 1234ULL}, {0.00001f, 1023, 5, 1234ULL}, {0.00001f, 1025, 30, 1234ULL}, + {0.00001f, 2047, 65, 1234ULL}, {0.00001f, 2049, 22, 1234ULL}, {0.00001f, 31, 644, 1234ULL}, + {0.00001f, 33, 999, 1234ULL}, {0.00001f, 1, 1, 1234ULL}, {0.00001f, 7, 23, 1234ULL}, + {0.00001f, 17, 5, 1234ULL}}; + +const std::vector> inputsd = { + {0.0000001, 1024, 32, 1234ULL}, {0.0000001, 1024, 64, 1234ULL}, + {0.0000001, 1024, 128, 1234ULL}, {0.0000001, 1024, 256, 1234ULL}, + {0.0000001, 1024, 512, 1234ULL}, {0.0000001, 1024, 1024, 1234ULL}, + {0.0000001, 4096, 32, 1234ULL}, {0.0000001, 4096, 64, 1234ULL}, + {0.0000001, 4096, 128, 1234ULL}, {0.0000001, 4096, 256, 1234ULL}, + {0.0000001, 4096, 512, 1234ULL}, {0.0000001, 4096, 1024, 1234ULL}, + {0.0000001, 8192, 32, 1234ULL}, {0.0000001, 8192, 64, 1234ULL}, + {0.0000001, 8192, 128, 1234ULL}, {0.0000001, 8192, 256, 1234ULL}, + {0.0000001, 8192, 512, 1234ULL}, {0.0000001, 8192, 1024, 1234ULL}, + {0.0000001, 1024, 8192, 1234ULL}, {0.0000001, 1023, 5, 1234ULL}, + {0.0000001, 1025, 30, 1234ULL}, {0.0000001, 2047, 65, 1234ULL}, + {0.0000001, 2049, 22, 1234ULL}, {0.0000001, 31, 644, 1234ULL}, + {0.0000001, 33, 999, 1234ULL}, {0.0000001, 1, 1, 1234ULL}, + {0.0000001, 7, 23, 1234ULL}, {0.0000001, 17, 5, 1234ULL}}; typedef MinMaxTest MinMaxTestF; TEST_P(MinMaxTestF, Result) diff --git a/cpp/test/stats/stddev.cu b/cpp/test/stats/stddev.cu index cf57d3a923..641621c1c6 100644 --- a/cpp/test/stats/stddev.cu +++ b/cpp/test/stats/stddev.cu @@ -141,7 +141,19 @@ const std::vector> inputsf = { {0.1f, -1.f, 2.f, 1024, 32, false, true, 1234ULL}, {0.1f, -1.f, 2.f, 1024, 64, false, true, 1234ULL}, {0.1f, -1.f, 2.f, 1024, 128, false, true, 1234ULL}, - {0.1f, -1.f, 2.f, 1024, 256, false, true, 1234ULL}}; + {0.1f, -1.f, 2.f, 1024, 256, false, true, 1234ULL}, + {0.1f, -1.f, 2.f, 1099, 97, false, false, 1234ULL}, + {0.1f, -1.f, 2.f, 1022, 694, true, false, 1234ULL}, + {0.5f, -1.f, 2.f, 31, 1, true, true, 1234ULL}, + {1.f, -1.f, 2.f, 1, 257, false, true, 1234ULL}, + {0.5f, -1.f, 2.f, 31, 1, false, false, 1234ULL}, + {1.f, -1.f, 2.f, 1, 257, true, false, 1234ULL}, + {1.f, -1.f, 2.f, 1, 1, false, false, 1234ULL}, + {1.f, -1.f, 2.f, 7, 23, false, false, 1234ULL}, + {1.f, -1.f, 2.f, 17, 5, false, false, 1234ULL}, + {1.f, -1.f, 2.f, 1, 1, false, true, 1234ULL}, + {1.f, -1.f, 2.f, 7, 23, false, true, 1234ULL}, + {1.f, -1.f, 2.f, 17, 5, false, true, 1234ULL}}; const std::vector> inputsd = { {0.1, 1.0, 2.0, 1024, 32, true, false, 1234ULL}, @@ -159,13 +171,33 @@ const std::vector> inputsd = { {0.1, -1.0, 2.0, 1024, 32, false, true, 1234ULL}, {0.1, -1.0, 2.0, 1024, 64, false, true, 1234ULL}, {0.1, -1.0, 2.0, 1024, 128, false, true, 1234ULL}, - {0.1, -1.0, 2.0, 1024, 256, false, true, 1234ULL}}; + {0.1, -1.0, 2.0, 1024, 256, false, true, 1234ULL}, + {0.1, -1.0, 2.0, 1099, 97, false, false, 1234ULL}, + {0.1, -1.0, 2.0, 1022, 694, true, false, 1234ULL}, + {0.5, -1.0, 2.0, 31, 1, true, true, 1234ULL}, + {1.0, -1.0, 2.0, 1, 257, false, true, 1234ULL}, + {0.5, -1.0, 2.0, 31, 1, false, false, 1234ULL}, + {1.0, -1.0, 2.0, 1, 257, true, false, 1234ULL}, + {1.0, -1.0, 2.0, 1, 1, false, false, 1234ULL}, + {1.0, -1.0, 2.0, 7, 23, false, false, 1234ULL}, + {1.0, -1.0, 2.0, 17, 5, false, false, 1234ULL}, + {1.0, -1.0, 2.0, 1, 1, false, true, 1234ULL}, + {1.0, -1.0, 2.0, 7, 23, false, true, 1234ULL}, + {1.0, -1.0, 2.0, 17, 5, false, true, 1234ULL}}; typedef StdDevTest StdDevTestF; TEST_P(StdDevTestF, Result) { - ASSERT_TRUE(devArrMatch( - params.stddev, stddev_act.data(), params.cols, CompareApprox(params.tolerance), stream)); + if (params.rows == 1) { + ASSERT_TRUE(devArrMatch( + float(0), stddev_act.data(), params.cols, CompareApprox(params.tolerance), stream)); + } else { + ASSERT_TRUE(devArrMatch(params.stddev, + stddev_act.data(), + params.cols, + CompareApprox(params.tolerance), + stream)); + } ASSERT_TRUE(devArrMatch(stddev_act.data(), vars_act.data(), @@ -177,11 +209,16 @@ TEST_P(StdDevTestF, Result) typedef StdDevTest StdDevTestD; TEST_P(StdDevTestD, Result) { - ASSERT_TRUE(devArrMatch(params.stddev, - stddev_act.data(), - params.cols, - CompareApprox(params.tolerance), - stream)); + if (params.rows == 1) { + ASSERT_TRUE(devArrMatch( + double(0), stddev_act.data(), params.cols, CompareApprox(params.tolerance), stream)); + } else { + ASSERT_TRUE(devArrMatch(params.stddev, + stddev_act.data(), + params.cols, + CompareApprox(params.tolerance), + stream)); + } ASSERT_TRUE(devArrMatch(stddev_act.data(), vars_act.data(), diff --git a/cpp/test/stats/sum.cu b/cpp/test/stats/sum.cu index 5a549f8ba4..bf2aa44a2c 100644 --- a/cpp/test/stats/sum.cu +++ b/cpp/test/stats/sum.cu @@ -33,7 +33,8 @@ template struct SumInputs { T tolerance; int rows, cols; - unsigned long long int seed; + bool rowMajor; + T value = T(1); }; template @@ -56,20 +57,34 @@ class SumTest : public ::testing::TestWithParam> { } protected: - void SetUp() override + void runTest() { int len = rows * cols; - T data_h[len]; + std::vector data_h(len); for (int i = 0; i < len; i++) { - data_h[i] = T(1); + data_h[i] = T(params.value); } - raft::update_device(data.data(), data_h, len, stream); - sum(handle, - raft::make_device_matrix_view(data.data(), rows, cols), - raft::make_device_vector_view(sum_act.data(), cols)); + raft::update_device(data.data(), data_h.data(), len, stream); + + if (params.rowMajor) { + using layout = raft::row_major; + sum(handle, + raft::make_device_matrix_view(data.data(), rows, cols), + raft::make_device_vector_view(sum_act.data(), cols)); + } else { + using layout = raft::col_major; + sum(handle, + raft::make_device_matrix_view(data.data(), rows, cols), + raft::make_device_vector_view(sum_act.data(), cols)); + } resource::sync_stream(handle, stream); + + double expected = double(params.rows) * params.value; + + ASSERT_TRUE(raft::devArrMatch( + T(expected), sum_act.data(), params.cols, raft::CompareApprox(params.tolerance))); } protected: @@ -81,27 +96,49 @@ class SumTest : public ::testing::TestWithParam> { rmm::device_uvector data, sum_act; }; -const std::vector> inputsf = { - {0.05f, 4, 5, 1234ULL}, {0.05f, 1024, 32, 1234ULL}, {0.05f, 1024, 256, 1234ULL}}; - -const std::vector> inputsd = {{0.05, 1024, 32, 1234ULL}, - {0.05, 1024, 256, 1234ULL}}; +const std::vector> inputsf = {{0.0001f, 4, 5, true, 1}, + {0.0001f, 1024, 32, true, 1}, + {0.0001f, 1024, 256, true, 1}, + {0.0001f, 100000000, 1, true, 0.001}, + {0.0001f, 1, 30, true, 0.001}, + {0.0001f, 1, 1, true, 0.001}, + {0.0001f, 17, 5, true, 0.001}, + {0.0001f, 7, 23, true, 0.001}, + {0.0001f, 3, 97, true, 0.001}, + {0.0001f, 4, 5, false, 1}, + {0.0001f, 1024, 32, false, 1}, + {0.0001f, 1024, 256, false, 1}, + {0.0001f, 100000000, 1, false, 0.001}, + {0.0001f, 1, 30, false, 0.001}, + {0.0001f, 1, 1, false, 0.001}, + {0.0001f, 17, 5, false, 0.001}, + {0.0001f, 7, 23, false, 0.001}, + {0.0001f, 3, 97, false, 0.001}}; + +const std::vector> inputsd = {{0.000001, 1024, 32, true, 1}, + {0.000001, 1024, 256, true, 1}, + {0.000001, 1024, 256, true, 1}, + {0.000001, 100000000, 1, true, 0.001}, + {0.000001, 1, 30, true, 0.0001}, + {0.000001, 1, 1, true, 0.0001}, + {0.000001, 17, 5, true, 0.0001}, + {0.000001, 7, 23, true, 0.0001}, + {0.000001, 3, 97, true, 0.0001}, + {0.000001, 1024, 32, false, 1}, + {0.000001, 1024, 256, false, 1}, + {0.000001, 1024, 256, false, 1}, + {0.000001, 100000000, 1, false, 0.001}, + {0.000001, 1, 30, false, 0.0001}, + {0.000001, 1, 1, false, 0.0001}, + {0.000001, 17, 5, false, 0.0001}, + {0.000001, 7, 23, false, 0.0001}, + {0.000001, 3, 97, false, 0.0001}}; typedef SumTest SumTestF; -TEST_P(SumTestF, Result) -{ - ASSERT_TRUE(raft::devArrMatch( - float(params.rows), sum_act.data(), params.cols, raft::CompareApprox(params.tolerance))); -} - typedef SumTest SumTestD; -TEST_P(SumTestD, Result) -{ - ASSERT_TRUE(raft::devArrMatch(double(params.rows), - sum_act.data(), - params.cols, - raft::CompareApprox(params.tolerance))); -} + +TEST_P(SumTestF, Result) { runTest(); } +TEST_P(SumTestD, Result) { runTest(); } INSTANTIATE_TEST_CASE_P(SumTests, SumTestF, ::testing::ValuesIn(inputsf)); From 69fd9714afa4a40cf3d68ef6a2ed137545ef52a2 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 18 Mar 2024 12:09:24 -0400 Subject: [PATCH 04/18] Replace local copyright check with pre-commit-hooks verify-copyright (#2220) The local `copyright.py` script is bug-prone. Replace it with a more robust centralized script from `pre-commit-hooks`. Issue: https://github.com/rapidsai/build-planning/issues/30 Authors: - Kyle Edwards (https://github.com/KyleFromNVIDIA) Approvers: - Jake Awe (https://github.com/AyodeAwe) URL: https://github.com/rapidsai/raft/pull/2220 --- .pre-commit-config.yaml | 23 +++- ci/checks/copyright.py | 290 ---------------------------------------- 2 files changed, 17 insertions(+), 296 deletions(-) delete mode 100644 ci/checks/copyright.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0d6ab7ee54..6bca1d228e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -81,12 +81,6 @@ repos: verbose: true require_serial: true exclude: .*/thirdparty/.* - - id: copyright-check - name: copyright-check - entry: python ./ci/checks/copyright.py --git-modified-only --update-current-year - language: python - pass_filenames: false - additional_dependencies: [gitpython] - id: include-check name: include-check entry: python ./cpp/scripts/include_checker.py cpp/bench cpp/include cpp/test @@ -109,6 +103,23 @@ repos: rev: v4.5.0 hooks: - id: check-json + - repo: https://github.com/rapidsai/pre-commit-hooks + rev: v0.0.1 + hooks: + - id: verify-copyright + files: | + (?x) + [.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$| + CMakeLists[.]txt$| + CMakeLists_standalone[.]txt$| + meta[.]yaml$| + setup[.]cfg$ + exclude: | + (?x) + cpp/include/raft/neighbors/detail/faiss_select/| + cpp/include/raft/thirdparty/| + docs/source/sphinxext/github_link[.]py| + cpp/cmake/modules/FindAVX[.]cmake default_language_version: python: python3 diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py deleted file mode 100644 index 2af8b1b8ff..0000000000 --- a/ci/checks/copyright.py +++ /dev/null @@ -1,290 +0,0 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import datetime -import re -import argparse -import io -import os -import sys - -import git - -SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - -# Add the scripts dir for gitutils -sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, - "../../cpp/scripts"))) - -# Now import gitutils. Ignore flake8 error here since there is no other way to -# set up imports -import gitutils # noqa: E402 - -FilesToCheck = [ - re.compile(r"[.](cmake|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx)$"), - re.compile(r"CMakeLists[.]txt$"), - re.compile(r"CMakeLists_standalone[.]txt$"), - re.compile(r"setup[.]cfg$"), - re.compile(r"meta[.]yaml$") -] -ExemptFiles = [ - re.compile("cpp/include/raft/neighbors/detail/faiss_select/"), - re.compile("cpp/include/raft/thirdparty/"), - re.compile("docs/source/sphinxext/github_link.py"), - re.compile("cpp/cmake/modules/FindAVX.cmake") -] - -# this will break starting at year 10000, which is probably OK :) -CheckSimple = re.compile( - r"Copyright *(?:\(c\))? *(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)") -CheckDouble = re.compile( - r"Copyright *(?:\(c\))? *(\d{4})-(\d{4}),? *NVIDIA C(?:ORPORATION|orporation)" # noqa: E501 -) - - -def checkThisFile(f): - if isinstance(f, git.Diff): - if f.deleted_file or f.b_blob.size == 0: - return False - f = f.b_path - elif not os.path.exists(f) or os.stat(f).st_size == 0: - # This check covers things like symlinks which point to files that DNE - return False - for exempt in ExemptFiles: - if exempt.search(f): - return False - for checker in FilesToCheck: - if checker.search(f): - return True - return False - - -def modifiedFiles(): - """Get a set of all modified files, as Diff objects. - - The files returned have been modified in git since the merge base of HEAD - and the upstream of the target branch. We return the Diff objects so that - we can read only the staged changes. - """ - repo = git.Repo() - # Use the environment variable TARGET_BRANCH or RAPIDS_BASE_BRANCH (defined in CI) if possible - target_branch = os.environ.get("TARGET_BRANCH", os.environ.get("RAPIDS_BASE_BRANCH")) - if target_branch is None: - # Fall back to the closest branch if not on CI - target_branch = repo.git.describe( - all=True, tags=True, match="branch-*", abbrev=0 - ).lstrip("heads/") - - upstream_target_branch = None - if target_branch in repo.heads: - # Use the tracking branch of the local reference if it exists. This - # returns None if no tracking branch is set. - upstream_target_branch = repo.heads[target_branch].tracking_branch() - if upstream_target_branch is None: - # Fall back to the remote with the newest target_branch. This code - # path is used on CI because the only local branch reference is - # current-pr-branch, and thus target_branch is not in repo.heads. - # This also happens if no tracking branch is defined for the local - # target_branch. We use the remote with the latest commit if - # multiple remotes are defined. - candidate_branches = [ - remote.refs[target_branch] for remote in repo.remotes - if target_branch in remote.refs - ] - if len(candidate_branches) > 0: - upstream_target_branch = sorted( - candidate_branches, - key=lambda branch: branch.commit.committed_datetime, - )[-1] - else: - # If no remotes are defined, try to use the local version of the - # target_branch. If this fails, the repo configuration must be very - # strange and we can fix this script on a case-by-case basis. - upstream_target_branch = repo.heads[target_branch] - merge_base = repo.merge_base("HEAD", upstream_target_branch.commit)[0] - diff = merge_base.diff() - changed_files = {f for f in diff if f.b_path is not None} - return changed_files - - -def getCopyrightYears(line): - res = CheckSimple.search(line) - if res: - return int(res.group(1)), int(res.group(1)) - res = CheckDouble.search(line) - if res: - return int(res.group(1)), int(res.group(2)) - return None, None - - -def replaceCurrentYear(line, start, end): - # first turn a simple regex into double (if applicable). then update years - res = CheckSimple.sub(r"Copyright (c) \1-\1, NVIDIA CORPORATION", line) - res = CheckDouble.sub( - rf"Copyright (c) {start:04d}-{end:04d}, NVIDIA CORPORATION", - res, - ) - return res - - -def checkCopyright(f, update_current_year): - """Checks for copyright headers and their years.""" - errs = [] - thisYear = datetime.datetime.now().year - lineNum = 0 - crFound = False - yearMatched = False - - if isinstance(f, git.Diff): - path = f.b_path - lines = f.b_blob.data_stream.read().decode().splitlines(keepends=True) - else: - path = f - with open(f, encoding="utf-8") as fp: - lines = fp.readlines() - - for line in lines: - lineNum += 1 - start, end = getCopyrightYears(line) - if start is None: - continue - crFound = True - if start > end: - e = [ - path, - lineNum, - "First year after second year in the copyright " - "header (manual fix required)", - None, - ] - errs.append(e) - elif thisYear < start or thisYear > end: - e = [ - path, - lineNum, - "Current year not included in the copyright header", - None, - ] - if thisYear < start: - e[-1] = replaceCurrentYear(line, thisYear, end) - if thisYear > end: - e[-1] = replaceCurrentYear(line, start, thisYear) - errs.append(e) - else: - yearMatched = True - # copyright header itself not found - if not crFound: - e = [ - path, - 0, - "Copyright header missing or formatted incorrectly " - "(manual fix required)", - None, - ] - errs.append(e) - # even if the year matches a copyright header, make the check pass - if yearMatched: - errs = [] - - if update_current_year: - errs_update = [x for x in errs if x[-1] is not None] - if len(errs_update) > 0: - lines_changed = ", ".join(str(x[1]) for x in errs_update) - print(f"File: {path}. Changing line(s) {lines_changed}") - for _, lineNum, __, replacement in errs_update: - lines[lineNum - 1] = replacement - with open(path, "w", encoding="utf-8") as out_file: - out_file.writelines(lines) - - return errs - - -def getAllFilesUnderDir(root, pathFilter=None): - retList = [] - for dirpath, dirnames, filenames in os.walk(root): - for fn in filenames: - filePath = os.path.join(dirpath, fn) - if pathFilter(filePath): - retList.append(filePath) - return retList - - -def checkCopyright_main(): - """ - Checks for copyright headers in all the modified files. In case of local - repo, this script will just look for uncommitted files and in case of CI - it compares between branches "$PR_TARGET_BRANCH" and "current-pr-branch" - """ - retVal = 0 - - argparser = argparse.ArgumentParser( - "Checks for a consistent copyright header in git's modified files" - ) - argparser.add_argument( - "--update-current-year", - dest="update_current_year", - action="store_true", - required=False, - help="If set, " - "update the current year if a header is already " - "present and well formatted.", - ) - argparser.add_argument( - "--git-modified-only", - dest="git_modified_only", - action="store_true", - required=False, - help="If set, " - "only files seen as modified by git will be " - "processed.", - ) - - args, dirs = argparser.parse_known_args() - - if args.git_modified_only: - files = [f for f in modifiedFiles() if checkThisFile(f)] - else: - files = [] - for d in [os.path.abspath(d) for d in dirs]: - if not os.path.isdir(d): - raise ValueError(f"{d} is not a directory.") - files += getAllFilesUnderDir(d, pathFilter=checkThisFile) - - errors = [] - for f in files: - errors += checkCopyright(f, args.update_current_year) - - if len(errors) > 0: - if any(e[-1] is None for e in errors): - print("Copyright headers incomplete in some of the files!") - for e in errors: - print(" %s:%d Issue: %s" % (e[0], e[1], e[2])) - print("") - n_fixable = sum(1 for e in errors if e[-1] is not None) - path_parts = os.path.abspath(__file__).split(os.sep) - file_from_repo = os.sep.join(path_parts[path_parts.index("ci") :]) - if n_fixable > 0 and not args.update_current_year: - print( - f"You can run `python {file_from_repo} --git-modified-only " - "--update-current-year` and stage the results in git to " - f"fix {n_fixable} of these errors.\n" - ) - retVal = 1 - - return retVal - - -if __name__ == "__main__": - sys.exit(checkCopyright_main()) From 32f6f40c561a97bc0174f2ebf0894d3289a76d65 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Mon, 18 Mar 2024 17:14:11 +0100 Subject: [PATCH 05/18] Add CAGRA-Q build (compression) (#2213) Add a `cagra::compress` function that implements CAGRA-Q (VQ + PQ) compression of a given dataset. The result, `compressed_dataset`, is supposed to complement the CAGRA graph during `cagra::search` in place of a raw dataset. ### Current state: - The code runs and produces a meaningful output (tested internally by running the original prototype search with the generated compressed dataset); the recall levels are approximately the same as with the prototype implementation. - No test coverage yet (need to coordinate with the search PR https://github.com/rapidsai/raft/pull/2206) - Full `pq_bits` support ([4,5,6,7,8] - same as in IVF-PQ) - Any `pq_dim` values are accepted, but the dataset is not padded and thus `dim` must be a multiple of `pq_dim`. - The codebook math type is hardcoded to `half` to match the prototype implementation for now. This could be a runtime (build) parameter as well. - All common input data types should work (`uint8_t`, `int8_t`, `half`, and `float` compile), but I tested only `float`. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/2213 --- .../core/detail/mdspan_numpy_serializer.hpp | 3 +- cpp/include/raft/neighbors/cagra.cuh | 89 ++-- cpp/include/raft/neighbors/cagra_types.hpp | 90 ++-- cpp/include/raft/neighbors/dataset.hpp | 330 ++++++++++++++ .../neighbors/detail/cagra/cagra_build.cuh | 10 + .../detail/cagra/cagra_serialize.cuh | 35 +- .../neighbors/detail/dataset_serialize.hpp | 192 ++++++++ .../raft/neighbors/detail/vpq_dataset.cuh | 427 ++++++++++++++++++ cpp/include/raft/neighbors/vpq_dataset.cuh | 51 +++ 9 files changed, 1103 insertions(+), 124 deletions(-) create mode 100644 cpp/include/raft/neighbors/dataset.hpp create mode 100644 cpp/include/raft/neighbors/detail/dataset_serialize.hpp create mode 100644 cpp/include/raft/neighbors/detail/vpq_dataset.cuh create mode 100644 cpp/include/raft/neighbors/vpq_dataset.cuh diff --git a/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp b/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp index 176309c8ce..3fb7b3005b 100644 --- a/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp +++ b/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp @@ -126,7 +126,8 @@ inline dtype_t get_numpy_dtype() } #if defined(_RAFT_HAS_CUDA) -template , bool> = true> +template , half>, bool> = true> inline dtype_t get_numpy_dtype() { return {RAFT_NUMPY_HOST_ENDIAN_CHAR, 'e', sizeof(T)}; diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index b8258297e6..b7e362f704 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -279,62 +280,6 @@ index build(raft::resources const& res, return detail::build(res, params, dataset); } -/** - * @brief Search ANN using the constructed index. - * - * See the [cagra::build](#cagra::build) documentation for a usage example. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] res raft resources - * @param[in] params configure the search - * @param[in] idx cagra index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ -template -void search(raft::resources const& res, - const search_params& params, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must equal k"); - RAFT_EXPECTS(queries.extent(1) == idx.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - using internal_IdxT = typename std::make_unsigned::type; - auto queries_internal = raft::make_device_matrix_view( - queries.data_handle(), queries.extent(0), queries.extent(1)); - auto neighbors_internal = raft::make_device_matrix_view( - reinterpret_cast(neighbors.data_handle()), - neighbors.extent(0), - neighbors.extent(1)); - auto distances_internal = raft::make_device_matrix_view( - distances.data_handle(), distances.extent(0), distances.extent(1)); - - cagra::detail::search_main(res, - params, - idx, - queries_internal, - neighbors_internal, - distances_internal, - raft::neighbors::filtering::none_cagra_sample_filter()); -} - /** * @brief Search ANN using the constructed index with the given sample filter. * @@ -401,10 +346,40 @@ void search_with_filtering(raft::resources const& res, auto distances_internal = raft::make_device_matrix_view( distances.data_handle(), distances.extent(0), distances.extent(1)); - cagra::detail::search_main( + return cagra::detail::search_main( res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); } +/** + * @brief Search ANN using the constructed index. + * + * See the [cagra::build](#cagra::build) documentation for a usage example. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] res raft resources + * @param[in] params configure the search + * @param[in] idx cagra index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search(raft::resources const& res, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + using none_filter_type = raft::neighbors::filtering::none_cagra_sample_filter; + return cagra::search_with_filtering( + res, params, idx, queries, neighbors, distances, none_filter_type{}); +} + /** @} */ // end group cagra } // namespace raft::neighbors::cagra diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 0f574ae5bb..807f89fd65 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -17,6 +17,7 @@ #pragma once #include "ann_types.hpp" +#include "dataset.hpp" #include #include @@ -35,6 +36,7 @@ #include #include #include + namespace raft::neighbors::cagra { /** * @addtogroup cagra @@ -61,6 +63,12 @@ struct index_params : ann::index_params { graph_build_algo build_algo = graph_build_algo::IVF_PQ; /** Number of Iterations to run if building with NN_DESCENT */ size_t nn_descent_niter = 20; + /** + * Specify compression params if compression is desired. + * + * NOTE: this is experimental new API, consider it unsafe. + */ + std::optional compression = std::nullopt; }; enum class search_algo { @@ -145,25 +153,37 @@ struct index : ann::index { /** Total length of the index (number of vectors). */ [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { - return dataset_view_.extent(0) ? dataset_view_.extent(0) : graph_view_.extent(0); + auto data_rows = dataset_->n_rows(); + return data_rows > 0 ? data_rows : graph_view_.extent(0); } /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t - { - return dataset_view_.extent(1); - } + [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { return dataset_->dim(); } /** Graph degree */ [[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t { return graph_view_.extent(1); } - /** Dataset [size, dim] */ - [[nodiscard]] inline auto dataset() const noexcept + /** + * DEPRECATED: please use data() instead. + * If you need to query dataset dimensions, use the dim() and size() of the cagra index. + * The data_handle() is not always available: you need to do a dynamic_cast to the expected + * dataset type at runtime. + */ + [[nodiscard]] [[deprecated("Use data()")]] inline auto dataset() const noexcept -> device_matrix_view { - return dataset_view_; + auto p = dynamic_cast*>(dataset_.get()); + if (p != nullptr) { return p->view(); } + auto d = dataset_->dim(); + return make_device_strided_matrix_view(nullptr, 0, d, d); + } + + /** Dataset [size, dim] */ + [[nodiscard]] inline auto data() const noexcept -> const neighbors::dataset& + { + return *dataset_; } /** neighborhood graph [size, graph-degree] */ @@ -185,7 +205,7 @@ struct index : ann::index { raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded) : ann::index(), metric_(metric), - dataset_(make_device_matrix(res, 0, 0)), + dataset_(new neighbors::empty_dataset(0)), graph_(make_device_matrix(res, 0, 0)) { } @@ -251,12 +271,11 @@ struct index : ann::index { mdspan, row_major, graph_accessor> knn_graph) : ann::index(), metric_(metric), - dataset_(make_device_matrix(res, 0, 0)), + dataset_(make_aligned_dataset(res, dataset, 16)), graph_(make_device_matrix(res, 0, 0)) { RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), "Dataset and knn_graph must have equal number of rows"); - update_dataset(res, dataset); update_graph(res, knn_graph); resource::sync_stream(res); } @@ -271,21 +290,14 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - if (dataset.extent(1) * sizeof(T) % 16 != 0) { - RAFT_LOG_DEBUG("Creating a padded copy of CAGRA dataset in device memory"); - copy_padded(res, dataset); - } else { - dataset_view_ = make_device_strided_matrix_view( - dataset.data_handle(), dataset.extent(0), dataset.extent(1), dataset.extent(1)); - } + dataset_ = make_aligned_dataset(res, dataset, 16); } /** Set the dataset reference explicitly to a device matrix view with padding. */ - void update_dataset(raft::resources const&, + void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - RAFT_EXPECTS(dataset.stride(0) * sizeof(T) % 16 == 0, "Incorrect data padding."); - dataset_view_ = dataset; + dataset_ = make_aligned_dataset(res, dataset, 16); } /** @@ -296,8 +308,22 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::host_matrix_view dataset) { - RAFT_LOG_DEBUG("Copying CAGRA dataset from host to device"); - copy_padded(res, dataset); + dataset_ = make_aligned_dataset(res, dataset, 16); + } + + /** Replace the dataset with a new dataset. */ + template + auto update_dataset(raft::resources const& res, DatasetT&& dataset) + -> std::enable_if_t, DatasetT>> + { + dataset_ = std::make_unique(std::move(dataset)); + } + + template + auto update_dataset(raft::resources const& res, std::unique_ptr&& dataset) + -> std::enable_if_t, DatasetT>> + { + dataset_ = std::move(dataset); } /** @@ -334,26 +360,10 @@ struct index : ann::index { } private: - /** Create a device copy of the dataset, and pad it if necessary. */ - template - void copy_padded(raft::resources const& res, - mdspan, row_major, data_accessor> dataset) - { - detail::copy_with_padding(res, dataset_, dataset); - - dataset_view_ = make_device_strided_matrix_view( - dataset_.data_handle(), dataset_.extent(0), dataset.extent(1), dataset_.extent(1)); - RAFT_LOG_DEBUG("CAGRA dataset strided matrix view %zux%zu, stride %zu", - static_cast(dataset_view_.extent(0)), - static_cast(dataset_view_.extent(1)), - static_cast(dataset_view_.stride(0))); - } - raft::distance::DistanceType metric_; - raft::device_matrix dataset_; raft::device_matrix graph_; - raft::device_matrix_view dataset_view_; raft::device_matrix_view graph_view_; + std::unique_ptr> dataset_; }; /** @} */ diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp new file mode 100644 index 0000000000..e7a3ba97a4 --- /dev/null +++ b/cpp/include/raft/neighbors/dataset.hpp @@ -0,0 +1,330 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include // get_device_for_address +#include // rounding up + +#include +#include +#include + +#ifdef __cpp_lib_bitops +#include +#endif + +namespace raft::neighbors { + +/** Two-dimensional dataset; maybe owning, maybe compressed, maybe strided. */ +template +struct dataset { + using index_type = IdxT; + /** Size of the dataset. */ + [[nodiscard]] virtual auto n_rows() const noexcept -> index_type = 0; + /** Dimensionality of the dataset. */ + [[nodiscard]] virtual auto dim() const noexcept -> uint32_t = 0; + /** Whether the object owns the data. */ + [[nodiscard]] virtual auto is_owning() const noexcept -> bool = 0; + virtual ~dataset() noexcept = default; +}; + +template +struct empty_dataset : public dataset { + using index_type = IdxT; + uint32_t suggested_dim; + explicit empty_dataset(uint32_t dim) noexcept : suggested_dim(dim) {} + [[nodiscard]] auto n_rows() const noexcept -> index_type final { return 0; } + [[nodiscard]] auto dim() const noexcept -> uint32_t final { return suggested_dim; } + [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } +}; + +template +struct strided_dataset : public dataset { + using index_type = IdxT; + using value_type = DataT; + using view_type = device_matrix_view; + [[nodiscard]] auto n_rows() const noexcept -> index_type final { return view().extent(0); } + [[nodiscard]] auto dim() const noexcept -> uint32_t final + { + return static_cast(view().extent(1)); + } + /** Leading dimension of the dataset. */ + [[nodiscard]] constexpr auto stride() const noexcept -> uint32_t + { + auto v = view(); + return static_cast(v.stride(0) > 0 ? v.stride(0) : v.extent(1)); + } + /** Get the view of the data. */ + [[nodiscard]] virtual auto view() const noexcept -> view_type; +}; + +template +struct non_owning_dataset : public strided_dataset { + using index_type = IdxT; + using value_type = DataT; + using typename strided_dataset::view_type; + view_type data; + explicit non_owning_dataset(view_type v) noexcept : data(v) {} + [[nodiscard]] auto is_owning() const noexcept -> bool final { return false; } + [[nodiscard]] auto view() const noexcept -> view_type final { return data; }; +}; + +template +struct owning_dataset : public strided_dataset { + using index_type = IdxT; + using value_type = DataT; + using typename strided_dataset::view_type; + using storage_type = + mdarray, LayoutPolicy, ContainerPolicy>; + using mapping_type = typename view_type::mapping_type; + storage_type data; + mapping_type view_mapping; + owning_dataset(storage_type&& store, mapping_type view_mapping) noexcept + : data{std::move(store)}, view_mapping{view_mapping} + { + } + + [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } + [[nodiscard]] auto view() const noexcept -> view_type final + { + return view_type{data.data_handle(), view_mapping}; + }; +}; + +/** + * @brief Contstruct a strided matrix from any mdarray or mdspan. + * + * This function constructs a non-owning view if the input satisfied two conditions: + * + * 1) The data is accessible from the current device + * 2) The memory layout is the same as expected (row-major matrix with the required stride) + * + * Otherwise, this function constructs an owning device matrix and copies the data. + * When the data is copied, padding elements are filled with zeroes. + * + * @tparam SrcT the source mdarray or mdspan + * + * @param[in] res raft resources handle + * @param[in] src the source mdarray or mdspan + * @param[in] required_stride the leading dimension (in elements) + * @return maybe owning current-device-accessible strided matrix + */ +template +auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t required_stride) + -> std::unique_ptr> +{ + using extents_type = typename SrcT::extents_type; + using value_type = typename SrcT::value_type; + using index_type = typename SrcT::index_type; + using layout_type = typename SrcT::layout_type; + static_assert(extents_type::rank() == 2, "The input must be a matrix."); + static_assert(std::is_same_v || + std::is_same_v> || + std::is_same_v, + "The input must be row-major"); + RAFT_EXPECTS(src.extent(1) <= required_stride, + "The input row length must be not larger than the desired stride."); + cudaPointerAttributes ptr_attrs; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&ptr_attrs, src.data_handle())); + auto* device_ptr = reinterpret_cast(ptr_attrs.devicePointer); + const uint32_t src_stride = src.stride(0) > 0 ? src.stride(0) : src.extent(1); + const bool device_accessible = device_ptr != nullptr; + const bool row_major = src.stride(1) <= 1; + const bool stride_matches = required_stride == src_stride; + + if (device_accessible && row_major && stride_matches) { + // Everything matches: make a non-owning dataset + return std::make_unique>( + make_device_strided_matrix_view( + device_ptr, src.extent(0), src.extent(1), required_stride)); + } + // Something is wrong: have to make a copy and produce an owning dataset + auto out_layout = + make_strided_layout(src.extents(), std::array{required_stride, 1}); + auto out_array = make_device_matrix(res, src.extent(0), required_stride); + + using out_mdarray_type = decltype(out_array); + using out_layout_type = typename out_mdarray_type::layout_type; + using out_container_policy_type = typename out_mdarray_type::container_policy_type; + using out_owning_type = + owning_dataset; + + RAFT_CUDA_TRY(cudaMemsetAsync(out_array.data_handle(), + 0, + out_array.size() * sizeof(value_type), + resource::get_cuda_stream(res))); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(), + sizeof(value_type) * required_stride, + src.data_handle(), + sizeof(value_type) * src_stride, + sizeof(value_type) * src.extent(1), + src.extent(0), + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + + return std::make_unique(std::move(out_array), out_layout); +} + +/** + * @brief Contstruct a strided matrix from any mdarray or mdspan. + * + * A variant `make_strided_dataset` that allows specifying the byte alignment instead of the + * explicit stride length. + * + * @tparam SrcT the source mdarray or mdspan + * + * @param[in] res raft resources handle + * @param[in] src the source mdarray or mdspan + * @param[in] align_bytes the required byte alignment for the dataset rows. + * @return maybe owning current-device-accessible strided matrix + */ +template +auto make_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t align_bytes = 16) + -> std::unique_ptr> +{ + using value_type = typename SrcT::value_type; + constexpr size_t kSize = sizeof(value_type); + uint32_t required_stride = + raft::round_up_safe(src.extent(1) * kSize, std::lcm(align_bytes, kSize)) / kSize; + return make_strided_dataset(res, src, required_stride); +} + +/** Parameters for VPQ compression. */ +struct vpq_params { + /** + * The bit length of the vector element after compression by PQ. + * + * Possible values: [4, 5, 6, 7, 8]. + * + * Hint: the smaller the 'pq_bits', the smaller the index size and the better the search + * performance, but the lower the recall. + */ + uint32_t pq_bits = 8; + /** + * The dimensionality of the vector after compression by PQ. + * When zero, an optimal value is selected using a heuristic. + * + * TODO: at the moment `dim` must be a multiple `pq_dim`. + */ + uint32_t pq_dim = 0; + /** + * Vector Quantization (VQ) codebook size - number of "coarse cluster centers". + * When zero, an optimal value is selected using a heuristic. + */ + uint32_t vq_n_centers = 0; + /** The number of iterations searching for kmeans centers (both VQ & PQ phases). */ + uint32_t kmeans_n_iters = 25; + /** + * The fraction of data to use during iterative kmeans building (VQ phase). + * When zero, an optimal value is selected using a heuristic. + */ + double vq_kmeans_trainset_fraction = 0; + /** + * The fraction of data to use during iterative kmeans building (PQ phase). + * When zero, an optimal value is selected using a heuristic. + */ + double pq_kmeans_trainset_fraction = 0; +}; + +/** + * @brief VPQ compressed dataset. + * + * The dataset is compressed using two level quantization + * + * 1. Vector Quantization + * 2. Product Quantization of residuals + * + * @tparam MathT the type of elements in the codebooks + * @tparam IdxT type of the vector indices (represent dataset.extent(0)) + * + */ +template +struct vpq_dataset : public dataset { + /** Vector Quantization codebook - "coarse cluster centers". */ + device_matrix vq_code_book; + /** Product Quantization codebook - "fine cluster centers". */ + device_matrix pq_code_book; + /** Compressed dataset. */ + device_matrix data; + + vpq_dataset(device_matrix&& vq_code_book, + device_matrix&& pq_code_book, + device_matrix&& data) + : vq_code_book{std::move(vq_code_book)}, + pq_code_book{std::move(pq_code_book)}, + data{std::move(data)} + { + } + + [[nodiscard]] auto n_rows() const noexcept -> IdxT final { return data.extent(0); } + [[nodiscard]] auto dim() const noexcept -> uint32_t final { return vq_code_book.extent(1); } + [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } + + /** Row length of the encoded data in bytes. */ + [[nodiscard]] constexpr inline auto encoded_row_length() const noexcept -> uint32_t + { + return data.extent(1); + } + /** The number of "coarse cluster centers" */ + [[nodiscard]] constexpr inline auto vq_n_centers() const noexcept -> uint32_t + { + return vq_code_book.extent(0); + } + /** The bit length of an encoded vector element after compression by PQ. */ + [[nodiscard]] constexpr inline auto pq_bits() const noexcept -> uint32_t + { + /* + NOTE: pq_bits and the book size + + Normally, we'd store `pq_bits` as a part of the index. + However, we know there's an invariant `pq_n_centers = 1 << pq_bits`, i.e. the codebook size is + the same as the number of possible code values. Hence, we don't store the pq_bits and derive it + from the array dimensions instead. + */ + auto pq_width = pq_n_centers(); +#ifdef __cpp_lib_bitops + return std::countr_zero(pq_width); +#else + uint32_t pq_bits = 0; + while (pq_width > 1) { + pq_bits++; + pq_width >>= 1; + } + return pq_bits; +#endif + } + /** The dimensionality of an encoded vector after compression by PQ. */ + [[nodiscard]] constexpr inline auto pq_dim() const noexcept -> uint32_t + { + return raft::div_rounding_up_unsafe(dim(), pq_len()); + } + /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ + [[nodiscard]] constexpr inline auto pq_len() const noexcept -> uint32_t + { + return pq_code_book.extent(1); + } + /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ + [[nodiscard]] constexpr inline auto pq_n_centers() const noexcept -> uint32_t + { + return pq_code_book.extent(0); + } +}; + +} // namespace raft::neighbors diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 08cc2beaeb..d91e45257e 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -16,6 +16,7 @@ #pragma once #include "../../cagra_types.hpp" +#include "../../vpq_dataset.cuh" #include "graph_core.cuh" #include @@ -344,6 +345,15 @@ index build( RAFT_LOG_INFO("Graph optimized, creating index"); // Construct an index from dataset and optimized knn graph. if (construct_index_with_dataset) { + if (params.compression.has_value()) { + index idx(res, params.metric); + idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view())); + idx.update_dataset( + res, + // TODO: hardcoding codebook math to `half`, we can do runtime dispatching later + neighbors::vpq_build(res, *params.compression, dataset)); + return idx; + } return index(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view())); } else { // We just add the graph. User is expected to update dataset separately. This branch is used diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index d7bd27222b..600c8785e0 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -31,7 +32,7 @@ namespace raft::neighbors::cagra::detail { -constexpr int serialization_version = 3; +constexpr int serialization_version = 4; /** * Save the index to file. @@ -65,26 +66,14 @@ void serialize(raft::resources const& res, serialize_scalar(res, os, index_.metric()); serialize_mdspan(res, os, index_.graph()); - include_dataset &= (index_.dataset().extent(0) > 0); + include_dataset &= (index_.data().n_rows() > 0); serialize_scalar(res, os, include_dataset); if (include_dataset) { RAFT_LOG_INFO("Saving CAGRA index with dataset"); - auto dataset = index_.dataset(); - // Remove padding before saving the dataset - auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), - sizeof(T) * host_dataset.extent(1), - dataset.data_handle(), - sizeof(T) * dataset.stride(0), - sizeof(T) * host_dataset.extent(1), - dataset.extent(0), - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - resource::sync_stream(res); - serialize_mdspan(res, os, host_dataset.view()); + neighbors::detail::serialize(res, os, index_.data()); } else { - RAFT_LOG_INFO("Saving CAGRA index WITHOUT dataset"); + RAFT_LOG_DEBUG("Saving CAGRA index WITHOUT dataset"); } } @@ -256,19 +245,13 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index auto graph = raft::make_host_matrix(n_rows, graph_degree); deserialize_mdspan(res, is, graph.view()); + index idx(res, metric); + idx.update_graph(res, raft::make_const_mdspan(graph.view())); bool has_dataset = deserialize_scalar(res, is); if (has_dataset) { - auto dataset = raft::make_host_matrix(n_rows, dim); - deserialize_mdspan(res, is, dataset.view()); - return index( - res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view())); - } else { - // create a new index with no dataset - the user must supply via update_dataset themselves - // later (this avoids allocating GPU memory in the meantime) - index idx(res, metric); - idx.update_graph(res, raft::make_const_mdspan(graph.view())); - return idx; + idx.update_dataset(res, neighbors::detail::deserialize_dataset(res, is)); } + return idx; } template diff --git a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp new file mode 100644 index 0000000000..a6a6ae59a5 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../dataset.hpp" + +#include +#include +#include + +#include + +#include +#include + +namespace raft::neighbors::detail { + +using dataset_instance_tag = uint32_t; +constexpr dataset_instance_tag kSerializeEmptyDataset = 1; +constexpr dataset_instance_tag kSerializeStridedDataset = 2; +constexpr dataset_instance_tag kSerializeVPQDataset = 3; + +template +void serialize(const raft::resources& res, std::ostream& os, const empty_dataset& dataset) +{ + serialize_scalar(res, os, dataset.suggested_dim); +} + +template +void serialize(const raft::resources& res, + std::ostream& os, + const strided_dataset& dataset) +{ + auto n_rows = dataset.n_rows(); + auto dim = dataset.dim(); + auto stride = dataset.stride(); + serialize_scalar(res, os, n_rows); + serialize_scalar(res, os, dim); + serialize_scalar(res, os, stride); + // Remove padding before saving the dataset + auto src = dataset.view(); + auto dst = make_host_matrix(n_rows, dim); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(), + sizeof(DataT) * dim, + src.data_handle(), + sizeof(DataT) * stride, + sizeof(DataT) * dim, + n_rows, + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + resource::sync_stream(res); + serialize_mdspan(res, os, dst.view()); +} + +template +void serialize(const raft::resources& res, + std::ostream& os, + const vpq_dataset& dataset) +{ + serialize_scalar(res, os, dataset.n_rows()); + serialize_scalar(res, os, dataset.dim()); + serialize_scalar(res, os, dataset.vq_n_centers()); + serialize_scalar(res, os, dataset.pq_n_centers()); + serialize_scalar(res, os, dataset.pq_len()); + serialize_scalar(res, os, dataset.encoded_row_length()); + serialize_mdspan(res, os, make_const_mdspan(dataset.vq_code_book.view())); + serialize_mdspan(res, os, make_const_mdspan(dataset.pq_code_book.view())); + serialize_mdspan(res, os, make_const_mdspan(dataset.data.view())); +} + +template +void serialize(const raft::resources& res, std::ostream& os, const dataset& dataset) +{ + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeEmptyDataset); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_32F); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_16F); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_8I); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeStridedDataset); + serialize_scalar(res, os, CUDA_R_8U); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeVPQDataset); + serialize_scalar(res, os, CUDA_R_32F); + return serialize(res, os, *x); + } + if (auto x = dynamic_cast*>(&dataset); x != nullptr) { + serialize_scalar(res, os, kSerializeVPQDataset); + serialize_scalar(res, os, CUDA_R_16F); + return serialize(res, os, *x); + } + RAFT_FAIL("unsupported dataset type."); +} + +template +auto deserialize_empty(raft::resources const& res, std::istream& is) + -> std::unique_ptr> +{ + auto suggested_dim = deserialize_scalar(res, is); + return std::make_unique>(suggested_dim); +} + +template +auto deserialize_strided(raft::resources const& res, std::istream& is) + -> std::unique_ptr> +{ + auto n_rows = deserialize_scalar(res, is); + auto dim = deserialize_scalar(res, is); + auto stride = deserialize_scalar(res, is); + auto host_array = make_host_matrix(n_rows, dim); + deserialize_mdspan(res, is, host_array.view()); + return make_strided_dataset(res, host_array, stride); +} + +template +auto deserialize_vpq(raft::resources const& res, std::istream& is) + -> std::unique_ptr> +{ + auto n_rows = deserialize_scalar(res, is); + auto dim = deserialize_scalar(res, is); + auto vq_n_centers = deserialize_scalar(res, is); + auto pq_n_centers = deserialize_scalar(res, is); + auto pq_len = deserialize_scalar(res, is); + auto encoded_row_length = deserialize_scalar(res, is); + + auto vq_code_book = make_device_matrix(res, vq_n_centers, dim); + auto pq_code_book = make_device_matrix(res, pq_n_centers, pq_len); + auto data = make_device_matrix(res, n_rows, encoded_row_length); + + deserialize_mdspan(res, is, vq_code_book.view()); + deserialize_mdspan(res, is, pq_code_book.view()); + deserialize_mdspan(res, is, data.view()); + + return std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book), std::move(data)); +} + +template +auto deserialize_dataset(raft::resources const& res, std::istream& is) + -> std::unique_ptr> +{ + switch (deserialize_scalar(res, is)) { + case kSerializeEmptyDataset: return deserialize_empty(res, is); + case kSerializeStridedDataset: + switch (deserialize_scalar(res, is)) { + case CUDA_R_32F: return deserialize_strided(res, is); + case CUDA_R_16F: return deserialize_strided(res, is); + case CUDA_R_8I: return deserialize_strided(res, is); + case CUDA_R_8U: return deserialize_strided(res, is); + default: break; + } + case kSerializeVPQDataset: + switch (deserialize_scalar(res, is)) { + case CUDA_R_32F: return deserialize_vpq(res, is); + case CUDA_R_16F: return deserialize_vpq(res, is); + default: break; + } + default: break; + } + RAFT_FAIL("Failed to deserialize dataset: unsupported combination of instance tags."); +} + +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/vpq_dataset.cuh b/cpp/include/raft/neighbors/detail/vpq_dataset.cuh new file mode 100644 index 0000000000..f6cd2a1ceb --- /dev/null +++ b/cpp/include/raft/neighbors/detail/vpq_dataset.cuh @@ -0,0 +1,427 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../dataset.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include // pq_bits-bitfield +#include // utils::mapping etc +#include +#include + +// A temporary stub till https://github.com/rapidsai/raft/pull/2077 is re-merged +namespace raft::util { + +/** + * Subsample the dataset to create a training set. + * + * @tparam DatasetT a row-major mdspan or mdarray (device or host) + * + * @param res raft handle + * @param dataset input row-major mdspan or mdarray (device or host) + * @param n_samples the size of the output mdarray + * + * @return a newly allocated subset of the dataset. + */ +template +auto subsample(raft::resources const& res, + const DatasetT& dataset, + typename DatasetT::index_type n_samples) + -> raft::device_matrix +{ + using value_type = typename DatasetT::value_type; + using index_type = typename DatasetT::index_type; + static_assert(std::is_same_v, + "Only row-major layout is supported at the moment"); + RAFT_EXPECTS(n_samples <= dataset.extent(0), + "The number of samples must be smaller than the number of input rows in the current " + "implementation."); + size_t dim = dataset.extent(1); + size_t trainset_ratio = dataset.extent(0) / n_samples; + auto result = raft::make_device_matrix(res, n_samples, dataset.extent(1)); + + RAFT_CUDA_TRY(cudaMemcpy2DAsync(result.data_handle(), + sizeof(value_type) * dim, + dataset.data_handle(), + sizeof(value_type) * dim * trainset_ratio, + sizeof(value_type) * dim, + n_samples, + cudaMemcpyDefault, + raft::resource::get_cuda_stream(res))); + return result; +} + +} // namespace raft::util + +namespace raft::neighbors::detail { + +template +auto fill_missing_params_heuristics(const vpq_params& params, const DatasetT& dataset) -> vpq_params +{ + vpq_params r = params; + double n_rows = dataset.extent(0); + size_t dim = dataset.extent(1); + if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{4}); } + if (r.pq_bits == 0) { r.pq_bits = 8; } + if (r.vq_n_centers == 0) { r.vq_n_centers = raft::round_up_safe(std::sqrt(n_rows), 8); } + if (r.vq_kmeans_trainset_fraction == 0) { + double vq_trainset_size = 100.0 * r.vq_n_centers; + r.vq_kmeans_trainset_fraction = std::min(1.0, vq_trainset_size / n_rows); + } + if (r.pq_kmeans_trainset_fraction == 0) { + // NB: we'll have actually `pq_dim` times more samples than this + // (because the dataset is reinterpreted as `[n_rows * pq_dim, pq_len]`) + double pq_trainset_size = 1000.0 * (1u << r.pq_bits); + r.pq_kmeans_trainset_fraction = std::min(1.0, pq_trainset_size / n_rows); + } + return r; +} + +template +auto transform_data(const raft::resources& res, DatasetT dataset) + -> device_mdarray +{ + using index_type = typename DatasetT::index_type; + using extents_type = typename DatasetT::extents_type; + using layout_type = typename DatasetT::layout_type; + using out_mdarray_type = device_mdarray; + if constexpr (std::is_same_v>) { return dataset; } + + auto result = raft::make_device_mdarray(res, dataset.extents()); + + linalg::map(res, + result.view(), + spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(dataset.view())); + + return result; +} + +/** Fix the internal indexing type to avoid integer underflows/overflows */ +using ix_t = int64_t; + +template +auto train_vq(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) + -> device_matrix +{ + const ix_t n_rows = dataset.extent(0); + const ix_t vq_n_centers = params.vq_n_centers; + const ix_t dim = dataset.extent(1); + const ix_t n_rows_train = n_rows * params.vq_kmeans_trainset_fraction; + + // Subsample the dataset and transform into the required type if necessary + auto vq_trainset = raft::util::subsample(res, dataset, n_rows_train); + auto vq_centers = raft::make_device_matrix(res, vq_n_centers, dim); + + using kmeans_in_type = typename DatasetT::value_type; + raft::cluster::kmeans_balanced_params kmeans_params; + kmeans_params.n_iters = params.kmeans_n_iters; + kmeans_params.metric = raft::distance::DistanceType::L2Expanded; + auto vq_centers_view = + raft::make_device_matrix_view(vq_centers.data_handle(), vq_n_centers, dim); + auto vq_trainset_view = raft::make_device_matrix_view( + vq_trainset.data_handle(), n_rows_train, dim); + raft::cluster::kmeans_balanced::fit( + res, + kmeans_params, + vq_trainset_view, + vq_centers_view, + spatial::knn::detail::utils::mapping{}); + + return vq_centers; +} + +template +auto predict_vq(const raft::resources& res, const DatasetT& dataset, const VqCentersT& vq_centers) + -> device_vector +{ + using kmeans_data_type = typename DatasetT::value_type; + using kmeans_math_type = typename VqCentersT::value_type; + using index_type = typename DatasetT::index_type; + using label_type = LabelT; + + auto vq_labels = raft::make_device_vector(res, dataset.extent(0)); + + raft::cluster::kmeans_balanced_params kmeans_params; + kmeans_params.metric = raft::distance::DistanceType::L2Expanded; + + auto vq_centers_view = raft::make_device_matrix_view( + vq_centers.data_handle(), vq_centers.extent(0), vq_centers.extent(1)); + + auto vq_dataset_view = raft::make_device_matrix_view( + dataset.data_handle(), dataset.extent(0), dataset.extent(1)); + + raft::cluster::kmeans_balanced:: + predict( + res, + kmeans_params, + vq_dataset_view, + vq_centers_view, + vq_labels.view(), + spatial::knn::detail::utils::mapping{}); + + return vq_labels; +} + +template +auto train_pq(const raft::resources& res, + const vpq_params& params, + const DatasetT& dataset, + const device_matrix_view& vq_centers) + -> device_matrix +{ + const ix_t n_rows = dataset.extent(0); + const ix_t dim = dataset.extent(1); + const ix_t pq_dim = params.pq_dim; + const ix_t pq_bits = params.pq_bits; + const ix_t pq_n_centers = ix_t{1} << pq_bits; + const ix_t pq_len = raft::div_rounding_up_safe(dim, pq_dim); + const ix_t n_rows_train = n_rows * params.pq_kmeans_trainset_fraction; + + // Subsample the dataset and transform into the required type if necessary + auto pq_trainset = transform_data(res, raft::util::subsample(res, dataset, n_rows_train)); + + // Subtract VQ centers + { + auto vq_labels = predict_vq(res, pq_trainset, vq_centers); + using index_type = typename DatasetT::index_type; + linalg::map_offset( + res, + pq_trainset.view(), + [labels = vq_labels.view(), centers = vq_centers, dim] __device__(index_type off, MathT x) { + index_type i = off / dim; + index_type j = off % dim; + return x - centers(labels(i), j); + }, + raft::make_const_mdspan(pq_trainset.view())); + } + + auto pq_centers = raft::make_device_matrix(res, pq_n_centers, pq_len); + + // Train PQ centers + { + raft::cluster::kmeans_balanced_params kmeans_params; + kmeans_params.n_iters = params.kmeans_n_iters; + kmeans_params.metric = raft::distance::DistanceType::L2Expanded; + + auto pq_centers_view = + raft::make_device_matrix_view(pq_centers.data_handle(), pq_n_centers, pq_len); + + auto pq_trainset_view = raft::make_device_matrix_view( + pq_trainset.data_handle(), n_rows_train * pq_dim, pq_len); + + raft::cluster::kmeans_balanced::fit( + res, kmeans_params, pq_trainset_view, pq_centers_view); + } + + return pq_centers; +} + +template +__device__ auto compute_code(device_matrix_view dataset, + device_matrix_view vq_centers, + device_matrix_view pq_centers, + IdxT i, + uint32_t j, + LabelT vq_label) -> uint8_t +{ + auto data_mapping = spatial::knn::detail::utils::mapping{}; + uint32_t lane_id = Pow2::mod(laneId()); + + const uint32_t pq_book_size = pq_centers.extent(0); + const uint32_t pq_len = pq_centers.extent(1); + float min_dist = std::numeric_limits::infinity(); + uint8_t code = 0; + // calculate the distance for each PQ cluster, find the minimum for each thread + for (uint32_t l = lane_id; l < pq_book_size; l += SubWarpSize) { + // NB: the L2 quantifiers on residuals are always trained on L2 metric. + float d = 0.0f; + for (uint32_t k = 0; k < pq_len; k++) { + auto jk = j * pq_len + k; + auto x = data_mapping(dataset(i, jk)) - vq_centers(vq_label, jk); + auto t = x - pq_centers(l, k); + d += t * t; + } + if (d < min_dist) { + min_dist = d; + code = uint8_t(l); + } + } + // reduce among threads +#pragma unroll + for (uint32_t stride = SubWarpSize >> 1; stride > 0; stride >>= 1) { + const auto other_dist = shfl_xor(min_dist, stride, SubWarpSize); + const auto other_code = shfl_xor(code, stride, SubWarpSize); + if (other_dist < min_dist) { + min_dist = other_dist; + code = other_code; + } + } + return code; +} + +template +__launch_bounds__(BlockSize) RAFT_KERNEL + process_and_fill_codes_kernel(device_matrix_view out_codes, + device_matrix_view dataset, + device_matrix_view vq_centers, + device_vector_view vq_labels, + device_matrix_view pq_centers) +{ + constexpr uint32_t kSubWarpSize = std::min(WarpSize, 1u << PqBits); + using subwarp_align = Pow2; + const IdxT row_ix = subwarp_align::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x}); + if (row_ix >= out_codes.extent(0)) { return; } + + const uint32_t pq_dim = raft::div_rounding_up_unsafe(vq_centers.extent(1), pq_centers.extent(1)); + + const uint32_t lane_id = Pow2::mod(threadIdx.x); + const LabelT vq_label = vq_labels(row_ix); + + // write label + auto* out_label_ptr = reinterpret_cast(&out_codes(row_ix, 0)); + if (lane_id == 0) { *out_label_ptr = vq_label; } + + auto* out_codes_ptr = reinterpret_cast(out_label_ptr + 1); + ivf_pq::detail::bitfield_view_t code_view{out_codes_ptr}; + for (uint32_t j = 0; j < pq_dim; j++) { + // find PQ label + uint8_t code = compute_code(dataset, vq_centers, pq_centers, row_ix, j, vq_label); + // TODO: this writes in global memory one byte per warp, which is very slow. + // It's better to keep the codes in the shared memory or registers and dump them at once. + if (lane_id == 0) { code_view[j] = code; } + } +} + +template +auto process_and_fill_codes(const raft::resources& res, + const vpq_params& params, + const DatasetT& dataset, + device_matrix_view vq_centers, + device_matrix_view pq_centers) + -> device_matrix +{ + using data_t = typename DatasetT::value_type; + using cdataset_t = vpq_dataset; + using label_t = uint32_t; + + const ix_t n_rows = dataset.extent(0); + const ix_t dim = dataset.extent(1); + const ix_t pq_dim = params.pq_dim; + const ix_t pq_bits = params.pq_bits; + const ix_t pq_n_centers = ix_t{1} << pq_bits; + // NB: codes must be aligned at least to sizeof(label_t) to be able to read labels. + const ix_t codes_rowlen = + sizeof(label_t) * (1 + raft::div_rounding_up_safe(pq_dim * pq_bits, 8 * sizeof(label_t))); + + auto codes = raft::make_device_matrix(res, n_rows, codes_rowlen); + + auto stream = raft::resource::get_cuda_stream(res); + + // TODO: with scaling workspace we could choose the batch size dynamically + constexpr ix_t kReasonableMaxBatchSize = 65536; + constexpr ix_t kBlockSize = 256; + const ix_t threads_per_vec = std::min(WarpSize, pq_n_centers); + dim3 threads(kBlockSize, 1, 1); + ix_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); + auto kernel = [](uint32_t pq_bits) { + switch (pq_bits) { + case 4: return process_and_fill_codes_kernel; + case 5: return process_and_fill_codes_kernel; + case 6: return process_and_fill_codes_kernel; + case 7: return process_and_fill_codes_kernel; + case 8: return process_and_fill_codes_kernel; + default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); + } + }(pq_bits); + for (const auto& batch : + spatial::knn::detail::utils::batch_load_iterator(dataset.data_handle(), + n_rows, + dim, + max_batch_size, + stream, + rmm::mr::get_current_device_resource())) { + auto batch_view = raft::make_device_matrix_view(batch.data(), ix_t(batch.size()), dim); + auto labels = predict_vq(res, batch_view, vq_centers); + dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize / threads_per_vec), 1, 1); + kernel<<>>( + make_device_matrix_view( + codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen), + batch_view, + vq_centers, + make_const_mdspan(labels.view()), + pq_centers); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + return codes; +} + +template +auto vpq_convert_math_type(const raft::resources& res, vpq_dataset&& src) + -> vpq_dataset +{ + auto vq_code_book = make_device_mdarray(res, src.vq_code_book.extents()); + auto pq_code_book = make_device_mdarray(res, src.pq_code_book.extents()); + + linalg::map(res, + vq_code_book.view(), + spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(src.vq_code_book.view())); + linalg::map(res, + pq_code_book.view(), + spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(src.pq_code_book.view())); + return vpq_dataset{ + std::move(vq_code_book), std::move(pq_code_book), std::move(src.data)}; +} + +template +auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) + -> vpq_dataset +{ + // Use a heuristic to impute missing parameters. + auto ps = fill_missing_params_heuristics(params, dataset); + + // Train codes + auto vq_code_book = train_vq(res, ps, dataset); + auto pq_code_book = + train_pq(res, ps, dataset, raft::make_const_mdspan(vq_code_book.view())); + + // Encode dataset + auto codes = process_and_fill_codes(res, + ps, + dataset, + raft::make_const_mdspan(vq_code_book.view()), + raft::make_const_mdspan(pq_code_book.view())); + + return vpq_dataset{ + std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; +} + +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/vpq_dataset.cuh b/cpp/include/raft/neighbors/vpq_dataset.cuh new file mode 100644 index 0000000000..73ee6c52ed --- /dev/null +++ b/cpp/include/raft/neighbors/vpq_dataset.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "dataset.hpp" +#include "detail/vpq_dataset.cuh" + +#include + +namespace raft::neighbors { + +/** + * @brief Compress a dataset for use in CAGRA-Q search in place of the original data. + * + * @tparam DatasetT a row-major mdspan or mdarray (device or host). + * @tparam MathT a type of the codebook elements and internal math ops. + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res + * @param[in] params VQ and PQ parameters for compressing the data + * @param[in] dataset a row-major mdspan or mdarray (device or host) [n_rows, dim]. + */ +template +auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) + -> vpq_dataset +{ + if constexpr (std::is_same_v) { + return detail::vpq_convert_math_type( + res, detail::vpq_build(res, params, dataset)); + } else { + return detail::vpq_build(res, params, dataset); + } +} + +} // namespace raft::neighbors From d14cac2a9f36d24559c8a0a61806cef91913048b Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Mon, 18 Mar 2024 18:58:46 +0100 Subject: [PATCH 06/18] Add explicit instantiations for IVF-PQ search kernels used in tests (#2212) Compilation of IVF-PQ search kernels can be time consuming. In `libraft.so` the compilation is done in parallel for kernels without filtering and with `int64_t` index type. We have test with `uint32_t` index type as well as tests for `bitset_filter` with both 32 and 64 bit index types. This PR adds explicit template instantiations for the test. This way we avoid repeated compilation of the kernels with filter and this also enables parallel compilation of the `compute_similarity` kernel for different template types. The kernels with these additional type parameters are not added to `libraft.so`, only linked together with the test executable. Note that this PR does not increase the number of compiled kernels, but it enables to compile them in parallel. Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/2212 --- cpp/bench/prims/CMakeLists.txt | 8 + .../knn/ivf_pq_filter_float_int64_t.cu | 5 +- cpp/include/raft/core/detail/nvtx.hpp | 1 + .../raft/neighbors/detail/ivf_pq_build.cuh | 46 +++-- .../ivf_pq_compute_similarity_template.cuh | 71 +++++++ ...pq_compute_similarity_filters_test-ext.cuh | 181 ++++++++++++++++++ .../neighbors/ivf_pq_search_test-ext.cuh | 88 +++++++++ .../ivf_pq_compute_similarity_00_generate.py | 107 ++++------- .../ivf_pq_compute_similarity_float_float.cu | 57 +----- ...compute_similarity_float_float_bitset32.cu | 28 +++ ...compute_similarity_float_float_bitset64.cu | 28 +++ ...q_compute_similarity_float_float_filt32.cu | 28 +++ ...f_pq_compute_similarity_float_fp8_false.cu | 57 +----- ...ute_similarity_float_fp8_false_bitset32.cu | 28 +++ ...ute_similarity_float_fp8_false_bitset64.cu | 28 +++ ...mpute_similarity_float_fp8_false_filt32.cu | 28 +++ ...vf_pq_compute_similarity_float_fp8_true.cu | 57 +----- ...pute_similarity_float_fp8_true_bitset32.cu | 28 +++ ...pute_similarity_float_fp8_true_bitset64.cu | 28 +++ ...ompute_similarity_float_fp8_true_filt32.cu | 28 +++ .../ivf_pq_compute_similarity_float_half.cu | 57 +----- ..._compute_similarity_float_half_bitset32.cu | 28 +++ ..._compute_similarity_float_half_bitset64.cu | 28 +++ ...pq_compute_similarity_float_half_filt32.cu | 28 +++ ...vf_pq_compute_similarity_half_fp8_false.cu | 57 +----- ...pute_similarity_half_fp8_false_bitset32.cu | 28 +++ ...pute_similarity_half_fp8_false_bitset64.cu | 28 +++ ...ompute_similarity_half_fp8_false_filt32.cu | 28 +++ ...ivf_pq_compute_similarity_half_fp8_true.cu | 57 +----- ...mpute_similarity_half_fp8_true_bitset32.cu | 28 +++ ...mpute_similarity_half_fp8_true_bitset64.cu | 28 +++ ...compute_similarity_half_fp8_true_filt32.cu | 28 +++ .../ivf_pq_compute_similarity_half_half.cu | 57 +----- ...q_compute_similarity_half_half_bitset32.cu | 28 +++ ...q_compute_similarity_half_half_bitset64.cu | 28 +++ ..._pq_compute_similarity_half_half_filt32.cu | 28 +++ .../ivf_pq_search_filtering_float_int64_t.cu | 43 +++++ cpp/test/CMakeLists.txt | 24 +++ .../ann_ivf_pq/ivf_pq_build_float_uint32_t.cu | 37 ++++ .../ann_ivf_pq/ivf_pq_build_test-ext.cuh | 38 ++++ .../ivf_pq_search_float_uint32_t.cu | 68 +++++++ .../ann_ivf_pq/test_filter_float_int64_t.cu | 6 +- .../ann_ivf_pq/test_filter_int8_t_int64_t.cu | 6 +- .../ann_ivf_pq/test_float_uint32_t.cu | 12 +- .../ann_ivf_pq/test_int8_t_int64_t.cu | 3 +- 45 files changed, 1245 insertions(+), 486 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity_template.cuh create mode 100644 cpp/internal/raft_internal/neighbors/ivf_pq_compute_similarity_filters_test-ext.cuh create mode 100644 cpp/internal/raft_internal/neighbors/ivf_pq_search_test-ext.cuh create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_filt32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_filt32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_filt32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_filt32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_filt32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_filt32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_filt32.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu create mode 100644 cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_float_uint32_t.cu create mode 100644 cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_test-ext.cuh create mode 100644 cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 3a2431cd34..5577881ef7 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -156,6 +156,14 @@ if(BUILD_PRIMS_BENCH) bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu bench/prims/neighbors/knn/ivf_pq_int8_t_int64_t.cu bench/prims/neighbors/knn/ivf_pq_uint8_t_int64_t.cu + src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu bench/prims/neighbors/refine_float_int64_t.cu bench/prims/neighbors/refine_uint8_t_int64_t.cu bench/prims/main.cpp diff --git a/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu b/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu index 9534515cbb..1840eca99d 100644 --- a/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu +++ b/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,10 @@ * limitations under the License. */ -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter #include "../knn.cuh" +#include +#include namespace raft::bench::spatial { KNN_REGISTER(float, int64_t, ivf_pq_filter_knn, kInputsFilter, kNoCopyOnly, kScopeFull); diff --git a/cpp/include/raft/core/detail/nvtx.hpp b/cpp/include/raft/core/detail/nvtx.hpp index 8afd1f16c6..82db75de84 100644 --- a/cpp/include/raft/core/detail/nvtx.hpp +++ b/cpp/include/raft/core/detail/nvtx.hpp @@ -28,6 +28,7 @@ #include #include #include +#include namespace raft::common::nvtx::detail { diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index b7796d52fa..8e3f7dbaf3 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -61,6 +61,8 @@ namespace raft::neighbors::ivf_pq::detail { using namespace raft::spatial::knn::detail; // NOLINT +using internal_extents_t = int64_t; // The default mdspan extent type used internally. + template __launch_bounds__(BlockDim) RAFT_KERNEL copy_warped_kernel( T* out, uint32_t ld_out, const S* in, uint32_t ld_in, uint32_t n_cols, size_t n_rows) @@ -442,15 +444,16 @@ void train_per_subset(raft::resources const& handle, stream); // train PQ codebook for this subspace - auto sub_trainset_view = - raft::make_device_matrix_view(sub_trainset.data(), n_rows, index.pq_len()); - auto centers_tmp_view = raft::make_device_matrix_view( + auto sub_trainset_view = raft::make_device_matrix_view( + sub_trainset.data(), n_rows, index.pq_len()); + auto centers_tmp_view = raft::make_device_matrix_view( pq_centers_tmp.data() + index.pq_book_size() * index.pq_len() * j, index.pq_book_size(), index.pq_len()); - auto sub_labels_view = raft::make_device_vector_view(sub_labels.data(), n_rows); - auto cluster_sizes_view = - raft::make_device_vector_view(pq_cluster_sizes.data(), index.pq_book_size()); + auto sub_labels_view = + raft::make_device_vector_view(sub_labels.data(), n_rows); + auto cluster_sizes_view = raft::make_device_vector_view( + pq_cluster_sizes.data(), index.pq_book_size()); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.n_iters = kmeans_n_iters; kmeans_params.metric = raft::distance::DistanceType::L2Expanded; @@ -525,17 +528,17 @@ void train_per_cluster(raft::resources const& handle, size_t available_rows = size_t(cluster_size) * size_t(index.pq_dim()); auto pq_n_rows = uint32_t(std::min(big_enough, available_rows)); // train PQ codebook for this cluster - auto rot_vectors_view = raft::make_device_matrix_view( + auto rot_vectors_view = raft::make_device_matrix_view( rot_vectors.data(), pq_n_rows, index.pq_len()); - auto centers_tmp_view = raft::make_device_matrix_view( + auto centers_tmp_view = raft::make_device_matrix_view( pq_centers_tmp.data() + static_cast(index.pq_book_size()) * static_cast(index.pq_len()) * static_cast(l), index.pq_book_size(), index.pq_len()); auto pq_labels_view = - raft::make_device_vector_view(pq_labels.data(), pq_n_rows); - auto pq_cluster_sizes_view = - raft::make_device_vector_view(pq_cluster_sizes.data(), index.pq_book_size()); + raft::make_device_vector_view(pq_labels.data(), pq_n_rows); + auto pq_cluster_sizes_view = raft::make_device_vector_view( + pq_cluster_sizes.data(), index.pq_book_size()); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.n_iters = kmeans_n_iters; kmeans_params.metric = raft::distance::DistanceType::L2Expanded; @@ -1587,11 +1590,11 @@ void extend(raft::resources const& handle, cudaMemcpyDefault, stream)); for (const auto& batch : vec_batches) { - auto batch_data_view = - raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); - auto batch_labels_view = raft::make_device_vector_view( + auto batch_data_view = raft::make_device_matrix_view( + batch.data(), batch.size(), index->dim()); + auto batch_labels_view = raft::make_device_vector_view( new_data_labels.data() + batch.offset(), batch.size()); - auto centers_view = raft::make_device_matrix_view( + auto centers_view = raft::make_device_matrix_view( cluster_centers.data(), n_clusters, index->dim()); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.metric = index->metric(); @@ -1767,10 +1770,10 @@ auto build(raft::resources const& handle, auto cluster_centers = cluster_centers_buf.data(); // Train balanced hierarchical kmeans clustering - auto trainset_const_view = - raft::make_device_matrix_view(trainset.data(), n_rows_train, index.dim()); - auto centers_view = - raft::make_device_matrix_view(cluster_centers, index.n_lists(), index.dim()); + auto trainset_const_view = raft::make_device_matrix_view( + trainset.data(), n_rows_train, index.dim()); + auto centers_view = raft::make_device_matrix_view( + cluster_centers, index.n_lists(), index.dim()); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; kmeans_params.metric = index.metric(); @@ -1779,9 +1782,10 @@ auto build(raft::resources const& handle, // Trainset labels are needed for training PQ codebooks rmm::device_uvector labels(n_rows_train, stream, device_memory); - auto centers_const_view = raft::make_device_matrix_view( + auto centers_const_view = raft::make_device_matrix_view( cluster_centers, index.n_lists(), index.dim()); - auto labels_view = raft::make_device_vector_view(labels.data(), n_rows_train); + auto labels_view = + raft::make_device_vector_view(labels.data(), n_rows_train); raft::cluster::kmeans_balanced::predict(handle, kmeans_params, trainset_const_view, diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity_template.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity_template.cuh new file mode 100644 index 0000000000..83dd994bd6 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity_template.cuh @@ -0,0 +1,71 @@ + +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is to be used in source files generated by + * src/neighbors/detailivf_pq_compute_similarity_00_generate.py + */ + +#pragma once + +#include +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , diff --git a/cpp/internal/raft_internal/neighbors/ivf_pq_compute_similarity_filters_test-ext.cuh b/cpp/internal/raft_internal/neighbors/ivf_pq_compute_similarity_filters_test-ext.cuh new file mode 100644 index 0000000000..aa14ab19b8 --- /dev/null +++ b/cpp/internal/raft_internal/neighbors/ivf_pq_compute_similarity_filters_test-ext.cuh @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // RAFT_WEAK_FUNCTION +#include // raft::distance::DistanceType +#include +#include // raft::neighbors::ivf_pq::detail::fp_8bit +#include // none_ivf_sample_filter +#include // none_ivf_sample_filter + +#include // rmm::cuda_stream_view + +#include // __half + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ + OutT, LutT, IvfSampleFilterT) \ + extern template auto \ + raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + extern template void \ + raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + uint32_t queries_offset, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + IvfSampleFilterT sample_filter, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + float, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); + +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + float, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + float, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/internal/raft_internal/neighbors/ivf_pq_search_test-ext.cuh b/cpp/internal/raft_internal/neighbors/ivf_pq_search_test-ext.cuh new file mode 100644 index 0000000000..7a65e2d2f8 --- /dev/null +++ b/cpp/internal/raft_internal/neighbors/ivf_pq_search_test-ext.cuh @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::device_matrix_view +#include // raft::resources +#include +#include // raft::neighbors::ivf_pq::index +#include +#include + +#include + +#include // int64_t + +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + extern template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + extern template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr); \ + \ + extern template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances) + +instantiate_raft_neighbors_ivf_pq_search(float, uint32_t); + +#undef instantiate_raft_neighbors_ivf_pq_search + +#define instantiate_raft_neighbors_ivf_pq_search_with_filtering(T, IdxT, FilterT) \ + extern template void raft::neighbors::ivf_pq::search_with_filtering( \ + raft::resources const& handle, \ + const search_params& params, \ + const index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + FilterT sample_filter) + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + float, uint32_t, raft::neighbors::filtering::bitset_filter); + +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + float, uint32_t, raft::neighbors::filtering::none_ivf_sample_filter); + +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + float, int64_t, raft::neighbors::filtering::bitset_filter); + +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + int8_t, int64_t, raft::neighbors::filtering::bitset_filter); + +#undef COMMA +#undef instantiate_raft_neighbors_ivf_pq_search_with_filtering diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py index 670ed57ed1..9825a48f81 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -header = """ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. +header = """/* + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,78 +30,56 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ - -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT, IvfSampleFilterT) \\ - template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \\ - const cudaDeviceProp& dev_props, \\ - bool manage_local_topk, \\ - int locality_hint, \\ - double preferred_shmem_carveout, \\ - uint32_t pq_bits, \\ - uint32_t pq_dim, \\ - uint32_t precomp_data_count, \\ - uint32_t n_queries, \\ - uint32_t n_probes, \\ - uint32_t topk) -> raft::neighbors::ivf_pq::detail::selected; \\ -\\ - template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \\ - raft::neighbors::ivf_pq::detail::selected s, \\ - rmm::cuda_stream_view stream, \\ - uint32_t dim, \\ - uint32_t n_probes, \\ - uint32_t pq_dim, \\ - uint32_t n_queries, \\ - uint32_t queries_offset, \\ - raft::distance::DistanceType metric, \\ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \\ - uint32_t topk, \\ - uint32_t max_samples, \\ - const float* cluster_centers, \\ - const float* pq_centers, \\ - const uint8_t* const* pq_dataset, \\ - const uint32_t* cluster_labels, \\ - const uint32_t* _chunk_indices, \\ - const float* queries, \\ - const uint32_t* index_list, \\ - float* query_kths, \\ - IvfSampleFilterT sample_filter, \\ - LutT* lut_scores, \\ - OutT* _out_scores, \\ - uint32_t* _out_indices); - - -#define COMMA , + +#include """ -trailer = """ -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select -""" +none_filter_int64 = "raft::neighbors::filtering::ivf_to_sample_filter" \ + "" +none_filter_int32 = "raft::neighbors::filtering::ivf_to_sample_filter" \ + "" +bitset_filter32 = "raft::neighbors::filtering::ivf_to_sample_filter" \ + ">" +bitset_filter64 = "raft::neighbors::filtering::ivf_to_sample_filter" \ + ">" types = dict( - half_fp8_false=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>"), - half_fp8_true=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>"), - half_half=("half", "half"), - float_half=("float", "half"), - float_float= ("float", "float"), - float_fp8_false=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>"), - float_fp8_true=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>"), + half_fp8_false=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", none_filter_int64), + half_fp8_true=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", none_filter_int64), + half_half=("half", "half", none_filter_int64), + float_half=("float", "half", none_filter_int64), + float_float= ("float", "float", none_filter_int64), + float_fp8_false=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", none_filter_int64), + float_fp8_true=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", none_filter_int64), + half_fp8_false_filt32=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", none_filter_int32), + half_fp8_true_filt32=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", none_filter_int32), + half_half_filt32=("half", "half", none_filter_int32), + float_half_filt32=("float", "half", none_filter_int32), + float_float_filt32= ("float", "float", none_filter_int32), + float_fp8_false_filt32=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", none_filter_int32), + float_fp8_true_filt32=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", none_filter_int32), + half_fp8_false_bitset32=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", bitset_filter32), + half_fp8_true_bitset32=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", bitset_filter32), + half_half_bitset32=("half", "half", bitset_filter32), + float_half_bitset32=("float", "half", bitset_filter32), + float_float_bitset32= ("float", "float", bitset_filter32), + float_fp8_false_bitset32=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", bitset_filter32), + float_fp8_true_bitset32=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", bitset_filter32), + half_fp8_false_bitset64=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", bitset_filter64), + half_fp8_true_bitset64=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", bitset_filter64), + half_half_bitset64=("half", "half", bitset_filter64), + float_half_bitset64=("float", "half", bitset_filter64), + float_float_bitset64= ("float", "float", bitset_filter64), + float_fp8_false_bitset64=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>", bitset_filter64), + float_fp8_true_bitset64=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>", bitset_filter64) ) -for path_key, (OutT, LutT) in types.items(): +for path_key, (OutT, LutT, FilterT) in types.items(): path = f"ivf_pq_compute_similarity_{path_key}.cu" with open(path, "w") as f: f.write(header) - f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, raft::neighbors::filtering::ivf_to_sample_filter);\n") - f.write(trailer) + f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT}, {FilterT});\n") print(f"src/neighbors/detail/{path}") diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu index 7e17d6822a..db51608ae1 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,65 +16,13 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , +#include instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, float, raft::neighbors::filtering::ivf_to_sample_filter< int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset32.cu new file mode 100644 index 0000000000..caaf40abdf --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + float, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu new file mode 100644 index 0000000000..7801c25e9f --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + float, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_filt32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_filt32.cu new file mode 100644 index 0000000000..45ae348849 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float_filt32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + float, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu index c1b72dab33..2f5bcf8f92 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,65 +16,13 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , +#include instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, raft::neighbors::filtering::ivf_to_sample_filter< int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset32.cu new file mode 100644 index 0000000000..e7f2c44254 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu new file mode 100644 index 0000000000..01b6900bb8 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_filt32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_filt32.cu new file mode 100644 index 0000000000..9f8d453364 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_filt32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu index fdff0860fc..06d21bcd50 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,65 +16,13 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , +#include instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, raft::neighbors::filtering::ivf_to_sample_filter< int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset32.cu new file mode 100644 index 0000000000..8b733a23c1 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu new file mode 100644 index 0000000000..77e4f9a023 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_filt32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_filt32.cu new file mode 100644 index 0000000000..3e036e3df4 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_filt32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu index 7205544370..ff42f5e041 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,65 +16,13 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , +#include instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( float, half, raft::neighbors::filtering::ivf_to_sample_filter< int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset32.cu new file mode 100644 index 0000000000..40b6313865 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu new file mode 100644 index 0000000000..9cedabdb11 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_filt32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_filt32.cu new file mode 100644 index 0000000000..61422bbc36 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half_filt32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu index 2ac6c3527b..d2064cfe97 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,65 +16,13 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , +#include instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, raft::neighbors::filtering::ivf_to_sample_filter< int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset32.cu new file mode 100644 index 0000000000..1127f39f71 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu new file mode 100644 index 0000000000..0330bf58d6 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_filt32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_filt32.cu new file mode 100644 index 0000000000..d20f7921d5 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_filt32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu index 70f3ffdb0c..9dc954406e 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,65 +16,13 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , +#include instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, raft::neighbors::filtering::ivf_to_sample_filter< int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset32.cu new file mode 100644 index 0000000000..9131fa25a8 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu new file mode 100644 index 0000000000..8b4521b31b --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_filt32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_filt32.cu new file mode 100644 index 0000000000..71b63cf4a0 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_filt32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu index 5cc1cb8038..f527d879be 100644 --- a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,65 +16,13 @@ /* * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py - * * Make changes there and run in this directory: - * * > python ivf_pq_compute_similarity_00_generate.py - * */ -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , +#include instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( half, half, raft::neighbors::filtering::ivf_to_sample_filter< int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset32.cu new file mode 100644 index 0000000000..8e1962e2bb --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu new file mode 100644 index 0000000000..e9671703e7 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + int64_t COMMA raft::neighbors::filtering::bitset_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_filt32.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_filt32.cu new file mode 100644 index 0000000000..b66a07d1a9 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half_filt32.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * Make changes there and run in this directory: + * > python ivf_pq_compute_similarity_00_generate.py + */ + +#include +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, + half, + raft::neighbors::filtering::ivf_to_sample_filter< + uint32_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); diff --git a/cpp/src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu new file mode 100644 index 0000000000..39af78f12e --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // raft::device_matrix_view +#include // raft::resources +#include +#include // raft::neighbors::ivf_pq::index +#include +#include + +#include + +#include // int64_t + +#define instantiate_raft_neighbors_ivf_pq_search_with_filtering(T, IdxT, FilterT) \ + template void raft::neighbors::ivf_pq::search_with_filtering( \ + raft::resources const& handle, \ + const search_params& params, \ + const index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + FilterT sample_filter) + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + float, int64_t, raft::neighbors::filtering::bitset_filter); + +#undef COMMA +#undef instantiate_raft_neighbors_ivf_pq_search_with_filtering diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index dd7eb839ab..65d6e738a2 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -401,6 +401,30 @@ if(BUILD_TESTS) test/neighbors/ann_ivf_flat/test_float_int64_t.cu test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu + test/neighbors/ann_ivf_pq/ivf_pq_build_float_uint32_t.cu + test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu + src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_float_filt32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_filt32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_filt32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_half_filt32.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_filt32.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_filt32.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_half_filt32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset32.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset32.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset32.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset32.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu test/neighbors/ann_ivf_pq/test_float_uint32_t.cu test/neighbors/ann_ivf_pq/test_float_int64_t.cu test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu diff --git a/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_float_uint32_t.cu b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_float_uint32_t.cu new file mode 100644 index 0000000000..5ba21c3c2f --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_float_uint32_t.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index +#include + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + template auto raft::neighbors::ivf_pq::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(float, uint32_t); + +#undef instantiate_raft_neighbors_ivf_pq_build diff --git a/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_test-ext.cuh b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_test-ext.cuh new file mode 100644 index 0000000000..cd5435ab2e --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_build_test-ext.cuh @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include // raft::neighbors::ivf_pq::index +#include + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + extern template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + extern template auto raft::neighbors::ivf_pq::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(float, uint32_t); + +#undef instantiate_raft_neighbors_ivf_pq_build diff --git a/cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu new file mode 100644 index 0000000000..942d0fcc44 --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index +#include + +#include + +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_neighbors_ivf_pq_search(float, uint32_t); + +#undef instantiate_raft_neighbors_ivf_pq_search + +#define instantiate_raft_neighbors_ivf_pq_search_with_filtering(T, IdxT, FilterT) \ + template void raft::neighbors::ivf_pq::search_with_filtering( \ + raft::resources const& handle, \ + const search_params& params, \ + const index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances, \ + FilterT sample_filter) + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + float, uint32_t, raft::neighbors::filtering::bitset_filter); + +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + int8_t, int64_t, raft::neighbors::filtering::bitset_filter); + +instantiate_raft_neighbors_ivf_pq_search_with_filtering( + float, uint32_t, raft::neighbors::filtering::none_ivf_sample_filter); + +#undef COMMA +#undef instantiate_raft_neighbors_ivf_pq_search_with_filtering diff --git a/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu index 17f72fb08a..70d5d8761f 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,11 @@ * limitations under the License. */ -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter #include "../ann_ivf_pq.cuh" +#include +#include + namespace raft::neighbors::ivf_pq { using f32_f32_i64_filter = ivf_pq_filter_test; diff --git a/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu index 537dbb4979..ba96a8db0b 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,11 @@ * limitations under the License. */ -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter #include "../ann_ivf_pq.cuh" +#include +#include + namespace raft::neighbors::ivf_pq { using f32_i08_i64_filter = ivf_pq_filter_test; diff --git a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu index a6cfab1f19..b8ada2249a 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu @@ -14,15 +14,11 @@ * limitations under the License. */ -// XXX: the uint32_t instance is not compiled in libraft.so. So we allow -// instantiating the template here. -// -// TODO: consider removing this test or consider adding an instantiation to the -// library. - -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY - #include "../ann_ivf_pq.cuh" +#include "ivf_pq_build_test-ext.cuh" + +#include +#include namespace raft::neighbors::ivf_pq { diff --git a/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu index 014e96a2db..970bdd6a12 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include "../ann_ivf_pq.cuh" +#include namespace raft::neighbors::ivf_pq { using f32_i08_i64 = ivf_pq_test; From bd50c37f52eeff459ca032bd0ef5f7920c2bcc8d Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 19 Mar 2024 10:51:23 +0100 Subject: [PATCH 07/18] Fix ANN bench ground truth generation for k>1024 (#2180) Generating ANN bench ground truth is affected by bug #2171, when k>1024. This PR fixes the issue for the ground truth generation. Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2180 --- .../raft/neighbors/detail/knn_merge_parts.cuh | 3 +++ .../raft-ann-bench/generate_groundtruth/__main__.py | 13 ++++--------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh index 0395e86e43..33324714fd 100644 --- a/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh +++ b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -168,5 +169,7 @@ inline void knn_merge_parts(const value_t* inK, else if (k <= 1024) knn_merge_parts_impl( inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else + THROW("Unimplemented for k=%d, knn_merge_parts works for k<=1024", k); } } // namespace raft::neighbors::detail diff --git a/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py b/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py index f4d97edea5..a5ebb76635 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -62,17 +62,12 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"): X = cp.asarray(dataset[i : i + n_batch, :], cp.float32) - D, Ind = knn( - X, - queries, - k, - metric=metric, - handle=handle, - global_id_offset=i, # shift neighbor index by offset i - ) + D, Ind = knn(X, queries, k, metric=metric, handle=handle) handle.sync() D, Ind = cp.asarray(D), cp.asarray(Ind) + Ind += i # shift neighbor index by offset i + if distances is None: distances = D indices = Ind From e53aa0c18d77f263ed6aa058dff87d3a6dd43590 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Akif=20=C3=87=C3=96RD=C3=9CK?= Date: Tue, 19 Mar 2024 15:26:15 +0100 Subject: [PATCH 08/18] Fix bug in blockRankedReduce (#2226) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit There was a bug appearing for negative floating point numbers with a max reduce operation. The `std::numeric_limits::min()` is greater than the negative floating point values whereas we want it to be smaller than all representable values. This PR replaces the `min` with the `lowest`. Authors: - Akif ÇÖRDÜK (https://github.com/akifcorduk) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Bradley Dice (https://github.com/bdice) - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/2226 --- cpp/include/raft/util/reduction.cuh | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/util/reduction.cuh b/cpp/include/raft/util/reduction.cuh index 7e897e67f5..2c2b1aa228 100644 --- a/cpp/include/raft/util/reduction.cuh +++ b/cpp/include/raft/util/reduction.cuh @@ -157,11 +157,10 @@ DI std::pair blockRankedReduce(T val, val = values[lane]; idx = indices[lane]; } else { - // get the min if it is a max op, get the max if it is a min op - val = reduce_op(std::numeric_limits::min(), std::numeric_limits::max()) == - std::numeric_limits::min() - ? std::numeric_limits::max() - : std::numeric_limits::min(); + // get the lower_bound of the type if it is a max op, + // get the upper bound of the type if it is a min op + val = reduce_op(lower_bound(), upper_bound()) == lower_bound() ? upper_bound() + : lower_bound(); idx = -1; } __syncthreads(); From 413e34e786b1c6b50a9dc6d76ff57d1f8a31555d Mon Sep 17 00:00:00 2001 From: Mahesh Doijade <36705640+mdoijade@users.noreply.github.com> Date: Tue, 19 Mar 2024 22:33:54 +0530 Subject: [PATCH 09/18] Add fused cosine 1-NN cutlass based kernel (#2125) - Adds cosine 1-NN cutlass based kernel for SM 8.0 or higher using tensor cores. - based on 3x TF32 - unifies the fusedDistanceNN kernels for L2/cosine. - expose this API in pylibraft as `fused_distance_nn_arg_min` supporting cosine & L2 distance metrics. Authors: - Mahesh Doijade (https://github.com/mdoijade) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/2125 --- cpp/CMakeLists.txt | 4 +- cpp/bench/prims/CMakeLists.txt | 12 +- .../distance/detail/fused_distance_nn.cuh | 97 ++++ .../custom_epilogue_with_broadcast.h | 1 + .../detail/fused_distance_nn/cutlass_base.cuh | 18 +- .../epilogue_elementwise.cuh | 9 +- .../fused_distance_nn/fused_cosine_nn.cuh | 136 ++++++ .../detail/fused_distance_nn/fused_l2_nn.cuh | 146 ++++++ .../fused_distance_nn/helper_structs.cuh | 146 ++++++ .../fused_distance_nn/persistent_gemm.h | 4 +- .../predicated_tile_iterator_reduced_vec.h | 99 ++--- .../detail/fused_distance_nn/simt_kernel.cuh | 187 ++++++++ .../raft/distance/detail/masked_nn.cuh | 2 +- .../raft/distance/fused_distance_nn-ext.cuh | 84 ++++ .../raft/distance/fused_distance_nn-inl.cuh | 328 ++++++++++++++ .../raft/distance/fused_distance_nn.cuh | 24 + ...pers.cuh => fused_distance_nn_helpers.cuh} | 4 +- cpp/include/raft/distance/fused_l2_nn-ext.cuh | 8 +- cpp/include/raft/distance/fused_l2_nn-inl.cuh | 4 +- .../distance/fused_distance_nn.hpp | 62 +++ .../raft_runtime/distance/fused_l2_nn.hpp | 36 +- cpp/src/distance/fused_distance_nn.cu | 53 +++ .../distance/fused_distance_min_arg.cu | 56 +++ .../distance/fused_distance_min_arg.hpp | 144 ++++++ .../raft_runtime/distance/fused_l2_min_arg.cu | 87 +--- cpp/test/CMakeLists.txt | 1 + cpp/test/distance/fused_cosine_nn.cu | 420 ++++++++++++++++++ cpp/test/distance/fused_l2_nn.cu | 1 - .../pylibraft/distance/CMakeLists.txt | 4 +- .../pylibraft/pylibraft/distance/__init__.py | 9 +- .../pylibraft/distance/fused_distance_nn.pyx | 200 +++++++++ .../test/test_fused_distance_argmin.py | 69 +++ 32 files changed, 2281 insertions(+), 174 deletions(-) create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh create mode 100644 cpp/include/raft/distance/fused_distance_nn-ext.cuh create mode 100644 cpp/include/raft/distance/fused_distance_nn-inl.cuh create mode 100755 cpp/include/raft/distance/fused_distance_nn.cuh rename cpp/include/raft/distance/{fused_l2_nn_helpers.cuh => fused_distance_nn_helpers.cuh} (92%) create mode 100644 cpp/include/raft_runtime/distance/fused_distance_nn.hpp create mode 100644 cpp/src/distance/fused_distance_nn.cu create mode 100644 cpp/src/raft_runtime/distance/fused_distance_min_arg.cu create mode 100644 cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp create mode 100644 cpp/test/distance/fused_cosine_nn.cu create mode 100644 python/pylibraft/pylibraft/distance/fused_distance_nn.pyx create mode 100755 python/pylibraft/pylibraft/test/test_fused_distance_argmin.py diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 638ceb3b45..6107b9325a 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -256,7 +256,7 @@ endif() if(RAFT_NVTX) # This enables NVTX within the project with no option to disable it downstream. - target_link_libraries(raft INTERFACE CUDA::nvToolsExt) + target_link_libraries(raft INTERFACE CUDA::nvtx3) target_compile_definitions(raft INTERFACE NVTX_ENABLED) else() # Allow enable NVTX downstream if not set here. This creates a new option at build/install time, @@ -324,6 +324,7 @@ if(RAFT_COMPILE_LIBRARY) src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu src/distance/distance.cu src/distance/fused_l2_nn.cu + src/distance/fused_distance_nn.cu src/linalg/detail/coalesced_reduction.cu src/matrix/detail/select_k_double_int64_t.cu src/matrix/detail/select_k_double_uint32_t.cu @@ -422,6 +423,7 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/cluster/update_centroids.cuh src/raft_runtime/cluster/update_centroids_double.cu src/raft_runtime/cluster/update_centroids_float.cu + src/raft_runtime/distance/fused_distance_min_arg.cu src/raft_runtime/distance/fused_l2_min_arg.cu src/raft_runtime/distance/pairwise_distance.cu src/raft_runtime/matrix/select_k_float_int64_t.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 5577881ef7..903f4e4347 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -122,16 +122,8 @@ if(BUILD_PRIMS_BENCH) ) ConfigureBench( - NAME - MATRIX_BENCH - PATH - bench/prims/matrix/argmin.cu - bench/prims/matrix/gather.cu - bench/prims/matrix/select_k.cu - bench/prims/matrix/main.cpp - OPTIONAL - LIB - EXPLICIT_INSTANTIATE_ONLY + NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu + bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( diff --git a/cpp/include/raft/distance/detail/fused_distance_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn.cuh new file mode 100644 index 0000000000..4fbfdc8755 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn.cuh @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include +#include +#include +#include // PairwiseDistances +#include +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +#include // size_t +#include // std::numeric_limits + +namespace raft { +namespace distance { + +namespace detail { + +template +void fusedDistanceNNImpl(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedDistanceNN. + typedef Policy P; + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef KeyValuePair KVPair; + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + switch (metric) { + case raft::distance::DistanceType::CosineExpanded: + fusedCosineNN( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream); + break; + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2Expanded: + // initOutBuffer is take care by fusedDistanceNNImpl() so we set it false to fusedL2NNImpl. + fusedL2NNImpl( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream); + break; + default: assert("only cosine/l2 metric is supported with fusedDistanceNN\n"); break; + } +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h index 32e9214a0d..186715851b 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h @@ -611,6 +611,7 @@ class EpilogueWithBroadcastCustom : public EpilogueBase +#include + #include #include #include @@ -46,6 +48,14 @@ namespace raft { namespace distance { namespace detail { +template +RAFT_KERNEL initBinMutexKernel(cuda::binary_semaphore* mut, IdxT m) +{ + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + + if (tid < m) { mut[tid].release(); } +} + template ; constexpr int batch_count = 1; + rmm::device_uvector> bin_mutex(m, stream); + + int blks_ = (m / 256) + 1; + + initBinMutexKernel<<>>(bin_mutex.data(), m); + typename EpilogueOutputOp::Params epilog_op_param( - dist_op, cg_reduce_op, redOp, pairRedOp, mutexes); + dist_op, cg_reduce_op, redOp, pairRedOp, mutexes, bin_mutex.data()); // Number of pipelines you want to use constexpr int NumStages = 3; diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh index f9ab394585..e69b2486df 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh @@ -56,6 +56,8 @@ #pragma once +#include + #include #include #include @@ -121,6 +123,7 @@ class FusedDistanceNNEpilogueElementwise { KVPReduceOpT_ pair_redop_; ReduceOpT_ red_op_; int* mutexes_; + cuda::binary_semaphore* bin_mutex_; using CGReduceT = CGReduceOp_; // // Methods @@ -130,12 +133,14 @@ class FusedDistanceNNEpilogueElementwise { CGReduceOp cg_reduce_op, ReduceOpT_ red_op, KVPReduceOpT_ pair_redop, - int* mutexes) + int* mutexes, + cuda::binary_semaphore* bin_mutex) : cg_reduce_op(cg_reduce_op), dist_op_(dist_op), pair_redop_(pair_redop), red_op_(red_op), - mutexes_(mutexes) + mutexes_(mutexes), + bin_mutex_(bin_mutex) { } diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh new file mode 100644 index 0000000000..f29c8b4d4c --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include +#include // PairwiseDistances +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +#include // size_t +#include // std::numeric_limits + +namespace raft { +namespace distance { + +namespace detail { + +template +void fusedCosineNN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedL2NN. + typedef Policy P; + + dim3 blk(P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef KeyValuePair KVPair; + + namespace arch = raft::util::arch; + using AccT = DataT; + ops::cosine_distance_op distance_op{}; + + raft::identity_op fin_op{}; + + auto kernel = fusedDistanceNNkernel; + + // Get pointer to fp32 SIMT kernel to determine the runtime architecture of the + // current system. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using cosineOp = raft::distance::detail::ops::cosine_cutlass_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + kvp_cg_min_reduce_op_ cg_reduce_op; + cosineOp cosine_dist_op; + + IdxT lda, ldb, ldd; + lda = k, ldb = k, ldd = n; + + cutlassFusedDistanceNN(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + min, + workspace, + cg_reduce_op, + cosine_dist_op, + redOp, + pairRedOp, + stream); + } else { + // If device less than SM_80, use fp32 SIMT kernel. + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + RAFT_CUDA_TRY(cudaGetLastError()); + } +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh new file mode 100644 index 0000000000..65475e73c7 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include +#include // PairwiseDistances +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +#include // size_t +#include // std::numeric_limits + +namespace raft { +namespace distance { + +namespace detail { + +template +void fusedL2NNImpl(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedL2NN. + typedef Policy P; + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef KeyValuePair KVPair; + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + namespace arch = raft::util::arch; + using AccT = DataT; + ops::l2_exp_distance_op distance_op{sqrt}; + + raft::identity_op fin_op{}; + + auto kernel = fusedDistanceNNkernel; + + // Get pointer to fp32 SIMT kernel to determine the best compute architecture + // out of all for which the kernel was compiled for that matches closely + // to the current device. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using L2Op = raft::distance::detail::ops::l2_exp_cutlass_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + kvp_cg_min_reduce_op_ cg_reduce_op; + L2Op L2_dist_op(sqrt); + + IdxT lda, ldb, ldd; + lda = k, ldb = k, ldd = n; + + cutlassFusedDistanceNN(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + min, + workspace, + cg_reduce_op, + L2_dist_op, + redOp, + pairRedOp, + stream); + } else { + // If device less than SM_80, use fp32 SIMT kernel. + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + RAFT_CUDA_TRY(cudaGetLastError()); + } +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh new file mode 100644 index 0000000000..e056c5d397 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include // PairwiseDistances +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl +#include + +#include // size_t +#include // std::numeric_limits + +namespace raft { +namespace distance { + +namespace detail { + +template +struct KVPMinReduceImpl { + typedef raft::KeyValuePair KVP; + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + +}; // KVPMinReduce + +template +struct MinAndDistanceReduceOpImpl { + typedef typename raft::KeyValuePair KVP; + + DI void operator()(LabelT rid, KVP* out, const KVP& other) const + { + if (other.value < out->value) { + out->key = other.key; + out->value = other.value; + } + } + DI void operator()(LabelT rid, volatile KVP* out, const KVP& other) const + { + if (other.value < out->value) { + out->key = other.key; + out->value = other.value; + } + } + + DI void operator()(LabelT rid, DataT* out, const KVP& other) const + { + if (other.value < *out) { *out = other.value; } + } + + DI void operator()(LabelT rid, volatile DataT* out, const KVP& other) const + { + if (other.value < *out) { *out = other.value; } + } + + DI void operator()(LabelT rid, DataT* out, const DataT& other) const + { + if (other < *out) { *out = other; } + } + + DI void operator()(LabelT rid, volatile DataT* out, const DataT& other) const + { + if (other < *out) { *out = other; } + } + + DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } + DI void init(KVP* out, DataT maxVal) const + { + out->value = maxVal; + out->key = 0xfffffff0; + } + + DI void init_key(DataT& out, LabelT idx) const { return; } + DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } + + DI DataT get_value(KVP& out) const { return out.value; } + DI DataT get_value(DataT& out) const { return out; } +}; + +template +struct MinReduceOpImpl { + typedef typename raft::KeyValuePair KVP; + DI void operator()(LabelT rid, DataT* out, const KVP& other) + { + if (other.value < *out) { *out = other.value; } + } + + DI void init(DataT* out, DataT maxVal) { *out = maxVal; } +}; + +template +RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) +{ + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (tid < m) { redOp.init(min + tid, maxVal); } +} + +template +void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, cudaStream_t stream) +{ + auto blks = raft::ceildiv(m, 256); + initKernel<<>>(min, m, maxVal, redOp); +} + +// cg::reduce functor for FusedDistanceNN used in its cutlass version +// to output the min distance value & key(loc id). +// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h +// store_with_byte_offset() passed to cg::reduce() & select_reduce. +template +struct kvp_cg_min_reduce_op { + typedef typename raft::KeyValuePair KVP; + + __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; + + using AccTypeT = AccType; + using IndexT = Index; + // functor signature. + __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } + + __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } + + __host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } +}; + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index 4f05251705..f1a7c728e9 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -180,8 +180,7 @@ struct FusedDistanceNNPersistent { /// Default ctor CUTLASS_HOST_DEVICE Arguments() - : // problem_count(0), - threadblock_count(0), + : threadblock_count(0), ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), @@ -235,7 +234,6 @@ struct FusedDistanceNNPersistent { /// Parameters structure struct Params { - // typename ProblemVisitor::Params problem_visitor; temp_problem_visitor problem_visitor; int threadblock_count; diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index 936b9f2c89..d61018593f 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -322,10 +322,8 @@ class PredicatedTileIteratorReducedVec { /// Parameters structure containing reference and precomputed state. Params params_; - /// Byte-level pointer - uint8_t* byte_pointer_; /// Byte-level pointer first tile offset of this threadblock. - uint8_t* first_tile_byte_pointer_; + volatile uint8_t* first_tile_byte_pointer_; /// Array of boolean values to contain steady-state predicates Mask mask_; @@ -350,6 +348,8 @@ class PredicatedTileIteratorReducedVec { /// Scatter indices int const* indices_; + const int do_gmem_reduction_; + // // Static asserts about internal strides // @@ -360,7 +360,6 @@ class PredicatedTileIteratorReducedVec { protected: SharedStorage& shared_storage_; - const bool& do_gmem_reduction_; private: // @@ -374,10 +373,10 @@ class PredicatedTileIteratorReducedVec { CUTLASS_DEVICE PredicatedTileIteratorReducedVec(SharedStorage& shared_storage, Params const& params, - Element* pointer, + volatile Element* pointer, TensorCoord extent, int thread_idx, - const bool& do_gmem_reduction, + const bool do_gmem_reduction, TensorCoord threadblock_offset = TensorCoord(), int const* indices = nullptr) : params_(params), @@ -409,6 +408,7 @@ class PredicatedTileIteratorReducedVec { EpilogueOpParams const& user_params = params_.user_param; shared_storage_.initSmem(user_params); } + __syncthreads(); // Null pointer performs no accesses if (!pointer) { mask_.clear(); } @@ -416,66 +416,53 @@ class PredicatedTileIteratorReducedVec { if (ScatterD && !indices) { mask_.clear(); } // Initialize pointer - first_tile_byte_pointer_ = reinterpret_cast(pointer) + + first_tile_byte_pointer_ = reinterpret_cast(pointer) + LongIndex(block_offset.row()) * LongIndex(params_.stride); - if (ScatterD) { - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; - } - // Initialize internal state counter state_[0] = state_[1] = state_[2] = 0; } - /// Destructor - CUTLASS_DEVICE - ~PredicatedTileIteratorReducedVec() + CUTLASS_DEVICE void dumpToGmem() { + if (block_start_row_first_tile_ >= extent_row_) { return; } + if (do_gmem_reduction_) { EpilogueOpParams const& user_params = params_.user_param; - auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - Element* shared_elem_arr = shared_storage_.data(); const uint32_t mutex_id = (block_start_row_first_tile_ / total_rows); - bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); - // If this is not optimal grid size perform mutex based gmem reduce. - if (useGmemMutex) { - // single lock per block for multiple rows - if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { - // acquire mutex lock. - unsigned int ns = 8; - while (atomicCAS(user_params.mutexes_ + mutex_id, 0, 1) == 1) { - __nanosleep(ns); - if (ns < 256) { ns *= 2; } - } + const bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); + int row = threadIdx.x; + Element* shared_elem_arr = shared_storage_.data(); + Element row_local_min; + if (row < total_rows) { row_local_min = shared_elem_arr[row]; } + + // single lock per block for multiple rows + if (useGmemMutex && threadIdx.x == 0) { user_params.bin_mutex_[mutex_id].acquire(); } + __syncthreads(); + + if (row < total_rows) { + volatile Element* gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + + if ((block_start_row_first_tile_ + row) < extent_row_) { + user_params.red_op_(block_start_row_first_tile_ + row, (gmem_ptr + row), row_local_min); } } __syncthreads(); - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - if (block_start_row_first_tile_ + row < extent_row_) { - user_params.red_op_( - block_start_row_first_tile_ + row, &gmem_ptr[row], shared_elem_arr[row]); - } - } + __threadfence(); - if (useGmemMutex) { - __threadfence(); - __syncthreads(); - if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { - // release mutex lock. - atomicExch(user_params.mutexes_ + mutex_id, 0); - } + if (useGmemMutex && (threadIdx.x == 0)) { + // release mutex lock. + user_params.bin_mutex_[mutex_id].release(); } + shared_storage_.initSmem(user_params); + __syncthreads(); } } - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - byte_pointer_ += pointer_offset * sizeof_bits::value / 8; - } + /// Destructor + CUTLASS_DEVICE + ~PredicatedTileIteratorReducedVec() {} /// Performs reduction and Stores a reduced output to memory CUTLASS_DEVICE @@ -515,9 +502,6 @@ class PredicatedTileIteratorReducedVec { user_params.red_op_.init(&red_val, maxVal); if (row_guard) { - const int iter_row = (row_id % total_rows); - const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); - CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn * kElementsPerAccess; ++column) { @@ -536,6 +520,10 @@ class PredicatedTileIteratorReducedVec { user_params.red_op_(row_id, &red_val, this_val); } } + } + const int iter_row = (row_id % total_rows); + const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); + if (row_guard) { // select_reduce doesn't need to use `red_op_` as at the warp level we use cg_reduce_op, // this satisfies the requirement of mst/single linkage of checking colors buffer. select_reduce red_obj( @@ -544,6 +532,7 @@ class PredicatedTileIteratorReducedVec { } } } + __syncthreads(); } /// Stores a fragment to memory @@ -575,14 +564,11 @@ class PredicatedTileIteratorReducedVec { { ++state_[0]; - if (!ScatterD) { byte_pointer_ += params_.advance_row; } - thread_start_row_ += ThreadMap::Shape::kRow; if (state_[0] == ThreadMap::Count::kRow) { state_[0] = 0; ++state_[1]; - byte_pointer_ += params_.advance_group; thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; @@ -590,18 +576,13 @@ class PredicatedTileIteratorReducedVec { if (state_[1] == ThreadMap::Count::kGroup) { state_[1] = 0; ++state_[2]; - byte_pointer_ += params_.advance_cluster; thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; - if (state_[2] == ThreadMap::Count::kCluster) { - state_[2] = 0; - byte_pointer_ += params_.advance_tile; - } + if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; } } } - return *this; } diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh new file mode 100644 index 0000000000..7417fd5dac --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // ops::l2_exp_distance_op +#include // PairwiseDistances +#include // Policy + +#include // size_t +#include // std::numeric_limits + +namespace raft { +namespace distance { +namespace detail { + +// TODO: specialize this function for MinAndDistanceReduceOp +// with atomicCAS of 64 bit which will eliminate mutex and shfls +template +DI void updateReducedVal( + int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, IdxT m, IdxT gridStrideY) +{ + const auto lid = threadIdx.x % raft::WarpSize; + const auto accrowid = threadIdx.x / P::AccThCols; + + // Update each output row in order within a warp. This will resolve hang + // issues with pre-Volta architectures +#pragma unroll + for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { + if (lid == j * P::AccThCols) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rid = gridStrideY + accrowid + i * P::AccThRows; + if (rid < m) { + auto value = val[i]; + while (atomicCAS(mutex + rid, 0, 1) == 1) + ; + __threadfence(); + red_op(rid, min + rid, value); + __threadfence(); + atomicCAS(mutex + rid, 1, 0); + } + } + } + } +} + +template +__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedDistanceNNkernel(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + DataT maxVal, + int* mutex, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + OpT distance_op, + FinalLambda fin_op) +{ +// compile only if below non-ampere arch. +#if __CUDA_ARCH__ < 800 + extern __shared__ char smem[]; + + typedef KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {0, maxVal}; + } + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( + DataT acc[P::AccRowsPerTh][P::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + + // intra thread reduce + const auto acccolid = threadIdx.x % P::AccThCols; + const auto accrowid = threadIdx.x / P::AccThCols; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < n) { + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + } + }; + + auto rowEpilog_lambda = + [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); + + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + + // reduce +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = P::AccThCols / 2; j > 0; j >>= 1) { + // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, + // but the shfl op applies the modulo internally. + auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); + KVPair tmp = {tmpkey, tmpvalue}; + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + + updateReducedVal(mutex, min, val, red_op, m, gridStrideY); + + // reset the val array. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {0, maxVal}; + } + }; + + IdxT lda = k, ldb = k, ldd = n; + constexpr bool row_major = true; + constexpr bool write_out = false; + PairwiseDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + xn, + yn, + nullptr, // Output pointer + smem, + distance_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); +#endif +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 536e4937ab..3e3699766f 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/cpp/include/raft/distance/fused_distance_nn-ext.cuh b/cpp/include/raft/distance/fused_distance_nn-ext.cuh new file mode 100644 index 0000000000..263bbcea81 --- /dev/null +++ b/cpp/include/raft/distance/fused_distance_nn-ext.cuh @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // raft::resources +#include // include initialize and reduce operations +#include // RAFT_EXPLICIT + +#include // int64_t + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft { +namespace distance { + +template +void fusedDistanceNNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) RAFT_EXPLICIT; + +} // namespace distance +} // namespace raft + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_distance_fusedDistanceNNMinReduce(DataT, OutT, IdxT) \ + extern template void raft::distance::fusedDistanceNNMinReduce( \ + OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + bool isRowMajor, \ + raft::distance::DistanceType metric, \ + float metric_arg, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedDistanceNNMinReduce diff --git a/cpp/include/raft/distance/fused_distance_nn-inl.cuh b/cpp/include/raft/distance/fused_distance_nn-inl.cuh new file mode 100644 index 0000000000..ffe86a1c04 --- /dev/null +++ b/cpp/include/raft/distance/fused_distance_nn-inl.cuh @@ -0,0 +1,328 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __FUSED_DISTANCE_NN_H +#define __FUSED_DISTANCE_NN_H + +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +namespace raft { +namespace distance { + +/** + * \ingroup fused_l2_nn + * @{ + */ +/** + * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. + * + * The benefits of such a call are 2-fold: 1) eliminate the need for an + * intermediate buffer to store the output of gemm 2) reduce the memory read + * traffic on this intermediate buffer, otherwise needed during the reduction + * phase for 1-NN. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances or store only the min distances. Accordingly, one + * has to pass an appropriate `ReduceOpT` + * @tparam IdxT indexing arithmetic type + * @tparam ReduceOpT A struct to perform the final needed reduction operation + * and also to initialize the output array elements with the + * appropriate initial value needed for reduction. + * @tparam KVPReduceOpT A struct providing functions for key-value pair comparison. + * + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] redOp reduction operator in the epilogue + * @param[in] pairRedOp reduction operation on key value pairs + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] isRowMajor whether the input/output is row or column major. + * @param[in] metric Distance metric to be used (supports L2, cosine) + * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) + * @param[in] stream cuda stream + */ +template +void fusedDistanceNN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + ASSERT(isRowMajor, "fusedDistanceNN only supports row major inputs"); + // When k is smaller than 32, the Policy4x4 results in redundant calculations + // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead + // that uses tiles with a smaller value of k. + bool is_skinny = k < 32; + + size_t bytes = sizeof(DataT) * k; + auto px = reinterpret_cast(x); + auto py = reinterpret_cast(y); + if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { + if (is_skinny) { + detail::fusedDistanceNNImpl< + DataT, + OutT, + IdxT, + typename linalg::Policy4x4Skinny::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } else { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } + } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { + if (is_skinny) { + detail::fusedDistanceNNImpl< + DataT, + OutT, + IdxT, + typename linalg::Policy4x4Skinny::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } else { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } + } else { + if (is_skinny) { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } else { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } + } +} + +/** + * @brief Wrapper around fusedDistanceNN with minimum reduction operators. + * + * fusedDistanceNN cannot be compiled in the distance library due to the lambda + * operators, so this wrapper covers the most common case (minimum). + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances (e.g. raft::KeyValuePair) or store only the min + * distances. + * @tparam IdxT indexing arithmetic type + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] isRowMajor whether the input/output is row or column major. + * @param[in] metric Distance metric to be used (supports L2, cosine) + * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) + * @param[in] stream cuda stream + */ +template +void fusedDistanceNNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + MinAndDistanceReduceOp redOp; + KVPMinReduce pairRedOp; + + fusedDistanceNN(min, + x, + y, + xn, + yn, + m, + n, + k, + workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); +} + +/** @} */ + +} // namespace distance +} // namespace raft + +#endif diff --git a/cpp/include/raft/distance/fused_distance_nn.cuh b/cpp/include/raft/distance/fused_distance_nn.cuh new file mode 100755 index 0000000000..04c42e49a1 --- /dev/null +++ b/cpp/include/raft/distance/fused_distance_nn.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "fused_distance_nn-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "fused_distance_nn-ext.cuh" +#endif diff --git a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh b/cpp/include/raft/distance/fused_distance_nn_helpers.cuh similarity index 92% rename from cpp/include/raft/distance/fused_l2_nn_helpers.cuh rename to cpp/include/raft/distance/fused_distance_nn_helpers.cuh index 996f696ef6..3a570c681c 100644 --- a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh +++ b/cpp/include/raft/distance/fused_distance_nn_helpers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ #pragma once #include -#include +#include namespace raft::distance { diff --git a/cpp/include/raft/distance/fused_l2_nn-ext.cuh b/cpp/include/raft/distance/fused_l2_nn-ext.cuh index c7dc4c5fc6..d0ac83cd51 100644 --- a/cpp/include/raft/distance/fused_l2_nn-ext.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-ext.cuh @@ -16,10 +16,10 @@ #pragma once -#include // raft::KeyValuePair -#include // raft::resources -#include // include initialize and reduce operations -#include // RAFT_EXPLICIT +#include // raft::KeyValuePair +#include // raft::resources +#include // include initialize and reduce operations +#include // RAFT_EXPLICIT #include // int64_t diff --git a/cpp/include/raft/distance/fused_l2_nn-inl.cuh b/cpp/include/raft/distance/fused_l2_nn-inl.cuh index 6f4db15c72..bf9a49d813 100644 --- a/cpp/include/raft/distance/fused_l2_nn-inl.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-inl.cuh @@ -20,8 +20,8 @@ #pragma once #include -#include -#include +#include +#include #include #include diff --git a/cpp/include/raft_runtime/distance/fused_distance_nn.hpp b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp new file mode 100644 index 0000000000..7c309d6fc7 --- /dev/null +++ b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::distance { + +/** + * @defgroup fused_distance_nn_min_arg_runtime Fused Distance 1NN Runtime API + * @{ + */ + +/** + * @brief Wrapper around fusedDistanceNN with minimum reduction operators. + * + * fusedDistanceNN cannot be compiled in the distance library due to the lambda + * operators, so this wrapper covers the most common case (minimum). + * + * @param[in] handle raft handle + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] metric Distance metric to be used (supports L2, cosine) + * @param[in] isRowMajor whether the input/output is row or column major. + * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) + */ +void fused_distance_nn_min_arg(raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + +/** @} */ // end group fused_distance_nn_min_arg_runtime + +} // end namespace raft::runtime::distance diff --git a/cpp/include/raft_runtime/distance/fused_l2_nn.hpp b/cpp/include/raft_runtime/distance/fused_l2_nn.hpp index 6154e03f4c..e46b3c5271 100644 --- a/cpp/include/raft_runtime/distance/fused_l2_nn.hpp +++ b/cpp/include/raft_runtime/distance/fused_l2_nn.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,23 +42,25 @@ namespace raft::runtime::distance { * @param[in] k gemm k * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt */ -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const float* x, - const float* y, - int m, - int n, - int k, - bool sqrt); +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt); -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const double* x, - const double* y, - int m, - int n, - int k, - bool sqrt); +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const double* x, + const double* y, + int m, + int n, + int k, + bool sqrt); /** @} */ // end group fused_l2_nn_min_arg_runtime diff --git a/cpp/src/distance/fused_distance_nn.cu b/cpp/src/distance/fused_distance_nn.cu new file mode 100644 index 0000000000..dc722d929c --- /dev/null +++ b/cpp/src/distance/fused_distance_nn.cu @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // raft::KeyValuePair +#include + +#include // int64_t + +#define instantiate_raft_distance_fusedDistanceNNMinReduce(DataT, OutT, IdxT) \ + template void raft::distance::fusedDistanceNNMinReduce( \ + OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + bool isRowMajor, \ + raft::distance::DistanceType metric, \ + float metric_arg, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedDistanceNNMinReduce diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu new file mode 100644 index 0000000000..dfdff4e94b --- /dev/null +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fused_distance_min_arg.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft::runtime::distance { + +void fused_distance_nn_min_arg(raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) +{ + switch (metric) { + case raft::distance::DistanceType::CosineExpanded: + compute_fused_cosine_nn_min_arg(handle, min, x, y, m, n, k, sqrt); + break; + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: + compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); + break; + default: assert("only Cosine/L2 metric is supported with fusedDistanceNN\n"); break; + } +} + +} // end namespace raft::runtime::distance diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp b/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp new file mode 100644 index 0000000000..6452752a79 --- /dev/null +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft::runtime::distance { + +template +struct KeyValueIndexOp { + __host__ __device__ __forceinline__ IndexT + operator()(const raft::KeyValuePair& a) const + { + return a.key; + } +}; + +template +void compute_fused_l2_nn_min_arg(raft::resources const& handle, + idx_t* min, + const value_t* x, + const value_t* y, + idx_t m, + idx_t n, + idx_t k, + bool sqrt) +{ + rmm::device_uvector workspace(m, resource::get_cuda_stream(handle)); + auto kvp = raft::make_device_vector>(handle, m); + constexpr bool is_row_major = true; + + rmm::device_uvector x_norms(m, resource::get_cuda_stream(handle)); + rmm::device_uvector y_norms(n, resource::get_cuda_stream(handle)); + raft::linalg::rowNorm( + x_norms.data(), x, k, m, raft::linalg::L2Norm, is_row_major, resource::get_cuda_stream(handle)); + raft::linalg::rowNorm( + y_norms.data(), y, k, n, raft::linalg::L2Norm, is_row_major, resource::get_cuda_stream(handle)); + + raft::distance::fusedL2NNMinReduce(kvp.data_handle(), + x, + y, + x_norms.data(), + y_norms.data(), + m, + n, + k, + (void*)workspace.data(), + sqrt, + true, + resource::get_cuda_stream(handle)); + + KeyValueIndexOp conversion_op; + thrust::transform(resource::get_thrust_policy(handle), + kvp.data_handle(), + kvp.data_handle() + m, + min, + conversion_op); + resource::sync_stream(handle); +} + +template +void compute_fused_cosine_nn_min_arg(raft::resources const& handle, + idx_t* min, + const value_t* x, + const value_t* y, + idx_t m, + idx_t n, + idx_t k, + bool sqrt) +{ + rmm::device_uvector workspace(m, resource::get_cuda_stream(handle)); + auto kvp = raft::make_device_vector>(handle, m); + + rmm::device_uvector x_norms(m, resource::get_cuda_stream(handle)); + rmm::device_uvector y_norms(n, resource::get_cuda_stream(handle)); + constexpr bool is_row_major = true; + raft::linalg::rowNorm(x_norms.data(), + x, + k, + m, + raft::linalg::L2Norm, + is_row_major, + resource::get_cuda_stream(handle), + raft::sqrt_op{}); + raft::linalg::rowNorm(y_norms.data(), + y, + k, + n, + raft::linalg::L2Norm, + is_row_major, + resource::get_cuda_stream(handle), + raft::sqrt_op{}); + + raft::distance::fusedDistanceNNMinReduce(kvp.data_handle(), + x, + y, + x_norms.data(), + y_norms.data(), + m, + n, + k, + (void*)workspace.data(), + sqrt, + true, + is_row_major, + raft::distance::DistanceType::CosineExpanded, + 0.0f, + resource::get_cuda_stream(handle)); + + KeyValueIndexOp conversion_op; + thrust::transform(resource::get_thrust_policy(handle), + kvp.data_handle(), + kvp.data_handle() + m, + min, + conversion_op); + resource::sync_stream(handle); +} + +} // end namespace raft::runtime::distance diff --git a/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu index bf4b1a431a..870757dca1 100644 --- a/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu +++ b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu @@ -14,6 +14,8 @@ * limitations under the License. */ +#include "fused_distance_min_arg.hpp" + #include #include #include @@ -28,77 +30,28 @@ namespace raft::runtime::distance { -template -struct KeyValueIndexOp { - __host__ __device__ __forceinline__ IndexT - operator()(const raft::KeyValuePair& a) const - { - return a.key; - } -}; - -template -void compute_fused_l2_nn_min_arg(raft::resources const& handle, - idx_t* min, - const value_t* x, - const value_t* y, - idx_t m, - idx_t n, - idx_t k, - bool sqrt) -{ - rmm::device_uvector workspace(m, resource::get_cuda_stream(handle)); - auto kvp = raft::make_device_vector>(handle, m); - - rmm::device_uvector x_norms(m, resource::get_cuda_stream(handle)); - rmm::device_uvector y_norms(n, resource::get_cuda_stream(handle)); - raft::linalg::rowNorm( - x_norms.data(), x, k, m, raft::linalg::L2Norm, true, resource::get_cuda_stream(handle)); - raft::linalg::rowNorm( - y_norms.data(), y, k, n, raft::linalg::L2Norm, true, resource::get_cuda_stream(handle)); - - raft::distance::fusedL2NNMinReduce(kvp.data_handle(), - x, - y, - x_norms.data(), - y_norms.data(), - m, - n, - k, - (void*)workspace.data(), - sqrt, - true, - resource::get_cuda_stream(handle)); - - KeyValueIndexOp conversion_op; - thrust::transform(resource::get_thrust_policy(handle), - kvp.data_handle(), - kvp.data_handle() + m, - min, - conversion_op); - resource::sync_stream(handle); -} - -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const float* x, - const float* y, - int m, - int n, - int k, - bool sqrt) +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt) { compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); } -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const double* x, - const double* y, - int m, - int n, - int k, - bool sqrt) +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const double* x, + const double* y, + int m, + int n, + int k, + bool sqrt) { compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); } diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 65d6e738a2..bf44cf9c60 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -171,6 +171,7 @@ if(BUILD_TESTS) test/distance/masked_nn.cu test/distance/masked_nn_compress_to_bits.cu test/distance/fused_l2_nn.cu + test/distance/fused_cosine_nn.cu test/distance/gram.cu LIB EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/test/distance/fused_cosine_nn.cu b/cpp/test/distance/fused_cosine_nn.cu new file mode 100644 index 0000000000..d4d632e1dc --- /dev/null +++ b/cpp/test/distance/fused_cosine_nn.cu @@ -0,0 +1,420 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Search with filter instantiation + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace raft { +namespace distance { + +template +struct RaftKVPMinReduce { + typedef raft::KeyValuePair KVP; + + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + +}; // KVPMinReduce + +template +__global__ void naiveCosKernel(raft::KeyValuePair* min, + DataT* x, + DataT* y, + int m, + int n, + int k, + int* workspace, + DataT maxVal) +{ + int midx = threadIdx.y + blockIdx.y * blockDim.y; + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + DataT acc_a = DataT(0); + DataT acc_b = DataT(0); + DataT acc_ab = DataT(0); + // if (midx >= m || nidx >= n) { return; } + + for (int i = 0; i < k; ++i) { + int xidx = i + midx * k; + int yidx = i + nidx * k; + auto a = x[xidx]; + auto b = y[yidx]; + acc_a += a * a; + acc_b += b * b; + acc_ab += a * b; + } + + // Use 1.0 - (cosine similarity) to calc the distance + DataT acc = maxVal; + if (midx < m || nidx < n) { acc = (DataT)1.0 - acc_ab / (raft::sqrt(acc_a) * raft::sqrt(acc_b)); } + + ReduceOpT redOp; + typedef cub::WarpReduce> WarpReduce; + __shared__ typename WarpReduce::TempStorage temp[NWARPS]; + int warpId = threadIdx.x / raft::WarpSize; + raft::KeyValuePair tmp; + tmp.key = nidx; + tmp.value = midx >= m || nidx >= n ? maxVal : acc; + tmp = WarpReduce(temp[warpId]).Reduce(tmp, RaftKVPMinReduce()); + if (threadIdx.x % raft::WarpSize == 0 && midx < m) { + while (atomicCAS(workspace + midx, 0, 1) == 1) + ; + __threadfence(); + redOp(midx, min + midx, tmp); + __threadfence(); + atomicCAS(workspace + midx, 1, 0); + } +} + +template +void naive(raft::KeyValuePair* min, + DataT* x, + DataT* y, + int m, + int n, + int k, + int* workspace, + cudaStream_t stream) +{ + static const dim3 TPB(32, 16, 1); + dim3 nblks(raft::ceildiv(n, (int)TPB.x), raft::ceildiv(m, (int)TPB.y), 1); + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + auto blks = raft::ceildiv(m, 256); + MinAndDistanceReduceOp op; + detail::initKernel, int> + <<>>(min, m, std::numeric_limits::max(), op); + RAFT_CUDA_TRY(cudaGetLastError()); + naiveCosKernel, 16> + <<>>(min, x, y, m, n, k, workspace, std::numeric_limits::max()); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +template +struct Inputs { + DataT tolerance; + int m, n, k; + unsigned long long int seed; + + friend std::ostream& operator<<(std::ostream& os, const Inputs& p) + { + return os << "m: " << p.m + << ", " + "n: " + << p.n + << ", " + "k: " + << p.k + << ", " + "seed: " + << p.seed + << ", " + "tol: " + << p.tolerance; + } +}; + +template +class FusedCosineNNTest : public ::testing::TestWithParam> { + public: + FusedCosineNNTest() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + x(params.m * params.k, stream), + y(params.n * params.k, stream), + xn(params.m, stream), + yn(params.n, stream), + min(params.m, stream), + min_ref(params.m, stream), + workspace(params.m * sizeof(int), stream) + { + } + + protected: + void SetUp() override + { + raft::random::RngState r(params.seed); + int m = params.m; + int n = params.n; + int k = params.k; + uniform(handle, r, x.data(), m * k, DataT(-1.0), DataT(1.0)); + uniform(handle, r, y.data(), n * k, DataT(-1.0), DataT(1.0)); + generateGoldenResult(); + raft::linalg::rowNorm( + xn.data(), x.data(), k, m, raft::linalg::L2Norm, true, stream, raft::sqrt_op{}); + raft::linalg::rowNorm( + yn.data(), y.data(), k, n, raft::linalg::L2Norm, true, stream, raft::sqrt_op{}); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + Inputs params; + rmm::device_uvector x; + rmm::device_uvector y; + rmm::device_uvector xn; + rmm::device_uvector yn; + rmm::device_uvector> min; + rmm::device_uvector> min_ref; + rmm::device_uvector workspace; + + virtual void generateGoldenResult() + { + int m = params.m; + int n = params.n; + int k = params.k; + naive(min_ref.data(), x.data(), y.data(), m, n, k, (int*)workspace.data(), stream); + } + + void runTest(raft::KeyValuePair* out) + { + int m = params.m; + int n = params.n; + int k = params.k; + raft::distance::DistanceType metric = raft::distance::DistanceType::CosineExpanded; + constexpr bool init_out_buffer = true; + fusedDistanceNNMinReduce, int>(out, + x.data(), + y.data(), + xn.data(), + yn.data(), + m, + n, + k, + (void*)workspace.data(), + false, + init_out_buffer, + true, + metric, + 0.0f, + stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } +}; + +template +struct CompareApproxAbsKVP { + typedef typename raft::KeyValuePair KVP; + CompareApproxAbsKVP(T eps_) : eps(eps_) {} + bool operator()(const KVP& a, const KVP& b) const + { + T diff = std::abs(std::abs(a.value) - std::abs(b.value)); + T m = std::max(std::abs(a.value), std::abs(b.value)); + T ratio = m >= eps ? diff / m : diff; + return (ratio <= eps); + } + + private: + T eps; +}; + +template +struct CompareExactKVP { + typedef typename raft::KeyValuePair KVP; + bool operator()(const KVP& a, const KVP& b) const + { + if (a.value != b.value) return false; + return true; + } +}; + +template +::testing::AssertionResult devArrMatch(const raft::KeyValuePair* expected, + const raft::KeyValuePair* actual, + size_t size, + L eq_compare, + cudaStream_t stream = 0) +{ + typedef typename raft::KeyValuePair KVP; + std::shared_ptr exp_h(new KVP[size]); + std::shared_ptr act_h(new KVP[size]); + raft::update_host(exp_h.get(), expected, size, stream); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < size; ++i) { + auto exp = exp_h.get()[i]; + auto act = act_h.get()[i]; + if (!eq_compare(exp, act)) { + return ::testing::AssertionFailure() + << "actual=" << act.key << "," << act.value << " != expected=" << exp.key << "," + << exp.value << " @" << i; + } + } + return ::testing::AssertionSuccess(); +} + +const std::vector> inputsf = { + {0.001f, 32, 32, 32, 1234ULL}, + {0.001f, 32, 64, 32, 1234ULL}, + {0.001f, 64, 32, 32, 1234ULL}, + {0.001f, 64, 64, 32, 1234ULL}, + {0.001f, 128, 32, 32, 1234ULL}, + {0.001f, 128, 64, 32, 1234ULL}, + {0.001f, 128, 128, 64, 1234ULL}, + {0.001f, 64, 128, 128, 1234ULL}, + + {0.001f, 32, 32, 34, 1234ULL}, + {0.001f, 32, 64, 34, 1234ULL}, + {0.001f, 64, 32, 34, 1234ULL}, + {0.001f, 64, 64, 34, 1234ULL}, + {0.001f, 128, 32, 34, 1234ULL}, + {0.001f, 128, 64, 34, 1234ULL}, + {0.001f, 128, 128, 66, 1234ULL}, + {0.001f, 64, 128, 130, 1234ULL}, + + {0.001f, 32, 32, 33, 1234ULL}, + {0.001f, 32, 64, 33, 1234ULL}, + {0.001f, 64, 32, 33, 1234ULL}, + {0.001f, 64, 64, 33, 1234ULL}, + {0.001f, 128, 32, 33, 1234ULL}, + {0.001f, 128, 64, 33, 1234ULL}, + {0.001f, 128, 128, 65, 1234ULL}, + {0.001f, 64, 128, 129, 1234ULL}, + {0.006f, 1805, 134, 2, 1234ULL}, + {0.006f, 8192, 1024, 64, 1234ULL}, + {0.006f, 8192, 1025, 64, 1234ULL}, + + // Repeat with smaller values of k + {0.006f, 32, 32, 1, 1234ULL}, + {0.001f, 32, 64, 2, 1234ULL}, + {0.001f, 64, 32, 3, 1234ULL}, + {0.001f, 64, 64, 4, 1234ULL}, + {0.001f, 128, 32, 5, 1234ULL}, + {0.001f, 128, 64, 6, 1234ULL}, + {0.001f, 128, 128, 7, 1234ULL}, + {0.001f, 64, 128, 8, 1234ULL}, + + {0.001f, 32, 32, 9, 1234ULL}, + {0.001f, 32, 64, 10, 1234ULL}, + {0.001f, 64, 32, 11, 1234ULL}, + {0.001f, 64, 64, 12, 1234ULL}, + {0.001f, 128, 32, 13, 1234ULL}, + {0.001f, 128, 64, 14, 1234ULL}, + {0.001f, 128, 128, 15, 1234ULL}, + {0.001f, 64, 128, 16, 1234ULL}, + + {0.001f, 32, 32, 17, 1234ULL}, + {0.001f, 32, 64, 18, 1234ULL}, + {0.001f, 64, 32, 19, 1234ULL}, + {0.001f, 64, 64, 20, 1234ULL}, + {0.001f, 128, 32, 21, 1234ULL}, + {0.001f, 128, 64, 22, 1234ULL}, + {0.001f, 128, 128, 23, 1234ULL}, + {0.00001, 64, 128, 24, 1234ULL}, + {0.001f, 1805, 134, 25, 1234ULL}, + {0.006f, 8192, 1024, 25, 1234ULL}, + {0.006f, 8192, 1024, 66, 1234ULL}, +}; +typedef FusedCosineNNTest FusedCosineNNTestF; +TEST_P(FusedCosineNNTestF, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNTests, FusedCosineNNTestF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.00001, 32, 32, 32, 1234ULL}, {0.00001, 32, 64, 32, 1234ULL}, + {0.00001, 64, 32, 32, 1234ULL}, {0.00001, 64, 64, 32, 1234ULL}, + {0.00001, 128, 32, 32, 1234ULL}, {0.00001, 128, 64, 32, 1234ULL}, + {0.00001, 128, 128, 64, 1234ULL}, {0.00001, 64, 128, 128, 1234ULL}, + + {0.00001, 32, 32, 34, 1234ULL}, {0.00001, 32, 64, 34, 1234ULL}, + {0.00001, 64, 32, 34, 1234ULL}, {0.00001, 64, 64, 34, 1234ULL}, + {0.00001, 128, 32, 34, 1234ULL}, {0.00001, 128, 64, 34, 1234ULL}, + {0.00001, 128, 128, 66, 1234ULL}, {0.00001, 64, 128, 130, 1234ULL}, + + {0.00001, 32, 32, 33, 1234ULL}, {0.00001, 32, 64, 33, 1234ULL}, + {0.00001, 64, 32, 33, 1234ULL}, {0.00001, 64, 64, 33, 1234ULL}, + {0.00001, 128, 32, 33, 1234ULL}, {0.00001, 128, 64, 33, 1234ULL}, + {0.00001, 128, 128, 65, 1234ULL}, {0.00001, 64, 128, 129, 1234ULL}, + + {0.00001, 1805, 134, 2, 1234ULL}, {0.00001, 8192, 1024, 25, 1234ULL}, +}; +typedef FusedCosineNNTest FusedCosineNNTestD; +TEST_P(FusedCosineNNTestD, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNTests, FusedCosineNNTestD, ::testing::ValuesIn(inputsd)); + +/// This is to test output determinism of the prim +template +class FusedCosineNNDetTest : public FusedCosineNNTest { + public: + FusedCosineNNDetTest() : stream(resource::get_cuda_stream(handle)), min1(0, stream) {} + + void SetUp() override + { + FusedCosineNNTest::SetUp(); + int m = this->params.m; + min1.resize(m, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void TearDown() override { FusedCosineNNTest::TearDown(); } + + protected: + raft::resources handle; + cudaStream_t stream; + + rmm::device_uvector> min1; + + static const int NumRepeats = 3; + + void generateGoldenResult() override {} +}; + +typedef FusedCosineNNDetTest FusedCosineNNDetTestF; +TEST_P(FusedCosineNNDetTestF, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + cudaMemsetAsync(min1.data(), 0, sizeof(*min.data()) * params.m, stream); + } +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNDetTests, FusedCosineNNDetTestF, ::testing::ValuesIn(inputsf)); + +typedef FusedCosineNNDetTest FusedCosineNNDetTestD; +TEST_P(FusedCosineNNDetTestD, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + } +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNDetTests, FusedCosineNNDetTestD, ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index d21a525d88..6fd8f15808 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -18,7 +18,6 @@ #include #include -#include #include #include #include diff --git a/python/pylibraft/pylibraft/distance/CMakeLists.txt b/python/pylibraft/pylibraft/distance/CMakeLists.txt index 14f0cc441a..2530e07a98 100644 --- a/python/pylibraft/pylibraft/distance/CMakeLists.txt +++ b/python/pylibraft/pylibraft/distance/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -13,7 +13,7 @@ # ============================================================================= # Set the list of Cython files to build -set(cython_sources pairwise_distance.pyx fused_l2_nn.pyx) +set(cython_sources pairwise_distance.pyx fused_l2_nn.pyx fused_distance_nn.pyx) set(linked_libraries raft::raft raft::compiled) # Build all of the Cython targets diff --git a/python/pylibraft/pylibraft/distance/__init__.py b/python/pylibraft/pylibraft/distance/__init__.py index f059b5f3dd..d16ab30b2f 100644 --- a/python/pylibraft/pylibraft/distance/__init__.py +++ b/python/pylibraft/pylibraft/distance/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,12 @@ # limitations under the License. # +from .fused_distance_nn import fused_distance_nn_argmin from .fused_l2_nn import fused_l2_nn_argmin from .pairwise_distance import DISTANCE_TYPES, distance as pairwise_distance -__all__ = ["fused_l2_nn_argmin", "pairwise_distance"] +__all__ = [ + "fused_distance_nn_argmin", + "fused_l2_nn_argmin", + "pairwise_distance", +] diff --git a/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx new file mode 100644 index 0000000000..0e9fa4b366 --- /dev/null +++ b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx @@ -0,0 +1,200 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +from cython.operator cimport dereference as deref +from libc.stdint cimport uintptr_t +from libcpp cimport bool + +from .distance_type cimport DistanceType + +from pylibraft.common import ( + Handle, + auto_convert_output, + cai_wrapper, + device_ndarray, +) +from pylibraft.common.handle import auto_sync_handle + +from pylibraft.common.handle cimport device_resources + + +cdef extern from "raft_runtime/distance/fused_distance_nn.hpp" \ + namespace "raft::runtime::distance" nogil: + + void fused_distance_nn_min_arg( + const device_resources &handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt, + DistanceType metric, + bool isRowMajor, + float metric_arg) except + + + +from pylibraft.distance.pairwise_distance import DISTANCE_TYPES + +SUPPORTED_DISTANCES = ["euclidean", "l2", "cosine", "sqeuclidean"] + + +@auto_sync_handle +@auto_convert_output +def fused_distance_nn_argmin(X, Y, out=None, sqrt=True, metric="euclidean", + handle=None): + """ + Compute the 1-nearest neighbors between X and Y using the distance metrics + + Valid values for metric: + ["euclidean", "l2", "cosine", "sqeuclidean"] + + Parameters + ---------- + + X : CUDA array interface compliant matrix shape (m, k) + Y : CUDA array interface compliant matrix shape (n, k) + out : Writable CUDA array interface matrix shape (m, 1) + metric : string denoting the metric type (default="euclidean") + + {handle_docstring} + + Examples + -------- + To compute the 1-nearest neighbors argmin: + + >>> import cupy as cp + >>> from pylibraft.common import Handle + >>> from pylibraft.distance import fused_distance_nn_argmin + >>> n_samples = 5000 + >>> n_clusters = 5 + >>> n_features = 50 + >>> in1 = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> in2 = cp.random.random_sample((n_clusters, n_features), + ... dtype=cp.float32) + >>> # A single RAFT handle can optionally be reused across + >>> # pylibraft functions. + >>> handle = Handle() + + >>> output = fused_distance_nn_argmin(in1, in2, handle=handle) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + + The output can also be computed in-place on a preallocated + array: + + >>> import cupy as cp + >>> from pylibraft.common import Handle + >>> from pylibraft.distance import fused_distance_nn_argmin + >>> n_samples = 5000 + >>> n_clusters = 5 + >>> n_features = 50 + >>> in1 = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> in2 = cp.random.random_sample((n_clusters, n_features), + ... dtype=cp.float32) + >>> output = cp.empty((n_samples, 1), dtype=cp.int32) + >>> # A single RAFT handle can optionally be reused across + >>> # pylibraft functions. + >>> handle = Handle() + + >>> fused_distance_nn_argmin(in1, in2, out=output, handle=handle) + array(...) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + """ + + x_cai = cai_wrapper(X) + y_cai = cai_wrapper(Y) + + x_dt = x_cai.dtype + y_dt = y_cai.dtype + + m = x_cai.shape[0] + n = y_cai.shape[0] + + if out is None: + output = device_ndarray.empty((m,), dtype="int32") + else: + output = out + + output_cai = cai_wrapper(output) + + x_k = x_cai.shape[1] + y_k = y_cai.shape[1] + + if x_k != y_k: + raise ValueError("Inputs must have same number of columns. " + "a=%s, b=%s" % (x_k, y_k)) + + if metric not in SUPPORTED_DISTANCES: + raise ValueError("metric %s is not supported" % metric) + + cdef DistanceType distance_type = DISTANCE_TYPES[metric] + + x_ptr = x_cai.data + y_ptr = y_cai.data + + d_ptr = output_cai.data + + handle = handle if handle is not None else Handle() + cdef device_resources *h = handle.getHandle() + + d_dt = output_cai.dtype + + x_c_contiguous = x_cai.c_contiguous + y_c_contiguous = y_cai.c_contiguous + + if x_c_contiguous != y_c_contiguous: + raise ValueError("Inputs must have matching strides") + + if not x_c_contiguous: + raise ValueError("Inputs must be C contiguous") + + if x_dt != y_dt: + raise ValueError("Inputs must have the same dtypes") + if d_dt != np.int32: + raise ValueError("Output array must be int32") + # unused arg for now. + metric_arg = 0.0 + if x_dt == np.float32: + fused_distance_nn_min_arg(deref(h), + d_ptr, + x_ptr, + y_ptr, + m, + n, + x_k, + sqrt, + distance_type, + x_c_contiguous, + metric_arg) + else: + raise ValueError("dtype %s not supported" % x_dt) + + return output diff --git a/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py b/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py new file mode 100755 index 0000000000..6736128242 --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from scipy.spatial.distance import cdist + +from pylibraft.common import DeviceResources, device_ndarray +from pylibraft.distance import fused_distance_nn_argmin + + +@pytest.mark.parametrize("inplace", [True, False]) +@pytest.mark.parametrize("n_rows", [10, 100]) +@pytest.mark.parametrize("n_clusters", [50, 100]) +@pytest.mark.parametrize("n_cols", [128, 31]) +@pytest.mark.parametrize("dtype", [np.float32]) +@pytest.mark.parametrize( + "metric", + [ + "euclidean", + "cosine", + "sqeuclidean", + ], +) +def test_fused_distance_nn_minarg( + n_rows, n_cols, n_clusters, dtype, inplace, metric +): + input1 = np.random.random_sample((n_rows, n_cols)) + input1 = np.asarray(input1, order="C").astype(dtype) + + input2 = np.random.random_sample((n_clusters, n_cols)) + input2 = np.asarray(input2, order="C").astype(dtype) + + output = np.zeros((n_rows), dtype="int32") + expected = cdist(input1, input2, metric) + + expected = expected.argmin(axis=1) + + input1_device = device_ndarray(input1) + input2_device = device_ndarray(input2) + output_device = device_ndarray(output) if inplace else None + + is_sqrt = True if metric == "sqeuclidean" else False + handle = DeviceResources() + ret_output = fused_distance_nn_argmin( + input1_device, + input2_device, + output_device, + is_sqrt, + metric, + handle=handle, + ) + handle.sync() + output_device = ret_output if not inplace else output_device + actual = output_device.copy_to_host() + + assert np.allclose(expected, actual, rtol=1e-4) From 0b9692b25f78cd1b27631e354e3f8921a976645c Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 19 Mar 2024 18:07:21 +0100 Subject: [PATCH 10/18] random sampling of dataset rows with improved memory utilization (#2155) The random sampling of IVF methods was reverted (#2144) due to large memory utilization #2141. This PR improves the memory consumption of subsamling: it is O(n_train) where n_train is the size of the subsampled dataset. This PR adds the following new APIs: - random::excess_sampling (todo may just call as sample_without_replacement) - matrix::sample_rows - matrix::gather for host input matrix Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/2155 --- cpp/bench/prims/CMakeLists.txt | 2 +- cpp/bench/prims/matrix/gather.cu | 38 ++++- cpp/bench/prims/random/subsample.cu | 112 ++++++++++++++ cpp/include/raft/matrix/detail/gather.cuh | 87 +++++++++++ .../raft/matrix/detail/sample_rows.cuh | 57 +++++++ cpp/include/raft/matrix/sample_rows.cuh | 75 ++++++++++ cpp/include/raft/random/detail/rng_impl.cuh | 138 ++++++++++++++++- cpp/include/raft/random/rng.cuh | 26 ++++ cpp/test/CMakeLists.txt | 2 + cpp/test/matrix/sample_rows.cu | 140 ++++++++++++++++++ cpp/test/random/excess_sampling.cu | 114 ++++++++++++++ 11 files changed, 786 insertions(+), 5 deletions(-) create mode 100644 cpp/bench/prims/random/subsample.cu create mode 100644 cpp/include/raft/matrix/detail/sample_rows.cuh create mode 100644 cpp/include/raft/matrix/sample_rows.cuh create mode 100644 cpp/test/matrix/sample_rows.cu create mode 100644 cpp/test/random/excess_sampling.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 903f4e4347..95361e19ca 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -128,7 +128,7 @@ if(BUILD_PRIMS_BENCH) ConfigureBench( NAME RANDOM_BENCH PATH bench/prims/random/make_blobs.cu bench/prims/random/permute.cu - bench/prims/random/rng.cu bench/prims/main.cpp + bench/prims/random/rng.cu bench/prims/random/subsample.cu bench/prims/main.cpp ) ConfigureBench(NAME SPARSE_BENCH PATH bench/prims/sparse/convert_csr.cu bench/prims/main.cpp) diff --git a/cpp/bench/prims/matrix/gather.cu b/cpp/bench/prims/matrix/gather.cu index e6f26ba925..078f9e6198 100644 --- a/cpp/bench/prims/matrix/gather.cu +++ b/cpp/bench/prims/matrix/gather.cu @@ -16,34 +16,48 @@ #include +#include +#include #include #include #include #include #include +#include namespace raft::bench::matrix { template struct GatherParams { IdxT rows, cols, map_length; + bool host; }; template inline auto operator<<(std::ostream& os, const GatherParams& p) -> std::ostream& { - os << p.rows << "#" << p.cols << "#" << p.map_length; + os << p.rows << "#" << p.cols << "#" << p.map_length << (p.host ? "#host" : "#device"); return os; } template struct Gather : public fixture { Gather(const GatherParams& p) - : params(p), matrix(this->handle), map(this->handle), out(this->handle), stencil(this->handle) + : params(p), + old_mr(rmm::mr::get_current_device_resource()), + pool_mr(rmm::mr::get_current_device_resource(), 2 * (1ULL << 30)), + matrix(this->handle), + map(this->handle), + out(this->handle), + stencil(this->handle), + matrix_h(this->handle) { + rmm::mr::set_current_device_resource(&pool_mr); } + ~Gather() { rmm::mr::set_current_device_resource(old_mr); } + void allocate_data(const ::benchmark::State& state) override { matrix = raft::make_device_matrix(handle, params.rows, params.cols); @@ -59,6 +73,11 @@ struct Gather : public fixture { if constexpr (Conditional) { raft::random::uniform(handle, rng, stencil.data_handle(), params.map_length, T(-1), T(1)); } + + if (params.host) { + matrix_h = raft::make_host_matrix(handle, params.rows, params.cols); + raft::copy(matrix_h.data_handle(), matrix.data_handle(), matrix.size(), stream); + } resource::sync_stream(handle, stream); } @@ -77,14 +96,22 @@ struct Gather : public fixture { raft::matrix::gather_if( handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op); } else { - raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view()); + if (params.host) { + raft::matrix::detail::gather( + handle, make_const_mdspan(matrix_h.view()), map_const_view, out.view()); + } else { + raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view()); + } } }); } private: GatherParams params; + rmm::mr::device_memory_resource* old_mr; + rmm::mr::pool_memory_resource pool_mr; raft::device_matrix matrix, out; + raft::host_matrix matrix_h; raft::device_vector stencil; raft::device_vector map; }; // struct Gather @@ -100,4 +127,9 @@ RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); + +auto inputs_host = raft::util::itertools::product>( + {10000000}, {100}, {1000, 1000000, 10000000}, {true}); +RAFT_BENCH_REGISTER((Gather), "Host", inputs_host); + } // namespace raft::bench::matrix diff --git a/cpp/bench/prims/random/subsample.cu b/cpp/bench/prims/random/subsample.cu new file mode 100644 index 0000000000..4c8ca2bf31 --- /dev/null +++ b/cpp/bench/prims/random/subsample.cu @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace raft::bench::random { + +struct sample_inputs { + int n_samples; + int n_train; + int method; +}; // struct sample_inputs + +inline auto operator<<(std::ostream& os, const sample_inputs& p) -> std::ostream& +{ + os << p.n_samples << "#" << p.n_train << "#" << p.method; + return os; +} + +// Sample with replacement. We use this as a baseline. +template +auto bernoulli_subsample(raft::resources const& res, IdxT n_samples, IdxT n_subsamples, int seed) + -> raft::device_vector +{ + RAFT_EXPECTS(n_subsamples <= n_samples, "Cannot have more training samples than dataset vectors"); + + auto indices = raft::make_device_vector(res, n_subsamples); + raft::random::RngState state(123456ULL); + raft::random::uniformInt( + res, state, indices.data_handle(), n_subsamples, IdxT(0), IdxT(n_samples)); + return indices; +} + +template +struct sample : public fixture { + sample(const sample_inputs& p) + : params(p), + old_mr(rmm::mr::get_current_device_resource()), + pool_mr(rmm::mr::get_current_device_resource(), 2 * GiB), + in(make_device_vector(res, p.n_samples)), + out(make_device_vector(res, p.n_train)) + { + rmm::mr::set_current_device_resource(&pool_mr); + raft::random::RngState r(123456ULL); + } + + ~sample() { rmm::mr::set_current_device_resource(old_mr); } + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + raft::random::RngState r(123456ULL); + loop_on_state(state, [this, &r]() { + if (params.method == 1) { + this->out = + bernoulli_subsample(this->res, this->params.n_samples, this->params.n_train, 137); + } else if (params.method == 2) { + this->out = raft::random::excess_subsample( + this->res, r, this->params.n_samples, this->params.n_train); + } + }); + } + + private: + float GiB = 1073741824.0f; + raft::device_resources res; + rmm::mr::device_memory_resource* old_mr; + rmm::mr::pool_memory_resource pool_mr; + sample_inputs params; + raft::device_vector out, in; +}; // struct sample + +const std::vector input_vecs = {{100000000, 10000000, 1}, + {100000000, 50000000, 1}, + {100000000, 100000000, 1}, + {100000000, 10000000, 2}, + {100000000, 50000000, 2}, + {100000000, 100000000, 2}}; + +RAFT_BENCH_REGISTER(sample, "", input_vecs); + +} // namespace raft::bench::random diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index 651fec81c3..05cc9204bf 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -16,9 +16,19 @@ #pragma once +#include +#include +#include +#include +#include #include +#include +#include +#include #include +#include + #include namespace raft { @@ -336,6 +346,83 @@ void gather_if(const InputIteratorT in, gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream); } +/** + * Helper function to gather a set of vectors from a (host) dataset. + */ +template +void gather_buff(host_matrix_view dataset, + host_vector_view indices, + MatIdxT offset, + pinned_matrix_view buff) +{ + raft::common::nvtx::range fun_scope("gather_host_buff"); + IdxT batch_size = std::min(buff.extent(0), indices.extent(0) - offset); + +#pragma omp for + for (IdxT i = 0; i < batch_size; i++) { + IdxT in_idx = indices(offset + i); + for (IdxT k = 0; k < buff.extent(1); k++) { + buff(i, k) = dataset(in_idx, k); + } + } +} + +template +void gather(raft::resources const& res, + host_matrix_view dataset, + device_vector_view indices, + raft::device_matrix_view output) +{ + raft::common::nvtx::range fun_scope("gather"); + IdxT n_dim = output.extent(1); + IdxT n_train = output.extent(0); + auto indices_host = raft::make_host_vector(n_train); + raft::copy( + indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res)); + resource::sync_stream(res); + + const size_t buffer_size = 32768 * 1024; // bytes + const size_t max_batch_size = + std::min(round_up_safe(buffer_size / n_dim, 32), n_train); + RAFT_LOG_DEBUG("Gathering data with batch size %zu", max_batch_size); + + // Gather the vector on the host in tmp buffers. We use two buffers to overlap H2D sync + // and gathering the data. + auto out_tmp1 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + auto out_tmp2 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + + // Usually a limited number of threads provide sufficient bandwidth for gathering data. + int n_threads = std::min(omp_get_max_threads(), 32); + + // The gather_buff function has a parallel for loop. We start the the omp parallel + // region here, to avoid repeated overhead within the device_offset loop. +#pragma omp parallel num_threads(n_threads) + { + auto view1 = out_tmp1.view(); + auto view2 = out_tmp2.view(); + gather_buff(dataset, make_const_mdspan(indices_host.view()), (MatIdxT)0, view1); + for (MatIdxT device_offset = 0; device_offset < n_train; device_offset += max_batch_size) { + MatIdxT batch_size = std::min(max_batch_size, n_train - device_offset); + +#pragma omp master + raft::copy(output.data_handle() + device_offset * n_dim, + view1.data_handle(), + batch_size * n_dim, + resource::get_cuda_stream(res)); + // Start gathering the next batch on the host. + MatIdxT host_offset = device_offset + batch_size; + batch_size = std::min(max_batch_size, n_train - host_offset); + if (batch_size > 0) { + gather_buff(dataset, make_const_mdspan(indices_host.view()), host_offset, view2); + } +#pragma omp master + resource::sync_stream(res); +#pragma omp barrier + std::swap(view1, view2); + } + } +} + } // namespace detail } // namespace matrix } // namespace raft diff --git a/cpp/include/raft/matrix/detail/sample_rows.cuh b/cpp/include/raft/matrix/detail/sample_rows.cuh new file mode 100644 index 0000000000..e28ad648da --- /dev/null +++ b/cpp/include/raft/matrix/detail/sample_rows.cuh @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::matrix::detail { + +/** Select rows randomly from input and copy to output. */ +template +void sample_rows(raft::resources const& res, + random::RngState random_state, + const T* input, + IdxT n_rows_input, + raft::device_matrix_view output) +{ + IdxT n_dim = output.extent(1); + IdxT n_samples = output.extent(0); + + raft::device_vector train_indices = + raft::random::excess_subsample(res, random_state, n_rows_input, n_samples); + + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, input)); + T* ptr = reinterpret_cast(attr.devicePointer); + if (ptr != nullptr) { + raft::matrix::gather(res, + raft::make_device_matrix_view(ptr, n_rows_input, n_dim), + raft::make_const_mdspan(train_indices.view()), + output); + } else { + auto dataset = raft::make_host_matrix_view(input, n_rows_input, n_dim); + raft::matrix::detail::gather(res, dataset, make_const_mdspan(train_indices.view()), output); + } +} +} // namespace raft::matrix::detail diff --git a/cpp/include/raft/matrix/sample_rows.cuh b/cpp/include/raft/matrix/sample_rows.cuh new file mode 100644 index 0000000000..7925d344e4 --- /dev/null +++ b/cpp/include/raft/matrix/sample_rows.cuh @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft::matrix { + +/** @brief Select rows randomly from input and copy to output. + * + * The rows are selected randomly. The random sampling method does not guarantee completely unique + * selection of rows, but it is close to being unique. + * + * @param res RAFT resource handle + * @param random_state + * @param dataset input dataset + * @param output subsampled dataset + */ +template +void sample_rows(raft::resources const& res, + random::RngState random_state, + mdspan, row_major, accessor> dataset, + raft::device_matrix_view output) +{ + RAFT_EXPECTS(dataset.extent(1) == output.extent(1), + "dataset dims must match, but received %ld vs %ld", + static_cast(dataset.extent(1)), + static_cast(output.extent(1))); + detail::sample_rows(res, random_state, dataset.data_handle(), dataset.extent(0), output); +} + +/** @brief Select rows randomly from input and copy to output. + * + * The rows are selected randomly. The random sampling method does not guarantee completely unique + * selection of rows, but it is close to being unique. + * + * @param res RAFT resource handle + * @param random_state + * @param dataset input dataset + * @param n_samples number of rows in the returned matrix + * + * @return subsampled dataset + * */ +template +raft::device_matrix sample_rows( + raft::resources const& res, + random::RngState random_state, + mdspan, row_major, accessor> dataset, + IdxT n_samples) +{ + auto output = raft::make_device_matrix(res, n_samples, dataset.extent(1)); + sample_rows(res, random_state, dataset, output.view()); + return output; +} + +} // namespace raft::matrix diff --git a/cpp/include/raft/random/detail/rng_impl.cuh b/cpp/include/raft/random/detail/rng_impl.cuh index 57f4c8d33d..61a944e9b6 100644 --- a/cpp/include/raft/random/detail/rng_impl.cuh +++ b/cpp/include/raft/random/detail/rng_impl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,12 +17,20 @@ #pragma once #include +#include +#include +#include +#include #include #include #include #include #include +#include + +#include + namespace raft { namespace random { namespace detail { @@ -278,6 +286,7 @@ std::enable_if_t> discrete(RngState& rng_state, len); } +/** Note the memory space requirements are O(4*len) */ template void sampleWithoutReplacement(RngState& rng_state, DataT* out, @@ -328,6 +337,133 @@ void affine_transform_params(RngState const& rng_state, IdxT n, IdxT& a, IdxT& b b = mt_rng() % n; } +/** @brief Sample without replacement from range 0..N-1. + * + * Elements are sampled uniformly. + * The algorithm will allocate a workspace of size O(4*n_samples) internally. + * + * We use max N random numbers. Depending on how large n_samples is w.r.t to N, we + * either use rejection sampling, or sort the [0..N-1] values using random keys. + * + * @tparam IdxT type of indices that we sample + * @tparam MatIdxT extent type of the returned mdarray + * + * @param res RAFT resource handle + * @param state random number generator state + * @param N number of elements to sample from. We will sample values in range 0..N-1 + * @param n_samples number of samples to return + * + * @return device mdarray with the random samples + */ +template +auto excess_subsample(raft::resources const& res, RngState& state, IdxT N, IdxT n_samples) + -> raft::device_vector +{ + RAFT_EXPECTS(n_samples <= N, "Cannot have more training samples than dataset vectors"); + + // Number of samples we'll need to sample (with replacement), to expect 'k' + // unique samples from 'n' is given by the following equation: log(1 - k/n)/log(1 - 1/n) ref: + // https://stats.stackexchange.com/questions/296005/the-expected-number-of-unique-elements-drawn-with-replacement + IdxT n_excess_samples = + n_samples < N + ? std::ceil(raft::log(1 - double(n_samples) / double(N)) / (raft::log(1 - 1 / double(N)))) + : N; + + // There is a variance of n_excess_samples, we take 10% more elements. + n_excess_samples += std::max(0.1 * n_samples, 100); + + // n_excess_sampless will be larger than N around k = 0.64*N. When we reach N, then instead of + // doing rejection sampling, we simply shuffle the range [0..N-1] using N random numbers. + n_excess_samples = std::min(n_excess_samples, N); + auto rnd_idx = raft::make_device_vector(res, n_excess_samples); + + auto linear_idx = raft::make_device_vector(res, rnd_idx.size()); + raft::linalg::map_offset(res, linear_idx.view(), identity_op()); + + uniformInt(res, state, rnd_idx.data_handle(), rnd_idx.size(), IdxT(0), IdxT(N)); + + // Sort indices according to rnd keys + size_t workspace_size = 0; + auto stream = resource::get_cuda_stream(res); + cub::DeviceMergeSort::SortPairs(nullptr, + workspace_size, + rnd_idx.data_handle(), + linear_idx.data_handle(), + rnd_idx.size(), + raft::less_op{}, + stream); + auto workspace = raft::make_device_vector(res, workspace_size); + cub::DeviceMergeSort::SortPairs(workspace.data_handle(), + workspace_size, + rnd_idx.data_handle(), + linear_idx.data_handle(), + rnd_idx.size(), + raft::less_op{}, + stream); + + if (rnd_idx.size() == static_cast(N)) { + // We shuffled the linear_idx array by sorting it according to rnd_idx. + // We return the first n_samples elements. + if (n_samples == N) { return linear_idx; } + rnd_idx = raft::make_device_vector(res, n_samples); + raft::copy(rnd_idx.data_handle(), linear_idx.data_handle(), n_samples, stream); + return rnd_idx; + } + // Else we do a rejection sampling (or excess sampling): we generated more random indices than + // needed and reject the duplicates. + auto keys_out = raft::make_device_vector(res, rnd_idx.size()); + auto values_out = raft::make_device_vector(res, rnd_idx.size()); + rmm::device_scalar num_selected(stream); + size_t worksize2 = 0; + cub::DeviceSelect::UniqueByKey(nullptr, + worksize2, + rnd_idx.data_handle(), + linear_idx.data_handle(), + keys_out.data_handle(), + values_out.data_handle(), + num_selected.data(), + rnd_idx.size(), + stream); + + if (worksize2 > workspace.size()) { + workspace = raft::make_device_vector(res, worksize2); + workspace_size = workspace.size(); + } + + cub::DeviceSelect::UniqueByKey(workspace.data_handle(), + workspace_size, + rnd_idx.data_handle(), + linear_idx.data_handle(), + keys_out.data_handle(), + values_out.data_handle(), + num_selected.data(), + rnd_idx.size(), + stream); + + IdxT selected = num_selected.value(stream); + + if (selected < n_samples) { + RAFT_LOG_DEBUG("Subsampling returned with less unique indices (%zu) than requested (%zu)", + (size_t)selected, + (size_t)n_samples); + + // We continue to select n_samples elements, this will now contains a few duplicates. + } + + // After duplicates are removed, we need to shuffle back to random order + cub::DeviceMergeSort::SortPairs(workspace.data_handle(), + workspace_size, + values_out.data_handle(), + keys_out.data_handle(), + n_samples, + raft::less_op{}, + stream); + + values_out = raft::make_device_vector(res, n_samples); + raft::copy(values_out.data_handle(), keys_out.data_handle(), n_samples, stream); + return values_out; +} + }; // end namespace detail }; // end namespace random }; // end namespace raft diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 4e63669f98..7fd461980f 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -813,6 +813,32 @@ void sampleWithoutReplacement(raft::resources const& handle, rng_state, out, outIdx, in, wts, sampledLen, len, resource::get_cuda_stream(handle)); } +/** @brief Sample from range 0..N-1. + * + * Elements are sampled uniformly. The method aims to sample without replacement, + * but there is a small probability of a few having duplicate elements. + * + * The algorithm will allocate a workspace of size 4*n_samples*sizeof(IdxT) internally. + * + * We use max N random numbers. Depending on how large n_samples is w.r.t to N, we + * either use rejection sampling, or sort the [0..N-1] values using random keys. + * + * @tparam IdxT type of indices that we sample + * @tparam MatIdxT extent type of the returned mdarray + * + * @param res RAFT resource handle + * @param state random number generator state + * @param N number of elements to sample from. We will sample values in range 0..N-1. + * @param n_samples number of samples to return + * + * @return device mdarray with the random samples + */ +template +auto excess_subsample(raft::resources const& res, RngState& state, IdxT N, IdxT n_samples) +{ + return detail::excess_subsample(res, state, N, n_samples); +} + /** * @brief Generates the 'a' and 'b' parameters for a modulo affine * transformation equation: `(ax + b) % n` diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index bf44cf9c60..ecb871fccc 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -267,6 +267,7 @@ if(BUILD_TESTS) test/matrix/matrix.cu test/matrix/norm.cu test/matrix/reverse.cu + test/matrix/sample_rows.cu test/matrix/slice.cu test/matrix/triangular.cu test/sparse/spectral_matrix.cu @@ -294,6 +295,7 @@ if(BUILD_TESTS) test/random/rng_int.cu test/random/rmat_rectangular_generator.cu test/random/sample_without_replacement.cu + test/random/excess_sampling.cu ) ConfigureTest( diff --git a/cpp/test/matrix/sample_rows.cu b/cpp/test/matrix/sample_rows.cu new file mode 100644 index 0000000000..e332a918fe --- /dev/null +++ b/cpp/test/matrix/sample_rows.cu @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace raft { +namespace matrix { + +struct inputs { + int N; + int dim; + int n_samples; + bool host; +}; + +::std::ostream& operator<<(::std::ostream& os, const inputs p) +{ + os << p.N << "#" << p.dim << "#" << p.n_samples << (p.host ? "#host" : "#device"); + return os; +} + +template +class SampleRowsTest : public ::testing::TestWithParam { + public: + SampleRowsTest() + : params(::testing::TestWithParam::GetParam()), + stream(resource::get_cuda_stream(res)), + state{137ULL}, + in(make_device_matrix(res, params.N, params.dim)), + out(make_device_matrix(res, 0, 0)), + in_h(make_host_matrix(res, params.N, params.dim)), + out_h(make_host_matrix(res, params.n_samples, params.dim)) + { + raft::random::uniform(res, state, in.data_handle(), in.size(), T(-1.0), T(1.0)); + for (int64_t i = 0; i < params.N; i++) { + for (int64_t k = 0; k < params.dim; k++) + in_h(i, k) = i * 1000 + k; + } + raft::copy(in.data_handle(), in_h.data_handle(), in_h.size(), stream); + } + + void check() + { + if (params.host) { + out = raft::matrix::sample_rows( + res, state, make_const_mdspan(in_h.view()), (int64_t)params.n_samples); + } else { + out = raft::matrix::sample_rows( + res, state, make_const_mdspan(in.view()), (int64_t)params.n_samples); + } + + raft::copy(out_h.data_handle(), out.data_handle(), out.size(), stream); + resource::sync_stream(res, stream); + + ASSERT_TRUE(out.extent(0) == params.n_samples); + ASSERT_TRUE(out.extent(1) == params.dim); + + std::unordered_set occurrence; + + for (int64_t i = 0; i < params.n_samples; ++i) { + T val = out_h(i, 0) / 1000; + ASSERT_TRUE(0 <= val && val < params.N) + << "out-of-range index @i=" << i << " val=" << val << " params=" << params; + EXPECT_TRUE(occurrence.find(val) == occurrence.end()) + << "repeated index @i=" << i << " idx=" << val << " params=" << params; + occurrence.insert(val); + for (int64_t k = 0; k < params.dim; k++) { + ASSERT_TRUE(raft::match(out_h(i, k), val * 1000 + k, raft::CompareApprox(1e-6))); + } + } + } + + protected: + inputs params; + raft::resources res; + cudaStream_t stream; + random::RngState state; + device_matrix in, out; + host_matrix in_h, out_h; +}; + +inline std::vector generate_inputs() +{ + std::vector input1 = + raft::util::itertools::product({10}, {1, 17, 96}, {1, 6, 9, 10}, {false}); + + std::vector input2 = + raft::util::itertools::product({137}, {1, 17, 128}, {1, 10, 100, 137}, {false}); + input1.insert(input1.end(), input2.begin(), input2.end()); + + input2 = raft::util::itertools::product( + {100000}, {1, 42}, {1, 137, 1000, 10000, 50000, 62000, 100000}, {false}); + + input1.insert(input1.end(), input2.begin(), input2.end()); + + int n = input1.size(); + // Add same tests for host data + for (int i = 0; i < n; i++) { + inputs x = input1[i]; + x.host = true; + input1.push_back(x); + } + return input1; +} + +const std::vector inputs1 = generate_inputs(); + +using SampleRowsTestInt64 = SampleRowsTest; +TEST_P(SampleRowsTestInt64, SamplingTest) { check(); } +INSTANTIATE_TEST_SUITE_P(SampleRowsTests, SampleRowsTestInt64, ::testing::ValuesIn(inputs1)); + +} // namespace matrix +} // namespace raft diff --git a/cpp/test/random/excess_sampling.cu b/cpp/test/random/excess_sampling.cu new file mode 100644 index 0000000000..e86436fb7d --- /dev/null +++ b/cpp/test/random/excess_sampling.cu @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace raft { +namespace random { + +using namespace raft::random; + +struct inputs { + int64_t N; + int64_t n_samples; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const inputs p) +{ + os << p.N << "/" << p.n_samples; + return os; +} + +template +class ExcessSamplingTest : public ::testing::TestWithParam { + public: + ExcessSamplingTest() + : params(::testing::TestWithParam::GetParam()), + stream(resource::get_cuda_stream(res)), + state{137ULL} + { + } + + void check() + { + device_vector out = + raft::random::excess_subsample(res, state, params.N, params.n_samples); + ASSERT_TRUE(out.extent(0) == params.n_samples); + + auto h_out = make_host_vector(res, params.n_samples); + raft::copy(h_out.data_handle(), out.data_handle(), out.size(), stream); + resource::sync_stream(res, stream); + + std::unordered_set occurrence; + int64_t sum = 0; + for (int64_t i = 0; i < params.n_samples; ++i) { + T val = h_out(i); + sum += val; + ASSERT_TRUE(0 <= val && val < params.N) + << "out-of-range index @i=" << i << " val=" << val << " n_samples=" << params.n_samples; + ASSERT_TRUE(occurrence.find(val) == occurrence.end()) + << "repeated index @i=" << i << " idx=" << val; + occurrence.insert(val); + } + float avg = sum / (float)params.n_samples; + if (params.n_samples >= 100 && params.N / params.n_samples < 100) { + ASSERT_TRUE(raft::match(avg, (params.N - 1) / 2.0f, raft::CompareApprox(0.2))) + << "non-uniform sample"; + } + } + + protected: + inputs params; + raft::resources res; + cudaStream_t stream; + RngState state; +}; + +const std::vector input1 = {{1, 0}, + {1, 1}, + {10, 0}, + {10, 1}, + {10, 2}, + {10, 10}, + {137, 42}, + {200, 0}, + {200, 1}, + {200, 100}, + {200, 130}, + {200, 200}, + {10000, 893}, + {10000000000, 1023}}; + +using ExcessSamplingTestInt64 = ExcessSamplingTest; +TEST_P(ExcessSamplingTestInt64, SamplingTest) { check(); } +INSTANTIATE_TEST_SUITE_P(ExcessSamplingTests, ExcessSamplingTestInt64, ::testing::ValuesIn(input1)); + +} // namespace random +} // namespace raft From 335236c705c0c53da8a4bf6a22835fdbe669f1df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Malte=20F=C3=B6rster?= <97973773+mfoerste4@users.noreply.github.com> Date: Wed, 20 Mar 2024 15:08:19 +0100 Subject: [PATCH 11/18] Performance optimization of IVF-flat / select_k (#2221) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR is a followup to #2169. To enable IVF-flat with k>256 we need an additional select_k invocation which was unexpectedly slow. There are two reasons for that: First problem is the data handed to select_k: The valid data length per row is much smaller than the conservative maximum that could be achieved by probing the N largest probes. Therefore each query row contains roughly ~50% dummy values. This is also the case for IVF-PQ, but did not show up as prominent due to the second reason. The second problem, and also a difference to the IVF-PQ algorithm - is that a 64bit payload data type is used for selectK. The performance of selectK with 64bit index type is significantly slower than with 32bit, especially when many elements are in the same range: ``` Benchmark Time CPU Iterations ----------------------------------------------------------------------------------------------------- SelectK/float/uint32_t/kRadix11bitsExtraPass/1/manual_time 1.68 ms 1.74 ms 413 1357#200000#512 SelectK/float/uint32_t/kRadix11bitsExtraPass/3/manual_time 2.31 ms 2.37 ms 302 1357#200000#512#same-leading-bits SelectK/float/int64_t/kRadix11bitsExtraPass/1/manual_time 5.92 ms 5.98 ms 116 1357#200000#512 SelectK/float/int64_t/kRadix11bitsExtraPass/3/manual_time 83.7 ms 83.8 ms 8 1357#200000#512#same-leading-bits ----------------------------------------------------------------------------------------------------- ``` The data distribution within a IVF-flat benchmark resulted in a select_k time of ~24ms. ### scope: * additional parameter added to select_k to optionally pass individual row lengths for every batch entry. This parameter is utilized by both IVF-Flat and IVF-PQ and results in a ~2x speedup (50 nodes out of 5000) of the final `select_k`. * refactor ivf-flat search to work with 32bit indices by storing positions instead of actual indices. This allows to utilize 32bit index type select_k for ~10x speedup in the final `select_k`. FYI @tfeher @achirkin ### not in scope: * General optimization of select_k: In the current implementation there is no difference in the type of the payload and the actual index type. Especially the type of the histogram has a large effect on performance (due to the atomics). Authors: - Malte Förster (https://github.com/mfoerste4) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/2221 --- .../raft/matrix/detail/select_k-ext.cuh | 8 +- .../raft/matrix/detail/select_k-inl.cuh | 16 ++-- .../raft/matrix/detail/select_radix.cuh | 35 ++++++- .../raft/neighbors/detail/ivf_common.cuh | 20 ++-- .../detail/ivf_flat_interleaved_scan-ext.cuh | 4 +- .../detail/ivf_flat_interleaved_scan-inl.cuh | 25 ++--- .../neighbors/detail/ivf_flat_search-inl.cuh | 91 +++++++++++-------- .../raft/neighbors/detail/ivf_pq_search.cuh | 5 +- .../raft/neighbors/detail/refine_device.cuh | 36 +++++++- .../matrix/detail/select_k_double_int64_t.cu | 3 +- .../matrix/detail/select_k_double_uint32_t.cu | 3 +- cpp/src/matrix/detail/select_k_float_int32.cu | 3 +- .../matrix/detail/select_k_float_int64_t.cu | 3 +- .../matrix/detail/select_k_float_uint32_t.cu | 3 +- .../matrix/detail/select_k_half_int64_t.cu | 3 +- .../matrix/detail/select_k_half_uint32_t.cu | 3 +- ...at_interleaved_scan_float_float_int64_t.cu | 2 +- ...flat_interleaved_scan_half_half_int64_t.cu | 2 +- ...interleaved_scan_int8_t_int32_t_int64_t.cu | 2 +- ...terleaved_scan_uint8_t_uint32_t_int64_t.cu | 2 +- cpp/test/neighbors/ann_cagra.cuh | 8 +- cpp/test/neighbors/ann_utils.cuh | 43 ++++++++- 22 files changed, 221 insertions(+), 99 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh index 6a7847d8a0..506cbffcb9 100644 --- a/cpp/include/raft/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -41,8 +41,9 @@ void select_k(raft::resources const& handle, T* out_val, IdxT* out_idx, bool select_min, - bool sorted = false, - SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT; + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto, + const IdxT* len_i = nullptr) RAFT_EXPLICIT; } // namespace raft::matrix::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -58,7 +59,8 @@ void select_k(raft::resources const& handle, IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(__half, uint32_t); instantiate_raft_matrix_detail_select_k(__half, int64_t); instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index 8f40e6ae00..93d233152b 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -229,6 +229,9 @@ void segmented_sort_by_key(raft::resources const& handle, * whether to make sure selected pairs are sorted by value * @param[in] algo * the selection algorithm to use + * @param[in] len_i + * array of size (batch_size) providing lengths for each individual row + * only radix select-k supported */ template void select_k(raft::resources const& handle, @@ -240,8 +243,9 @@ void select_k(raft::resources const& handle, T* out_val, IdxT* out_idx, bool select_min, - bool sorted = false, - SelectAlgo algo = SelectAlgo::kAuto) + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto, + const IdxT* len_i = nullptr) { common::nvtx::range fun_scope( "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); @@ -262,9 +266,8 @@ void select_k(raft::resources const& handle, out_val, out_idx, select_min, - true // fused_last_filter - ); - + true, // fused_last_filter + len_i); } else { bool fused_last_filter = algo == SelectAlgo::kRadix11bits; detail::select::radix::select_k(handle, @@ -276,7 +279,8 @@ void select_k(raft::resources const& handle, out_val, out_idx, select_min, - fused_last_filter); + fused_last_filter, + len_i); } if (sorted) { auto offsets = make_device_mdarray( diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 82983b7cd2..36a346fda3 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -557,6 +557,7 @@ RAFT_KERNEL radix_kernel(const T* in, Counter* counters, IdxT* histograms, const IdxT len, + const IdxT* len_i, const IdxT k, const bool select_min, const int pass) @@ -598,6 +599,14 @@ RAFT_KERNEL radix_kernel(const T* in, in_buf += batch_id * buf_len; in_idx_buf += batch_id * buf_len; } + + // in case we have individual len for each query defined we want to make sure + // that we only iterate valid elements. + if (len_i != nullptr) { + const IdxT max_len = max(len_i[batch_id], k); + if (max_len < previous_len) previous_len = max_len; + } + // "current_len > buf_len" means current pass will skip writing buffer if (pass == 0 || current_len > buf_len) { out_buf = nullptr; @@ -829,6 +838,7 @@ void radix_topk(const T* in, IdxT* out_idx, bool select_min, bool fused_last_filter, + const IdxT* len_i, unsigned grid_dim, int sm_cnt, rmm::cuda_stream_view stream, @@ -868,6 +878,7 @@ void radix_topk(const T* in, const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; T* chunk_out = out + offset * k; IdxT* chunk_out_idx = out_idx + offset * k; + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; const T* in_buf = nullptr; const IdxT* in_idx_buf = nullptr; @@ -905,6 +916,7 @@ void radix_topk(const T* in, counters.data(), histograms.data(), len, + chunk_len_i, k, select_min, pass); @@ -1007,6 +1019,7 @@ template RAFT_KERNEL radix_topk_one_block_kernel(const T* in, const IdxT* in_idx, const IdxT len, + const IdxT* len_i, const IdxT k, T* out, IdxT* out_idx, @@ -1057,6 +1070,13 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, out_idx_buf = nullptr; } + // in case we have individual len for each query defined we want to make sure + // that we only iterate valid elements. + if (len_i != nullptr) { + const IdxT max_len = max(len_i[batch_id], k); + if (max_len < previous_len) previous_len = max_len; + } + filter_and_histogram_for_one_block(in_buf, in_idx_buf, out_buf, @@ -1106,6 +1126,7 @@ void radix_topk_one_block(const T* in, T* out, IdxT* out_idx, bool select_min, + const IdxT* len_i, int sm_cnt, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -1121,10 +1142,12 @@ void radix_topk_one_block(const T* in, max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { - int chunk_size = std::min(max_chunk_size, batch_size - offset); + int chunk_size = std::min(max_chunk_size, batch_size - offset); + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; kernel<<>>(in + offset * len, in_idx ? (in_idx + offset * len) : nullptr, len, + chunk_len_i, k, out + offset * k, out_idx + offset * k, @@ -1188,6 +1211,8 @@ void radix_topk_one_block(const T* in, * blocks is called. The later case is preferable when leading bits of input data are almost the * same. That is, when the value range of input data is narrow. In such case, there could be a * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. + * @param len_i + * optional array of size (batch_size) providing lengths for each individual row */ template void select_k(raft::resources const& res, @@ -1199,7 +1224,8 @@ void select_k(raft::resources const& res, T* out, IdxT* out_idx, bool select_min, - bool fused_last_filter) + bool fused_last_filter, + const IdxT* len_i) { auto stream = resource::get_cuda_stream(res); auto mr = resource::get_workspace_resource(res); @@ -1223,13 +1249,13 @@ void select_k(raft::resources const& res, if (len <= BlockSize * items_per_thread) { impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { unsigned grid_dim = impl::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { impl::radix_topk(in, in_idx, @@ -1240,6 +1266,7 @@ void select_k(raft::resources const& res, out_idx, select_min, fused_last_filter, + len_i, grid_dim, sm_cnt, stream, diff --git a/cpp/include/raft/neighbors/detail/ivf_common.cuh b/cpp/include/raft/neighbors/detail/ivf_common.cuh index ef7ae7c804..df0319e181 100644 --- a/cpp/include/raft/neighbors/detail/ivf_common.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_common.cuh @@ -147,11 +147,11 @@ __device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT return ix_min; } -template +template __launch_bounds__(BlockDim) RAFT_KERNEL - postprocess_neighbors_kernel(IdxT1* neighbors_out, // [n_queries, topk] - const IdxT2* neighbors_in, // [n_queries, topk] - const IdxT1* const* db_indices, // [n_clusters][..] + postprocess_neighbors_kernel(IdxT* neighbors_out, // [n_queries, topk] + const uint32_t* neighbors_in, // [n_queries, topk] + const IdxT* const* db_indices, // [n_clusters][..] const uint32_t* clusters_to_probe, // [n_queries, n_probes] const uint32_t* chunk_indices, // [n_queries, n_probes] uint32_t n_queries, @@ -170,7 +170,7 @@ __launch_bounds__(BlockDim) RAFT_KERNEL const uint32_t chunk_ix = find_chunk_ix(data_ix, n_probes, chunk_indices); const bool valid = chunk_ix < n_probes; neighbors_out[k] = - valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord; + valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord; } /** @@ -180,10 +180,10 @@ __launch_bounds__(BlockDim) RAFT_KERNEL * probed clusters / defined by the `chunk_indices`. * We assume the searched sample sizes (for a single query) fit into `uint32_t`. */ -template -void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, topk] - const IdxT2* neighbors_in, // [n_queries, topk] - const IdxT1* const* db_indices, // [n_clusters][..] +template +void postprocess_neighbors(IdxT* neighbors_out, // [n_queries, topk] + const uint32_t* neighbors_in, // [n_queries, topk] + const IdxT* const* db_indices, // [n_clusters][..] const uint32_t* clusters_to_probe, // [n_queries, n_probes] const uint32_t* chunk_indices, // [n_queries, n_probes] uint32_t n_queries, @@ -193,7 +193,7 @@ void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, to { constexpr int kPNThreads = 256; const int pn_blocks = raft::div_rounding_up_unsafe(n_queries * topk, kPNThreads); - postprocess_neighbors_kernel + postprocess_neighbors_kernel <<>>(neighbors_out, neighbors_in, db_indices, diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh index 7c2d1d2157..140a9f17c8 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh @@ -45,7 +45,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i const uint32_t* chunk_indices, const bool select_min, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) RAFT_EXPLICIT; @@ -70,7 +70,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index 6fc528e26b..9cd8b70148 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -690,7 +690,6 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) const uint32_t query_smem_elems, const T* query, const uint32_t* coarse_index, - const IdxT* const* list_indices_ptrs, const T* const* list_data_ptrs, const uint32_t* list_sizes, const uint32_t queries_offset, @@ -700,7 +699,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) const uint32_t* chunk_indices, const uint32_t dim, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances) { extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[]; @@ -719,8 +718,8 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) distances += query_id * k * gridDim.x + blockIdx.x * k; } else { distances += query_id * uint64_t(max_samples); - chunk_indices += (n_probes * query_id); } + chunk_indices += (n_probes * query_id); coarse_index += query_id * n_probes; } @@ -728,7 +727,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); __syncthreads(); - using local_topk_t = block_sort_t; + using local_topk_t = block_sort_t; local_topk_t queue(k); { using align_warp = Pow2; @@ -752,11 +751,9 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2 uint32_t sample_offset = 0; - if constexpr (!kManageLocalTopK) { - if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; } - assert(list_length == chunk_indices[probe_id] - sample_offset); - assert(sample_offset + list_length <= max_samples); - } + if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; } + assert(list_length == chunk_indices[probe_id] - sample_offset); + assert(sample_offset + list_length <= max_samples); constexpr int kUnroll = WarpSize / Veclen; constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize; @@ -806,8 +803,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) // Enqueue one element per thread const float val = valid ? static_cast(dist) : local_topk_t::queue_t::kDummy; if constexpr (kManageLocalTopK) { - const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; - queue.add(val, idx); + queue.add(val, sample_offset + vec_id); } else { if (vec_id < list_length) distances[sample_offset + vec_id] = val; } @@ -873,7 +869,7 @@ void launch_kernel(Lambda lambda, const uint32_t max_samples, const uint32_t* chunk_indices, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) @@ -927,7 +923,6 @@ void launch_kernel(Lambda lambda, query_smem_elems, queries, coarse_index, - index.inds_ptrs().data_handle(), index.data_ptrs().data_handle(), index.list_sizes().data_handle(), queries_offset + query_offset, @@ -945,8 +940,8 @@ void launch_kernel(Lambda lambda, distances += grid_dim_y * grid_dim_x * k; } else { distances += grid_dim_y * max_samples; - chunk_indices += grid_dim_y * n_probes; } + chunk_indices += grid_dim_y * n_probes; coarse_index += grid_dim_y * n_probes; } } @@ -1161,7 +1156,7 @@ void ivfflat_interleaved_scan(const index& index, const uint32_t* chunk_indices, const bool select_min, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 98bdeda42f..441fb76b2f 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -67,13 +67,16 @@ void search_impl(raft::resources const& handle, // Optional structures if postprocessing is required // The topk distance value of candidate vectors from each cluster(list) rmm::device_uvector distances_tmp_dev(0, stream, search_mr); - // The topk index of candidate vectors from each cluster(list) - rmm::device_uvector indices_tmp_dev(0, stream, search_mr); // Number of samples for each query rmm::device_uvector num_samples(0, stream, search_mr); // Offsets per probe for each query rmm::device_uvector chunk_index(0, stream, search_mr); + // The topk index of candidate vectors from each cluster(list), local index offset + // also we might need additional storage for select_k + rmm::device_uvector indices_tmp_dev(0, stream, search_mr); + rmm::device_uvector neighbors_uint32_buf(0, stream, search_mr); + size_t float_query_size; if constexpr (std::is_integral_v) { float_query_size = n_queries * index.dim(); @@ -175,23 +178,29 @@ void search_impl(raft::resources const& handle, grid_dim_x = 1; } + num_samples.resize(n_queries, stream); + chunk_index.resize(n_queries_probes, stream); + + ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)(index.list_sizes().data_handle(), + coarse_indices_dev.data(), + chunk_index.data(), + num_samples.data(), + stream); + auto distances_dev_ptr = distances; - auto indices_dev_ptr = neighbors; + + uint32_t* neighbors_uint32 = nullptr; + if constexpr (sizeof(IdxT) == sizeof(uint32_t)) { + neighbors_uint32 = reinterpret_cast(neighbors); + } else { + neighbors_uint32_buf.resize(std::size_t(n_queries) * std::size_t(k), stream); + neighbors_uint32 = neighbors_uint32_buf.data(); + } + + uint32_t* indices_dev_ptr = nullptr; bool manage_local_topk = is_local_topk_feasible(k); if (!manage_local_topk || grid_dim_x > 1) { - if (!manage_local_topk) { - num_samples.resize(n_queries, stream); - chunk_index.resize(n_queries_probes, stream); - - ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)( - index.list_sizes().data_handle(), - coarse_indices_dev.data(), - chunk_index.data(), - num_samples.data(), - stream); - } - auto target_size = std::size_t(n_queries) * (manage_local_topk ? grid_dim_x * k : max_samples); distances_tmp_dev.resize(target_size, stream); @@ -199,6 +208,8 @@ void search_impl(raft::resources const& handle, distances_dev_ptr = distances_tmp_dev.data(); indices_dev_ptr = indices_tmp_dev.data(); + } else { + indices_dev_ptr = neighbors_uint32; } ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( @@ -224,31 +235,33 @@ void search_impl(raft::resources const& handle, // Merge topk values from different blocks if (!manage_local_topk || grid_dim_x > 1) { - matrix::detail::select_k(handle, - distances_tmp_dev.data(), - indices_tmp_dev.data(), - n_queries, - manage_local_topk ? (k * grid_dim_x) : max_samples, - k, - distances, - neighbors, - select_min); - - if (!manage_local_topk) { - // post process distances && neighbor IDs - ivf::detail::postprocess_distances( - distances, distances, index.metric(), n_queries, k, 1.0, false, stream); - ivf::detail::postprocess_neighbors(neighbors, - neighbors, - index.inds_ptrs().data_handle(), - coarse_indices_dev.data(), - chunk_index.data(), - n_queries, - n_probes, - k, - stream); - } + matrix::detail::select_k(handle, + distances_tmp_dev.data(), + indices_tmp_dev.data(), + n_queries, + manage_local_topk ? (k * grid_dim_x) : max_samples, + k, + distances, + neighbors_uint32, + select_min, + false, + matrix::SelectAlgo::kAuto, + manage_local_topk ? nullptr : num_samples.data()); + } + if (!manage_local_topk) { + // post process distances && neighbor IDs + ivf::detail::postprocess_distances( + distances, distances, index.metric(), n_queries, k, 1.0, false, stream); } + ivf::detail::postprocess_neighbors(neighbors, + neighbors_uint32, + index.inds_ptrs().data_handle(), + coarse_indices_dev.data(), + chunk_index.data(), + n_queries, + n_probes, + k, + stream); } /** See raft::neighbors::ivf_flat::search docs */ diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index d445f909e5..4c5da38092 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -447,7 +447,10 @@ void ivfpq_search_worker(raft::resources const& handle, topK, topk_dists.data(), neighbors_uint32, - true); + true, + false, + matrix::SelectAlgo::kAuto, + manage_local_topk ? nullptr : num_samples.data()); // Postprocessing ivf::detail::postprocess_distances( diff --git a/cpp/include/raft/neighbors/detail/refine_device.cuh b/cpp/include/raft/neighbors/detail/refine_device.cuh index e76e52657b..bdc29ca121 100644 --- a/cpp/include/raft/neighbors/detail/refine_device.cuh +++ b/cpp/include/raft/neighbors/detail/refine_device.cuh @@ -88,6 +88,27 @@ void refine_device(raft::resources const& handle, n_queries, n_candidates); uint32_t grid_dim_x = 1; + + // the neighbor ids will be computed in uint32_t as offset + rmm::device_uvector neighbors_uint32_buf(0, resource::get_cuda_stream(handle)); + // Offsets per probe for each query [n_queries] as n_probes = 1 + rmm::device_uvector chunk_index(n_queries, resource::get_cuda_stream(handle)); + + // we know that each cluster has exactly n_candidates entries + thrust::fill(resource::get_thrust_policy(handle), + chunk_index.data(), + chunk_index.data() + n_queries, + uint32_t(n_candidates)); + + uint32_t* neighbors_uint32 = nullptr; + if constexpr (sizeof(idx_t) == sizeof(uint32_t)) { + neighbors_uint32 = reinterpret_cast(indices.data_handle()); + } else { + neighbors_uint32_buf.resize(std::size_t(n_queries) * std::size_t(k), + resource::get_cuda_stream(handle)); + neighbors_uint32 = neighbors_uint32_buf.data(); + } + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< data_t, typename raft::spatial::knn::detail::utils::config::value_t, @@ -100,13 +121,24 @@ void refine_device(raft::resources const& handle, 1, k, 0, - nullptr, + chunk_index.data(), raft::distance::is_min_close(metric), raft::neighbors::filtering::none_ivf_sample_filter(), - indices.data_handle(), + neighbors_uint32, distances.data_handle(), grid_dim_x, resource::get_cuda_stream(handle)); + + // postprocessing -- neighbors from position to actual id + ivf::detail::postprocess_neighbors(indices.data_handle(), + neighbors_uint32, + refinement_index.inds_ptrs().data_handle(), + fake_coarse_idx.data(), + chunk_index.data(), + n_queries, + 1, + k, + resource::get_cuda_stream(handle)); } } // namespace raft::neighbors::detail diff --git a/cpp/src/matrix/detail/select_k_double_int64_t.cu b/cpp/src/matrix/detail/select_k_double_int64_t.cu index e32b4ef6f0..bf234aacbf 100644 --- a/cpp/src/matrix/detail/select_k_double_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_double_int64_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(double, int64_t); diff --git a/cpp/src/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/matrix/detail/select_k_double_uint32_t.cu index 21c954ca46..7f0511a76a 100644 --- a/cpp/src/matrix/detail/select_k_double_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_double_uint32_t.cu @@ -29,7 +29,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(double, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_float_int32.cu b/cpp/src/matrix/detail/select_k_float_int32.cu index 7f163a0b0d..e68b1e32df 100644 --- a/cpp/src/matrix/detail/select_k_float_int32.cu +++ b/cpp/src/matrix/detail/select_k_float_int32.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(float, int); diff --git a/cpp/src/matrix/detail/select_k_float_int64_t.cu b/cpp/src/matrix/detail/select_k_float_int64_t.cu index 87b6525356..5aa40d8c9d 100644 --- a/cpp/src/matrix/detail/select_k_float_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_float_int64_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/src/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/detail/select_k_float_uint32_t.cu index e698f811d8..9aba147edf 100644 --- a/cpp/src/matrix/detail/select_k_float_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_float_uint32_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(float, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_half_int64_t.cu b/cpp/src/matrix/detail/select_k_half_int64_t.cu index 0eee20b1fa..bc513e4aeb 100644 --- a/cpp/src/matrix/detail/select_k_half_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_half_int64_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(__half, int64_t); diff --git a/cpp/src/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/detail/select_k_half_uint32_t.cu index f4e6bae21f..e46c7d46bb 100644 --- a/cpp/src/matrix/detail/select_k_half_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_half_uint32_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(__half, uint32_t); diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu index def33e493e..5ac820e0dd 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -33,7 +33,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu index e96600ee02..4d847cdeb1 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu @@ -35,7 +35,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu index 13c9d2e283..8a0e8f0118 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -33,7 +33,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu index 51f02343fc..7cad992e2b 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -33,7 +33,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index a111de0762..7278f71a24 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -549,6 +549,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { EXPECT_FALSE(unacceptable_node); double min_recall = ps.min_recall; + // TODO(mfoerster): re-enable uniquenes test EXPECT_TRUE(eval_neighbours(indices_naive, indices_Cagra, distances_naive, @@ -556,7 +557,8 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { ps.n_queries, ps.k, 0.003, - min_recall)); + min_recall, + false)); EXPECT_TRUE(eval_distances(handle_, database.data(), search_queries.data(), @@ -668,6 +670,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { } double min_recall = ps.min_recall; + // TODO(mfoerster): re-enable uniquenes test EXPECT_TRUE(eval_neighbours(indices_naive, indices_Cagra, distances_naive, @@ -675,7 +678,8 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { ps.n_queries, ps.k, 0.003, - min_recall)); + min_recall, + false)); EXPECT_TRUE(eval_distances(handle_, database.data(), search_queries.data(), diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index afd083d512..6be2ac7fc7 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -35,6 +35,7 @@ #include #include +#include namespace raft::neighbors { @@ -153,13 +154,40 @@ auto calc_recall(const std::vector& expected_idx, static_cast(match_count) / static_cast(total_count), match_count, total_count); } +/** check uniqueness of indices + */ +template +auto check_unique_indices(const std::vector& actual_idx, size_t rows, size_t cols) +{ + size_t max_count; + std::set unique_indices; + for (size_t i = 0; i < rows; ++i) { + unique_indices.clear(); + max_count = 0; + for (size_t k = 0; k < cols; ++k) { + size_t idx_k = i * cols + k; // row major assumption! + auto act_idx = actual_idx[idx_k]; + if (act_idx == std::numeric_limits::max()) { + max_count++; + } else if (unique_indices.find(act_idx) == unique_indices.end()) { + unique_indices.insert(act_idx); + } else { + return testing::AssertionFailure() + << "Duplicated index " << act_idx << " at k " << k << " for query " << i << "! "; + } + } + } + return testing::AssertionSuccess(); +} + template auto eval_recall(const std::vector& expected_idx, const std::vector& actual_idx, size_t rows, size_t cols, double eps, - double min_recall) -> testing::AssertionResult + double min_recall, + bool test_unique = true) -> testing::AssertionResult { auto [actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, rows, cols); @@ -176,7 +204,10 @@ auto eval_recall(const std::vector& expected_idx, << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" << min_recall << "); eps = " << eps << ". "; } - return testing::AssertionSuccess(); + if (test_unique) + return check_unique_indices(actual_idx, rows, cols); + else + return testing::AssertionSuccess(); } /** Overload of calc_recall to account for distances @@ -224,7 +255,8 @@ auto eval_neighbours(const std::vector& expected_idx, size_t rows, size_t cols, double eps, - double min_recall) -> testing::AssertionResult + double min_recall, + bool test_unique = true) -> testing::AssertionResult { auto [actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, expected_dist, actual_dist, rows, cols, eps); @@ -241,7 +273,10 @@ auto eval_neighbours(const std::vector& expected_idx, << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" << min_recall << "); eps = " << eps << ". "; } - return testing::AssertionSuccess(); + if (test_unique) + return check_unique_indices(actual_idx, rows, cols); + else + return testing::AssertionSuccess(); } template From de7341e795fc880bb4fa34e4c833adfc9a54afc1 Mon Sep 17 00:00:00 2001 From: tsuki <12711693+enp1s0@users.noreply.github.com> Date: Thu, 21 Mar 2024 14:52:39 +0900 Subject: [PATCH 12/18] CAGRA-Q search (#2206) Rel: #1889 ## Limitations - Only 8-bit PQ is supported - Sub-space size is only 2 supported Authors: - tsuki (https://github.com/enp1s0) - Artem M. Chirkin (https://github.com/achirkin) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/2206 --- cpp/CMakeLists.txt | 96 +++ cpp/include/raft/neighbors/dataset.hpp | 2 +- .../neighbors/detail/cagra/cagra_search.cuh | 228 ++++++-- .../detail/cagra/compute_distance.hpp | 298 ++++------ .../detail/cagra/compute_distance_vpq.cuh | 227 ++++++++ .../neighbors/detail/cagra/device_common.hpp | 8 +- .../raft/neighbors/detail/cagra/factory.cuh | 28 +- .../detail/cagra/search_multi_cta.cuh | 100 ++-- .../cagra/search_multi_cta_kernel-ext.cuh | 399 +++++++++++-- .../cagra/search_multi_cta_kernel-inl.cuh | 185 +++--- .../detail/cagra/search_multi_kernel.cuh | 386 ++++++------ .../neighbors/detail/cagra/search_plan.cuh | 8 +- .../detail/cagra/search_single_cta.cuh | 86 +-- .../cagra/search_single_cta_kernel-ext.cuh | 548 ++++++++++++++++-- .../cagra/search_single_cta_kernel-inl.cuh | 230 ++++---- .../raft/neighbors/detail/cagra/utils.hpp | 5 + .../raft/neighbors/detail/refine_host-ext.hpp | 1 + .../raft/neighbors/detail/vpq_dataset.cuh | 2 +- cpp/include/raft/neighbors/refine-ext.cuh | 21 +- .../cagra/q_search_multi_cta_00_generate.py | 84 +++ ...float_uint32_dim1024_t32_8pq_2subd_half.cu | 37 ++ ...float_uint32_dim1024_t32_8pq_4subd_half.cu | 37 ++ ...a_float_uint32_dim128_t8_8pq_2subd_half.cu | 37 ++ ...a_float_uint32_dim128_t8_8pq_4subd_half.cu | 37 ++ ..._float_uint32_dim256_t16_8pq_2subd_half.cu | 37 ++ ..._float_uint32_dim256_t16_8pq_4subd_half.cu | 37 ++ ..._float_uint32_dim512_t32_8pq_2subd_half.cu | 37 ++ ..._float_uint32_dim512_t32_8pq_4subd_half.cu | 37 ++ ...float_uint64_dim1024_t32_8pq_2subd_half.cu | 37 ++ ...float_uint64_dim1024_t32_8pq_4subd_half.cu | 37 ++ ...a_float_uint64_dim128_t8_8pq_2subd_half.cu | 37 ++ ...a_float_uint64_dim128_t8_8pq_4subd_half.cu | 37 ++ ..._float_uint64_dim256_t16_8pq_2subd_half.cu | 37 ++ ..._float_uint64_dim256_t16_8pq_4subd_half.cu | 37 ++ ..._float_uint64_dim512_t32_8pq_2subd_half.cu | 37 ++ ..._float_uint64_dim512_t32_8pq_4subd_half.cu | 37 ++ ..._half_uint32_dim1024_t32_8pq_2subd_half.cu | 37 ++ ..._half_uint32_dim1024_t32_8pq_4subd_half.cu | 37 ++ ...ta_half_uint32_dim128_t8_8pq_2subd_half.cu | 37 ++ ...ta_half_uint32_dim128_t8_8pq_4subd_half.cu | 37 ++ ...a_half_uint32_dim256_t16_8pq_2subd_half.cu | 37 ++ ...a_half_uint32_dim256_t16_8pq_4subd_half.cu | 37 ++ ...a_half_uint32_dim512_t32_8pq_2subd_half.cu | 37 ++ ...a_half_uint32_dim512_t32_8pq_4subd_half.cu | 37 ++ ..._half_uint64_dim1024_t32_8pq_2subd_half.cu | 37 ++ ..._half_uint64_dim1024_t32_8pq_4subd_half.cu | 37 ++ ...ta_half_uint64_dim128_t8_8pq_2subd_half.cu | 37 ++ ...ta_half_uint64_dim128_t8_8pq_4subd_half.cu | 37 ++ ...a_half_uint64_dim256_t16_8pq_2subd_half.cu | 37 ++ ...a_half_uint64_dim256_t16_8pq_4subd_half.cu | 37 ++ ...a_half_uint64_dim512_t32_8pq_2subd_half.cu | 37 ++ ...a_half_uint64_dim512_t32_8pq_4subd_half.cu | 37 ++ ..._int8_uint32_dim1024_t32_8pq_2subd_half.cu | 37 ++ ..._int8_uint32_dim1024_t32_8pq_4subd_half.cu | 37 ++ ...ta_int8_uint32_dim128_t8_8pq_2subd_half.cu | 37 ++ ...ta_int8_uint32_dim128_t8_8pq_4subd_half.cu | 37 ++ ...a_int8_uint32_dim256_t16_8pq_2subd_half.cu | 37 ++ ...a_int8_uint32_dim256_t16_8pq_4subd_half.cu | 37 ++ ...a_int8_uint32_dim512_t32_8pq_2subd_half.cu | 37 ++ ...a_int8_uint32_dim512_t32_8pq_4subd_half.cu | 37 ++ ...uint8_uint32_dim1024_t32_8pq_2subd_half.cu | 37 ++ ...uint8_uint32_dim1024_t32_8pq_4subd_half.cu | 37 ++ ...a_uint8_uint32_dim128_t8_8pq_2subd_half.cu | 37 ++ ...a_uint8_uint32_dim128_t8_8pq_4subd_half.cu | 37 ++ ..._uint8_uint32_dim256_t16_8pq_2subd_half.cu | 37 ++ ..._uint8_uint32_dim256_t16_8pq_4subd_half.cu | 37 ++ ..._uint8_uint32_dim512_t32_8pq_2subd_half.cu | 37 ++ ..._uint8_uint32_dim512_t32_8pq_4subd_half.cu | 37 ++ .../cagra/q_search_single_cta_00_generate.py | 89 +++ ...float_uint32_dim1024_t32_8pq_2subd_half.cu | 37 ++ ...float_uint32_dim1024_t32_8pq_4subd_half.cu | 37 ++ ...a_float_uint32_dim128_t8_8pq_2subd_half.cu | 37 ++ ...a_float_uint32_dim128_t8_8pq_4subd_half.cu | 37 ++ ..._float_uint32_dim256_t16_8pq_2subd_half.cu | 37 ++ ..._float_uint32_dim256_t16_8pq_4subd_half.cu | 37 ++ ..._float_uint32_dim512_t32_8pq_2subd_half.cu | 37 ++ ..._float_uint32_dim512_t32_8pq_4subd_half.cu | 37 ++ ...float_uint64_dim1024_t32_8pq_2subd_half.cu | 37 ++ ...float_uint64_dim1024_t32_8pq_4subd_half.cu | 37 ++ ...a_float_uint64_dim128_t8_8pq_2subd_half.cu | 37 ++ ...a_float_uint64_dim128_t8_8pq_4subd_half.cu | 37 ++ ..._float_uint64_dim256_t16_8pq_2subd_half.cu | 37 ++ ..._float_uint64_dim256_t16_8pq_4subd_half.cu | 37 ++ ..._float_uint64_dim512_t32_8pq_2subd_half.cu | 37 ++ ..._float_uint64_dim512_t32_8pq_4subd_half.cu | 37 ++ ..._half_uint32_dim1024_t32_8pq_2subd_half.cu | 37 ++ ..._half_uint32_dim1024_t32_8pq_4subd_half.cu | 37 ++ ...ta_half_uint32_dim128_t8_8pq_2subd_half.cu | 37 ++ ...ta_half_uint32_dim128_t8_8pq_4subd_half.cu | 37 ++ ...a_half_uint32_dim256_t16_8pq_2subd_half.cu | 37 ++ ...a_half_uint32_dim256_t16_8pq_4subd_half.cu | 37 ++ ...a_half_uint32_dim512_t32_8pq_2subd_half.cu | 37 ++ ...a_half_uint32_dim512_t32_8pq_4subd_half.cu | 37 ++ ..._half_uint64_dim1024_t32_8pq_2subd_half.cu | 37 ++ ..._half_uint64_dim1024_t32_8pq_4subd_half.cu | 37 ++ ...ta_half_uint64_dim128_t8_8pq_2subd_half.cu | 37 ++ ...ta_half_uint64_dim128_t8_8pq_4subd_half.cu | 37 ++ ...a_half_uint64_dim256_t16_8pq_2subd_half.cu | 37 ++ ...a_half_uint64_dim256_t16_8pq_4subd_half.cu | 37 ++ ...a_half_uint64_dim512_t32_8pq_2subd_half.cu | 37 ++ ...a_half_uint64_dim512_t32_8pq_4subd_half.cu | 37 ++ ..._int8_uint32_dim1024_t32_8pq_2subd_half.cu | 37 ++ ..._int8_uint32_dim1024_t32_8pq_4subd_half.cu | 37 ++ ...ta_int8_uint32_dim128_t8_8pq_2subd_half.cu | 37 ++ ...ta_int8_uint32_dim128_t8_8pq_4subd_half.cu | 37 ++ ...a_int8_uint32_dim256_t16_8pq_2subd_half.cu | 37 ++ ...a_int8_uint32_dim256_t16_8pq_4subd_half.cu | 37 ++ ...a_int8_uint32_dim512_t32_8pq_2subd_half.cu | 37 ++ ...a_int8_uint32_dim512_t32_8pq_4subd_half.cu | 37 ++ ...uint8_uint32_dim1024_t32_8pq_2subd_half.cu | 37 ++ ...uint8_uint32_dim1024_t32_8pq_4subd_half.cu | 37 ++ ...a_uint8_uint32_dim128_t8_8pq_2subd_half.cu | 37 ++ ...a_uint8_uint32_dim128_t8_8pq_4subd_half.cu | 37 ++ ..._uint8_uint32_dim256_t16_8pq_2subd_half.cu | 37 ++ ..._uint8_uint32_dim256_t16_8pq_4subd_half.cu | 37 ++ ..._uint8_uint32_dim512_t32_8pq_2subd_half.cu | 37 ++ ..._uint8_uint32_dim512_t32_8pq_4subd_half.cu | 37 ++ .../detail/cagra/search_multi_cta.cuh | 51 ++ .../cagra/search_multi_cta_00_generate.py | 42 +- ...arch_multi_cta_float_uint32_dim1024_t32.cu | 45 +- ...search_multi_cta_float_uint32_dim128_t8.cu | 45 +- ...earch_multi_cta_float_uint32_dim256_t16.cu | 45 +- ...earch_multi_cta_float_uint32_dim512_t32.cu | 45 +- ...arch_multi_cta_float_uint64_dim1024_t32.cu | 45 +- ...search_multi_cta_float_uint64_dim128_t8.cu | 45 +- ...earch_multi_cta_float_uint64_dim256_t16.cu | 45 +- ...earch_multi_cta_float_uint64_dim512_t32.cu | 45 +- ...earch_multi_cta_half_uint32_dim1024_t32.cu | 43 +- .../search_multi_cta_half_uint32_dim128_t8.cu | 43 +- ...search_multi_cta_half_uint32_dim256_t16.cu | 43 +- ...search_multi_cta_half_uint32_dim512_t32.cu | 43 +- ...earch_multi_cta_half_uint64_dim1024_t32.cu | 43 +- .../search_multi_cta_half_uint64_dim128_t8.cu | 43 +- ...search_multi_cta_half_uint64_dim256_t16.cu | 43 +- ...search_multi_cta_half_uint64_dim512_t32.cu | 43 +- ...earch_multi_cta_int8_uint32_dim1024_t32.cu | 45 +- .../search_multi_cta_int8_uint32_dim128_t8.cu | 45 +- ...search_multi_cta_int8_uint32_dim256_t16.cu | 45 +- ...search_multi_cta_int8_uint32_dim512_t32.cu | 45 +- ...arch_multi_cta_uint8_uint32_dim1024_t32.cu | 45 +- ...search_multi_cta_uint8_uint32_dim128_t8.cu | 45 +- ...earch_multi_cta_uint8_uint32_dim256_t16.cu | 45 +- ...earch_multi_cta_uint8_uint32_dim512_t32.cu | 45 +- .../detail/cagra/search_single_cta.cuh | 52 ++ .../cagra/search_single_cta_00_generate.py | 43 +- ...rch_single_cta_float_uint32_dim1024_t32.cu | 48 +- ...earch_single_cta_float_uint32_dim128_t8.cu | 48 +- ...arch_single_cta_float_uint32_dim256_t16.cu | 48 +- ...arch_single_cta_float_uint32_dim512_t32.cu | 48 +- ...rch_single_cta_float_uint64_dim1024_t32.cu | 48 +- ...earch_single_cta_float_uint64_dim128_t8.cu | 48 +- ...arch_single_cta_float_uint64_dim256_t16.cu | 48 +- ...arch_single_cta_float_uint64_dim512_t32.cu | 48 +- ...arch_single_cta_half_uint32_dim1024_t32.cu | 46 +- ...search_single_cta_half_uint32_dim128_t8.cu | 46 +- ...earch_single_cta_half_uint32_dim256_t16.cu | 46 +- ...earch_single_cta_half_uint32_dim512_t32.cu | 46 +- ...arch_single_cta_half_uint64_dim1024_t32.cu | 46 +- ...search_single_cta_half_uint64_dim128_t8.cu | 46 +- ...earch_single_cta_half_uint64_dim256_t16.cu | 46 +- ...earch_single_cta_half_uint64_dim512_t32.cu | 46 +- ...arch_single_cta_int8_uint32_dim1024_t32.cu | 48 +- ...search_single_cta_int8_uint32_dim128_t8.cu | 48 +- ...earch_single_cta_int8_uint32_dim256_t16.cu | 48 +- ...earch_single_cta_int8_uint32_dim512_t32.cu | 48 +- ...rch_single_cta_uint8_uint32_dim1024_t32.cu | 48 +- ...earch_single_cta_uint8_uint32_dim128_t8.cu | 48 +- ...arch_single_cta_uint8_uint32_dim256_t16.cu | 48 +- ...arch_single_cta_uint8_uint32_dim512_t32.cu | 48 +- .../detail/refine_host_float_float.cpp | 3 +- cpp/src/neighbors/refine_float_float.cu | 30 +- cpp/test/CMakeLists.txt | 2 + .../ann_cagra/search_kernel_uint64_t.cuh | 192 +++--- cpp/test/neighbors/ann_cagra_vpq.cuh | 336 +++++++++++ .../ann_cagra_vpq/test_float_int64_t.cu | 29 + .../ann_cagra_vpq/test_float_uint32_t.cu | 28 + cpp/test/neighbors/ann_utils.cuh | 2 +- 177 files changed, 6795 insertions(+), 2798 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_00_generate.py create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim1024_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim1024_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim128_t8_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim128_t8_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim256_t16_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim256_t16_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim512_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim512_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim1024_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim1024_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim128_t8_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim128_t8_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim256_t16_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim256_t16_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim512_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim512_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim1024_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim1024_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim128_t8_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim128_t8_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim256_t16_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim256_t16_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim512_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim512_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim1024_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim1024_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim128_t8_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim128_t8_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim256_t16_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim256_t16_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim512_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim512_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim1024_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim1024_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim128_t8_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim128_t8_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim256_t16_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim256_t16_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim512_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim512_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim1024_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim1024_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim128_t8_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim128_t8_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim256_t16_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim256_t16_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim512_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim512_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_00_generate.py create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim1024_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim1024_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim128_t8_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim128_t8_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim256_t16_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim256_t16_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim512_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim512_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim1024_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim1024_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim128_t8_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim128_t8_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim256_t16_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim256_t16_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim512_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim512_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim1024_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim1024_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim128_t8_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim128_t8_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim256_t16_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim256_t16_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim512_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim512_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim1024_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim1024_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim128_t8_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim128_t8_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim256_t16_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim256_t16_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim512_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim512_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim1024_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim1024_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim128_t8_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim128_t8_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim256_t16_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim256_t16_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim512_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim512_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim1024_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim1024_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim128_t8_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim128_t8_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim256_t16_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim256_t16_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim512_t32_8pq_2subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim512_t32_8pq_4subd_half.cu create mode 100644 cpp/src/neighbors/detail/cagra/search_multi_cta.cuh create mode 100644 cpp/src/neighbors/detail/cagra/search_single_cta.cuh create mode 100755 cpp/test/neighbors/ann_cagra_vpq.cuh create mode 100644 cpp/test/neighbors/ann_cagra_vpq/test_float_int64_t.cu create mode 100644 cpp/test/neighbors/ann_cagra_vpq/test_float_uint32_t.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6107b9325a..cbae4bfb3f 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -372,6 +372,102 @@ if(RAFT_COMPILE_LIBRARY) src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim128_t8_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim128_t8_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim256_t16_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim256_t16_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim512_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim512_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim1024_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim1024_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim128_t8_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim128_t8_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim256_t16_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim256_t16_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim512_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim512_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim1024_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim1024_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim128_t8_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim128_t8_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim256_t16_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim256_t16_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim512_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim512_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim1024_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim1024_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim128_t8_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim128_t8_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim256_t16_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim256_t16_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim512_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim512_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim1024_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim1024_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim128_t8_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim128_t8_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim256_t16_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim256_t16_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim512_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim512_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim1024_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim1024_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim128_t8_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim128_t8_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim256_t16_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim256_t16_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim512_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim512_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim1024_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim1024_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim128_t8_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim128_t8_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim256_t16_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim256_t16_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim512_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim512_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim1024_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim1024_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim128_t8_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim128_t8_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim256_t16_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim256_t16_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim512_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim512_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim1024_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim1024_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim128_t8_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim128_t8_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim256_t16_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim256_t16_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim512_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim512_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim1024_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim1024_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim128_t8_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim128_t8_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim256_t16_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim256_t16_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim512_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim512_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim1024_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim1024_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim128_t8_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim128_t8_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim256_t16_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim256_t16_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim512_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim512_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim1024_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim1024_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim128_t8_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim128_t8_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim256_t16_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim256_t16_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim512_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim512_t32_8pq_4subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim1024_t32_8pq_2subd_half.cu + src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim1024_t32_8pq_4subd_half.cu src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp index e7a3ba97a4..a6444775f4 100644 --- a/cpp/include/raft/neighbors/dataset.hpp +++ b/cpp/include/raft/neighbors/dataset.hpp @@ -72,7 +72,7 @@ struct strided_dataset : public dataset { return static_cast(v.stride(0) > 0 ? v.stride(0) : v.extent(1)); } /** Get the view of the data. */ - [[nodiscard]] virtual auto view() const noexcept -> view_type; + [[nodiscard]] virtual auto view() const noexcept -> view_type = 0; }; template diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 0832e75633..d30f69ddcd 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -16,6 +16,7 @@ #pragma once +#include "compute_distance_vpq.cuh" #include "factory.cuh" #include "search_plan.cuh" #include "search_single_cta.cuh" @@ -77,46 +78,24 @@ inline return filter; } -/** - * @brief Search ANN using the constructed index. - * - * See the [build](#build) documentation for a usage example. - * - * @tparam T data element type - * @tparam IdxT type of database vector indices - * @tparam internal_IdxT during search we map IdxT to internal_IdxT, this way we do not need - * separate kernels for int/uint. - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] idx ivf-pq constructed index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ - -template -void search_main(raft::resources const& res, - search_params params, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - CagraSampleFilterT sample_filter = CagraSampleFilterT()) +template +void search_main_core( + raft::resources const& res, + search_params params, + DatasetDescriptorT dataset_desc, + raft::device_matrix_view graph, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + CagraSampleFilterT sample_filter = CagraSampleFilterT()) { RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n", - static_cast(index.dataset().extent(0)), - static_cast(index.dataset().extent(1))); + static_cast(index.data().n_rows()), + static_cast(index.data().dim())); RAFT_LOG_DEBUG("# query size = %lu, dim = %lu\n", static_cast(queries.extent(0)), static_cast(queries.extent(1))); - RAFT_EXPECTS(queries.extent(1) == index.dim(), "Queries and index dim must match"); + RAFT_EXPECTS(queries.extent(1) == dataset_desc.dim, "Queries and index dim must match"); const uint32_t topk = neighbors.extent(1); cudaDeviceProp deviceProp = resource::get_device_properties(res); @@ -125,12 +104,15 @@ void search_main(raft::resources const& res, } common::nvtx::range fun_scope( - "cagra::search(max_queries = %u, k = %u, dim = %zu)", params.max_queries, topk, index.dim()); + "cagra::search(max_queries = %u, k = %u, dim = %zu)", + params.max_queries, + topk, + dataset_desc.dim); using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector::type; - std::unique_ptr> plan = - factory::create( - res, params, index.dim(), index.graph_degree(), topk); + std::unique_ptr> plan = + factory::create( + res, params, dataset_desc.dim, graph.extent(1), topk); plan->check(topk); @@ -140,30 +122,22 @@ void search_main(raft::resources const& res, for (unsigned qid = 0; qid < queries.extent(0); qid += max_queries) { const uint32_t n_queries = std::min(max_queries, queries.extent(0) - qid); - internal_IdxT* _topk_indices_ptr = - reinterpret_cast(neighbors.data_handle()) + (topk * qid); - DistanceT* _topk_distances_ptr = distances.data_handle() + (topk * qid); + auto _topk_indices_ptr = + reinterpret_cast(neighbors.data_handle()) + + (topk * qid); + auto _topk_distances_ptr = distances.data_handle() + (topk * qid); // todo(tfeher): one could keep distances optional and pass nullptr - const T* _query_ptr = queries.data_handle() + (query_dim * qid); - const internal_IdxT* _seed_ptr = + const auto* _query_ptr = queries.data_handle() + (query_dim * qid); + const auto* _seed_ptr = plan->num_seeds > 0 - ? reinterpret_cast(plan->dev_seed.data()) + (plan->num_seeds * qid) + ? reinterpret_cast(plan->dev_seed.data()) + + (plan->num_seeds * qid) : nullptr; uint32_t* _num_executed_iterations = nullptr; - auto dataset_internal = - make_device_strided_matrix_view(index.dataset().data_handle(), - index.dataset().extent(0), - index.dataset().extent(1), - index.dataset().stride(0)); - auto graph_internal = raft::make_device_matrix_view( - reinterpret_cast(index.graph().data_handle()), - index.graph().extent(0), - index.graph().extent(1)); - (*plan)(res, - dataset_internal, - graph_internal, + dataset_desc, + graph, _topk_indices_ptr, _topk_distances_ptr, _query_ptr, @@ -173,6 +147,146 @@ void search_main(raft::resources const& res, topk, set_offset(sample_filter, qid)); } +} + +template +void launch_vpq_search_main_core( + raft::resources const& res, + const vpq_dataset* vpq_dset, + search_params params, + raft::device_matrix_view graph, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + CagraSampleFilterT sample_filter) +{ + RAFT_EXPECTS(vpq_dset->pq_bits() == 8, "Only pq_bits = 8 is supported for now"); + RAFT_EXPECTS(vpq_dset->pq_len() == 2, "Only pq_len 2 is supported for now"); + RAFT_EXPECTS(vpq_dset->dim() % vpq_dset->pq_dim() == 0, + "dim must be a multiple of pq_dim at the moment"); + + const float vq_scale = 1.0f; + const float pq_scale = 1.0f; + + if (vpq_dset->pq_bits() == 8) { + if (vpq_dset->pq_len() == 2) { + using dataset_desc_t = cagra_q_dataset_descriptor_t; + dataset_desc_t dataset_desc(vpq_dset->data.data_handle(), + vpq_dset->encoded_row_length(), + vpq_dset->pq_dim(), + vpq_dset->vq_code_book.data_handle(), + vq_scale, + vpq_dset->pq_code_book.data_handle(), + pq_scale, + size_t(vpq_dset->n_rows()), + vpq_dset->dim()); + search_main_core( + res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter); + } else if (vpq_dset->pq_len() == 4) { + using dataset_desc_t = cagra_q_dataset_descriptor_t; + dataset_desc_t dataset_desc(vpq_dset->data.data_handle(), + vpq_dset->encoded_row_length(), + vpq_dset->pq_dim(), + vpq_dset->vq_code_book.data_handle(), + vq_scale, + vpq_dset->pq_code_book.data_handle(), + pq_scale, + size_t(vpq_dset->n_rows()), + vpq_dset->dim()); + search_main_core( + res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter); + } else { + RAFT_FAIL("Subspace dimension must be 2 or 4"); + } + } else { + RAFT_FAIL("Only 8-bit PQ is supported now"); + } +} + +/** + * @brief Search ANN using the constructed index. + * + * See the [build](#build) documentation for a usage example. + * + * @tparam T data element type + * @tparam IdxT type of database vector indices + * @tparam internal_IdxT during search we map IdxT to internal_IdxT, this way we do not need + * separate kernels for int/uint. + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search_main(raft::resources const& res, + search_params params, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + CagraSampleFilterT sample_filter = CagraSampleFilterT()) +{ + const auto& graph = index.graph(); + auto graph_internal = raft::make_device_matrix_view( + reinterpret_cast(graph.data_handle()), graph.extent(0), graph.extent(1)); + + // n_rows has the same type as the dataset index (the array extents type) + using ds_idx_type = decltype(index.data().n_rows()); + // Dispatch search parameters based on the dataset kind. + if (auto* strided_dset = dynamic_cast*>(&index.data()); + strided_dset != nullptr) { + // Set TEAM_SIZE and DATASET_BLOCK_SIZE to zero tentatively since these parameters cannot be + // determined here. They are set just before kernel launch. + using dataset_desc_t = standard_dataset_descriptor_t; + // Search using a plain (strided) row-major dataset + const dataset_desc_t dataset_desc(strided_dset->view().data_handle(), + strided_dset->n_rows(), + strided_dset->dim(), + strided_dset->stride()); + + search_main_core( + res, params, dataset_desc, graph_internal, queries, neighbors, distances, sample_filter); + } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); + vpq_dset != nullptr) { + // Search using a compressed dataset + RAFT_FAIL("FP32 VPQ dataset support is coming soon"); + } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); + vpq_dset != nullptr) { + launch_vpq_search_main_core( + res, vpq_dset, params, graph_internal, queries, neighbors, distances, sample_filter); + } else if (auto* empty_dset = dynamic_cast*>(&index.data()); + empty_dset != nullptr) { + // Forgot to add a dataset. + RAFT_FAIL( + "Attempted to search without a dataset. Please call index.update_dataset(...) first."); + } else { + // This is a logic error. + RAFT_FAIL("Unrecognized dataset format"); + } static_assert(std::is_same_v, "only float distances are supported at the moment"); diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp index 3732dcf3fe..49e14be73d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp @@ -20,6 +20,7 @@ #include "utils.hpp" #include +#include #include @@ -36,152 +37,16 @@ _RAFT_DEVICE constexpr unsigned get_vlen() return utils::size_of() / utils::size_of(); } -template -struct data_load_t { - union { - LOAD_T load; - DATA_T data[VLEN]; - }; -}; - -template -struct distance_op; -template -struct distance_op { - const float* const query_buffer; - __device__ distance_op(const float* const query_buffer) : query_buffer(query_buffer) {} - - __device__ DISTANCE_T operator()(const DATA_T* const dataset_ptr, - const std::uint32_t dataset_dim, - const bool valid) - { - const unsigned lane_id = threadIdx.x % TEAM_SIZE; - constexpr unsigned vlen = get_vlen(); - constexpr unsigned reg_nelem = - (DATASET_BLOCK_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen); - data_load_t dl_buff[reg_nelem]; - - DISTANCE_T norm2 = 0; - if (valid) { - for (uint32_t elem_offset = 0; elem_offset < dataset_dim; elem_offset += DATASET_BLOCK_DIM) { -#pragma unroll - for (uint32_t e = 0; e < reg_nelem; e++) { - const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset; - if (k >= dataset_dim) break; - dl_buff[e].load = *reinterpret_cast(dataset_ptr + k); - } -#pragma unroll - for (uint32_t e = 0; e < reg_nelem; e++) { - const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset; - if (k >= dataset_dim) break; -#pragma unroll - for (uint32_t v = 0; v < vlen; v++) { - const uint32_t kv = k + v; - // if (kv >= dataset_dim) break; - DISTANCE_T diff = query_buffer[device::swizzling(kv)]; - diff -= spatial::knn::detail::utils::mapping{}(dl_buff[e].data[v]); - norm2 += diff * diff; - } - } - } - } - for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { - norm2 += __shfl_xor_sync(0xffffffff, norm2, offset); - } - return norm2; - } -}; -template -struct distance_op { - static constexpr unsigned N_FRAGS = (DATASET_BLOCK_DIM + TEAM_SIZE - 1) / TEAM_SIZE; - float query_frags[N_FRAGS]; - - __device__ distance_op(const float* const query_buffer) - { - constexpr unsigned vlen = get_vlen(); - constexpr unsigned reg_nelem = - (DATASET_BLOCK_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen); - const std::uint32_t lane_id = threadIdx.x % TEAM_SIZE; - // Pre-load query vectors into registers when register usage is not too large. -#pragma unroll - for (unsigned e = 0; e < reg_nelem; e++) { - const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; - // if (k >= dataset_dim) break; -#pragma unroll - for (unsigned v = 0; v < vlen; v++) { - const unsigned kv = k + v; - const unsigned ev = (vlen * e) + v; - query_frags[ev] = query_buffer[device::swizzling(kv)]; - } - } - } - - __device__ DISTANCE_T operator()(const DATA_T* const dataset_ptr, - const std::uint32_t dataset_dim, - const bool valid) - { - const unsigned lane_id = threadIdx.x % TEAM_SIZE; - constexpr unsigned vlen = get_vlen(); - constexpr unsigned reg_nelem = - (DATASET_BLOCK_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen); - data_load_t dl_buff[reg_nelem]; - - DISTANCE_T norm2 = 0; - if (valid) { -#pragma unroll - for (unsigned e = 0; e < reg_nelem; e++) { - const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; - if (k >= dataset_dim) break; - dl_buff[e].load = *reinterpret_cast(dataset_ptr + k); - } -#pragma unroll - for (unsigned e = 0; e < reg_nelem; e++) { - const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; - if (k >= dataset_dim) break; -#pragma unroll - for (unsigned v = 0; v < vlen; v++) { - DISTANCE_T diff; - const unsigned ev = (vlen * e) + v; - diff = query_frags[ev]; - diff -= spatial::knn::detail::utils::mapping{}(dl_buff[e].data[v]); - norm2 += diff * diff; - } - } - } - for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { - norm2 += __shfl_xor_sync(0xffffffff, norm2, offset); - } - return norm2; - } -}; - template _RAFT_DEVICE void compute_distance_to_random_nodes( INDEX_T* const result_indices_ptr, // [num_pickup] DISTANCE_T* const result_distances_ptr, // [num_pickup] - const float* const query_buffer, - const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] - const std::size_t dataset_dim, - const std::size_t dataset_size, - const std::size_t dataset_ld, + const typename DATASET_DESCRIPTOR_T::QUERY_T* const query_buffer, + const DATASET_DESCRIPTOR_T& dataset_desc, const std::size_t num_pickup, const unsigned num_distilation, const uint64_t rand_xor_mask, @@ -195,9 +60,6 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( uint32_t max_i = num_pickup; if (max_i % (32 / TEAM_SIZE)) { max_i += (32 / TEAM_SIZE) - (max_i % (32 / TEAM_SIZE)); } - distance_op dist_op( - query_buffer); - for (uint32_t i = threadIdx.x / TEAM_SIZE; i < max_i; i += blockDim.x / TEAM_SIZE) { const bool valid_i = (i < num_pickup); @@ -212,11 +74,12 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( if (seed_ptr && (gid < num_seeds)) { seed_index = seed_ptr[gid]; } else { - seed_index = device::xorshift64(gid ^ rand_xor_mask) % dataset_size; + seed_index = device::xorshift64(gid ^ rand_xor_mask) % dataset_desc.size; } } - const auto norm2 = dist_op(dataset_ptr + dataset_ld * seed_index, dataset_dim, valid_i); + const auto norm2 = dataset_desc.template compute_similarity( + query_buffer, seed_index, valid_i); if (valid_i && (norm2 < best_norm2_team_local)) { best_norm2_team_local = norm2; @@ -240,27 +103,25 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( template -_RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_indices_ptr, - DISTANCE_T* const result_child_distances_ptr, - // query - const float* const query_buffer, - // [dataset_dim, dataset_size] - const DATA_T* const dataset_ptr, - const std::size_t dataset_dim, - const std::size_t dataset_ld, - // [knn_k, dataset_size] - const INDEX_T* const knn_graph, - const std::uint32_t knn_k, - // hashmap - INDEX_T* const visited_hashmap_ptr, - const std::uint32_t hash_bitlen, - const INDEX_T* const parent_indices, - const INDEX_T* const internal_topk_list, - const std::uint32_t search_width) +_RAFT_DEVICE void compute_distance_to_child_nodes( + INDEX_T* const result_child_indices_ptr, + DISTANCE_T* const result_child_distances_ptr, + // query + const typename DATASET_DESCRIPTOR_T::QUERY_T* const query_buffer, + // [dataset_dim, dataset_size] + const DATASET_DESCRIPTOR_T& dataset_desc, + // [knn_k, dataset_size] + const INDEX_T* const knn_graph, + const std::uint32_t knn_k, + // hashmap + INDEX_T* const visited_hashmap_ptr, + const std::uint32_t hash_bitlen, + const INDEX_T* const parent_indices, + const INDEX_T* const internal_topk_list, + const std::uint32_t search_width) { constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; const INDEX_T invalid_index = utils::get_max_value(); @@ -281,16 +142,6 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in } result_child_indices_ptr[i] = child_id; } - - // [Notice] - // Loading the query vector here from shared memory into registers reduces - // shared memory trafiic. However, register usage increase. The - // MAX_N_FRAGS below is used as the threshold to enable or disable this, - // but the appropriate value should be discussed. - constexpr unsigned N_FRAGS = (DATASET_BLOCK_DIM + TEAM_SIZE - 1) / TEAM_SIZE; - constexpr bool use_fragment = N_FRAGS <= MAX_N_FRAGS; - distance_op dist_op( - query_buffer); __syncthreads(); // Compute the distance to child nodes @@ -302,8 +153,8 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in INDEX_T child_id = invalid_index; if (valid_i) { child_id = result_child_indices_ptr[i]; } - DISTANCE_T norm2 = - dist_op(dataset_ptr + child_id * dataset_ld, dataset_dim, child_id != invalid_index); + const auto norm2 = dataset_desc.template compute_similarity( + query_buffer, child_id, child_id != invalid_index); // Store the distance const unsigned lane_id = threadIdx.x % TEAM_SIZE; @@ -318,4 +169,101 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in } } // namespace device + +template +struct dataset_descriptor_base_t { + using INDEX_T = INDEX_T_; + using QUERY_T = QUERY_T_; + using DISTANCE_T = DISTANCE_T_; + + const INDEX_T size; + const std::uint32_t dim; + + dataset_descriptor_base_t(const INDEX_T size, const std::uint32_t dim) : size(size), dim(dim) {} +}; + +template +struct standard_dataset_descriptor_t + : public dataset_descriptor_base_t { + using LOAD_T = device::LOAD_128BIT_T; + using DATA_T = DATA_T_; + using QUERY_T = typename dataset_descriptor_base_t::QUERY_T; + + const DATA_T* const ptr; + const std::size_t ld; + using dataset_descriptor_base_t::size; + using dataset_descriptor_base_t::dim; + + standard_dataset_descriptor_t(const DATA_T* const ptr, + const std::size_t size, + const std::uint32_t dim, + const std::size_t ld) + : dataset_descriptor_base_t(size, dim), ptr(ptr), ld(ld) + { + } + + static const std::uint32_t smem_buffer_size_in_byte = 0; + __device__ void set_smem_ptr(void* const){}; + + template + __device__ void copy_query(const DATA_T* const dmem_query_ptr, + QUERY_T* const smem_query_ptr, + const std::uint32_t query_smem_buffer_length) + { + for (unsigned i = threadIdx.x; i < query_smem_buffer_length; i += blockDim.x) { + unsigned j = device::swizzling(i); + if (i < dim) { + smem_query_ptr[j] = spatial::knn::detail::utils::mapping{}(dmem_query_ptr[i]); + } else { + smem_query_ptr[j] = 0.0; + } + } + } + + template + __device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr, + const INDEX_T dataset_i, + const bool valid) const + { + const auto dataset_ptr = ptr + dataset_i * ld; + const unsigned lane_id = threadIdx.x % TEAM_SIZE; + constexpr unsigned vlen = device::get_vlen(); + // #include (DATASET_BLOCK_DIM, TEAM_SIZE * vlen); + raft::TxN_t dl_buff[reg_nelem]; + + DISTANCE_T norm2 = 0; + if (valid) { + for (uint32_t elem_offset = 0; elem_offset < dim; elem_offset += DATASET_BLOCK_DIM) { +#pragma unroll + for (uint32_t e = 0; e < reg_nelem; e++) { + const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset; + if (k >= dim) break; + dl_buff[e].load(dataset_ptr, k); + } +#pragma unroll + for (uint32_t e = 0; e < reg_nelem; e++) { + const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset; + if (k >= dim) break; +#pragma unroll + for (uint32_t v = 0; v < vlen; v++) { + const uint32_t kv = k + v; + // Note this loop can go above the dataset_dim for padded arrays. This is not a problem + // because: + // - Above the last element (dataset_dim-1), the query array is filled with zeros. + // - The data buffer has to be also padded with zeros. + DISTANCE_T diff = query_ptr[device::swizzling(kv)]; + diff -= spatial::knn::detail::utils::mapping{}(dl_buff[e].val.data[v]); + norm2 += diff * diff; + } + } + } + } + for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { + norm2 += __shfl_xor_sync(0xffffffff, norm2, offset); + } + return norm2; + } +}; + } // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh b/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh new file mode 100644 index 0000000000..0204addba7 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh @@ -0,0 +1,227 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "compute_distance.hpp" + +#include + +namespace raft::neighbors::cagra::detail { +template +struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t { + using LOAD_T = device::LOAD_128BIT_T; + using DATA_T = DATA_T_; + using CODE_BOOK_T = CODE_BOOK_T_; + using QUERY_T = typename dataset_descriptor_base_t::QUERY_T; + + const std::uint8_t* encoded_dataset_ptr; + const std::uint32_t encoded_dataset_dim; + const std::uint32_t n_subspace; + const CODE_BOOK_T* vq_code_book_ptr; + const float vq_scale; + const CODE_BOOK_T* pq_code_book_ptr; + const float pq_scale; + using dataset_descriptor_base_t::size; + using dataset_descriptor_base_t::dim; + + // Set on device + CODE_BOOK_T* smem_pq_code_book_ptr; + static const std::uint32_t smem_buffer_size_in_byte = + (1 << PQ_BITS) * PQ_LEN * utils::size_of(); + + __device__ void set_smem_ptr(void* const smem_ptr) + { + smem_pq_code_book_ptr = reinterpret_cast(smem_ptr); + + // Copy PQ table + if constexpr (std::is_same::value) { + for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) { + half2 buf2; + buf2.x = pq_code_book_ptr[i]; + buf2.y = pq_code_book_ptr[i + 1]; + (reinterpret_cast(smem_pq_code_book_ptr + i))[0] = buf2; + } + } else { + for (unsigned i = threadIdx.x; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x) { + // TODO: vectorize + smem_pq_code_book_ptr[i] = pq_code_book_ptr[i]; + } + } + } + + cagra_q_dataset_descriptor_t(const std::uint8_t* encoded_dataset_ptr, + const std::uint32_t encoded_dataset_dim, + const std::uint32_t n_subspace, + const CODE_BOOK_T* const vq_code_book_ptr, + const float vq_scale, + const CODE_BOOK_T* const pq_code_book_ptr, + const float pq_scale, + const std::size_t size, + const std::uint32_t dim) + : dataset_descriptor_base_t(size, dim), + encoded_dataset_ptr(encoded_dataset_ptr), + encoded_dataset_dim(encoded_dataset_dim), + n_subspace(n_subspace), + vq_code_book_ptr(vq_code_book_ptr), + vq_scale(vq_scale), + pq_code_book_ptr(pq_code_book_ptr), + pq_scale(pq_scale) + { + } + + template + __device__ void copy_query(const DATA_T* const dmem_query_ptr, + QUERY_T* const smem_query_ptr, + const std::uint32_t query_smem_buffer_length) + { + constexpr spatial::knn::detail::utils::mapping mapping{}; + for (unsigned i = threadIdx.x * 2; i < dim; i += blockDim.x * 2) { + half2 buf2{0, 0}; + if (i < dim) { buf2.x = mapping(dmem_query_ptr[i]); } + if (i + 1 < dim) { buf2.y = mapping(dmem_query_ptr[i + 1]); } + if ((PQ_BITS == 8) && (PQ_LEN % 2 == 0)) { + // Use swizzling in the condition to reduce bank conflicts in shared + // memory, which are likely to occur when pq_code_book_dim is large. + ((half2*)smem_query_ptr)[device::swizzling(i / 2)] = + buf2; + } else { + (reinterpret_cast(smem_query_ptr + i))[0] = buf2; + } + } + } + + template + __device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr, + const INDEX_T node_id, + const bool valid) const + { + float norm = 0; + if (valid) { + const unsigned lane_id = threadIdx.x % TEAM_SIZE; + const uint32_t vq_code = *(reinterpret_cast( + encoded_dataset_ptr + (static_cast(encoded_dataset_dim) * node_id))); + if (PQ_BITS == 8) { + for (uint32_t elem_offset = 0; elem_offset < dim; elem_offset += DATASET_BLOCK_DIM) { + constexpr unsigned vlen = 4; // **** DO NOT CHANGE **** + constexpr unsigned nelem = + raft::div_rounding_up_unsafe(DATASET_BLOCK_DIM / PQ_LEN, TEAM_SIZE * vlen); + // Loading PQ codes + uint32_t pq_codes[nelem]; +#pragma unroll + for (std::uint32_t e = 0; e < nelem; e++) { + const std::uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset / PQ_LEN; + if (k >= n_subspace) break; + // Loading 4 x 8-bit PQ-codes using 32-bit load ops (from device memory) + pq_codes[e] = *(reinterpret_cast( + encoded_dataset_ptr + (static_cast(encoded_dataset_dim) * node_id) + + 4 + k)); + } + // + if constexpr ((std::is_same::value) && (PQ_LEN % 2 == 0)) { + // **** Use half2 for distance computation **** + half2 norm2{0, 0}; +#pragma unroll + for (std::uint32_t e = 0; e < nelem; e++) { + const std::uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset / PQ_LEN; + if (k >= n_subspace) break; + // Loading VQ code-book + raft::TxN_t vq_vals[PQ_LEN]; +#pragma unroll + for (std::uint32_t m = 0; m < PQ_LEN; m += 1) { + const uint32_t d = (vlen * m) + (PQ_LEN * k); + if (d >= dim) break; + vq_vals[m].load( + reinterpret_cast(vq_code_book_ptr + d + (dim * vq_code)), 0); + } + // Compute distance + std::uint32_t pq_code = pq_codes[e]; +#pragma unroll + for (std::uint32_t v = 0; v < vlen; v++) { + if (PQ_LEN * (v + k) >= dim) break; +#pragma unroll + for (std::uint32_t m = 0; m < PQ_LEN; m += 2) { + const std::uint32_t d1 = m + (PQ_LEN * v); + const std::uint32_t d = d1 + (PQ_LEN * k); + // Loading query vector in smem + half2 diff2 = (reinterpret_cast( + query_ptr))[device::swizzling(d / 2)]; + // Loading PQ code book in smem + diff2 -= *(reinterpret_cast( + smem_pq_code_book_ptr + (1 << PQ_BITS) * 2 * (m / 2) + (2 * (pq_code & 0xff)))); + diff2 -= vq_vals[d1 / vlen].val.data[(d1 % vlen) / 2]; + norm2 += diff2 * diff2; + } + pq_code >>= 8; + } + } + norm += static_cast(norm2.x + norm2.y); + } else { + // **** Use float for distance computation **** +#pragma unroll + for (std::uint32_t e = 0; e < nelem; e++) { + const std::uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset / PQ_LEN; + if (k >= n_subspace) break; + // Loading VQ code-book + raft::TxN_t vq_vals[PQ_LEN]; +#pragma unroll + for (std::uint32_t m = 0; m < PQ_LEN; m++) { + const std::uint32_t d = (vlen * m) + (PQ_LEN * k); + if (d >= dim) break; + // Loading 4 x 8/16-bit VQ-values using 32/64-bit load ops (from L2$ or device + // memory) + vq_vals[m].load( + reinterpret_cast(vq_code_book_ptr + d + (dim * vq_code)), 0); + } + // Compute distance + std::uint32_t pq_code = pq_codes[e]; +#pragma unroll + for (std::uint32_t v = 0; v < vlen; v++) { + if (PQ_LEN * (v + k) >= dim) break; + raft::TxN_t pq_vals; + pq_vals.load( + reinterpret_cast(smem_pq_code_book_ptr + PQ_LEN * (pq_code & 0xff)), + 0); // (from L1$ or smem) +#pragma unroll + for (std::uint32_t m = 0; m < PQ_LEN; m++) { + const std::uint32_t d1 = m + (PQ_LEN * v); + const std::uint32_t d = d1 + (PQ_LEN * k); + // if (d >= dataset_dim) break; + DISTANCE_T diff = query_ptr[d]; // (from smem) + diff -= pq_scale * static_cast(pq_vals.data[m]); + diff -= vq_scale * static_cast(vq_vals[d1 / vlen].val.data[d1 % vlen]); + norm += diff * diff; + } + pq_code >>= 8; + } + } + } + } + } + } + for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { + norm += __shfl_xor_sync(0xffffffff, norm, offset); + } + return norm; + } +}; + +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/device_common.hpp b/cpp/include/raft/neighbors/detail/cagra/device_common.hpp index cd7469b55e..d4d69e6a67 100644 --- a/cpp/include/raft/neighbors/detail/cagra/device_common.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/device_common.hpp @@ -42,13 +42,17 @@ _RAFT_HOST_DEVICE inline uint64_t xorshift64(uint64_t u) return u * 0x2545F4914F6CDD1DULL; } -template +template _RAFT_DEVICE inline T swizzling(T x) { // Address swizzling reduces bank conflicts in shared memory, but increases // the amount of operation instead. // return x; - return x ^ (x >> 5); // "x" must be less than 1024 + if constexpr (X_MAX <= 1024) { + return (x) ^ ((x) >> 5); + } else { + return (x) ^ (((x) >> 5) & 0x1f); + } } } // namespace device diff --git a/cpp/include/raft/neighbors/detail/cagra/factory.cuh b/cpp/include/raft/neighbors/detail/cagra/factory.cuh index 0aee912e25..4944b57c46 100644 --- a/cpp/include/raft/neighbors/detail/cagra/factory.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/factory.cuh @@ -25,16 +25,18 @@ namespace raft::neighbors::cagra::detail { -template class factory { + using T = typename DATASET_DESCRIPTOR_T::DATA_T; + using IdxT = typename DATASET_DESCRIPTOR_T::INDEX_T; + using DistanceT = typename DATASET_DESCRIPTOR_T::DISTANCE_T; + public: /** * Create a search structure for dataset with dim features. */ - static std::unique_ptr> create( + static std::unique_ptr> create( raft::resources const& res, search_params const& params, int64_t dim, @@ -63,28 +65,28 @@ class factory { break; default: THROW("Incorrect dataset_block_dim (%lu)\n", plan.dataset_block_dim); } - return std::unique_ptr>(); + return std::unique_ptr>(); } private: template - static std::unique_ptr> dispatch_kernel( - raft::resources const& res, search_plan_impl_base& plan) + static std::unique_ptr> + dispatch_kernel(raft::resources const& res, search_plan_impl_base& plan) { if (plan.algo == search_algo::SINGLE_CTA) { - return std::unique_ptr>( + return std::unique_ptr>( new single_cta_search:: - search( + search( res, plan, plan.dim, plan.graph_degree, plan.topk)); } else if (plan.algo == search_algo::MULTI_CTA) { - return std::unique_ptr>( + return std::unique_ptr>( new multi_cta_search:: - search( + search( res, plan, plan.dim, plan.graph_degree, plan.topk)); } else { - return std::unique_ptr>( + return std::unique_ptr>( new multi_kernel_search:: - search( + search( res, plan, plan.dim, plan.graph_degree, plan.topk)); } } diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 1fcd159959..8192b1ae51 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -45,44 +45,46 @@ namespace multi_cta_search { template -struct search : public search_plan_impl { - using search_plan_impl::max_queries; - using search_plan_impl::itopk_size; - using search_plan_impl::algo; - using search_plan_impl::team_size; - using search_plan_impl::search_width; - using search_plan_impl::min_iterations; - using search_plan_impl::max_iterations; - using search_plan_impl::thread_block_size; - using search_plan_impl::hashmap_mode; - using search_plan_impl::hashmap_min_bitlen; - using search_plan_impl::hashmap_max_fill_rate; - using search_plan_impl::num_random_samplings; - using search_plan_impl::rand_xor_mask; +struct search : public search_plan_impl { + using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; + using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; + using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - using search_plan_impl::dim; - using search_plan_impl::graph_degree; - using search_plan_impl::topk; + using search_plan_impl::max_queries; + using search_plan_impl::itopk_size; + using search_plan_impl::algo; + using search_plan_impl::team_size; + using search_plan_impl::search_width; + using search_plan_impl::min_iterations; + using search_plan_impl::max_iterations; + using search_plan_impl::thread_block_size; + using search_plan_impl::hashmap_mode; + using search_plan_impl::hashmap_min_bitlen; + using search_plan_impl::hashmap_max_fill_rate; + using search_plan_impl::num_random_samplings; + using search_plan_impl::rand_xor_mask; - using search_plan_impl::hash_bitlen; + using search_plan_impl::dim; + using search_plan_impl::graph_degree; + using search_plan_impl::topk; - using search_plan_impl::small_hash_bitlen; - using search_plan_impl::small_hash_reset_interval; - using search_plan_impl::hashmap_size; - using search_plan_impl::dataset_size; - using search_plan_impl::result_buffer_size; + using search_plan_impl::hash_bitlen; - using search_plan_impl::smem_size; + using search_plan_impl::small_hash_bitlen; + using search_plan_impl::small_hash_reset_interval; + using search_plan_impl::hashmap_size; + using search_plan_impl::dataset_size; + using search_plan_impl::result_buffer_size; - using search_plan_impl::hashmap; - using search_plan_impl::num_executed_iterations; - using search_plan_impl::dev_seed; - using search_plan_impl::num_seeds; + using search_plan_impl::smem_size; + + using search_plan_impl::hashmap; + using search_plan_impl::num_executed_iterations; + using search_plan_impl::dev_seed; + using search_plan_impl::num_seeds; uint32_t num_cta_per_query; rmm::device_uvector intermediate_indices; @@ -95,8 +97,7 @@ struct search : public search_plan_impl( - res, params, dim, graph_degree, topk), + : search_plan_impl(res, params, dim, graph_degree, topk), intermediate_indices(0, resource::get_cuda_stream(res)), intermediate_distances(0, resource::get_cuda_stream(res)), topk_workspace(0, resource::get_cuda_stream(res)) @@ -120,9 +121,11 @@ struct search : public search_plan_impl(dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; + smem_size = sizeof(float) * query_smem_buffer_length + (sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 + - sizeof(uint32_t) * search_width + sizeof(uint32_t); + sizeof(uint32_t) * search_width + sizeof(uint32_t) + + DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte; RAFT_LOG_DEBUG("# smem_size: %u", smem_size); // @@ -191,22 +194,25 @@ struct search : public search_plan_impl dataset, - raft::device_matrix_view graph, - INDEX_T* const topk_indices_ptr, // [num_queries, topk] - DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const uint32_t num_queries, - const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk, - SAMPLE_FILTER_T sample_filter) + void operator()( + raft::resources const& res, + // raft::device_matrix_view dataset, + DATASET_DESCRIPTOR_T dataset_desc, + raft::device_matrix_view + graph, + typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, // [num_queries, topk] + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] + const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const uint32_t num_queries, + const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* const num_executed_iterations, // [num_queries,] + uint32_t topk, + SAMPLE_FILTER_T sample_filter) { cudaStream_t stream = resource::get_cuda_stream(res); - select_and_run( - dataset, + select_and_run( + dataset_desc, graph, intermediate_indices.data(), intermediate_distances.data(), diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh index 7a5ad17460..50f9e69593 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh @@ -15,6 +15,7 @@ */ #pragma once +#include #include // none_cagra_sample_filter #include // RAFT_EXPLICIT @@ -27,63 +28,66 @@ namespace multi_cta_search { template -void select_and_run(raft::device_matrix_view dataset, - raft::device_matrix_view graph, - INDEX_T* const topk_indices_ptr, - DISTANCE_T* const topk_distances_ptr, - const DATA_T* const queries_ptr, - const uint32_t num_queries, - const INDEX_T* dev_seed_ptr, - uint32_t* const num_executed_iterations, - uint32_t topk, - uint32_t block_size, - uint32_t result_buffer_size, - uint32_t smem_size, - int64_t hash_bitlen, - INDEX_T* hashmap_ptr, - uint32_t num_cta_per_query, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, - uint32_t num_seeds, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, - SAMPLE_FILTER_T sample_filter, - cudaStream_t stream) RAFT_EXPLICIT; +void select_and_run( + DATASET_DESCRIPTOR_T dataset_desc, + raft::device_matrix_view graph, + typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, + const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, + const uint32_t num_queries, + const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, + uint32_t* const num_executed_iterations, + uint32_t topk, + uint32_t block_size, + uint32_t result_buffer_size, + uint32_t smem_size, + int64_t hash_bitlen, + typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr, + uint32_t num_cta_per_query, + uint32_t num_random_samplings, + uint64_t rand_xor_mask, + uint32_t num_seeds, + size_t itopk_size, + size_t search_width, + size_t min_iterations, + size_t max_iterations, + SAMPLE_FILTER_T sample_filter, + cudaStream_t stream) RAFT_EXPLICIT; #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - extern template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ +#define instantiate_kernel_selection( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + extern template void select_and_run< \ + TEAM_SIZE, \ + MAX_DATASET_DIM, \ + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, \ + SAMPLE_FILTER_T>( \ + raft::neighbors::cagra::detail::standard_dataset_descriptor_t \ + dataset_desc, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); instantiate_kernel_selection( @@ -120,5 +124,292 @@ instantiate_kernel_selection( 32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection + +#define instantiate_q_kernel_selection(TEAM_SIZE, \ + MAX_DATASET_DIM, \ + CODE_BOOK_T, \ + PQ_BITS, \ + PQ_CODE_BOOK_DIM, \ + DATA_T, \ + INDEX_T, \ + DISTANCE_T, \ + SAMPLE_FILTER_T) \ + extern template void \ + select_and_run, \ + SAMPLE_FILTER_T>( \ + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t dataset_desc, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ + cudaStream_t stream); + +instantiate_q_kernel_selection( + 8, 128, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection( + 16, 256, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection( + 32, 512, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(32, + 1024, + half, + 8, + 2, + half, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection( + 8, 128, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection( + 16, 256, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection( + 32, 512, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(32, + 1024, + half, + 8, + 4, + half, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); + +instantiate_q_kernel_selection( + 8, 128, half, 8, 2, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(16, + 256, + half, + 8, + 2, + float, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(32, + 512, + half, + 8, + 2, + float, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(32, + 1024, + half, + 8, + 2, + float, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection( + 8, 128, half, 8, 4, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(16, + 256, + half, + 8, + 4, + float, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(32, + 512, + half, + 8, + 4, + float, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(32, + 1024, + half, + 8, + 4, + float, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); + +instantiate_q_kernel_selection(8, + 128, + half, + 8, + 2, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(16, + 256, + half, + 8, + 2, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(32, + 512, + half, + 8, + 2, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(32, + 1024, + half, + 8, + 2, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(8, + 128, + half, + 8, + 4, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(16, + 256, + half, + 8, + 4, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(32, + 512, + half, + 8, + 4, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(32, + 1024, + half, + 8, + 4, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); + +instantiate_q_kernel_selection(8, + 128, + half, + 8, + 2, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(16, + 256, + half, + 8, + 2, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(32, + 512, + half, + 8, + 2, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(32, + 1024, + half, + 8, + 2, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(8, + 128, + half, + 8, + 4, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(16, + 256, + half, + 8, + 4, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(32, + 512, + half, + 8, + 4, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_kernel_selection(32, + 1024, + half, + 8, + 4, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); + +#undef instantiate_q_kernel_selection } // namespace multi_cta_search } // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index 30f56780d6..48c22d9d14 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -123,29 +123,26 @@ __device__ inline void topk_by_bitonic_sort(float* distances, // [num_elements] // // multiple CTAs per single query // -template __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( - INDEX_T* const result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] - DISTANCE_T* const result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] - const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] - const size_t dataset_dim, - const size_t dataset_size, - const size_t dataset_ld, - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const INDEX_T* const knn_graph, // [dataset_size, graph_degree] + typename DATASET_DESCRIPTOR_T::INDEX_T* const + result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const + result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] + DATASET_DESCRIPTOR_T dataset_desc, + const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const typename DATASET_DESCRIPTOR_T::INDEX_T* const knn_graph, // [dataset_size, graph_degree] const uint32_t graph_degree, const unsigned num_distilation, const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] + const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] const uint32_t num_seeds, - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + typename DATASET_DESCRIPTOR_T::INDEX_T* const + visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] const uint32_t hash_bitlen, const uint32_t itopk_size, const uint32_t search_width, @@ -154,6 +151,11 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( uint32_t* const num_executed_iterations, /* stats */ SAMPLE_FILTER_T sample_filter) { + using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; + using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; + using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; + using QUERY_T = typename DATASET_DESCRIPTOR_T::QUERY_T; + const auto num_queries = gridDim.y; const auto query_id = blockIdx.y; const auto num_cta_per_query = gridDim.x; @@ -188,14 +190,20 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( assert(result_buffer_size_32 <= MAX_ELEMENTS); const auto query_smem_buffer_length = - raft::ceildiv(dataset_dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; - auto query_buffer = reinterpret_cast(smem); + raft::ceildiv(dataset_desc.dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; + auto query_buffer = reinterpret_cast(smem); auto result_indices_buffer = reinterpret_cast(query_buffer + query_smem_buffer_length); auto result_distances_buffer = reinterpret_cast(result_indices_buffer + result_buffer_size_32); auto parent_indices_buffer = reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto terminate_flag = reinterpret_cast(parent_indices_buffer + search_width); + auto distance_work_buffer_ptr = + reinterpret_cast(parent_indices_buffer + search_width); + auto terminate_flag = reinterpret_cast(distance_work_buffer_ptr + + DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte); + + // Set smem working buffer for the distance calculation + dataset_desc.set_smem_ptr(distance_work_buffer_ptr); #if 0 /* debug */ @@ -204,15 +212,10 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( result_distances_buffer[i] = utils::get_max_value(); } #endif - const DATA_T* const query_ptr = queries_ptr + (dataset_dim * query_id); - for (unsigned i = threadIdx.x; i < query_smem_buffer_length; i += blockDim.x) { - unsigned j = device::swizzling(i); - if (i < dataset_dim) { - query_buffer[j] = spatial::knn::detail::utils::mapping{}(query_ptr[i]); - } else { - query_buffer[j] = 0.0; - } - } + const DATA_T* const query_ptr = queries_ptr + (dataset_desc.dim * query_id); + dataset_desc.template copy_query( + query_ptr, query_buffer, query_smem_buffer_length); + if (threadIdx.x == 0) { terminate_flag[0] = 0; } INDEX_T* const local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); @@ -224,23 +227,19 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; uint32_t block_id = cta_id + (num_cta_per_query * query_id); uint32_t num_blocks = num_cta_per_query * num_queries; - device::compute_distance_to_random_nodes( - result_indices_buffer, - result_distances_buffer, - query_buffer, - dataset_ptr, - dataset_dim, - dataset_size, - dataset_ld, - result_buffer_size, - num_distilation, - rand_xor_mask, - local_seed_ptr, - num_seeds, - local_visited_hashmap_ptr, - hash_bitlen, - block_id, - num_blocks); + device::compute_distance_to_random_nodes(result_indices_buffer, + result_distances_buffer, + query_buffer, + dataset_desc, + result_buffer_size, + num_distilation, + rand_xor_mask, + local_seed_ptr, + num_seeds, + local_visited_hashmap_ptr, + hash_bitlen, + block_id, + num_blocks); __syncthreads(); _CLK_REC(clk_compute_1st_distance); @@ -272,13 +271,11 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( _CLK_START(); // constexpr unsigned max_n_frags = 16; constexpr unsigned max_n_frags = 0; - device::compute_distance_to_child_nodes( + device::compute_distance_to_child_nodes( result_indices_buffer + itopk_size, result_distances_buffer + itopk_size, query_buffer, - dataset_ptr, - dataset_dim, - dataset_ld, + dataset_desc, knn_graph, graph_degree, local_visited_hashmap_ptr, @@ -398,53 +395,35 @@ void set_value_batch(T* const dev_ptr, <<>>(dev_ptr, ld, val, count, batch_size); } -template struct search_kernel_config { // Search kernel function type. Note that the actual values for the template value // parameters do not matter, because they are not part of the function signature. The // second to fourth value parameters will be selected by the choose_* functions below. using kernel_t = decltype(&search_kernel); static auto choose_buffer_size(unsigned result_buffer_size, unsigned block_size) -> kernel_t { if (result_buffer_size <= 64) { - return search_kernel; + return search_kernel; } else if (result_buffer_size <= 128) { return search_kernel; } else if (result_buffer_size <= 256) { return search_kernel; } THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256); @@ -453,26 +432,24 @@ struct search_kernel_config { template -void select_and_run( // raft::resources const& res, - raft::device_matrix_view dataset, - raft::device_matrix_view graph, - INDEX_T* const topk_indices_ptr, // [num_queries, topk] - DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] +void select_and_run( + DATASET_DESCRIPTOR_T dataset_desc, + raft::device_matrix_view graph, + typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, // [num_queries, topk] + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] + const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] const uint32_t num_queries, - const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] + const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* const num_executed_iterations, // [num_queries,] uint32_t topk, // multi_cta_search (params struct) uint32_t block_size, // uint32_t result_buffer_size, uint32_t smem_size, int64_t hash_bitlen, - INDEX_T* hashmap_ptr, + typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr, uint32_t num_cta_per_query, uint32_t num_random_samplings, uint64_t rand_xor_mask, @@ -485,19 +462,20 @@ void select_and_run( // raft::resources const& res, cudaStream_t stream) { auto kernel = - search_kernel_config::choose_buffer_size(result_buffer_size, block_size); - - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + search_kernel_config:: + choose_buffer_size(result_buffer_size, block_size); + + RAFT_CUDA_TRY(cudaFuncSetAttribute(kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size + DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte)); // Initialize hash table const uint32_t hash_size = hashmap::get_size(hash_bitlen); - set_value_batch( - hashmap_ptr, hash_size, utils::get_max_value(), hash_size, num_queries, stream); + set_value_batch(hashmap_ptr, + hash_size, + utils::get_max_value(), + hash_size, + num_queries, + stream); dim3 block_dims(block_size, 1, 1); dim3 grid_dims(num_cta_per_query, num_queries, 1); @@ -508,10 +486,7 @@ void select_and_run( // raft::resources const& res, smem_size); kernel<<>>(topk_indices_ptr, topk_distances_ptr, - dataset.data_handle(), - dataset.extent(1), - dataset.extent(0), - dataset.stride(0), + dataset_desc, queries_ptr, graph.data_handle(), graph.extent(1), diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index e4a30675bb..10788da432 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -16,6 +16,7 @@ #pragma once #include "compute_distance.hpp" +#include "compute_distance_vpq.cuh" #include "device_common.hpp" #include "hashmap.hpp" #include "search_plan.cuh" @@ -86,27 +87,25 @@ void get_value(T* const host_ptr, const T* const dev_ptr, cudaStream_t cuda_stre } // MAX_DATASET_DIM : must equal to or greater than dataset_dim -template -RAFT_KERNEL random_pickup_kernel(const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] - const std::size_t dataset_dim, - const std::size_t dataset_size, - const std::size_t dataset_ld, - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const std::size_t num_pickup, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - INDEX_T* const result_indices_ptr, // [num_queries, ldr] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] - const std::uint32_t ldr, // (*) ldr >= num_pickup - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] - const std::uint32_t hash_bitlen) +template +RAFT_KERNEL random_pickup_kernel( + const DATASET_DESCRIPTOR_T dataset_desc, + const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const std::size_t num_pickup, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, ldr] + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] + const std::uint32_t ldr, // (*) ldr >= num_pickup + typename DATASET_DESCRIPTOR_T::INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] + const std::uint32_t hash_bitlen) { + using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; + using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; + using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; + const auto ldb = hashmap::get_size(hash_bitlen); const auto global_team_index = (blockIdx.x * blockDim.x + threadIdx.x) / TEAM_SIZE; const uint32_t query_id = blockIdx.y; @@ -114,19 +113,17 @@ RAFT_KERNEL random_pickup_kernel(const DATA_T* const dataset_ptr, // [dataset_s // Load a query extern __shared__ float query_buffer[]; const auto query_smem_buffer_length = - raft::ceildiv(dataset_dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; + raft::ceildiv(dataset_desc.dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; for (uint32_t i = threadIdx.x; i < query_smem_buffer_length; i += blockDim.x) { unsigned j = device::swizzling(i); - if (i < dataset_dim) { - query_buffer[j] = - spatial::knn::detail::utils::mapping{}((queries_ptr + query_id * dataset_dim)[i]); + if (i < dataset_desc.dim) { + query_buffer[j] = spatial::knn::detail::utils::mapping{}( + (queries_ptr + query_id * dataset_desc.dim)[i]); } else { query_buffer[j] = 0.0; } } __syncthreads(); - device::distance_op dist_op( - query_buffer); INDEX_T best_index_team_local; DISTANCE_T best_norm2_team_local = utils::get_max_value(); @@ -136,10 +133,12 @@ RAFT_KERNEL random_pickup_kernel(const DATA_T* const dataset_ptr, // [dataset_s seed_index = seed_ptr[global_team_index + (num_seeds * query_id)]; } else { // Chose a seed node randomly - seed_index = device::xorshift64((global_team_index ^ rand_xor_mask) * (i + 1)) % dataset_size; + seed_index = + device::xorshift64((global_team_index ^ rand_xor_mask) * (i + 1)) % dataset_desc.size; } - const auto norm2 = dist_op(dataset_ptr + (dataset_ld * seed_index), dataset_dim, true); + const auto norm2 = dataset_desc.template compute_similarity( + query_buffer, seed_index, true); if (norm2 < best_norm2_team_local) { best_norm2_team_local = norm2; @@ -161,28 +160,22 @@ RAFT_KERNEL random_pickup_kernel(const DATA_T* const dataset_ptr, // [dataset_s } // MAX_DATASET_DIM : must be equal to or greater than dataset_dim -template -void random_pickup(const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] - const std::size_t dataset_dim, - const std::size_t dataset_size, - const std::size_t dataset_ld, - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const std::size_t num_queries, - const std::size_t num_pickup, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - INDEX_T* const result_indices_ptr, // [num_queries, ldr] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] - const std::size_t ldr, // (*) ldr >= num_pickup - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] - const std::uint32_t hash_bitlen, - cudaStream_t const cuda_stream = 0) +template +void random_pickup( + const DATASET_DESCRIPTOR_T dataset_desc, + const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const std::size_t num_queries, + const std::size_t num_pickup, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, ldr] + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] + const std::size_t ldr, // (*) ldr >= num_pickup + typename DATASET_DESCRIPTOR_T::INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] + const std::uint32_t hash_bitlen, + cudaStream_t const cuda_stream = 0) { const auto block_size = 256u; const auto num_teams_per_threadblock = block_size / TEAM_SIZE; @@ -190,14 +183,11 @@ void random_pickup(const DATA_T* const dataset_ptr, // [dataset_size, dataset_d num_queries); const auto query_smem_buffer_length = - raft::ceildiv(dataset_dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; + raft::ceildiv(dataset_desc.dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; const auto smem_size = query_smem_buffer_length * sizeof(float); - random_pickup_kernel - <<>>(dataset_ptr, - dataset_dim, - dataset_size, - dataset_ld, + random_pickup_kernel + <<>>(dataset_desc, queries_ptr, num_pickup, num_distilation, @@ -313,30 +303,33 @@ void pickup_next_parents(INDEX_T* const parent_candidates_ptr, // [num_queries, template RAFT_KERNEL compute_distance_to_child_nodes_kernel( - const INDEX_T* const parent_node_list, // [num_queries, search_width] - INDEX_T* const parent_candidates_ptr, // [num_queries, search_width] - DISTANCE_T* const parent_distance_ptr, // [num_queries, search_width] + const typename DATASET_DESCRIPTOR_T::INDEX_T* const + parent_node_list, // [num_queries, search_width] + typename DATASET_DESCRIPTOR_T::INDEX_T* const + parent_candidates_ptr, // [num_queries, search_width] + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const + parent_distance_ptr, // [num_queries, search_width] const std::size_t lds, const std::uint32_t search_width, - const DATA_T* const dataset_ptr, // [dataset_size, data_dim] - const std::uint32_t dataset_dim, - const std::uint32_t dataset_size, - const std::uint32_t dataset_ld, - const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] + const DATASET_DESCRIPTOR_T dataset_desc, + const typename DATASET_DESCRIPTOR_T::INDEX_T* const + neighbor_graph_ptr, // [dataset_size, graph_degree] const std::uint32_t graph_degree, - const DATA_T* query_ptr, // [num_queries, data_dim] - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const typename DATASET_DESCRIPTOR_T::DATA_T* query_ptr, // [num_queries, data_dim] + typename DATASET_DESCRIPTOR_T::INDEX_T* const + visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] const std::uint32_t hash_bitlen, - INDEX_T* const result_indices_ptr, // [num_queries, ldd] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] - const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree + typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, ldd] + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] + const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree SAMPLE_FILTER_T sample_filter) { + using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; + using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; + const uint32_t ldb = hashmap::get_size(hash_bitlen); const auto tid = threadIdx.x + blockDim.x * blockIdx.x; const auto global_team_id = tid / TEAM_SIZE; @@ -344,12 +337,12 @@ RAFT_KERNEL compute_distance_to_child_nodes_kernel( extern __shared__ float query_buffer[]; const auto query_smem_buffer_length = - raft::ceildiv(dataset_dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; + raft::ceildiv(dataset_desc.dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; for (uint32_t i = threadIdx.x; i < query_smem_buffer_length; i += blockDim.x) { unsigned j = device::swizzling(i); - if (i < dataset_dim) { + if (i < dataset_desc.dim) { query_buffer[j] = - spatial::knn::detail::utils::mapping{}((query_ptr + query_id * dataset_dim)[i]); + spatial::knn::detail::utils::mapping{}((query_ptr + query_id * dataset_desc.dim)[i]); } else { query_buffer[j] = 0.0; } @@ -357,9 +350,6 @@ RAFT_KERNEL compute_distance_to_child_nodes_kernel( __syncthreads(); if (global_team_id >= search_width * graph_degree) { return; } - device::distance_op dist_op( - query_buffer); - const std::size_t parent_list_index = parent_node_list[global_team_id / graph_degree + (search_width * blockIdx.y)]; @@ -381,8 +371,8 @@ RAFT_KERNEL compute_distance_to_child_nodes_kernel( const auto compute_distance_flag = hashmap::insert( visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id); - const auto norm2 = - dist_op(dataset_ptr + (dataset_ld * child_id), dataset_dim, compute_distance_flag); + const auto norm2 = dataset_desc.template compute_similarity( + query_buffer, child_id, compute_distance_flag); if (compute_distance_flag) { if (threadIdx.x % TEAM_SIZE == 0) { @@ -407,29 +397,29 @@ RAFT_KERNEL compute_distance_to_child_nodes_kernel( template + class SAMPLE_FILTER_T, + class DATASET_DESCRIPTOR_T> void compute_distance_to_child_nodes( - const INDEX_T* const parent_node_list, // [num_queries, search_width] - INDEX_T* const parent_candidates_ptr, // [num_queries, search_width] - DISTANCE_T* const parent_distance_ptr, // [num_queries, search_width] + const typename DATASET_DESCRIPTOR_T::INDEX_T* const + parent_node_list, // [num_queries, search_width] + typename DATASET_DESCRIPTOR_T::INDEX_T* const + parent_candidates_ptr, // [num_queries, search_width] + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const + parent_distance_ptr, // [num_queries, search_width] const std::size_t lds, const uint32_t search_width, - const DATA_T* const dataset_ptr, // [dataset_size, data_dim] - const std::uint32_t dataset_dim, - const std::uint32_t dataset_size, - const std::uint32_t dataset_ld, - const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] + const DATASET_DESCRIPTOR_T dataset_desc, + const typename DATASET_DESCRIPTOR_T::INDEX_T* const + neighbor_graph_ptr, // [dataset_size, graph_degree] const std::uint32_t graph_degree, - const DATA_T* query_ptr, // [num_queries, data_dim] + const typename DATASET_DESCRIPTOR_T::DATA_T* query_ptr, // [num_queries, data_dim] const std::uint32_t num_queries, - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + typename DATASET_DESCRIPTOR_T::INDEX_T* const + visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] const std::uint32_t hash_bitlen, - INDEX_T* const result_indices_ptr, // [num_queries, ldd] - DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] - const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree + typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, ldd] + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] + const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree SAMPLE_FILTER_T sample_filter, cudaStream_t cuda_stream = 0) { @@ -439,20 +429,21 @@ void compute_distance_to_child_nodes( num_queries); const auto query_smem_buffer_length = - raft::ceildiv(dataset_dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; + raft::ceildiv(dataset_desc.dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; - const auto smem_size = query_smem_buffer_length * sizeof(float); + const auto smem_size = + query_smem_buffer_length * sizeof(float) + DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte; - compute_distance_to_child_nodes_kernel + compute_distance_to_child_nodes_kernel <<>>(parent_node_list, parent_candidates_ptr, parent_distance_ptr, lds, search_width, - dataset_ptr, - dataset_dim, - dataset_size, - dataset_ld, + dataset_desc, neighbor_graph_ptr, graph_degree, query_ptr, @@ -609,47 +600,51 @@ void set_value_batch(T* const dev_ptr, // |<--- result_buffer_size --->| // Double buffer (B) template -struct search : search_plan_impl { - using search_plan_impl::max_queries; - using search_plan_impl::itopk_size; - using search_plan_impl::algo; - using search_plan_impl::team_size; - using search_plan_impl::search_width; - using search_plan_impl::min_iterations; - using search_plan_impl::max_iterations; - using search_plan_impl::thread_block_size; - using search_plan_impl::hashmap_mode; - using search_plan_impl::hashmap_min_bitlen; - using search_plan_impl::hashmap_max_fill_rate; - using search_plan_impl::num_random_samplings; - using search_plan_impl::rand_xor_mask; - - using search_plan_impl::dim; - using search_plan_impl::graph_degree; - using search_plan_impl::topk; - - using search_plan_impl::hash_bitlen; - - using search_plan_impl::small_hash_bitlen; - using search_plan_impl::small_hash_reset_interval; - using search_plan_impl::hashmap_size; - using search_plan_impl::dataset_size; - using search_plan_impl::result_buffer_size; - - using search_plan_impl::smem_size; - - using search_plan_impl::hashmap; - using search_plan_impl::num_executed_iterations; - using search_plan_impl::dev_seed; - using search_plan_impl::num_seeds; +struct search : search_plan_impl { + using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; + using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; + using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; + + static_assert(std::is_same_v, "Only float is supported as resulting distance"); + + using search_plan_impl::max_queries; + using search_plan_impl::itopk_size; + using search_plan_impl::algo; + using search_plan_impl::team_size; + using search_plan_impl::search_width; + using search_plan_impl::min_iterations; + using search_plan_impl::max_iterations; + using search_plan_impl::thread_block_size; + using search_plan_impl::hashmap_mode; + using search_plan_impl::hashmap_min_bitlen; + using search_plan_impl::hashmap_max_fill_rate; + using search_plan_impl::num_random_samplings; + using search_plan_impl::rand_xor_mask; + + using search_plan_impl::dim; + using search_plan_impl::graph_degree; + using search_plan_impl::topk; + + using search_plan_impl::hash_bitlen; + + using search_plan_impl::small_hash_bitlen; + using search_plan_impl::small_hash_reset_interval; + using search_plan_impl::hashmap_size; + using search_plan_impl::dataset_size; + using search_plan_impl::result_buffer_size; + + using search_plan_impl::smem_size; + + using search_plan_impl::hashmap; + using search_plan_impl::num_executed_iterations; + using search_plan_impl::dev_seed; + using search_plan_impl::num_seeds; size_t result_buffer_allocation_size; - rmm::device_uvector result_indices; // results_indices_buffer - rmm::device_uvector result_distances; // result_distances_buffer + rmm::device_uvector result_indices; // results_indices_buffer + rmm::device_uvector result_distances; // result_distances_buffer rmm::device_uvector parent_node_list; rmm::device_uvector topk_hint; rmm::device_scalar terminate_flag; // dev_terminate_flag, host_terminate_flag.; @@ -666,8 +661,7 @@ struct search : search_plan_impl { int64_t dim, int64_t graph_degree, uint32_t topk) - : search_plan_impl( - res, params, dim, graph_degree, topk), + : search_plan_impl(res, params, dim, graph_degree, topk), result_indices(0, resource::get_cuda_stream(res)), result_distances(0, resource::get_cuda_stream(res)), parent_node_list(0, resource::get_cuda_stream(res)), @@ -800,7 +794,7 @@ struct search : search_plan_impl { } void operator()(raft::resources const& res, - raft::device_matrix_view dataset, + DATASET_DESCRIPTOR_T dataset_desc, raft::device_matrix_view graph, INDEX_T* const topk_indices_ptr, // [num_queries, topk] DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] @@ -828,24 +822,20 @@ struct search : search_plan_impl { } // Choose initial entry point candidates at random - random_pickup( - dataset.data_handle(), - dataset.extent(1), - dataset.extent(0), - dataset.stride(0), - queries_ptr, - num_queries, - result_buffer_size, - num_random_samplings, - rand_xor_mask, - dev_seed_ptr, - num_seeds, - result_indices.data(), - result_distances.data(), - result_buffer_allocation_size, - hashmap.data(), - hash_bitlen, - stream); + random_pickup(dataset_desc, + queries_ptr, + num_queries, + result_buffer_size, + num_random_samplings, + rand_xor_mask, + dev_seed_ptr, + num_seeds, + result_indices.data(), + result_distances.data(), + result_buffer_allocation_size, + hashmap.data(), + hash_bitlen, + stream); unsigned iter = 0; while (1) { @@ -897,16 +887,13 @@ struct search : search_plan_impl { } // Compute distance to child nodes that are adjacent to the parent node - compute_distance_to_child_nodes( + compute_distance_to_child_nodes( parent_node_list.data(), result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size, result_buffer_allocation_size, search_width, - dataset.data_handle(), - dataset.extent(1), - dataset.extent(0), - dataset.stride(0), + dataset_desc, graph.data_handle(), graph.extent(1), queries_ptr, @@ -993,5 +980,68 @@ struct search : search_plan_impl { } }; +template +struct search, + SAMPLE_FILTER_T> + : public search_plan_impl, + SAMPLE_FILTER_T> { + using DATASET_DESCRIPTOR_T = cagra_q_dataset_descriptor_t; + using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; + using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; + + search(raft::resources const& res, + search_params params, + int64_t dim, + int64_t graph_degree, + uint32_t topk) + : search_plan_impl(res, params, dim, graph_degree, topk) + { + THROW("The multi-kernel mode does not support VPQ"); + } + + void set_params(raft::resources const& res) {} + + void operator()(raft::resources const& res, + DATASET_DESCRIPTOR_T dataset_desc, + raft::device_matrix_view graph, + INDEX_T* const topk_indices_ptr, // [num_queries, topk] + DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] + const DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const uint32_t num_queries, + const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* const num_executed_iterations, // [num_queries,] + uint32_t topk, + SAMPLE_FILTER_T sample_filter) + { + } +}; + } // namespace multi_kernel_search } // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index 11ef7e5211..be5ac0554f 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -71,8 +71,12 @@ struct search_plan_impl_base : public search_params { } }; -template +template struct search_plan_impl : public search_plan_impl_base { + using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; + using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; + using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; + int64_t hash_bitlen; size_t small_hash_bitlen; @@ -111,7 +115,7 @@ struct search_plan_impl : public search_plan_impl_base { virtual ~search_plan_impl() {} virtual void operator()(raft::resources const& res, - raft::device_matrix_view dataset, + DATASET_DESCRIPTOR_T dataset_desc, raft::device_matrix_view graph, INDEX_T* const result_indices_ptr, // [num_queries, topk] DISTANCE_T* const result_distances_ptr, // [num_queries, topk] diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index f1e74ee7a5..4430b929fb 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -48,43 +48,45 @@ namespace single_cta_search { template -struct search : search_plan_impl { - using search_plan_impl::max_queries; - using search_plan_impl::itopk_size; - using search_plan_impl::algo; - using search_plan_impl::team_size; - using search_plan_impl::search_width; - using search_plan_impl::min_iterations; - using search_plan_impl::max_iterations; - using search_plan_impl::thread_block_size; - using search_plan_impl::hashmap_mode; - using search_plan_impl::hashmap_min_bitlen; - using search_plan_impl::hashmap_max_fill_rate; - using search_plan_impl::num_random_samplings; - using search_plan_impl::rand_xor_mask; - - using search_plan_impl::dim; - using search_plan_impl::graph_degree; - using search_plan_impl::topk; - - using search_plan_impl::hash_bitlen; - - using search_plan_impl::small_hash_bitlen; - using search_plan_impl::small_hash_reset_interval; - using search_plan_impl::hashmap_size; - using search_plan_impl::dataset_size; - using search_plan_impl::result_buffer_size; - - using search_plan_impl::smem_size; - - using search_plan_impl::hashmap; - using search_plan_impl::num_executed_iterations; - using search_plan_impl::dev_seed; - using search_plan_impl::num_seeds; +struct search : search_plan_impl { + using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; + using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; + using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; + + using search_plan_impl::max_queries; + using search_plan_impl::itopk_size; + using search_plan_impl::algo; + using search_plan_impl::team_size; + using search_plan_impl::search_width; + using search_plan_impl::min_iterations; + using search_plan_impl::max_iterations; + using search_plan_impl::thread_block_size; + using search_plan_impl::hashmap_mode; + using search_plan_impl::hashmap_min_bitlen; + using search_plan_impl::hashmap_max_fill_rate; + using search_plan_impl::num_random_samplings; + using search_plan_impl::rand_xor_mask; + + using search_plan_impl::dim; + using search_plan_impl::graph_degree; + using search_plan_impl::topk; + + using search_plan_impl::hash_bitlen; + + using search_plan_impl::small_hash_bitlen; + using search_plan_impl::small_hash_reset_interval; + using search_plan_impl::hashmap_size; + using search_plan_impl::dataset_size; + using search_plan_impl::result_buffer_size; + + using search_plan_impl::smem_size; + + using search_plan_impl::hashmap; + using search_plan_impl::num_executed_iterations; + using search_plan_impl::dev_seed; + using search_plan_impl::num_seeds; uint32_t num_itopk_candidates; @@ -93,8 +95,7 @@ struct search : search_plan_impl { int64_t dim, int64_t graph_degree, uint32_t topk) - : search_plan_impl( - res, params, dim, graph_degree, topk) + : search_plan_impl(res, params, dim, graph_degree, topk) { set_params(res); } @@ -128,7 +129,8 @@ struct search : search_plan_impl { sizeof(float) * query_smem_buffer_length + (sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 + sizeof(INDEX_T) * hashmap::get_size(small_hash_bitlen) + sizeof(INDEX_T) * search_width + - sizeof(std::uint32_t) * topk_ws_size + sizeof(std::uint32_t); + sizeof(std::uint32_t) * topk_ws_size + sizeof(std::uint32_t) + + DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte; smem_size = base_smem_size; if (num_itopk_candidates > 256) { // Tentatively calculate the required share memory size when radix @@ -205,7 +207,7 @@ struct search : search_plan_impl { } void operator()(raft::resources const& res, - raft::device_matrix_view dataset, + DATASET_DESCRIPTOR_T dataset_desc, raft::device_matrix_view graph, INDEX_T* const result_indices_ptr, // [num_queries, topk] DISTANCE_T* const result_distances_ptr, // [num_queries, topk] @@ -217,8 +219,8 @@ struct search : search_plan_impl { SAMPLE_FILTER_T sample_filter) { cudaStream_t stream = resource::get_cuda_stream(res); - select_and_run( - dataset, + select_and_run( + dataset_desc, graph, result_indices_ptr, result_distances_ptr, diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh index fef060ffee..a836334667 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh @@ -27,25 +27,23 @@ namespace single_cta_search { template void select_and_run( // raft::resources const& res, - raft::device_matrix_view dataset, - raft::device_matrix_view graph, - INDEX_T* const topk_indices_ptr, // [num_queries, topk] - DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] + DATASET_DESCRIPTOR_T dataset_desc, + raft::device_matrix_view graph, + typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, // [num_queries, topk] + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] + const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] const uint32_t num_queries, - const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] + const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* const num_executed_iterations, // [num_queries,] uint32_t topk, uint32_t num_itopk_candidates, uint32_t block_size, uint32_t smem_size, int64_t hash_bitlen, - INDEX_T* hashmap_ptr, + typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr, size_t small_hash_bitlen, size_t small_hash_reset_interval, uint32_t num_random_samplings, @@ -60,34 +58,38 @@ void select_and_run( // raft::resources const& res, #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - extern template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ +#define instantiate_single_cta_select_and_run( \ + TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + extern template void select_and_run< \ + TEAM_SIZE, \ + MAX_DATASET_DIM, \ + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, \ + SAMPLE_FILTER_T>( \ + raft::neighbors::cagra::detail::standard_dataset_descriptor_t \ + dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); instantiate_single_cta_select_and_run( @@ -125,5 +127,473 @@ instantiate_single_cta_select_and_run( #undef instantiate_single_cta_select_and_run +#define instantiate_q_single_cta_select_and_run(TEAM_SIZE, \ + MAX_DATASET_DIM, \ + CODE_BOOK_T, \ + PQ_BITS, \ + PQ_CODE_BOOK_DIM, \ + DATA_T, \ + INDEX_T, \ + DISTANCE_T, \ + SAMPLE_FILTER_T) \ + extern template void \ + select_and_run, \ + SAMPLE_FILTER_T>( \ + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t dataset, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ + cudaStream_t stream); + +instantiate_q_single_cta_select_and_run( + 8, 128, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 16, 256, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 32, 512, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 1024, + half, + 8, + 2, + half, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 8, 128, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 16, 256, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 32, 512, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 1024, + half, + 8, + 4, + half, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); + +instantiate_q_single_cta_select_and_run( + 8, 128, half, 8, 2, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(16, + 256, + half, + 8, + 2, + float, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 512, + half, + 8, + 2, + float, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 1024, + half, + 8, + 2, + float, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 8, 128, half, 8, 4, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(16, + 256, + half, + 8, + 4, + float, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 512, + half, + 8, + 4, + float, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 1024, + half, + 8, + 4, + float, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); + +instantiate_q_single_cta_select_and_run( + 8, 128, half, 8, 2, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 16, 256, half, 8, 2, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 32, 512, half, 8, 2, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 32, 1024, half, 8, 2, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 8, 128, half, 8, 4, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 16, 256, half, 8, 4, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 32, 512, half, 8, 4, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 32, 1024, half, 8, 4, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); + +instantiate_q_single_cta_select_and_run( + 8, 128, half, 8, 2, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 16, 256, half, 8, 2, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 32, 512, half, 8, 2, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 1024, + half, + 8, + 2, + float, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 8, 128, half, 8, 4, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 16, 256, half, 8, 4, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 32, 512, half, 8, 4, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 1024, + half, + 8, + 4, + float, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); + +instantiate_q_single_cta_select_and_run(8, + 128, + half, + 8, + 2, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(16, + 256, + half, + 8, + 2, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 512, + half, + 8, + 2, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 1024, + half, + 8, + 2, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(8, + 128, + half, + 8, + 4, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(16, + 256, + half, + 8, + 4, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 512, + half, + 8, + 4, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 1024, + half, + 8, + 4, + uint8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); + +instantiate_q_single_cta_select_and_run(8, + 128, + half, + 8, + 2, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(16, + 256, + half, + 8, + 2, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 512, + half, + 8, + 2, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 1024, + half, + 8, + 2, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(8, + 128, + half, + 8, + 4, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(16, + 256, + half, + 8, + 4, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 512, + half, + 8, + 4, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 1024, + half, + 8, + 4, + int8_t, + uint32_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); + +instantiate_q_single_cta_select_and_run(8, + 128, + half, + 8, + 2, + uint8_t, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(16, + 256, + half, + 8, + 2, + uint8_t, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 512, + half, + 8, + 2, + uint8_t, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 1024, + half, + 8, + 2, + uint8_t, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(8, + 128, + half, + 8, + 4, + uint8_t, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(16, + 256, + half, + 8, + 4, + uint8_t, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 512, + half, + 8, + 4, + uint8_t, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 1024, + half, + 8, + 4, + uint8_t, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); + +instantiate_q_single_cta_select_and_run( + 8, 128, half, 8, 2, int8_t, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(16, + 256, + half, + 8, + 2, + int8_t, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 512, + half, + 8, + 2, + int8_t, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 1024, + half, + 8, + 2, + int8_t, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run( + 8, 128, half, 8, 4, int8_t, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(16, + 256, + half, + 8, + 4, + int8_t, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 512, + half, + 8, + 4, + int8_t, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_q_single_cta_select_and_run(32, + 1024, + half, + 8, + 4, + int8_t, + int64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); + +#undef instantiate_q_single_cta_select_and_run + } // namespace single_cta_search } // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 652115928b..a697f9512c 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -456,42 +456,44 @@ __device__ inline void set_value_device(T* const ptr, const T fill, const std::u } // One query one thread block -template -__launch_bounds__(1024, 1) RAFT_KERNEL - search_kernel(INDEX_T* const result_indices_ptr, // [num_queries, top_k] - DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] - const std::uint32_t top_k, - const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] - const std::size_t dataset_dim, - const std::size_t dataset_size, - const std::size_t dataset_ld, // stride of dataset - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const INDEX_T* const knn_graph, // [dataset_size, graph_degree] - const std::uint32_t graph_degree, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::uint32_t internal_topk, - const std::uint32_t search_width, - const std::uint32_t min_iteration, - const std::uint32_t max_iteration, - std::uint32_t* const num_executed_iterations, // [num_queries] - const std::uint32_t hash_bitlen, - const std::uint32_t small_hash_bitlen, - const std::uint32_t small_hash_reset_interval, - SAMPLE_FILTER_T sample_filter) +__launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( + typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, top_k] + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] + const std::uint32_t top_k, + DATASET_DESCRIPTOR_T dataset_desc, + const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] + const typename DATASET_DESCRIPTOR_T::INDEX_T* const knn_graph, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + typename DATASET_DESCRIPTOR_T::INDEX_T* const + visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, // [num_queries] + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + SAMPLE_FILTER_T sample_filter) { - using LOAD_T = device::LOAD_128BIT_T; + using LOAD_T = device::LOAD_128BIT_T; + + using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; + using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; + using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; + using QUERY_T = typename DATASET_DESCRIPTOR_T::QUERY_T; + const auto query_id = blockIdx.y; #ifdef _CLK_BREAKDOWN @@ -525,30 +527,31 @@ __launch_bounds__(1024, 1) RAFT_KERNEL const auto small_hash_size = hashmap::get_size(small_hash_bitlen); const auto query_smem_buffer_length = - raft::ceildiv(dataset_dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; - auto query_buffer = reinterpret_cast(smem); + raft::ceildiv(dataset_desc.dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; + auto query_buffer = reinterpret_cast(smem); auto result_indices_buffer = reinterpret_cast(query_buffer + query_smem_buffer_length); auto result_distances_buffer = reinterpret_cast(result_indices_buffer + result_buffer_size_32); auto visited_hash_buffer = reinterpret_cast(result_distances_buffer + result_buffer_size_32); auto parent_list_buffer = reinterpret_cast(visited_hash_buffer + small_hash_size); - auto topk_ws = reinterpret_cast(parent_list_buffer + search_width); - auto terminate_flag = reinterpret_cast(topk_ws + 3); - auto smem_working_ptr = reinterpret_cast(terminate_flag + 1); + auto distance_work_buffer_ptr = + reinterpret_cast(parent_list_buffer + search_width); + auto topk_ws = reinterpret_cast(distance_work_buffer_ptr + + DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte); + auto terminate_flag = reinterpret_cast(topk_ws + 3); + auto smem_work_ptr = reinterpret_cast(terminate_flag + 1); + + // Set smem working buffer for the distance calculation + dataset_desc.set_smem_ptr(distance_work_buffer_ptr); // A flag for filtering. auto filter_flag = terminate_flag; - const DATA_T* const query_ptr = queries_ptr + query_id * dataset_dim; - for (unsigned i = threadIdx.x; i < query_smem_buffer_length; i += blockDim.x) { - unsigned j = device::swizzling(i); - if (i < dataset_dim) { - query_buffer[j] = spatial::knn::detail::utils::mapping{}(query_ptr[i]); - } else { - query_buffer[j] = 0.0; - } - } + const DATA_T* const query_ptr = queries_ptr + query_id * dataset_desc.dim; + dataset_desc.template copy_query( + query_ptr, query_buffer, query_smem_buffer_length); + if (threadIdx.x == 0) { terminate_flag[0] = 0; topk_ws[0] = ~0u; @@ -568,21 +571,17 @@ __launch_bounds__(1024, 1) RAFT_KERNEL // compute distance to randomly selecting nodes _CLK_START(); const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; - device::compute_distance_to_random_nodes( - result_indices_buffer, - result_distances_buffer, - query_buffer, - dataset_ptr, - dataset_dim, - dataset_size, - dataset_ld, - result_buffer_size, - num_distilation, - rand_xor_mask, - local_seed_ptr, - num_seeds, - local_visited_hashmap_ptr, - hash_bitlen); + device::compute_distance_to_random_nodes(result_indices_buffer, + result_distances_buffer, + query_buffer, + dataset_desc, + result_buffer_size, + num_distilation, + rand_xor_mask, + local_seed_ptr, + num_seeds, + local_visited_hashmap_ptr, + hash_bitlen); __syncthreads(); _CLK_REC(clk_compute_1st_distance); @@ -667,7 +666,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL nullptr, topk_ws, true, - reinterpret_cast(smem_working_ptr)); + reinterpret_cast(smem_work_ptr)); _CLK_REC(clk_topk); // reset small-hash table @@ -688,7 +687,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL parent_list_buffer, result_indices_buffer, internal_topk, - dataset_size, + dataset_desc.size, search_width); _CLK_REC(clk_pickup_parents); } @@ -708,13 +707,11 @@ __launch_bounds__(1024, 1) RAFT_KERNEL // compute the norms between child nodes and query node _CLK_START(); constexpr unsigned max_n_frags = 8; - device::compute_distance_to_child_nodes( + device::compute_distance_to_child_nodes( result_indices_buffer + internal_topk, result_distances_buffer + internal_topk, query_buffer, - dataset_ptr, - dataset_dim, - dataset_ld, + dataset_desc, knn_graph, graph_degree, local_visited_hashmap_ptr, @@ -814,50 +811,53 @@ __launch_bounds__(1024, 1) RAFT_KERNEL #endif } -template struct search_kernel_config { - using kernel_t = - decltype(&search_kernel); + using kernel_t = decltype(&search_kernel); template static auto choose_search_kernel(unsigned itopk_size) -> kernel_t { if (itopk_size <= 64) { - return search_kernel; + return search_kernel; } else if (itopk_size <= 128) { return search_kernel; } else if (itopk_size <= 256) { return search_kernel; } else if (itopk_size <= 512) { return search_kernel; } THROW("No kernel for parametels itopk_size %u, max_candidates %u", itopk_size, MAX_CANDIDATES); @@ -878,9 +878,21 @@ struct search_kernel_config { // Radix-based topk is used constexpr unsigned max_candidates = 32; // to avoid build failure if (itopk_size <= 256) { - return search_kernel; + return search_kernel; } else if (itopk_size <= 512) { - return search_kernel; + return search_kernel; } } THROW("No kernel for parametels itopk_size %u, num_itopk_candidates %u", @@ -891,25 +903,23 @@ struct search_kernel_config { template -void select_and_run( // raft::resources const& res, - raft::device_matrix_view dataset, - raft::device_matrix_view graph, - INDEX_T* const topk_indices_ptr, // [num_queries, topk] - DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] +void select_and_run( + DATASET_DESCRIPTOR_T dataset_desc, + raft::device_matrix_view graph, + typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, // [num_queries, topk] + typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] + const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] const uint32_t num_queries, - const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] + const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* const num_executed_iterations, // [num_queries,] uint32_t topk, uint32_t num_itopk_candidates, uint32_t block_size, // uint32_t smem_size, int64_t hash_bitlen, - INDEX_T* hashmap_ptr, + typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr, size_t small_hash_bitlen, size_t small_hash_reset_interval, uint32_t num_random_samplings, @@ -923,16 +933,11 @@ void select_and_run( // raft::resources const& res, cudaStream_t stream) { auto kernel = - search_kernel_config::choose_itopk_and_mx_candidates(itopk_size, - num_itopk_candidates, - block_size); - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + search_kernel_config:: + choose_itopk_and_mx_candidates(itopk_size, num_itopk_candidates, block_size); + RAFT_CUDA_TRY(cudaFuncSetAttribute(kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size + DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte)); dim3 thread_dims(block_size, 1, 1); dim3 block_dims(1, num_queries, 1); RAFT_LOG_DEBUG( @@ -940,10 +945,7 @@ void select_and_run( // raft::resources const& res, kernel<<>>(topk_indices_ptr, topk_distances_ptr, topk, - dataset.data_handle(), - dataset.extent(1), - dataset.extent(0), - dataset.stride(0), + dataset_desc, queries_ptr, graph.data_handle(), graph.extent(1), diff --git a/cpp/include/raft/neighbors/detail/cagra/utils.hpp b/cpp/include/raft/neighbors/detail/cagra/utils.hpp index 7e403abe91..265cbfdceb 100644 --- a/cpp/include/raft/neighbors/detail/cagra/utils.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/utils.hpp @@ -111,6 +111,11 @@ _RAFT_HOST_DEVICE constexpr unsigned size_of() { return 2; } +template <> +_RAFT_HOST_DEVICE constexpr unsigned size_of() +{ + return 4; +} // max values for data types template diff --git a/cpp/include/raft/neighbors/detail/refine_host-ext.hpp b/cpp/include/raft/neighbors/detail/refine_host-ext.hpp index 69d2bd29b2..f5c8c73bb9 100644 --- a/cpp/include/raft/neighbors/detail/refine_host-ext.hpp +++ b/cpp/include/raft/neighbors/detail/refine_host-ext.hpp @@ -54,6 +54,7 @@ template distance::DistanceType metric); instantiate_raft_neighbors_refine(int64_t, float, float, int64_t); +instantiate_raft_neighbors_refine(uint32_t, float, float, int64_t); instantiate_raft_neighbors_refine(int64_t, int8_t, float, int64_t); instantiate_raft_neighbors_refine(int64_t, uint8_t, float, int64_t); diff --git a/cpp/include/raft/neighbors/detail/vpq_dataset.cuh b/cpp/include/raft/neighbors/detail/vpq_dataset.cuh index f6cd2a1ceb..f1321ba343 100644 --- a/cpp/include/raft/neighbors/detail/vpq_dataset.cuh +++ b/cpp/include/raft/neighbors/detail/vpq_dataset.cuh @@ -81,7 +81,7 @@ auto fill_missing_params_heuristics(const vpq_params& params, const DatasetT& da vpq_params r = params; double n_rows = dataset.extent(0); size_t dim = dataset.extent(1); - if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{4}); } + if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{2}); } if (r.pq_bits == 0) { r.pq_bits = 8; } if (r.vq_n_centers == 0) { r.vq_n_centers = raft::round_up_safe(std::sqrt(n_rows), 8); } if (r.vq_kmeans_trainset_fraction == 0) { diff --git a/cpp/include/raft/neighbors/refine-ext.cuh b/cpp/include/raft/neighbors/refine-ext.cuh index fc57494b22..7948a0e4f2 100644 --- a/cpp/include/raft/neighbors/refine-ext.cuh +++ b/cpp/include/raft/neighbors/refine-ext.cuh @@ -52,7 +52,7 @@ void refine(raft::resources const& handle, #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \ +#define instantiate_raft_neighbors_refine_d(idx_t, data_t, distance_t, matrix_idx) \ extern template void raft::neighbors::refine( \ raft::resources const& handle, \ raft::device_matrix_view dataset, \ @@ -60,8 +60,9 @@ void refine(raft::resources const& handle, raft::device_matrix_view neighbor_candidates, \ raft::device_matrix_view indices, \ raft::device_matrix_view distances, \ - raft::distance::DistanceType metric); \ - \ + raft::distance::DistanceType metric); + +#define instantiate_raft_neighbors_refine_h(idx_t, data_t, distance_t, matrix_idx) \ extern template void raft::neighbors::refine( \ raft::resources const& handle, \ raft::host_matrix_view dataset, \ @@ -71,8 +72,14 @@ void refine(raft::resources const& handle, raft::host_matrix_view distances, \ raft::distance::DistanceType metric); -instantiate_raft_neighbors_refine(int64_t, float, float, int64_t); -instantiate_raft_neighbors_refine(int64_t, int8_t, float, int64_t); -instantiate_raft_neighbors_refine(int64_t, uint8_t, float, int64_t); +instantiate_raft_neighbors_refine_d(int64_t, float, float, int64_t); +instantiate_raft_neighbors_refine_d(int64_t, int8_t, float, int64_t); +instantiate_raft_neighbors_refine_d(int64_t, uint8_t, float, int64_t); + +instantiate_raft_neighbors_refine_h(int64_t, float, float, int64_t); +instantiate_raft_neighbors_refine_h(uint32_t, float, float, int64_t); +instantiate_raft_neighbors_refine_h(int64_t, int8_t, float, int64_t); +instantiate_raft_neighbors_refine_h(int64_t, uint8_t, float, int64_t); -#undef instantiate_raft_neighbors_refine +#undef instantiate_raft_neighbors_refine_d +#undef instantiate_raft_neighbors_refine_h diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_00_generate.py new file mode 100644 index 0000000000..e827c06be5 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_00_generate.py @@ -0,0 +1,84 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +""" + +trailer = """ +} // namespace raft::neighbors::cagra::detail::multi_cta_search +""" + +mxdim_team = [(128, 8), (256, 16), (512, 32), (1024, 32)] +pq_bits = [8] +subspace_dims = [2, 4] +# block = [(64, 16), (128, 8), (256, 4), (512, 2), (1024, 1)] +# mxelem = [64, 128, 256] +load_types = ["uint4"] +code_book_types = ["half"] +search_types = dict( + float_uint32=( + "float", + "uint32_t", + "float", + ), # data_t, vec_idx_t, distance_t + half_uint32=("half", "uint32_t", "float"), + int8_uint32=("int8_t", "uint32_t", "float"), + uint8_uint32=("uint8_t", "uint32_t", "float"), + float_uint64=("float", "uint64_t", "float"), + half_uint64=("half", "uint64_t", "float"), +) +# knn +for type_path, (data_t, idx_t, distance_t) in search_types.items(): + for (mxdim, team) in mxdim_team: + for code_book_t in code_book_types: + for subspace_dim in subspace_dims: + for pq_bit in pq_bits: + path = f"q_search_multi_cta_{type_path}_dim{mxdim}_t{team}_{pq_bit}pq_{subspace_dim}subd_{code_book_t}.cu" + with open(path, "w") as f: + f.write(header) + f.write( + f"instantiate_kernel_selection(\n {team}, {mxdim}, raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t<{data_t} COMMA {code_book_t} COMMA {pq_bit} COMMA {subspace_dim} COMMA {distance_t} COMMA {idx_t}>, raft::neighbors::filtering::none_cagra_sample_filter);\n" + ) + f.write(trailer) + # For pasting into CMakeLists.txt + print(f"src/neighbors/detail/cagra/{path}") diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim1024_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim1024_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..0bd386144c --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim1024_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim1024_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim1024_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..cd891b8e97 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim1024_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim128_t8_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim128_t8_8pq_2subd_half.cu new file mode 100644 index 0000000000..66e8357498 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim128_t8_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim128_t8_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim128_t8_8pq_4subd_half.cu new file mode 100644 index 0000000000..eb84983f9e --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim128_t8_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim256_t16_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim256_t16_8pq_2subd_half.cu new file mode 100644 index 0000000000..c66f8a0ae3 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim256_t16_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim256_t16_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim256_t16_8pq_4subd_half.cu new file mode 100644 index 0000000000..2a1783944c --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim256_t16_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim512_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim512_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..9fa74f1134 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim512_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim512_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim512_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..8fc91b5a10 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim512_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim1024_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim1024_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..4e68c00525 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim1024_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim1024_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim1024_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..5fe526ae47 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim1024_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim128_t8_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim128_t8_8pq_2subd_half.cu new file mode 100644 index 0000000000..64c89a880a --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim128_t8_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim128_t8_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim128_t8_8pq_4subd_half.cu new file mode 100644 index 0000000000..c3e2427f57 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim128_t8_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim256_t16_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim256_t16_8pq_2subd_half.cu new file mode 100644 index 0000000000..0a8826df1c --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim256_t16_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim256_t16_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim256_t16_8pq_4subd_half.cu new file mode 100644 index 0000000000..8019bec3e3 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim256_t16_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim512_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim512_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..1a2a364037 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim512_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim512_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim512_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..2f661538e6 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim512_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim1024_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim1024_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..aec486769f --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim1024_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim1024_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim1024_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..03f27085d8 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim1024_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim128_t8_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim128_t8_8pq_2subd_half.cu new file mode 100644 index 0000000000..119d1f2921 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim128_t8_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim128_t8_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim128_t8_8pq_4subd_half.cu new file mode 100644 index 0000000000..666c676e87 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim128_t8_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim256_t16_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim256_t16_8pq_2subd_half.cu new file mode 100644 index 0000000000..e53b456a54 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim256_t16_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim256_t16_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim256_t16_8pq_4subd_half.cu new file mode 100644 index 0000000000..2aee739141 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim256_t16_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim512_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim512_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..daa442b514 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim512_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim512_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim512_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..a19346d19b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim512_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim1024_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim1024_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..1c1d5381c9 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim1024_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim1024_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim1024_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..b7402a3c38 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim1024_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim128_t8_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim128_t8_8pq_2subd_half.cu new file mode 100644 index 0000000000..f493b83bee --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim128_t8_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim128_t8_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim128_t8_8pq_4subd_half.cu new file mode 100644 index 0000000000..8efcbe0650 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim128_t8_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim256_t16_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim256_t16_8pq_2subd_half.cu new file mode 100644 index 0000000000..cb770f44ba --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim256_t16_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim256_t16_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim256_t16_8pq_4subd_half.cu new file mode 100644 index 0000000000..0fd8ab809c --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim256_t16_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim512_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim512_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..50cf198883 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim512_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim512_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim512_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..1548ed831e --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim512_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim1024_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim1024_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..c60ea7c87d --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim1024_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim1024_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim1024_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..4a68e1e43c --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim1024_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim128_t8_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim128_t8_8pq_2subd_half.cu new file mode 100644 index 0000000000..df9fabd6a5 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim128_t8_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim128_t8_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim128_t8_8pq_4subd_half.cu new file mode 100644 index 0000000000..77075b0a44 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim128_t8_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim256_t16_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim256_t16_8pq_2subd_half.cu new file mode 100644 index 0000000000..374af8b56b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim256_t16_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim256_t16_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim256_t16_8pq_4subd_half.cu new file mode 100644 index 0000000000..ddb80458fd --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim256_t16_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim512_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim512_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..14e5c5d3dc --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim512_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim512_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim512_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..3c1776760a --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim512_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim1024_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim1024_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..e5a0a8882c --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim1024_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim1024_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim1024_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..cee80390e8 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim1024_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim128_t8_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim128_t8_8pq_2subd_half.cu new file mode 100644 index 0000000000..88678bf4ff --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim128_t8_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim128_t8_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim128_t8_8pq_4subd_half.cu new file mode 100644 index 0000000000..baa7ee358a --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim128_t8_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim256_t16_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim256_t16_8pq_2subd_half.cu new file mode 100644 index 0000000000..5c44f052f2 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim256_t16_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim256_t16_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim256_t16_8pq_4subd_half.cu new file mode 100644 index 0000000000..127a065fb5 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim256_t16_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim512_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim512_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..fcf6985f97 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim512_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim512_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim512_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..f361e771b5 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim512_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_multi_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_multi_cta_00_generate.py + * + */ + +#include "search_multi_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/q_search_single_cta_00_generate.py new file mode 100644 index 0000000000..418d528a82 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_00_generate.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +""" + +trailer = """ +} // namespace raft::neighbors::cagra::detail::single_cta_search +""" + +mxdim_team = [(128, 8), (256, 16), (512, 32), (1024, 32)] +# block = [(64, 16), (128, 8), (256, 4), (512, 2), (1024, 1)] +# itopk_candidates = [64, 128, 256] +# itopk_size = [64, 128, 256, 512] +# mxelem = [64, 128, 256] + +pq_bits = [8] +subspace_dims = [2, 4] + +# rblock = [(256, 4), (512, 2), (1024, 1)] +# rcandidates = [32] +# rsize = [256, 512] +code_book_types = ["half"] + +search_types = dict( + float_uint32=("float", "uint32_t", "float"), # data_t, idx_t, distance_t + half_uint32=("half", "uint32_t", "float"), + int8_uint32=("int8_t", "uint32_t", "float"), + uint8_uint32=("uint8_t", "uint32_t", "float"), + float_uint64=("float", "uint64_t", "float"), + half_uint64=("half", "uint64_t", "float"), +) + +# knn +for type_path, (data_t, idx_t, distance_t) in search_types.items(): + for (mxdim, team) in mxdim_team: + for code_book_t in code_book_types: + for subspace_dim in subspace_dims: + for pq_bit in pq_bits: + path = f"q_search_single_cta_{type_path}_dim{mxdim}_t{team}_{pq_bit}pq_{subspace_dim}subd_{code_book_t}.cu" + with open(path, "w") as f: + f.write(header) + f.write( + f"instantiate_kernel_selection(\n {team}, {mxdim}, raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t<{data_t} COMMA {code_book_t} COMMA {pq_bit} COMMA {subspace_dim} COMMA {distance_t} COMMA {idx_t}>, raft::neighbors::filtering::none_cagra_sample_filter);\n" + ) + + f.write(trailer) + # For pasting into CMakeLists.txt + print(f"src/neighbors/detail/cagra/{path}") diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim1024_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim1024_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..d61ad0ce15 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim1024_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim1024_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim1024_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..410d2377ec --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim1024_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim128_t8_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim128_t8_8pq_2subd_half.cu new file mode 100644 index 0000000000..60cd58bab9 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim128_t8_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim128_t8_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim128_t8_8pq_4subd_half.cu new file mode 100644 index 0000000000..dfe5e6f14e --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim128_t8_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim256_t16_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim256_t16_8pq_2subd_half.cu new file mode 100644 index 0000000000..9a5d862276 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim256_t16_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim256_t16_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim256_t16_8pq_4subd_half.cu new file mode 100644 index 0000000000..d92ab50a58 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim256_t16_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim512_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim512_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..aac197d590 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim512_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim512_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim512_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..f38a10e6d0 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim512_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim1024_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim1024_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..5523e63038 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim1024_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim1024_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim1024_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..b06ef3d4fd --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim1024_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim128_t8_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim128_t8_8pq_2subd_half.cu new file mode 100644 index 0000000000..1fddee0e06 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim128_t8_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim128_t8_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim128_t8_8pq_4subd_half.cu new file mode 100644 index 0000000000..2aee442186 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim128_t8_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim256_t16_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim256_t16_8pq_2subd_half.cu new file mode 100644 index 0000000000..7a15e85280 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim256_t16_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim256_t16_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim256_t16_8pq_4subd_half.cu new file mode 100644 index 0000000000..efba46c248 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim256_t16_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim512_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim512_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..990582f18b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim512_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim512_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim512_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..a55907c66f --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim512_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + float COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim1024_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim1024_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..55fd749720 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim1024_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim1024_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim1024_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..4b4063652a --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim1024_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim128_t8_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim128_t8_8pq_2subd_half.cu new file mode 100644 index 0000000000..bae83dc0fa --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim128_t8_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim128_t8_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim128_t8_8pq_4subd_half.cu new file mode 100644 index 0000000000..99492db344 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim128_t8_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim256_t16_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim256_t16_8pq_2subd_half.cu new file mode 100644 index 0000000000..797142e317 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim256_t16_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim256_t16_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim256_t16_8pq_4subd_half.cu new file mode 100644 index 0000000000..9a36c35ae0 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim256_t16_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim512_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim512_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..e0a01e84cc --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim512_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim512_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim512_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..14de1b8941 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim512_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim1024_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim1024_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..b1d50fb445 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim1024_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim1024_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim1024_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..c189a91764 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim1024_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim128_t8_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim128_t8_8pq_2subd_half.cu new file mode 100644 index 0000000000..8693ee3716 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim128_t8_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim128_t8_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim128_t8_8pq_4subd_half.cu new file mode 100644 index 0000000000..216ffd1ec5 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim128_t8_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim256_t16_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim256_t16_8pq_2subd_half.cu new file mode 100644 index 0000000000..36985d218b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim256_t16_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim256_t16_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim256_t16_8pq_4subd_half.cu new file mode 100644 index 0000000000..8d55fe2b09 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim256_t16_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim512_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim512_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..2fdb1cbc20 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim512_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim512_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim512_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..6dc3dc2ca8 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim512_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + half COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint64_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim1024_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim1024_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..21f8633033 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim1024_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim1024_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim1024_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..1a3867e06f --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim1024_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim128_t8_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim128_t8_8pq_2subd_half.cu new file mode 100644 index 0000000000..9cbb16188a --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim128_t8_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim128_t8_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim128_t8_8pq_4subd_half.cu new file mode 100644 index 0000000000..305a1754bc --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim128_t8_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim256_t16_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim256_t16_8pq_2subd_half.cu new file mode 100644 index 0000000000..900e1b69d9 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim256_t16_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim256_t16_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim256_t16_8pq_4subd_half.cu new file mode 100644 index 0000000000..a0bb2259f0 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim256_t16_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim512_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim512_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..09d36a39a0 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim512_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim512_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim512_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..dc9cbb2b56 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim512_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + int8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim1024_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim1024_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..c5508a38e2 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim1024_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim1024_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim1024_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..7024425155 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim1024_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 1024, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim128_t8_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim128_t8_8pq_2subd_half.cu new file mode 100644 index 0000000000..68687bc9cf --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim128_t8_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim128_t8_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim128_t8_8pq_4subd_half.cu new file mode 100644 index 0000000000..60efc55a30 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim128_t8_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(8, + 128, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim256_t16_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim256_t16_8pq_2subd_half.cu new file mode 100644 index 0000000000..b2dfaac5fe --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim256_t16_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim256_t16_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim256_t16_8pq_4subd_half.cu new file mode 100644 index 0000000000..891e9ef7cc --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim256_t16_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(16, + 256, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim512_t32_8pq_2subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim512_t32_8pq_2subd_half.cu new file mode 100644 index 0000000000..91e617204c --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim512_t32_8pq_2subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 2 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim512_t32_8pq_4subd_half.cu b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim512_t32_8pq_4subd_half.cu new file mode 100644 index 0000000000..a01d497676 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim512_t32_8pq_4subd_half.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by q_search_single_cta_00_generate.py + * + * Make changes there and run in this directory: + * + * > python q_search_single_cta_00_generate.py + * + */ + +#include "search_single_cta.cuh" + +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection(32, + 512, + raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t< + uint8_t COMMA half COMMA 8 COMMA 4 COMMA float COMMA uint32_t>, + raft::neighbors::filtering::none_cagra_sample_filter); + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh new file mode 100644 index 0000000000..179bf8f20f --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::neighbors::cagra::detail::multi_cta_search { + +#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATASET_DESC_T, SAMPLE_FILTER_T) \ + template void select_and_run( \ + DATASET_DESC_T dataset_desc, \ + raft::device_matrix_view graph, \ + typename DATASET_DESC_T::INDEX_T* const topk_indices_ptr, \ + typename DATASET_DESC_T::DISTANCE_T* const topk_distances_ptr, \ + const typename DATASET_DESC_T::DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const typename DATASET_DESC_T::INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + typename DATASET_DESC_T::INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ + cudaStream_t stream); + +#define COMMA , + +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py index 6f8766c86b..6f023c39f1 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -header = """ -/* +header = """/* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -38,45 +37,14 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \\ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \\ - template void \\ - select_and_run( \\ - raft::device_matrix_view dataset, \\ - raft::device_matrix_view graph, \\ - INDEX_T* const topk_indices_ptr, \\ - DISTANCE_T* const topk_distances_ptr, \\ - const DATA_T* const queries_ptr, \\ - const uint32_t num_queries, \\ - const INDEX_T* dev_seed_ptr, \\ - uint32_t* const num_executed_iterations, \\ - uint32_t topk, \\ - uint32_t block_size, \\ - uint32_t result_buffer_size, \\ - uint32_t smem_size, \\ - int64_t hash_bitlen, \\ - INDEX_T* hashmap_ptr, \\ - uint32_t num_cta_per_query, \\ - uint32_t num_random_samplings, \\ - uint64_t rand_xor_mask, \\ - uint32_t num_seeds, \\ - size_t itopk_size, \\ - size_t search_width, \\ - size_t min_iterations, \\ - size_t max_iterations, \\ - SAMPLE_FILTER_T sample_filter, \\ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { """ trailer = """ -#undef instantiate_kernel_selection - } // namespace raft::neighbors::cagra::detail::multi_cta_search """ @@ -103,7 +71,7 @@ with open(path, "w") as f: f.write(header) f.write( - f"instantiate_kernel_selection(\n {team}, {mxdim}, {data_t}, {idx_t}, {distance_t}, raft::neighbors::filtering::none_cagra_sample_filter);\n" + f"instantiate_kernel_selection(\n {team}, {mxdim}, raft::neighbors::cagra::detail::standard_dataset_descriptor_t<{data_t} COMMA {idx_t} COMMA {distance_t}>, raft::neighbors::filtering::none_cagra_sample_filter);\n" ) f.write(trailer) # For pasting into CMakeLists.txt diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu index 1a3b2284bd..0e28d7a876 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 32, + 1024, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu index 36e86d9ed6..5e5e80a5de 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 8, + 128, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu index 6f1af2d93f..9039496968 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 16, + 256, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu index 1279f8e415..fe1c7e77e5 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 32, + 512, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu index 0dabff0df5..7ef36baf7d 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 32, 1024, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 32, + 1024, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu index 72bb74cdb8..da51c16314 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 8, 128, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 8, + 128, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu index dceea10b5d..99a4f7feb7 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 16, 256, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 16, + 256, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu index acb8bd6a12..50cdc97dd7 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 32, 512, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 32, + 512, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim1024_t32.cu index fa89bca45f..b2d9cdb600 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim1024_t32.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 32, 1024, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 32, + 1024, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim128_t8.cu index 645ca61ff5..d756b295b7 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim128_t8.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 8, 128, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 8, + 128, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim256_t16.cu index 41b6f9b420..b1e998762c 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim256_t16.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 16, 256, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 16, + 256, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim512_t32.cu index 38f0ac3b04..e712de6390 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim512_t32.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 32, 512, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 32, + 512, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim1024_t32.cu index c462a9d359..282de4a851 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim1024_t32.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 32, 1024, half, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 32, + 1024, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim128_t8.cu index f5b2874e20..71ef968575 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim128_t8.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 8, 128, half, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 8, + 128, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim256_t16.cu index 0b01428b86..7c88406d71 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim256_t16.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 16, 256, half, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 16, + 256, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim512_t32.cu index 70228a129d..360635dddb 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_half_uint64_dim512_t32.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 32, 512, half, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 32, + 512, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu index 0254f09ff0..3f129bd7cf 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 32, + 1024, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu index 2b67e7e968..053b73275e 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 8, + 128, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu index 17d6722e58..a1bb20369a 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 16, + 256, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu index 38f02812e2..dbbc8bdd21 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 32, + 512, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu index fa111196c6..125499e319 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 32, + 1024, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu index 1ef3c28aa3..f2117c4f80 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 8, + 128, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu index d26cb44843..8e5ba0f98f 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 16, + 256, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu index 4d4322f261..bea7d25392 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,43 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::multi_cta_search { +#include "search_multi_cta.cuh" -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::multi_cta_search { instantiate_kernel_selection( - 32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection + 32, + 512, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh new file mode 100644 index 0000000000..7fb705a2d2 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::neighbors::cagra::detail::single_cta_search { + +#define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATASET_DESC_T, SAMPLE_FILTER_T) \ + template void select_and_run( \ + DATASET_DESC_T dataset_desc, \ + raft::device_matrix_view graph, \ + typename DATASET_DESC_T::INDEX_T* const topk_indices_ptr, \ + typename DATASET_DESC_T::DISTANCE_T* const topk_distances_ptr, \ + const typename DATASET_DESC_T::DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const typename DATASET_DESC_T::INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + typename DATASET_DESC_T::INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ + cudaStream_t stream); + +#define COMMA , + +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py index 1515f43134..0e809e4dc3 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -header = """ -/* +header = """/* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -38,46 +37,14 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \\ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \\ - template void \\ - select_and_run( \\ - raft::device_matrix_view dataset, \\ - raft::device_matrix_view graph, \\ - INDEX_T* const topk_indices_ptr, \\ - DISTANCE_T* const topk_distances_ptr, \\ - const DATA_T* const queries_ptr, \\ - const uint32_t num_queries, \\ - const INDEX_T* dev_seed_ptr, \\ - uint32_t* const num_executed_iterations, \\ - uint32_t topk, \\ - uint32_t num_itopk_candidates, \\ - uint32_t block_size, \\ - uint32_t smem_size, \\ - int64_t hash_bitlen, \\ - INDEX_T* hashmap_ptr, \\ - size_t small_hash_bitlen, \\ - size_t small_hash_reset_interval, \\ - uint32_t num_random_samplings, \\ - uint64_t rand_xor_mask, \\ - uint32_t num_seeds, \\ - size_t itopk_size, \\ - size_t search_width, \\ - size_t min_iterations, \\ - size_t max_iterations, \\ - SAMPLE_FILTER_T sample_filter, \\ - cudaStream_t stream); +#include +namespace raft::neighbors::cagra::detail::single_cta_search { """ trailer = """ -#undef instantiate_single_cta_search_kernel - } // namespace raft::neighbors::cagra::detail::single_cta_search """ @@ -107,7 +74,7 @@ with open(path, "w") as f: f.write(header) f.write( - f"instantiate_single_cta_select_and_run(\n {team}, {mxdim}, {data_t}, {idx_t}, {distance_t}, raft::neighbors::filtering::none_cagra_sample_filter);\n" + f"instantiate_kernel_selection(\n {team}, {mxdim}, raft::neighbors::cagra::detail::standard_dataset_descriptor_t<{data_t} COMMA {idx_t} COMMA {distance_t}>, raft::neighbors::filtering::none_cagra_sample_filter);\n" ) f.write(trailer) diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu index b8c23103ba..8a9fc408ee 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 32, + 1024, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu index 8ab1897119..c6f7c90c69 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 8, + 128, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu index 9fd36b4cb9..2766286673 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 16, + 256, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu index a9ee2c864b..98ee189766 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 32, + 512, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu index dadc574b65..c3ea39a729 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 32, 1024, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 32, + 1024, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu index 30e043f47e..a53457656c 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 8, 128, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 8, + 128, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu index 089e4c930f..52318efb85 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 16, 256, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 16, + 256, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu index 3e8ffb8bf8..6451fdc7f3 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 32, 512, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 32, + 512, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim1024_t32.cu index 29e7bfa250..e927fd0878 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim1024_t32.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 32, 1024, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 32, + 1024, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim128_t8.cu index a004f900d0..3f3d22ee08 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim128_t8.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 8, 128, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 8, + 128, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim256_t16.cu index 549849b21d..a84e5b8bd7 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim256_t16.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 16, 256, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 16, + 256, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim512_t32.cu index 3825f572f7..af4248865b 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint32_dim512_t32.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 32, 512, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 32, + 512, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim1024_t32.cu index 31d83f443b..16bd0cb647 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim1024_t32.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 32, 1024, half, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 32, + 1024, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim128_t8.cu index 3493ab294c..afc59c8a59 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim128_t8.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 8, 128, half, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 8, + 128, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim256_t16.cu index 6e09709994..147d31cf85 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim256_t16.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 16, 256, half, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 16, + 256, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim512_t32.cu index 4bc0158f7e..5624a71c3c 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_half_uint64_dim512_t32.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2023-2024, NVIDIA CORPORATION. * @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 32, 512, half, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 32, + 512, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu index 279587738e..761fb705ba 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 32, + 1024, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu index ef127d3f7d..84b76cba53 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 8, + 128, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu index 7fcfdcc28e..598fff9cdf 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 16, + 256, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu index a6c606d99b..e7a1a9d9c6 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 32, + 512, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu index 0b8be56614..d40b9285fc 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 32, + 1024, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu index 4c193b9408..073bb350da 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 8, + 128, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu index bdf16d2f03..29b0224b4d 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 16, + 256, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu index 93624df4aa..d9601de2ad 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu @@ -1,6 +1,5 @@ - /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,44 +23,15 @@ * */ -#include -#include - -namespace raft::neighbors::cagra::detail::single_cta_search { +#include "search_single_cta.cuh" -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - cudaStream_t stream); +#include -instantiate_single_cta_select_and_run( - 32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_search_kernel +namespace raft::neighbors::cagra::detail::single_cta_search { +instantiate_kernel_selection( + 32, + 512, + raft::neighbors::cagra::detail::standard_dataset_descriptor_t, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/refine_host_float_float.cpp b/cpp/src/neighbors/detail/refine_host_float_float.cpp index c596200c0a..09dcae9c3a 100644 --- a/cpp/src/neighbors/detail/refine_host_float_float.cpp +++ b/cpp/src/neighbors/detail/refine_host_float_float.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,5 +25,6 @@ distance::DistanceType metric); instantiate_raft_neighbors_refine(int64_t, float, float, int64_t); +instantiate_raft_neighbors_refine(uint32_t, float, float, int64_t); #undef instantiate_raft_neighbors_refine diff --git a/cpp/src/neighbors/refine_float_float.cu b/cpp/src/neighbors/refine_float_float.cu index ea6892d2c5..75851eeedb 100644 --- a/cpp/src/neighbors/refine_float_float.cu +++ b/cpp/src/neighbors/refine_float_float.cu @@ -1,6 +1,6 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ #include -#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \ +#define instantiate_raft_neighbors_refine_d(idx_t, data_t, distance_t, matrix_idx) \ template void raft::neighbors::refine( \ raft::resources const& handle, \ raft::device_matrix_view dataset, \ @@ -34,17 +34,21 @@ raft::device_matrix_view neighbor_candidates, \ raft::device_matrix_view indices, \ raft::device_matrix_view distances, \ - raft::distance::DistanceType metric); \ - \ - template void raft::neighbors::refine( \ - raft::resources const& handle, \ - raft::host_matrix_view dataset, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbor_candidates, \ - raft::host_matrix_view indices, \ - raft::host_matrix_view distances, \ raft::distance::DistanceType metric); -instantiate_raft_neighbors_refine(int64_t, float, float, int64_t); +#define instantiate_raft_neighbors_refine_h(idx_t, data_t, distance_t, matrix_idx) \ + template void raft::neighbors::refine( \ + raft::resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_refine_d(int64_t, float, float, int64_t); +instantiate_raft_neighbors_refine_h(int64_t, float, float, int64_t); +instantiate_raft_neighbors_refine_h(uint32_t, float, float, int64_t); -#undef instantiate_raft_neighbors_refine +#undef instantiate_raft_neighbors_refine_d +#undef instantiate_raft_neighbors_refine_h diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index ecb871fccc..20ed3bacc7 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -372,6 +372,8 @@ if(BUILD_TESTS) test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu test/neighbors/ann_cagra/test_float_int64_t.cu test/neighbors/ann_cagra/test_half_int64_t.cu + test/neighbors/ann_cagra_vpq/test_float_int64_t.cu + test/neighbors/ann_cagra_vpq/test_float_uint32_t.cu src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu diff --git a/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh b/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh index 175e4ef483..5cca6d561a 100644 --- a/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh +++ b/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,87 +21,133 @@ namespace raft::neighbors::cagra::detail { namespace multi_cta_search { -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - extern template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ +#define instantiate_kernel_selection( \ + DATASET_DESCRIPTOR, TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + extern template void \ + select_and_run, \ + SAMPLE_FILTER_T>( \ + raft::neighbors::cagra::detail::DATASET_DESCRIPTOR dataset_desc, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_kernel_selection( - 32, 1024, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 8, 128, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 16, 256, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 512, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection(standard_dataset_descriptor_t, + 32, + 1024, + float, + uint64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection(standard_dataset_descriptor_t, + 8, + 128, + float, + uint64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection(standard_dataset_descriptor_t, + 16, + 256, + float, + uint64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_kernel_selection(standard_dataset_descriptor_t, + 32, + 512, + float, + uint64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); #undef instantiate_kernel_selection } // namespace multi_cta_search namespace single_cta_search { -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - extern template void \ - select_and_run( \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ +#define instantiate_single_cta_select_and_run( \ + DATASET_DESCRIPTOR, TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ + extern template void \ + select_and_run, \ + SAMPLE_FILTER_T>( \ + raft::neighbors::cagra::detail::DATASET_DESCRIPTOR dataset_desc, \ + raft::device_matrix_view graph, \ + INDEX_T* const topk_indices_ptr, \ + DISTANCE_T* const topk_distances_ptr, \ + const DATA_T* const queries_ptr, \ + const uint32_t num_queries, \ + const INDEX_T* dev_seed_ptr, \ + uint32_t* const num_executed_iterations, \ + uint32_t topk, \ + uint32_t num_itopk_candidates, \ + uint32_t block_size, \ + uint32_t smem_size, \ + int64_t hash_bitlen, \ + INDEX_T* hashmap_ptr, \ + size_t small_hash_bitlen, \ + size_t small_hash_reset_interval, \ + uint32_t num_random_samplings, \ + uint64_t rand_xor_mask, \ + uint32_t num_seeds, \ + size_t itopk_size, \ + size_t search_width, \ + size_t min_iterations, \ + size_t max_iterations, \ + SAMPLE_FILTER_T sample_filter, \ cudaStream_t stream); -instantiate_single_cta_select_and_run( - 32, 1024, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 8, 128, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 16, 256, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 512, float, uint64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run(standard_dataset_descriptor_t, + 32, + 1024, + float, + uint64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run(standard_dataset_descriptor_t, + 8, + 128, + float, + uint64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run(standard_dataset_descriptor_t, + 16, + 256, + float, + uint64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); +instantiate_single_cta_select_and_run(standard_dataset_descriptor_t, + 32, + 512, + float, + uint64_t, + float, + raft::neighbors::filtering::none_cagra_sample_filter); } // namespace single_cta_search -} // namespace raft::neighbors::cagra::detail \ No newline at end of file +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/test/neighbors/ann_cagra_vpq.cuh b/cpp/test/neighbors/ann_cagra_vpq.cuh new file mode 100755 index 0000000000..503b1a413a --- /dev/null +++ b/cpp/test/neighbors/ann_cagra_vpq.cuh @@ -0,0 +1,336 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../test_utils.cuh" +#include "ann_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace { +template +void GenerateDataset(T* const dataset_ptr, + T* const query_ptr, + const std::size_t dataset_size, + const std::size_t query_size, + const std::size_t dim, + const std::size_t num_centers, + cudaStream_t cuda_stream) +{ + auto center_list = raft::make_host_matrix(num_centers, dim); + auto host_dataset = raft::make_host_matrix(std::max(dataset_size, query_size), dim); + + std::normal_distribution dist(0, 1); + std::mt19937 mt(0); + for (std::size_t i = 0; i < center_list.size(); i++) { + center_list.data_handle()[i] = dist(mt); + } + + std::uniform_int_distribution i_dist(0, num_centers - 1); + for (std::size_t i = 0; i < dataset_size; i++) { + const auto center_index = i_dist(mt); + for (std::size_t j = 0; j < dim; j++) { + host_dataset.data_handle()[i * dim + j] = + center_list.data_handle()[center_index + j] + dist(mt) * 1e-1; + } + } + raft::copy(dataset_ptr, host_dataset.data_handle(), dataset_size * dim, cuda_stream); + + for (std::size_t i = 0; i < query_size; i++) { + const auto center_index = i_dist(mt); + for (std::size_t j = 0; j < dim; j++) { + host_dataset.data_handle()[i * dim + j] = + center_list.data_handle()[center_index + j] + dist(mt) * 1e-1; + } + } + raft::copy(query_ptr, host_dataset.data_handle(), query_size * dim, cuda_stream); +} +} // namespace + +namespace raft::neighbors::cagra { +struct AnnCagraVpqInputs { + int n_queries; + int n_rows; + int dim; + int k; + int pq_len; + int pq_bits; + graph_build_algo build_algo; + search_algo algo; + int max_queries; + int team_size; + int itopk_size; + int search_width; + raft::distance::DistanceType metric; + bool host_dataset; + bool include_serialized_dataset; + // std::optional + double min_recall; // = std::nullopt; +}; + +inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraVpqInputs& p) +{ + std::vector algo = {"single-cta", "multi_cta", "multi_kernel", "auto"}; + std::vector build_algo = {"IVF_PQ", "NN_DESCENT"}; + os << "{n_queries=" << p.n_queries << ", dataset shape=" << p.n_rows << "x" << p.dim + << ", k=" << p.k << ", pq_bits=" << p.pq_bits << ", pq_len=" << p.pq_len << ", " + << algo.at((int)p.algo) << ", max_queries=" << p.max_queries << ", itopk_size=" << p.itopk_size + << ", search_width=" << p.search_width << ", metric=" << static_cast(p.metric) + << (p.host_dataset ? ", host" : ", device") + << ", build_algo=" << build_algo.at((int)p.build_algo) << '}' << std::endl; + return os; +} + +template +class AnnCagraVpqTest : public ::testing::TestWithParam { + public: + AnnCagraVpqTest() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam::GetParam()), + database(0, stream_), + search_queries(0, stream_) + { + } + + protected: + void testCagra() + { + size_t queries_size = ps.n_queries * ps.k; + std::vector indices_Cagra(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_Cagra(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database.data(), + ps.n_queries, + ps.n_rows, + ps.dim, + ps.k, + ps.metric); + update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + const auto vpq_k = ps.k * 16; + { + rmm::device_uvector distances_dev(vpq_k * ps.n_queries, stream_); + rmm::device_uvector indices_dev(vpq_k * ps.n_queries, stream_); + + { + if ((ps.dim % ps.pq_len) != 0) { + // TODO: remove this requirement in the algorithm. + GTEST_SKIP() << "(TODO) At the moment dim, (" << ps.dim + << ") must be a multiple of pq_len (" << ps.pq_len << ")"; + } + cagra::index_params index_params; + index_params.compression = vpq_params{.pq_bits = static_cast(ps.pq_bits), + .pq_dim = static_cast(ps.dim / ps.pq_len)}; + index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is + // not used for knn_graph building. + index_params.build_algo = ps.build_algo; + cagra::search_params search_params; + search_params.algo = ps.algo; + search_params.max_queries = ps.max_queries; + search_params.team_size = ps.team_size; + search_params.itopk_size = ps.itopk_size; + + auto database_view = + raft::make_device_matrix_view(database.data(), ps.n_rows, ps.dim); + + { + cagra::index index(handle_); + if (ps.host_dataset) { + auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + auto database_host_view = raft::make_host_matrix_view( + database_host.data_handle(), ps.n_rows, ps.dim); + index = cagra::build(handle_, index_params, database_host_view); + } else { + index = cagra::build(handle_, index_params, database_view); + }; + cagra::serialize(handle_, "cagra_index", index, ps.include_serialized_dataset); + } + + auto index = cagra::deserialize(handle_, "cagra_index"); + if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); } + + // CAGRA-Q sanity check: we've built the right index type + auto* vpq_dataset = + dynamic_cast*>(&index.data()); + EXPECT_NE(vpq_dataset, nullptr) + << "Expected VPQ dataset, because we're testing CAGRA-Q here."; + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.n_queries, ps.dim); + auto indices_out_view = + raft::make_device_matrix_view(indices_dev.data(), ps.n_queries, vpq_k); + auto dists_out_view = raft::make_device_matrix_view( + distances_dev.data(), ps.n_queries, vpq_k); + + cagra::search( + handle_, search_params, index, search_queries_view, indices_out_view, dists_out_view); + + { + auto host_dataset = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy( + host_dataset.data_handle(), (const DataT*)database.data(), ps.n_rows * ps.dim, stream_); + + auto host_queries = raft::make_host_matrix(ps.n_queries, ps.dim); + raft::copy(host_queries.data_handle(), + (const DataT*)search_queries_view.data_handle(), + ps.n_queries * ps.dim, + stream_); + + auto host_index_candidate = raft::make_host_matrix(ps.n_queries, vpq_k); + raft::copy(host_index_candidate.data_handle(), + indices_out_view.data_handle(), + ps.n_queries * vpq_k, + stream_); + + auto host_indices_Cagra_view = + raft::make_host_matrix_view(indices_Cagra.data(), ps.n_queries, ps.k); + + auto host_dists_Cagra_view = + raft::make_host_matrix_view(distances_Cagra.data(), ps.n_queries, ps.k); + + resource::sync_stream(handle_); + + raft::neighbors::refine(handle_, + raft::make_const_mdspan(host_dataset.view()), + raft::make_const_mdspan(host_queries.view()), + raft::make_const_mdspan(host_index_candidate.view()), + host_indices_Cagra_view, + host_dists_Cagra_view, + ps.metric); + + raft::copy(indices_dev.data(), + host_indices_Cagra_view.data_handle(), + ps.k * ps.n_queries, + stream_); + raft::copy(distances_dev.data(), + host_dists_Cagra_view.data_handle(), + ps.k * ps.n_queries, + stream_); + resource::sync_stream(handle_); + } + } + + double min_recall = ps.min_recall; + EXPECT_TRUE(eval_neighbours(indices_naive, + indices_Cagra, + distances_naive, + distances_Cagra, + ps.n_queries, + ps.k, + 0.003, + min_recall)); + EXPECT_TRUE(eval_distances(handle_, + database.data(), + search_queries.data(), + indices_dev.data(), + distances_dev.data(), + ps.n_rows, + ps.dim, + ps.n_queries, + ps.k, + ps.metric, + 1.0e-4)); + } + } + + void SetUp() override + { + database.resize(((size_t)ps.n_rows) * ps.dim, stream_); + search_queries.resize(ps.n_queries * ps.dim, stream_); + GenerateDataset(database.data(), + search_queries.data(), + ps.n_rows, + ps.n_queries, + ps.dim, + static_cast(std::sqrt(ps.n_rows)), + stream_); + resource::sync_stream(handle_); + } + + void TearDown() override + { + resource::sync_stream(handle_); + database.resize(0, stream_); + search_queries.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnCagraVpqInputs ps; + rmm::device_uvector database; + rmm::device_uvector search_queries; +}; + +const std::vector vpq_inputs = raft::util::itertools::product( + {100}, // n_queries + {1000, 10000}, // n_rows + {128, 132, 192, 256, 512, 768}, // dim + {8, 12}, // k + {2}, // pq_len + {8}, // pq_bits + {graph_build_algo::NN_DESCENT}, // build_algo + {search_algo::SINGLE_CTA, search_algo::MULTI_CTA}, // algo + {0}, // max_queries + {0}, // team_size + {512}, // itopk_size + {1}, // search_width + {raft::distance::DistanceType::L2Expanded}, // metric + {false}, // host_dataset + {true}, // include_serialized_dataset + {0.8} // min_recall +); + +} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra_vpq/test_float_int64_t.cu b/cpp/test/neighbors/ann_cagra_vpq/test_float_int64_t.cu new file mode 100644 index 0000000000..f60edb5ed6 --- /dev/null +++ b/cpp/test/neighbors/ann_cagra_vpq/test_float_int64_t.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "../ann_cagra_vpq.cuh" + +#include + +namespace raft::neighbors::cagra { + +typedef AnnCagraVpqTest AnnCagraVpqTestF_I64; +TEST_P(AnnCagraVpqTestF_I64, AnnCagraVpq) { this->testCagra(); } + +INSTANTIATE_TEST_CASE_P(AnnCagraVpqTest, AnnCagraVpqTestF_I64, ::testing::ValuesIn(vpq_inputs)); + +} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_cagra_vpq/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra_vpq/test_float_uint32_t.cu new file mode 100644 index 0000000000..19d3f32250 --- /dev/null +++ b/cpp/test/neighbors/ann_cagra_vpq/test_float_uint32_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../ann_cagra_vpq.cuh" + +#include + +namespace raft::neighbors::cagra { + +typedef AnnCagraVpqTest AnnCagraVpqTestF_U32; +TEST_P(AnnCagraVpqTestF_U32, AnnCagraVpq) { this->testCagra(); } + +INSTANTIATE_TEST_CASE_P(AnnCagraVpqTest, AnnCagraVpqTestF_U32, ::testing::ValuesIn(vpq_inputs)); + +} // namespace raft::neighbors::cagra diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 6be2ac7fc7..3e0bead665 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -300,7 +300,7 @@ auto eval_distances(raft::resources const& handle, raft::matrix::copy_rows( handle, - make_device_matrix_view(x, k, n_cols), + make_device_matrix_view(x, n_rows, n_cols), y.view(), make_device_vector_view(neighbors + i * k, k)); From b7734949c8b6ae00a6ccc726289fa2641b1d30c3 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Thu, 21 Mar 2024 07:53:50 +0100 Subject: [PATCH 13/18] Add CAGRA-Q to ANN benchmarks (#2233) Add the relevant options to the CAGRA parameter parser and refinement to the CAGRA ANN benchmark. No changes to the library code. NB: the new option won't work correctly until https://github.com/rapidsai/raft/pull/2206 is merged. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/2233 --- .../src/raft/raft_ann_bench_param_parser.h | 23 +++++ cpp/bench/ann/src/raft/raft_cagra_wrapper.h | 83 ++++++++++++++++++- 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index 2339677340..48bf1d70d8 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -43,6 +43,7 @@ extern template class raft::bench::ann::RaftIvfPQ; #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA extern template class raft::bench::ann::RaftCagra; +extern template class raft::bench::ann::RaftCagra; extern template class raft::bench::ann::RaftCagra; extern template class raft::bench::ann::RaftCagra; #endif @@ -149,6 +150,20 @@ void parse_build_param(const nlohmann::json& conf, } } +inline void parse_build_param(const nlohmann::json& conf, raft::neighbors::vpq_params& param) +{ + if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); } + if (conf.contains("pq_dim")) { param.pq_dim = conf.at("pq_dim"); } + if (conf.contains("vq_n_centers")) { param.vq_n_centers = conf.at("vq_n_centers"); } + if (conf.contains("kmeans_n_iters")) { param.kmeans_n_iters = conf.at("kmeans_n_iters"); } + if (conf.contains("vq_kmeans_trainset_fraction")) { + param.vq_kmeans_trainset_fraction = conf.at("vq_kmeans_trainset_fraction"); + } + if (conf.contains("pq_kmeans_trainset_fraction")) { + param.pq_kmeans_trainset_fraction = conf.at("pq_kmeans_trainset_fraction"); + } +} + nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf, const std::string& prefix, bool remove_prefix = true) @@ -204,6 +219,12 @@ void parse_build_param(const nlohmann::json& conf, } param.nn_descent_params = nn_param; } + nlohmann::json comp_search_conf = collect_conf_with_prefix(conf, "compression_"); + if (!comp_search_conf.empty()) { + raft::neighbors::vpq_params vpq_pams; + parse_build_param(comp_search_conf, vpq_pams); + param.cagra_params.compression.emplace(vpq_pams); + } } raft::bench::ann::AllocatorType parse_allocator(std::string mem_type) @@ -248,5 +269,7 @@ void parse_search_param(const nlohmann::json& conf, if (conf.contains("internal_dataset_memory_type")) { param.dataset_mem = parse_allocator(conf.at("internal_dataset_memory_type")); } + // Same ratio as in IVF-PQ + param.refine_ratio = conf.value("refine_ratio", 1.0f); } #endif diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 25f7f93777..70fd22001e 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -56,6 +57,7 @@ class RaftCagra : public ANN, public AnnGPU { struct SearchParam : public AnnSearchParam { raft::neighbors::experimental::cagra::search_params p; + float refine_ratio; AllocatorType graph_mem = AllocatorType::Device; AllocatorType dataset_mem = AllocatorType::Device; auto needs_dataset() const -> bool override { return true; } @@ -98,6 +100,8 @@ class RaftCagra : public ANN, public AnnGPU { // will be filled with (size_t)-1 void search( const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override; + void search_base( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const; [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { @@ -124,6 +128,7 @@ class RaftCagra : public ANN, public AnnGPU { raft::mr::cuda_huge_page_resource mr_huge_page_; AllocatorType graph_mem_; AllocatorType dataset_mem_; + float refine_ratio_; BuildParam index_params_; bool need_dataset_update_; raft::neighbors::cagra::search_params search_params_; @@ -151,6 +156,9 @@ void RaftCagra::build(const T* dataset, size_t nrow) auto& params = index_params_.cagra_params; + // Do include the compressed dataset for the CAGRA-Q + bool shall_include_dataset = params.compression.has_value(); + index_ = std::make_shared>( std::move(raft::neighbors::cagra::detail::build(handle_, params, @@ -159,7 +167,7 @@ void RaftCagra::build(const T* dataset, size_t nrow) index_params_.ivf_pq_refine_rate, index_params_.ivf_pq_build_params, index_params_.ivf_pq_search_params, - false))); + shall_include_dataset))); } inline std::string allocator_to_string(AllocatorType mem_type) @@ -179,6 +187,7 @@ void RaftCagra::set_search_param(const AnnSearchParam& param) { auto search_param = dynamic_cast(param); search_params_ = search_param.p; + refine_ratio_ = search_param.refine_ratio; if (search_param.graph_mem != graph_mem_) { // Move graph to correct memory space graph_mem_ = search_param.graph_mem; @@ -223,12 +232,16 @@ void RaftCagra::set_search_param(const AnnSearchParam& param) template void RaftCagra::set_search_dataset(const T* dataset, size_t nrow) { + using ds_idx_type = decltype(index_->data().n_rows()); + bool is_vpq = + dynamic_cast*>(&index_->data()) || + dynamic_cast*>(&index_->data()); // It can happen that we are re-using a previous algo object which already has // the dataset set. Check if we need update. if (static_cast(input_dataset_v_->extent(0)) != nrow || input_dataset_v_->data_handle() != dataset) { *input_dataset_v_ = make_device_matrix_view(dataset, nrow, this->dim_); - need_dataset_update_ = true; + need_dataset_update_ = !is_vpq; // ignore update if this is a VPQ dataset. } } @@ -258,7 +271,7 @@ std::unique_ptr> RaftCagra::copy() } template -void RaftCagra::search( +void RaftCagra::search_base( const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const { IdxT* neighbors_IdxT; @@ -286,4 +299,68 @@ void RaftCagra::search( raft::resource::get_cuda_stream(handle_)); } } + +template +void RaftCagra::search( + const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const +{ + auto k0 = static_cast(refine_ratio_ * k); + const bool disable_refinement = k0 <= static_cast(k); + const raft::resources& res = handle_; + auto stream = resource::get_cuda_stream(res); + + if (disable_refinement) { + search_base(queries, batch_size, k, neighbors, distances); + } else { + auto candidate_ixs = raft::make_device_matrix(res, batch_size, k0); + auto candidate_dists = raft::make_device_matrix(res, batch_size, k0); + search_base(queries, + batch_size, + k0, + reinterpret_cast(candidate_ixs.data_handle()), + candidate_dists.data_handle()); + + if (raft::get_device_for_address(input_dataset_v_->data_handle()) >= 0) { + auto queries_v = + raft::make_device_matrix_view(queries, batch_size, dimension_); + auto neighours_v = raft::make_device_matrix_view( + reinterpret_cast(neighbors), batch_size, k); + auto distances_v = raft::make_device_matrix_view(distances, batch_size, k); + raft::neighbors::refine( + res, + *input_dataset_v_, + queries_v, + raft::make_const_mdspan(candidate_ixs.view()), + neighours_v, + distances_v, + index_->metric()); + } else { + auto dataset_host = raft::make_host_matrix_view( + input_dataset_v_->data_handle(), input_dataset_v_->extent(0), input_dataset_v_->extent(1)); + auto queries_host = raft::make_host_matrix(batch_size, dimension_); + auto candidates_host = raft::make_host_matrix(batch_size, k0); + auto neighbors_host = raft::make_host_matrix(batch_size, k); + auto distances_host = raft::make_host_matrix(batch_size, k); + + raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream); + raft::copy( + candidates_host.data_handle(), candidate_ixs.data_handle(), candidates_host.size(), stream); + + raft::resource::sync_stream(res); // wait for the queries and candidates + raft::neighbors::refine(res, + dataset_host, + queries_host.view(), + candidates_host.view(), + neighbors_host.view(), + distances_host.view(), + index_->metric()); + + raft::copy(neighbors, + reinterpret_cast(neighbors_host.data_handle()), + neighbors_host.size(), + stream); + raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); + } + } +} } // namespace raft::bench::ann From fd91efb22a64fa1f75bd83d024e4c1ffed3702f0 Mon Sep 17 00:00:00 2001 From: rhdong Date: Thu, 21 Mar 2024 05:48:55 -0700 Subject: [PATCH 14/18] [FEA] Add support for bitmap_view & the API of `bitmap_to_csr` (#2109) - This PR is one part of the feature of #1969 Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) Authors: - rhdong (https://github.com/rhdong) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Micka (https://github.com/lowener) URL: https://github.com/rapidsai/raft/pull/2109 --- cpp/bench/prims/CMakeLists.txt | 9 +- cpp/bench/prims/sparse/bitmap_to_csr.cu | 156 +++++++++ cpp/include/raft/core/bitmap.cuh | 127 ++++++++ cpp/include/raft/sparse/convert/csr.cuh | 29 +- .../sparse/convert/detail/bitmap_to_csr.cuh | 300 ++++++++++++++++++ cpp/test/sparse/convert_csr.cu | 243 ++++++++++++++ docs/source/cpp_api/core.rst | 3 +- docs/source/cpp_api/core_bitmap.rst | 15 + 8 files changed, 879 insertions(+), 3 deletions(-) create mode 100644 cpp/bench/prims/sparse/bitmap_to_csr.cu create mode 100644 cpp/include/raft/core/bitmap.cuh create mode 100644 cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh create mode 100644 docs/source/cpp_api/core_bitmap.rst diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 95361e19ca..9f23c44a5c 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -131,7 +131,14 @@ if(BUILD_PRIMS_BENCH) bench/prims/random/rng.cu bench/prims/random/subsample.cu bench/prims/main.cpp ) - ConfigureBench(NAME SPARSE_BENCH PATH bench/prims/sparse/convert_csr.cu bench/prims/main.cpp) + ConfigureBench( + NAME + SPARSE_BENCH + PATH + bench/prims/sparse/bitmap_to_csr.cu + bench/prims/sparse/convert_csr.cu + bench/prims/main.cpp + ) ConfigureBench( NAME diff --git a/cpp/bench/prims/sparse/bitmap_to_csr.cu b/cpp/bench/prims/sparse/bitmap_to_csr.cu new file mode 100644 index 0000000000..ed53df3265 --- /dev/null +++ b/cpp/bench/prims/sparse/bitmap_to_csr.cu @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace raft::bench::sparse { + +template +struct bench_param { + index_t n_rows; + index_t n_cols; + float sparsity; +}; + +template +inline auto operator<<(std::ostream& os, const bench_param& params) -> std::ostream& +{ + os << " rows*cols=" << params.n_rows << "*" << params.n_cols << "\tsparsity=" << params.sparsity; + return os; +} + +template +struct BitmapToCsrBench : public fixture { + BitmapToCsrBench(const bench_param& p) + : fixture(true), + params(p), + handle(stream), + bitmap_d(0, stream), + nnz(0), + indptr_d(0, stream), + indices_d(0, stream), + values_d(0, stream) + { + index_t element = raft::ceildiv(params.n_rows * params.n_cols, index_t(sizeof(bitmap_t) * 8)); + std::vector bitmap_h(element); + nnz = create_sparse_matrix(params.n_rows, params.n_cols, params.sparsity, bitmap_h); + + bitmap_d.resize(bitmap_h.size(), stream); + indptr_d.resize(params.n_rows + 1, stream); + indices_d.resize(nnz, stream); + values_d.resize(nnz, stream); + + update_device(bitmap_d.data(), bitmap_h.data(), bitmap_h.size(), stream); + + resource::sync_stream(handle); + } + + index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitmap) + { + index_t total = static_cast(m * n); + index_t num_ones = static_cast((total * 1.0f) * sparsity); + index_t res = num_ones; + + for (auto& item : bitmap) { + item = static_cast(0); + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, total - 1); + + while (num_ones > 0) { + index_t index = dis(gen); + + bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))]; + index_t bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1) == 0) { + element |= (static_cast(1) << bit_position); + num_ones--; + } + } + return res; + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + auto bitmap = + raft::core::bitmap_view(bitmap_d.data(), params.n_rows, params.n_cols); + + auto csr_view = raft::make_device_compressed_structure_view( + indptr_d.data(), indices_d.data(), params.n_rows, params.n_cols, nnz); + auto csr = raft::make_device_csr_matrix(handle, csr_view); + + raft::sparse::convert::bitmap_to_csr(handle, bitmap, csr); + + resource::sync_stream(handle); + loop_on_state(state, [this, &bitmap, &csr]() { + raft::sparse::convert::bitmap_to_csr(handle, bitmap, csr); + }); + } + + protected: + const raft::device_resources handle; + + bench_param params; + + rmm::device_uvector bitmap_d; + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + rmm::device_uvector values_d; + + index_t nnz; +}; // struct BitmapToCsrBench + +template +const std::vector> getInputs() +{ + std::vector> param_vec; + struct TestParams { + index_t m; + index_t n; + float sparsity; + }; + + const std::vector params_group = raft::util::itertools::product( + {index_t(10), index_t(1024)}, {index_t(1024 * 1024)}, {0.01f, 0.1f, 0.2f, 0.5f}); + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.sparsity})); + } + return param_vec; +} + +RAFT_BENCH_REGISTER((BitmapToCsrBench), "", getInputs()); +RAFT_BENCH_REGISTER((BitmapToCsrBench), "", getInputs()); + +} // namespace raft::bench::sparse diff --git a/cpp/include/raft/core/bitmap.cuh b/cpp/include/raft/core/bitmap.cuh new file mode 100644 index 0000000000..829c84ed25 --- /dev/null +++ b/cpp/include/raft/core/bitmap.cuh @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::core { +/** + * @defgroup bitmap Bitmap + * @{ + */ +/** + * @brief View of a RAFT Bitmap. + * + * This lightweight structure which represents and manipulates a two-dimensional bitmap matrix view + * with row major order. This class provides functionality for handling a matrix where each element + * is represented as a bit in a bitmap. + * + * @tparam bitmap_t Underlying type of the bitmap array. Default is uint32_t. + * @tparam index_t Indexing type used. Default is uint32_t. + */ +template +struct bitmap_view : public bitset_view { + static_assert((std::is_same::value || + std::is_same::value), + "The bitmap_t must be uint32_t or uint64_t."); + /** + * @brief Create a bitmap view from a device raw pointer. + * + * @param bitmap_ptr Device raw pointer + * @param rows Number of row in the matrix. + * @param cols Number of col in the matrix. + */ + _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols) + : bitset_view(bitmap_ptr, rows * cols), rows_(rows), cols_(cols) + { + } + + /** + * @brief Create a bitmap view from a device vector view of the bitset. + * + * @param bitmap_span Device vector view of the bitmap + * @param rows Number of row in the matrix. + * @param cols Number of col in the matrix. + */ + _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, + index_t rows, + index_t cols) + : bitset_view(bitmap_span, rows * cols), rows_(rows), cols_(cols) + { + } + + private: + // Hide the constructors of bitset_view. + _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t bitmap_len) + : bitset_view(bitmap_ptr, bitmap_len) + { + } + + _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, + index_t bitmap_len) + : bitset_view(bitmap_span, bitmap_len) + { + } + + public: + /** + * @brief Device function to test if a given row and col are set in the bitmap. + * + * @param row Row index of the bit to test + * @param col Col index of the bit to test + * @return bool True if index has not been unset in the bitset + */ + inline _RAFT_DEVICE auto test(const index_t row, const index_t col) const -> bool + { + return test(row * cols_ + col); + } + + /** + * @brief Device function to set a given row and col to set_value in the bitset. + * + * @param row Row index of the bit to set + * @param col Col index of the bit to set + * @param new_value Value to set the bit to (true or false) + */ + inline _RAFT_DEVICE void set(const index_t row, const index_t col, bool new_value) const + { + set(row * cols_ + col, &new_value); + } + + /** + * @brief Get the total number of rows + * @return index_t The total number of rows + */ + inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; } + + /** + * @brief Get the total number of columns + * @return index_t The total number of columns + */ + inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; } + + private: + index_t rows_; + index_t cols_; +}; + +/** @} */ +} // end namespace raft::core diff --git a/cpp/include/raft/sparse/convert/csr.cuh b/cpp/include/raft/sparse/convert/csr.cuh index 999e64cb0b..081192ed44 100644 --- a/cpp/include/raft/sparse/convert/csr.cuh +++ b/cpp/include/raft/sparse/convert/csr.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,10 @@ #pragma once +#include +#include #include +#include #include #include @@ -102,6 +105,30 @@ void adj_to_csr(raft::resources const& handle, detail::adj_to_csr(handle, adj, row_ind, num_rows, num_cols, tmp, out_col_ind); } +/** + * @brief Converts a bitmap matrix to a Compressed Sparse Row (CSR) format matrix. + * + * @tparam bitmap_t The data type of the elements in the bitmap matrix. + * @tparam index_t The data type used for indexing the elements in the matrices. + * @tparam csr_matrix_t Specifies the CSR matrix type, constrained to + * raft::device_csr_matrix. + * + * @param[in] handle The RAFT handle containing the CUDA stream for operations. + * @param[in] bitmap The bitmap matrix view, to be converted to CSR format. + * @param[out] csr Output parameter where the resulting CSR matrix is stored. In the + * bitmap, each '1' bit corresponds to a non-zero element in the CSR matrix. + */ +template >> +void bitmap_to_csr(raft::resources const& handle, + raft::core::bitmap_view bitmap, + csr_matrix_t& csr) +{ + detail::bitmap_to_csr(handle, bitmap, csr); +} + }; // end NAMESPACE convert }; // end NAMESPACE sparse }; // end NAMESPACE raft diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh new file mode 100644 index 0000000000..b0315486ff --- /dev/null +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // detail::popc +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace cg = cooperative_groups; + +namespace raft { +namespace sparse { +namespace convert { +namespace detail { + +// Threads per block in calc_nnz_by_rows_kernel. +static const constexpr int calc_nnz_by_rows_tpb = 32; + +template +RAFT_KERNEL __launch_bounds__(calc_nnz_by_rows_tpb) calc_nnz_by_rows_kernel(const bitmap_t* bitmap, + index_t num_rows, + index_t num_cols, + index_t bitmap_num, + nnz_t* nnz_per_row) +{ + constexpr bitmap_t FULL_MASK = ~bitmap_t(0u); + constexpr bitmap_t ONE = bitmap_t(1u); + constexpr index_t BITS_PER_BITMAP = sizeof(bitmap_t) * 8; + + auto block = cg::this_thread_block(); + auto tile = cg::tiled_partition<32>(block); + + int lane_id = threadIdx.x & 0x1f; + + for (index_t row = blockIdx.x; row < num_rows; row += gridDim.x) { + index_t offset = 0; + index_t s_bit = row * num_cols; + index_t e_bit = s_bit + num_cols; + index_t l_sum = 0; + + while (offset < num_cols) { + index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; + bitmap_t l_bitmap = bitmap_t(0); + + if (bitmap_idx * BITS_PER_BITMAP < e_bit) { l_bitmap = bitmap[bitmap_idx]; } + + if (s_bit > bitmap_idx * BITS_PER_BITMAP) { + l_bitmap >>= (s_bit - bitmap_idx * BITS_PER_BITMAP); + l_bitmap <<= (s_bit - bitmap_idx * BITS_PER_BITMAP); + } + + if ((bitmap_idx + 1) * BITS_PER_BITMAP > e_bit) { + l_bitmap <<= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); + l_bitmap >>= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); + } + + l_sum += static_cast(raft::detail::popc(l_bitmap)); + offset += BITS_PER_BITMAP * warpSize; + } + + l_sum = cg::reduce(tile, l_sum, cg::plus()); + + if (lane_id == 0) { *(nnz_per_row + row) += static_cast(l_sum); } + } +} + +template +void calc_nnz_by_rows(raft::resources const& handle, + const bitmap_t* bitmap, + index_t num_rows, + index_t num_cols, + nnz_t* nnz_per_row) +{ + auto stream = resource::get_cuda_stream(handle); + const index_t total = num_rows * num_cols; + const index_t bitmap_num = raft::ceildiv(total, index_t(sizeof(bitmap_t) * 8)); + + int dev_id, sm_count, blocks_per_sm; + + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, calc_nnz_by_rows_kernel, calc_nnz_by_rows_tpb, 0); + + index_t max_active_blocks = sm_count * blocks_per_sm; + auto grid = std::min(max_active_blocks, raft::ceildiv(bitmap_num, index_t(calc_nnz_by_rows_tpb))); + auto block = calc_nnz_by_rows_tpb; + + calc_nnz_by_rows_kernel + <<>>(bitmap, num_rows, num_cols, bitmap_num, nnz_per_row); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +/* + Execute the exclusive_scan within one warp with no inter-warp communication. + This function calculates the exclusive prefix sum of `value` across threads within the same warp. + Each thread in the warp will end up with the sum of all the values of the threads with lower IDs + in the same warp, with the first thread always getting a sum of 0. +*/ +template +RAFT_DEVICE_INLINE_FUNCTION value_t warp_exclusive_scan(value_t value) +{ + int lane_id = threadIdx.x & 0x1f; + value_t shifted_value = __shfl_up_sync(0xffffffff, value, 1, warpSize); + if (lane_id == 0) shifted_value = 0; + + value_t sum = shifted_value; + + for (int i = 1; i < warpSize; i *= 2) { + value_t n = __shfl_up_sync(0xffffffff, sum, i, warpSize); + if (lane_id >= i) { sum += n; } + } + return sum; +} + +// Threads per block in fill_indices_by_rows_kernel. +static const constexpr int fill_indices_by_rows_tpb = 32; + +template +RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) + fill_indices_by_rows_kernel(const bitmap_t* bitmap, + const index_t* indptr, + index_t num_rows, + index_t num_cols, + nnz_t nnz, + index_t bitmap_num, + index_t* indices) +{ + constexpr bitmap_t FULL_MASK = ~bitmap_t(0u); + constexpr bitmap_t ONE = bitmap_t(1u); + constexpr index_t BITS_PER_BITMAP = sizeof(bitmap_t) * 8; + + int lane_id = threadIdx.x & 0x1f; + + // Ensure the HBM allocated for CSR values is sufficient to handle all non-zero bitmap bits. + // An assert will trigger if the allocated HBM is insufficient when `NDEBUG` isn't defined. + // Note: Assertion is active only if `NDEBUG` is undefined. + if constexpr (check_nnz) { + if (lane_id == 0) { assert(nnz < indptr[num_rows]); } + } + +#pragma unroll + for (index_t row = blockIdx.x; row < num_rows; row += gridDim.x) { + index_t g_sum = 0; + index_t s_bit = row * num_cols; + index_t e_bit = s_bit + num_cols; + index_t indptr_row = indptr[row]; + +#pragma unroll + for (index_t offset = 0; offset < num_cols; offset += BITS_PER_BITMAP * warpSize) { + index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; + bitmap_t l_bitmap = bitmap_t(0); + index_t l_offset = offset + lane_id * BITS_PER_BITMAP - (s_bit % BITS_PER_BITMAP); + + if (bitmap_idx * BITS_PER_BITMAP < e_bit) { l_bitmap = bitmap[bitmap_idx]; } + + if (s_bit > bitmap_idx * BITS_PER_BITMAP) { + l_bitmap >>= (s_bit - bitmap_idx * BITS_PER_BITMAP); + l_bitmap <<= (s_bit - bitmap_idx * BITS_PER_BITMAP); + } + + if ((bitmap_idx + 1) * BITS_PER_BITMAP > e_bit) { + l_bitmap <<= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); + l_bitmap >>= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); + } + + index_t l_sum = + g_sum + warp_exclusive_scan(static_cast(raft::detail::popc(l_bitmap))); + + for (int i = 0; i < BITS_PER_BITMAP; i++) { + if (l_bitmap & (ONE << i)) { + indices[indptr_row + l_sum] = l_offset + i; + l_sum++; + } + } + g_sum = __shfl_sync(0xffffffff, l_sum, warpSize - 1); + } + } +} + +template +void fill_indices_by_rows(raft::resources const& handle, + const bitmap_t* bitmap, + const index_t* indptr, + index_t num_rows, + index_t num_cols, + nnz_t nnz, + index_t* indices) +{ + auto stream = resource::get_cuda_stream(handle); + const index_t total = num_rows * num_cols; + const index_t bitmap_num = raft::ceildiv(total, index_t(sizeof(bitmap_t) * 8)); + + int dev_id, sm_count, blocks_per_sm; + + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, + fill_indices_by_rows_kernel, + fill_indices_by_rows_tpb, + 0); + + index_t max_active_blocks = sm_count * blocks_per_sm; + auto grid = std::min(max_active_blocks, num_rows); + auto block = fill_indices_by_rows_tpb; + + fill_indices_by_rows_kernel + <<>>(bitmap, indptr, num_rows, num_cols, nnz, bitmap_num, indices); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template >> +void bitmap_to_csr(raft::resources const& handle, + raft::core::bitmap_view bitmap, + csr_matrix_t& csr) +{ + auto csr_view = csr.structure_view(); + + if (csr_view.get_n_rows() == 0 || csr_view.get_n_cols() == 0 || csr_view.get_nnz() == 0) { + return; + } + + RAFT_EXPECTS(bitmap.get_n_rows() == csr_view.get_n_rows(), + "Number of rows in bitmap must be equal to " + "number of rows in csr"); + + RAFT_EXPECTS(bitmap.get_n_cols() == csr_view.get_n_cols(), + "Number of columns in bitmap must be equal to " + "number of columns in csr"); + + auto thrust_policy = resource::get_thrust_policy(handle); + auto stream = resource::get_cuda_stream(handle); + + index_t* indptr = csr_view.get_indptr().data(); + index_t* indices = csr_view.get_indices().data(); + + RAFT_CUDA_TRY(cudaMemsetAsync(indptr, 0, (csr_view.get_n_rows() + 1) * sizeof(index_t), stream)); + + calc_nnz_by_rows(handle, bitmap.data(), csr_view.get_n_rows(), csr_view.get_n_cols(), indptr); + thrust::exclusive_scan(thrust_policy, indptr, indptr + csr_view.get_n_rows() + 1, indptr); + + if constexpr (is_device_csr_sparsity_owning_v) { + index_t nnz = 0; + RAFT_CUDA_TRY(cudaMemcpyAsync( + &nnz, indptr + csr_view.get_n_rows(), sizeof(index_t), cudaMemcpyDeviceToHost, stream)); + resource::sync_stream(handle); + csr.initialize_sparsity(nnz); + } + constexpr bool check_nnz = is_device_csr_sparsity_preserving_v; + fill_indices_by_rows( + handle, + bitmap.data(), + indptr, + csr_view.get_n_rows(), + csr_view.get_n_cols(), + csr_view.get_nnz(), + indices); + + thrust::fill_n(thrust_policy, + csr.get_elements().data(), + csr_view.get_nnz(), + typename csr_matrix_t::element_type(1)); +} + +}; // end NAMESPACE detail +}; // end NAMESPACE convert +}; // end NAMESPACE sparse +}; // end NAMESPACE raft diff --git a/cpp/test/sparse/convert_csr.cu b/cpp/test/sparse/convert_csr.cu index 4af792a9ea..1cd49b0bbd 100644 --- a/cpp/test/sparse/convert_csr.cu +++ b/cpp/test/sparse/convert_csr.cu @@ -16,6 +16,7 @@ #include "../test_utils.cuh" +#include #include #include #include @@ -218,5 +219,247 @@ INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, CSRAdjGraphTestL, ::testing::ValuesIn(csradjgraph_inputs_l)); +/******************************** bitmap to csr ********************************/ + +template +struct BitmapToCSRInputs { + index_t n_rows; + index_t n_cols; + float sparsity; + bool owning; +}; + +template +class BitmapToCSRTest : public ::testing::TestWithParam> { + public: + BitmapToCSRTest() + : stream(resource::get_cuda_stream(handle)), + params(::testing::TestWithParam>::GetParam()), + bitmap_d(0, stream), + indices_d(0, stream), + indptr_d(0, stream), + values_d(0, stream), + indptr_expected_d(0, stream), + indices_expected_d(0, stream), + values_expected_d(0, stream) + { + } + + protected: + index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitmap) + { + index_t total = static_cast(m * n); + index_t num_ones = static_cast((total * 1.0f) * sparsity); + index_t res = num_ones; + + for (auto& item : bitmap) { + item = static_cast(0); + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, total - 1); + + while (num_ones > 0) { + index_t index = dis(gen); + + bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))]; + index_t bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1) == 0) { + element |= (static_cast(1) << bit_position); + num_ones--; + } + } + return res; + } + + void cpu_convert_to_csr(std::vector& bitmap, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + index_t index = 0; + bitmap_t element = 0; + index_t bit_position = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + index = i * cols + j; + element = bitmap[index / (8 * sizeof(bitmap_t))]; + bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1)) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + bool csr_compare(const std::vector& row_ptrs1, + const std::vector& col_indices1, + const std::vector& row_ptrs2, + const std::vector& col_indices2) + { + if (row_ptrs1.size() != row_ptrs2.size()) { return false; } + + if (col_indices1.size() != col_indices2.size()) { return false; } + + if (!std::equal(row_ptrs1.begin(), row_ptrs1.end(), row_ptrs2.begin())) { return false; } + + for (size_t i = 0; i < row_ptrs1.size() - 1; ++i) { + size_t start_idx = row_ptrs1[i]; + size_t end_idx = row_ptrs1[i + 1]; + + std::vector cols1(col_indices1.begin() + start_idx, col_indices1.begin() + end_idx); + std::vector cols2(col_indices2.begin() + start_idx, col_indices2.begin() + end_idx); + + std::sort(cols1.begin(), cols1.end()); + std::sort(cols2.begin(), cols2.end()); + + if (cols1 != cols2) { return false; } + } + + return true; + } + + void SetUp() override + { + index_t element = raft::ceildiv(params.n_rows * params.n_cols, index_t(sizeof(bitmap_t) * 8)); + std::vector bitmap_h(element); + nnz = create_sparse_matrix(params.n_rows, params.n_cols, params.sparsity, bitmap_h); + + std::vector indices_h(nnz); + std::vector indptr_h(params.n_rows + 1); + + cpu_convert_to_csr(bitmap_h, params.n_rows, params.n_cols, indices_h, indptr_h); + + bitmap_d.resize(bitmap_h.size(), stream); + indptr_d.resize(params.n_rows + 1, stream); + indices_d.resize(nnz, stream); + + indptr_expected_d.resize(params.n_rows + 1, stream); + indices_expected_d.resize(nnz, stream); + values_expected_d.resize(nnz, stream); + + thrust::fill_n(resource::get_thrust_policy(handle), values_expected_d.data(), nnz, value_t{1}); + + values_d.resize(nnz, stream); + + update_device(indices_expected_d.data(), indices_h.data(), indices_h.size(), stream); + update_device(indptr_expected_d.data(), indptr_h.data(), indptr_h.size(), stream); + update_device(bitmap_d.data(), bitmap_h.data(), bitmap_h.size(), stream); + + resource::sync_stream(handle); + } + + void Run() + { + auto bitmap = + raft::core::bitmap_view(bitmap_d.data(), params.n_rows, params.n_cols); + + if (params.owning) { + auto csr = + raft::make_device_csr_matrix(handle, params.n_rows, params.n_cols, nnz); + auto csr_view = csr.structure_view(); + + convert::bitmap_to_csr(handle, bitmap, csr); + raft::copy(indptr_d.data(), csr_view.get_indptr().data(), indptr_d.size(), stream); + raft::copy(indices_d.data(), csr_view.get_indices().data(), indices_d.size(), stream); + raft::copy(values_d.data(), csr.get_elements().data(), nnz, stream); + } else { + auto csr_view = raft::make_device_compressed_structure_view( + indptr_d.data(), indices_d.data(), params.n_rows, params.n_cols, nnz); + auto csr = raft::make_device_csr_matrix(handle, csr_view); + + convert::bitmap_to_csr(handle, bitmap, csr); + raft::copy(values_d.data(), csr.get_elements().data(), nnz, stream); + } + resource::sync_stream(handle); + + std::vector indices_h(indices_expected_d.size(), 0); + std::vector indices_expected_h(indices_expected_d.size(), 0); + update_host(indices_h.data(), indices_d.data(), indices_h.size(), stream); + update_host(indices_expected_h.data(), indices_expected_d.data(), indices_h.size(), stream); + + std::vector indptr_h(indptr_expected_d.size(), 0); + std::vector indptr_expected_h(indptr_expected_d.size(), 0); + update_host(indptr_h.data(), indptr_d.data(), indptr_h.size(), stream); + update_host(indptr_expected_h.data(), indptr_expected_d.data(), indptr_h.size(), stream); + + resource::sync_stream(handle); + + ASSERT_TRUE(csr_compare(indptr_h, indices_h, indptr_expected_h, indices_expected_h)); + ASSERT_TRUE(raft::devArrMatch( + values_expected_d.data(), values_d.data(), nnz, raft::Compare(), stream)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + BitmapToCSRInputs params; + + rmm::device_uvector bitmap_d; + + index_t nnz; + + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + rmm::device_uvector values_d; + + rmm::device_uvector indptr_expected_d; + rmm::device_uvector indices_expected_d; + rmm::device_uvector values_expected_d; +}; + +using BitmapToCSRTestI = BitmapToCSRTest; +TEST_P(BitmapToCSRTestI, Result) { Run(); } + +using BitmapToCSRTestL = BitmapToCSRTest; +TEST_P(BitmapToCSRTestL, Result) { Run(); } + +template +const std::vector> bitmaptocsr_inputs = { + {0, 0, 0.2, false}, + {10, 32, 0.4, false}, + {10, 3, 0.2, false}, + {32, 1024, 0.4, false}, + {1024, 1048576, 0.01, false}, + {1024, 1024, 0.4, false}, + {64 * 1024 + 10, 2, 0.3, false}, // 64K + 10 is slightly over maximum of blockDim.y + {16, 16, 0.3, false}, // No peeling-remainder + {17, 16, 0.3, false}, // Check peeling-remainder + {18, 16, 0.3, false}, // Check peeling-remainder + {32 + 9, 33, 0.2, false}, // Check peeling-remainder + {2, 33, 0.2, false}, // Check peeling-remainder + {0, 0, 0.2, true}, + {10, 32, 0.4, true}, + {10, 3, 0.2, true}, + {32, 1024, 0.4, true}, + {1024, 1048576, 0.01, true}, + {1024, 1024, 0.4, true}, + {64 * 1024 + 10, 2, 0.3, true}, // 64K + 10 is slightly over maximum of blockDim.y + {16, 16, 0.3, true}, // No peeling-remainder + {17, 16, 0.3, true}, // Check peeling-remainder + {18, 16, 0.3, true}, // Check peeling-remainder + {32 + 9, 33, 0.2, true}, // Check peeling-remainder + {2, 33, 0.2, true}, // Check peeling-remainder +}; + +INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, + BitmapToCSRTestI, + ::testing::ValuesIn(bitmaptocsr_inputs)); +INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, + BitmapToCSRTestL, + ::testing::ValuesIn(bitmaptocsr_inputs)); + } // namespace sparse } // namespace raft diff --git a/docs/source/cpp_api/core.rst b/docs/source/cpp_api/core.rst index 39e57fd69a..4122a18506 100644 --- a/docs/source/cpp_api/core.rst +++ b/docs/source/cpp_api/core.rst @@ -21,4 +21,5 @@ expose in public APIs. core_interruptible.rst core_operators.rst core_math.rst - core_bitset.rst \ No newline at end of file + core_bitset.rst + core_bitmap.rst \ No newline at end of file diff --git a/docs/source/cpp_api/core_bitmap.rst b/docs/source/cpp_api/core_bitmap.rst new file mode 100644 index 0000000000..6c1dc607bf --- /dev/null +++ b/docs/source/cpp_api/core_bitmap.rst @@ -0,0 +1,15 @@ +Bitmap +====== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::core* + +.. doxygengroup:: bitmap + :project: RAFT + :members: + :content-only: \ No newline at end of file From 9637b3c22a3e67d20200886cffb5e804e33473dc Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Thu, 21 Mar 2024 13:32:41 -0400 Subject: [PATCH 15/18] Update pre-commit-hooks to v0.0.3 (#2239) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6bca1d228e..2b89948ec1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -104,7 +104,7 @@ repos: hooks: - id: check-json - repo: https://github.com/rapidsai/pre-commit-hooks - rev: v0.0.1 + rev: v0.0.3 hooks: - id: verify-copyright files: | From 52e0d7331cb533955f479d82e4656253eaa9ef6f Mon Sep 17 00:00:00 2001 From: Michael Schellenberger Costa Date: Thu, 21 Mar 2024 16:15:41 -0700 Subject: [PATCH 16/18] Replace usages of raw `get_upstream` with `get_upstream_resource()` (#2207) We want to get rid of raw memory resources so move to the new interface instead Authors: - Michael Schellenberger Costa (https://github.com/miscco) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2207 --- cpp/test/core/device_resources_manager.cpp | 16 ++++++++-------- cpp/test/core/handle.cpp | 8 +++++--- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/cpp/test/core/device_resources_manager.cpp b/cpp/test/core/device_resources_manager.cpp index c7c9e175ea..b9b8996a09 100644 --- a/cpp/test/core/device_resources_manager.cpp +++ b/cpp/test/core/device_resources_manager.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include @@ -114,17 +115,16 @@ TEST(DeviceResourcesManager, ObeysSetters) auto* mr = dynamic_cast*>( rmm::mr::get_current_device_resource()); - auto* workspace_mr = - dynamic_cast*>( - dynamic_cast*>( - res.get_workspace_resource()) - ->get_upstream()); + rmm::device_async_resource_ref workspace_mr = + dynamic_cast*>( + res.get_workspace_resource()) + ->get_upstream_resource(); if (upstream_mrs[i % devices.size()] != nullptr) { // Expect that the current memory resource is a pool memory resource as requested EXPECT_NE(mr, nullptr); - // Expect that the upstream workspace memory resource is a pool memory - // resource as requested - EXPECT_NE(workspace_mr, nullptr); + + // We cannot easily check the type of a resource_ref + (void)workspace_mr; } { diff --git a/cpp/test/core/handle.cpp b/cpp/test/core/handle.cpp index 0b0b4b54ab..be18b0d5b4 100644 --- a/cpp/test/core/handle.cpp +++ b/cpp/test/core/handle.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -281,7 +282,8 @@ TEST(Raft, WorkspaceResource) raft::handle_t handle; // The returned resource is always a limiting adaptor - auto* orig_mr = resource::get_workspace_resource(handle)->get_upstream(); + rmm::device_async_resource_ref orig_mr{ + resource::get_workspace_resource(handle)->get_upstream_resource()}; // Let's create a pooled resource auto pool_mr = std::shared_ptr{new rmm::mr::pool_memory_resource( @@ -295,8 +297,8 @@ TEST(Raft, WorkspaceResource) auto new_mr = resource::get_workspace_resource(handle); // By this point, the orig_mr likely points to a non-existent resource; don't dereference! - ASSERT_NE(orig_mr, new_mr); - ASSERT_EQ(pool_mr.get(), new_mr->get_upstream()); + ASSERT_NE(orig_mr, rmm::device_async_resource_ref{new_mr}); + ASSERT_EQ(rmm::device_async_resource_ref{pool_mr.get()}, new_mr->get_upstream_resource()); // We can safely reset pool_mr, because the shared_ptr to the pool memory stays in the resource pool_mr.reset(); From 03b24cf01c1d583d1352e3607e664dae7c89cb1a Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Tue, 26 Mar 2024 10:13:08 -0700 Subject: [PATCH 17/18] Get rid of `cuco::sentinel` namespace (#2243) This PR removes the use of the deprecated `cuco::sentinel` namespace. Needed by https://github.com/rapidsai/rapids-cmake/pull/569 Authors: - Yunsong Wang (https://github.com/PointKernel) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/raft/pull/2243 --- .../detail/coo_spmv_strategies/hash_strategy.cuh | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh index e271f2cdbe..8c267c5e63 100644 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh @@ -236,8 +236,8 @@ class hash_strategy : public coo_spmv_strategy { return insert_type::make_from_uninitialized_slots(cooperative_groups::this_thread_block(), cache, cache_size, - cuco::sentinel::empty_key{value_idx{-1}}, - cuco::sentinel::empty_value{value_t{0}}); + cuco::empty_key{value_idx{-1}}, + cuco::empty_value{value_t{0}}); } __device__ inline void insert(insert_type cache, const value_idx& key, const value_t& value) @@ -247,10 +247,8 @@ class hash_strategy : public coo_spmv_strategy { __device__ inline find_type init_find(smem_type cache, const value_idx& cache_size) { - return find_type(cache, - cache_size, - cuco::sentinel::empty_key{value_idx{-1}}, - cuco::sentinel::empty_value{value_t{0}}); + return find_type( + cache, cache_size, cuco::empty_key{value_idx{-1}}, cuco::empty_value{value_t{0}}); } __device__ inline value_t find(find_type cache, const value_idx& key) From 5f2bd19276d4314556795c86fa2230ccac988583 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Thu, 28 Mar 2024 12:33:39 -0500 Subject: [PATCH 18/18] Use `conda env create --yes` instead of `--force`. (#2247) --- ci/build_docs.sh | 2 +- ci/check_style.sh | 4 ++-- ci/test_cpp.sh | 2 +- ci/test_python.sh | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ci/build_docs.sh b/ci/build_docs.sh index 3d72c815db..9605b52f8b 100755 --- a/ci/build_docs.sh +++ b/ci/build_docs.sh @@ -11,7 +11,7 @@ rapids-dependency-file-generator \ --file_key docs \ --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" | tee env.yaml -rapids-mamba-retry env create --force -f env.yaml -n docs +rapids-mamba-retry env create --yes -f env.yaml -n docs conda activate docs rapids-print-env diff --git a/ci/check_style.sh b/ci/check_style.sh index 0ee6e88e58..d7baa88e8f 100755 --- a/ci/check_style.sh +++ b/ci/check_style.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. set -euo pipefail @@ -11,7 +11,7 @@ rapids-dependency-file-generator \ --file_key checks \ --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" | tee env.yaml -rapids-mamba-retry env create --force -f env.yaml -n checks +rapids-mamba-retry env create --yes -f env.yaml -n checks conda activate checks # Run pre-commit checks diff --git a/ci/test_cpp.sh b/ci/test_cpp.sh index fb2d025f9c..f83ddf616d 100755 --- a/ci/test_cpp.sh +++ b/ci/test_cpp.sh @@ -14,7 +14,7 @@ rapids-dependency-file-generator \ --file_key test_cpp \ --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch)" | tee env.yaml -rapids-mamba-retry env create --force -f env.yaml -n test +rapids-mamba-retry env create --yes -f env.yaml -n test # Temporarily allow unbound variables for conda activation. set +u diff --git a/ci/test_python.sh b/ci/test_python.sh index aae8ae03ea..f5b188ca0b 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -14,7 +14,7 @@ rapids-dependency-file-generator \ --file_key test_python \ --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" | tee env.yaml -rapids-mamba-retry env create --force -f env.yaml -n test +rapids-mamba-retry env create --yes -f env.yaml -n test # Temporarily allow unbound variables for conda activation. set +u