diff --git a/python/cudf/cudf/core/udf/groupby_typing.py b/python/cudf/cudf/core/udf/groupby_typing.py index 37381a95fdf..bc6a084f2b4 100644 --- a/python/cudf/cudf/core/udf/groupby_typing.py +++ b/python/cudf/cudf/core/udf/groupby_typing.py @@ -17,9 +17,14 @@ index_default_type = types.int64 group_size_type = types.int64 -SUPPORTED_GROUPBY_NUMBA_TYPES = [types.int64, types.float64] +SUPPORTED_GROUPBY_NUMBA_TYPES = [ + types.int32, + types.int64, + types.float32, + types.float64, +] SUPPORTED_GROUPBY_NUMPY_TYPES = [ - numpy_support.as_dtype(dt) for dt in [types.int64, types.float64] + numpy_support.as_dtype(dt) for dt in SUPPORTED_GROUPBY_NUMBA_TYPES ] @@ -133,6 +138,25 @@ def caller(data, index, size): call_cuda_functions[funcname.lower()][type_key] = caller +def _make_unary_attr(funcname): + class GroupUnaryReductionAttrTyping(AbstractTemplate): + key = f"GroupType.{funcname}" + + def generic(self, args, kws): + for retty, inputty in call_cuda_functions[funcname.lower()].keys(): + if self.this.group_scalar_type == inputty: + return nb_signature(retty, recvr=self.this) + return None + + def _attr(self, mod): + return types.BoundFunction( + GroupUnaryReductionAttrTyping, + GroupType(mod.group_scalar_type, mod.index_type), + ) + + return _attr + + def _create_reduction_attr(name, retty=None): class Attr(AbstractTemplate): key = name @@ -171,9 +195,13 @@ def generic(self, args, kws): class GroupAttr(AttributeTemplate): key = GroupType - resolve_max = _create_reduction_attr("GroupType.max") - resolve_min = _create_reduction_attr("GroupType.min") - resolve_sum = _create_reduction_attr("GroupType.sum") + resolve_max = _make_unary_attr("max") + resolve_min = _make_unary_attr("min") + resolve_sum = _make_unary_attr("sum") + + resolve_mean = _make_unary_attr("mean") + resolve_var = _make_unary_attr("var") + resolve_std = _make_unary_attr("std") resolve_size = _create_reduction_attr( "GroupType.size", retty=group_size_type @@ -181,11 +209,6 @@ class GroupAttr(AttributeTemplate): resolve_count = _create_reduction_attr( "GroupType.count", retty=types.int64 ) - resolve_mean = _create_reduction_attr( - "GroupType.mean", retty=types.float64 - ) - resolve_var = _create_reduction_attr("GroupType.var", retty=types.float64) - resolve_std = _create_reduction_attr("GroupType.std", retty=types.float64) def resolve_idxmax(self, mod): return types.BoundFunction( @@ -201,13 +224,30 @@ def resolve_idxmin(self, mod): for ty in SUPPORTED_GROUPBY_NUMBA_TYPES: _register_cuda_reduction_caller("Max", ty, ty) _register_cuda_reduction_caller("Min", ty, ty) - _register_cuda_reduction_caller("Sum", ty, ty) - _register_cuda_reduction_caller("Mean", ty, types.float64) - _register_cuda_reduction_caller("Std", ty, types.float64) - _register_cuda_reduction_caller("Var", ty, types.float64) _register_cuda_idx_reduction_caller("IdxMax", ty) _register_cuda_idx_reduction_caller("IdxMin", ty) +_register_cuda_reduction_caller("Sum", types.int32, types.int64) +_register_cuda_reduction_caller("Sum", types.int64, types.int64) +_register_cuda_reduction_caller("Sum", types.float32, types.float32) +_register_cuda_reduction_caller("Sum", types.float64, types.float64) + + +_register_cuda_reduction_caller("Mean", types.int32, types.float64) +_register_cuda_reduction_caller("Mean", types.int64, types.float64) +_register_cuda_reduction_caller("Mean", types.float32, types.float32) +_register_cuda_reduction_caller("Mean", types.float64, types.float64) + +_register_cuda_reduction_caller("Std", types.int32, types.float64) +_register_cuda_reduction_caller("Std", types.int64, types.float64) +_register_cuda_reduction_caller("Std", types.float32, types.float32) +_register_cuda_reduction_caller("Std", types.float64, types.float64) + +_register_cuda_reduction_caller("Var", types.int32, types.float64) +_register_cuda_reduction_caller("Var", types.int64, types.float64) +_register_cuda_reduction_caller("Var", types.float32, types.float32) +_register_cuda_reduction_caller("Var", types.float64, types.float64) + for attr in ("group_data", "index", "size"): make_attribute_wrapper(GroupType, attr, attr) diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index dde80639fc7..48092be390d 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -402,7 +402,11 @@ def run_groupby_apply_jit_test(data, func, keys, *args): assert_groupby_results_equal(cudf_jit_result, pandas_result) -@pytest.mark.parametrize("dtype", SUPPORTED_GROUPBY_NUMPY_TYPES) +@pytest.mark.parametrize( + "dtype", + SUPPORTED_GROUPBY_NUMPY_TYPES, + ids=[str(t) for t in SUPPORTED_GROUPBY_NUMPY_TYPES], +) @pytest.mark.parametrize( "func", ["min", "max", "sum", "mean", "var", "std", "idxmin", "idxmax"] ) diff --git a/python/cudf/udf_cpp/shim.cu b/python/cudf/udf_cpp/shim.cu index 63ad1039da6..a81c8238f76 100644 --- a/python/cudf/udf_cpp/shim.cu +++ b/python/cudf/udf_cpp/shim.cu @@ -630,17 +630,34 @@ extern "C" { return 0; \ } +make_definition(BlockSum, int32, int32_t, int64_t); make_definition(BlockSum, int64, int64_t, int64_t); +make_definition(BlockSum, float32, float, float); make_definition(BlockSum, float64, double, double); + +make_definition(BlockMean, int32, int32_t, double); make_definition(BlockMean, int64, int64_t, double); +make_definition(BlockMean, float32, float, float); make_definition(BlockMean, float64, double, double); + +make_definition(BlockStd, int32, int32_t, double); make_definition(BlockStd, int64, int64_t, double); +make_definition(BlockStd, float32, float, float); make_definition(BlockStd, float64, double, double); + make_definition(BlockVar, int64, int64_t, double); +make_definition(BlockVar, int32, int32_t, double); +make_definition(BlockVar, float32, float, float); make_definition(BlockVar, float64, double, double); + +make_definition(BlockMin, int32, int32_t, int32_t); make_definition(BlockMin, int64, int64_t, int64_t); +make_definition(BlockMin, float32, float, float); make_definition(BlockMin, float64, double, double); + +make_definition(BlockMax, int32, int32_t, int32_t); make_definition(BlockMax, int64, int64_t, int64_t); +make_definition(BlockMax, float32, float, float); make_definition(BlockMax, float64, double, double); #undef make_definition } @@ -656,9 +673,14 @@ extern "C" { return 0; \ } +make_definition_idx(BlockIdxMin, int32, int32_t); make_definition_idx(BlockIdxMin, int64, int64_t); +make_definition_idx(BlockIdxMin, float32, float); make_definition_idx(BlockIdxMin, float64, double); + +make_definition_idx(BlockIdxMax, int32, int32_t); make_definition_idx(BlockIdxMax, int64, int64_t); +make_definition_idx(BlockIdxMax, float32, float); make_definition_idx(BlockIdxMax, float64, double); #undef make_definition_idx }