Skip to content

Commit

Permalink
Avoid generating NaNs as partition values in test_part_write_round_trip
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Lowe <[email protected]>
  • Loading branch information
jlowe committed Dec 4, 2023
1 parent ea7b7fe commit 894012c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
11 changes: 9 additions & 2 deletions integration_tests/src/main/python/orc_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,15 @@ def test_write_round_trip_corner(spark_tmp_path, orc_gen, orc_impl):
@pytest.mark.parametrize('orc_gen', orc_part_write_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_part_write_round_trip(spark_tmp_path, orc_gen):
gen_list = [('a', RepeatSeqGen(orc_gen, 10)),
('b', orc_gen)]
part_gen = orc_gen
# Avoid generating NaNs for partition values.
# Spark does not handle partition switching properly since NaN != NaN.
if isinstance(part_gen, FloatGen):
part_gen = FloatGen(no_nans=True)
elif isinstance(part_gen, DoubleGen):
part_gen = DoubleGen(no_nans=True)
gen_list = [('a', RepeatSeqGen(part_gen, 10)),
('b', orc_gen)]
data_path = spark_tmp_path + '/ORC_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.partitionBy('a').orc(path),
Expand Down
11 changes: 9 additions & 2 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,15 @@ def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase):
@pytest.mark.parametrize('parquet_gen', parquet_part_write_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_part_write_round_trip(spark_tmp_path, parquet_gen):
gen_list = [('a', RepeatSeqGen(parquet_gen, 10)),
('b', parquet_gen)]
part_gen = parquet_gen
# Avoid generating NaNs for partition values.
# Spark does not handle partition switching properly since NaN != NaN.
if isinstance(part_gen, FloatGen):
part_gen = FloatGen(no_nans=True)
elif isinstance(part_gen, DoubleGen):
part_gen = DoubleGen(no_nans=True)
gen_list = [('a', RepeatSeqGen(part_gen, 10)),
('b', parquet_gen)]
data_path = spark_tmp_path + '/PARQUET_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.partitionBy('a').parquet(path),
Expand Down

0 comments on commit 894012c

Please sign in to comment.