From 4415f61d4e2f1b7d4da71cc877ddd6a44721b5bb Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Wed, 26 Jun 2024 14:09:58 +0800 Subject: [PATCH] optimzing Expand+Aggregate in sqlw with many count distinct (#22) * optimzing Expand+Aggregate in sqlw with many count distinct Signed-off-by: Hongbin Ma (Mahone) * Add GpuBucketingUtils shim to Spark 4.0.0 (#11092) * Add GpuBucketingUtils shim to Spark 4.0.0 * Signing off Signed-off-by: Raza Jafri --------- Signed-off-by: Raza Jafri * Improve the diagnostics for 'conv' fallback explain (#11076) * Improve the diagnostics for 'conv' fallback explain Signed-off-by: Jihoon Son * don't use nil Signed-off-by: Jihoon Son * the bases should not be an empty string in the error message when the user input is not Signed-off-by: Jihoon Son * more user-friendly message * Update sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala Co-authored-by: Gera Shegalov --------- Signed-off-by: Jihoon Son Co-authored-by: Gera Shegalov * Disable ANSI mode for window function tests [databricks] (#11073) * Disable ANSI mode for window function tests. Fixes #11019. Window function tests fail on Spark 4.0 because of #5114 (and #5120 broadly), because spark-rapids does not support SUM, COUNT, and certain other aggregations in ANSI mode. This commit disables ANSI mode tests for the failing window function tests. These may be revisited, once error/overflow checking is available for ANSI mode in spark-rapids. Signed-off-by: MithunR * Switch from @ansi_mode_disabled to @disable_ansi_mode. --------- Signed-off-by: MithunR --------- Signed-off-by: Hongbin Ma (Mahone) Signed-off-by: Raza Jafri Signed-off-by: Jihoon Son Signed-off-by: MithunR Co-authored-by: Hongbin Ma (Mahone) Co-authored-by: Raza Jafri Co-authored-by: Jihoon Son Co-authored-by: Gera Shegalov Co-authored-by: MithunR --- .../src/main/python/string_test.py | 16 ++++++ .../src/main/python/window_function_test.py | 41 +++++++++++++++ .../nvidia/spark/rapids/GpuExpandExec.scala | 31 ++++++++++-- .../nvidia/spark/rapids/GpuExpressions.scala | 50 ++++++++++++++++++- .../spark/rapids/GpuTransitionOverrides.scala | 1 + .../com/nvidia/spark/rapids/RapidsConf.scala | 18 +++++++ .../spark/sql/rapids/stringFunctions.scala | 20 ++++++-- .../shims/spark330/GpuBucketingUtils.scala | 1 + 8 files changed, 168 insertions(+), 10 deletions(-) diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 5631f13f13d..6ca0e1a1967 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -820,6 +820,22 @@ def test_conv_dec_to_from_hex(from_base, to_base, pattern): conf={'spark.rapids.sql.expression.Conv': True} ) +@pytest.mark.parametrize('from_base,to_base,expected_err_msg_prefix', + [ + pytest.param(10, 15, '15 is not a supported target radix', id='to_base_unsupported'), + pytest.param(11, 16, '11 is not a supported source radix', id='from_base_unsupported'), + pytest.param(9, 17, 'both 9 and 17 are not a supported radix', id='both_base_unsupported') + ]) +def test_conv_unsupported_base(from_base, to_base, expected_err_msg_prefix): + def do_conv(spark): + gen = StringGen() + df = unary_op_df(spark, gen).select('a', f.conv(f.col('a'), from_base, to_base)) + explain_str = spark.sparkContext._jvm.com.nvidia.spark.rapids.ExplainPlan.explainPotentialGpuPlan(df._jdf, "ALL") + unsupported_base_str = f'{expected_err_msg_prefix}, only literal 10 or 16 are supported for source and target radixes' + assert unsupported_base_str in explain_str + + with_cpu_session(do_conv) + format_number_gens = integral_gens + [DecimalGen(precision=7, scale=7), DecimalGen(precision=18, scale=0), DecimalGen(precision=18, scale=3), DecimalGen(precision=36, scale=5), DecimalGen(precision=36, scale=-5), DecimalGen(precision=38, scale=10), diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py index af8bbbb55b3..44bc2a07d57 100644 --- a/integration_tests/src/main/python/window_function_test.py +++ b/integration_tests/src/main/python/window_function_test.py @@ -165,6 +165,8 @@ def test_float_window_min_max_all_nans(data_gen): .withColumn("max_b", f.max('a').over(w)) ) + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order @pytest.mark.parametrize('data_gen', [decimal_gen_128bit], ids=idfn) def test_decimal128_count_window(data_gen): @@ -177,6 +179,8 @@ def test_decimal128_count_window(data_gen): ' rows between 2 preceding and 10 following) as count_c_asc ' 'from window_agg_table') + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order @pytest.mark.parametrize('data_gen', [decimal_gen_128bit], ids=idfn) def test_decimal128_count_window_no_part(data_gen): @@ -189,6 +193,8 @@ def test_decimal128_count_window_no_part(data_gen): ' rows between 2 preceding and 10 following) as count_b_asc ' 'from window_agg_table') + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order @pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn) def test_decimal_sum_window(data_gen): @@ -201,6 +207,8 @@ def test_decimal_sum_window(data_gen): ' rows between 2 preceding and 10 following) as sum_c_asc ' 'from window_agg_table') + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order @pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn) def test_decimal_sum_window_no_part(data_gen): @@ -214,6 +222,7 @@ def test_decimal_sum_window_no_part(data_gen): 'from window_agg_table') +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order @pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn) def test_decimal_running_sum_window(data_gen): @@ -227,6 +236,8 @@ def test_decimal_running_sum_window(data_gen): 'from window_agg_table', conf = {'spark.rapids.sql.batchSizeBytes': '100'}) + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order @pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn) def test_decimal_running_sum_window_no_part(data_gen): @@ -302,6 +313,7 @@ def test_window_aggs_for_ranges_numeric_long_overflow(data_gen): 'from window_agg_table') +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # In a distributed setup the order of the partitions returned might be different, so we must ignore the order # but small batch sizes can make sort very slow, so do the final order by locally @ignore_order(local=True) @@ -352,6 +364,7 @@ def test_window_aggs_for_range_numeric_date(data_gen, batch_size): conf = conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # In a distributed setup the order of the partitions returned might be different, so we must ignore the order # but small batch sizes can make sort very slow, so do the final order by locally @ignore_order(local=True) @@ -396,6 +409,7 @@ def test_window_aggs_for_rows(data_gen, batch_size): conf = conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) @pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) @pytest.mark.parametrize('data_gen', [ @@ -482,6 +496,8 @@ def test_window_batched_unbounded(b_gen, batch_size): validate_execs_in_gpu_plan = ['GpuCachedDoublePassWindowExec'], conf = conf) + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # This is for aggregations that work with a running window optimization. They don't need to be batched # specially, but it only works if all of the aggregations can support this. # the order returned should be consistent because the data ends up in a single task (no partitioning) @@ -520,6 +536,7 @@ def test_rows_based_running_window_unpartitioned(b_gen, batch_size): conf = conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # Testing multiple batch sizes. @pytest.mark.parametrize('a_gen', integral_gens + [string_gen, date_gen, timestamp_gen], ids=meta_idfn('data:')) @allow_non_gpu(*non_utc_allow) @@ -694,6 +711,7 @@ def test_window_running_rank(data_gen): conf = conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # This is for aggregations that work with a running window optimization. They don't need to be batched # specially, but it only works if all of the aggregations can support this. # In a distributed setup the order of the partitions returned might be different, so we must ignore the order @@ -738,6 +756,8 @@ def test_rows_based_running_window_partitioned(b_gen, c_gen, batch_size): conf = conf) + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) @pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # Test different batch sizes. @pytest.mark.parametrize('part_gen', [int_gen, long_gen], ids=idfn) # Partitioning is not really the focus of the test. @@ -805,6 +825,7 @@ def must_test_sum_aggregation(gen): conf=conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # Test that we can do a running window sum on floats and doubles and decimal. This becomes problematic because we do the agg in parallel # which means that the result can switch back and forth from Inf to not Inf depending on the order of aggregations. # We test this by limiting the range of the values in the sum to never hit Inf, and by using abs so we don't have @@ -836,6 +857,7 @@ def test_window_running_float_decimal_sum(batch_size): conf = conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @approximate_float @ignore_order(local=True) @pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # Test different batch sizes. @@ -879,6 +901,7 @@ def window(oby_column): conf=conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # In a distributed setup the order of the partitions returned might be different, so we must ignore the order # but small batch sizes can make sort very slow, so do the final order by locally @ignore_order(local=True) @@ -1000,6 +1023,7 @@ def test_window_aggs_for_rows_lead_lag_on_arrays(a_gen, b_gen, c_gen, d_gen): ''') +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # lead and lag don't currently work for string columns, so redo the tests, but just for strings # without lead and lag # In a distributed setup the order of the partitions returned might be different, so we must ignore the order @@ -1107,6 +1131,8 @@ def test_window_aggs_lag_ignore_nulls_fallback(a_gen, b_gen, c_gen, d_gen): FROM window_agg_table ''') + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # Test for RANGE queries, with timestamp order-by expressions. # In a distributed setup the order of the partitions returned might be different, so we must ignore the order # but small batch sizes can make sort very slow, so do the final order by locally @@ -1155,6 +1181,7 @@ def test_window_aggs_for_ranges_timestamps(data_gen): conf = {'spark.rapids.sql.castFloatToDecimal.enabled': True}) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # In a distributed setup the order of the partitions returned might be different, so we must ignore the order # but small batch sizes can make sort very slow, so do the final order by locally @ignore_order(local=True) @@ -1201,6 +1228,7 @@ def test_window_aggregations_for_decimal_and_float_ranges(data_gen): conf={}) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # In a distributed setup the order of the partitions returned might be different, so we must ignore the order # but small batch sizes can make sort very slow, so do the final order by locally @ignore_order(local=True) @@ -1306,6 +1334,7 @@ def test_window_aggs_for_rows_collect_list(): conf={'spark.rapids.sql.window.collectList.enabled': True}) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # SortExec does not support array type, so sort the result locally. @ignore_order(local=True) # This test is more directed at Databricks and their running window optimization instead of ours @@ -1347,6 +1376,8 @@ def test_running_window_function_exec_for_all_aggs(): ''', conf={'spark.rapids.sql.window.collectList.enabled': True}) + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 # Test the Databricks WindowExec which combines a WindowExec with a ProjectExec and provides the output # fields that we need to handle with an extra GpuProjectExec and we need the input expressions to compute # a window function of another window function case @@ -1668,6 +1699,8 @@ def do_it(spark): assert_gpu_fallback_collect(do_it, 'WindowExec') + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) # single-level structs (no nested structs) are now supported by the plugin @pytest.mark.parametrize('part_gen', [StructGen([["a", long_gen]])], ids=meta_idfn('partBy:')) @@ -1731,6 +1764,8 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) + +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order def test_unbounded_to_unbounded_window(): # This is specifically to test a bug that caused overflow issues when calculating @@ -1784,6 +1819,7 @@ def test_window_first_last_nth_ignore_nulls(data_gen): 'FROM window_agg_table') +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @tz_sensitive_test @allow_non_gpu(*non_supported_tz_allow) @ignore_order(local=True) @@ -1825,6 +1861,7 @@ def test_to_date_with_window_functions(): ) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) @approximate_float @pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) @@ -1881,6 +1918,7 @@ def spark_bugs_in_decimal_sorting(): return v < "3.1.4" or v < "3.3.1" or v < "3.2.3" or v < "3.4.0" +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) @approximate_float @pytest.mark.parametrize('batch_size', ['1g'], ids=idfn) @@ -1925,6 +1963,7 @@ def test_window_aggs_for_negative_rows_unpartitioned(data_gen, batch_size): conf=conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) @pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) @pytest.mark.parametrize('data_gen', [ @@ -1964,6 +2003,7 @@ def test_window_aggs_for_batched_finite_row_windows_partitioned(data_gen, batch_ conf=conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) @pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) @pytest.mark.parametrize('data_gen', [ @@ -2003,6 +2043,7 @@ def test_window_aggs_for_batched_finite_row_windows_unpartitioned(data_gen, batc conf=conf) +@disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @ignore_order(local=True) @pytest.mark.parametrize('data_gen', [_grpkey_int_with_nulls,], ids=idfn) def test_window_aggs_for_batched_finite_row_windows_fallback(data_gen): diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandExec.scala index 0fc7defd063..c0e20bfaebc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandExec.scala @@ -20,6 +20,7 @@ import scala.util.Random import ai.rapids.cudf.NvtxColor import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.GpuExpressionsUtils.NullVecCache import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.shims.ShimUnaryExecNode @@ -54,7 +55,9 @@ class GpuExpandExecMeta( val projections = gpuProjections.map(_.map(_.convertToGpu())) GpuExpandExec(projections, expand.output, childPlans.head.convertIfNeeded())( useTieredProject = conf.isTieredProjectEnabled, - preprojectEnabled = conf.isExpandPreprojectEnabled) + preprojectEnabled = conf.isExpandPreprojectEnabled, + cacheNullMaxCount = conf.expandCachingNullVecMaxCount, + coalesceAfter = conf.isCoalesceAfterExpandEnabled) } } @@ -72,11 +75,17 @@ case class GpuExpandExec( output: Seq[Attribute], child: SparkPlan)( useTieredProject: Boolean = false, - preprojectEnabled: Boolean = false) extends ShimUnaryExecNode with GpuExec { + preprojectEnabled: Boolean = false, + cacheNullMaxCount: Int = 0, + override val coalesceAfter: Boolean = true +) extends ShimUnaryExecNode with GpuExec { override def otherCopyArgs: Seq[AnyRef] = Seq[AnyRef]( useTieredProject.asInstanceOf[java.lang.Boolean], - preprojectEnabled.asInstanceOf[java.lang.Boolean]) + preprojectEnabled.asInstanceOf[java.lang.Boolean], + cacheNullMaxCount.asInstanceOf[java.lang.Integer], + coalesceAfter.asInstanceOf[java.lang.Boolean] + ) private val PRE_PROJECT_TIME = "preprojectTime" override val outputRowsLevel: MetricsLevel = ESSENTIAL_LEVEL @@ -127,7 +136,7 @@ case class GpuExpandExec( } child.executeColumnar().mapPartitions { it => - new GpuExpandIterator(boundProjections, metricsMap, preprojectIter(it)) + new GpuExpandIterator(boundProjections, metricsMap, preprojectIter(it), cacheNullMaxCount) } } @@ -191,7 +200,8 @@ case class GpuExpandExec( class GpuExpandIterator( boundProjections: Seq[GpuTieredProject], metrics: Map[String, GpuMetric], - it: Iterator[ColumnarBatch]) + it: Iterator[ColumnarBatch], + cacheNullMaxCount: Int) extends Iterator[ColumnarBatch] { private var sb: Option[SpillableColumnarBatch] = None @@ -206,9 +216,20 @@ class GpuExpandIterator( Option(TaskContext.get()).foreach { tc => onTaskCompletion(tc) { sb.foreach(_.close()) + + if (cacheNullMaxCount > 0) { + import scala.collection.JavaConverters._ + GpuExpressionsUtils.cachedNullVectors.get().values().asScala.foreach(_.close()) + GpuExpressionsUtils.cachedNullVectors.get().clear() + } } } + if (cacheNullMaxCount > 0 && GpuExpressionsUtils.cachedNullVectors.get() == null) { + GpuExpressionsUtils.cachedNullVectors.set(new NullVecCache(cacheNullMaxCount)) + } + + override def hasNext: Boolean = sb.isDefined || it.hasNext override def next(): ColumnarBatch = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala index c46862ab2aa..2400b364b5a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala @@ -21,11 +21,12 @@ import com.nvidia.spark.Retryable import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.shims.{ShimBinaryExpression, ShimExpression, ShimTernaryExpression, ShimUnaryExpression} +import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.UTF8String @@ -52,6 +53,40 @@ object GpuExpressionsUtils { "implemented and should have been disabled") } + // This is only for ExpandExec which will generate a lot of null vectors + case class NullVecKey(d: DataType, n: Int) + + class NullVecCache(private val maxNulls: Int) + extends util.LinkedHashMap[NullVecKey, GpuColumnVector](100, 0.75f, true) { + private var totalNulls: Long = 0L + + override def clear(): Unit = { + super.clear() + totalNulls = 0 + } + + override def put(key: NullVecKey, v: GpuColumnVector): GpuColumnVector = { + if (v.getRowCount > maxNulls) { + throw new IllegalStateException(s"spark.rapids.sql.expandCachingNullVec.maxNulls" + + s"($maxNulls) is set too small to hold single vector with ${v.getRowCount} rows.") + } + val iter = entrySet().iterator() + while (iter.hasNext && totalNulls > maxNulls - v.getRowCount) { + val entry = iter.next() + iter.remove() + totalNulls -= entry.getValue.getRowCount + } + + val ret = super.put(key, v) + totalNulls += v.getRowCount + ret + } + + override def remove(key: Any): GpuColumnVector = throw new UnsupportedOperationException() + } + + val cachedNullVectors = new ThreadLocal[NullVecCache]() + /** * Tries to resolve a `GpuColumnVector` from a Scala `Any`. * @@ -73,7 +108,18 @@ object GpuExpressionsUtils { def resolveColumnVector(any: Any, numRows: Int): GpuColumnVector = { withResourceIfAllowed(any) { case c: GpuColumnVector => c.incRefCount() - case s: GpuScalar => GpuColumnVector.from(s, numRows, s.dataType) + case s: GpuScalar => + if (!s.isValid && cachedNullVectors.get() != null) { + if (!cachedNullVectors.get.containsKey(NullVecKey.apply(s.dataType, numRows))) { + cachedNullVectors.get.put(NullVecKey.apply(s.dataType, numRows), + GpuColumnVector.from(s, numRows, s.dataType)) + } + + val ret = cachedNullVectors.get().get(NullVecKey.apply(s.dataType, numRows)) + ret.incRefCount() + } else { + GpuColumnVector.from(s, numRows, s.dataType) + } case other => throw new IllegalArgumentException(s"Cannot resolve a ColumnVector from the value:" + s" $other. Please convert it to a GpuScalar or a GpuColumnVector before returning.") diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala index 48f9de5a61a..eef083bb93d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala @@ -360,6 +360,7 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { case _: GpuDataSourceScanExec => true case _: DataSourceV2ScanExecBase => true case _: RDDScanExec => true // just in case an RDD was reading in data + case _: ExpandExec => true case _ => false } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 5203e926efa..826d398dbc9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -1219,6 +1219,20 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") .booleanConf .createWithDefault(true) + val ENABLE_COALESCE_AFTER_EXPAND = conf("spark.rapids.sql.coalesceAfterExpand.enabled") + .doc("When set to false disables the coalesce after GPU Expand. ") + .internal() + .booleanConf + .createWithDefault(false) + + val EXPAND_CACHING_NULL_VEC_MAX_NULL_COUNT = + conf("spark.rapids.sql.expandCachingNullVec.maxNulls") + .doc("Max number of null scalar in null vectors to cache for GPU Expand. " + + "If the number of null scala exceeds this value, the null vectors will not be cached." + + "The value has to be positive for caching to be enabled.") + .internal().integerConf + .createWithDefault(0) + val ENABLE_ORC_FLOAT_TYPES_TO_STRING = conf("spark.rapids.sql.format.orc.floatTypesToString.enable") .doc("When reading an ORC file, the source data schemas(schemas of ORC file) may differ " + @@ -2762,6 +2776,10 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val isExpandPreprojectEnabled: Boolean = get(ENABLE_EXPAND_PREPROJECT) + lazy val isCoalesceAfterExpandEnabled: Boolean = get(ENABLE_COALESCE_AFTER_EXPAND) + + lazy val expandCachingNullVecMaxCount: Int = get(EXPAND_CACHING_NULL_VEC_MAX_NULL_COUNT) + lazy val multiThreadReadNumThreads: Int = { // Use the largest value set among all the options. val deprecatedConfs = Seq( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index dc2845e4461..a435988686d 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -2084,11 +2084,25 @@ class GpuConvMeta( override def tagExprForGpu(): Unit = { val fromBaseLit = GpuOverrides.extractLit(expr.fromBaseExpr) val toBaseLit = GpuOverrides.extractLit(expr.toBaseExpr) + val errorPostfix = "only literal 10 or 16 are supported for source and target radixes" (fromBaseLit, toBaseLit) match { - case (Some(Literal(fromBaseVal, IntegerType)), Some(Literal(toBaseVal, IntegerType))) - if Set(fromBaseVal, toBaseVal).subsetOf(Set(10, 16)) => () + case (Some(Literal(fromBaseVal, IntegerType)), Some(Literal(toBaseVal, IntegerType))) => + def isBaseSupported(base: Any): Boolean = base == 10 || base == 16 + if (!isBaseSupported(fromBaseVal) && !isBaseSupported(toBaseVal)) { + willNotWorkOnGpu(because = s"both ${fromBaseVal} and ${toBaseVal} are not " + + s"a supported radix, ${errorPostfix}") + } else if (!isBaseSupported(fromBaseVal)) { + willNotWorkOnGpu(because = s"${fromBaseVal} is not a supported source radix, " + + s"${errorPostfix}") + } else if (!isBaseSupported(toBaseVal)) { + willNotWorkOnGpu(because = s"${toBaseVal} is not a supported target radix, " + + s"${errorPostfix}") + } case _ => - willNotWorkOnGpu(because = "only literal 10 or 16 for from_base and to_base are supported") + // This will never happen in production as the function signature enforces + // integer types for the bases, but nice to have an edge case handling. + willNotWorkOnGpu(because = "either source radix or target radix is not an integer " + + "literal, " + errorPostfix) } } diff --git a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/spark330/GpuBucketingUtils.scala b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/spark330/GpuBucketingUtils.scala index feb562fa9b8..0f7c9b4fd62 100644 --- a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/spark330/GpuBucketingUtils.scala +++ b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/spark330/GpuBucketingUtils.scala @@ -31,6 +31,7 @@ {"spark": "343"} {"spark": "350"} {"spark": "351"} +{"spark": "400"} spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims