From bd50c37f52eeff459ca032bd0ef5f7920c2bcc8d Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 19 Mar 2024 10:51:23 +0100 Subject: [PATCH 1/4] Fix ANN bench ground truth generation for k>1024 (#2180) Generating ANN bench ground truth is affected by bug #2171, when k>1024. This PR fixes the issue for the ground truth generation. Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2180 --- .../raft/neighbors/detail/knn_merge_parts.cuh | 3 +++ .../raft-ann-bench/generate_groundtruth/__main__.py | 13 ++++--------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh index 0395e86e43..33324714fd 100644 --- a/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh +++ b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -168,5 +169,7 @@ inline void knn_merge_parts(const value_t* inK, else if (k <= 1024) knn_merge_parts_impl( inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else + THROW("Unimplemented for k=%d, knn_merge_parts works for k<=1024", k); } } // namespace raft::neighbors::detail diff --git a/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py b/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py index f4d97edea5..a5ebb76635 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -62,17 +62,12 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"): X = cp.asarray(dataset[i : i + n_batch, :], cp.float32) - D, Ind = knn( - X, - queries, - k, - metric=metric, - handle=handle, - global_id_offset=i, # shift neighbor index by offset i - ) + D, Ind = knn(X, queries, k, metric=metric, handle=handle) handle.sync() D, Ind = cp.asarray(D), cp.asarray(Ind) + Ind += i # shift neighbor index by offset i + if distances is None: distances = D indices = Ind From e53aa0c18d77f263ed6aa058dff87d3a6dd43590 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Akif=20=C3=87=C3=96RD=C3=9CK?= Date: Tue, 19 Mar 2024 15:26:15 +0100 Subject: [PATCH 2/4] Fix bug in blockRankedReduce (#2226) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit There was a bug appearing for negative floating point numbers with a max reduce operation. The `std::numeric_limits::min()` is greater than the negative floating point values whereas we want it to be smaller than all representable values. This PR replaces the `min` with the `lowest`. Authors: - Akif ÇÖRDÜK (https://github.com/akifcorduk) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Bradley Dice (https://github.com/bdice) - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/2226 --- cpp/include/raft/util/reduction.cuh | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/util/reduction.cuh b/cpp/include/raft/util/reduction.cuh index 7e897e67f5..2c2b1aa228 100644 --- a/cpp/include/raft/util/reduction.cuh +++ b/cpp/include/raft/util/reduction.cuh @@ -157,11 +157,10 @@ DI std::pair blockRankedReduce(T val, val = values[lane]; idx = indices[lane]; } else { - // get the min if it is a max op, get the max if it is a min op - val = reduce_op(std::numeric_limits::min(), std::numeric_limits::max()) == - std::numeric_limits::min() - ? std::numeric_limits::max() - : std::numeric_limits::min(); + // get the lower_bound of the type if it is a max op, + // get the upper bound of the type if it is a min op + val = reduce_op(lower_bound(), upper_bound()) == lower_bound() ? upper_bound() + : lower_bound(); idx = -1; } __syncthreads(); From 413e34e786b1c6b50a9dc6d76ff57d1f8a31555d Mon Sep 17 00:00:00 2001 From: Mahesh Doijade <36705640+mdoijade@users.noreply.github.com> Date: Tue, 19 Mar 2024 22:33:54 +0530 Subject: [PATCH 3/4] Add fused cosine 1-NN cutlass based kernel (#2125) - Adds cosine 1-NN cutlass based kernel for SM 8.0 or higher using tensor cores. - based on 3x TF32 - unifies the fusedDistanceNN kernels for L2/cosine. - expose this API in pylibraft as `fused_distance_nn_arg_min` supporting cosine & L2 distance metrics. Authors: - Mahesh Doijade (https://github.com/mdoijade) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/2125 --- cpp/CMakeLists.txt | 4 +- cpp/bench/prims/CMakeLists.txt | 12 +- .../distance/detail/fused_distance_nn.cuh | 97 ++++ .../custom_epilogue_with_broadcast.h | 1 + .../detail/fused_distance_nn/cutlass_base.cuh | 18 +- .../epilogue_elementwise.cuh | 9 +- .../fused_distance_nn/fused_cosine_nn.cuh | 136 ++++++ .../detail/fused_distance_nn/fused_l2_nn.cuh | 146 ++++++ .../fused_distance_nn/helper_structs.cuh | 146 ++++++ .../fused_distance_nn/persistent_gemm.h | 4 +- .../predicated_tile_iterator_reduced_vec.h | 99 ++--- .../detail/fused_distance_nn/simt_kernel.cuh | 187 ++++++++ .../raft/distance/detail/masked_nn.cuh | 2 +- .../raft/distance/fused_distance_nn-ext.cuh | 84 ++++ .../raft/distance/fused_distance_nn-inl.cuh | 328 ++++++++++++++ .../raft/distance/fused_distance_nn.cuh | 24 + ...pers.cuh => fused_distance_nn_helpers.cuh} | 4 +- cpp/include/raft/distance/fused_l2_nn-ext.cuh | 8 +- cpp/include/raft/distance/fused_l2_nn-inl.cuh | 4 +- .../distance/fused_distance_nn.hpp | 62 +++ .../raft_runtime/distance/fused_l2_nn.hpp | 36 +- cpp/src/distance/fused_distance_nn.cu | 53 +++ .../distance/fused_distance_min_arg.cu | 56 +++ .../distance/fused_distance_min_arg.hpp | 144 ++++++ .../raft_runtime/distance/fused_l2_min_arg.cu | 87 +--- cpp/test/CMakeLists.txt | 1 + cpp/test/distance/fused_cosine_nn.cu | 420 ++++++++++++++++++ cpp/test/distance/fused_l2_nn.cu | 1 - .../pylibraft/distance/CMakeLists.txt | 4 +- .../pylibraft/pylibraft/distance/__init__.py | 9 +- .../pylibraft/distance/fused_distance_nn.pyx | 200 +++++++++ .../test/test_fused_distance_argmin.py | 69 +++ 32 files changed, 2281 insertions(+), 174 deletions(-) create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh create mode 100644 cpp/include/raft/distance/fused_distance_nn-ext.cuh create mode 100644 cpp/include/raft/distance/fused_distance_nn-inl.cuh create mode 100755 cpp/include/raft/distance/fused_distance_nn.cuh rename cpp/include/raft/distance/{fused_l2_nn_helpers.cuh => fused_distance_nn_helpers.cuh} (92%) create mode 100644 cpp/include/raft_runtime/distance/fused_distance_nn.hpp create mode 100644 cpp/src/distance/fused_distance_nn.cu create mode 100644 cpp/src/raft_runtime/distance/fused_distance_min_arg.cu create mode 100644 cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp create mode 100644 cpp/test/distance/fused_cosine_nn.cu create mode 100644 python/pylibraft/pylibraft/distance/fused_distance_nn.pyx create mode 100755 python/pylibraft/pylibraft/test/test_fused_distance_argmin.py diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 638ceb3b45..6107b9325a 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -256,7 +256,7 @@ endif() if(RAFT_NVTX) # This enables NVTX within the project with no option to disable it downstream. - target_link_libraries(raft INTERFACE CUDA::nvToolsExt) + target_link_libraries(raft INTERFACE CUDA::nvtx3) target_compile_definitions(raft INTERFACE NVTX_ENABLED) else() # Allow enable NVTX downstream if not set here. This creates a new option at build/install time, @@ -324,6 +324,7 @@ if(RAFT_COMPILE_LIBRARY) src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu src/distance/distance.cu src/distance/fused_l2_nn.cu + src/distance/fused_distance_nn.cu src/linalg/detail/coalesced_reduction.cu src/matrix/detail/select_k_double_int64_t.cu src/matrix/detail/select_k_double_uint32_t.cu @@ -422,6 +423,7 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/cluster/update_centroids.cuh src/raft_runtime/cluster/update_centroids_double.cu src/raft_runtime/cluster/update_centroids_float.cu + src/raft_runtime/distance/fused_distance_min_arg.cu src/raft_runtime/distance/fused_l2_min_arg.cu src/raft_runtime/distance/pairwise_distance.cu src/raft_runtime/matrix/select_k_float_int64_t.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 5577881ef7..903f4e4347 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -122,16 +122,8 @@ if(BUILD_PRIMS_BENCH) ) ConfigureBench( - NAME - MATRIX_BENCH - PATH - bench/prims/matrix/argmin.cu - bench/prims/matrix/gather.cu - bench/prims/matrix/select_k.cu - bench/prims/matrix/main.cpp - OPTIONAL - LIB - EXPLICIT_INSTANTIATE_ONLY + NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu + bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( diff --git a/cpp/include/raft/distance/detail/fused_distance_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn.cuh new file mode 100644 index 0000000000..4fbfdc8755 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn.cuh @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include +#include +#include +#include // PairwiseDistances +#include +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +#include // size_t +#include // std::numeric_limits + +namespace raft { +namespace distance { + +namespace detail { + +template +void fusedDistanceNNImpl(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedDistanceNN. + typedef Policy P; + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef KeyValuePair KVPair; + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + switch (metric) { + case raft::distance::DistanceType::CosineExpanded: + fusedCosineNN( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream); + break; + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2Expanded: + // initOutBuffer is take care by fusedDistanceNNImpl() so we set it false to fusedL2NNImpl. + fusedL2NNImpl( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream); + break; + default: assert("only cosine/l2 metric is supported with fusedDistanceNN\n"); break; + } +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h index 32e9214a0d..186715851b 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h @@ -611,6 +611,7 @@ class EpilogueWithBroadcastCustom : public EpilogueBase +#include + #include #include #include @@ -46,6 +48,14 @@ namespace raft { namespace distance { namespace detail { +template +RAFT_KERNEL initBinMutexKernel(cuda::binary_semaphore* mut, IdxT m) +{ + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + + if (tid < m) { mut[tid].release(); } +} + template ; constexpr int batch_count = 1; + rmm::device_uvector> bin_mutex(m, stream); + + int blks_ = (m / 256) + 1; + + initBinMutexKernel<<>>(bin_mutex.data(), m); + typename EpilogueOutputOp::Params epilog_op_param( - dist_op, cg_reduce_op, redOp, pairRedOp, mutexes); + dist_op, cg_reduce_op, redOp, pairRedOp, mutexes, bin_mutex.data()); // Number of pipelines you want to use constexpr int NumStages = 3; diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh index f9ab394585..e69b2486df 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh @@ -56,6 +56,8 @@ #pragma once +#include + #include #include #include @@ -121,6 +123,7 @@ class FusedDistanceNNEpilogueElementwise { KVPReduceOpT_ pair_redop_; ReduceOpT_ red_op_; int* mutexes_; + cuda::binary_semaphore* bin_mutex_; using CGReduceT = CGReduceOp_; // // Methods @@ -130,12 +133,14 @@ class FusedDistanceNNEpilogueElementwise { CGReduceOp cg_reduce_op, ReduceOpT_ red_op, KVPReduceOpT_ pair_redop, - int* mutexes) + int* mutexes, + cuda::binary_semaphore* bin_mutex) : cg_reduce_op(cg_reduce_op), dist_op_(dist_op), pair_redop_(pair_redop), red_op_(red_op), - mutexes_(mutexes) + mutexes_(mutexes), + bin_mutex_(bin_mutex) { } diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh new file mode 100644 index 0000000000..f29c8b4d4c --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include +#include // PairwiseDistances +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +#include // size_t +#include // std::numeric_limits + +namespace raft { +namespace distance { + +namespace detail { + +template +void fusedCosineNN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedL2NN. + typedef Policy P; + + dim3 blk(P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef KeyValuePair KVPair; + + namespace arch = raft::util::arch; + using AccT = DataT; + ops::cosine_distance_op distance_op{}; + + raft::identity_op fin_op{}; + + auto kernel = fusedDistanceNNkernel; + + // Get pointer to fp32 SIMT kernel to determine the runtime architecture of the + // current system. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using cosineOp = raft::distance::detail::ops::cosine_cutlass_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + kvp_cg_min_reduce_op_ cg_reduce_op; + cosineOp cosine_dist_op; + + IdxT lda, ldb, ldd; + lda = k, ldb = k, ldd = n; + + cutlassFusedDistanceNN(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + min, + workspace, + cg_reduce_op, + cosine_dist_op, + redOp, + pairRedOp, + stream); + } else { + // If device less than SM_80, use fp32 SIMT kernel. + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + RAFT_CUDA_TRY(cudaGetLastError()); + } +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh new file mode 100644 index 0000000000..65475e73c7 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include +#include // PairwiseDistances +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +#include // size_t +#include // std::numeric_limits + +namespace raft { +namespace distance { + +namespace detail { + +template +void fusedL2NNImpl(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedL2NN. + typedef Policy P; + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef KeyValuePair KVPair; + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + namespace arch = raft::util::arch; + using AccT = DataT; + ops::l2_exp_distance_op distance_op{sqrt}; + + raft::identity_op fin_op{}; + + auto kernel = fusedDistanceNNkernel; + + // Get pointer to fp32 SIMT kernel to determine the best compute architecture + // out of all for which the kernel was compiled for that matches closely + // to the current device. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using L2Op = raft::distance::detail::ops::l2_exp_cutlass_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + kvp_cg_min_reduce_op_ cg_reduce_op; + L2Op L2_dist_op(sqrt); + + IdxT lda, ldb, ldd; + lda = k, ldb = k, ldd = n; + + cutlassFusedDistanceNN(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + min, + workspace, + cg_reduce_op, + L2_dist_op, + redOp, + pairRedOp, + stream); + } else { + // If device less than SM_80, use fp32 SIMT kernel. + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + RAFT_CUDA_TRY(cudaGetLastError()); + } +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh new file mode 100644 index 0000000000..e056c5d397 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include +#include +#include // PairwiseDistances +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl +#include + +#include // size_t +#include // std::numeric_limits + +namespace raft { +namespace distance { + +namespace detail { + +template +struct KVPMinReduceImpl { + typedef raft::KeyValuePair KVP; + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + +}; // KVPMinReduce + +template +struct MinAndDistanceReduceOpImpl { + typedef typename raft::KeyValuePair KVP; + + DI void operator()(LabelT rid, KVP* out, const KVP& other) const + { + if (other.value < out->value) { + out->key = other.key; + out->value = other.value; + } + } + DI void operator()(LabelT rid, volatile KVP* out, const KVP& other) const + { + if (other.value < out->value) { + out->key = other.key; + out->value = other.value; + } + } + + DI void operator()(LabelT rid, DataT* out, const KVP& other) const + { + if (other.value < *out) { *out = other.value; } + } + + DI void operator()(LabelT rid, volatile DataT* out, const KVP& other) const + { + if (other.value < *out) { *out = other.value; } + } + + DI void operator()(LabelT rid, DataT* out, const DataT& other) const + { + if (other < *out) { *out = other; } + } + + DI void operator()(LabelT rid, volatile DataT* out, const DataT& other) const + { + if (other < *out) { *out = other; } + } + + DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } + DI void init(KVP* out, DataT maxVal) const + { + out->value = maxVal; + out->key = 0xfffffff0; + } + + DI void init_key(DataT& out, LabelT idx) const { return; } + DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } + + DI DataT get_value(KVP& out) const { return out.value; } + DI DataT get_value(DataT& out) const { return out; } +}; + +template +struct MinReduceOpImpl { + typedef typename raft::KeyValuePair KVP; + DI void operator()(LabelT rid, DataT* out, const KVP& other) + { + if (other.value < *out) { *out = other.value; } + } + + DI void init(DataT* out, DataT maxVal) { *out = maxVal; } +}; + +template +RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) +{ + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (tid < m) { redOp.init(min + tid, maxVal); } +} + +template +void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, cudaStream_t stream) +{ + auto blks = raft::ceildiv(m, 256); + initKernel<<>>(min, m, maxVal, redOp); +} + +// cg::reduce functor for FusedDistanceNN used in its cutlass version +// to output the min distance value & key(loc id). +// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h +// store_with_byte_offset() passed to cg::reduce() & select_reduce. +template +struct kvp_cg_min_reduce_op { + typedef typename raft::KeyValuePair KVP; + + __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; + + using AccTypeT = AccType; + using IndexT = Index; + // functor signature. + __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } + + __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } + + __host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } +}; + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index 4f05251705..f1a7c728e9 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -180,8 +180,7 @@ struct FusedDistanceNNPersistent { /// Default ctor CUTLASS_HOST_DEVICE Arguments() - : // problem_count(0), - threadblock_count(0), + : threadblock_count(0), ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), @@ -235,7 +234,6 @@ struct FusedDistanceNNPersistent { /// Parameters structure struct Params { - // typename ProblemVisitor::Params problem_visitor; temp_problem_visitor problem_visitor; int threadblock_count; diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index 936b9f2c89..d61018593f 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -322,10 +322,8 @@ class PredicatedTileIteratorReducedVec { /// Parameters structure containing reference and precomputed state. Params params_; - /// Byte-level pointer - uint8_t* byte_pointer_; /// Byte-level pointer first tile offset of this threadblock. - uint8_t* first_tile_byte_pointer_; + volatile uint8_t* first_tile_byte_pointer_; /// Array of boolean values to contain steady-state predicates Mask mask_; @@ -350,6 +348,8 @@ class PredicatedTileIteratorReducedVec { /// Scatter indices int const* indices_; + const int do_gmem_reduction_; + // // Static asserts about internal strides // @@ -360,7 +360,6 @@ class PredicatedTileIteratorReducedVec { protected: SharedStorage& shared_storage_; - const bool& do_gmem_reduction_; private: // @@ -374,10 +373,10 @@ class PredicatedTileIteratorReducedVec { CUTLASS_DEVICE PredicatedTileIteratorReducedVec(SharedStorage& shared_storage, Params const& params, - Element* pointer, + volatile Element* pointer, TensorCoord extent, int thread_idx, - const bool& do_gmem_reduction, + const bool do_gmem_reduction, TensorCoord threadblock_offset = TensorCoord(), int const* indices = nullptr) : params_(params), @@ -409,6 +408,7 @@ class PredicatedTileIteratorReducedVec { EpilogueOpParams const& user_params = params_.user_param; shared_storage_.initSmem(user_params); } + __syncthreads(); // Null pointer performs no accesses if (!pointer) { mask_.clear(); } @@ -416,66 +416,53 @@ class PredicatedTileIteratorReducedVec { if (ScatterD && !indices) { mask_.clear(); } // Initialize pointer - first_tile_byte_pointer_ = reinterpret_cast(pointer) + + first_tile_byte_pointer_ = reinterpret_cast(pointer) + LongIndex(block_offset.row()) * LongIndex(params_.stride); - if (ScatterD) { - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; - } - // Initialize internal state counter state_[0] = state_[1] = state_[2] = 0; } - /// Destructor - CUTLASS_DEVICE - ~PredicatedTileIteratorReducedVec() + CUTLASS_DEVICE void dumpToGmem() { + if (block_start_row_first_tile_ >= extent_row_) { return; } + if (do_gmem_reduction_) { EpilogueOpParams const& user_params = params_.user_param; - auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - Element* shared_elem_arr = shared_storage_.data(); const uint32_t mutex_id = (block_start_row_first_tile_ / total_rows); - bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); - // If this is not optimal grid size perform mutex based gmem reduce. - if (useGmemMutex) { - // single lock per block for multiple rows - if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { - // acquire mutex lock. - unsigned int ns = 8; - while (atomicCAS(user_params.mutexes_ + mutex_id, 0, 1) == 1) { - __nanosleep(ns); - if (ns < 256) { ns *= 2; } - } + const bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); + int row = threadIdx.x; + Element* shared_elem_arr = shared_storage_.data(); + Element row_local_min; + if (row < total_rows) { row_local_min = shared_elem_arr[row]; } + + // single lock per block for multiple rows + if (useGmemMutex && threadIdx.x == 0) { user_params.bin_mutex_[mutex_id].acquire(); } + __syncthreads(); + + if (row < total_rows) { + volatile Element* gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + + if ((block_start_row_first_tile_ + row) < extent_row_) { + user_params.red_op_(block_start_row_first_tile_ + row, (gmem_ptr + row), row_local_min); } } __syncthreads(); - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - if (block_start_row_first_tile_ + row < extent_row_) { - user_params.red_op_( - block_start_row_first_tile_ + row, &gmem_ptr[row], shared_elem_arr[row]); - } - } + __threadfence(); - if (useGmemMutex) { - __threadfence(); - __syncthreads(); - if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { - // release mutex lock. - atomicExch(user_params.mutexes_ + mutex_id, 0); - } + if (useGmemMutex && (threadIdx.x == 0)) { + // release mutex lock. + user_params.bin_mutex_[mutex_id].release(); } + shared_storage_.initSmem(user_params); + __syncthreads(); } } - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - byte_pointer_ += pointer_offset * sizeof_bits::value / 8; - } + /// Destructor + CUTLASS_DEVICE + ~PredicatedTileIteratorReducedVec() {} /// Performs reduction and Stores a reduced output to memory CUTLASS_DEVICE @@ -515,9 +502,6 @@ class PredicatedTileIteratorReducedVec { user_params.red_op_.init(&red_val, maxVal); if (row_guard) { - const int iter_row = (row_id % total_rows); - const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); - CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn * kElementsPerAccess; ++column) { @@ -536,6 +520,10 @@ class PredicatedTileIteratorReducedVec { user_params.red_op_(row_id, &red_val, this_val); } } + } + const int iter_row = (row_id % total_rows); + const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); + if (row_guard) { // select_reduce doesn't need to use `red_op_` as at the warp level we use cg_reduce_op, // this satisfies the requirement of mst/single linkage of checking colors buffer. select_reduce red_obj( @@ -544,6 +532,7 @@ class PredicatedTileIteratorReducedVec { } } } + __syncthreads(); } /// Stores a fragment to memory @@ -575,14 +564,11 @@ class PredicatedTileIteratorReducedVec { { ++state_[0]; - if (!ScatterD) { byte_pointer_ += params_.advance_row; } - thread_start_row_ += ThreadMap::Shape::kRow; if (state_[0] == ThreadMap::Count::kRow) { state_[0] = 0; ++state_[1]; - byte_pointer_ += params_.advance_group; thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; @@ -590,18 +576,13 @@ class PredicatedTileIteratorReducedVec { if (state_[1] == ThreadMap::Count::kGroup) { state_[1] = 0; ++state_[2]; - byte_pointer_ += params_.advance_cluster; thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; - if (state_[2] == ThreadMap::Count::kCluster) { - state_[2] = 0; - byte_pointer_ += params_.advance_tile; - } + if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; } } } - return *this; } diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh new file mode 100644 index 0000000000..7417fd5dac --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // ops::l2_exp_distance_op +#include // PairwiseDistances +#include // Policy + +#include // size_t +#include // std::numeric_limits + +namespace raft { +namespace distance { +namespace detail { + +// TODO: specialize this function for MinAndDistanceReduceOp +// with atomicCAS of 64 bit which will eliminate mutex and shfls +template +DI void updateReducedVal( + int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, IdxT m, IdxT gridStrideY) +{ + const auto lid = threadIdx.x % raft::WarpSize; + const auto accrowid = threadIdx.x / P::AccThCols; + + // Update each output row in order within a warp. This will resolve hang + // issues with pre-Volta architectures +#pragma unroll + for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { + if (lid == j * P::AccThCols) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rid = gridStrideY + accrowid + i * P::AccThRows; + if (rid < m) { + auto value = val[i]; + while (atomicCAS(mutex + rid, 0, 1) == 1) + ; + __threadfence(); + red_op(rid, min + rid, value); + __threadfence(); + atomicCAS(mutex + rid, 1, 0); + } + } + } + } +} + +template +__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedDistanceNNkernel(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + DataT maxVal, + int* mutex, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + OpT distance_op, + FinalLambda fin_op) +{ +// compile only if below non-ampere arch. +#if __CUDA_ARCH__ < 800 + extern __shared__ char smem[]; + + typedef KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {0, maxVal}; + } + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( + DataT acc[P::AccRowsPerTh][P::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + + // intra thread reduce + const auto acccolid = threadIdx.x % P::AccThCols; + const auto accrowid = threadIdx.x / P::AccThCols; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < n) { + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + } + }; + + auto rowEpilog_lambda = + [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); + + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + + // reduce +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = P::AccThCols / 2; j > 0; j >>= 1) { + // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, + // but the shfl op applies the modulo internally. + auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); + KVPair tmp = {tmpkey, tmpvalue}; + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + + updateReducedVal(mutex, min, val, red_op, m, gridStrideY); + + // reset the val array. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {0, maxVal}; + } + }; + + IdxT lda = k, ldb = k, ldd = n; + constexpr bool row_major = true; + constexpr bool write_out = false; + PairwiseDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + xn, + yn, + nullptr, // Output pointer + smem, + distance_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); +#endif +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 536e4937ab..3e3699766f 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/cpp/include/raft/distance/fused_distance_nn-ext.cuh b/cpp/include/raft/distance/fused_distance_nn-ext.cuh new file mode 100644 index 0000000000..263bbcea81 --- /dev/null +++ b/cpp/include/raft/distance/fused_distance_nn-ext.cuh @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::KeyValuePair +#include // raft::resources +#include // include initialize and reduce operations +#include // RAFT_EXPLICIT + +#include // int64_t + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft { +namespace distance { + +template +void fusedDistanceNNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) RAFT_EXPLICIT; + +} // namespace distance +} // namespace raft + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_distance_fusedDistanceNNMinReduce(DataT, OutT, IdxT) \ + extern template void raft::distance::fusedDistanceNNMinReduce( \ + OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + bool isRowMajor, \ + raft::distance::DistanceType metric, \ + float metric_arg, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedDistanceNNMinReduce diff --git a/cpp/include/raft/distance/fused_distance_nn-inl.cuh b/cpp/include/raft/distance/fused_distance_nn-inl.cuh new file mode 100644 index 0000000000..ffe86a1c04 --- /dev/null +++ b/cpp/include/raft/distance/fused_distance_nn-inl.cuh @@ -0,0 +1,328 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __FUSED_DISTANCE_NN_H +#define __FUSED_DISTANCE_NN_H + +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +namespace raft { +namespace distance { + +/** + * \ingroup fused_l2_nn + * @{ + */ +/** + * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. + * + * The benefits of such a call are 2-fold: 1) eliminate the need for an + * intermediate buffer to store the output of gemm 2) reduce the memory read + * traffic on this intermediate buffer, otherwise needed during the reduction + * phase for 1-NN. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances or store only the min distances. Accordingly, one + * has to pass an appropriate `ReduceOpT` + * @tparam IdxT indexing arithmetic type + * @tparam ReduceOpT A struct to perform the final needed reduction operation + * and also to initialize the output array elements with the + * appropriate initial value needed for reduction. + * @tparam KVPReduceOpT A struct providing functions for key-value pair comparison. + * + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] redOp reduction operator in the epilogue + * @param[in] pairRedOp reduction operation on key value pairs + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] isRowMajor whether the input/output is row or column major. + * @param[in] metric Distance metric to be used (supports L2, cosine) + * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) + * @param[in] stream cuda stream + */ +template +void fusedDistanceNN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + ASSERT(isRowMajor, "fusedDistanceNN only supports row major inputs"); + // When k is smaller than 32, the Policy4x4 results in redundant calculations + // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead + // that uses tiles with a smaller value of k. + bool is_skinny = k < 32; + + size_t bytes = sizeof(DataT) * k; + auto px = reinterpret_cast(x); + auto py = reinterpret_cast(y); + if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { + if (is_skinny) { + detail::fusedDistanceNNImpl< + DataT, + OutT, + IdxT, + typename linalg::Policy4x4Skinny::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } else { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } + } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { + if (is_skinny) { + detail::fusedDistanceNNImpl< + DataT, + OutT, + IdxT, + typename linalg::Policy4x4Skinny::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } else { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } + } else { + if (is_skinny) { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } else { + detail::fusedDistanceNNImpl::Policy, + ReduceOpT>(min, + x, + y, + xn, + yn, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); + } + } +} + +/** + * @brief Wrapper around fusedDistanceNN with minimum reduction operators. + * + * fusedDistanceNN cannot be compiled in the distance library due to the lambda + * operators, so this wrapper covers the most common case (minimum). + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances (e.g. raft::KeyValuePair) or store only the min + * distances. + * @tparam IdxT indexing arithmetic type + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] isRowMajor whether the input/output is row or column major. + * @param[in] metric Distance metric to be used (supports L2, cosine) + * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) + * @param[in] stream cuda stream + */ +template +void fusedDistanceNNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + raft::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + MinAndDistanceReduceOp redOp; + KVPMinReduce pairRedOp; + + fusedDistanceNN(min, + x, + y, + xn, + yn, + m, + n, + k, + workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + stream); +} + +/** @} */ + +} // namespace distance +} // namespace raft + +#endif diff --git a/cpp/include/raft/distance/fused_distance_nn.cuh b/cpp/include/raft/distance/fused_distance_nn.cuh new file mode 100755 index 0000000000..04c42e49a1 --- /dev/null +++ b/cpp/include/raft/distance/fused_distance_nn.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "fused_distance_nn-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "fused_distance_nn-ext.cuh" +#endif diff --git a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh b/cpp/include/raft/distance/fused_distance_nn_helpers.cuh similarity index 92% rename from cpp/include/raft/distance/fused_l2_nn_helpers.cuh rename to cpp/include/raft/distance/fused_distance_nn_helpers.cuh index 996f696ef6..3a570c681c 100644 --- a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh +++ b/cpp/include/raft/distance/fused_distance_nn_helpers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ #pragma once #include -#include +#include namespace raft::distance { diff --git a/cpp/include/raft/distance/fused_l2_nn-ext.cuh b/cpp/include/raft/distance/fused_l2_nn-ext.cuh index c7dc4c5fc6..d0ac83cd51 100644 --- a/cpp/include/raft/distance/fused_l2_nn-ext.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-ext.cuh @@ -16,10 +16,10 @@ #pragma once -#include // raft::KeyValuePair -#include // raft::resources -#include // include initialize and reduce operations -#include // RAFT_EXPLICIT +#include // raft::KeyValuePair +#include // raft::resources +#include // include initialize and reduce operations +#include // RAFT_EXPLICIT #include // int64_t diff --git a/cpp/include/raft/distance/fused_l2_nn-inl.cuh b/cpp/include/raft/distance/fused_l2_nn-inl.cuh index 6f4db15c72..bf9a49d813 100644 --- a/cpp/include/raft/distance/fused_l2_nn-inl.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-inl.cuh @@ -20,8 +20,8 @@ #pragma once #include -#include -#include +#include +#include #include #include diff --git a/cpp/include/raft_runtime/distance/fused_distance_nn.hpp b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp new file mode 100644 index 0000000000..7c309d6fc7 --- /dev/null +++ b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::distance { + +/** + * @defgroup fused_distance_nn_min_arg_runtime Fused Distance 1NN Runtime API + * @{ + */ + +/** + * @brief Wrapper around fusedDistanceNN with minimum reduction operators. + * + * fusedDistanceNN cannot be compiled in the distance library due to the lambda + * operators, so this wrapper covers the most common case (minimum). + * + * @param[in] handle raft handle + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] metric Distance metric to be used (supports L2, cosine) + * @param[in] isRowMajor whether the input/output is row or column major. + * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) + */ +void fused_distance_nn_min_arg(raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + +/** @} */ // end group fused_distance_nn_min_arg_runtime + +} // end namespace raft::runtime::distance diff --git a/cpp/include/raft_runtime/distance/fused_l2_nn.hpp b/cpp/include/raft_runtime/distance/fused_l2_nn.hpp index 6154e03f4c..e46b3c5271 100644 --- a/cpp/include/raft_runtime/distance/fused_l2_nn.hpp +++ b/cpp/include/raft_runtime/distance/fused_l2_nn.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,23 +42,25 @@ namespace raft::runtime::distance { * @param[in] k gemm k * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt */ -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const float* x, - const float* y, - int m, - int n, - int k, - bool sqrt); +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt); -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const double* x, - const double* y, - int m, - int n, - int k, - bool sqrt); +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const double* x, + const double* y, + int m, + int n, + int k, + bool sqrt); /** @} */ // end group fused_l2_nn_min_arg_runtime diff --git a/cpp/src/distance/fused_distance_nn.cu b/cpp/src/distance/fused_distance_nn.cu new file mode 100644 index 0000000000..dc722d929c --- /dev/null +++ b/cpp/src/distance/fused_distance_nn.cu @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // raft::KeyValuePair +#include + +#include // int64_t + +#define instantiate_raft_distance_fusedDistanceNNMinReduce(DataT, OutT, IdxT) \ + template void raft::distance::fusedDistanceNNMinReduce( \ + OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + bool isRowMajor, \ + raft::distance::DistanceType metric, \ + float metric_arg, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedDistanceNNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedDistanceNNMinReduce diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu new file mode 100644 index 0000000000..dfdff4e94b --- /dev/null +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fused_distance_min_arg.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft::runtime::distance { + +void fused_distance_nn_min_arg(raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) +{ + switch (metric) { + case raft::distance::DistanceType::CosineExpanded: + compute_fused_cosine_nn_min_arg(handle, min, x, y, m, n, k, sqrt); + break; + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: + compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); + break; + default: assert("only Cosine/L2 metric is supported with fusedDistanceNN\n"); break; + } +} + +} // end namespace raft::runtime::distance diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp b/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp new file mode 100644 index 0000000000..6452752a79 --- /dev/null +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.hpp @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft::runtime::distance { + +template +struct KeyValueIndexOp { + __host__ __device__ __forceinline__ IndexT + operator()(const raft::KeyValuePair& a) const + { + return a.key; + } +}; + +template +void compute_fused_l2_nn_min_arg(raft::resources const& handle, + idx_t* min, + const value_t* x, + const value_t* y, + idx_t m, + idx_t n, + idx_t k, + bool sqrt) +{ + rmm::device_uvector workspace(m, resource::get_cuda_stream(handle)); + auto kvp = raft::make_device_vector>(handle, m); + constexpr bool is_row_major = true; + + rmm::device_uvector x_norms(m, resource::get_cuda_stream(handle)); + rmm::device_uvector y_norms(n, resource::get_cuda_stream(handle)); + raft::linalg::rowNorm( + x_norms.data(), x, k, m, raft::linalg::L2Norm, is_row_major, resource::get_cuda_stream(handle)); + raft::linalg::rowNorm( + y_norms.data(), y, k, n, raft::linalg::L2Norm, is_row_major, resource::get_cuda_stream(handle)); + + raft::distance::fusedL2NNMinReduce(kvp.data_handle(), + x, + y, + x_norms.data(), + y_norms.data(), + m, + n, + k, + (void*)workspace.data(), + sqrt, + true, + resource::get_cuda_stream(handle)); + + KeyValueIndexOp conversion_op; + thrust::transform(resource::get_thrust_policy(handle), + kvp.data_handle(), + kvp.data_handle() + m, + min, + conversion_op); + resource::sync_stream(handle); +} + +template +void compute_fused_cosine_nn_min_arg(raft::resources const& handle, + idx_t* min, + const value_t* x, + const value_t* y, + idx_t m, + idx_t n, + idx_t k, + bool sqrt) +{ + rmm::device_uvector workspace(m, resource::get_cuda_stream(handle)); + auto kvp = raft::make_device_vector>(handle, m); + + rmm::device_uvector x_norms(m, resource::get_cuda_stream(handle)); + rmm::device_uvector y_norms(n, resource::get_cuda_stream(handle)); + constexpr bool is_row_major = true; + raft::linalg::rowNorm(x_norms.data(), + x, + k, + m, + raft::linalg::L2Norm, + is_row_major, + resource::get_cuda_stream(handle), + raft::sqrt_op{}); + raft::linalg::rowNorm(y_norms.data(), + y, + k, + n, + raft::linalg::L2Norm, + is_row_major, + resource::get_cuda_stream(handle), + raft::sqrt_op{}); + + raft::distance::fusedDistanceNNMinReduce(kvp.data_handle(), + x, + y, + x_norms.data(), + y_norms.data(), + m, + n, + k, + (void*)workspace.data(), + sqrt, + true, + is_row_major, + raft::distance::DistanceType::CosineExpanded, + 0.0f, + resource::get_cuda_stream(handle)); + + KeyValueIndexOp conversion_op; + thrust::transform(resource::get_thrust_policy(handle), + kvp.data_handle(), + kvp.data_handle() + m, + min, + conversion_op); + resource::sync_stream(handle); +} + +} // end namespace raft::runtime::distance diff --git a/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu index bf4b1a431a..870757dca1 100644 --- a/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu +++ b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu @@ -14,6 +14,8 @@ * limitations under the License. */ +#include "fused_distance_min_arg.hpp" + #include #include #include @@ -28,77 +30,28 @@ namespace raft::runtime::distance { -template -struct KeyValueIndexOp { - __host__ __device__ __forceinline__ IndexT - operator()(const raft::KeyValuePair& a) const - { - return a.key; - } -}; - -template -void compute_fused_l2_nn_min_arg(raft::resources const& handle, - idx_t* min, - const value_t* x, - const value_t* y, - idx_t m, - idx_t n, - idx_t k, - bool sqrt) -{ - rmm::device_uvector workspace(m, resource::get_cuda_stream(handle)); - auto kvp = raft::make_device_vector>(handle, m); - - rmm::device_uvector x_norms(m, resource::get_cuda_stream(handle)); - rmm::device_uvector y_norms(n, resource::get_cuda_stream(handle)); - raft::linalg::rowNorm( - x_norms.data(), x, k, m, raft::linalg::L2Norm, true, resource::get_cuda_stream(handle)); - raft::linalg::rowNorm( - y_norms.data(), y, k, n, raft::linalg::L2Norm, true, resource::get_cuda_stream(handle)); - - raft::distance::fusedL2NNMinReduce(kvp.data_handle(), - x, - y, - x_norms.data(), - y_norms.data(), - m, - n, - k, - (void*)workspace.data(), - sqrt, - true, - resource::get_cuda_stream(handle)); - - KeyValueIndexOp conversion_op; - thrust::transform(resource::get_thrust_policy(handle), - kvp.data_handle(), - kvp.data_handle() + m, - min, - conversion_op); - resource::sync_stream(handle); -} - -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const float* x, - const float* y, - int m, - int n, - int k, - bool sqrt) +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt) { compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); } -void fused_l2_nn_min_arg(raft::resources const& handle, - int* min, - const double* x, - const double* y, - int m, - int n, - int k, - bool sqrt) +[[deprecated("use fused_distance_nn_min_arg instead")]] void fused_l2_nn_min_arg( + raft::resources const& handle, + int* min, + const double* x, + const double* y, + int m, + int n, + int k, + bool sqrt) { compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); } diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 65d6e738a2..bf44cf9c60 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -171,6 +171,7 @@ if(BUILD_TESTS) test/distance/masked_nn.cu test/distance/masked_nn_compress_to_bits.cu test/distance/fused_l2_nn.cu + test/distance/fused_cosine_nn.cu test/distance/gram.cu LIB EXPLICIT_INSTANTIATE_ONLY diff --git a/cpp/test/distance/fused_cosine_nn.cu b/cpp/test/distance/fused_cosine_nn.cu new file mode 100644 index 0000000000..d4d632e1dc --- /dev/null +++ b/cpp/test/distance/fused_cosine_nn.cu @@ -0,0 +1,420 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Search with filter instantiation + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace raft { +namespace distance { + +template +struct RaftKVPMinReduce { + typedef raft::KeyValuePair KVP; + + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + +}; // KVPMinReduce + +template +__global__ void naiveCosKernel(raft::KeyValuePair* min, + DataT* x, + DataT* y, + int m, + int n, + int k, + int* workspace, + DataT maxVal) +{ + int midx = threadIdx.y + blockIdx.y * blockDim.y; + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + DataT acc_a = DataT(0); + DataT acc_b = DataT(0); + DataT acc_ab = DataT(0); + // if (midx >= m || nidx >= n) { return; } + + for (int i = 0; i < k; ++i) { + int xidx = i + midx * k; + int yidx = i + nidx * k; + auto a = x[xidx]; + auto b = y[yidx]; + acc_a += a * a; + acc_b += b * b; + acc_ab += a * b; + } + + // Use 1.0 - (cosine similarity) to calc the distance + DataT acc = maxVal; + if (midx < m || nidx < n) { acc = (DataT)1.0 - acc_ab / (raft::sqrt(acc_a) * raft::sqrt(acc_b)); } + + ReduceOpT redOp; + typedef cub::WarpReduce> WarpReduce; + __shared__ typename WarpReduce::TempStorage temp[NWARPS]; + int warpId = threadIdx.x / raft::WarpSize; + raft::KeyValuePair tmp; + tmp.key = nidx; + tmp.value = midx >= m || nidx >= n ? maxVal : acc; + tmp = WarpReduce(temp[warpId]).Reduce(tmp, RaftKVPMinReduce()); + if (threadIdx.x % raft::WarpSize == 0 && midx < m) { + while (atomicCAS(workspace + midx, 0, 1) == 1) + ; + __threadfence(); + redOp(midx, min + midx, tmp); + __threadfence(); + atomicCAS(workspace + midx, 1, 0); + } +} + +template +void naive(raft::KeyValuePair* min, + DataT* x, + DataT* y, + int m, + int n, + int k, + int* workspace, + cudaStream_t stream) +{ + static const dim3 TPB(32, 16, 1); + dim3 nblks(raft::ceildiv(n, (int)TPB.x), raft::ceildiv(m, (int)TPB.y), 1); + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + auto blks = raft::ceildiv(m, 256); + MinAndDistanceReduceOp op; + detail::initKernel, int> + <<>>(min, m, std::numeric_limits::max(), op); + RAFT_CUDA_TRY(cudaGetLastError()); + naiveCosKernel, 16> + <<>>(min, x, y, m, n, k, workspace, std::numeric_limits::max()); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +template +struct Inputs { + DataT tolerance; + int m, n, k; + unsigned long long int seed; + + friend std::ostream& operator<<(std::ostream& os, const Inputs& p) + { + return os << "m: " << p.m + << ", " + "n: " + << p.n + << ", " + "k: " + << p.k + << ", " + "seed: " + << p.seed + << ", " + "tol: " + << p.tolerance; + } +}; + +template +class FusedCosineNNTest : public ::testing::TestWithParam> { + public: + FusedCosineNNTest() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + x(params.m * params.k, stream), + y(params.n * params.k, stream), + xn(params.m, stream), + yn(params.n, stream), + min(params.m, stream), + min_ref(params.m, stream), + workspace(params.m * sizeof(int), stream) + { + } + + protected: + void SetUp() override + { + raft::random::RngState r(params.seed); + int m = params.m; + int n = params.n; + int k = params.k; + uniform(handle, r, x.data(), m * k, DataT(-1.0), DataT(1.0)); + uniform(handle, r, y.data(), n * k, DataT(-1.0), DataT(1.0)); + generateGoldenResult(); + raft::linalg::rowNorm( + xn.data(), x.data(), k, m, raft::linalg::L2Norm, true, stream, raft::sqrt_op{}); + raft::linalg::rowNorm( + yn.data(), y.data(), k, n, raft::linalg::L2Norm, true, stream, raft::sqrt_op{}); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + Inputs params; + rmm::device_uvector x; + rmm::device_uvector y; + rmm::device_uvector xn; + rmm::device_uvector yn; + rmm::device_uvector> min; + rmm::device_uvector> min_ref; + rmm::device_uvector workspace; + + virtual void generateGoldenResult() + { + int m = params.m; + int n = params.n; + int k = params.k; + naive(min_ref.data(), x.data(), y.data(), m, n, k, (int*)workspace.data(), stream); + } + + void runTest(raft::KeyValuePair* out) + { + int m = params.m; + int n = params.n; + int k = params.k; + raft::distance::DistanceType metric = raft::distance::DistanceType::CosineExpanded; + constexpr bool init_out_buffer = true; + fusedDistanceNNMinReduce, int>(out, + x.data(), + y.data(), + xn.data(), + yn.data(), + m, + n, + k, + (void*)workspace.data(), + false, + init_out_buffer, + true, + metric, + 0.0f, + stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } +}; + +template +struct CompareApproxAbsKVP { + typedef typename raft::KeyValuePair KVP; + CompareApproxAbsKVP(T eps_) : eps(eps_) {} + bool operator()(const KVP& a, const KVP& b) const + { + T diff = std::abs(std::abs(a.value) - std::abs(b.value)); + T m = std::max(std::abs(a.value), std::abs(b.value)); + T ratio = m >= eps ? diff / m : diff; + return (ratio <= eps); + } + + private: + T eps; +}; + +template +struct CompareExactKVP { + typedef typename raft::KeyValuePair KVP; + bool operator()(const KVP& a, const KVP& b) const + { + if (a.value != b.value) return false; + return true; + } +}; + +template +::testing::AssertionResult devArrMatch(const raft::KeyValuePair* expected, + const raft::KeyValuePair* actual, + size_t size, + L eq_compare, + cudaStream_t stream = 0) +{ + typedef typename raft::KeyValuePair KVP; + std::shared_ptr exp_h(new KVP[size]); + std::shared_ptr act_h(new KVP[size]); + raft::update_host(exp_h.get(), expected, size, stream); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < size; ++i) { + auto exp = exp_h.get()[i]; + auto act = act_h.get()[i]; + if (!eq_compare(exp, act)) { + return ::testing::AssertionFailure() + << "actual=" << act.key << "," << act.value << " != expected=" << exp.key << "," + << exp.value << " @" << i; + } + } + return ::testing::AssertionSuccess(); +} + +const std::vector> inputsf = { + {0.001f, 32, 32, 32, 1234ULL}, + {0.001f, 32, 64, 32, 1234ULL}, + {0.001f, 64, 32, 32, 1234ULL}, + {0.001f, 64, 64, 32, 1234ULL}, + {0.001f, 128, 32, 32, 1234ULL}, + {0.001f, 128, 64, 32, 1234ULL}, + {0.001f, 128, 128, 64, 1234ULL}, + {0.001f, 64, 128, 128, 1234ULL}, + + {0.001f, 32, 32, 34, 1234ULL}, + {0.001f, 32, 64, 34, 1234ULL}, + {0.001f, 64, 32, 34, 1234ULL}, + {0.001f, 64, 64, 34, 1234ULL}, + {0.001f, 128, 32, 34, 1234ULL}, + {0.001f, 128, 64, 34, 1234ULL}, + {0.001f, 128, 128, 66, 1234ULL}, + {0.001f, 64, 128, 130, 1234ULL}, + + {0.001f, 32, 32, 33, 1234ULL}, + {0.001f, 32, 64, 33, 1234ULL}, + {0.001f, 64, 32, 33, 1234ULL}, + {0.001f, 64, 64, 33, 1234ULL}, + {0.001f, 128, 32, 33, 1234ULL}, + {0.001f, 128, 64, 33, 1234ULL}, + {0.001f, 128, 128, 65, 1234ULL}, + {0.001f, 64, 128, 129, 1234ULL}, + {0.006f, 1805, 134, 2, 1234ULL}, + {0.006f, 8192, 1024, 64, 1234ULL}, + {0.006f, 8192, 1025, 64, 1234ULL}, + + // Repeat with smaller values of k + {0.006f, 32, 32, 1, 1234ULL}, + {0.001f, 32, 64, 2, 1234ULL}, + {0.001f, 64, 32, 3, 1234ULL}, + {0.001f, 64, 64, 4, 1234ULL}, + {0.001f, 128, 32, 5, 1234ULL}, + {0.001f, 128, 64, 6, 1234ULL}, + {0.001f, 128, 128, 7, 1234ULL}, + {0.001f, 64, 128, 8, 1234ULL}, + + {0.001f, 32, 32, 9, 1234ULL}, + {0.001f, 32, 64, 10, 1234ULL}, + {0.001f, 64, 32, 11, 1234ULL}, + {0.001f, 64, 64, 12, 1234ULL}, + {0.001f, 128, 32, 13, 1234ULL}, + {0.001f, 128, 64, 14, 1234ULL}, + {0.001f, 128, 128, 15, 1234ULL}, + {0.001f, 64, 128, 16, 1234ULL}, + + {0.001f, 32, 32, 17, 1234ULL}, + {0.001f, 32, 64, 18, 1234ULL}, + {0.001f, 64, 32, 19, 1234ULL}, + {0.001f, 64, 64, 20, 1234ULL}, + {0.001f, 128, 32, 21, 1234ULL}, + {0.001f, 128, 64, 22, 1234ULL}, + {0.001f, 128, 128, 23, 1234ULL}, + {0.00001, 64, 128, 24, 1234ULL}, + {0.001f, 1805, 134, 25, 1234ULL}, + {0.006f, 8192, 1024, 25, 1234ULL}, + {0.006f, 8192, 1024, 66, 1234ULL}, +}; +typedef FusedCosineNNTest FusedCosineNNTestF; +TEST_P(FusedCosineNNTestF, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNTests, FusedCosineNNTestF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.00001, 32, 32, 32, 1234ULL}, {0.00001, 32, 64, 32, 1234ULL}, + {0.00001, 64, 32, 32, 1234ULL}, {0.00001, 64, 64, 32, 1234ULL}, + {0.00001, 128, 32, 32, 1234ULL}, {0.00001, 128, 64, 32, 1234ULL}, + {0.00001, 128, 128, 64, 1234ULL}, {0.00001, 64, 128, 128, 1234ULL}, + + {0.00001, 32, 32, 34, 1234ULL}, {0.00001, 32, 64, 34, 1234ULL}, + {0.00001, 64, 32, 34, 1234ULL}, {0.00001, 64, 64, 34, 1234ULL}, + {0.00001, 128, 32, 34, 1234ULL}, {0.00001, 128, 64, 34, 1234ULL}, + {0.00001, 128, 128, 66, 1234ULL}, {0.00001, 64, 128, 130, 1234ULL}, + + {0.00001, 32, 32, 33, 1234ULL}, {0.00001, 32, 64, 33, 1234ULL}, + {0.00001, 64, 32, 33, 1234ULL}, {0.00001, 64, 64, 33, 1234ULL}, + {0.00001, 128, 32, 33, 1234ULL}, {0.00001, 128, 64, 33, 1234ULL}, + {0.00001, 128, 128, 65, 1234ULL}, {0.00001, 64, 128, 129, 1234ULL}, + + {0.00001, 1805, 134, 2, 1234ULL}, {0.00001, 8192, 1024, 25, 1234ULL}, +}; +typedef FusedCosineNNTest FusedCosineNNTestD; +TEST_P(FusedCosineNNTestD, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNTests, FusedCosineNNTestD, ::testing::ValuesIn(inputsd)); + +/// This is to test output determinism of the prim +template +class FusedCosineNNDetTest : public FusedCosineNNTest { + public: + FusedCosineNNDetTest() : stream(resource::get_cuda_stream(handle)), min1(0, stream) {} + + void SetUp() override + { + FusedCosineNNTest::SetUp(); + int m = this->params.m; + min1.resize(m, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void TearDown() override { FusedCosineNNTest::TearDown(); } + + protected: + raft::resources handle; + cudaStream_t stream; + + rmm::device_uvector> min1; + + static const int NumRepeats = 3; + + void generateGoldenResult() override {} +}; + +typedef FusedCosineNNDetTest FusedCosineNNDetTestF; +TEST_P(FusedCosineNNDetTestF, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + cudaMemsetAsync(min1.data(), 0, sizeof(*min.data()) * params.m, stream); + } +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNDetTests, FusedCosineNNDetTestF, ::testing::ValuesIn(inputsf)); + +typedef FusedCosineNNDetTest FusedCosineNNDetTestD; +TEST_P(FusedCosineNNDetTestD, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + } +} +INSTANTIATE_TEST_CASE_P(FusedCosineNNDetTests, FusedCosineNNDetTestD, ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index d21a525d88..6fd8f15808 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -18,7 +18,6 @@ #include #include -#include #include #include #include diff --git a/python/pylibraft/pylibraft/distance/CMakeLists.txt b/python/pylibraft/pylibraft/distance/CMakeLists.txt index 14f0cc441a..2530e07a98 100644 --- a/python/pylibraft/pylibraft/distance/CMakeLists.txt +++ b/python/pylibraft/pylibraft/distance/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -13,7 +13,7 @@ # ============================================================================= # Set the list of Cython files to build -set(cython_sources pairwise_distance.pyx fused_l2_nn.pyx) +set(cython_sources pairwise_distance.pyx fused_l2_nn.pyx fused_distance_nn.pyx) set(linked_libraries raft::raft raft::compiled) # Build all of the Cython targets diff --git a/python/pylibraft/pylibraft/distance/__init__.py b/python/pylibraft/pylibraft/distance/__init__.py index f059b5f3dd..d16ab30b2f 100644 --- a/python/pylibraft/pylibraft/distance/__init__.py +++ b/python/pylibraft/pylibraft/distance/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,12 @@ # limitations under the License. # +from .fused_distance_nn import fused_distance_nn_argmin from .fused_l2_nn import fused_l2_nn_argmin from .pairwise_distance import DISTANCE_TYPES, distance as pairwise_distance -__all__ = ["fused_l2_nn_argmin", "pairwise_distance"] +__all__ = [ + "fused_distance_nn_argmin", + "fused_l2_nn_argmin", + "pairwise_distance", +] diff --git a/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx new file mode 100644 index 0000000000..0e9fa4b366 --- /dev/null +++ b/python/pylibraft/pylibraft/distance/fused_distance_nn.pyx @@ -0,0 +1,200 @@ +# +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +from cython.operator cimport dereference as deref +from libc.stdint cimport uintptr_t +from libcpp cimport bool + +from .distance_type cimport DistanceType + +from pylibraft.common import ( + Handle, + auto_convert_output, + cai_wrapper, + device_ndarray, +) +from pylibraft.common.handle import auto_sync_handle + +from pylibraft.common.handle cimport device_resources + + +cdef extern from "raft_runtime/distance/fused_distance_nn.hpp" \ + namespace "raft::runtime::distance" nogil: + + void fused_distance_nn_min_arg( + const device_resources &handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt, + DistanceType metric, + bool isRowMajor, + float metric_arg) except + + + +from pylibraft.distance.pairwise_distance import DISTANCE_TYPES + +SUPPORTED_DISTANCES = ["euclidean", "l2", "cosine", "sqeuclidean"] + + +@auto_sync_handle +@auto_convert_output +def fused_distance_nn_argmin(X, Y, out=None, sqrt=True, metric="euclidean", + handle=None): + """ + Compute the 1-nearest neighbors between X and Y using the distance metrics + + Valid values for metric: + ["euclidean", "l2", "cosine", "sqeuclidean"] + + Parameters + ---------- + + X : CUDA array interface compliant matrix shape (m, k) + Y : CUDA array interface compliant matrix shape (n, k) + out : Writable CUDA array interface matrix shape (m, 1) + metric : string denoting the metric type (default="euclidean") + + {handle_docstring} + + Examples + -------- + To compute the 1-nearest neighbors argmin: + + >>> import cupy as cp + >>> from pylibraft.common import Handle + >>> from pylibraft.distance import fused_distance_nn_argmin + >>> n_samples = 5000 + >>> n_clusters = 5 + >>> n_features = 50 + >>> in1 = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> in2 = cp.random.random_sample((n_clusters, n_features), + ... dtype=cp.float32) + >>> # A single RAFT handle can optionally be reused across + >>> # pylibraft functions. + >>> handle = Handle() + + >>> output = fused_distance_nn_argmin(in1, in2, handle=handle) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + + The output can also be computed in-place on a preallocated + array: + + >>> import cupy as cp + >>> from pylibraft.common import Handle + >>> from pylibraft.distance import fused_distance_nn_argmin + >>> n_samples = 5000 + >>> n_clusters = 5 + >>> n_features = 50 + >>> in1 = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> in2 = cp.random.random_sample((n_clusters, n_features), + ... dtype=cp.float32) + >>> output = cp.empty((n_samples, 1), dtype=cp.int32) + >>> # A single RAFT handle can optionally be reused across + >>> # pylibraft functions. + >>> handle = Handle() + + >>> fused_distance_nn_argmin(in1, in2, out=output, handle=handle) + array(...) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + """ + + x_cai = cai_wrapper(X) + y_cai = cai_wrapper(Y) + + x_dt = x_cai.dtype + y_dt = y_cai.dtype + + m = x_cai.shape[0] + n = y_cai.shape[0] + + if out is None: + output = device_ndarray.empty((m,), dtype="int32") + else: + output = out + + output_cai = cai_wrapper(output) + + x_k = x_cai.shape[1] + y_k = y_cai.shape[1] + + if x_k != y_k: + raise ValueError("Inputs must have same number of columns. " + "a=%s, b=%s" % (x_k, y_k)) + + if metric not in SUPPORTED_DISTANCES: + raise ValueError("metric %s is not supported" % metric) + + cdef DistanceType distance_type = DISTANCE_TYPES[metric] + + x_ptr = x_cai.data + y_ptr = y_cai.data + + d_ptr = output_cai.data + + handle = handle if handle is not None else Handle() + cdef device_resources *h = handle.getHandle() + + d_dt = output_cai.dtype + + x_c_contiguous = x_cai.c_contiguous + y_c_contiguous = y_cai.c_contiguous + + if x_c_contiguous != y_c_contiguous: + raise ValueError("Inputs must have matching strides") + + if not x_c_contiguous: + raise ValueError("Inputs must be C contiguous") + + if x_dt != y_dt: + raise ValueError("Inputs must have the same dtypes") + if d_dt != np.int32: + raise ValueError("Output array must be int32") + # unused arg for now. + metric_arg = 0.0 + if x_dt == np.float32: + fused_distance_nn_min_arg(deref(h), + d_ptr, + x_ptr, + y_ptr, + m, + n, + x_k, + sqrt, + distance_type, + x_c_contiguous, + metric_arg) + else: + raise ValueError("dtype %s not supported" % x_dt) + + return output diff --git a/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py b/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py new file mode 100755 index 0000000000..6736128242 --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_fused_distance_argmin.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from scipy.spatial.distance import cdist + +from pylibraft.common import DeviceResources, device_ndarray +from pylibraft.distance import fused_distance_nn_argmin + + +@pytest.mark.parametrize("inplace", [True, False]) +@pytest.mark.parametrize("n_rows", [10, 100]) +@pytest.mark.parametrize("n_clusters", [50, 100]) +@pytest.mark.parametrize("n_cols", [128, 31]) +@pytest.mark.parametrize("dtype", [np.float32]) +@pytest.mark.parametrize( + "metric", + [ + "euclidean", + "cosine", + "sqeuclidean", + ], +) +def test_fused_distance_nn_minarg( + n_rows, n_cols, n_clusters, dtype, inplace, metric +): + input1 = np.random.random_sample((n_rows, n_cols)) + input1 = np.asarray(input1, order="C").astype(dtype) + + input2 = np.random.random_sample((n_clusters, n_cols)) + input2 = np.asarray(input2, order="C").astype(dtype) + + output = np.zeros((n_rows), dtype="int32") + expected = cdist(input1, input2, metric) + + expected = expected.argmin(axis=1) + + input1_device = device_ndarray(input1) + input2_device = device_ndarray(input2) + output_device = device_ndarray(output) if inplace else None + + is_sqrt = True if metric == "sqeuclidean" else False + handle = DeviceResources() + ret_output = fused_distance_nn_argmin( + input1_device, + input2_device, + output_device, + is_sqrt, + metric, + handle=handle, + ) + handle.sync() + output_device = ret_output if not inplace else output_device + actual = output_device.copy_to_host() + + assert np.allclose(expected, actual, rtol=1e-4) From 0b9692b25f78cd1b27631e354e3f8921a976645c Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 19 Mar 2024 18:07:21 +0100 Subject: [PATCH 4/4] random sampling of dataset rows with improved memory utilization (#2155) The random sampling of IVF methods was reverted (#2144) due to large memory utilization #2141. This PR improves the memory consumption of subsamling: it is O(n_train) where n_train is the size of the subsampled dataset. This PR adds the following new APIs: - random::excess_sampling (todo may just call as sample_without_replacement) - matrix::sample_rows - matrix::gather for host input matrix Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/2155 --- cpp/bench/prims/CMakeLists.txt | 2 +- cpp/bench/prims/matrix/gather.cu | 38 ++++- cpp/bench/prims/random/subsample.cu | 112 ++++++++++++++ cpp/include/raft/matrix/detail/gather.cuh | 87 +++++++++++ .../raft/matrix/detail/sample_rows.cuh | 57 +++++++ cpp/include/raft/matrix/sample_rows.cuh | 75 ++++++++++ cpp/include/raft/random/detail/rng_impl.cuh | 138 ++++++++++++++++- cpp/include/raft/random/rng.cuh | 26 ++++ cpp/test/CMakeLists.txt | 2 + cpp/test/matrix/sample_rows.cu | 140 ++++++++++++++++++ cpp/test/random/excess_sampling.cu | 114 ++++++++++++++ 11 files changed, 786 insertions(+), 5 deletions(-) create mode 100644 cpp/bench/prims/random/subsample.cu create mode 100644 cpp/include/raft/matrix/detail/sample_rows.cuh create mode 100644 cpp/include/raft/matrix/sample_rows.cuh create mode 100644 cpp/test/matrix/sample_rows.cu create mode 100644 cpp/test/random/excess_sampling.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 903f4e4347..95361e19ca 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -128,7 +128,7 @@ if(BUILD_PRIMS_BENCH) ConfigureBench( NAME RANDOM_BENCH PATH bench/prims/random/make_blobs.cu bench/prims/random/permute.cu - bench/prims/random/rng.cu bench/prims/main.cpp + bench/prims/random/rng.cu bench/prims/random/subsample.cu bench/prims/main.cpp ) ConfigureBench(NAME SPARSE_BENCH PATH bench/prims/sparse/convert_csr.cu bench/prims/main.cpp) diff --git a/cpp/bench/prims/matrix/gather.cu b/cpp/bench/prims/matrix/gather.cu index e6f26ba925..078f9e6198 100644 --- a/cpp/bench/prims/matrix/gather.cu +++ b/cpp/bench/prims/matrix/gather.cu @@ -16,34 +16,48 @@ #include +#include +#include #include #include #include #include #include +#include namespace raft::bench::matrix { template struct GatherParams { IdxT rows, cols, map_length; + bool host; }; template inline auto operator<<(std::ostream& os, const GatherParams& p) -> std::ostream& { - os << p.rows << "#" << p.cols << "#" << p.map_length; + os << p.rows << "#" << p.cols << "#" << p.map_length << (p.host ? "#host" : "#device"); return os; } template struct Gather : public fixture { Gather(const GatherParams& p) - : params(p), matrix(this->handle), map(this->handle), out(this->handle), stencil(this->handle) + : params(p), + old_mr(rmm::mr::get_current_device_resource()), + pool_mr(rmm::mr::get_current_device_resource(), 2 * (1ULL << 30)), + matrix(this->handle), + map(this->handle), + out(this->handle), + stencil(this->handle), + matrix_h(this->handle) { + rmm::mr::set_current_device_resource(&pool_mr); } + ~Gather() { rmm::mr::set_current_device_resource(old_mr); } + void allocate_data(const ::benchmark::State& state) override { matrix = raft::make_device_matrix(handle, params.rows, params.cols); @@ -59,6 +73,11 @@ struct Gather : public fixture { if constexpr (Conditional) { raft::random::uniform(handle, rng, stencil.data_handle(), params.map_length, T(-1), T(1)); } + + if (params.host) { + matrix_h = raft::make_host_matrix(handle, params.rows, params.cols); + raft::copy(matrix_h.data_handle(), matrix.data_handle(), matrix.size(), stream); + } resource::sync_stream(handle, stream); } @@ -77,14 +96,22 @@ struct Gather : public fixture { raft::matrix::gather_if( handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op); } else { - raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view()); + if (params.host) { + raft::matrix::detail::gather( + handle, make_const_mdspan(matrix_h.view()), map_const_view, out.view()); + } else { + raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view()); + } } }); } private: GatherParams params; + rmm::mr::device_memory_resource* old_mr; + rmm::mr::pool_memory_resource pool_mr; raft::device_matrix matrix, out; + raft::host_matrix matrix_h; raft::device_vector stencil; raft::device_vector map; }; // struct Gather @@ -100,4 +127,9 @@ RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); + +auto inputs_host = raft::util::itertools::product>( + {10000000}, {100}, {1000, 1000000, 10000000}, {true}); +RAFT_BENCH_REGISTER((Gather), "Host", inputs_host); + } // namespace raft::bench::matrix diff --git a/cpp/bench/prims/random/subsample.cu b/cpp/bench/prims/random/subsample.cu new file mode 100644 index 0000000000..4c8ca2bf31 --- /dev/null +++ b/cpp/bench/prims/random/subsample.cu @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace raft::bench::random { + +struct sample_inputs { + int n_samples; + int n_train; + int method; +}; // struct sample_inputs + +inline auto operator<<(std::ostream& os, const sample_inputs& p) -> std::ostream& +{ + os << p.n_samples << "#" << p.n_train << "#" << p.method; + return os; +} + +// Sample with replacement. We use this as a baseline. +template +auto bernoulli_subsample(raft::resources const& res, IdxT n_samples, IdxT n_subsamples, int seed) + -> raft::device_vector +{ + RAFT_EXPECTS(n_subsamples <= n_samples, "Cannot have more training samples than dataset vectors"); + + auto indices = raft::make_device_vector(res, n_subsamples); + raft::random::RngState state(123456ULL); + raft::random::uniformInt( + res, state, indices.data_handle(), n_subsamples, IdxT(0), IdxT(n_samples)); + return indices; +} + +template +struct sample : public fixture { + sample(const sample_inputs& p) + : params(p), + old_mr(rmm::mr::get_current_device_resource()), + pool_mr(rmm::mr::get_current_device_resource(), 2 * GiB), + in(make_device_vector(res, p.n_samples)), + out(make_device_vector(res, p.n_train)) + { + rmm::mr::set_current_device_resource(&pool_mr); + raft::random::RngState r(123456ULL); + } + + ~sample() { rmm::mr::set_current_device_resource(old_mr); } + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + raft::random::RngState r(123456ULL); + loop_on_state(state, [this, &r]() { + if (params.method == 1) { + this->out = + bernoulli_subsample(this->res, this->params.n_samples, this->params.n_train, 137); + } else if (params.method == 2) { + this->out = raft::random::excess_subsample( + this->res, r, this->params.n_samples, this->params.n_train); + } + }); + } + + private: + float GiB = 1073741824.0f; + raft::device_resources res; + rmm::mr::device_memory_resource* old_mr; + rmm::mr::pool_memory_resource pool_mr; + sample_inputs params; + raft::device_vector out, in; +}; // struct sample + +const std::vector input_vecs = {{100000000, 10000000, 1}, + {100000000, 50000000, 1}, + {100000000, 100000000, 1}, + {100000000, 10000000, 2}, + {100000000, 50000000, 2}, + {100000000, 100000000, 2}}; + +RAFT_BENCH_REGISTER(sample, "", input_vecs); + +} // namespace raft::bench::random diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index 651fec81c3..05cc9204bf 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -16,9 +16,19 @@ #pragma once +#include +#include +#include +#include +#include #include +#include +#include +#include #include +#include + #include namespace raft { @@ -336,6 +346,83 @@ void gather_if(const InputIteratorT in, gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream); } +/** + * Helper function to gather a set of vectors from a (host) dataset. + */ +template +void gather_buff(host_matrix_view dataset, + host_vector_view indices, + MatIdxT offset, + pinned_matrix_view buff) +{ + raft::common::nvtx::range fun_scope("gather_host_buff"); + IdxT batch_size = std::min(buff.extent(0), indices.extent(0) - offset); + +#pragma omp for + for (IdxT i = 0; i < batch_size; i++) { + IdxT in_idx = indices(offset + i); + for (IdxT k = 0; k < buff.extent(1); k++) { + buff(i, k) = dataset(in_idx, k); + } + } +} + +template +void gather(raft::resources const& res, + host_matrix_view dataset, + device_vector_view indices, + raft::device_matrix_view output) +{ + raft::common::nvtx::range fun_scope("gather"); + IdxT n_dim = output.extent(1); + IdxT n_train = output.extent(0); + auto indices_host = raft::make_host_vector(n_train); + raft::copy( + indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res)); + resource::sync_stream(res); + + const size_t buffer_size = 32768 * 1024; // bytes + const size_t max_batch_size = + std::min(round_up_safe(buffer_size / n_dim, 32), n_train); + RAFT_LOG_DEBUG("Gathering data with batch size %zu", max_batch_size); + + // Gather the vector on the host in tmp buffers. We use two buffers to overlap H2D sync + // and gathering the data. + auto out_tmp1 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + auto out_tmp2 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + + // Usually a limited number of threads provide sufficient bandwidth for gathering data. + int n_threads = std::min(omp_get_max_threads(), 32); + + // The gather_buff function has a parallel for loop. We start the the omp parallel + // region here, to avoid repeated overhead within the device_offset loop. +#pragma omp parallel num_threads(n_threads) + { + auto view1 = out_tmp1.view(); + auto view2 = out_tmp2.view(); + gather_buff(dataset, make_const_mdspan(indices_host.view()), (MatIdxT)0, view1); + for (MatIdxT device_offset = 0; device_offset < n_train; device_offset += max_batch_size) { + MatIdxT batch_size = std::min(max_batch_size, n_train - device_offset); + +#pragma omp master + raft::copy(output.data_handle() + device_offset * n_dim, + view1.data_handle(), + batch_size * n_dim, + resource::get_cuda_stream(res)); + // Start gathering the next batch on the host. + MatIdxT host_offset = device_offset + batch_size; + batch_size = std::min(max_batch_size, n_train - host_offset); + if (batch_size > 0) { + gather_buff(dataset, make_const_mdspan(indices_host.view()), host_offset, view2); + } +#pragma omp master + resource::sync_stream(res); +#pragma omp barrier + std::swap(view1, view2); + } + } +} + } // namespace detail } // namespace matrix } // namespace raft diff --git a/cpp/include/raft/matrix/detail/sample_rows.cuh b/cpp/include/raft/matrix/detail/sample_rows.cuh new file mode 100644 index 0000000000..e28ad648da --- /dev/null +++ b/cpp/include/raft/matrix/detail/sample_rows.cuh @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::matrix::detail { + +/** Select rows randomly from input and copy to output. */ +template +void sample_rows(raft::resources const& res, + random::RngState random_state, + const T* input, + IdxT n_rows_input, + raft::device_matrix_view output) +{ + IdxT n_dim = output.extent(1); + IdxT n_samples = output.extent(0); + + raft::device_vector train_indices = + raft::random::excess_subsample(res, random_state, n_rows_input, n_samples); + + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, input)); + T* ptr = reinterpret_cast(attr.devicePointer); + if (ptr != nullptr) { + raft::matrix::gather(res, + raft::make_device_matrix_view(ptr, n_rows_input, n_dim), + raft::make_const_mdspan(train_indices.view()), + output); + } else { + auto dataset = raft::make_host_matrix_view(input, n_rows_input, n_dim); + raft::matrix::detail::gather(res, dataset, make_const_mdspan(train_indices.view()), output); + } +} +} // namespace raft::matrix::detail diff --git a/cpp/include/raft/matrix/sample_rows.cuh b/cpp/include/raft/matrix/sample_rows.cuh new file mode 100644 index 0000000000..7925d344e4 --- /dev/null +++ b/cpp/include/raft/matrix/sample_rows.cuh @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft::matrix { + +/** @brief Select rows randomly from input and copy to output. + * + * The rows are selected randomly. The random sampling method does not guarantee completely unique + * selection of rows, but it is close to being unique. + * + * @param res RAFT resource handle + * @param random_state + * @param dataset input dataset + * @param output subsampled dataset + */ +template +void sample_rows(raft::resources const& res, + random::RngState random_state, + mdspan, row_major, accessor> dataset, + raft::device_matrix_view output) +{ + RAFT_EXPECTS(dataset.extent(1) == output.extent(1), + "dataset dims must match, but received %ld vs %ld", + static_cast(dataset.extent(1)), + static_cast(output.extent(1))); + detail::sample_rows(res, random_state, dataset.data_handle(), dataset.extent(0), output); +} + +/** @brief Select rows randomly from input and copy to output. + * + * The rows are selected randomly. The random sampling method does not guarantee completely unique + * selection of rows, but it is close to being unique. + * + * @param res RAFT resource handle + * @param random_state + * @param dataset input dataset + * @param n_samples number of rows in the returned matrix + * + * @return subsampled dataset + * */ +template +raft::device_matrix sample_rows( + raft::resources const& res, + random::RngState random_state, + mdspan, row_major, accessor> dataset, + IdxT n_samples) +{ + auto output = raft::make_device_matrix(res, n_samples, dataset.extent(1)); + sample_rows(res, random_state, dataset, output.view()); + return output; +} + +} // namespace raft::matrix diff --git a/cpp/include/raft/random/detail/rng_impl.cuh b/cpp/include/raft/random/detail/rng_impl.cuh index 57f4c8d33d..61a944e9b6 100644 --- a/cpp/include/raft/random/detail/rng_impl.cuh +++ b/cpp/include/raft/random/detail/rng_impl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,12 +17,20 @@ #pragma once #include +#include +#include +#include +#include #include #include #include #include #include +#include + +#include + namespace raft { namespace random { namespace detail { @@ -278,6 +286,7 @@ std::enable_if_t> discrete(RngState& rng_state, len); } +/** Note the memory space requirements are O(4*len) */ template void sampleWithoutReplacement(RngState& rng_state, DataT* out, @@ -328,6 +337,133 @@ void affine_transform_params(RngState const& rng_state, IdxT n, IdxT& a, IdxT& b b = mt_rng() % n; } +/** @brief Sample without replacement from range 0..N-1. + * + * Elements are sampled uniformly. + * The algorithm will allocate a workspace of size O(4*n_samples) internally. + * + * We use max N random numbers. Depending on how large n_samples is w.r.t to N, we + * either use rejection sampling, or sort the [0..N-1] values using random keys. + * + * @tparam IdxT type of indices that we sample + * @tparam MatIdxT extent type of the returned mdarray + * + * @param res RAFT resource handle + * @param state random number generator state + * @param N number of elements to sample from. We will sample values in range 0..N-1 + * @param n_samples number of samples to return + * + * @return device mdarray with the random samples + */ +template +auto excess_subsample(raft::resources const& res, RngState& state, IdxT N, IdxT n_samples) + -> raft::device_vector +{ + RAFT_EXPECTS(n_samples <= N, "Cannot have more training samples than dataset vectors"); + + // Number of samples we'll need to sample (with replacement), to expect 'k' + // unique samples from 'n' is given by the following equation: log(1 - k/n)/log(1 - 1/n) ref: + // https://stats.stackexchange.com/questions/296005/the-expected-number-of-unique-elements-drawn-with-replacement + IdxT n_excess_samples = + n_samples < N + ? std::ceil(raft::log(1 - double(n_samples) / double(N)) / (raft::log(1 - 1 / double(N)))) + : N; + + // There is a variance of n_excess_samples, we take 10% more elements. + n_excess_samples += std::max(0.1 * n_samples, 100); + + // n_excess_sampless will be larger than N around k = 0.64*N. When we reach N, then instead of + // doing rejection sampling, we simply shuffle the range [0..N-1] using N random numbers. + n_excess_samples = std::min(n_excess_samples, N); + auto rnd_idx = raft::make_device_vector(res, n_excess_samples); + + auto linear_idx = raft::make_device_vector(res, rnd_idx.size()); + raft::linalg::map_offset(res, linear_idx.view(), identity_op()); + + uniformInt(res, state, rnd_idx.data_handle(), rnd_idx.size(), IdxT(0), IdxT(N)); + + // Sort indices according to rnd keys + size_t workspace_size = 0; + auto stream = resource::get_cuda_stream(res); + cub::DeviceMergeSort::SortPairs(nullptr, + workspace_size, + rnd_idx.data_handle(), + linear_idx.data_handle(), + rnd_idx.size(), + raft::less_op{}, + stream); + auto workspace = raft::make_device_vector(res, workspace_size); + cub::DeviceMergeSort::SortPairs(workspace.data_handle(), + workspace_size, + rnd_idx.data_handle(), + linear_idx.data_handle(), + rnd_idx.size(), + raft::less_op{}, + stream); + + if (rnd_idx.size() == static_cast(N)) { + // We shuffled the linear_idx array by sorting it according to rnd_idx. + // We return the first n_samples elements. + if (n_samples == N) { return linear_idx; } + rnd_idx = raft::make_device_vector(res, n_samples); + raft::copy(rnd_idx.data_handle(), linear_idx.data_handle(), n_samples, stream); + return rnd_idx; + } + // Else we do a rejection sampling (or excess sampling): we generated more random indices than + // needed and reject the duplicates. + auto keys_out = raft::make_device_vector(res, rnd_idx.size()); + auto values_out = raft::make_device_vector(res, rnd_idx.size()); + rmm::device_scalar num_selected(stream); + size_t worksize2 = 0; + cub::DeviceSelect::UniqueByKey(nullptr, + worksize2, + rnd_idx.data_handle(), + linear_idx.data_handle(), + keys_out.data_handle(), + values_out.data_handle(), + num_selected.data(), + rnd_idx.size(), + stream); + + if (worksize2 > workspace.size()) { + workspace = raft::make_device_vector(res, worksize2); + workspace_size = workspace.size(); + } + + cub::DeviceSelect::UniqueByKey(workspace.data_handle(), + workspace_size, + rnd_idx.data_handle(), + linear_idx.data_handle(), + keys_out.data_handle(), + values_out.data_handle(), + num_selected.data(), + rnd_idx.size(), + stream); + + IdxT selected = num_selected.value(stream); + + if (selected < n_samples) { + RAFT_LOG_DEBUG("Subsampling returned with less unique indices (%zu) than requested (%zu)", + (size_t)selected, + (size_t)n_samples); + + // We continue to select n_samples elements, this will now contains a few duplicates. + } + + // After duplicates are removed, we need to shuffle back to random order + cub::DeviceMergeSort::SortPairs(workspace.data_handle(), + workspace_size, + values_out.data_handle(), + keys_out.data_handle(), + n_samples, + raft::less_op{}, + stream); + + values_out = raft::make_device_vector(res, n_samples); + raft::copy(values_out.data_handle(), keys_out.data_handle(), n_samples, stream); + return values_out; +} + }; // end namespace detail }; // end namespace random }; // end namespace raft diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index 4e63669f98..7fd461980f 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -813,6 +813,32 @@ void sampleWithoutReplacement(raft::resources const& handle, rng_state, out, outIdx, in, wts, sampledLen, len, resource::get_cuda_stream(handle)); } +/** @brief Sample from range 0..N-1. + * + * Elements are sampled uniformly. The method aims to sample without replacement, + * but there is a small probability of a few having duplicate elements. + * + * The algorithm will allocate a workspace of size 4*n_samples*sizeof(IdxT) internally. + * + * We use max N random numbers. Depending on how large n_samples is w.r.t to N, we + * either use rejection sampling, or sort the [0..N-1] values using random keys. + * + * @tparam IdxT type of indices that we sample + * @tparam MatIdxT extent type of the returned mdarray + * + * @param res RAFT resource handle + * @param state random number generator state + * @param N number of elements to sample from. We will sample values in range 0..N-1. + * @param n_samples number of samples to return + * + * @return device mdarray with the random samples + */ +template +auto excess_subsample(raft::resources const& res, RngState& state, IdxT N, IdxT n_samples) +{ + return detail::excess_subsample(res, state, N, n_samples); +} + /** * @brief Generates the 'a' and 'b' parameters for a modulo affine * transformation equation: `(ax + b) % n` diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index bf44cf9c60..ecb871fccc 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -267,6 +267,7 @@ if(BUILD_TESTS) test/matrix/matrix.cu test/matrix/norm.cu test/matrix/reverse.cu + test/matrix/sample_rows.cu test/matrix/slice.cu test/matrix/triangular.cu test/sparse/spectral_matrix.cu @@ -294,6 +295,7 @@ if(BUILD_TESTS) test/random/rng_int.cu test/random/rmat_rectangular_generator.cu test/random/sample_without_replacement.cu + test/random/excess_sampling.cu ) ConfigureTest( diff --git a/cpp/test/matrix/sample_rows.cu b/cpp/test/matrix/sample_rows.cu new file mode 100644 index 0000000000..e332a918fe --- /dev/null +++ b/cpp/test/matrix/sample_rows.cu @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace raft { +namespace matrix { + +struct inputs { + int N; + int dim; + int n_samples; + bool host; +}; + +::std::ostream& operator<<(::std::ostream& os, const inputs p) +{ + os << p.N << "#" << p.dim << "#" << p.n_samples << (p.host ? "#host" : "#device"); + return os; +} + +template +class SampleRowsTest : public ::testing::TestWithParam { + public: + SampleRowsTest() + : params(::testing::TestWithParam::GetParam()), + stream(resource::get_cuda_stream(res)), + state{137ULL}, + in(make_device_matrix(res, params.N, params.dim)), + out(make_device_matrix(res, 0, 0)), + in_h(make_host_matrix(res, params.N, params.dim)), + out_h(make_host_matrix(res, params.n_samples, params.dim)) + { + raft::random::uniform(res, state, in.data_handle(), in.size(), T(-1.0), T(1.0)); + for (int64_t i = 0; i < params.N; i++) { + for (int64_t k = 0; k < params.dim; k++) + in_h(i, k) = i * 1000 + k; + } + raft::copy(in.data_handle(), in_h.data_handle(), in_h.size(), stream); + } + + void check() + { + if (params.host) { + out = raft::matrix::sample_rows( + res, state, make_const_mdspan(in_h.view()), (int64_t)params.n_samples); + } else { + out = raft::matrix::sample_rows( + res, state, make_const_mdspan(in.view()), (int64_t)params.n_samples); + } + + raft::copy(out_h.data_handle(), out.data_handle(), out.size(), stream); + resource::sync_stream(res, stream); + + ASSERT_TRUE(out.extent(0) == params.n_samples); + ASSERT_TRUE(out.extent(1) == params.dim); + + std::unordered_set occurrence; + + for (int64_t i = 0; i < params.n_samples; ++i) { + T val = out_h(i, 0) / 1000; + ASSERT_TRUE(0 <= val && val < params.N) + << "out-of-range index @i=" << i << " val=" << val << " params=" << params; + EXPECT_TRUE(occurrence.find(val) == occurrence.end()) + << "repeated index @i=" << i << " idx=" << val << " params=" << params; + occurrence.insert(val); + for (int64_t k = 0; k < params.dim; k++) { + ASSERT_TRUE(raft::match(out_h(i, k), val * 1000 + k, raft::CompareApprox(1e-6))); + } + } + } + + protected: + inputs params; + raft::resources res; + cudaStream_t stream; + random::RngState state; + device_matrix in, out; + host_matrix in_h, out_h; +}; + +inline std::vector generate_inputs() +{ + std::vector input1 = + raft::util::itertools::product({10}, {1, 17, 96}, {1, 6, 9, 10}, {false}); + + std::vector input2 = + raft::util::itertools::product({137}, {1, 17, 128}, {1, 10, 100, 137}, {false}); + input1.insert(input1.end(), input2.begin(), input2.end()); + + input2 = raft::util::itertools::product( + {100000}, {1, 42}, {1, 137, 1000, 10000, 50000, 62000, 100000}, {false}); + + input1.insert(input1.end(), input2.begin(), input2.end()); + + int n = input1.size(); + // Add same tests for host data + for (int i = 0; i < n; i++) { + inputs x = input1[i]; + x.host = true; + input1.push_back(x); + } + return input1; +} + +const std::vector inputs1 = generate_inputs(); + +using SampleRowsTestInt64 = SampleRowsTest; +TEST_P(SampleRowsTestInt64, SamplingTest) { check(); } +INSTANTIATE_TEST_SUITE_P(SampleRowsTests, SampleRowsTestInt64, ::testing::ValuesIn(inputs1)); + +} // namespace matrix +} // namespace raft diff --git a/cpp/test/random/excess_sampling.cu b/cpp/test/random/excess_sampling.cu new file mode 100644 index 0000000000..e86436fb7d --- /dev/null +++ b/cpp/test/random/excess_sampling.cu @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace raft { +namespace random { + +using namespace raft::random; + +struct inputs { + int64_t N; + int64_t n_samples; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const inputs p) +{ + os << p.N << "/" << p.n_samples; + return os; +} + +template +class ExcessSamplingTest : public ::testing::TestWithParam { + public: + ExcessSamplingTest() + : params(::testing::TestWithParam::GetParam()), + stream(resource::get_cuda_stream(res)), + state{137ULL} + { + } + + void check() + { + device_vector out = + raft::random::excess_subsample(res, state, params.N, params.n_samples); + ASSERT_TRUE(out.extent(0) == params.n_samples); + + auto h_out = make_host_vector(res, params.n_samples); + raft::copy(h_out.data_handle(), out.data_handle(), out.size(), stream); + resource::sync_stream(res, stream); + + std::unordered_set occurrence; + int64_t sum = 0; + for (int64_t i = 0; i < params.n_samples; ++i) { + T val = h_out(i); + sum += val; + ASSERT_TRUE(0 <= val && val < params.N) + << "out-of-range index @i=" << i << " val=" << val << " n_samples=" << params.n_samples; + ASSERT_TRUE(occurrence.find(val) == occurrence.end()) + << "repeated index @i=" << i << " idx=" << val; + occurrence.insert(val); + } + float avg = sum / (float)params.n_samples; + if (params.n_samples >= 100 && params.N / params.n_samples < 100) { + ASSERT_TRUE(raft::match(avg, (params.N - 1) / 2.0f, raft::CompareApprox(0.2))) + << "non-uniform sample"; + } + } + + protected: + inputs params; + raft::resources res; + cudaStream_t stream; + RngState state; +}; + +const std::vector input1 = {{1, 0}, + {1, 1}, + {10, 0}, + {10, 1}, + {10, 2}, + {10, 10}, + {137, 42}, + {200, 0}, + {200, 1}, + {200, 100}, + {200, 130}, + {200, 200}, + {10000, 893}, + {10000000000, 1023}}; + +using ExcessSamplingTestInt64 = ExcessSamplingTest; +TEST_P(ExcessSamplingTestInt64, SamplingTest) { check(); } +INSTANTIATE_TEST_SUITE_P(ExcessSamplingTests, ExcessSamplingTestInt64, ::testing::ValuesIn(input1)); + +} // namespace random +} // namespace raft