Skip to content

Commit

Permalink
Change binops for-each kernel to thrust::for_each_n (#17419)
Browse files Browse the repository at this point in the history
Replaces the custom `for_each_kernel` in `binary_ops.cuh` with `thrust::for_each_n`

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - Karthikeyan (https://github.com/karthikeyann)
  - Bradley Dice (https://github.com/bdice)

URL: #17419
  • Loading branch information
davidwendt authored Nov 26, 2024
1 parent f05e89d commit 4e3afd2
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 63 deletions.
56 changes: 10 additions & 46 deletions cpp/src/binaryop/compiled/binary_ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -244,44 +244,6 @@ struct binary_op_double_device_dispatcher {
}
};

/**
* @brief Simplified for_each kernel
*
* @param size number of elements to process.
* @param f Functor object to call for each element.
*/
template <typename Functor>
CUDF_KERNEL void for_each_kernel(cudf::size_type size, Functor f)
{
auto start = cudf::detail::grid_1d::global_thread_id();
auto const stride = cudf::detail::grid_1d::grid_stride();

#pragma unroll
for (auto i = start; i < size; i += stride) {
f(i);
}
}

/**
* @brief Launches Simplified for_each kernel with maximum occupancy grid dimensions.
*
* @tparam Functor
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param size number of elements to process.
* @param f Functor object to call for each element.
*/
template <typename Functor>
void for_each(rmm::cuda_stream_view stream, cudf::size_type size, Functor f)
{
int block_size;
int min_grid_size;
CUDF_CUDA_TRY(
cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, for_each_kernel<decltype(f)>));
auto grid = cudf::detail::grid_1d(size, block_size, 2 /* elements_per_thread */);
for_each_kernel<<<grid.num_blocks, grid.num_threads_per_block, 0, stream.value()>>>(
size, std::forward<Functor&&>(f));
}

template <class BinaryOperator>
void apply_binary_op(mutable_column_view& out,
column_view const& lhs,
Expand All @@ -298,16 +260,18 @@ void apply_binary_op(mutable_column_view& out,
// Create binop functor instance
if (common_dtype) {
// Execute it on every element
for_each(stream,
out.size(),
binary_op_device_dispatcher<BinaryOperator>{
*common_dtype, *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar});
thrust::for_each_n(rmm::exec_policy_nosync(stream),
thrust::counting_iterator<size_type>(0),
out.size(),
binary_op_device_dispatcher<BinaryOperator>{
*common_dtype, *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar});
} else {
// Execute it on every element
for_each(stream,
out.size(),
binary_op_double_device_dispatcher<BinaryOperator>{
*outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar});
thrust::for_each_n(rmm::exec_policy_nosync(stream),
thrust::counting_iterator<size_type>(0),
out.size(),
binary_op_double_device_dispatcher<BinaryOperator>{
*outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar});
}
}

Expand Down
38 changes: 21 additions & 17 deletions cpp/src/binaryop/compiled/equality_ops.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,27 +34,31 @@ void dispatch_equality_op(mutable_column_view& out,
auto rhsd = column_device_view::create(rhs, stream);
if (common_dtype) {
if (op == binary_operator::EQUAL) {
for_each(stream,
out.size(),
binary_op_device_dispatcher<ops::Equal>{
*common_dtype, *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar});
thrust::for_each_n(rmm::exec_policy_nosync(stream),
thrust::counting_iterator<size_type>(0),
out.size(),
binary_op_device_dispatcher<ops::Equal>{
*common_dtype, *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar});
} else if (op == binary_operator::NOT_EQUAL) {
for_each(stream,
out.size(),
binary_op_device_dispatcher<ops::NotEqual>{
*common_dtype, *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar});
thrust::for_each_n(rmm::exec_policy_nosync(stream),
thrust::counting_iterator<size_type>(0),
out.size(),
binary_op_device_dispatcher<ops::NotEqual>{
*common_dtype, *outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar});
}
} else {
if (op == binary_operator::EQUAL) {
for_each(stream,
out.size(),
binary_op_double_device_dispatcher<ops::Equal>{
*outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar});
thrust::for_each_n(rmm::exec_policy_nosync(stream),
thrust::counting_iterator<size_type>(0),
out.size(),
binary_op_double_device_dispatcher<ops::Equal>{
*outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar});
} else if (op == binary_operator::NOT_EQUAL) {
for_each(stream,
out.size(),
binary_op_double_device_dispatcher<ops::NotEqual>{
*outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar});
thrust::for_each_n(rmm::exec_policy_nosync(stream),
thrust::counting_iterator<size_type>(0),
out.size(),
binary_op_double_device_dispatcher<ops::NotEqual>{
*outd, *lhsd, *rhsd, is_lhs_scalar, is_rhs_scalar});
}
}
}
Expand Down

0 comments on commit 4e3afd2

Please sign in to comment.