Skip to content

Commit

Permalink
Address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
vyasr committed Feb 6, 2024
1 parent 68a0a76 commit 6ce8f45
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 21 deletions.
8 changes: 1 addition & 7 deletions python/cudf/cudf/_lib/pylibcudf/aggregation.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,11 @@ ctypedef groupby_scan_aggregation * gbsa_ptr
ctypedef reduce_aggregation * ra_ptr
ctypedef scan_aggregation * sa_ptr

ctypedef fused agg_ptr:
gba_ptr
gbsa_ptr
ra_ptr
sa_ptr


cdef class Aggregation:
cdef unique_ptr[aggregation] c_obj
cpdef kind(self)
cdef agg_ptr _raise_if_null(self, agg_ptr ptr, str alg)
cdef void _unsupported_agg_error(self, str alg)
cdef unique_ptr[groupby_aggregation] clone_underlying_as_groupby(self) except *
cdef unique_ptr[groupby_scan_aggregation] clone_underlying_as_groupby_scan(
self
Expand Down
32 changes: 18 additions & 14 deletions python/cudf/cudf/_lib/pylibcudf/aggregation.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,19 @@ cdef class Aggregation:
"""Get the kind of the aggregation."""
return dereference(self.c_obj).kind

cdef agg_ptr _raise_if_null(self, agg_ptr ptr, str alg):
# The functions calling this all use a dynamic cast between aggregation types,
cdef void _unsupported_agg_error(self, str alg):
# Te functions calling this all use a dynamic cast between aggregation types,
# and the cast returning a null pointer is how we capture whether or not
# libcudf supports a given aggregation for a particular algorithm.
if ptr is NULL:
agg_repr = str(self.kind()).split(".")[1].title()
raise TypeError(f"{agg_repr} aggregations are not supported by {alg}")
return ptr
agg_repr = str(self.kind()).split(".")[1].title()
raise TypeError(f"{agg_repr} aggregations are not supported by {alg}")

cdef unique_ptr[groupby_aggregation] clone_underlying_as_groupby(self) except *:
"""Make a copy of the aggregation that can be used in a groupby."""
cdef unique_ptr[aggregation] agg = dereference(self.c_obj).clone()
cdef groupby_aggregation *agg_cast = self._raise_if_null(
dynamic_cast[gba_ptr](agg.get()), "groupby"
)
cdef groupby_aggregation *agg_cast = dynamic_cast[gba_ptr](agg.get())
if agg_cast is NULL:
self._unsupported_agg_error("groupby")
agg.release()
return unique_ptr[groupby_aggregation](agg_cast)

Expand All @@ -106,19 +104,25 @@ cdef class Aggregation:
) except *:
"""Make a copy of the aggregation that can be used in a groupby scan."""
cdef unique_ptr[aggregation] agg = dereference(self.c_obj).clone()
cdef groupby_scan_aggregation *agg_cast = self._raise_if_null(
dynamic_cast[gbsa_ptr](agg.get()), "groupby scan"
)
cdef groupby_scan_aggregation *agg_cast = dynamic_cast[gbsa_ptr](agg.get())
if agg_cast is NULL:
self._unsupported_agg_error("groupby_scan")
agg.release()
return unique_ptr[groupby_scan_aggregation](agg_cast)

cdef const reduce_aggregation* view_underlying_as_reduce(self) except *:
"""View the underlying aggregation as a reduce_aggregation."""
return self._raise_if_null(dynamic_cast[ra_ptr](self.c_obj.get()), "reduce")
cdef reduce_aggregation *agg_cast = dynamic_cast[ra_ptr](self.c_obj.get())
if agg_cast is NULL:
self._unsupported_agg_error("reduce")
return agg_cast

cdef const scan_aggregation* view_underlying_as_scan(self) except *:
"""View the underlying aggregation as a scan_aggregation."""
return self._raise_if_null(dynamic_cast[sa_ptr](self.c_obj.get()), "scan")
cdef scan_aggregation *agg_cast = dynamic_cast[sa_ptr](self.c_obj.get())
if agg_cast is NULL:
self._unsupported_agg_error("scan")
return agg_cast

@staticmethod
cdef Aggregation from_libcudf(unique_ptr[aggregation] agg):
Expand Down

0 comments on commit 6ce8f45

Please sign in to comment.