Skip to content

Commit

Permalink
Support negative preceding/following for ROW-based window functions (#…
Browse files Browse the repository at this point in the history
…9229)

This commit adds support for negative values for preceding/following offsets specified for `ROW` based window functions.

Prior to this commit, window function queries such as the following were not supported:
```SQL
SELECT MIN(x) OVER (PARTITION BY grp 
                    ORDER BY oby 
                    ROWS BETWEEN 5 PRECEDING AND -1 FOLLOWING) min_x 
FROM mytable
```
For this query, the window includes all rows between upto 5 rows preceding the current row, and the previous row.

This functionality is currently supported only for:
1. `AVG`
2. `COUNT(1)`/`COUNT(*)`
3. `MAX`
4. `MIN`
5. `SUM`
6. `COLLECT_LIST`
7. `COLLECT_SET`

Signed-off-by: MithunR <[email protected]>
  • Loading branch information
mythrocks authored Sep 26, 2023
1 parent c5f40d5 commit e976724
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 18 deletions.
99 changes: 98 additions & 1 deletion integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyspark.sql.types import NumericType
from pyspark.sql.window import Window
import pyspark.sql.functions as f
from spark_session import is_before_spark_320, is_databricks113_or_later
from spark_session import is_before_spark_320, is_databricks113_or_later, spark_version
import warnings

_grpkey_longs_with_no_nulls = [
Expand Down Expand Up @@ -1486,6 +1486,103 @@ def test_to_date_with_window_functions():
"""
)


@ignore_order(local=True)
@approximate_float
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn)
@pytest.mark.parametrize('data_gen', [_grpkey_longs_with_no_nulls,
_grpkey_longs_with_nulls,
_grpkey_longs_with_dates,
_grpkey_longs_with_nullable_dates,
_grpkey_longs_with_decimals,
_grpkey_longs_with_nullable_decimals,
_grpkey_longs_with_nullable_larger_decimals
], ids=idfn)
@pytest.mark.parametrize('window_spec', ["3 PRECEDING AND -1 FOLLOWING",
"-2 PRECEDING AND 4 FOLLOWING",
"UNBOUNDED PRECEDING AND -1 FOLLOWING",
"-1 PRECEDING AND UNBOUNDED FOLLOWING",
"10 PRECEDING AND -1 FOLLOWING",
"5 PRECEDING AND -2 FOLLOWING"], ids=idfn)
def test_window_aggs_for_negative_rows_partitioned(data_gen, batch_size, window_spec):
conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
'spark.rapids.sql.castFloatToDecimal.enabled': True}
assert_gpu_and_cpu_are_equal_sql(
lambda spark: gen_df(spark, data_gen, length=2048),
"window_agg_table",
'SELECT '
' SUM(c) OVER '
' (PARTITION BY a ORDER BY b,c ASC ROWS BETWEEN {window}) AS sum_c_asc, '
' MAX(c) OVER '
' (PARTITION BY a ORDER BY b DESC, c DESC ROWS BETWEEN {window}) AS max_c_desc, '
' MIN(c) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS min_c_asc, '
' COUNT(1) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS count_1, '
' COUNT(c) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS count_c, '
' AVG(c) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS avg_c, '
' COLLECT_LIST(c) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS list_c, '
' SORT_ARRAY(COLLECT_SET(c) OVER '
' (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window})) AS sorted_set_c '
'FROM window_agg_table '.format(window=window_spec),
conf=conf)


def spark_bugs_in_decimal_sorting():
"""
Checks whether Apache Spark version has a bug in sorting Decimal columns correctly.
See https://issues.apache.org/jira/browse/SPARK-40089.
:return: True, if Apache Spark version does not sort Decimal(>20, >2) correctly. False, otherwise.
"""
v = spark_version()
return v < "3.1.4" or v < "3.3.1" or v < "3.2.3" or v < "3.4.0"


@ignore_order(local=True)
@approximate_float
@pytest.mark.parametrize('batch_size', ['1g'], ids=idfn)
@pytest.mark.parametrize('data_gen', [_grpkey_longs_with_no_nulls,
_grpkey_longs_with_nulls,
_grpkey_longs_with_dates,
_grpkey_longs_with_nullable_dates,
_grpkey_longs_with_decimals,
_grpkey_longs_with_nullable_decimals,
pytest.param(_grpkey_longs_with_nullable_larger_decimals,
marks=pytest.mark.skipif(
condition=spark_bugs_in_decimal_sorting(),
reason='https://github.com/NVIDIA/spark-rapids/issues/7429'))],
ids=idfn)
def test_window_aggs_for_negative_rows_unpartitioned(data_gen, batch_size):
conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
'spark.rapids.sql.castFloatToDecimal.enabled': True}

assert_gpu_and_cpu_are_equal_sql(
lambda spark: gen_df(spark, data_gen, length=2048),
"window_agg_table",
'SELECT '
' SUM(c) OVER '
' (ORDER BY b,c,a ROWS BETWEEN 3 PRECEDING AND -1 FOLLOWING) AS sum_c_asc, '
' MAX(c) OVER '
' (ORDER BY b DESC, c DESC, a DESC ROWS BETWEEN -2 PRECEDING AND 4 FOLLOWING) AS max_c_desc, '
' min(c) OVER '
' (ORDER BY b,c,a ROWS BETWEEN UNBOUNDED PRECEDING AND -1 FOLLOWING) AS min_c_asc, '
' COUNT(1) OVER '
' (ORDER BY b,c,a ROWS BETWEEN -1 PRECEDING AND UNBOUNDED FOLLOWING) AS count_1, '
' COUNT(c) OVER '
' (ORDER BY b,c,a ROWS BETWEEN 10 PRECEDING AND -1 FOLLOWING) AS count_c, '
' AVG(c) OVER '
' (ORDER BY b,c,a ROWS BETWEEN -1 PRECEDING AND UNBOUNDED FOLLOWING) AS avg_c, '
' COLLECT_LIST(c) OVER '
' (PARTITION BY a ORDER BY b,c,a ROWS BETWEEN 5 PRECEDING AND -2 FOLLOWING) AS list_c, '
' SORT_ARRAY(COLLECT_SET(c) OVER '
' (PARTITION BY a ORDER BY b,c,a ROWS BETWEEN 5 PRECEDING AND -2 FOLLOWING)) AS set_c '
'FROM window_agg_table ',
conf=conf)


def test_lru_cache_datagen():
# log cache info at the end of integration tests, not related to window functions
info = gen_df_help.cache_info()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -689,12 +689,13 @@ object GroupedAggregations {
private def getWindowOptions(
orderSpec: Seq[SortOrder],
orderPositions: Seq[Int],
frame: GpuSpecifiedWindowFrame): WindowOptions = {
frame: GpuSpecifiedWindowFrame,
minPeriods: Int): WindowOptions = {
frame.frameType match {
case RowFrame =>
withResource(getRowBasedLower(frame)) { lower =>
withResource(getRowBasedUpper(frame)) { upper =>
val builder = WindowOptions.builder().minPeriods(1)
val builder = WindowOptions.builder().minPeriods(minPeriods)
if (isUnbounded(frame.lower)) builder.unboundedPreceding() else builder.preceding(lower)
if (isUnbounded(frame.upper)) builder.unboundedFollowing() else builder.following(upper)
builder.build
Expand All @@ -718,7 +719,7 @@ object GroupedAggregations {
withResource(asScalarRangeBoundary(orderType, lower)) { preceding =>
withResource(asScalarRangeBoundary(orderType, upper)) { following =>
val windowOptionBuilder = WindowOptions.builder()
.minPeriods(1)
.minPeriods(1) // Does not currently support custom minPeriods.
.orderByColumnIndex(orderByIndex)

if (preceding.isEmpty) {
Expand Down Expand Up @@ -929,13 +930,18 @@ class GroupedAggregations {
if (frameSpec.frameType == frameType) {
// For now I am going to assume that we don't need to combine calls across frame specs
// because it would just not help that much
val result = withResource(
getWindowOptions(boundOrderSpec, orderByPositions, frameSpec)) { windowOpts =>
val allAggs = functions.map {
case (winFunc, _) => winFunc.aggOverWindow(inputCb, windowOpts)
}.toSeq
withResource(GpuColumnVector.from(inputCb)) { initProjTab =>
aggIt(initProjTab.groupBy(partByPositions: _*), allAggs)
val result = {
val allWindowOpts = functions.map { f =>
getWindowOptions(boundOrderSpec, orderByPositions, frameSpec,
f._1.windowFunc.getMinPeriods)
}
withResource(allWindowOpts.toSeq) { allWindowOpts =>
val allAggs = allWindowOpts.zip(functions).map { case (windowOpt, f) =>
f._1.aggOverWindow(inputCb, windowOpt)
}
withResource(GpuColumnVector.from(inputCb)) { initProjTab =>
aggIt(initProjTab.groupBy(partByPositions: _*), allAggs)
}
}
}
withResource(result) { result =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, Max, Min, Sum}
import org.apache.spark.sql.rapids.{AddOverflowChecks, GpuAggregateExpression, GpuCount, GpuCreateNamedStruct, GpuDivide, GpuSubtract}
import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -81,13 +82,25 @@ abstract class GpuWindowExpressionMetaBase(
case _: Lead | _: Lag => // ignored we are good
case _ =>
// need to be sure that the lower/upper are acceptable
if (lower > 0) {
willNotWorkOnGpu(s"lower-bounds ahead of current row is not supported. " +
s"Found $lower")
// Negative bounds are allowed, so long as lower does not exceed upper.
if (upper < lower) {
willNotWorkOnGpu("upper-bounds must equal or exceed the lower bounds. " +
s"Found lower=$lower, upper=$upper ")
}
if (upper < 0) {
willNotWorkOnGpu(s"upper-bounds behind the current row is not supported. " +
s"Found $upper")
// Also check for negative offsets.
if (upper < 0 || lower > 0) {
windowFunction.asInstanceOf[AggregateExpression].aggregateFunction match {
case _: Average => // Supported
case _: CollectList => // Supported
case _: CollectSet => // Supported
case _: Count => // Supported
case _: Max => // Supported
case _: Min => // Supported
case _: Sum => // Supported
case f: AggregateFunction =>
willNotWorkOnGpu("negative row bounds unsupported for specified " +
s"aggregation: ${f.prettyName}")
}
}
}
case RangeFrame =>
Expand Down Expand Up @@ -649,7 +662,15 @@ case class GpuSpecialFrameBoundary(boundary : SpecialFrameBoundary)

// This is here for now just to tag an expression as being a GpuWindowFunction and match
// Spark. This may expand in the future if other types of window functions show up.
trait GpuWindowFunction extends GpuUnevaluable with ShimExpression
trait GpuWindowFunction extends GpuUnevaluable with ShimExpression {
/**
* Get "min-periods" value, i.e. the minimum number of periods/rows
* above which a non-null value is returned for the function.
* Otherwise, null is returned.
* @return Non-negative value for min-periods.
*/
def getMinPeriods: Int = 1
}

/**
* This is a special window function that simply replaces itself with one or more
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1586,6 +1586,13 @@ case class GpuCount(children: Seq[Expression],

override def newUnboundedToUnboundedFixer: BatchedUnboundedToUnboundedWindowFixer =
new CountUnboundedToUnboundedFixer(failOnError)

// minPeriods should be 0.
// Consider the following rows:
// v = [ 0, 1, 2, 3, 4, 5 ]
// A `COUNT` window aggregation over (2, -1) should yield 0, not null,
// for the first row.
override def getMinPeriods: Int = 0
}

object GpuAverage {
Expand Down Expand Up @@ -1971,6 +1978,12 @@ case class GpuCollectList(
override def windowAggregation(
inputs: Seq[(ColumnVector, Int)]): RollingAggregationOnColumn =
RollingAggregation.collectList().onColumn(inputs.head._2)

// minPeriods should be 0.
// Consider the following rows: v = [ 0, 1, 2, 3, 4, 5 ]
// A `COLLECT_LIST` window aggregation over (2, -1) should yield an empty array [],
// not null, for the first row.
override def getMinPeriods: Int = 0
}

/**
Expand Down Expand Up @@ -2005,6 +2018,12 @@ case class GpuCollectSet(
RollingAggregation.collectSet(NullPolicy.EXCLUDE, NullEquality.EQUAL,
NaNEquality.ALL_EQUAL).onColumn(inputs.head._2)
}

// minPeriods should be 0.
// Consider the following rows: v = [ 0, 1, 2, 3, 4, 5 ]
// A `COLLECT_SET` window aggregation over (2, -1) should yield an empty array [],
// not null, for the first row.
override def getMinPeriods: Int = 0
}

trait CpuToGpuAggregateBufferConverter {
Expand Down

0 comments on commit e976724

Please sign in to comment.