diff --git a/cpp/include/raft/util/reduction.cuh b/cpp/include/raft/util/reduction.cuh index 0298d00311..c0d3da7609 100644 --- a/cpp/include/raft/util/reduction.cuh +++ b/cpp/include/raft/util/reduction.cuh @@ -19,7 +19,6 @@ #include #include #include -#include #include #include @@ -242,10 +241,12 @@ DI i_t binaryBlockReduce(i_t val, i_t* shmem) template DI void logicalWarpReduceVector(T* acc, int lane_id, ReduceLambda reduce_op) { - static_assert(raft::is_pow2(vecWidth), "Vec width must be a power of two."); + static_assert(vecWidth > 0, "Vec width must be strictly positive."); + static_assert(!(vecWidth & (vecWidth - 1)), "Vec width must be a power of two."); static_assert(logicalWarpSize >= 2 && logicalWarpSize <= 32, "Logical warp size must be between 2 and 32"); - static_assert(raft::is_pow2(logicalWarpSize), "Logical warp size must be a power of two."); + static_assert(!(logicalWarpSize & (logicalWarpSize - 1)), + "Logical warp size must be a power of two."); constexpr int shflStride = logicalWarpSize / 2; constexpr int nextWarpSize = logicalWarpSize / 2;