Skip to content

Commit

Permalink
Add GpuCheckOverflowInTableInsert to Databricks 11.3+ (#9800)
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Lowe <[email protected]>
  • Loading branch information
jlowe authored Nov 22, 2023
1 parent bdc45cb commit 908e986
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 93 deletions.
21 changes: 21 additions & 0 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,3 +818,24 @@ def test_parquet_write_column_name_with_dots(spark_tmp_path):
lambda spark, path: gen_df(spark, gens).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path)

@ignore_order
def test_parquet_append_with_downcast(spark_tmp_table_factory, spark_tmp_path):
data_path = spark_tmp_path + "/PARQUET_DATA"
cpu_table = spark_tmp_table_factory.get()
gpu_table = spark_tmp_table_factory.get()
def setup_tables(spark):
df = unary_op_df(spark, int_gen, length=10)
df.write.format("parquet").option("path", data_path + "/CPU").saveAsTable(cpu_table)
df.write.format("parquet").option("path", data_path + "/GPU").saveAsTable(gpu_table)
with_cpu_session(setup_tables)
def do_append(spark, path):
table = cpu_table
if path.endswith("/GPU"):
table = gpu_table
unary_op_df(spark, LongGen(min_val=0, max_val=128, special_cases=[]), length=10)\
.write.mode("append").saveAsTable(table)
assert_gpu_and_cpu_writes_are_equal_collect(
do_append,
lambda spark, path: spark.read.parquet(path),
data_path)
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330db"}
{"spark": "332db"}
{"spark": "341db"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.exchange.{EXECUTOR_BROADCAST, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.rapids.{GpuCheckOverflowInTableInsert, GpuElementAtMeta}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastHashJoinExec, GpuBroadcastNestedLoopJoinExec}

trait Spark330PlusDBShims extends Spark321PlusDBShims {
override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
val shimExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
GpuOverrides.expr[CheckOverflowInTableInsert](
"Casting a numeric value as another numeric type in store assignment",
ExprChecks.unaryProjectInputMatchesOutput(
TypeSig.all,
TypeSig.all),
(t, conf, p, r) => new UnaryExprMeta[CheckOverflowInTableInsert](t, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression = {
child match {
case c: GpuCast => GpuCheckOverflowInTableInsert(c, t.columnName)
case _ =>
throw new IllegalStateException("Expression child is not of Type GpuCast")
}
}
}),
GpuElementAtMeta.elementAtRule(true)
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
super.getExprs ++ shimExprs ++ DayTimeIntervalShims.exprs ++ RoundingShims.exprs
}

override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] =
super.getExecs ++ PythonMapInArrowExecShims.execs

override def reproduceEmptyStringBug: Boolean = false

override def isExecutorBroadcastShuffle(shuffle: ShuffleExchangeLike): Boolean = {
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
}

override def shuffleParentReadsShuffleData(shuffle: ShuffleExchangeLike,
parent: SparkPlan): Boolean = {
parent match {
case _: GpuBroadcastHashJoinExec =>
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
case _: GpuBroadcastNestedLoopJoinExec =>
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
case _ => false
}
}


override def addRowShuffleToQueryStageTransitionIfNeeded(c2r: ColumnarToRowTransition,
sqse: ShuffleQueryStageExec): SparkPlan = {
val plan = GpuTransitionOverrides.getNonQueryStagePlan(sqse)
plan match {
case shuffle: ShuffleExchangeLike if shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST) =>
ShuffleExchangeExec(SinglePartition, c2r, EXECUTOR_BROADCAST)
case _ =>
c2r
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,13 @@ package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, RunnableCommand}
import org.apache.spark.sql.execution.exchange.{EXECUTOR_BROADCAST, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.rapids.GpuElementAtMeta
import org.apache.spark.sql.rapids.execution.{GpuBroadcastHashJoinExec, GpuBroadcastNestedLoopJoinExec}

object SparkShimImpl extends Spark321PlusDBShims {
object SparkShimImpl extends Spark330PlusDBShims {
// AnsiCast is removed from Spark3.4.0
override def ansiCastRule: ExprRule[_ <: Expression] = null

override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
val elementAtExpr: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
GpuElementAtMeta.elementAtRule(true)
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
super.getExprs ++ DayTimeIntervalShims.exprs ++ RoundingShims.exprs ++ elementAtExpr
}

override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] =
super.getExecs ++ PythonMapInArrowExecShims.execs

override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand],
DataWritingCommandRule[_ <: DataWritingCommand]] = {
Seq(GpuOverrides.dataWriteCmd[CreateDataSourceTableAsSelectCommand](
Expand All @@ -56,32 +40,4 @@ object SparkShimImpl extends Spark321PlusDBShims {
RunnableCommandRule[_ <: RunnableCommand]] = {
Map.empty
}

override def reproduceEmptyStringBug: Boolean = false

override def isExecutorBroadcastShuffle(shuffle: ShuffleExchangeLike): Boolean = {
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
}

override def shuffleParentReadsShuffleData(shuffle: ShuffleExchangeLike,
parent: SparkPlan): Boolean = {
parent match {
case _: GpuBroadcastHashJoinExec =>
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
case _: GpuBroadcastNestedLoopJoinExec =>
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
case _ => false
}
}

override def addRowShuffleToQueryStageTransitionIfNeeded(c2r: ColumnarToRowTransition,
sqse: ShuffleQueryStageExec): SparkPlan = {
val plan = GpuTransitionOverrides.getNonQueryStagePlan(sqse)
plan match {
case shuffle: ShuffleExchangeLike if shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST) =>
ShuffleExchangeExec(SinglePartition, c2r, EXECUTOR_BROADCAST)
case _ =>
c2r
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
*/

/*** spark-rapids-shim-json-lines
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332cdh"}
{"spark": "332db"}
{"spark": "333"}
{"spark": "340"}
{"spark": "341"}
{"spark": "341db"}
{"spark": "350"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import com.nvidia.spark.rapids.{ExprChecks, ExprRule, GpuCast, GpuExpression, Gp
import org.apache.spark.sql.catalyst.expressions.{CheckOverflowInTableInsert, Expression}
import org.apache.spark.sql.rapids.GpuCheckOverflowInTableInsert

trait Spark331PlusShims extends Spark330PlusNonDBShims {
trait Spark331PlusNonDBShims extends Spark330PlusNonDBShims {
override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
val map: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
// Add expression CheckOverflowInTableInsert starting Spark-3.3.1+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import com.nvidia.spark.rapids._

import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, RunnableCommand}

object SparkShimImpl extends Spark331PlusShims with AnsiCastRuleShims {
object SparkShimImpl extends Spark331PlusNonDBShims with AnsiCastRuleShims {
override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand],
DataWritingCommandRule[_ <: DataWritingCommand]] = {
Seq(GpuOverrides.dataWriteCmd[CreateDataSourceTableAsSelectCommand](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

object SparkShimImpl extends Spark33cdhShims with Spark331PlusShims {}
object SparkShimImpl extends Spark33cdhShims with Spark331PlusNonDBShims {}
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,11 @@ package com.nvidia.spark.rapids.shims
import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, RunnableCommand}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.exchange.{EXECUTOR_BROADCAST, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.rapids.GpuElementAtMeta
import org.apache.spark.sql.rapids.execution.{GpuBroadcastHashJoinExec, GpuBroadcastNestedLoopJoinExec}

trait Spark332PlusDBShims extends Spark321PlusDBShims {
trait Spark332PlusDBShims extends Spark330PlusDBShims {
// AnsiCast is removed from Spark3.4.0
override def ansiCastRule: ExprRule[_ <: Expression] = null

Expand All @@ -45,10 +40,9 @@ trait Spark332PlusDBShims extends Spark321PlusDBShims {
(a, conf, p, r) => new UnaryExprMeta[KnownNullable](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression = GpuKnownNullable(child)
}
),
GpuElementAtMeta.elementAtRule(true)
)
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
super.getExprs ++ shimExprs ++ DayTimeIntervalShims.exprs ++ RoundingShims.exprs
super.getExprs ++ shimExprs
}

private val shimExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = Seq(
Expand All @@ -63,7 +57,7 @@ trait Spark332PlusDBShims extends Spark321PlusDBShims {
).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap

override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] =
super.getExecs ++ shimExecs ++ PythonMapInArrowExecShims.execs
super.getExecs ++ shimExecs

override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand],
DataWritingCommandRule[_ <: DataWritingCommand]] = {
Expand All @@ -78,32 +72,4 @@ trait Spark332PlusDBShims extends Spark321PlusDBShims {
(a, conf, p, r) => new CreateDataSourceTableAsSelectCommandMeta(a, conf, p, r))
).map(r => (r.getClassFor.asSubclass(classOf[RunnableCommand]), r)).toMap
}

override def reproduceEmptyStringBug: Boolean = false

override def isExecutorBroadcastShuffle(shuffle: ShuffleExchangeLike): Boolean = {
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
}

override def shuffleParentReadsShuffleData(shuffle: ShuffleExchangeLike,
parent: SparkPlan): Boolean = {
parent match {
case _: GpuBroadcastHashJoinExec =>
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
case _: GpuBroadcastNestedLoopJoinExec =>
shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST)
case _ => false
}
}

override def addRowShuffleToQueryStageTransitionIfNeeded(c2r: ColumnarToRowTransition,
sqse: ShuffleQueryStageExec): SparkPlan = {
val plan = GpuTransitionOverrides.getNonQueryStagePlan(sqse)
plan match {
case shuffle: ShuffleExchangeLike if shuffle.shuffleOrigin.equals(EXECUTOR_BROADCAST) =>
ShuffleExchangeExec(SinglePartition, c2r, EXECUTOR_BROADCAST)
case _ =>
c2r
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS
import org.apache.spark.sql.rapids.GpuElementAtMeta
import org.apache.spark.sql.rapids.GpuV1WriteUtils.GpuEmpty2Null

trait Spark340PlusShims extends Spark331PlusShims {
trait Spark340PlusNonDBShims extends Spark331PlusNonDBShims {

private val shimExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = Seq(
GpuOverrides.exec[GlobalLimitExec](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

object SparkShimImpl extends Spark340PlusShims
object SparkShimImpl extends Spark340PlusNonDBShims
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDAF, ToPret
import org.apache.spark.sql.rapids.execution.python.GpuPythonUDAF
import org.apache.spark.sql.types.StringType

object SparkShimImpl extends Spark340PlusShims {
object SparkShimImpl extends Spark340PlusNonDBShims {

override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
val shimExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
Expand Down

0 comments on commit 908e986

Please sign in to comment.