Skip to content

Commit

Permalink
Mark some constexpr functions as CUDF_HOST_DEVICE that are needed in …
Browse files Browse the repository at this point in the history
…device code
  • Loading branch information
vyasr committed Dec 10, 2024
1 parent 5306eca commit 13433be
Show file tree
Hide file tree
Showing 31 changed files with 293 additions and 196 deletions.
18 changes: 10 additions & 8 deletions cpp/include/cudf/column/column_device_view.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
#include <rmm/cuda_stream_view.hpp>

#include <cuda/std/optional>
#include <cuda/std/type_traits>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/pair.h>

#include <algorithm>
#include <type_traits>

/**
* @file column_device_view.cuh
Expand All @@ -56,8 +58,8 @@ namespace CUDF_EXPORT cudf {
*
*/
struct nullate {
struct YES : std::bool_constant<true> {};
struct NO : std::bool_constant<false> {};
struct YES : cuda::std::bool_constant<true> {};
struct NO : cuda::std::bool_constant<false> {};
/**
* @brief `nullate::DYNAMIC` defers the determination of nullability to run time rather than
* compile time. The calling code is responsible for specifying whether or not nulls are
Expand All @@ -80,7 +82,7 @@ struct nullate {
* @return `true` if nulls are expected in the operation in which this object is applied,
* otherwise false
*/
constexpr operator bool() const noexcept { return value; }
CUDF_HOST_DEVICE constexpr operator bool() const noexcept { return value; }
bool value; ///< True if nulls are expected
};
};
Expand Down Expand Up @@ -319,14 +321,14 @@ class alignas(16) column_device_view_base {
}

template <typename C, typename T, typename = void>
struct has_element_accessor_impl : std::false_type {};
struct has_element_accessor_impl : cuda::std::false_type {};

template <typename C, typename T>
struct has_element_accessor_impl<
C,
T,
void_t<decltype(std::declval<C>().template element<T>(std::declval<size_type>()))>>
: std::true_type {};
void_t<decltype(cuda::std::declval<C>().template element<T>(cuda::std::declval<size_type>()))>>
: cuda::std::true_type {};
};
// @cond
// Forward declaration
Expand Down Expand Up @@ -534,7 +536,7 @@ class alignas(16) column_device_view : public detail::column_device_view_base {
* @return `true` if `column_device_view::element<T>()` has a valid overload, `false` otherwise
*/
template <typename T>
static constexpr bool has_element_accessor()
CUDF_HOST_DEVICE static constexpr bool has_element_accessor()
{
return has_element_accessor_impl<column_device_view, T>::value;
}
Expand Down Expand Up @@ -1044,7 +1046,7 @@ class alignas(16) mutable_column_device_view : public detail::column_device_view
* @return `true` if `mutable_column_device_view::element<T>()` has a valid overload, `false`
*/
template <typename T>
static constexpr bool has_element_accessor()
CUDF_HOST_DEVICE static constexpr bool has_element_accessor()
{
return has_element_accessor_impl<mutable_column_device_view, T>::value;
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/cudf/detail/aggregation/aggregation.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
namespace cudf {
namespace detail {
template <typename T>
constexpr bool is_product_supported()
CUDF_HOST_DEVICE constexpr bool is_product_supported()
{
return is_numeric<T>();
}
Expand Down
11 changes: 6 additions & 5 deletions cpp/include/cudf/detail/utilities/cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ class grid_1d {
* @param num_threads_per_block The number of threads per block
* @return thread_index_type The global thread index
*/
static constexpr thread_index_type global_thread_id(thread_index_type thread_id,
thread_index_type block_id,
thread_index_type num_threads_per_block)
CUDF_HOST_DEVICE static constexpr thread_index_type global_thread_id(
thread_index_type thread_id,
thread_index_type block_id,
thread_index_type num_threads_per_block)
{
return thread_id + block_id * num_threads_per_block;
}
Expand Down Expand Up @@ -114,8 +115,8 @@ class grid_1d {
* @param num_threads_per_block The number of threads per block
* @return thread_index_type The global thread index
*/
static constexpr thread_index_type grid_stride(thread_index_type num_threads_per_block,
thread_index_type num_blocks_per_grid)
CUDF_HOST_DEVICE static constexpr thread_index_type grid_stride(
thread_index_type num_threads_per_block, thread_index_type num_blocks_per_grid)
{
return num_threads_per_block * num_blocks_per_grid;
}
Expand Down
31 changes: 17 additions & 14 deletions cpp/include/cudf/detail/utilities/device_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
#include <cudf/utilities/error.hpp>
#include <cudf/utilities/traits.hpp>

#include <cuda/std/__algorithm/max.h>
#include <cuda/std/__algorithm/min.h>

#include <type_traits>

namespace cudf {
Expand All @@ -42,7 +45,7 @@ template <typename LHS,
std::enable_if_t<cudf::is_relationally_comparable<LHS, RHS>()>* = nullptr>
CUDF_HOST_DEVICE inline auto min(LHS const& lhs, RHS const& rhs)
{
return std::min(lhs, rhs);
return cuda::std::min(lhs, rhs);
}

/**
Expand All @@ -53,7 +56,7 @@ template <typename LHS,
std::enable_if_t<cudf::is_relationally_comparable<LHS, RHS>()>* = nullptr>
CUDF_HOST_DEVICE inline auto max(LHS const& lhs, RHS const& rhs)
{
return std::max(lhs, rhs);
return cuda::std::max(lhs, rhs);
}
} // namespace detail

Expand All @@ -68,20 +71,20 @@ struct DeviceSum {
}

template <typename T, std::enable_if_t<cudf::is_timestamp<T>()>* = nullptr>
static constexpr T identity()
CUDF_HOST_DEVICE static constexpr T identity()
{
return T{typename T::duration{0}};
}

template <typename T,
std::enable_if_t<!cudf::is_timestamp<T>() && !cudf::is_fixed_point<T>()>* = nullptr>
static constexpr T identity()
CUDF_HOST_DEVICE static constexpr T identity()
{
return T{0};
}

template <typename T, std::enable_if_t<cudf::is_fixed_point<T>()>* = nullptr>
static constexpr T identity()
CUDF_HOST_DEVICE static constexpr T identity()
{
#ifndef __CUDA_ARCH__
CUDF_FAIL("fixed_point does not yet support device operator identity");
Expand Down Expand Up @@ -109,7 +112,7 @@ struct DeviceCount {
}

template <typename T>
static constexpr T identity()
CUDF_HOST_DEVICE static constexpr T identity()
{
return T{};
}
Expand All @@ -129,7 +132,7 @@ struct DeviceMin {
template <typename T,
std::enable_if_t<!std::is_same_v<T, cudf::string_view> && !cudf::is_dictionary<T>() &&
!cudf::is_fixed_point<T>()>* = nullptr>
static constexpr T identity()
CUDF_HOST_DEVICE static constexpr T identity()
{
// chrono types do not have std::numeric_limits specializations and should use T::max()
// https://eel.is/c++draft/numeric.limits.general#6
Expand All @@ -143,7 +146,7 @@ struct DeviceMin {
}

template <typename T, std::enable_if_t<cudf::is_fixed_point<T>()>* = nullptr>
static constexpr T identity()
CUDF_HOST_DEVICE static constexpr T identity()
{
#ifndef __CUDA_ARCH__
CUDF_FAIL("fixed_point does not yet support DeviceMin identity");
Expand All @@ -161,7 +164,7 @@ struct DeviceMin {
}

template <typename T, std::enable_if_t<cudf::is_dictionary<T>()>* = nullptr>
static constexpr T identity()
CUDF_HOST_DEVICE static constexpr T identity()
{
return static_cast<T>(T::max_value());
}
Expand All @@ -181,7 +184,7 @@ struct DeviceMax {
template <typename T,
std::enable_if_t<!std::is_same_v<T, cudf::string_view> && !cudf::is_dictionary<T>() &&
!cudf::is_fixed_point<T>()>* = nullptr>
static constexpr T identity()
CUDF_HOST_DEVICE static constexpr T identity()
{
// chrono types do not have std::numeric_limits specializations and should use T::min()
// https://eel.is/c++draft/numeric.limits.general#6
Expand All @@ -195,7 +198,7 @@ struct DeviceMax {
}

template <typename T, std::enable_if_t<cudf::is_fixed_point<T>()>* = nullptr>
static constexpr T identity()
CUDF_HOST_DEVICE static constexpr T identity()
{
#ifndef __CUDA_ARCH__
CUDF_FAIL("fixed_point does not yet support DeviceMax identity");
Expand All @@ -212,7 +215,7 @@ struct DeviceMax {
}

template <typename T, std::enable_if_t<cudf::is_dictionary<T>()>* = nullptr>
static constexpr T identity()
CUDF_HOST_DEVICE static constexpr T identity()
{
return static_cast<T>(T::lowest_value());
}
Expand All @@ -229,13 +232,13 @@ struct DeviceProduct {
}

template <typename T, std::enable_if_t<!cudf::is_fixed_point<T>()>* = nullptr>
static constexpr T identity()
CUDF_HOST_DEVICE static constexpr T identity()
{
return T{1};
}

template <typename T, std::enable_if_t<cudf::is_fixed_point<T>()>* = nullptr>
static constexpr T identity()
CUDF_HOST_DEVICE static constexpr T identity()
{
#ifndef __CUDA_ARCH__
CUDF_FAIL("fixed_point does not yet support DeviceProduct identity");
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cudf/detail/utilities/integer_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ constexpr S round_down_safe(S number_to_round, S modulus) noexcept
* `modulus` is positive and does not check for overflow.
*/
template <typename S>
constexpr S round_up_unsafe(S number_to_round, S modulus) noexcept
CUDF_HOST_DEVICE constexpr S round_up_unsafe(S number_to_round, S modulus) noexcept
{
auto remainder = number_to_round % modulus;
if (remainder == 0) { return number_to_round; }
Expand Down Expand Up @@ -187,7 +187,7 @@ constexpr bool is_a_power_of_two(I val) noexcept
* @return Absolute value if value type is signed.
*/
template <typename T>
constexpr auto absolute_value(T value) -> T
CUDF_HOST_DEVICE constexpr auto absolute_value(T value) -> T
{
if constexpr (cuda::std::is_signed<T>()) return numeric::detail::abs(value);
return value;
Expand Down
7 changes: 4 additions & 3 deletions cpp/include/cudf/fixed_point/detail/floating_conversion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <cuda/std/cmath>
#include <cuda/std/limits>
#include <cuda/std/type_traits>
#include <cuda/std/utility>

#include <cstring>

Expand Down Expand Up @@ -183,7 +184,7 @@ struct floating_converter {
* @param integer_rep The bit-casted floating value to extract the exponent from
* @return The stored base-2 exponent and significand, shifted for denormals
*/
CUDF_HOST_DEVICE inline static std::pair<IntegralType, int> get_significand_and_pow2(
CUDF_HOST_DEVICE inline static cuda::std::pair<IntegralType, int> get_significand_and_pow2(
IntegralType integer_rep)
{
// Extract the significand
Expand Down Expand Up @@ -1008,7 +1009,7 @@ CUDF_HOST_DEVICE inline auto shift_to_binary_pospow(DecimalRep decimal_rep, int
}

// Our shifting_rep is now the integer mantissa, return it and the powers of 2
return std::pair{shifting_rep, pow2};
return cuda::std::pair{shifting_rep, pow2};
}

/**
Expand Down Expand Up @@ -1075,7 +1076,7 @@ CUDF_HOST_DEVICE inline auto shift_to_binary_negpow(DecimalRep decimal_rep, int
}

// Our shifting_rep is now the integer mantissa, return it and the powers of 2
return std::pair{shifting_rep, pow2};
return cuda::std::pair{shifting_rep, pow2};
}

/**
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cudf/hashing/detail/hash_functions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include <cudf/utilities/traits.hpp>

#include <limits>
#include <cuda/std/limits>

namespace cudf::hashing::detail {

Expand All @@ -29,7 +29,7 @@ template <typename T>
T __device__ inline normalize_nans(T const& key)
{
if constexpr (cudf::is_floating_point<T>()) {
if (std::isnan(key)) { return std::numeric_limits<T>::quiet_NaN(); }
if (std::isnan(key)) { return cuda::std::numeric_limits<T>::quiet_NaN(); }
}
return key;
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/cudf/hashing/detail/hashing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ std::unique_ptr<column> xxhash_64(table_view const& input,
* @param rhs The second hash value
* @return Combined hash value
*/
constexpr uint32_t hash_combine(uint32_t lhs, uint32_t rhs)
CUDF_HOST_DEVICE constexpr uint32_t hash_combine(uint32_t lhs, uint32_t rhs)
{
return lhs ^ (rhs + 0x9e37'79b9 + (lhs << 6) + (lhs >> 2));
}
Expand Down
21 changes: 12 additions & 9 deletions cpp/include/cudf/strings/detail/utf8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace strings::detail {
* @param chr Any single byte from a valid UTF-8 character
* @return true if this is not the first byte of the character
*/
constexpr bool is_utf8_continuation_char(unsigned char chr)
CUDF_HOST_DEVICE constexpr bool is_utf8_continuation_char(unsigned char chr)
{
// The (0xC0 & 0x80) bit pattern identifies a continuation byte of a character.
return (chr & 0xC0) == 0x80;
Expand All @@ -43,7 +43,10 @@ constexpr bool is_utf8_continuation_char(unsigned char chr)
* @param chr Any single byte from a valid UTF-8 character
* @return true if this the first byte of the character
*/
constexpr bool is_begin_utf8_char(unsigned char chr) { return not is_utf8_continuation_char(chr); }
CUDF_HOST_DEVICE constexpr bool is_begin_utf8_char(unsigned char chr)
{
return not is_utf8_continuation_char(chr);
}

/**
* @brief This will return true if the passed in byte could be the start of
Expand All @@ -55,7 +58,7 @@ constexpr bool is_begin_utf8_char(unsigned char chr) { return not is_utf8_contin
* @param byte The byte to be tested
* @return true if this can be the first byte of a character
*/
constexpr bool is_valid_begin_utf8_char(uint8_t byte)
CUDF_HOST_DEVICE constexpr bool is_valid_begin_utf8_char(uint8_t byte)
{
// to be the first byte of a valid (up to 4 byte) UTF-8 char, byte must be one of:
// 0b0vvvvvvv a 1 byte character
Expand All @@ -72,7 +75,7 @@ constexpr bool is_valid_begin_utf8_char(uint8_t byte)
* @param character Single character
* @return Number of bytes
*/
constexpr size_type bytes_in_char_utf8(char_utf8 character)
CUDF_HOST_DEVICE constexpr size_type bytes_in_char_utf8(char_utf8 character)
{
return 1 + static_cast<size_type>((character & 0x0000'FF00u) > 0) +
static_cast<size_type>((character & 0x00FF'0000u) > 0) +
Expand All @@ -89,7 +92,7 @@ constexpr size_type bytes_in_char_utf8(char_utf8 character)
* @param byte Byte from an encoded character.
* @return Number of bytes.
*/
constexpr size_type bytes_in_utf8_byte(uint8_t byte)
CUDF_HOST_DEVICE constexpr size_type bytes_in_utf8_byte(uint8_t byte)
{
return 1 + static_cast<size_type>((byte & 0xF0) == 0xF0) // 4-byte character prefix
+ static_cast<size_type>((byte & 0xE0) == 0xE0) // 3-byte character prefix
Expand All @@ -104,7 +107,7 @@ constexpr size_type bytes_in_utf8_byte(uint8_t byte)
* @param[out] character Single char_utf8 value.
* @return The number of bytes in the character
*/
constexpr size_type to_char_utf8(char const* str, char_utf8& character)
CUDF_HOST_DEVICE constexpr size_type to_char_utf8(char const* str, char_utf8& character)
{
size_type const chr_width = bytes_in_utf8_byte(static_cast<uint8_t>(*str));

Expand All @@ -131,7 +134,7 @@ constexpr size_type to_char_utf8(char const* str, char_utf8& character)
* @param[out] str Output array.
* @return The number of bytes in the character
*/
constexpr inline size_type from_char_utf8(char_utf8 character, char* str)
CUDF_HOST_DEVICE constexpr inline size_type from_char_utf8(char_utf8 character, char* str)
{
size_type const chr_width = bytes_in_char_utf8(character);
for (size_type idx = 0; idx < chr_width; ++idx) {
Expand All @@ -148,7 +151,7 @@ constexpr inline size_type from_char_utf8(char_utf8 character, char* str)
* @param utf8_char Single UTF-8 character to convert.
* @return Code-point for the UTF-8 character.
*/
constexpr uint32_t utf8_to_codepoint(cudf::char_utf8 utf8_char)
CUDF_HOST_DEVICE constexpr uint32_t utf8_to_codepoint(cudf::char_utf8 utf8_char)
{
uint32_t unchr = 0;
if (utf8_char < 0x0000'0080) // single-byte pass thru
Expand Down Expand Up @@ -178,7 +181,7 @@ constexpr uint32_t utf8_to_codepoint(cudf::char_utf8 utf8_char)
* @param unchr Character code-point to convert.
* @return Single UTF-8 character.
*/
constexpr cudf::char_utf8 codepoint_to_utf8(uint32_t unchr)
CUDF_HOST_DEVICE constexpr cudf::char_utf8 codepoint_to_utf8(uint32_t unchr)
{
cudf::char_utf8 utf8 = 0;
if (unchr < 0x0000'0080) // single byte utf8
Expand Down
Loading

0 comments on commit 13433be

Please sign in to comment.