Skip to content

Commit

Permalink
Move flatten_single_pass_aggs to its own TU
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Oct 10, 2024
1 parent 69b0f66 commit eafdc93
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 104 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ add_library(
src/filling/repeat.cu
src/filling/sequence.cu
src/groupby/groupby.cu
src/groupby/hash/flatten_single_pass_aggs.cpp
src/groupby/hash/groupby.cu
src/groupby/sort/aggregate.cpp
src/groupby/sort/group_argmax.cu
Expand Down
139 changes: 139 additions & 0 deletions cpp/src/groupby/hash/flatten_single_pass_aggs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "flatten_single_pass_aggs.hpp"

#include <cudf/aggregation.hpp>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/groupby.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/error.hpp>
#include <cudf/utilities/span.hpp>

#include <memory>
#include <tuple>
#include <unordered_set>
#include <vector>

namespace cudf::groupby::detail::hash {

class groupby_simple_aggregations_collector final
: public cudf::detail::simple_aggregations_collector {
public:
using cudf::detail::simple_aggregations_collector::visit;

std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
cudf::detail::min_aggregation const&) override
{
std::vector<std::unique_ptr<aggregation>> aggs;
aggs.push_back(col_type.id() == type_id::STRING ? make_argmin_aggregation()
: make_min_aggregation());
return aggs;
}

std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
cudf::detail::max_aggregation const&) override
{
std::vector<std::unique_ptr<aggregation>> aggs;
aggs.push_back(col_type.id() == type_id::STRING ? make_argmax_aggregation()
: make_max_aggregation());
return aggs;
}

std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
cudf::detail::mean_aggregation const&) override
{
(void)col_type;
CUDF_EXPECTS(is_fixed_width(col_type), "MEAN aggregation expects fixed width type");
std::vector<std::unique_ptr<aggregation>> aggs;
aggs.push_back(make_sum_aggregation());
// COUNT_VALID
aggs.push_back(make_count_aggregation());

return aggs;
}

std::vector<std::unique_ptr<aggregation>> visit(data_type,
cudf::detail::var_aggregation const&) override
{
std::vector<std::unique_ptr<aggregation>> aggs;
aggs.push_back(make_sum_aggregation());
// COUNT_VALID
aggs.push_back(make_count_aggregation());

return aggs;
}

std::vector<std::unique_ptr<aggregation>> visit(data_type,
cudf::detail::std_aggregation const&) override
{
std::vector<std::unique_ptr<aggregation>> aggs;
aggs.push_back(make_sum_aggregation());
// COUNT_VALID
aggs.push_back(make_count_aggregation());

return aggs;
}

std::vector<std::unique_ptr<aggregation>> visit(
data_type, cudf::detail::correlation_aggregation const&) override
{
std::vector<std::unique_ptr<aggregation>> aggs;
aggs.push_back(make_sum_aggregation());
// COUNT_VALID
aggs.push_back(make_count_aggregation());

return aggs;
}
};

// flatten aggs to filter in single pass aggs
std::tuple<table_view, std::vector<aggregation::Kind>, std::vector<std::unique_ptr<aggregation>>>
flatten_single_pass_aggs(host_span<aggregation_request const> requests)
{
std::vector<column_view> columns;
std::vector<std::unique_ptr<aggregation>> aggs;
std::vector<aggregation::Kind> agg_kinds;

for (auto const& request : requests) {
auto const& agg_v = request.aggregations;

std::unordered_set<aggregation::Kind> agg_kinds_set;
auto insert_agg = [&](column_view const& request_values, std::unique_ptr<aggregation>&& agg) {
if (agg_kinds_set.insert(agg->kind).second) {
agg_kinds.push_back(agg->kind);
aggs.push_back(std::move(agg));
columns.push_back(request_values);
}
};

auto values_type = cudf::is_dictionary(request.values.type())
? cudf::dictionary_column_view(request.values).keys().type()
: request.values.type();
for (auto&& agg : agg_v) {
groupby_simple_aggregations_collector collector;

for (auto& agg_s : agg->get_simple_aggregations(values_type, collector)) {
insert_agg(request.values, std::move(agg_s));
}
}
}

return std::make_tuple(table_view(columns), std::move(agg_kinds), std::move(aggs));
}

} // namespace cudf::groupby::detail::hash
33 changes: 33 additions & 0 deletions cpp/src/groupby/hash/flatten_single_pass_aggs.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cudf/aggregation.hpp>
#include <cudf/groupby.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/span.hpp>

#include <memory>
#include <tuple>
#include <vector>

namespace cudf::groupby::detail::hash {

// flatten aggs to filter in single pass aggs
std::tuple<table_view, std::vector<aggregation::Kind>, std::vector<std::unique_ptr<aggregation>>>
flatten_single_pass_aggs(host_span<aggregation_request const> requests);

} // namespace cudf::groupby::detail::hash
105 changes: 1 addition & 104 deletions cpp/src/groupby/hash/groupby.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include "flatten_single_pass_aggs.hpp"
#include "groupby/common/utils.hpp"
#include "groupby/hash/groupby_kernels.cuh"

Expand Down Expand Up @@ -110,76 +111,6 @@ bool constexpr is_hash_aggregation(aggregation::Kind t)
return array_contains(hash_aggregations, t);
}

class groupby_simple_aggregations_collector final
: public cudf::detail::simple_aggregations_collector {
public:
using cudf::detail::simple_aggregations_collector::visit;

std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
cudf::detail::min_aggregation const&) override
{
std::vector<std::unique_ptr<aggregation>> aggs;
aggs.push_back(col_type.id() == type_id::STRING ? make_argmin_aggregation()
: make_min_aggregation());
return aggs;
}

std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
cudf::detail::max_aggregation const&) override
{
std::vector<std::unique_ptr<aggregation>> aggs;
aggs.push_back(col_type.id() == type_id::STRING ? make_argmax_aggregation()
: make_max_aggregation());
return aggs;
}

std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
cudf::detail::mean_aggregation const&) override
{
(void)col_type;
CUDF_EXPECTS(is_fixed_width(col_type), "MEAN aggregation expects fixed width type");
std::vector<std::unique_ptr<aggregation>> aggs;
aggs.push_back(make_sum_aggregation());
// COUNT_VALID
aggs.push_back(make_count_aggregation());

return aggs;
}

std::vector<std::unique_ptr<aggregation>> visit(data_type,
cudf::detail::var_aggregation const&) override
{
std::vector<std::unique_ptr<aggregation>> aggs;
aggs.push_back(make_sum_aggregation());
// COUNT_VALID
aggs.push_back(make_count_aggregation());

return aggs;
}

std::vector<std::unique_ptr<aggregation>> visit(data_type,
cudf::detail::std_aggregation const&) override
{
std::vector<std::unique_ptr<aggregation>> aggs;
aggs.push_back(make_sum_aggregation());
// COUNT_VALID
aggs.push_back(make_count_aggregation());

return aggs;
}

std::vector<std::unique_ptr<aggregation>> visit(
data_type, cudf::detail::correlation_aggregation const&) override
{
std::vector<std::unique_ptr<aggregation>> aggs;
aggs.push_back(make_sum_aggregation());
// COUNT_VALID
aggs.push_back(make_count_aggregation());

return aggs;
}
};

template <typename SetType>
class hash_compound_agg_finalizer final : public cudf::detail::aggregation_finalizer {
column_view col;
Expand Down Expand Up @@ -347,40 +278,6 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final
dense_results->add_result(col, agg, std::move(result));
}
};
// flatten aggs to filter in single pass aggs
std::tuple<table_view, std::vector<aggregation::Kind>, std::vector<std::unique_ptr<aggregation>>>
flatten_single_pass_aggs(host_span<aggregation_request const> requests)
{
std::vector<column_view> columns;
std::vector<std::unique_ptr<aggregation>> aggs;
std::vector<aggregation::Kind> agg_kinds;

for (auto const& request : requests) {
auto const& agg_v = request.aggregations;

std::unordered_set<aggregation::Kind> agg_kinds_set;
auto insert_agg = [&](column_view const& request_values, std::unique_ptr<aggregation>&& agg) {
if (agg_kinds_set.insert(agg->kind).second) {
agg_kinds.push_back(agg->kind);
aggs.push_back(std::move(agg));
columns.push_back(request_values);
}
};

auto values_type = cudf::is_dictionary(request.values.type())
? cudf::dictionary_column_view(request.values).keys().type()
: request.values.type();
for (auto&& agg : agg_v) {
groupby_simple_aggregations_collector collector;

for (auto& agg_s : agg->get_simple_aggregations(values_type, collector)) {
insert_agg(request.values, std::move(agg_s));
}
}
}

return std::make_tuple(table_view(columns), std::move(agg_kinds), std::move(aggs));
}

/**
* @brief Gather sparse results into dense using `gather_map` and add to
Expand Down

0 comments on commit eafdc93

Please sign in to comment.