Skip to content

Commit

Permalink
Add device aggregators used by shared memory groupby (#17031)
Browse files Browse the repository at this point in the history
This work is part of splitting the original bulk shared memory groupby PR #16619.

It introduces two device-side element aggregators:

- `shmem_element_aggregator`: aggregates data from global memory sources to shared memory targets,
- `gmem_element_aggregator`: aggregates from shared memory sources to global memory targets. 

These two aggregators are similar to the `elementwise_aggregator` functionality. Follow-up work is tracked via #17032.

Authors:
  - Yunsong Wang (https://github.com/PointKernel)

Approvers:
  - Muhammad Haseeb (https://github.com/mhaseeb123)
  - David Wendt (https://github.com/davidwendt)

URL: #17031
  • Loading branch information
PointKernel authored Oct 18, 2024
1 parent ce93c36 commit 8ebf0d4
Show file tree
Hide file tree
Showing 3 changed files with 570 additions and 33 deletions.
63 changes: 30 additions & 33 deletions cpp/include/cudf/detail/aggregation/device_aggregators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/aggregation.hpp>
Expand All @@ -29,12 +28,31 @@
#include <cuda/std/type_traits>

namespace cudf::detail {
/// Checks if an aggregation kind needs to operate on the underlying storage type
template <aggregation::Kind k>
__device__ constexpr bool uses_underlying_type()
{
return k == aggregation::MIN or k == aggregation::MAX or k == aggregation::SUM;
}

/// Gets the underlying target type for the given source type and aggregation kind
template <typename Source, aggregation::Kind k>
using underlying_target_t =
cuda::std::conditional_t<uses_underlying_type<k>(),
cudf::device_storage_type_t<cudf::detail::target_type_t<Source, k>>,
cudf::detail::target_type_t<Source, k>>;

/// Gets the underlying source type for the given source type and aggregation kind
template <typename Source, aggregation::Kind k>
using underlying_source_t =
cuda::std::conditional_t<uses_underlying_type<k>(), cudf::device_storage_type_t<Source>, Source>;

template <typename Source, aggregation::Kind k, typename Enable = void>
struct update_target_element {
__device__ void operator()(mutable_column_device_view target,
size_type target_index,
column_device_view source,
size_type source_index) const noexcept
__device__ void operator()(mutable_column_device_view,
size_type,
column_device_view,
size_type) const noexcept
{
CUDF_UNREACHABLE("Invalid source type and aggregation combination.");
}
Expand All @@ -51,8 +69,6 @@ struct update_target_element<
column_device_view source,
size_type source_index) const noexcept
{
if (source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::MIN>;
cudf::detail::atomic_min(&target.element<Target>(target_index),
static_cast<Target>(source.element<Source>(source_index)));
Expand All @@ -72,8 +88,6 @@ struct update_target_element<
column_device_view source,
size_type source_index) const noexcept
{
if (source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::MIN>;
using DeviceTarget = device_storage_type_t<Target>;
using DeviceSource = device_storage_type_t<Source>;
Expand All @@ -96,8 +110,6 @@ struct update_target_element<
column_device_view source,
size_type source_index) const noexcept
{
if (source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::MAX>;
cudf::detail::atomic_max(&target.element<Target>(target_index),
static_cast<Target>(source.element<Source>(source_index)));
Expand All @@ -117,8 +129,6 @@ struct update_target_element<
column_device_view source,
size_type source_index) const noexcept
{
if (source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::MAX>;
using DeviceTarget = device_storage_type_t<Target>;
using DeviceSource = device_storage_type_t<Source>;
Expand All @@ -141,8 +151,6 @@ struct update_target_element<
column_device_view source,
size_type source_index) const noexcept
{
if (source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::SUM>;
cudf::detail::atomic_add(&target.element<Target>(target_index),
static_cast<Target>(source.element<Source>(source_index)));
Expand All @@ -162,8 +170,6 @@ struct update_target_element<
column_device_view source,
size_type source_index) const noexcept
{
if (source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::SUM>;
using DeviceTarget = device_storage_type_t<Target>;
using DeviceSource = device_storage_type_t<Source>;
Expand Down Expand Up @@ -197,10 +203,10 @@ struct update_target_from_dictionary {
template <typename Source,
aggregation::Kind k,
cuda::std::enable_if_t<is_dictionary<Source>()>* = nullptr>
__device__ void operator()(mutable_column_device_view target,
size_type target_index,
column_device_view source,
size_type source_index) const noexcept
__device__ void operator()(mutable_column_device_view,
size_type,
column_device_view,
size_type) const noexcept
{
}
};
Expand All @@ -227,8 +233,6 @@ struct update_target_element<
column_device_view source,
size_type source_index) const noexcept
{
if (source.is_null(source_index)) { return; }

dispatch_type_and_aggregation(
source.child(cudf::dictionary_column_view::keys_column_index).type(),
k,
Expand All @@ -249,8 +253,6 @@ struct update_target_element<Source,
column_device_view source,
size_type source_index) const noexcept
{
if (source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::SUM_OF_SQUARES>;
auto value = static_cast<Target>(source.element<Source>(source_index));
cudf::detail::atomic_add(&target.element<Target>(target_index), value * value);
Expand All @@ -267,8 +269,6 @@ struct update_target_element<Source,
column_device_view source,
size_type source_index) const noexcept
{
if (source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::PRODUCT>;
cudf::detail::atomic_mul(&target.element<Target>(target_index),
static_cast<Target>(source.element<Source>(source_index)));
Expand All @@ -286,8 +286,6 @@ struct update_target_element<
column_device_view source,
size_type source_index) const noexcept
{
if (source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::COUNT_VALID>;
cudf::detail::atomic_add(&target.element<Target>(target_index), Target{1});

Expand Down Expand Up @@ -323,8 +321,6 @@ struct update_target_element<
column_device_view source,
size_type source_index) const noexcept
{
if (source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::ARGMAX>;
auto old = cudf::detail::atomic_cas(
&target.element<Target>(target_index), ARGMAX_SENTINEL, source_index);
Expand All @@ -349,8 +345,6 @@ struct update_target_element<
column_device_view source,
size_type source_index) const noexcept
{
if (source.is_null(source_index)) { return; }

using Target = target_type_t<Source, aggregation::ARGMIN>;
auto old = cudf::detail::atomic_cas(
&target.element<Target>(target_index), ARGMIN_SENTINEL, source_index);
Expand All @@ -376,6 +370,9 @@ struct elementwise_aggregator {
column_device_view source,
size_type source_index) const noexcept
{
if constexpr (k != cudf::aggregation::COUNT_ALL) {
if (source.is_null(source_index)) { return; }
}
update_target_element<Source, k>{}(target, target_index, source, source_index);
}
};
Expand Down
Loading

0 comments on commit 8ebf0d4

Please sign in to comment.