Skip to content

Commit

Permalink
Support raise_error() on [databricks] 14.3, Spark 4. (#11969)
Browse files Browse the repository at this point in the history
Fixes #10969.

This commit adds support for `raise_error()` on Databricks 14.3 and
Spark 4.0.

On these new Spark versions, the `RaiseError` expression (that powers
the `raise_error()` API function) was changed from a Unary expression to
a Binary one. This was done without modifying the arity of
`raise_error()`. The ostensible reason seems to have been to eventually
allow user-code to raise custom errors via `raise_error()`.

This commit allows `raise_error()` to work on the GPU as it currently
does on the CPU: as a unary function powered by a binary expression in
the background.

The tests have been modified to verify both the new behaviour and the
legacy one on new platforms, while continuing to run as before on legacy
platforms.

---------

Signed-off-by: MithunR <[email protected]>
  • Loading branch information
mythrocks authored Jan 17, 2025
1 parent 1ba64a8 commit 8ace562
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 16 deletions.
72 changes: 60 additions & 12 deletions integration_tests/src/main/python/misc_expr_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,7 +14,7 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error
from data_gen import *
from marks import incompat, approximate_float
from pyspark.sql.types import *
Expand All @@ -33,27 +33,75 @@ def test_part_id():
f.col('a'),
f.spark_partition_id()))

# Spark conf key for choosing legacy error semantics.
legacy_semantics_key = "spark.sql.legacy.raiseErrorWithoutErrorClass"

def raise_error_test_impl(test_conf):
use_new_error_semantics = legacy_semantics_key in test_conf and test_conf[legacy_semantics_key] == False

@pytest.mark.skipif(condition=is_spark_400_or_later() or is_databricks_version_or_later(14, 3),
reason="raise_error() not currently implemented for Spark 4.0, or Databricks 14.3. "
"See https://github.com/NVIDIA/spark-rapids/issues/10107.")
def test_raise_error():
data_gen = ShortGen(nullable=False, min_val=0, max_val=20, special_cases=[])

# Test for "when" selecting the "raise_error()" expression (null-type).
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen, num_slices=2).select(
f.when(f.col('a') > 30, f.raise_error("unexpected"))))
f.when(f.col('a') > 30, f.raise_error("unexpected"))),
conf=test_conf)

# Test for if/else, with raise_error in the else.
# This should test if the data-type of raise_error() interferes with
# the result-type of the parent expression (if/else).
assert_gpu_and_cpu_are_equal_sql(
lambda spark: unary_op_df(spark, data_gen, num_slices=2),
'test_table',
"""
SELECT IF( a < 30, a, raise_error('unexpected') )
FROM test_table
""",
conf=test_conf)

assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.range(0).select(f.raise_error(f.col("id"))))
lambda spark: spark.range(0).select(f.raise_error(f.col("id"))),
conf=test_conf)

error_fragment = "org.apache.spark.SparkRuntimeException" if use_new_error_semantics \
else "java.lang.RuntimeException"
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, null_gen, length=2, num_slices=1).select(
f.raise_error(f.col('a'))).collect(),
conf={},
error_message="java.lang.RuntimeException")
conf=test_conf,
error_message=error_fragment)

error_fragment = error_fragment + (": [USER_RAISED_EXCEPTION] unexpected" if use_new_error_semantics
else ": unexpected")
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, short_gen, length=2, num_slices=1).select(
f.raise_error(f.lit("unexpected"))).collect(),
conf={},
error_message="java.lang.RuntimeException: unexpected")
conf=test_conf,
error_message=error_fragment)


def test_raise_error_legacy_semantics():
"""
Tests the "legacy" semantics of raise_error(), i.e. where the error
does not include an error class.
"""
if is_spark_400_or_later() or is_databricks_version_or_later(14, 3):
# Spark 4+ and Databricks 14.3+ support RaiseError with error-classes included.
# Must test "legacy" mode, where error-classes are excluded.
raise_error_test_impl(test_conf={legacy_semantics_key: True})
else:
# Spark versions preceding 4.0, or Databricks 14.3 do not support RaiseError with
# error-classes. No legacy mode need be selected.
raise_error_test_impl(test_conf={})


@pytest.mark.skipif(condition=not (is_spark_400_or_later() or is_databricks_version_or_later(14, 3)),
reason="RaiseError semantics with error-classes are only supported "
"on Spark 4.0+ and Databricks 14.3+.")
def test_raise_error_new_semantics():
"""
Tests the "new" semantics of raise_error(), i.e. where the error
includes an error class. Unsupported in Spark versions predating
Spark 4.0, Databricks 14.3.
"""
raise_error_test_impl(test_conf={legacy_semantics_key: False})
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,10 +19,24 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.ExprRule
import com.nvidia.spark.rapids.{ExprRule, GpuOverrides}
import com.nvidia.spark.rapids.{ExprChecks, TypeEnum, TypeSig}

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, RaiseError}
import org.apache.spark.sql.rapids.shims.RaiseErrorMeta

object RaiseErrorShim {
val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Map.empty
val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
Seq(GpuOverrides.expr[RaiseError](
"Throw an exception",
ExprChecks.binaryProject(
TypeSig.NULL, TypeSig.NULL,
// In Databricks 14.3 and Spark 4.0, RaiseError forwards the lhs expression
// (i.e. the error-class) as a scalar value. A vector/column here would be surprising.
("errorClass", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING),
("errorParams", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.STRING)),
),
(a, conf, p, r) => new RaiseErrorMeta(a, conf, p, r)
)).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*** spark-rapids-shim-json-lines
{"spark": "350db143"}
{"spark": "400"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector, ColumnView, Scalar}
import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuMapUtils, GpuScalar, RapidsConf, RapidsMeta}
import com.nvidia.spark.rapids.Arm.withResource

import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, RaiseError}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, MapData}
import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError
import org.apache.spark.sql.types.{AbstractDataType, DataType, NullType, StringType}
import org.apache.spark.unsafe.types.UTF8String

/**
* Implements `raise_error()` for Databricks 14.3 and Spark 4.0.
* Note that while the arity `raise_error()` remains 1 for all user-facing APIs
* (SQL, Scala, Python). But internally, the implementation uses a binary expression,
* where the first argument indicates the "error-class" for the error being raised.
*/
case class GpuRaiseError(left: Expression, right: Expression, dataType: DataType)
extends GpuBinaryExpression with ExpectsInputTypes {

val errorClass: Expression = left
val errorParams: Expression = right

override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
override def toString: String = s"raise_error($errorClass, $errorParams)"

/** Could evaluating this expression cause side-effects, such as throwing an exception? */
override def hasSideEffects: Boolean = true

override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector =
throw new UnsupportedOperationException("Expected errorClass (lhs) to be a String literal")

override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector =
throw new UnsupportedOperationException("Expected errorClass (lhs) to be a String literal")

private def extractScalaUTF8String(stringScalar: Scalar): UTF8String = {
// This is guaranteed to be a string scalar.
GpuScalar.extract(stringScalar).asInstanceOf[UTF8String]
}

private def extractStrings(stringsColumn: ColumnView): Array[UTF8String] = {
val size = stringsColumn.getRowCount.asInstanceOf[Int] // Already checked if exceeds threshold.
val output: Array[UTF8String] = new Array[UTF8String](size)
for (i <- 0 until size) {
output(i) = withResource(stringsColumn.getScalarElement(i)) {
extractScalaUTF8String(_)
}
}
output
}

private def makeMapData(listOfStructs: ColumnView): MapData = {
val THRESHOLD: Int = 10 // Avoiding surprises with large maps.
// All testing indicates a map with 1 entry.
val mapSize = listOfStructs.getRowCount

if (mapSize > THRESHOLD) {
throw new UnsupportedOperationException("Unexpectedly large error-parameter map")
}

val outputKeys: Array[UTF8String] =
withResource(GpuMapUtils.getKeysAsListView(listOfStructs)) { listOfKeys =>
withResource(listOfKeys.getChildColumnView(0)) { // Strings child of LIST column.
extractStrings(_)
}
}

val outputVals: Array[UTF8String] =
withResource(GpuMapUtils.getValuesAsListView(listOfStructs)) { listOfVals =>
withResource(listOfVals.getChildColumnView(0)) { // Strings child of LIST column.
extractStrings(_)
}
}

ArrayBasedMapData(outputKeys, outputVals)
}

override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = {
if (rhs.getRowCount <= 0) {
// For the case: when(condition, raise_error(col("a"))
// When `condition` selects no rows, a vector of nulls should be returned,
// instead of throwing.
return GpuColumnVector.columnVectorFromNull(0, NullType)
}

val lhsErrorClass = lhs.getValue.asInstanceOf[UTF8String]

val rhsMapData = withResource(rhs.getBase.slice(0,1)) { slices =>
val firstRhsRow = slices(0)
makeMapData(firstRhsRow)
}

throw raiseError(lhsErrorClass, rhsMapData)
}

override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {
if (numRows <= 0) {
// For the case: when(condition, raise_error(col("a"))
// When `condition` selects no rows, a vector of nulls should be returned,
// instead of throwing.
return GpuColumnVector.columnVectorFromNull(0, NullType)
}

val errorClass = lhs.getValue.asInstanceOf[UTF8String]
// TODO (future): Check if the map-data needs to be extracted differently.
// All testing indicates that the host value of the map literal is set always pre-set.
// But if it isn't, then GpuScalar.getValue might extract it incorrectly.
// https://github.com/NVIDIA/spark-rapids/issues/11974
val errorParams = rhs.getValue.asInstanceOf[MapData]
throw raiseError(errorClass, errorParams)
}
}

class RaiseErrorMeta(r: RaiseError,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule )
extends BinaryExprMeta[RaiseError](r, conf, parent, rule) {
override def convertToGpu(lhsErrorClass: Expression, rhsErrorParams: Expression): GpuExpression
= GpuRaiseError(lhsErrorClass, rhsErrorParams, r.dataType)
}
1 change: 1 addition & 0 deletions tools/generated_files/400/operatorsScore.csv
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ PythonUDAF,4
PythonUDF,4
Quarter,4
RLike,4
RaiseError,4
Rand,4
Rank,4
RegExpExtract,4
Expand Down
3 changes: 3 additions & 0 deletions tools/generated_files/400/supportedExprs.csv
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,9 @@ Quarter,S,`quarter`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
RLike,S,`regexp_like`; `regexp`; `rlike`,None,project,str,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
RLike,S,`regexp_like`; `regexp`; `rlike`,None,project,regexp,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
RLike,S,`regexp_like`; `regexp`; `rlike`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
RaiseError,S, ,None,project,errorClass,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
RaiseError,S, ,None,project,errorParams,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA
RaiseError,S, ,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA
Rand,S,`rand`; `random`,None,project,seed,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Rand,S,`rand`; `random`,None,project,result,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Rank,S,`rank`,None,window,ordering,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,NS,NS,NS,NS
Expand Down

0 comments on commit 8ace562

Please sign in to comment.