From 894012ce4dac659eaad116664713bcabda4ea3eb Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 4 Dec 2023 16:29:31 -0600 Subject: [PATCH] Avoid generating NaNs as partition values in test_part_write_round_trip Signed-off-by: Jason Lowe --- integration_tests/src/main/python/orc_write_test.py | 11 +++++++++-- .../src/main/python/parquet_write_test.py | 11 +++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/integration_tests/src/main/python/orc_write_test.py b/integration_tests/src/main/python/orc_write_test.py index 5617f8e20e5..912004ed8c7 100644 --- a/integration_tests/src/main/python/orc_write_test.py +++ b/integration_tests/src/main/python/orc_write_test.py @@ -118,8 +118,15 @@ def test_write_round_trip_corner(spark_tmp_path, orc_gen, orc_impl): @pytest.mark.parametrize('orc_gen', orc_part_write_gens, ids=idfn) @pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653') def test_part_write_round_trip(spark_tmp_path, orc_gen): - gen_list = [('a', RepeatSeqGen(orc_gen, 10)), - ('b', orc_gen)] + part_gen = orc_gen + # Avoid generating NaNs for partition values. + # Spark does not handle partition switching properly since NaN != NaN. + if isinstance(part_gen, FloatGen): + part_gen = FloatGen(no_nans=True) + elif isinstance(part_gen, DoubleGen): + part_gen = DoubleGen(no_nans=True) + gen_list = [('a', RepeatSeqGen(part_gen, 10)), + ('b', orc_gen)] data_path = spark_tmp_path + '/ORC_DATA' assert_gpu_and_cpu_writes_are_equal_collect( lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.partitionBy('a').orc(path), diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index 9584f2a3520..015e4481700 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -176,8 +176,15 @@ def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase): @pytest.mark.parametrize('parquet_gen', parquet_part_write_gens, ids=idfn) @pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653') def test_part_write_round_trip(spark_tmp_path, parquet_gen): - gen_list = [('a', RepeatSeqGen(parquet_gen, 10)), - ('b', parquet_gen)] + part_gen = parquet_gen + # Avoid generating NaNs for partition values. + # Spark does not handle partition switching properly since NaN != NaN. + if isinstance(part_gen, FloatGen): + part_gen = FloatGen(no_nans=True) + elif isinstance(part_gen, DoubleGen): + part_gen = DoubleGen(no_nans=True) + gen_list = [('a', RepeatSeqGen(part_gen, 10)), + ('b', parquet_gen)] data_path = spark_tmp_path + '/PARQUET_DATA' assert_gpu_and_cpu_writes_are_equal_collect( lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.partitionBy('a').parquet(path),