From 317a61cad71b5f83424d2481eb23b1d1b5817f40 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Mon, 22 Apr 2024 18:37:41 +0200 Subject: [PATCH] Improve coalesced reduction performance for tall and thin matrices (up to 2.6x faster) (#2259) This PR implements two optimizations to `coalescedReductionThinKernel` which is used for coalesced reductions of tall matrices (many rows) and/or thin (few columns): 1. Process multiple rows per warp to increase bytes in flight and amortize load latencies. 2. Use a vectorized reduction to avoid the LSU bottleneck and have fewer global stores (and at least partially coalesced). The benchmark below shows the achieved SOL percentage on A30. I also measured that on H200, it achieved 84% SOL for 32 columns and up to 94% for 512 columns. ![2024-04-09_coalesced_reduction_vec](https://github.com/rapidsai/raft/assets/17441062/73dabe9a-e3ad-4708-9ef8-77ca4a4c9166) Authors: - Louis Sugy (https://github.com/Nyrio) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2259 --- .../linalg/detail/coalesced_reduction-inl.cuh | 117 ++++++++++++++---- cpp/include/raft/util/pow2_utils.cuh | 13 +- cpp/include/raft/util/reduction.cuh | 104 +++++++++++++++- cpp/test/linalg/coalesced_reduction.cu | 46 +++++-- 4 files changed, 244 insertions(+), 36 deletions(-) diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh index d580ea72c1..9f3be7ce0e 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh @@ -28,11 +28,18 @@ namespace raft { namespace linalg { namespace detail { -template +template struct ReductionThinPolicy { - static constexpr int LogicalWarpSize = warpSize; - static constexpr int RowsPerBlock = rpb; - static constexpr int ThreadsPerBlock = LogicalWarpSize * RowsPerBlock; + static_assert(tpb % warpSize == 0); + + static constexpr int LogicalWarpSize = warpSize; + static constexpr int ThreadsPerBlock = tpb; + static constexpr int RowsPerLogicalWarp = rpw; + static constexpr int NumLogicalWarps = ThreadsPerBlock / LogicalWarpSize; + static constexpr int RowsPerBlock = NumLogicalWarps * RowsPerLogicalWarp; + + // Whether D (run-time arg) will be smaller than warpSize (compile-time parameter) + static constexpr bool NoSequentialReduce = noLoop; }; template (blockIdx.x)); - if (i >= N) return; + /* The strategy to achieve near-SOL memory bandwidth differs based on D: + * - For small D, we need to process multiple rows per logical warp in order to have + * multiple loads per thread and increase bytes in flight and amortize latencies. + * - For large D, we start with a sequential reduction. The compiler partially unrolls + * that loop (e.g. first a loop of stride 16, then 8, 4, and 1). + */ + IdxType i0 = threadIdx.y + (Policy::RowsPerBlock * static_cast(blockIdx.x)); + if (i0 >= N) return; - OutType acc = init; - for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) { - acc = reduce_op(acc, main_op(data[j + (D * i)], j)); + OutType acc[Policy::RowsPerLogicalWarp]; +#pragma unroll + for (int k = 0; k < Policy::RowsPerLogicalWarp; k++) { + acc[k] = init; } - acc = raft::logicalWarpReduce(acc, reduce_op); - if (threadIdx.x == 0) { + + if constexpr (Policy::NoSequentialReduce) { + IdxType j = threadIdx.x; + if (j < D) { +#pragma unroll + for (IdxType k = 0; k < Policy::RowsPerLogicalWarp; k++) { + // Only the first row is known to be within bounds. Clamp to avoid out-of-mem read. + const IdxType i = raft::min(i0 + k * Policy::NumLogicalWarps, N - 1); + acc[k] = reduce_op(acc[k], main_op(data[j + (D * i)], j)); + } + } + } else { + for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) { +#pragma unroll + for (IdxType k = 0; k < Policy::RowsPerLogicalWarp; k++) { + const IdxType i = raft::min(i0 + k * Policy::NumLogicalWarps, N - 1); + acc[k] = reduce_op(acc[k], main_op(data[j + (D * i)], j)); + } + } + } + + /* This vector reduction has two benefits compared to naive separate reductions: + * - It avoids the LSU bottleneck when the number of columns is around 32 (e.g. for 32, 5 shuffles + * are required and there is no initial sequential reduction to amortize that cost). + * - It distributes the outputs to multiple threads, enabling a coalesced store when the number of + * rows per logical warp and logical warp size are equal. + */ + raft::logicalWarpReduceVector( + acc, threadIdx.x, reduce_op); + + constexpr int reducOutVecWidth = + std::max(1, Policy::RowsPerLogicalWarp / Policy::LogicalWarpSize); + constexpr int reducOutGroupSize = + std::max(1, Policy::LogicalWarpSize / Policy::RowsPerLogicalWarp); + constexpr int reducNumGroups = Policy::LogicalWarpSize / reducOutGroupSize; + + if (threadIdx.x % reducOutGroupSize == 0) { + const int groupId = threadIdx.x / reducOutGroupSize; if (inplace) { - dots[i] = final_op(reduce_op(dots[i], acc)); +#pragma unroll + for (int k = 0; k < reducOutVecWidth; k++) { + const int reductionId = k * reducNumGroups + groupId; + const IdxType i = i0 + reductionId * Policy::NumLogicalWarps; + if (i < N) { dots[i] = final_op(reduce_op(dots[i], acc[k])); } + } } else { - dots[i] = final_op(acc); +#pragma unroll + for (int k = 0; k < reducOutVecWidth; k++) { + const int reductionId = k * reducNumGroups + groupId; + const IdxType i = i0 + reductionId * Policy::NumLogicalWarps; + if (i < N) { dots[i] = final_op(acc[k]); } + } } } } @@ -89,8 +149,12 @@ void coalescedReductionThin(OutType* dots, FinalLambda final_op = raft::identity_op()) { common::nvtx::range fun_scope( - "coalescedReductionThin<%d,%d>", Policy::LogicalWarpSize, Policy::RowsPerBlock); - dim3 threads(Policy::LogicalWarpSize, Policy::RowsPerBlock, 1); + "coalescedReductionThin<%d,%d,%d,%d>", + Policy::LogicalWarpSize, + Policy::ThreadsPerBlock, + Policy::RowsPerLogicalWarp, + static_cast(Policy::NoSequentialReduce)); + dim3 threads(Policy::LogicalWarpSize, Policy::NumLogicalWarps, 1); dim3 blocks(ceildiv(N, Policy::RowsPerBlock), 1, 1); coalescedReductionThinKernel <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); @@ -115,19 +179,28 @@ void coalescedReductionThinDispatcher(OutType* dots, FinalLambda final_op = raft::identity_op()) { if (D <= IdxType(2)) { - coalescedReductionThin>( + coalescedReductionThin>( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else if (D <= IdxType(4)) { - coalescedReductionThin>( + coalescedReductionThin>( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else if (D <= IdxType(8)) { - coalescedReductionThin>( + coalescedReductionThin>( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else if (D <= IdxType(16)) { - coalescedReductionThin>( + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D <= IdxType(32)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D < IdxType(128)) { + coalescedReductionThin>( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else { - coalescedReductionThin>( + // For D=128 (included) and above, the 4x-unrolled loading loop is used + // and multiple rows per warp are counter-productive in terms of cache-friendliness + // and register use. + coalescedReductionThin>( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } } @@ -319,10 +392,10 @@ void coalescedReductionThickDispatcher(OutType* dots, // Note: multiple elements per thread to take advantage of the sequential reduction and loop // unrolling if (D < IdxType(32768)) { - coalescedReductionThick, ReductionThinPolicy<32, 4>>( + coalescedReductionThick, ReductionThinPolicy<32, 128, 1>>( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else { - coalescedReductionThick, ReductionThinPolicy<32, 4>>( + coalescedReductionThick, ReductionThinPolicy<32, 128, 1>>( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } } diff --git a/cpp/include/raft/util/pow2_utils.cuh b/cpp/include/raft/util/pow2_utils.cuh index 68b35837b6..0c740ac5f6 100644 --- a/cpp/include/raft/util/pow2_utils.cuh +++ b/cpp/include/raft/util/pow2_utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,15 @@ namespace raft { +/** + * Checks whether an integer is a power of 2. + */ +template +constexpr HDI std::enable_if_t, bool> is_pow2(T v) +{ + return (v && !(v & (v - 1))); +} + /** * @brief Fast arithmetics and alignment checks for power-of-two values known at compile time. * @@ -33,7 +42,7 @@ struct Pow2 { static constexpr Type Mask = Value - 1; static_assert(std::is_integral::value, "Value must be integral."); - static_assert(Value && !(Value & Mask), "Value must be power of two."); + static_assert(is_pow2(Value), "Value must be power of two."); #define Pow2_FUNC_QUALIFIER static constexpr __host__ __device__ __forceinline__ #define Pow2_WHEN_INTEGRAL(I) std::enable_if_t diff --git a/cpp/include/raft/util/reduction.cuh b/cpp/include/raft/util/reduction.cuh index 2c2b1aa228..c0d3da7609 100644 --- a/cpp/include/raft/util/reduction.cuh +++ b/cpp/include/raft/util/reduction.cuh @@ -39,8 +39,8 @@ DI T logicalWarpReduce(T val, ReduceLambda reduce_op) { #pragma unroll for (int i = logicalWarpSize / 2; i > 0; i >>= 1) { - T tmp = shfl_xor(val, i); - val = reduce_op(val, tmp); + const T tmp = shfl_xor(val, i, logicalWarpSize); + val = reduce_op(val, tmp); } return val; } @@ -197,4 +197,104 @@ DI i_t binaryBlockReduce(i_t val, i_t* shmem) } } +/** + * @brief Executes a collaborative vector reduction per sub-warp + * + * This uses fewer shuffles than naively reducing each element independently. + * Better performance is achieved with a larger vector width, up to vecWidth == warpSize/2. + * For example, for logicalWarpSize == 32 and vecWidth == 16, the naive method requires 80 + * shuffles, this one only 31, 2.58x fewer. + * + * However, the output of the reduction is not broadcasted. The vector is modified in place and + * each thread holds a part of the output vector. The outputs are distributed in a round-robin + * pattern between the threads to facilitate coalesced IO. There are 2 possible layouts based on + * which of logicalWarpSize and vecWidth is larger: + * - If vecWidth >= logicalWarpSize, each thread has vecWidth/logicalWarpSize outputs. + * - If logicalWarpSize > vecWidth, logicalWarpSize/vecWidth threads have a copy of the same output. + * + * Example 1: logicalWarpSize == 4, vecWidth == 8, v = a+b+c+d + * IN OUT + * lane 0 | a0 a1 a2 a3 a4 a5 a6 a7 | v0 v4 - - - - - - + * lane 1 | b0 b1 b2 b3 b4 b5 b6 b7 | v1 v5 - - - - - - + * lane 2 | c0 c1 c2 c3 c4 c5 c6 c7 | v2 v6 - - - - - - + * lane 3 | d0 d1 d2 d3 d4 d5 d6 d7 | v3 v7 - - - - - - + * + * Example 2: logicalWarpSize == 8, vecWidth == 4, v = a+b+c+d+e+f+g+h + * IN OUT + * lane 0 | a0 a1 a2 a3 | v0 - - - + * lane 1 | b0 b1 b2 b3 | v0 - - - + * lane 2 | c0 c1 c2 c3 | v1 - - - + * lane 3 | d0 d1 d2 d3 | v1 - - - + * lane 4 | e0 e1 e2 e3 | v2 - - - + * lane 5 | f0 f1 f2 f3 | v2 - - - + * lane 6 | g0 g1 g2 g3 | v3 - - - + * lane 7 | h0 h1 h2 h3 | v3 - - - + * + * @tparam logicalWarpSize Sub-warp size. Must be 2, 4, 8, 16 or 32. + * @tparam vecWidth Vector width. Must be a power of two. + * @tparam T Vector element type. + * @tparam ReduceLambda Reduction operator type. + * @param[in,out] acc Pointer to a vector of size vecWidth or more in registers + * @param[in] lane_id Lane id between 0 and logicalWarpSize-1 + * @param[in] reduce_op Reduction operator, assumed to be commutative and associative. + */ +template +DI void logicalWarpReduceVector(T* acc, int lane_id, ReduceLambda reduce_op) +{ + static_assert(vecWidth > 0, "Vec width must be strictly positive."); + static_assert(!(vecWidth & (vecWidth - 1)), "Vec width must be a power of two."); + static_assert(logicalWarpSize >= 2 && logicalWarpSize <= 32, + "Logical warp size must be between 2 and 32"); + static_assert(!(logicalWarpSize & (logicalWarpSize - 1)), + "Logical warp size must be a power of two."); + + constexpr int shflStride = logicalWarpSize / 2; + constexpr int nextWarpSize = logicalWarpSize / 2; + + // One step of the butterfly reduction, applied to each element of the vector. +#pragma unroll + for (int k = 0; k < vecWidth; k++) { + const T tmp = shfl_xor(acc[k], shflStride, logicalWarpSize); + acc[k] = reduce_op(acc[k], tmp); + } + + constexpr int nextVecWidth = std::max(1, vecWidth / 2); + + /* Split into 2 smaller logical warps and distribute half of the data to each for the next step. + * The distribution pattern is designed so that at the end the outputs are coalesced/round-robin. + * The idea is to distribute contiguous "chunks" of the vectors based on the new warp size. These + * chunks will be halved in the next step and so on. + * + * Example for logicalWarpSize == 4, vecWidth == 8: + * lane 0 | 0 1 2 3 4 5 6 7 | [0 1] [4 5] - - - - | [0] [4] - - - - - - + * lane 1 | 0 1 2 3 4 5 6 7 | [0 1] [4 5] - - - - | [1] [5] - - - - - - + * lane 2 | 0 1 2 3 4 5 6 7 | [2 3] [6 7] - - - - | [2] [6] - - - - - - + * lane 3 | 0 1 2 3 4 5 6 7 | [2 3] [6 7] - - - - | [3] [7] - - - - - - + * chunkSize=2 chunkSize=1 + */ + if constexpr (nextVecWidth < vecWidth) { + T tmp[nextVecWidth]; + const bool firstHalf = (lane_id % logicalWarpSize) < nextWarpSize; + constexpr int chunkSize = std::min(nextVecWidth, nextWarpSize); + constexpr int numChunks = nextVecWidth / chunkSize; +#pragma unroll + for (int c = 0; c < numChunks; c++) { +#pragma unroll + for (int i = 0; i < chunkSize; i++) { + const int k = c * chunkSize + i; + tmp[k] = firstHalf ? acc[2 * c * chunkSize + i] : acc[(2 * c + 1) * chunkSize + i]; + } + } +#pragma unroll + for (int k = 0; k < nextVecWidth; k++) { + acc[k] = tmp[k]; + } + } + + // Recursively call with smaller sub-warps and possibly smaller vector width. + if constexpr (nextWarpSize > 1) { + logicalWarpReduceVector(acc, lane_id % nextWarpSize, reduce_op); + } +} + } // namespace raft diff --git a/cpp/test/linalg/coalesced_reduction.cu b/cpp/test/linalg/coalesced_reduction.cu index 2061f28d36..28f5ff5f60 100644 --- a/cpp/test/linalg/coalesced_reduction.cu +++ b/cpp/test/linalg/coalesced_reduction.cu @@ -39,7 +39,8 @@ struct coalescedReductionInputs { template ::std::ostream& operator<<(::std::ostream& os, const coalescedReductionInputs& dims) { - return os; + return os << "{ " << dims.tolerance << ", " << dims.rows << ", " << dims.cols << ", " + << dims.seed; } // Or else, we get the following compilation error @@ -113,15 +114,40 @@ class coalescedReductionTest : public ::testing::TestWithParam dots_act; }; -const std::vector> inputsf = {{0.000002f, 1024, 32, 1234ULL}, - {0.000002f, 1024, 64, 1234ULL}, - {0.000002f, 1024, 128, 1234ULL}, - {0.000002f, 1024, 256, 1234ULL}}; - -const std::vector> inputsd = {{0.000000001, 1024, 32, 1234ULL}, - {0.000000001, 1024, 64, 1234ULL}, - {0.000000001, 1024, 128, 1234ULL}, - {0.000000001, 1024, 256, 1234ULL}}; +// Note: it's important to have a variety of rows/columns combinations to test all possible code +// paths: thin (few cols or many rows), medium, thick (many cols, very few rows). + +const std::vector> inputsf = {{0.000002f, 50, 2, 1234ULL}, + {0.000002f, 50, 3, 1234ULL}, + {0.000002f, 50, 7, 1234ULL}, + {0.000002f, 50, 9, 1234ULL}, + {0.000002f, 50, 20, 1234ULL}, + {0.000002f, 50, 55, 1234ULL}, + {0.000002f, 50, 100, 1234ULL}, + {0.000002f, 50, 270, 1234ULL}, + {0.000002f, 10000, 3, 1234ULL}, + {0.000002f, 10000, 9, 1234ULL}, + {0.000002f, 10000, 20, 1234ULL}, + {0.000002f, 10000, 55, 1234ULL}, + {0.000002f, 10000, 100, 1234ULL}, + {0.000002f, 10000, 270, 1234ULL}, + {0.0001f, 10, 25000, 1234ULL}}; + +const std::vector> inputsd = {{0.000000001, 50, 2, 1234ULL}, + {0.000000001, 50, 3, 1234ULL}, + {0.000000001, 50, 7, 1234ULL}, + {0.000000001, 50, 9, 1234ULL}, + {0.000000001, 50, 20, 1234ULL}, + {0.000000001, 50, 55, 1234ULL}, + {0.000000001, 50, 100, 1234ULL}, + {0.000000001, 50, 270, 1234ULL}, + {0.000000001, 10000, 3, 1234ULL}, + {0.000000001, 10000, 9, 1234ULL}, + {0.000000001, 10000, 20, 1234ULL}, + {0.000000001, 10000, 55, 1234ULL}, + {0.000000001, 10000, 100, 1234ULL}, + {0.000000001, 10000, 270, 1234ULL}, + {0.0000001, 10, 25000, 1234ULL}}; typedef coalescedReductionTest coalescedReductionTestF; TEST_P(coalescedReductionTestF, Result)