-
Notifications
You must be signed in to change notification settings - Fork 242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement SumUnboundedToUnboundedFixer (second attempt) #9097
Conversation
Signed-off-by: Andy Grove <[email protected]>
77d3b3a
to
97af861
Compare
@pytest.mark.parametrize('data_gen', numeric_gens, ids=idfn) | ||
def test_numeric_running_sum_window_no_part_unbounded_partitioned(data_gen): | ||
assert_gpu_and_cpu_are_equal_sql( | ||
lambda spark: two_col_df(spark, UniqueLongGen(), data_gen).repartition(256), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think increasing the partitions is going to trigger the previous error. To trigger the issue the previousValue
has to be None
, which can happen the first time processing is started or if an overflow happened for decimal. But also the samePartitionMask
must be either scala.Left(cv)
or scala.Right(true)
. In the case we saw it was the latter one.
That indicates that we hit an overflow when doing decimal processing (which again matches with the negative scale), and the next entire batch is for the same grouping (Meaning there is no partition by key, only an order by key) which these tests cover.
I think you are more likely to hit the problem by having more rows with very large decimal values in them instead of trying to partition the data more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I revert the fix for the scale, I do see this failure:
23/08/23 17:58:11 WARN TaskSetManager: Lost task 0.0 in stage 31.0 (TID 79) (192.168.0.80 executor driver): java.lang.AssertionError: Type conversion is not allowed from Table{columns=[
ColumnVector{rows=598, type=DECIMAL128 scale:-1, nullCount=Optional.empty, offHeap=(ID: 144448 7f905c03d4a0)},
ColumnVector{rows=598, type=INT64, nullCount=Optional.empty, offHeap=(ID: 144449 7f905c2334a0)},
ColumnVector{rows=598, type=DECIMAL64 scale:0, nullCount=Optional.empty, offHeap=(ID: 144995 7f905c039810)},
ColumnVector{rows=598, type=DECIMAL128 scale:1, nullCount=Optional.empty, offHeap=(ID: 144997 7f905c20b270)}], cudfTable=140257995911808, rows=598}
to [DecimalType(38,1), LongType, DecimalType(10,0), DecimalType(38,1)] columns 0 to 4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you are more likely to hit the problem by having more rows with very large decimal values in them instead of trying to partition the data more.
How would I achieve this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to reproduce it and couldn't I'm not really sure. Would need to do some more debugging to see why I am not hitting it when I think I should.
'select ' | ||
' sum(b) over ' | ||
' (order by a asc ' | ||
' rows between UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as sum_b_asc ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we have a few tests with a partition by in the window?
@@ -291,7 +292,7 @@ case class BatchedOps(running: Seq[NamedExpression], | |||
def hasDoublePass: Boolean = unboundedToUnbounded.nonEmpty | |||
} | |||
|
|||
object GpuWindowExec { | |||
object GpuWindowExec { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: undo extra whitespace
val decimal = numeric.plus(Decimal(prev.getBigDecimal), | ||
Decimal(scalar.getBigDecimal)).asInstanceOf[Decimal] | ||
val dt = resultType.asInstanceOf[DecimalType] | ||
previousValue = Option(TrampolineUtil.checkDecimalOverflow( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did some more testing and unfortunately the overflow is not sticky for decimal values. I think we are going to need to either fallback to the old way for Decimals in non-ANSI mode or find a way to know that an overflow happened both within a batch and make the overflow between batches sticky.
Say I have two batches of input. The first batch overflows, and the result is a null, but the second batch does not. Now we try to add null to something that is not null. And we get the not-null value back. But in reality the overflow should be sticky. Once we overflow we can never go back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had run into this issue during testing and added the hasOverflowed
to handle that. Is the issue around how this is distributed? Are there multiple instances of SumUnboundedToUnboundedFixer
involved and that is why we can't catch all cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On Spark 3.3.0
spark.conf.set("spark.rapids.sql.batchSizeBytes", "100")
spark.time(spark.range(20).repartition(1).selectExpr("id as a", "if(id = 0, 99999999999999999999999999999999999999, 1) as b").selectExpr("SUM(b) OVER (ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as all_sum_b","*").show(false))
But out of the box this is exposing some other bugs. To fix it I had to apply the following patch.
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala
index c96f41400..d13777e4e 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala
@@ -1356,7 +1356,7 @@ class SumUnboundedToUnboundedFixer(resultType: DataType, failOnError: Boolean)
val dt = resultType.asInstanceOf[DecimalType]
previousValue = Option(TrampolineUtil.checkDecimalOverflow(
decimal, dt.precision, dt.scale, failOnError))
- .map(n => Scalar.fromDecimal(n.toJavaBigDecimal))
+ .map(n => GpuScalar.from(n, dt))
if (previousValue.isEmpty) {
hasOverflowed = true
}
The problem is that the overflow code only detects an overflow if it happened in between batches. But if the overflow happens within a batch it has no knowledge of that and assumes the null returned is the same as if all the data in that batch were nulls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I applied this suggested change but was not able to reproduce the issue locally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py
index 339954177..39d965419 100644
--- a/integration_tests/src/main/python/window_function_test.py
+++ b/integration_tests/src/main/python/window_function_test.py
@@ -292,6 +292,35 @@ def test_numeric_sum_window_unbounded(data_gen, partition_by):
conf = {'spark.rapids.sql.variableFloatAgg.enabled': 'true',
'spark.rapids.sql.batchSizeBytes': '100'})
+@ignore_order
+@pytest.mark.parametrize('batch_size', ['100', '1g'], ids=idfn)
+def test_numeric_sum_window_unbounded_decimal_overflow(batch_size):
+ assert_gpu_and_cpu_are_equal_sql(
+ # the 38 9s are the maximum value that a Decimal(38,0) can hold and Spark will infer that type automaticllay
+ # so the first batch will overflow, if it has at least two rows in it. This verifies that subsiquent batches
+ # can detect the overflow in the first batch and also include that in later batches.
+ lambda spark: spark.range(1024).selectExpr("id as a", "if (id = 0, 99999999999999999999999999999999999999, 1) as b"),
+ 'window_agg_table',
+ 'select '
+ ' sum(b) over '
+ ' (order by a asc '
+ ' rows between UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as sum_b_asc '
+ 'from window_agg_table',
+ conf = {'spark.rapids.sql.batchSizeBytes': batch_size})
+
+@ignore_order
+@pytest.mark.parametrize('batch_size', ['100', '1g'], ids=idfn)
+def test_numeric_sum_window_unbounded_long_overflow(batch_size):
+ assert_gpu_and_cpu_are_equal_sql(
+ lambda spark: spark.range(1024).selectExpr("id as a", "if (id = 0, 9223372036854775807, 1) as b"),
+ 'window_agg_table',
+ 'select '
+ ' sum(b) over '
+ ' (order by a asc '
+ ' rows between UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as sum_b_asc '
+ 'from window_agg_table',
+ conf = {'spark.rapids.sql.batchSizeBytes': batch_size})
+
@pytest.mark.xfail(reason="[UNSUPPORTED] Ranges over order by byte column overflow "
"(https://github.com/NVIDIA/spark-rapids/pull/2020#issuecomment-838127070)")
@ignore_order
Here is a new patch that adds a test that fails.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I confirmed that test_numeric_sum_window_unbounded_decimal_overflow
fails for me with the fix reverted. However, it also fails when the fix is reinstated due to different output between CPU and GPU where the CPU is producing None and the GPU is producing a value, so there is still an issue, but at least I have a failing test for it now.
Closing this for now and will revisit in the future |
Closes #9071
Closes #6560
Compared to the original PR, this PR adds an additional test that uses 256 partitions. This uncovered two bugs, which are fixed in this PR: