diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index 2acf3984f64..775b4a9d1cb 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -29,6 +29,11 @@ pytestmark = pytest.mark.nightly_resource_consuming_test +conf_key_parquet_datetimeRebaseModeInWrite = 'spark.sql.parquet.datetimeRebaseModeInWrite' +conf_key_parquet_int96RebaseModeInWrite = 'spark.sql.parquet.int96RebaseModeInWrite' +conf_key_parquet_datetimeRebaseModeInRead = 'spark.sql.parquet.datetimeRebaseModeInRead' +conf_key_parquet_int96RebaseModeInRead = 'spark.sql.parquet.int96RebaseModeInRead' + # test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for # non-cloud original_parquet_file_reader_conf={'spark.rapids.sql.format.parquet.reader.type': 'PERFILE'} @@ -37,11 +42,8 @@ reader_opt_confs = [original_parquet_file_reader_conf, multithreaded_parquet_file_reader_conf, coalesce_parquet_file_reader_conf] parquet_decimal_struct_gen= StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(decimal_gens)]) -legacy_parquet_datetimeRebaseModeInWrite='spark.sql.parquet.datetimeRebaseModeInWrite' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite' -legacy_parquet_int96RebaseModeInWrite='spark.sql.parquet.int96RebaseModeInWrite' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.int96RebaseModeInWrite' -legacy_parquet_int96RebaseModeInRead='spark.sql.parquet.int96RebaseModeInRead' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.int96RebaseModeInRead' -writer_confs={legacy_parquet_datetimeRebaseModeInWrite: 'CORRECTED', - legacy_parquet_int96RebaseModeInWrite: 'CORRECTED'} +writer_confs={conf_key_parquet_datetimeRebaseModeInWrite: 'CORRECTED', + conf_key_parquet_int96RebaseModeInWrite: 'CORRECTED'} parquet_basic_gen =[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, date_gen, TimestampGen(), binary_gen] @@ -161,8 +163,8 @@ def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase): lambda spark, path: unary_op_df(spark, gen).write.parquet(path), lambda spark, path: spark.read.parquet(path), data_path, - conf={legacy_parquet_datetimeRebaseModeInWrite: ts_rebase, - legacy_parquet_int96RebaseModeInWrite: ts_rebase, + conf={conf_key_parquet_datetimeRebaseModeInWrite: ts_rebase, + conf_key_parquet_int96RebaseModeInWrite: ts_rebase, 'spark.sql.parquet.outputTimestampType': ts_type}) @@ -288,8 +290,8 @@ def test_write_sql_save_table(spark_tmp_path, parquet_gens, spark_tmp_table_fact def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_factory, int96_rebase, datetime_rebase, ts_write): spark.conf.set('spark.sql.parquet.outputTimestampType', ts_write) - spark.conf.set(legacy_parquet_datetimeRebaseModeInWrite, datetime_rebase) - spark.conf.set(legacy_parquet_int96RebaseModeInWrite, int96_rebase) # for spark 310 + spark.conf.set(conf_key_parquet_datetimeRebaseModeInWrite, datetime_rebase) + spark.conf.set(conf_key_parquet_int96RebaseModeInWrite, int96_rebase) # for spark 310 with pytest.raises(Exception) as e_info: df.coalesce(1).write.format("parquet").mode('overwrite').option("path", data_path).saveAsTable(spark_tmp_table_factory.get()) assert e_info.match(r".*SparkUpgradeException.*") @@ -547,8 +549,8 @@ def generate_map_with_empty_validity(spark, path): def test_parquet_write_fails_legacy_datetime(spark_tmp_path, data_gen, ts_write, ts_rebase_write): data_path = spark_tmp_path + '/PARQUET_DATA' all_confs = {'spark.sql.parquet.outputTimestampType': ts_write, - legacy_parquet_datetimeRebaseModeInWrite: ts_rebase_write, - legacy_parquet_int96RebaseModeInWrite: ts_rebase_write} + conf_key_parquet_datetimeRebaseModeInWrite: ts_rebase_write, + conf_key_parquet_int96RebaseModeInWrite: ts_rebase_write} def writeParquetCatchException(spark, data_gen, data_path): with pytest.raises(Exception) as e_info: unary_op_df(spark, data_gen).coalesce(1).write.parquet(data_path) @@ -566,12 +568,12 @@ def test_parquet_write_roundtrip_datetime_with_legacy_rebase(spark_tmp_path, dat ts_rebase_write, ts_rebase_read): data_path = spark_tmp_path + '/PARQUET_DATA' all_confs = {'spark.sql.parquet.outputTimestampType': ts_write, - legacy_parquet_datetimeRebaseModeInWrite: ts_rebase_write[0], - legacy_parquet_int96RebaseModeInWrite: ts_rebase_write[1], + conf_key_parquet_datetimeRebaseModeInWrite: ts_rebase_write[0], + conf_key_parquet_int96RebaseModeInWrite: ts_rebase_write[1], # The rebase modes in read configs should be ignored and overridden by the same # modes in write configs, which are retrieved from the written files. - 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': ts_rebase_read[0], - legacy_parquet_int96RebaseModeInRead: ts_rebase_read[1]} + conf_key_parquet_datetimeRebaseModeInRead: ts_rebase_read[0], + conf_key_parquet_int96RebaseModeInRead: ts_rebase_read[1]} assert_gpu_and_cpu_writes_are_equal_collect( lambda spark, path: unary_op_df(spark, data_gen).coalesce(1).write.parquet(path), lambda spark, path: spark.read.parquet(path), @@ -600,7 +602,7 @@ def test_it(spark): spark.sql("CREATE TABLE {} LOCATION '{}/ctas' AS SELECT * FROM {}".format( ctas_with_existing_name, data_path, src_name)) except pyspark.sql.utils.AnalysisException as e: - description = e._desc if is_spark_400_or_later() else e.desc + description = e._desc if (is_spark_400_or_later() or is_databricks_version_or_later(14, 3)) else e.desc if allow_non_empty or description.find('non-empty directory') == -1: raise e with_gpu_session(test_it, conf) @@ -829,8 +831,8 @@ def write_partitions(spark, table_path): ) def hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, ts_rebase, func): - conf={legacy_parquet_datetimeRebaseModeInWrite: ts_rebase, - legacy_parquet_int96RebaseModeInWrite: ts_rebase} + conf={conf_key_parquet_datetimeRebaseModeInWrite: ts_rebase, + conf_key_parquet_int96RebaseModeInWrite: ts_rebase} def create_table(spark, path): tmp_table = spark_tmp_table_factory.get()