diff --git a/python/cudf/cudf/_lib/pylibcudf/aggregation.pxd b/python/cudf/cudf/_lib/pylibcudf/aggregation.pxd index 6e78cbbb6b1..1b7da5a5532 100644 --- a/python/cudf/cudf/_lib/pylibcudf/aggregation.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/aggregation.pxd @@ -31,17 +31,11 @@ ctypedef groupby_scan_aggregation * gbsa_ptr ctypedef reduce_aggregation * ra_ptr ctypedef scan_aggregation * sa_ptr -ctypedef fused agg_ptr: - gba_ptr - gbsa_ptr - ra_ptr - sa_ptr - cdef class Aggregation: cdef unique_ptr[aggregation] c_obj cpdef kind(self) - cdef agg_ptr _raise_if_null(self, agg_ptr ptr, str alg) + cdef void _unsupported_agg_error(self, str alg) cdef unique_ptr[groupby_aggregation] clone_underlying_as_groupby(self) except * cdef unique_ptr[groupby_scan_aggregation] clone_underlying_as_groupby_scan( self diff --git a/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx b/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx index 551da4f240c..0020a0c681d 100644 --- a/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx +++ b/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx @@ -83,21 +83,19 @@ cdef class Aggregation: """Get the kind of the aggregation.""" return dereference(self.c_obj).kind - cdef agg_ptr _raise_if_null(self, agg_ptr ptr, str alg): - # The functions calling this all use a dynamic cast between aggregation types, + cdef void _unsupported_agg_error(self, str alg): + # Te functions calling this all use a dynamic cast between aggregation types, # and the cast returning a null pointer is how we capture whether or not # libcudf supports a given aggregation for a particular algorithm. - if ptr is NULL: - agg_repr = str(self.kind()).split(".")[1].title() - raise TypeError(f"{agg_repr} aggregations are not supported by {alg}") - return ptr + agg_repr = str(self.kind()).split(".")[1].title() + raise TypeError(f"{agg_repr} aggregations are not supported by {alg}") cdef unique_ptr[groupby_aggregation] clone_underlying_as_groupby(self) except *: """Make a copy of the aggregation that can be used in a groupby.""" cdef unique_ptr[aggregation] agg = dereference(self.c_obj).clone() - cdef groupby_aggregation *agg_cast = self._raise_if_null( - dynamic_cast[gba_ptr](agg.get()), "groupby" - ) + cdef groupby_aggregation *agg_cast = dynamic_cast[gba_ptr](agg.get()) + if agg_cast is NULL: + self._unsupported_agg_error("groupby") agg.release() return unique_ptr[groupby_aggregation](agg_cast) @@ -106,19 +104,25 @@ cdef class Aggregation: ) except *: """Make a copy of the aggregation that can be used in a groupby scan.""" cdef unique_ptr[aggregation] agg = dereference(self.c_obj).clone() - cdef groupby_scan_aggregation *agg_cast = self._raise_if_null( - dynamic_cast[gbsa_ptr](agg.get()), "groupby scan" - ) + cdef groupby_scan_aggregation *agg_cast = dynamic_cast[gbsa_ptr](agg.get()) + if agg_cast is NULL: + self._unsupported_agg_error("groupby_scan") agg.release() return unique_ptr[groupby_scan_aggregation](agg_cast) cdef const reduce_aggregation* view_underlying_as_reduce(self) except *: """View the underlying aggregation as a reduce_aggregation.""" - return self._raise_if_null(dynamic_cast[ra_ptr](self.c_obj.get()), "reduce") + cdef reduce_aggregation *agg_cast = dynamic_cast[ra_ptr](self.c_obj.get()) + if agg_cast is NULL: + self._unsupported_agg_error("reduce") + return agg_cast cdef const scan_aggregation* view_underlying_as_scan(self) except *: """View the underlying aggregation as a scan_aggregation.""" - return self._raise_if_null(dynamic_cast[sa_ptr](self.c_obj.get()), "scan") + cdef scan_aggregation *agg_cast = dynamic_cast[sa_ptr](self.c_obj.get()) + if agg_cast is NULL: + self._unsupported_agg_error("scan") + return agg_cast @staticmethod cdef Aggregation from_libcudf(unique_ptr[aggregation] agg):