diff --git a/integration_tests/src/main/python/datasourcev2_write_test.py b/integration_tests/src/main/python/datasourcev2_write_test.py index 1f4bc133d2a..4fffd10ab44 100644 --- a/integration_tests/src/main/python/datasourcev2_write_test.py +++ b/integration_tests/src/main/python/datasourcev2_write_test.py @@ -18,7 +18,7 @@ from data_gen import gen_df, decimal_gens, non_utc_allow from marks import * from spark_session import is_hive_available, is_spark_330_or_later, with_cpu_session, with_gpu_session -from hive_parquet_write_test import _hive_bucket_gens, _hive_array_gens, _hive_struct_gens +from hive_parquet_write_test import _hive_bucket_gens_sans_bools, _hive_array_gens, _hive_struct_gens from hive_parquet_write_test import read_single_bucket _hive_write_conf = { @@ -33,9 +33,11 @@ @allow_non_gpu(*non_utc_allow) def test_write_hive_bucketed_table(spark_tmp_table_factory, file_format): num_rows = 2048 - + # Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and + # https://github.com/rapidsai/cudf/issues/6763 . + # Once the first issue is fixed, add back boolean_gen def gen_table(spark): - gen_list = [('_c' + str(i), gen) for i, gen in enumerate(_hive_bucket_gens)] + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(_hive_bucket_gens_sans_bools)] types_sql_str = ','.join('{} {}'.format( name, gen.data_type.simpleString()) for name, gen in gen_list) col_names_str = ','.join(name for name, gen in gen_list) diff --git a/integration_tests/src/main/python/hive_parquet_write_test.py b/integration_tests/src/main/python/hive_parquet_write_test.py index e66b889a986..540db74a1ad 100644 --- a/integration_tests/src/main/python/hive_parquet_write_test.py +++ b/integration_tests/src/main/python/hive_parquet_write_test.py @@ -25,9 +25,10 @@ # "GpuInsertIntoHiveTable" for Parquet write. _write_to_hive_conf = {"spark.sql.hive.convertMetastoreParquet": False} -_hive_bucket_gens = [ - boolean_gen, byte_gen, short_gen, int_gen, long_gen, string_gen, float_gen, double_gen, +_hive_bucket_gens_sans_bools = [ + byte_gen, short_gen, int_gen, long_gen, string_gen, float_gen, double_gen, DateGen(start=date(1590, 1, 1)), _restricted_timestamp()] +_hive_bucket_gens = [boolean_gen] + _hive_bucket_gens_sans_bools _hive_basic_gens = _hive_bucket_gens + [ DecimalGen(precision=19, scale=1, nullable=True), diff --git a/integration_tests/src/main/python/hive_write_test.py b/integration_tests/src/main/python/hive_write_test.py index 945cc4806fb..af825a99810 100644 --- a/integration_tests/src/main/python/hive_write_test.py +++ b/integration_tests/src/main/python/hive_write_test.py @@ -29,8 +29,11 @@ def _restricted_timestamp(nullable=True): end=datetime(2262, 4, 11, tzinfo=timezone.utc), nullable=nullable) +# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and +# https://github.com/rapidsai/cudf/issues/6763 . +# Once the first issue is fixed, add back boolean_gen _basic_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, - string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), + string_gen, DateGen(start=date(1590, 1, 1)), _restricted_timestamp() ] + decimal_gens @@ -45,8 +48,11 @@ def _restricted_timestamp(nullable=True): ArrayGen(ArrayGen(string_gen, max_length=10), max_length=10), ArrayGen(StructGen([['child0', byte_gen], ['child1', string_gen], ['child2', float_gen]]))] +# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and +# https://github.com/rapidsai/cudf/issues/6763 . +# Once the first issue is fixed, add back boolean_gen _map_gens = [simple_string_to_string_map_gen] + [MapGen(f(nullable=False), f()) for f in [ - BooleanGen, ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, + ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, lambda nullable=True: _restricted_timestamp(nullable=nullable), lambda nullable=True: DateGen(start=date(1590, 1, 1), nullable=nullable), lambda nullable=True: DecimalGen(precision=15, scale=1, nullable=nullable), diff --git a/integration_tests/src/main/python/orc_test.py b/integration_tests/src/main/python/orc_test.py index 618004ee60d..19894d29aa6 100644 --- a/integration_tests/src/main/python/orc_test.py +++ b/integration_tests/src/main/python/orc_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -112,8 +112,11 @@ def test_basic_read(std_input_path, name, read_func, v1_enabled_list, orc_impl, #E at org.apache.orc.TypeDescription.parseInt(TypeDescription.java:244) #E at org.apache.orc.TypeDescription.parseType(TypeDescription.java:362) # ... +# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and +# https://github.com/rapidsai/cudf/issues/6763 . +# Once the first issue is fixed, add back boolean_gen orc_basic_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, - string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), + string_gen, DateGen(start=date(1590, 1, 1)), orc_timestamp_gen] + decimal_gens orc_basic_struct_gen = StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(orc_basic_gens)]) @@ -201,8 +204,11 @@ def test_read_round_trip(spark_tmp_path, orc_gens, read_func, reader_confs, v1_e read_func(data_path), conf=all_confs) +# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and +# https://github.com/rapidsai/cudf/issues/6763 . +# Once the first issue is fixed, add back boolean_gen orc_pred_push_gens = [ - byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, boolean_gen, + byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, # Once https://github.com/NVIDIA/spark-rapids/issues/139 is fixed replace this with # date_gen @@ -277,8 +283,11 @@ def test_compress_read_round_trip(spark_tmp_path, compress, v1_enabled_list, rea def test_simple_partitioned_read(spark_tmp_path, v1_enabled_list, reader_confs): # Once https://github.com/NVIDIA/spark-rapids/issues/131 is fixed # we should go with a more standard set of generators + # Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and + # https://github.com/rapidsai/cudf/issues/6763 . + # Once the first issue is fixed, add back boolean_gen orc_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, - string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), + string_gen, DateGen(start=date(1590, 1, 1)), orc_timestamp_gen] gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)] first_data_path = spark_tmp_path + '/ORC_DATA/key=0/key2=20' @@ -344,8 +353,11 @@ def test_partitioned_read_just_partitions(spark_tmp_path, v1_enabled_list, reade def test_merge_schema_read(spark_tmp_path, v1_enabled_list, reader_confs): # Once https://github.com/NVIDIA/spark-rapids/issues/131 is fixed # we should go with a more standard set of generators + # Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and + # https://github.com/rapidsai/cudf/issues/6763 . + # Once the first issue is fixed, add back boolean_gen orc_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, - string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), + string_gen, DateGen(start=date(1590, 1, 1)), orc_timestamp_gen] first_gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)] first_data_path = spark_tmp_path + '/ORC_DATA/key=0' @@ -825,8 +837,11 @@ def test_read_round_trip_for_multithreaded_combining(spark_tmp_path, gens, keep_ @pytest.mark.parametrize('keep_order', [True, pytest.param(False, marks=pytest.mark.ignore_order(local=True))]) @allow_non_gpu(*non_utc_allow_orc_scan) def test_simple_partitioned_read_for_multithreaded_combining(spark_tmp_path, keep_order): + # Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and + # https://github.com/rapidsai/cudf/issues/6763 . + # Once the first issue is fixed, add back boolean_gen orc_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, - string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), + string_gen, DateGen(start=date(1590, 1, 1)), orc_timestamp_gen] gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)] first_data_path = spark_tmp_path + '/ORC_DATA/key=0/key2=20' @@ -927,7 +942,10 @@ def test_orc_column_name_with_dots(spark_tmp_path, reader_confs): ("f.g", int_gen), ("h", string_gen)])), ("i.j", long_gen)])), - ("k", boolean_gen)] + # Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and + # https://github.com/rapidsai/cudf/issues/6763 . + # Once the first issue is fixed, add back boolean_gen for column k + ("k", int_gen)] with_cpu_session(lambda spark: gen_df(spark, gens).write.orc(data_path)) assert_gpu_and_cpu_are_equal_collect(lambda spark: reader(spark), conf=all_confs) assert_gpu_and_cpu_are_equal_collect(lambda spark: reader(spark).selectExpr("`a.b`"), conf=all_confs) @@ -945,7 +963,10 @@ def test_orc_with_null_column(spark_tmp_path, reader_confs): def gen_null_df(spark): return spark.createDataFrame( [(None, None, None, None, None)], - "c1 int, c2 long, c3 float, c4 double, c5 boolean") + # Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and + # https://github.com/rapidsai/cudf/issues/6763 . + # Once the first issue is fixed, add back boolean_gen + "c1 int, c2 long, c3 float, c4 double, c5 int") assert_gpu_and_cpu_writes_are_equal_collect( lambda spark, path: gen_null_df(spark).write.orc(path), @@ -966,7 +987,10 @@ def test_orc_with_null_column_with_1m_rows(spark_tmp_path, reader_confs): def gen_null_df(spark): return spark.createDataFrame( data, - "c1 int, c2 long, c3 float, c4 double, c5 boolean") + # Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and + # https://github.com/rapidsai/cudf/issues/6763 . + # Once the first issue is fixed, add back boolean_gen + "c1 int, c2 long, c3 float, c4 double, c5 int") assert_gpu_and_cpu_writes_are_equal_collect( lambda spark, path: gen_null_df(spark).write.orc(path), lambda spark, path: spark.read.orc(path), diff --git a/integration_tests/src/main/python/orc_write_test.py b/integration_tests/src/main/python/orc_write_test.py index ddb69524ac4..7e415c79a46 100644 --- a/integration_tests/src/main/python/orc_write_test.py +++ b/integration_tests/src/main/python/orc_write_test.py @@ -24,9 +24,11 @@ from pyspark.sql.types import * pytestmark = pytest.mark.nightly_resource_consuming_test - +# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and +# https://github.com/rapidsai/cudf/issues/6763 . +# Once the first issue is fixed, add back boolean_gen. orc_write_basic_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, - string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), + string_gen, DateGen(start=date(1590, 1, 1)), TimestampGen(start=datetime(1970, 1, 1, tzinfo=timezone.utc)) ] + \ decimal_gens @@ -52,7 +54,8 @@ all_nulls_map_gen, all_empty_map_gen] -orc_write_basic_struct_gen = StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(orc_write_basic_gens)]) +orc_write_basic_struct_gen = StructGen( + [['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(orc_write_basic_gens)]) orc_write_struct_gens_sample = [orc_write_basic_struct_gen, StructGen([['child0', byte_gen], ['child1', orc_write_basic_struct_gen]]), @@ -62,15 +65,18 @@ ArrayGen(ArrayGen(short_gen, max_length=10), max_length=10), ArrayGen(ArrayGen(string_gen, max_length=10), max_length=10), ArrayGen(StructGen([['child0', byte_gen], ['child1', string_gen], ['child2', float_gen]]))] - +# Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and +# https://github.com/rapidsai/cudf/issues/6763 . +# Once the first issue is fixed, add back boolean_gen. orc_write_basic_map_gens = [simple_string_to_string_map_gen] + [MapGen(f(nullable=False), f()) for f in [ - BooleanGen, ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, + ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, # Using timestamps from 1970 to work around a cudf ORC bug # https://github.com/NVIDIA/spark-rapids/issues/140. lambda nullable=True: TimestampGen(start=datetime(1970, 1, 1, tzinfo=timezone.utc), nullable=nullable), lambda nullable=True: DateGen(start=date(1590, 1, 1), nullable=nullable), lambda nullable=True: DecimalGen(precision=15, scale=1, nullable=nullable), - lambda nullable=True: DecimalGen(precision=36, scale=5, nullable=nullable)]] + lambda nullable=True: DecimalGen(precision=36, scale=5, nullable=nullable)]] + [MapGen( + f(nullable=False), f(nullable=False)) for f in [IntegerGen]] orc_write_gens_list = [orc_write_basic_gens, orc_write_struct_gens_sample, @@ -79,6 +85,7 @@ pytest.param([date_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/139')), pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/140'))] +bool_gen = [BooleanGen(nullable=True), BooleanGen(nullable=False)] @pytest.mark.parametrize('orc_gens', orc_write_gens_list, ids=idfn) @pytest.mark.parametrize('orc_impl', ["native", "hive"]) @allow_non_gpu(*non_utc_allow) @@ -91,6 +98,30 @@ def test_write_round_trip(spark_tmp_path, orc_gens, orc_impl): data_path, conf={'spark.sql.orc.impl': orc_impl, 'spark.rapids.sql.format.orc.write.enabled': True}) +@pytest.mark.parametrize('orc_gens', [bool_gen], ids=idfn) +@pytest.mark.parametrize('orc_impl', ["native", "hive"]) +@allow_non_gpu('ExecutedCommandExec', 'DataWritingCommandExec', 'WriteFilesExec') +def test_write_round_trip_bools_only_fallback(spark_tmp_path, orc_gens, orc_impl): + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)] + data_path = spark_tmp_path + '/ORC_DATA' + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.orc(path), + lambda spark, path: spark.read.orc(path), + data_path, + conf={'spark.sql.orc.impl': orc_impl, 'spark.rapids.sql.format.orc.write.enabled': True}) + +@pytest.mark.parametrize('orc_gens', [bool_gen], ids=idfn) +@pytest.mark.parametrize('orc_impl', ["native", "hive"]) +def test_write_round_trip_bools_only_no_fallback(spark_tmp_path, orc_gens, orc_impl): + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)] + data_path = spark_tmp_path + '/ORC_DATA' + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.orc(path), + lambda spark, path: spark.read.orc(path), + data_path, + conf={'spark.sql.orc.impl': orc_impl, 'spark.rapids.sql.format.orc.write.enabled': True, + 'spark.rapids.sql.format.orc.write.boolType.enabled': True}) + @pytest.mark.parametrize('orc_gen', orc_write_odd_empty_strings_gens_sample, ids=idfn) @pytest.mark.parametrize('orc_impl', ["native", "hive"]) def test_write_round_trip_corner(spark_tmp_path, orc_gen, orc_impl): @@ -103,7 +134,8 @@ def test_write_round_trip_corner(spark_tmp_path, orc_gen, orc_impl): conf={'spark.sql.orc.impl': orc_impl, 'spark.rapids.sql.format.orc.write.enabled': True}) orc_part_write_gens = [ - byte_gen, short_gen, int_gen, long_gen, boolean_gen, + # Add back boolean_gen when https://github.com/rapidsai/cudf/issues/6763 is fixed + byte_gen, short_gen, int_gen, long_gen, # Some file systems have issues with UTF8 strings so to help the test pass even there StringGen('(\\w| ){0,50}'), # Once https://github.com/NVIDIA/spark-rapids/issues/139 is fixed replace this with @@ -345,7 +377,10 @@ def test_orc_write_column_name_with_dots(spark_tmp_path): ("f.g", int_gen), ("h", string_gen)])), ("i.j", long_gen)])), - ("k", boolean_gen)] + # Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and + # https://github.com/rapidsai/cudf/issues/6763 . + # Once the first issue is fixed, add back boolean_gen for column k + ("k", int_gen)] assert_gpu_and_cpu_writes_are_equal_collect( lambda spark, path: gen_df(spark, gens).coalesce(1).write.orc(path), lambda spark, path: spark.read.orc(path), diff --git a/integration_tests/src/main/python/schema_evolution_test.py b/integration_tests/src/main/python/schema_evolution_test.py index ff501324cc0..57af4a1126e 100644 --- a/integration_tests/src/main/python/schema_evolution_test.py +++ b/integration_tests/src/main/python/schema_evolution_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,7 +34,9 @@ # List of additional column data generators to use when adding columns _additional_gens = [ - boolean_gen, + # Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and + # https://github.com/rapidsai/cudf/issues/6763 . + # Once the first issue is fixed, add back boolean_gen byte_gen, short_gen, int_gen, @@ -49,7 +51,10 @@ # simple_string_to_string_map_gen), ArrayGen(_custom_date_gen), struct_gen_decimal128, - StructGen([("c0", ArrayGen(long_gen)), ("c1", boolean_gen)]), + # Use every type except boolean, see https://github.com/NVIDIA/spark-rapids/issues/11762 and + # https://github.com/rapidsai/cudf/issues/6763 . + # Once the first issue is fixed, add back boolean_gen from int_gen for c1 + StructGen([("c0", ArrayGen(long_gen)), ("c1", int_gen)]), ] def get_additional_columns(): diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 406aeb0365b..e750f5688ce 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -1268,6 +1268,14 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") .booleanConf .createWithDefault(true) + val ENABLE_ORC_BOOL = conf("spark.rapids.sql.format.orc.write.boolType.enabled") + .doc("When set to false disables boolean columns for ORC writes. " + + "Set to true if you want to experiment. " + + "See https://github.com/NVIDIA/spark-rapids/issues/11736.") + .internal() + .booleanConf + .createWithDefault(false) + val ENABLE_EXPAND_PREPROJECT = conf("spark.rapids.sql.expandPreproject.enabled") .doc("When set to false disables the pre-projection for GPU Expand. " + "Pre-projection leverages the tiered projection to evaluate expressions that " + @@ -3028,6 +3036,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val maxNumOrcFilesParallel: Int = get(ORC_MULTITHREAD_READ_MAX_NUM_FILES_PARALLEL) + lazy val isOrcBoolTypeEnabled: Boolean = get(ENABLE_ORC_BOOL) + lazy val isCsvEnabled: Boolean = get(ENABLE_CSV) lazy val isCsvReadEnabled: Boolean = get(ENABLE_CSV_READ) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala index d2f4380646c..1d4bc66a1da 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.orc.{OrcFileFormat, OrcOptions, OrcUtils} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.types._ object GpuOrcFileFormat extends Logging { @@ -83,6 +84,11 @@ object GpuOrcFileFormat extends Logging { // [[org.apache.spark.sql.execution.datasources.DaysWritable]] object // which is a subclass of [[org.apache.hadoop.hive.serde2.io.DateWritable]]. val types = schema.map(_.dataType).toSet + val hasBools = schema.exists { field => + TrampolineUtil.dataTypeExistsRecursively(field.dataType, t => + t.isInstanceOf[BooleanType]) + } + if (types.exists(GpuOverrides.isOrContainsDateOrTimestamp(_))) { if (!GpuOverrides.isUTCTimezone()) { meta.willNotWorkOnGpu("Only UTC timezone is supported for ORC. " + @@ -91,6 +97,10 @@ object GpuOrcFileFormat extends Logging { } } + if (hasBools && !meta.conf.isOrcBoolTypeEnabled) { + meta.willNotWorkOnGpu("Nullable Booleans can not work in certain cases with ORC writer." + + "See https://github.com/rapidsai/cudf/issues/6763") + } FileFormatChecks.tag(meta, schema, OrcFormatType, WriteFileOp) val sqlConf = spark.sessionState.conf diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/OrcFilterSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/OrcFilterSuite.scala index fe86900b32f..6d067800dde 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/OrcFilterSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/OrcFilterSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,10 +18,11 @@ package org.apache.spark.sql.rapids import java.sql.Timestamp -import com.nvidia.spark.rapids.{GpuFilterExec, SparkQueryCompareTestSuite} +import com.nvidia.spark.rapids.{GpuFilterExec, RapidsConf, SparkQueryCompareTestSuite} +import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.FilterExec +import org.apache.spark.sql.execution.{FilterExec, SparkPlan} class OrcFilterSuite extends SparkQueryCompareTestSuite { @@ -39,22 +40,42 @@ class OrcFilterSuite extends SparkQueryCompareTestSuite { test("Support for pushing down filters for boolean types gpu write gpu read") { withTempPath { file => - withGpuSparkSession(spark => { - val data = (0 until 10).map(i => Tuple1(i == 2)) - val df = spark.createDataFrame(data).toDF("a") - df.repartition(10).write.orc(file.getCanonicalPath) - checkPredicatePushDown(spark, file.getCanonicalPath, 10, "a == true") - }) + var gpuPlans: Array[SparkPlan] = Array.empty + val testConf = new SparkConf().set( + RapidsConf.TEST_ALLOWED_NONGPU.key, + "DataWritingCommandExec,ShuffleExchangeExec, WriteFilesExec") + ExecutionPlanCaptureCallback.startCapture() + try { + withGpuSparkSession(spark => { + val data = (0 until 10).map(i => Tuple1(i == 2)) + val df = spark.createDataFrame(data).toDF("a") + df.repartition(10).write.orc(file.getCanonicalPath) + checkPredicatePushDown(spark, file.getCanonicalPath, 10, "a == true") + }, testConf) + } finally { + gpuPlans = ExecutionPlanCaptureCallback.getResultsWithTimeout() + } + ExecutionPlanCaptureCallback.assertDidFallBack(gpuPlans.head, "DataWritingCommandExec") } } test("Support for pushing down filters for boolean types gpu write cpu read") { withTempPath { file => - withGpuSparkSession(spark => { - val data = (0 until 10).map(i => Tuple1(i == 2)) - val df = spark.createDataFrame(data).toDF("a") - df.repartition(10).write.orc(file.getCanonicalPath) - }) + var gpuPlans: Array[SparkPlan] = Array.empty + val testConf = new SparkConf().set( + RapidsConf.TEST_ALLOWED_NONGPU.key, + "DataWritingCommandExec,ShuffleExchangeExec, WriteFilesExec") + ExecutionPlanCaptureCallback.startCapture() + try { + withGpuSparkSession(spark => { + val data = (0 until 10).map(i => Tuple1(i == 2)) + val df = spark.createDataFrame(data).toDF("a") + df.repartition(10).write.orc(file.getCanonicalPath) + }, testConf) + } finally { + gpuPlans = ExecutionPlanCaptureCallback.getResultsWithTimeout() + } + ExecutionPlanCaptureCallback.assertDidFallBack(gpuPlans.head, "DataWritingCommandExec") withCpuSparkSession(spark => { checkPredicatePushDown(spark, file.getCanonicalPath, 10, "a == true") })