Skip to content

Commit

Permalink
Support minBy on GPU (#44)
Browse files Browse the repository at this point in the history
Signed-off-by: Firestarman <[email protected]>
Co-authored-by: Firestarman <[email protected]>
  • Loading branch information
wjxiz1992 and firestarman authored Jul 5, 2024
1 parent 2ffaf94 commit 251f2d3
Show file tree
Hide file tree
Showing 51 changed files with 555 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/additional-functionality/advanced_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Last"></a>spark.rapids.sql.expression.Last|`last_value`, `last`|last aggregate operator|true|None|
<a name="sql.expression.Max"></a>spark.rapids.sql.expression.Max|`max`|Max aggregate operator|true|None|
<a name="sql.expression.Min"></a>spark.rapids.sql.expression.Min|`min`|Min aggregate operator|true|None|
<a name="sql.expression.MinBy"></a>spark.rapids.sql.expression.MinBy|`min_by`|MinBy aggregate operator. It may produce different results than CPU when multiple rows in a group have same minimum value in the ordering column and different associated values in the value column.|true|None|
<a name="sql.expression.Percentile"></a>spark.rapids.sql.expression.Percentile|`percentile`|Aggregation computing exact percentile|true|None|
<a name="sql.expression.PivotFirst"></a>spark.rapids.sql.expression.PivotFirst| |PivotFirst operator|true|None|
<a name="sql.expression.StddevPop"></a>spark.rapids.sql.expression.StddevPop|`stddev_pop`|Aggregation computing population standard deviation|true|None|
Expand Down
236 changes: 236 additions & 0 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -18137,6 +18137,164 @@ are limited.
<td><b>NS</b></td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="6">MinBy</td>
<td rowSpan="6">`min_by`</td>
<td rowSpan="6">MinBy aggregate operator. It may produce different results than CPU when multiple rows in a group have same minimum value in the ordering column and different associated values in the value column.</td>
<td rowSpan="6">None</td>
<td rowSpan="3">aggregation</td>
<td>value</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
<td>ordering</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
<td>result</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
<td rowSpan="3">reduction</td>
<td>value</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
<td>ordering</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
<td>result</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
<td rowSpan="8">Percentile</td>
<td rowSpan="8">`percentile`</td>
<td rowSpan="8">Aggregation computing exact percentile</td>
Expand Down Expand Up @@ -18469,6 +18627,32 @@ are limited.
<td><b>NS</b></td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="6">StddevPop</td>
<td rowSpan="6">`stddev_pop`</td>
<td rowSpan="6">Aggregation computing population standard deviation</td>
Expand Down Expand Up @@ -18894,6 +19078,32 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="6">VariancePop</td>
<td rowSpan="6">`var_pop`</td>
<td rowSpan="6">Aggregation computing population variance</td>
Expand Down Expand Up @@ -19259,6 +19469,32 @@ are limited.
<td><b>NS</b></td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="2">HiveGenericUDF</td>
<td rowSpan="2"> </td>
<td rowSpan="2">Hive Generic UDF, the UDF can choose to implement a RAPIDS accelerated interface to get better performance</td>
Expand Down
14 changes: 14 additions & 0 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,20 @@ def test_generic_reductions(data_gen):
'count(1)'),
conf=local_conf)

@ignore_order(local=True)
def test_hash_groupby_with_minby():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: three_col_df(spark, int_gen, int_gen, UniqueLongGen())
.groupby('a').agg(f.min_by('b', 'c'))
)

@ignore_order(local=True)
def test_reduction_with_minby():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, int_gen, UniqueLongGen()).selectExpr(
"min_by(a, b)")
)

@pytest.mark.parametrize('data_gen', all_gen + _nested_gens, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_count(data_gen):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2266,6 +2266,33 @@ object GpuOverrides extends Logging {
// Last does not overflow, so it doesn't need the ANSI check
override val needsAnsiCheck: Boolean = false
}),
expr[MinBy](
"MinBy aggregate operator. It may produce different results than CPU when " +
"multiple rows in a group have same minimum value in the ordering column and " +
"different associated values in the value column.",
ExprChecks.reductionAndGroupByAgg(
(TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.BINARY +
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(),
TypeSig.all,
Seq(
ParamCheck("value", (TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.BINARY
+ TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128).nested(),
TypeSig.all),
ParamCheck("ordering", (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL
+ TypeSig.STRUCT + TypeSig.ARRAY).nested(),
TypeSig.orderable))
),
(minBy, conf, p, r) => new AggExprMeta[MinBy](minBy, conf, p, r) {

override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = {
// Only two children (value expression, ordering expression)
require(childExprs.length == 2)
GpuMinBy(childExprs.head, childExprs.last)
}

// MinBy does not overflow, so it doesn't need the ANSI check
override val needsAnsiCheck: Boolean = false
}),
expr[BRound](
"Round an expression to d decimal places using HALF_EVEN rounding mode",
ExprChecks.binaryProject(
Expand Down
Loading

0 comments on commit 251f2d3

Please sign in to comment.