diff --git a/python/cudf_polars/cudf_polars/experimental/groupby.py b/python/cudf_polars/cudf_polars/experimental/groupby.py index 35829420b95..282c16856bd 100644 --- a/python/cudf_polars/cudf_polars/experimental/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/groupby.py @@ -119,6 +119,8 @@ def lower_groupby_node( ir.options, *children, ) + child_count = partition_info[children[0]].count + partition_info[gb_pwise] = PartitionInfo(count=child_count) gb_tree = GroupByTree( ir.schema, @@ -128,6 +130,7 @@ def lower_groupby_node( ir.options, gb_pwise, ) + partition_info[gb_tree] = PartitionInfo(count=1) schema = ir.schema output_exprs = []