Skip to content

Commit

Permalink
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
Browse files Browse the repository at this point in the history
…cagra-dists
  • Loading branch information
tarang-jain committed Apr 23, 2024
2 parents 77ff0c2 + 317a61c commit b61c1c3
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 36 deletions.
117 changes: 95 additions & 22 deletions cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,18 @@ namespace raft {
namespace linalg {
namespace detail {

template <int warpSize, int rpb>
template <int warpSize, int tpb, int rpw, bool noLoop = false>
struct ReductionThinPolicy {
static constexpr int LogicalWarpSize = warpSize;
static constexpr int RowsPerBlock = rpb;
static constexpr int ThreadsPerBlock = LogicalWarpSize * RowsPerBlock;
static_assert(tpb % warpSize == 0);

static constexpr int LogicalWarpSize = warpSize;
static constexpr int ThreadsPerBlock = tpb;
static constexpr int RowsPerLogicalWarp = rpw;
static constexpr int NumLogicalWarps = ThreadsPerBlock / LogicalWarpSize;
static constexpr int RowsPerBlock = NumLogicalWarps * RowsPerLogicalWarp;

// Whether D (run-time arg) will be smaller than warpSize (compile-time parameter)
static constexpr bool NoSequentialReduce = noLoop;
};

template <typename Policy,
Expand All @@ -53,19 +60,72 @@ RAFT_KERNEL __launch_bounds__(Policy::ThreadsPerBlock)
FinalLambda final_op,
bool inplace = false)
{
IdxType i = threadIdx.y + (Policy::RowsPerBlock * static_cast<IdxType>(blockIdx.x));
if (i >= N) return;
/* The strategy to achieve near-SOL memory bandwidth differs based on D:
* - For small D, we need to process multiple rows per logical warp in order to have
* multiple loads per thread and increase bytes in flight and amortize latencies.
* - For large D, we start with a sequential reduction. The compiler partially unrolls
* that loop (e.g. first a loop of stride 16, then 8, 4, and 1).
*/
IdxType i0 = threadIdx.y + (Policy::RowsPerBlock * static_cast<IdxType>(blockIdx.x));
if (i0 >= N) return;

OutType acc = init;
for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) {
acc = reduce_op(acc, main_op(data[j + (D * i)], j));
OutType acc[Policy::RowsPerLogicalWarp];
#pragma unroll
for (int k = 0; k < Policy::RowsPerLogicalWarp; k++) {
acc[k] = init;
}
acc = raft::logicalWarpReduce<Policy::LogicalWarpSize>(acc, reduce_op);
if (threadIdx.x == 0) {

if constexpr (Policy::NoSequentialReduce) {
IdxType j = threadIdx.x;
if (j < D) {
#pragma unroll
for (IdxType k = 0; k < Policy::RowsPerLogicalWarp; k++) {
// Only the first row is known to be within bounds. Clamp to avoid out-of-mem read.
const IdxType i = raft::min(i0 + k * Policy::NumLogicalWarps, N - 1);
acc[k] = reduce_op(acc[k], main_op(data[j + (D * i)], j));
}
}
} else {
for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) {
#pragma unroll
for (IdxType k = 0; k < Policy::RowsPerLogicalWarp; k++) {
const IdxType i = raft::min(i0 + k * Policy::NumLogicalWarps, N - 1);
acc[k] = reduce_op(acc[k], main_op(data[j + (D * i)], j));
}
}
}

/* This vector reduction has two benefits compared to naive separate reductions:
* - It avoids the LSU bottleneck when the number of columns is around 32 (e.g. for 32, 5 shuffles
* are required and there is no initial sequential reduction to amortize that cost).
* - It distributes the outputs to multiple threads, enabling a coalesced store when the number of
* rows per logical warp and logical warp size are equal.
*/
raft::logicalWarpReduceVector<Policy::LogicalWarpSize, Policy::RowsPerLogicalWarp>(
acc, threadIdx.x, reduce_op);

constexpr int reducOutVecWidth =
std::max(1, Policy::RowsPerLogicalWarp / Policy::LogicalWarpSize);
constexpr int reducOutGroupSize =
std::max(1, Policy::LogicalWarpSize / Policy::RowsPerLogicalWarp);
constexpr int reducNumGroups = Policy::LogicalWarpSize / reducOutGroupSize;

if (threadIdx.x % reducOutGroupSize == 0) {
const int groupId = threadIdx.x / reducOutGroupSize;
if (inplace) {
dots[i] = final_op(reduce_op(dots[i], acc));
#pragma unroll
for (int k = 0; k < reducOutVecWidth; k++) {
const int reductionId = k * reducNumGroups + groupId;
const IdxType i = i0 + reductionId * Policy::NumLogicalWarps;
if (i < N) { dots[i] = final_op(reduce_op(dots[i], acc[k])); }
}
} else {
dots[i] = final_op(acc);
#pragma unroll
for (int k = 0; k < reducOutVecWidth; k++) {
const int reductionId = k * reducNumGroups + groupId;
const IdxType i = i0 + reductionId * Policy::NumLogicalWarps;
if (i < N) { dots[i] = final_op(acc[k]); }
}
}
}
}
Expand All @@ -89,8 +149,12 @@ void coalescedReductionThin(OutType* dots,
FinalLambda final_op = raft::identity_op())
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"coalescedReductionThin<%d,%d>", Policy::LogicalWarpSize, Policy::RowsPerBlock);
dim3 threads(Policy::LogicalWarpSize, Policy::RowsPerBlock, 1);
"coalescedReductionThin<%d,%d,%d,%d>",
Policy::LogicalWarpSize,
Policy::ThreadsPerBlock,
Policy::RowsPerLogicalWarp,
static_cast<int>(Policy::NoSequentialReduce));
dim3 threads(Policy::LogicalWarpSize, Policy::NumLogicalWarps, 1);
dim3 blocks(ceildiv<IdxType>(N, Policy::RowsPerBlock), 1, 1);
coalescedReductionThinKernel<Policy>
<<<blocks, threads, 0, stream>>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace);
Expand All @@ -115,19 +179,28 @@ void coalescedReductionThinDispatcher(OutType* dots,
FinalLambda final_op = raft::identity_op())
{
if (D <= IdxType(2)) {
coalescedReductionThin<ReductionThinPolicy<2, 64>>(
coalescedReductionThin<ReductionThinPolicy<2, 128, 8, true>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else if (D <= IdxType(4)) {
coalescedReductionThin<ReductionThinPolicy<4, 32>>(
coalescedReductionThin<ReductionThinPolicy<4, 128, 8, true>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else if (D <= IdxType(8)) {
coalescedReductionThin<ReductionThinPolicy<8, 16>>(
coalescedReductionThin<ReductionThinPolicy<8, 128, 8, true>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else if (D <= IdxType(16)) {
coalescedReductionThin<ReductionThinPolicy<16, 8>>(
coalescedReductionThin<ReductionThinPolicy<16, 128, 8, true>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else if (D <= IdxType(32)) {
coalescedReductionThin<ReductionThinPolicy<32, 128, 8, true>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else if (D < IdxType(128)) {
coalescedReductionThin<ReductionThinPolicy<32, 128, 4, false>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else {
coalescedReductionThin<ReductionThinPolicy<32, 4>>(
// For D=128 (included) and above, the 4x-unrolled loading loop is used
// and multiple rows per warp are counter-productive in terms of cache-friendliness
// and register use.
coalescedReductionThin<ReductionThinPolicy<32, 128, 1, false>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
}
}
Expand Down Expand Up @@ -319,10 +392,10 @@ void coalescedReductionThickDispatcher(OutType* dots,
// Note: multiple elements per thread to take advantage of the sequential reduction and loop
// unrolling
if (D < IdxType(32768)) {
coalescedReductionThick<ReductionThickPolicy<256, 32>, ReductionThinPolicy<32, 4>>(
coalescedReductionThick<ReductionThickPolicy<256, 32>, ReductionThinPolicy<32, 128, 1>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else {
coalescedReductionThick<ReductionThickPolicy<256, 64>, ReductionThinPolicy<32, 4>>(
coalescedReductionThick<ReductionThickPolicy<256, 64>, ReductionThinPolicy<32, 128, 1>>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
}
}
Expand Down
13 changes: 11 additions & 2 deletions cpp/include/raft/util/pow2_utils.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,15 @@

namespace raft {

/**
* Checks whether an integer is a power of 2.
*/
template <typename T>
constexpr HDI std::enable_if_t<std::is_integral_v<T>, bool> is_pow2(T v)
{
return (v && !(v & (v - 1)));
}

/**
* @brief Fast arithmetics and alignment checks for power-of-two values known at compile time.
*
Expand All @@ -33,7 +42,7 @@ struct Pow2 {
static constexpr Type Mask = Value - 1;

static_assert(std::is_integral<Type>::value, "Value must be integral.");
static_assert(Value && !(Value & Mask), "Value must be power of two.");
static_assert(is_pow2(Value), "Value must be power of two.");

#define Pow2_FUNC_QUALIFIER static constexpr __host__ __device__ __forceinline__
#define Pow2_WHEN_INTEGRAL(I) std::enable_if_t<Pow2_IS_REPRESENTABLE_AS(I), I>
Expand Down
104 changes: 102 additions & 2 deletions cpp/include/raft/util/reduction.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ DI T logicalWarpReduce(T val, ReduceLambda reduce_op)
{
#pragma unroll
for (int i = logicalWarpSize / 2; i > 0; i >>= 1) {
T tmp = shfl_xor(val, i);
val = reduce_op(val, tmp);
const T tmp = shfl_xor(val, i, logicalWarpSize);
val = reduce_op(val, tmp);
}
return val;
}
Expand Down Expand Up @@ -197,4 +197,104 @@ DI i_t binaryBlockReduce(i_t val, i_t* shmem)
}
}

/**
* @brief Executes a collaborative vector reduction per sub-warp
*
* This uses fewer shuffles than naively reducing each element independently.
* Better performance is achieved with a larger vector width, up to vecWidth == warpSize/2.
* For example, for logicalWarpSize == 32 and vecWidth == 16, the naive method requires 80
* shuffles, this one only 31, 2.58x fewer.
*
* However, the output of the reduction is not broadcasted. The vector is modified in place and
* each thread holds a part of the output vector. The outputs are distributed in a round-robin
* pattern between the threads to facilitate coalesced IO. There are 2 possible layouts based on
* which of logicalWarpSize and vecWidth is larger:
* - If vecWidth >= logicalWarpSize, each thread has vecWidth/logicalWarpSize outputs.
* - If logicalWarpSize > vecWidth, logicalWarpSize/vecWidth threads have a copy of the same output.
*
* Example 1: logicalWarpSize == 4, vecWidth == 8, v = a+b+c+d
* IN OUT
* lane 0 | a0 a1 a2 a3 a4 a5 a6 a7 | v0 v4 - - - - - -
* lane 1 | b0 b1 b2 b3 b4 b5 b6 b7 | v1 v5 - - - - - -
* lane 2 | c0 c1 c2 c3 c4 c5 c6 c7 | v2 v6 - - - - - -
* lane 3 | d0 d1 d2 d3 d4 d5 d6 d7 | v3 v7 - - - - - -
*
* Example 2: logicalWarpSize == 8, vecWidth == 4, v = a+b+c+d+e+f+g+h
* IN OUT
* lane 0 | a0 a1 a2 a3 | v0 - - -
* lane 1 | b0 b1 b2 b3 | v0 - - -
* lane 2 | c0 c1 c2 c3 | v1 - - -
* lane 3 | d0 d1 d2 d3 | v1 - - -
* lane 4 | e0 e1 e2 e3 | v2 - - -
* lane 5 | f0 f1 f2 f3 | v2 - - -
* lane 6 | g0 g1 g2 g3 | v3 - - -
* lane 7 | h0 h1 h2 h3 | v3 - - -
*
* @tparam logicalWarpSize Sub-warp size. Must be 2, 4, 8, 16 or 32.
* @tparam vecWidth Vector width. Must be a power of two.
* @tparam T Vector element type.
* @tparam ReduceLambda Reduction operator type.
* @param[in,out] acc Pointer to a vector of size vecWidth or more in registers
* @param[in] lane_id Lane id between 0 and logicalWarpSize-1
* @param[in] reduce_op Reduction operator, assumed to be commutative and associative.
*/
template <int logicalWarpSize, int vecWidth, typename T, typename ReduceLambda>
DI void logicalWarpReduceVector(T* acc, int lane_id, ReduceLambda reduce_op)
{
static_assert(vecWidth > 0, "Vec width must be strictly positive.");
static_assert(!(vecWidth & (vecWidth - 1)), "Vec width must be a power of two.");
static_assert(logicalWarpSize >= 2 && logicalWarpSize <= 32,
"Logical warp size must be between 2 and 32");
static_assert(!(logicalWarpSize & (logicalWarpSize - 1)),
"Logical warp size must be a power of two.");

constexpr int shflStride = logicalWarpSize / 2;
constexpr int nextWarpSize = logicalWarpSize / 2;

// One step of the butterfly reduction, applied to each element of the vector.
#pragma unroll
for (int k = 0; k < vecWidth; k++) {
const T tmp = shfl_xor(acc[k], shflStride, logicalWarpSize);
acc[k] = reduce_op(acc[k], tmp);
}

constexpr int nextVecWidth = std::max(1, vecWidth / 2);

/* Split into 2 smaller logical warps and distribute half of the data to each for the next step.
* The distribution pattern is designed so that at the end the outputs are coalesced/round-robin.
* The idea is to distribute contiguous "chunks" of the vectors based on the new warp size. These
* chunks will be halved in the next step and so on.
*
* Example for logicalWarpSize == 4, vecWidth == 8:
* lane 0 | 0 1 2 3 4 5 6 7 | [0 1] [4 5] - - - - | [0] [4] - - - - - -
* lane 1 | 0 1 2 3 4 5 6 7 | [0 1] [4 5] - - - - | [1] [5] - - - - - -
* lane 2 | 0 1 2 3 4 5 6 7 | [2 3] [6 7] - - - - | [2] [6] - - - - - -
* lane 3 | 0 1 2 3 4 5 6 7 | [2 3] [6 7] - - - - | [3] [7] - - - - - -
* chunkSize=2 chunkSize=1
*/
if constexpr (nextVecWidth < vecWidth) {
T tmp[nextVecWidth];
const bool firstHalf = (lane_id % logicalWarpSize) < nextWarpSize;
constexpr int chunkSize = std::min(nextVecWidth, nextWarpSize);
constexpr int numChunks = nextVecWidth / chunkSize;
#pragma unroll
for (int c = 0; c < numChunks; c++) {
#pragma unroll
for (int i = 0; i < chunkSize; i++) {
const int k = c * chunkSize + i;
tmp[k] = firstHalf ? acc[2 * c * chunkSize + i] : acc[(2 * c + 1) * chunkSize + i];
}
}
#pragma unroll
for (int k = 0; k < nextVecWidth; k++) {
acc[k] = tmp[k];
}
}

// Recursively call with smaller sub-warps and possibly smaller vector width.
if constexpr (nextWarpSize > 1) {
logicalWarpReduceVector<nextWarpSize, nextVecWidth>(acc, lane_id % nextWarpSize, reduce_op);
}
}

} // namespace raft
46 changes: 36 additions & 10 deletions cpp/test/linalg/coalesced_reduction.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ struct coalescedReductionInputs {
template <typename T>
::std::ostream& operator<<(::std::ostream& os, const coalescedReductionInputs<T>& dims)
{
return os;
return os << "{ " << dims.tolerance << ", " << dims.rows << ", " << dims.cols << ", "
<< dims.seed;
}

// Or else, we get the following compilation error
Expand Down Expand Up @@ -113,15 +114,40 @@ class coalescedReductionTest : public ::testing::TestWithParam<coalescedReductio
rmm::device_uvector<T> dots_act;
};

const std::vector<coalescedReductionInputs<float>> inputsf = {{0.000002f, 1024, 32, 1234ULL},
{0.000002f, 1024, 64, 1234ULL},
{0.000002f, 1024, 128, 1234ULL},
{0.000002f, 1024, 256, 1234ULL}};

const std::vector<coalescedReductionInputs<double>> inputsd = {{0.000000001, 1024, 32, 1234ULL},
{0.000000001, 1024, 64, 1234ULL},
{0.000000001, 1024, 128, 1234ULL},
{0.000000001, 1024, 256, 1234ULL}};
// Note: it's important to have a variety of rows/columns combinations to test all possible code
// paths: thin (few cols or many rows), medium, thick (many cols, very few rows).

const std::vector<coalescedReductionInputs<float>> inputsf = {{0.000002f, 50, 2, 1234ULL},
{0.000002f, 50, 3, 1234ULL},
{0.000002f, 50, 7, 1234ULL},
{0.000002f, 50, 9, 1234ULL},
{0.000002f, 50, 20, 1234ULL},
{0.000002f, 50, 55, 1234ULL},
{0.000002f, 50, 100, 1234ULL},
{0.000002f, 50, 270, 1234ULL},
{0.000002f, 10000, 3, 1234ULL},
{0.000002f, 10000, 9, 1234ULL},
{0.000002f, 10000, 20, 1234ULL},
{0.000002f, 10000, 55, 1234ULL},
{0.000002f, 10000, 100, 1234ULL},
{0.000002f, 10000, 270, 1234ULL},
{0.0001f, 10, 25000, 1234ULL}};

const std::vector<coalescedReductionInputs<double>> inputsd = {{0.000000001, 50, 2, 1234ULL},
{0.000000001, 50, 3, 1234ULL},
{0.000000001, 50, 7, 1234ULL},
{0.000000001, 50, 9, 1234ULL},
{0.000000001, 50, 20, 1234ULL},
{0.000000001, 50, 55, 1234ULL},
{0.000000001, 50, 100, 1234ULL},
{0.000000001, 50, 270, 1234ULL},
{0.000000001, 10000, 3, 1234ULL},
{0.000000001, 10000, 9, 1234ULL},
{0.000000001, 10000, 20, 1234ULL},
{0.000000001, 10000, 55, 1234ULL},
{0.000000001, 10000, 100, 1234ULL},
{0.000000001, 10000, 270, 1234ULL},
{0.0000001, 10, 25000, 1234ULL}};

typedef coalescedReductionTest<float> coalescedReductionTestF;
TEST_P(coalescedReductionTestF, Result)
Expand Down

0 comments on commit b61c1c3

Please sign in to comment.