diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 5631f13f13d..6ca0e1a1967 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -820,6 +820,22 @@ def test_conv_dec_to_from_hex(from_base, to_base, pattern): conf={'spark.rapids.sql.expression.Conv': True} ) +@pytest.mark.parametrize('from_base,to_base,expected_err_msg_prefix', + [ + pytest.param(10, 15, '15 is not a supported target radix', id='to_base_unsupported'), + pytest.param(11, 16, '11 is not a supported source radix', id='from_base_unsupported'), + pytest.param(9, 17, 'both 9 and 17 are not a supported radix', id='both_base_unsupported') + ]) +def test_conv_unsupported_base(from_base, to_base, expected_err_msg_prefix): + def do_conv(spark): + gen = StringGen() + df = unary_op_df(spark, gen).select('a', f.conv(f.col('a'), from_base, to_base)) + explain_str = spark.sparkContext._jvm.com.nvidia.spark.rapids.ExplainPlan.explainPotentialGpuPlan(df._jdf, "ALL") + unsupported_base_str = f'{expected_err_msg_prefix}, only literal 10 or 16 are supported for source and target radixes' + assert unsupported_base_str in explain_str + + with_cpu_session(do_conv) + format_number_gens = integral_gens + [DecimalGen(precision=7, scale=7), DecimalGen(precision=18, scale=0), DecimalGen(precision=18, scale=3), DecimalGen(precision=36, scale=5), DecimalGen(precision=36, scale=-5), DecimalGen(precision=38, scale=10), diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py index af8bbbb55b3..44bc2a07d57 100644 --- a/integration_tests/src/main/python/window_function_test.py +++ b/integration_tests/src/main/python/window_function_test.py @@ -165,6 +165,8 @@ def test_float_window_min_max_all_nans(data_gen): .withColumn("max_b", f.max('a').over(w)) ) + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order @pytest.mark.parametrize('data_gen', [decimal_gen_128bit], ids=idfn) def test_decimal128_count_window(data_gen): @@ -177,6 +179,8 @@ def test_decimal128_count_window(data_gen): ' rows between 2 preceding and 10 following) as count_c_asc ' 'from window_agg_table') + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order @pytest.mark.parametrize('data_gen', [decimal_gen_128bit], ids=idfn) def test_decimal128_count_window_no_part(data_gen): @@ -189,6 +193,8 @@ def test_decimal128_count_window_no_part(data_gen): ' rows between 2 preceding and 10 following) as count_b_asc ' 'from window_agg_table') + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order @pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn) def test_decimal_sum_window(data_gen): @@ -201,6 +207,8 @@ def test_decimal_sum_window(data_gen): ' rows between 2 preceding and 10 following) as sum_c_asc ' 'from window_agg_table') + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order @pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn) def test_decimal_sum_window_no_part(data_gen): @@ -214,6 +222,7 @@ def test_decimal_sum_window_no_part(data_gen): 'from window_agg_table') +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order @pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn) def test_decimal_running_sum_window(data_gen): @@ -227,6 +236,8 @@ def test_decimal_running_sum_window(data_gen): 'from window_agg_table', conf = {'spark.rapids.sql.batchSizeBytes': '100'}) + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order @pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn) def test_decimal_running_sum_window_no_part(data_gen): @@ -302,6 +313,7 @@ def test_window_aggs_for_ranges_numeric_long_overflow(data_gen): 'from window_agg_table') +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # In a distributed setup the order of the partitions returned might be different, so we must ignore the order # but small batch sizes can make sort very slow, so do the final order by locally @ignore_order(local=True) @@ -352,6 +364,7 @@ def test_window_aggs_for_range_numeric_date(data_gen, batch_size): conf = conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # In a distributed setup the order of the partitions returned might be different, so we must ignore the order # but small batch sizes can make sort very slow, so do the final order by locally @ignore_order(local=True) @@ -396,6 +409,7 @@ def test_window_aggs_for_rows(data_gen, batch_size): conf = conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) @pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) @pytest.mark.parametrize('data_gen', [ @@ -482,6 +496,8 @@ def test_window_batched_unbounded(b_gen, batch_size): validate_execs_in_gpu_plan = ['GpuCachedDoublePassWindowExec'], conf = conf) + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # This is for aggregations that work with a running window optimization. They don't need to be batched # specially, but it only works if all of the aggregations can support this. # the order returned should be consistent because the data ends up in a single task (no partitioning) @@ -520,6 +536,7 @@ def test_rows_based_running_window_unpartitioned(b_gen, batch_size): conf = conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # Testing multiple batch sizes. @pytest.mark.parametrize('a_gen', integral_gens + [string_gen, date_gen, timestamp_gen], ids=meta_idfn('data:')) @allow_non_gpu(*non_utc_allow) @@ -694,6 +711,7 @@ def test_window_running_rank(data_gen): conf = conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # This is for aggregations that work with a running window optimization. They don't need to be batched # specially, but it only works if all of the aggregations can support this. # In a distributed setup the order of the partitions returned might be different, so we must ignore the order @@ -738,6 +756,8 @@ def test_rows_based_running_window_partitioned(b_gen, c_gen, batch_size): conf = conf) + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) @pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # Test different batch sizes. @pytest.mark.parametrize('part_gen', [int_gen, long_gen], ids=idfn) # Partitioning is not really the focus of the test. @@ -805,6 +825,7 @@ def must_test_sum_aggregation(gen): conf=conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # Test that we can do a running window sum on floats and doubles and decimal. This becomes problematic because we do the agg in parallel # which means that the result can switch back and forth from Inf to not Inf depending on the order of aggregations. # We test this by limiting the range of the values in the sum to never hit Inf, and by using abs so we don't have @@ -836,6 +857,7 @@ def test_window_running_float_decimal_sum(batch_size): conf = conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @approximate_float @ignore_order(local=True) @pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # Test different batch sizes. @@ -879,6 +901,7 @@ def window(oby_column): conf=conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # In a distributed setup the order of the partitions returned might be different, so we must ignore the order # but small batch sizes can make sort very slow, so do the final order by locally @ignore_order(local=True) @@ -1000,6 +1023,7 @@ def test_window_aggs_for_rows_lead_lag_on_arrays(a_gen, b_gen, c_gen, d_gen): ''') +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # lead and lag don't currently work for string columns, so redo the tests, but just for strings # without lead and lag # In a distributed setup the order of the partitions returned might be different, so we must ignore the order @@ -1107,6 +1131,8 @@ def test_window_aggs_lag_ignore_nulls_fallback(a_gen, b_gen, c_gen, d_gen): FROM window_agg_table ''') + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # Test for RANGE queries, with timestamp order-by expressions. # In a distributed setup the order of the partitions returned might be different, so we must ignore the order # but small batch sizes can make sort very slow, so do the final order by locally @@ -1155,6 +1181,7 @@ def test_window_aggs_for_ranges_timestamps(data_gen): conf = {'spark.rapids.sql.castFloatToDecimal.enabled': True}) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # In a distributed setup the order of the partitions returned might be different, so we must ignore the order # but small batch sizes can make sort very slow, so do the final order by locally @ignore_order(local=True) @@ -1201,6 +1228,7 @@ def test_window_aggregations_for_decimal_and_float_ranges(data_gen): conf={}) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # In a distributed setup the order of the partitions returned might be different, so we must ignore the order # but small batch sizes can make sort very slow, so do the final order by locally @ignore_order(local=True) @@ -1306,6 +1334,7 @@ def test_window_aggs_for_rows_collect_list(): conf={'spark.rapids.sql.window.collectList.enabled': True}) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # SortExec does not support array type, so sort the result locally. @ignore_order(local=True) # This test is more directed at Databricks and their running window optimization instead of ours @@ -1347,6 +1376,8 @@ def test_running_window_function_exec_for_all_aggs(): ''', conf={'spark.rapids.sql.window.collectList.enabled': True}) + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # Test the Databricks WindowExec which combines a WindowExec with a ProjectExec and provides the output # fields that we need to handle with an extra GpuProjectExec and we need the input expressions to compute # a window function of another window function case @@ -1668,6 +1699,8 @@ def do_it(spark): assert_gpu_fallback_collect(do_it, 'WindowExec') + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) # single-level structs (no nested structs) are now supported by the plugin @pytest.mark.parametrize('part_gen', [StructGen([["a", long_gen]])], ids=meta_idfn('partBy:')) @@ -1731,6 +1764,8 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order def test_unbounded_to_unbounded_window(): # This is specifically to test a bug that caused overflow issues when calculating @@ -1784,6 +1819,7 @@ def test_window_first_last_nth_ignore_nulls(data_gen): 'FROM window_agg_table') +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @tz_sensitive_test @allow_non_gpu(*non_supported_tz_allow) @ignore_order(local=True) @@ -1825,6 +1861,7 @@ def test_to_date_with_window_functions(): ) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) @approximate_float @pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) @@ -1881,6 +1918,7 @@ def spark_bugs_in_decimal_sorting(): return v < "3.1.4" or v < "3.3.1" or v < "3.2.3" or v < "3.4.0" +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) @approximate_float @pytest.mark.parametrize('batch_size', ['1g'], ids=idfn) @@ -1925,6 +1963,7 @@ def test_window_aggs_for_negative_rows_unpartitioned(data_gen, batch_size): conf=conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) @pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) @pytest.mark.parametrize('data_gen', [ @@ -1964,6 +2003,7 @@ def test_window_aggs_for_batched_finite_row_windows_partitioned(data_gen, batch_ conf=conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) @pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) @pytest.mark.parametrize('data_gen', [ @@ -2003,6 +2043,7 @@ def test_window_aggs_for_batched_finite_row_windows_unpartitioned(data_gen, batc conf=conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) @pytest.mark.parametrize('data_gen', [_grpkey_int_with_nulls,], ids=idfn) def test_window_aggs_for_batched_finite_row_windows_fallback(data_gen): diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandExec.scala index 0fc7defd063..c0e20bfaebc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandExec.scala @@ -20,6 +20,7 @@ import scala.util.Random import ai.rapids.cudf.NvtxColor import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.GpuExpressionsUtils.NullVecCache import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.shims.ShimUnaryExecNode @@ -54,7 +55,9 @@ class GpuExpandExecMeta( val projections = gpuProjections.map(_.map(_.convertToGpu())) GpuExpandExec(projections, expand.output, childPlans.head.convertIfNeeded())( useTieredProject = conf.isTieredProjectEnabled, - preprojectEnabled = conf.isExpandPreprojectEnabled) + preprojectEnabled = conf.isExpandPreprojectEnabled, + cacheNullMaxCount = conf.expandCachingNullVecMaxCount, + coalesceAfter = conf.isCoalesceAfterExpandEnabled) } } @@ -72,11 +75,17 @@ case class GpuExpandExec( output: Seq[Attribute], child: SparkPlan)( useTieredProject: Boolean = false, - preprojectEnabled: Boolean = false) extends ShimUnaryExecNode with GpuExec { + preprojectEnabled: Boolean = false, + cacheNullMaxCount: Int = 0, + override val coalesceAfter: Boolean = true +) extends ShimUnaryExecNode with GpuExec { override def otherCopyArgs: Seq[AnyRef] = Seq[AnyRef]( useTieredProject.asInstanceOf[java.lang.Boolean], - preprojectEnabled.asInstanceOf[java.lang.Boolean]) + preprojectEnabled.asInstanceOf[java.lang.Boolean], + cacheNullMaxCount.asInstanceOf[java.lang.Integer], + coalesceAfter.asInstanceOf[java.lang.Boolean] + ) private val PRE_PROJECT_TIME = "preprojectTime" override val outputRowsLevel: MetricsLevel = ESSENTIAL_LEVEL @@ -127,7 +136,7 @@ case class GpuExpandExec( } child.executeColumnar().mapPartitions { it => - new GpuExpandIterator(boundProjections, metricsMap, preprojectIter(it)) + new GpuExpandIterator(boundProjections, metricsMap, preprojectIter(it), cacheNullMaxCount) } } @@ -191,7 +200,8 @@ case class GpuExpandExec( class GpuExpandIterator( boundProjections: Seq[GpuTieredProject], metrics: Map[String, GpuMetric], - it: Iterator[ColumnarBatch]) + it: Iterator[ColumnarBatch], + cacheNullMaxCount: Int) extends Iterator[ColumnarBatch] { private var sb: Option[SpillableColumnarBatch] = None @@ -206,9 +216,20 @@ class GpuExpandIterator( Option(TaskContext.get()).foreach { tc => onTaskCompletion(tc) { sb.foreach(_.close()) + + if (cacheNullMaxCount > 0) { + import scala.collection.JavaConverters._ + GpuExpressionsUtils.cachedNullVectors.get().values().asScala.foreach(_.close()) + GpuExpressionsUtils.cachedNullVectors.get().clear() + } } } + if (cacheNullMaxCount > 0 && GpuExpressionsUtils.cachedNullVectors.get() == null) { + GpuExpressionsUtils.cachedNullVectors.set(new NullVecCache(cacheNullMaxCount)) + } + + override def hasNext: Boolean = sb.isDefined || it.hasNext override def next(): ColumnarBatch = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala index c46862ab2aa..2400b364b5a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala @@ -21,11 +21,12 @@ import com.nvidia.spark.Retryable import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.shims.{ShimBinaryExpression, ShimExpression, ShimTernaryExpression, ShimUnaryExpression} +import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.UTF8String @@ -52,6 +53,40 @@ object GpuExpressionsUtils { "implemented and should have been disabled") } + // This is only for ExpandExec which will generate a lot of null vectors + case class NullVecKey(d: DataType, n: Int) + + class NullVecCache(private val maxNulls: Int) + extends util.LinkedHashMap[NullVecKey, GpuColumnVector](100, 0.75f, true) { + private var totalNulls: Long = 0L + + override def clear(): Unit = { + super.clear() + totalNulls = 0 + } + + override def put(key: NullVecKey, v: GpuColumnVector): GpuColumnVector = { + if (v.getRowCount > maxNulls) { + throw new IllegalStateException(s"spark.rapids.sql.expandCachingNullVec.maxNulls" + + s"($maxNulls) is set too small to hold single vector with ${v.getRowCount} rows.") + } + val iter = entrySet().iterator() + while (iter.hasNext && totalNulls > maxNulls - v.getRowCount) { + val entry = iter.next() + iter.remove() + totalNulls -= entry.getValue.getRowCount + } + + val ret = super.put(key, v) + totalNulls += v.getRowCount + ret + } + + override def remove(key: Any): GpuColumnVector = throw new UnsupportedOperationException() + } + + val cachedNullVectors = new ThreadLocal[NullVecCache]() + /** * Tries to resolve a `GpuColumnVector` from a Scala `Any`. * @@ -73,7 +108,18 @@ object GpuExpressionsUtils { def resolveColumnVector(any: Any, numRows: Int): GpuColumnVector = { withResourceIfAllowed(any) { case c: GpuColumnVector => c.incRefCount() - case s: GpuScalar => GpuColumnVector.from(s, numRows, s.dataType) + case s: GpuScalar => + if (!s.isValid && cachedNullVectors.get() != null) { + if (!cachedNullVectors.get.containsKey(NullVecKey.apply(s.dataType, numRows))) { + cachedNullVectors.get.put(NullVecKey.apply(s.dataType, numRows), + GpuColumnVector.from(s, numRows, s.dataType)) + } + + val ret = cachedNullVectors.get().get(NullVecKey.apply(s.dataType, numRows)) + ret.incRefCount() + } else { + GpuColumnVector.from(s, numRows, s.dataType) + } case other => throw new IllegalArgumentException(s"Cannot resolve a ColumnVector from the value:" + s" $other. Please convert it to a GpuScalar or a GpuColumnVector before returning.") diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala index 48f9de5a61a..eef083bb93d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala @@ -360,6 +360,7 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { case _: GpuDataSourceScanExec => true case _: DataSourceV2ScanExecBase => true case _: RDDScanExec => true // just in case an RDD was reading in data + case _: ExpandExec => true case _ => false } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 5203e926efa..826d398dbc9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -1219,6 +1219,20 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") .booleanConf .createWithDefault(true) + val ENABLE_COALESCE_AFTER_EXPAND = conf("spark.rapids.sql.coalesceAfterExpand.enabled") + .doc("When set to false disables the coalesce after GPU Expand. ") + .internal() + .booleanConf + .createWithDefault(false) + + val EXPAND_CACHING_NULL_VEC_MAX_NULL_COUNT = + conf("spark.rapids.sql.expandCachingNullVec.maxNulls") + .doc("Max number of null scalar in null vectors to cache for GPU Expand. " + + "If the number of null scala exceeds this value, the null vectors will not be cached." + + "The value has to be positive for caching to be enabled.") + .internal().integerConf + .createWithDefault(0) + val ENABLE_ORC_FLOAT_TYPES_TO_STRING = conf("spark.rapids.sql.format.orc.floatTypesToString.enable") .doc("When reading an ORC file, the source data schemas(schemas of ORC file) may differ " + @@ -2762,6 +2776,10 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val isExpandPreprojectEnabled: Boolean = get(ENABLE_EXPAND_PREPROJECT) + lazy val isCoalesceAfterExpandEnabled: Boolean = get(ENABLE_COALESCE_AFTER_EXPAND) + + lazy val expandCachingNullVecMaxCount: Int = get(EXPAND_CACHING_NULL_VEC_MAX_NULL_COUNT) + lazy val multiThreadReadNumThreads: Int = { // Use the largest value set among all the options. val deprecatedConfs = Seq( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index dc2845e4461..a435988686d 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -2084,11 +2084,25 @@ class GpuConvMeta( override def tagExprForGpu(): Unit = { val fromBaseLit = GpuOverrides.extractLit(expr.fromBaseExpr) val toBaseLit = GpuOverrides.extractLit(expr.toBaseExpr) + val errorPostfix = "only literal 10 or 16 are supported for source and target radixes" (fromBaseLit, toBaseLit) match { - case (Some(Literal(fromBaseVal, IntegerType)), Some(Literal(toBaseVal, IntegerType))) - if Set(fromBaseVal, toBaseVal).subsetOf(Set(10, 16)) => () + case (Some(Literal(fromBaseVal, IntegerType)), Some(Literal(toBaseVal, IntegerType))) => + def isBaseSupported(base: Any): Boolean = base == 10 || base == 16 + if (!isBaseSupported(fromBaseVal) && !isBaseSupported(toBaseVal)) { + willNotWorkOnGpu(because = s"both ${fromBaseVal} and ${toBaseVal} are not " + + s"a supported radix, ${errorPostfix}") + } else if (!isBaseSupported(fromBaseVal)) { + willNotWorkOnGpu(because = s"${fromBaseVal} is not a supported source radix, " + + s"${errorPostfix}") + } else if (!isBaseSupported(toBaseVal)) { + willNotWorkOnGpu(because = s"${toBaseVal} is not a supported target radix, " + + s"${errorPostfix}") + } case _ => - willNotWorkOnGpu(because = "only literal 10 or 16 for from_base and to_base are supported") + // This will never happen in production as the function signature enforces + // integer types for the bases, but nice to have an edge case handling. + willNotWorkOnGpu(because = "either source radix or target radix is not an integer " + + "literal, " + errorPostfix) } } diff --git a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/spark330/GpuBucketingUtils.scala b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/spark330/GpuBucketingUtils.scala index feb562fa9b8..0f7c9b4fd62 100644 --- a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/spark330/GpuBucketingUtils.scala +++ b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/spark330/GpuBucketingUtils.scala @@ -31,6 +31,7 @@ {"spark": "343"} {"spark": "350"} {"spark": "351"} +{"spark": "400"} spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims