Skip to content

Commit

Permalink
Support more numeric types in Groupby.apply with engine='jit' (ra…
Browse files Browse the repository at this point in the history
…pidsai#13729)

draft

This PR adds additional numeric dtypes to `GroupBy.apply` with `engine='jit'`.

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: rapidsai#13729
  • Loading branch information
brandon-b-miller authored Jul 25, 2023
1 parent 2a590db commit 43aca00
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 15 deletions.
68 changes: 54 additions & 14 deletions python/cudf/cudf/core/udf/groupby_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

index_default_type = types.int64
group_size_type = types.int64
SUPPORTED_GROUPBY_NUMBA_TYPES = [types.int64, types.float64]
SUPPORTED_GROUPBY_NUMBA_TYPES = [
types.int32,
types.int64,
types.float32,
types.float64,
]
SUPPORTED_GROUPBY_NUMPY_TYPES = [
numpy_support.as_dtype(dt) for dt in [types.int64, types.float64]
numpy_support.as_dtype(dt) for dt in SUPPORTED_GROUPBY_NUMBA_TYPES
]


Expand Down Expand Up @@ -133,6 +138,25 @@ def caller(data, index, size):
call_cuda_functions[funcname.lower()][type_key] = caller


def _make_unary_attr(funcname):
class GroupUnaryReductionAttrTyping(AbstractTemplate):
key = f"GroupType.{funcname}"

def generic(self, args, kws):
for retty, inputty in call_cuda_functions[funcname.lower()].keys():
if self.this.group_scalar_type == inputty:
return nb_signature(retty, recvr=self.this)
return None

def _attr(self, mod):
return types.BoundFunction(
GroupUnaryReductionAttrTyping,
GroupType(mod.group_scalar_type, mod.index_type),
)

return _attr


def _create_reduction_attr(name, retty=None):
class Attr(AbstractTemplate):
key = name
Expand Down Expand Up @@ -171,21 +195,20 @@ def generic(self, args, kws):
class GroupAttr(AttributeTemplate):
key = GroupType

resolve_max = _create_reduction_attr("GroupType.max")
resolve_min = _create_reduction_attr("GroupType.min")
resolve_sum = _create_reduction_attr("GroupType.sum")
resolve_max = _make_unary_attr("max")
resolve_min = _make_unary_attr("min")
resolve_sum = _make_unary_attr("sum")

resolve_mean = _make_unary_attr("mean")
resolve_var = _make_unary_attr("var")
resolve_std = _make_unary_attr("std")

resolve_size = _create_reduction_attr(
"GroupType.size", retty=group_size_type
)
resolve_count = _create_reduction_attr(
"GroupType.count", retty=types.int64
)
resolve_mean = _create_reduction_attr(
"GroupType.mean", retty=types.float64
)
resolve_var = _create_reduction_attr("GroupType.var", retty=types.float64)
resolve_std = _create_reduction_attr("GroupType.std", retty=types.float64)

def resolve_idxmax(self, mod):
return types.BoundFunction(
Expand All @@ -201,13 +224,30 @@ def resolve_idxmin(self, mod):
for ty in SUPPORTED_GROUPBY_NUMBA_TYPES:
_register_cuda_reduction_caller("Max", ty, ty)
_register_cuda_reduction_caller("Min", ty, ty)
_register_cuda_reduction_caller("Sum", ty, ty)
_register_cuda_reduction_caller("Mean", ty, types.float64)
_register_cuda_reduction_caller("Std", ty, types.float64)
_register_cuda_reduction_caller("Var", ty, types.float64)
_register_cuda_idx_reduction_caller("IdxMax", ty)
_register_cuda_idx_reduction_caller("IdxMin", ty)

_register_cuda_reduction_caller("Sum", types.int32, types.int64)
_register_cuda_reduction_caller("Sum", types.int64, types.int64)
_register_cuda_reduction_caller("Sum", types.float32, types.float32)
_register_cuda_reduction_caller("Sum", types.float64, types.float64)


_register_cuda_reduction_caller("Mean", types.int32, types.float64)
_register_cuda_reduction_caller("Mean", types.int64, types.float64)
_register_cuda_reduction_caller("Mean", types.float32, types.float32)
_register_cuda_reduction_caller("Mean", types.float64, types.float64)

_register_cuda_reduction_caller("Std", types.int32, types.float64)
_register_cuda_reduction_caller("Std", types.int64, types.float64)
_register_cuda_reduction_caller("Std", types.float32, types.float32)
_register_cuda_reduction_caller("Std", types.float64, types.float64)

_register_cuda_reduction_caller("Var", types.int32, types.float64)
_register_cuda_reduction_caller("Var", types.int64, types.float64)
_register_cuda_reduction_caller("Var", types.float32, types.float32)
_register_cuda_reduction_caller("Var", types.float64, types.float64)


for attr in ("group_data", "index", "size"):
make_attribute_wrapper(GroupType, attr, attr)
6 changes: 5 additions & 1 deletion python/cudf/cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,11 @@ def run_groupby_apply_jit_test(data, func, keys, *args):
assert_groupby_results_equal(cudf_jit_result, pandas_result)


@pytest.mark.parametrize("dtype", SUPPORTED_GROUPBY_NUMPY_TYPES)
@pytest.mark.parametrize(
"dtype",
SUPPORTED_GROUPBY_NUMPY_TYPES,
ids=[str(t) for t in SUPPORTED_GROUPBY_NUMPY_TYPES],
)
@pytest.mark.parametrize(
"func", ["min", "max", "sum", "mean", "var", "std", "idxmin", "idxmax"]
)
Expand Down
22 changes: 22 additions & 0 deletions python/cudf/udf_cpp/shim.cu
Original file line number Diff line number Diff line change
Expand Up @@ -630,17 +630,34 @@ extern "C" {
return 0; \
}

make_definition(BlockSum, int32, int32_t, int64_t);
make_definition(BlockSum, int64, int64_t, int64_t);
make_definition(BlockSum, float32, float, float);
make_definition(BlockSum, float64, double, double);

make_definition(BlockMean, int32, int32_t, double);
make_definition(BlockMean, int64, int64_t, double);
make_definition(BlockMean, float32, float, float);
make_definition(BlockMean, float64, double, double);

make_definition(BlockStd, int32, int32_t, double);
make_definition(BlockStd, int64, int64_t, double);
make_definition(BlockStd, float32, float, float);
make_definition(BlockStd, float64, double, double);

make_definition(BlockVar, int64, int64_t, double);
make_definition(BlockVar, int32, int32_t, double);
make_definition(BlockVar, float32, float, float);
make_definition(BlockVar, float64, double, double);

make_definition(BlockMin, int32, int32_t, int32_t);
make_definition(BlockMin, int64, int64_t, int64_t);
make_definition(BlockMin, float32, float, float);
make_definition(BlockMin, float64, double, double);

make_definition(BlockMax, int32, int32_t, int32_t);
make_definition(BlockMax, int64, int64_t, int64_t);
make_definition(BlockMax, float32, float, float);
make_definition(BlockMax, float64, double, double);
#undef make_definition
}
Expand All @@ -656,9 +673,14 @@ extern "C" {
return 0; \
}

make_definition_idx(BlockIdxMin, int32, int32_t);
make_definition_idx(BlockIdxMin, int64, int64_t);
make_definition_idx(BlockIdxMin, float32, float);
make_definition_idx(BlockIdxMin, float64, double);

make_definition_idx(BlockIdxMax, int32, int32_t);
make_definition_idx(BlockIdxMax, int64, int64_t);
make_definition_idx(BlockIdxMax, float32, float);
make_definition_idx(BlockIdxMax, float64, double);
#undef make_definition_idx
}

0 comments on commit 43aca00

Please sign in to comment.