Skip to content

Commit

Permalink
Fix collection_ops_tests for Spark 4.0 [databricks] (NVIDIA#11414)
Browse files Browse the repository at this point in the history
* Fix collection_ops_tests for Spark 4.0.

Fixes NVIDIA#11011.

This commit fixes the failures in `collection_ops_tests` on Spark 4.0.

On all versions of Spark, when a Sequence is collected with rows that exceed MAX_INT,
an exception is thrown indicating that the collected Sequence/array is
larger than permissible. The different versions of Spark vary in the
contents of the exception message.

On Spark 4, one sees that the error message now contains more
information than all prior versions, including:
1. The name of the op causing the error
2. The errant sequence size

This commit introduces a shim to make this new information available in
the exception.

Note that this shim does not fit cleanly in RapidsErrorUtils, because
there are differences within major Spark versions. For instance, Spark
3.4.0-1 have a different message as compared to 3.4.2 and 3.4.3.
Likewise, the differences in 3.5.0, 3.5.1, 3.5.2.

Signed-off-by: MithunR <[email protected]>

* Fixed formatting error.

* Review comments.

This moves the construction of the long-sequence error strings into
RapidsErrorUtils.  The process involved introducing many new RapidsErrorUtils
classes, and using mix-ins of concrete implementations for the error-string
construction.

* Added missing shim tag for 3.5.2.

* Review comments: Fixed code style.

* Reformatting, per project guideline.

* Fixed missed whitespace problem.

---------

Signed-off-by: MithunR <[email protected]>
  • Loading branch information
mythrocks authored Oct 12, 2024
1 parent 4866941 commit aca15ab
Show file tree
Hide file tree
Showing 17 changed files with 392 additions and 145 deletions.
11 changes: 8 additions & 3 deletions integration_tests/src/main/python/collection_ops_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,8 @@
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error
from data_gen import *
from pyspark.sql.types import *

from spark_session import is_before_spark_400
from string_test import mk_str_gen
import pyspark.sql.functions as f
import pyspark.sql.utils
Expand Down Expand Up @@ -326,8 +328,11 @@ def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen):
@pytest.mark.parametrize('stop_gen', sequence_too_long_length_gens, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_sequence_too_long_sequence(stop_gen):
msg = "Too long sequence" if is_before_spark_334() or (not is_before_spark_340() and is_before_spark_342()) \
or is_spark_350() else "Unsuccessful try to create array with"
msg = "Too long sequence" if is_before_spark_334() \
or (not is_before_spark_340() and is_before_spark_342()) \
or is_spark_350() \
else "Can't create array" if not is_before_spark_400() \
else "Unsuccessful try to create array with"
assert_gpu_and_cpu_error(
# To avoid OOM, reduce the row number to 1, it is enough to verify this case.
lambda spark:unary_op_df(spark, stop_gen, 1).selectExpr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids
import java.util.Optional

import ai.rapids.cudf
import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar, SegmentedReductionAggregation, Table}
import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, ReductionAggregation, Scalar, SegmentedReductionAggregation, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.ArrayIndexUtils.firstIndexAndNumElementUnchecked
Expand Down Expand Up @@ -1651,7 +1651,8 @@ object GpuSequenceUtil {
def computeSequenceSize(
start: ColumnVector,
stop: ColumnVector,
step: ColumnVector): ColumnVector = {
step: ColumnVector,
functionName: String): ColumnVector = {
checkSequenceInputs(start, stop, step)
val actualSize = GetSequenceSize(start, stop, step)
val sizeAsLong = withResource(actualSize) { _ =>
Expand All @@ -1673,7 +1674,12 @@ object GpuSequenceUtil {
// check max size
withResource(Scalar.fromInt(MAX_ROUNDED_ARRAY_LENGTH)) { maxLen =>
withResource(sizeAsLong.lessOrEqualTo(maxLen)) { allValid =>
require(isAllValidTrue(allValid), GetSequenceSize.TOO_LONG_SEQUENCE)
withResource(sizeAsLong.reduce(ReductionAggregation.max())) { maxSizeScalar =>
require(isAllValidTrue(allValid),
RapidsErrorUtils.getTooLongSequenceErrorString(
maxSizeScalar.getLong.asInstanceOf[Int],
functionName))
}
}
}
// cast to int and return
Expand Down Expand Up @@ -1713,7 +1719,7 @@ case class GpuSequence(start: Expression, stop: Expression, stepOpt: Option[Expr
val steps = stepGpuColOpt.map(_.getBase.incRefCount())
.getOrElse(defaultStepsFunc(startCol, stopCol))
closeOnExcept(steps) { _ =>
(computeSequenceSize(startCol, stopCol, steps), steps)
(computeSequenceSize(startCol, stopCol, steps, prettyName), steps)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ package com.nvidia.spark.rapids.shims
import ai.rapids.cudf._
import com.nvidia.spark.rapids.Arm._

import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

object GetSequenceSize {
val TOO_LONG_SEQUENCE = s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH"
/**
* Compute the size of each sequence according to 'start', 'stop' and 'step'.
* A row (Row[start, stop, step]) contains at least one null element will produce
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "320"}
{"spark": "321"}
{"spark": "321cdh"}
{"spark": "322"}
{"spark": "323"}
{"spark": "324"}
{"spark": "330"}
{"spark": "330cdh"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332cdh"}
{"spark": "332db"}
{"spark": "333"}
{"spark": "340"}
{"spark": "341"}
{"spark": "341db"}
{"spark": "350"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

trait SequenceSizeTooLongErrorBuilder {

def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = {
// For these Spark versions, the sequence length and function name
// do not appear in the exception message.
s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

object RapidsErrorUtils extends RapidsQueryErrorUtils {
object RapidsErrorUtils extends RapidsQueryErrorUtils with SequenceSizeTooLongErrorBuilder {
def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
// Follow the Spark string format before 3.3.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,64 +21,9 @@
{"spark": "332"}
{"spark": "332cdh"}
{"spark": "333"}
{"spark": "334"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.SparkDateTimeException
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}
object RapidsErrorUtils extends RapidsErrorUtils330To334Base
with SequenceSizeTooLongErrorBuilder

object RapidsErrorUtils extends RapidsErrorUtilsFor330plus with RapidsQueryErrorUtils {

def mapKeyNotExistError(
key: String,
keyType: DataType,
origin: Origin): NoSuchElementException = {
QueryExecutionErrors.mapKeyNotExistError(key, keyType, origin.context)
}

def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
if (isElementAtF) {
QueryExecutionErrors.invalidElementAtIndexError(index, numElements)
} else {
QueryExecutionErrors.invalidArrayIndexError(index, numElements)
}
}

def arithmeticOverflowError(
message: String,
hint: String = "",
errorContext: String = ""): ArithmeticException = {
QueryExecutionErrors.arithmeticOverflowError(message, hint, errorContext)
}

def cannotChangeDecimalPrecisionError(
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale, context
)
}

def overflowInIntegralDivideError(context: String = ""): ArithmeticException = {
QueryExecutionErrors.arithmeticOverflowError(
"Overflow in integral divide", "try_divide", context
)
}

def sparkDateTimeException(infOrNan: String): SparkDateTimeException = {
// These are the arguments required by SparkDateTimeException class to create error message.
val errorClass = "CAST_INVALID_INPUT"
val messageParameters = Array("DOUBLE", "TIMESTAMP", SQLConf.ANSI_ENABLED.key)
new SparkDateTimeException(errorClass, Array(infOrNan) ++ messageParameters)
}

def sqlArrayIndexNotStartAtOneError(): RuntimeException = {
new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330"}
{"spark": "330cdh"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332cdh"}
{"spark": "333"}
{"spark": "334"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.SparkDateTimeException
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, Decimal, DecimalType}

trait RapidsErrorUtils330To334Base extends RapidsErrorUtilsFor330plus with RapidsQueryErrorUtils {

def mapKeyNotExistError(
key: String,
keyType: DataType,
origin: Origin): NoSuchElementException = {
QueryExecutionErrors.mapKeyNotExistError(key, keyType, origin.context)
}

def invalidArrayIndexError(index: Int, numElements: Int,
isElementAtF: Boolean = false): ArrayIndexOutOfBoundsException = {
if (isElementAtF) {
QueryExecutionErrors.invalidElementAtIndexError(index, numElements)
} else {
QueryExecutionErrors.invalidArrayIndexError(index, numElements)
}
}

def arithmeticOverflowError(
message: String,
hint: String = "",
errorContext: String = ""): ArithmeticException = {
QueryExecutionErrors.arithmeticOverflowError(message, hint, errorContext)
}

def cannotChangeDecimalPrecisionError(
value: Decimal,
toType: DecimalType,
context: String = ""): ArithmeticException = {
QueryExecutionErrors.cannotChangeDecimalPrecisionError(
value, toType.precision, toType.scale, context
)
}

def overflowInIntegralDivideError(context: String = ""): ArithmeticException = {
QueryExecutionErrors.arithmeticOverflowError(
"Overflow in integral divide", "try_divide", context
)
}

def sparkDateTimeException(infOrNan: String): SparkDateTimeException = {
// These are the arguments required by SparkDateTimeException class to create error message.
val errorClass = "CAST_INVALID_INPUT"
val messageParameters = Array("DOUBLE", "TIMESTAMP", SQLConf.ANSI_ENABLED.key)
new SparkDateTimeException(errorClass, Array(infOrNan) ++ messageParameters)
}

def sqlArrayIndexNotStartAtOneError(): RuntimeException = {
new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.errors.QueryExecutionErrors

object RapidsErrorUtils extends RapidsErrorUtilsBase with RapidsQueryErrorUtils {
object RapidsErrorUtils extends RapidsErrorUtilsBase
with RapidsQueryErrorUtils with SequenceSizeTooLongErrorBuilder {
def sqlArrayIndexNotStartAtOneError(): RuntimeException = {
QueryExecutionErrors.elementAtByIndexZeroError(context = null)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ import org.apache.spark.sql.rapids.{AddOverflowChecks, SubtractOverflowChecks}
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

object GetSequenceSize {
val TOO_LONG_SEQUENCE = "Unsuccessful try to create array with elements exceeding the array " +
s"size limit $MAX_ROUNDED_ARRAY_LENGTH"
/**
* Compute the size of each sequence according to 'start', 'stop' and 'step'.
* A row (Row[start, stop, step]) contains at least one null element will produce
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "334"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

object RapidsErrorUtils extends RapidsErrorUtils330To334Base
with SequenceSizeTooLongUnsuccessfulErrorBuilder

Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "334"}
{"spark": "342"}
{"spark": "343"}
{"spark": "351"}
{"spark": "352"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

trait SequenceSizeTooLongUnsuccessfulErrorBuilder {
def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = {
// The errant function's name does not feature in the exception message
// prior to Spark 4.0. Neither does the attempted allocation size.
"Unsuccessful try to create array with elements exceeding the array " +
s"size limit $MAX_ROUNDED_ARRAY_LENGTH"
}
}
Loading

0 comments on commit aca15ab

Please sign in to comment.