Skip to content

Commit

Permalink
Avoid circular include dependency in reduction.cuh
Browse files Browse the repository at this point in the history
  • Loading branch information
Nyrio committed Apr 10, 2024
1 parent 102dc33 commit 591a9e3
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions cpp/include/raft/util/reduction.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include <raft/core/cudart_utils.hpp>
#include <raft/core/operators.hpp>
#include <raft/util/cuda_dev_essentials.cuh>
#include <raft/util/pow2_utils.cuh>
#include <raft/util/warp_primitives.cuh>

#include <stdint.h>
Expand Down Expand Up @@ -242,10 +241,12 @@ DI i_t binaryBlockReduce(i_t val, i_t* shmem)
template <int logicalWarpSize, int vecWidth, typename T, typename ReduceLambda>
DI void logicalWarpReduceVector(T* acc, int lane_id, ReduceLambda reduce_op)
{
static_assert(raft::is_pow2(vecWidth), "Vec width must be a power of two.");
static_assert(vecWidth > 0, "Vec width must be strictly positive.");
static_assert(!(vecWidth & (vecWidth - 1)), "Vec width must be a power of two.");
static_assert(logicalWarpSize >= 2 && logicalWarpSize <= 32,
"Logical warp size must be between 2 and 32");
static_assert(raft::is_pow2(logicalWarpSize), "Logical warp size must be a power of two.");
static_assert(!(logicalWarpSize & (logicalWarpSize - 1)),
"Logical warp size must be a power of two.");

constexpr int shflStride = logicalWarpSize / 2;
constexpr int nextWarpSize = logicalWarpSize / 2;
Expand Down

0 comments on commit 591a9e3

Please sign in to comment.