Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support raise_error() on [databricks] 14.3, Spark 4. #11969

Merged
merged 10 commits into from
Jan 17, 2025
72 changes: 60 additions & 12 deletions integration_tests/src/main/python/misc_expr_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,7 +14,7 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error
from data_gen import *
from marks import incompat, approximate_float
from pyspark.sql.types import *
Expand All @@ -33,27 +33,75 @@ def test_part_id():
f.col('a'),
f.spark_partition_id()))

# Spark conf key for choosing legacy error semantics.
legacy_semantics_key = "spark.sql.legacy.raiseErrorWithoutErrorClass"

def raise_error_test_impl(test_conf):
use_new_error_semantics = legacy_semantics_key in test_conf and test_conf[legacy_semantics_key] == False

@pytest.mark.skipif(condition=is_spark_400_or_later() or is_databricks_version_or_later(14, 3),
reason="raise_error() not currently implemented for Spark 4.0, or Databricks 14.3. "
"See https://github.com/NVIDIA/spark-rapids/issues/10107.")
def test_raise_error():
data_gen = ShortGen(nullable=False, min_val=0, max_val=20, special_cases=[])

# Test for "when" selecting the "raise_error()" expression (null-type).
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen, num_slices=2).select(
f.when(f.col('a') > 30, f.raise_error("unexpected"))))
f.when(f.col('a') > 30, f.raise_error("unexpected"))),
conf=test_conf)

# Test for if/else, with raise_error in the else.
# This should test if the data-type of raise_error() interferes with
# the result-type of the parent expression (if/else).
assert_gpu_and_cpu_are_equal_sql(
lambda spark: unary_op_df(spark, data_gen, num_slices=2),
'test_table',
"""
SELECT IF( a < 30, a, raise_error('unexpected') )
FROM test_table
""",
conf=test_conf)

assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.range(0).select(f.raise_error(f.col("id"))))
lambda spark: spark.range(0).select(f.raise_error(f.col("id"))),
conf=test_conf)

error_fragment = "org.apache.spark.SparkRuntimeException" if use_new_error_semantics \
else "java.lang.RuntimeException"
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, null_gen, length=2, num_slices=1).select(
f.raise_error(f.col('a'))).collect(),
conf={},
error_message="java.lang.RuntimeException")
conf=test_conf,
error_message=error_fragment)

error_fragment = error_fragment + (": [USER_RAISED_EXCEPTION] unexpected" if use_new_error_semantics
else ": unexpected")
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, short_gen, length=2, num_slices=1).select(
f.raise_error(f.lit("unexpected"))).collect(),
conf={},
error_message="java.lang.RuntimeException: unexpected")
conf=test_conf,
error_message=error_fragment)


def test_raise_error_legacy_semantics():
"""
Tests the "legacy" semantics of raise_error(), i.e. where the error
does not include an error class.
"""
if is_spark_400_or_later() or is_databricks_version_or_later(14, 3):
# Spark 4+ and Databricks 14.3+ support RaiseError with error-classes included.
# Must test "legacy" mode, where error-classes are excluded.
raise_error_test_impl(test_conf={legacy_semantics_key: True})
else:
# Spark versions preceding 4.0, or Databricks 14.3 do not support RaiseError with
# error-classes. No legacy mode need be selected.
raise_error_test_impl(test_conf={})


@pytest.mark.skipif(condition=not (is_spark_400_or_later() or is_databricks_version_or_later(14, 3)),
reason="RaiseError semantics with error-classes are only supported "
"on Spark 4.0+ and Databricks 14.3+.")
def test_raise_error_new_semantics():
"""
Tests the "new" semantics of raise_error(), i.e. where the error
includes an error class. Unsupported in Spark versions predating
Spark 4.0, Databricks 14.3.
"""
raise_error_test_impl(test_conf={legacy_semantics_key: False})
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,10 +19,24 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.ExprRule
import com.nvidia.spark.rapids.{ExprRule, GpuOverrides}
import com.nvidia.spark.rapids.{ExprChecks, TypeEnum, TypeSig}

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, RaiseError}
import org.apache.spark.sql.rapids.shims.RaiseErrorMeta

object RaiseErrorShim {
val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Map.empty
val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
Seq(GpuOverrides.expr[RaiseError](
"Throw an exception",
ExprChecks.binaryProject(
TypeSig.NULL, TypeSig.NULL,
// In Databricks 14.3 and Spark 4.0, RaiseError forwards the lhs expression
// (i.e. the error-class) as a scalar value. A vector/column here would be surprising.
("errorClass", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING),
("errorParams", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.STRING)),
),
(a, conf, p, r) => new RaiseErrorMeta(a, conf, p, r)
)).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*** spark-rapids-shim-json-lines
{"spark": "350db143"}
{"spark": "400"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector, ColumnView, Scalar}
import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuMapUtils, GpuScalar, RapidsConf, RapidsMeta}
import com.nvidia.spark.rapids.Arm.withResource

import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, RaiseError}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, MapData}
import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError
import org.apache.spark.sql.types.{AbstractDataType, DataType, NullType, StringType}
import org.apache.spark.unsafe.types.UTF8String

/**
* Implements `raise_error()` for Databricks 14.3 and Spark 4.0.
* Note that while the arity `raise_error()` remains 1 for all user-facing APIs
* (SQL, Scala, Python). But internally, the implementation uses a binary expression,
* where the first argument indicates the "error-class" for the error being raised.
*/
case class GpuRaiseError(left: Expression, right: Expression, dataType: DataType)
extends GpuBinaryExpression with ExpectsInputTypes {

val errorClass: Expression = left
val errorParams: Expression = right

override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
override def toString: String = s"raise_error($errorClass, $errorParams)"

/** Could evaluating this expression cause side-effects, such as throwing an exception? */
override def hasSideEffects: Boolean = true

override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector =
throw new UnsupportedOperationException("Expected errorClass (lhs) to be a String literal")

override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector =
throw new UnsupportedOperationException("Expected errorClass (lhs) to be a String literal")

private def extractScalaUTF8String(stringScalar: Scalar): UTF8String = {
// This is guaranteed to be a string scalar.
GpuScalar.extract(stringScalar).asInstanceOf[UTF8String]
}

private def extractStrings(stringsColumn: ColumnView): Array[UTF8String] = {
val size = stringsColumn.getRowCount.asInstanceOf[Int] // Already checked if exceeds threshold.
val output: Array[UTF8String] = new Array[UTF8String](size)
for (i <- 0 until size) {
output(i) = withResource(stringsColumn.getScalarElement(i)) {
extractScalaUTF8String(_)
}
}
output
}

private def makeMapData(listOfStructs: ColumnView): MapData = {
val THRESHOLD: Int = 10 // Avoiding surprises with large maps.
// All testing indicates a map with 1 entry.
val mapSize = listOfStructs.getRowCount

if (mapSize > THRESHOLD) {
throw new UnsupportedOperationException("Unexpectedly large error-parameter map")
}

val outputKeys: Array[UTF8String] =
withResource(GpuMapUtils.getKeysAsListView(listOfStructs)) { listOfKeys =>
withResource(listOfKeys.getChildColumnView(0)) { // Strings child of LIST column.
extractStrings(_)
}
}

val outputVals: Array[UTF8String] =
withResource(GpuMapUtils.getValuesAsListView(listOfStructs)) { listOfVals =>
withResource(listOfVals.getChildColumnView(0)) { // Strings child of LIST column.
extractStrings(_)
}
}

ArrayBasedMapData(outputKeys, outputVals)
}

override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = {
if (rhs.getRowCount <= 0) {
// For the case: when(condition, raise_error(col("a"))
// When `condition` selects no rows, a vector of nulls should be returned,
// instead of throwing.
return GpuColumnVector.columnVectorFromNull(0, NullType)
}

val lhsErrorClass = lhs.getValue.asInstanceOf[UTF8String]

val rhsMapData = withResource(rhs.getBase.slice(0,1)) { slices =>
val firstRhsRow = slices(0)
makeMapData(firstRhsRow)
}

throw raiseError(lhsErrorClass, rhsMapData)
}

override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {
if (numRows <= 0) {
// For the case: when(condition, raise_error(col("a"))
// When `condition` selects no rows, a vector of nulls should be returned,
// instead of throwing.
return GpuColumnVector.columnVectorFromNull(0, NullType)
}

val errorClass = lhs.getValue.asInstanceOf[UTF8String]
// TODO (future): Check if the map-data needs to be extracted differently.
// All testing indicates that the host value of the map literal is set always pre-set.
// But if it isn't, then GpuScalar.getValue might extract it incorrectly.
// https://github.com/NVIDIA/spark-rapids/issues/11974
val errorParams = rhs.getValue.asInstanceOf[MapData]
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
throw raiseError(errorClass, errorParams)
}
}

class RaiseErrorMeta(r: RaiseError,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule )
extends BinaryExprMeta[RaiseError](r, conf, parent, rule) {
override def convertToGpu(lhsErrorClass: Expression, rhsErrorParams: Expression): GpuExpression
= GpuRaiseError(lhsErrorClass, rhsErrorParams, r.dataType)
}
1 change: 1 addition & 0 deletions tools/generated_files/400/operatorsScore.csv
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ PythonUDAF,4
PythonUDF,4
Quarter,4
RLike,4
RaiseError,4
Rand,4
Rank,4
RegExpExtract,4
Expand Down
3 changes: 3 additions & 0 deletions tools/generated_files/400/supportedExprs.csv
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,9 @@ Quarter,S,`quarter`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
RLike,S,`regexp_like`; `regexp`; `rlike`,None,project,str,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
RLike,S,`regexp_like`; `regexp`; `rlike`,None,project,regexp,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
RLike,S,`regexp_like`; `regexp`; `rlike`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
RaiseError,S, ,None,project,errorClass,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
RaiseError,S, ,None,project,errorParams,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA
RaiseError,S, ,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA
Rand,S,`rand`; `random`,None,project,seed,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Rand,S,`rand`; `random`,None,project,result,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Rank,S,`rank`,None,window,ordering,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,NS,NS,NS,NS
Expand Down
Loading