Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support raise_error() on [databricks] 14.3, Spark 4. #11969

Merged
merged 10 commits into from
Jan 17, 2025
53 changes: 42 additions & 11 deletions integration_tests/src/main/python/misc_expr_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,27 +33,58 @@ def test_part_id():
f.col('a'),
f.spark_partition_id()))

# Spark conf key for choosing legacy error semantics.
legacy_semantics_key = "spark.sql.legacy.raiseErrorWithoutErrorClass"
is_new_raise_error_semantics_version=is_spark_400_or_later() or is_databricks_version_or_later(14, 3)

mythrocks marked this conversation as resolved.
Show resolved Hide resolved
def raise_error_test_impl(test_conf):
use_new_error_semantics = legacy_semantics_key in test_conf and test_conf[legacy_semantics_key] == False

@pytest.mark.skipif(condition=is_spark_400_or_later() or is_databricks_version_or_later(14, 3),
reason="raise_error() not currently implemented for Spark 4.0, or Databricks 14.3. "
"See https://github.com/NVIDIA/spark-rapids/issues/10107.")
def test_raise_error():
data_gen = ShortGen(nullable=False, min_val=0, max_val=20, special_cases=[])
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen, num_slices=2).select(
f.when(f.col('a') > 30, f.raise_error("unexpected"))))
f.when(f.col('a') > 30, f.raise_error("unexpected"))),
conf=test_conf)

assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.range(0).select(f.raise_error(f.col("id"))))
lambda spark: spark.range(0).select(f.raise_error(f.col("id"))),
conf=test_conf)

error_fragment = "org.apache.spark.SparkRuntimeException" if use_new_error_semantics \
else "java.lang.RuntimeException"
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, null_gen, length=2, num_slices=1).select(
f.raise_error(f.col('a'))).collect(),
conf={},
error_message="java.lang.RuntimeException")
conf=test_conf,
error_message=error_fragment)

error_fragment = error_fragment + (": [USER_RAISED_EXCEPTION] unexpected" if use_new_error_semantics
else ": unexpected")
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, short_gen, length=2, num_slices=1).select(
f.raise_error(f.lit("unexpected"))).collect(),
conf={},
error_message="java.lang.RuntimeException: unexpected")
conf=test_conf,
error_message=error_fragment)


def test_raise_error_legacy_semantics():
"""
Tests the "legacy" semantics of raise_error(), i.e. where the error
does not include an error class.
"""
if is_new_raise_error_semantics_version:
raise_error_test_impl(test_conf={legacy_semantics_key: True})
else:
raise_error_test_impl(test_conf={})


@pytest.mark.skipif(condition=not is_new_raise_error_semantics_version,
reason="New raise_error semantics (with error-class) is only available "
"on Spark 4.0 and Databricks 14.3.")
def test_raise_error_new_semantics():
"""
Tests the "new" semantics of raise_error(), i.e. where the error
includes an error class. Unsupported in Spark versions predating
Spark 4.0, Databricks 14.3.
"""
raise_error_test_impl(test_conf={legacy_semantics_key: False})
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,10 +19,34 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.ExprRule
import com.nvidia.spark.rapids.{ExprRule, GpuOverrides}
import com.nvidia.spark.rapids.{ExprChecks, GpuExpression, TypeSig, BinaryExprMeta}

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RaiseError}
import org.apache.spark.sql.rapids.shims.GpuRaiseError

object RaiseErrorShim {
val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Map.empty
val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
Seq(GpuOverrides.expr[RaiseError](
"Throw an exception",
ExprChecks.binaryProject(
TypeSig.NULL, TypeSig.NULL,
("errorClass", TypeSig.STRING, TypeSig.STRING),
("errorParams", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.STRING)),
),
(a, conf, p, r) => new BinaryExprMeta[RaiseError](a, conf, p, r) {
mythrocks marked this conversation as resolved.
Show resolved Hide resolved

override def tagExprForGpu(): Unit = {
// In Databricks 14.3 and Spark 4.0, RaiseError forwards the lhs expression (i.e. the error-class)
// as a scalar value. A vector/column here would be surprising.
a.errorClass match {
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
case _: Literal => // Supported.
case _ => willNotWorkOnGpu(s"expected error-class to be a STRING literal")
}
}

override def convertToGpu(lhsErrorClass: Expression, rhsErrorParams: Expression): GpuExpression =
GpuRaiseError(lhsErrorClass, rhsErrorParams)
})).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*** spark-rapids-shim-json-lines
{"spark": "350db143"}
{"spark": "400"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector, ColumnView, Scalar}
import com.nvidia.spark.rapids.{GpuColumnVector, GpuBinaryExpression, GpuMapUtils, GpuScalar}
import com.nvidia.spark.rapids.Arm.withResource

import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, MapData}
import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError
import org.apache.spark.sql.types.{AbstractDataType, DataType, NullType, StringType}
import org.apache.spark.unsafe.types.UTF8String

/**
* Implements `raise_error()` for Databricks 14.3 and Spark 4.0.
* Note that while the arity `raise_error()` remains 1 for all user-facing APIs (SQL, Scala, Python).
* But internally, the implementation uses a binary expression, where the first argument indicates
* the "error-class" for the error being raised.
*/
case class GpuRaiseError(left: Expression, right: Expression) extends GpuBinaryExpression with ExpectsInputTypes {

val errorClass: Expression = left
val errorParams: Expression = right

override def dataType: DataType = NullType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
override def toString: String = s"raise_error($errorClass, $errorParams)"

/** Could evaluating this expression cause side-effects, such as throwing an exception? */
override def hasSideEffects: Boolean = true

override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector =
throw new UnsupportedOperationException("Expected errorClass (lhs) to be a String literal")

override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector =
throw new UnsupportedOperationException("Expected errorClass (lhs) to be a String literal")

private def extractScalaUTF8String(stringScalar: Scalar): UTF8String = {
// This is guaranteed to be a string scalar.
GpuScalar.extract(stringScalar).asInstanceOf[UTF8String]
}

private def extractStrings(stringsColumn: ColumnView): Array[UTF8String] = {
val size = stringsColumn.getRowCount.asInstanceOf[Int] // Already checked if exceeds threshold.
val output: Array[UTF8String] = new Array[UTF8String](size)
for (i <- 0 until size) {
output(i) = withResource(stringsColumn.getScalarElement(i)) {
extractScalaUTF8String(_)
}
}
output
}

private def makeMapData(listOfStructs: ColumnView): MapData = {
val THRESHOLD: Int = 10 // Avoiding surprises with large maps. All testing indicates a map with 1 entry.
val mapSize = listOfStructs.getRowCount

if (mapSize > THRESHOLD)
throw new UnsupportedOperationException("Unexpectedly large error-parameter map")

val outputKeys: Array[UTF8String] =
withResource(GpuMapUtils.getKeysAsListView(listOfStructs)) { listOfKeys =>
withResource(listOfKeys.getChildColumnView(0)) { // Strings child of LIST column.
extractStrings(_)
}
}

val outputVals: Array[UTF8String] =
withResource(GpuMapUtils.getValuesAsListView(listOfStructs)) { listOfVals =>
withResource(listOfVals.getChildColumnView(0)) { // Strings child of LIST column.
extractStrings(_)
}
}

ArrayBasedMapData(outputKeys, outputVals)
}

override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = {
if (rhs.getRowCount <= 0) {
// For the case: when(condition, raise_error(col("a"))
// When `condition` selects no rows, a vector of nulls should be returned, instead of throwing.
return GpuColumnVector.columnVectorFromNull(0, NullType)
}

val lhsErrorClass = lhs.getValue.asInstanceOf[UTF8String]

val rhsMapData = withResource(rhs.getBase.slice(0,1)) { slices =>
val firstRhsRow = slices(0)
makeMapData(firstRhsRow)
}

throw raiseError(lhsErrorClass, rhsMapData)
}

override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {
if (numRows <= 0) {
// For the case: when(condition, raise_error(col("a"))
// When `condition` selects no rows, a vector of nulls should be returned, instead of throwing.
return GpuColumnVector.columnVectorFromNull(0, NullType)
}

val errorClass = lhs.getValue.asInstanceOf[UTF8String]
val errorParams = rhs.getValue.asInstanceOf[MapData]
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
throw raiseError(errorClass, errorParams)
}
}
Loading