Skip to content

Commit

Permalink
Fix zero-scale floor and ceil tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Lowe <[email protected]>
  • Loading branch information
jlowe committed Nov 28, 2023
1 parent 26c9e37 commit 0d433c8
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,17 +587,15 @@ def test_floor(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr('floor(a)'))

@datagen_overrides(seed=0, reason='https://github.com/NVIDIA/spark-rapids/issues/9722')
@pytest.mark.skipif(is_before_spark_330(), reason='scale parameter in Floor function is not supported before Spark 3.3.0')
@pytest.mark.parametrize('data_gen', double_n_long_gens + _arith_decimal_gens_no_neg_scale, ids=idfn)
@pytest.mark.parametrize('data_gen', [long_gen] + _arith_decimal_gens_no_neg_scale, ids=idfn)
def test_floor_scale_zero(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr('floor(a, 0)'),
conf={'spark.rapids.sql.castFloatToDecimal.enabled':'true'})
lambda spark : unary_op_df(spark, data_gen).selectExpr('floor(a, 0)'))

@pytest.mark.skipif(is_before_spark_330(), reason='scale parameter in Floor function is not supported before Spark 3.3.0')
@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('data_gen', double_n_long_gens + _arith_decimal_gens_no_neg_scale_38_0_overflow, ids=idfn)
@pytest.mark.parametrize('data_gen', [long_gen] + _arith_decimal_gens_no_neg_scale_38_0_overflow, ids=idfn)
def test_floor_scale_nonzero(data_gen):
assert_gpu_fallback_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr('floor(a, -1)'), 'RoundFloor')
Expand All @@ -607,13 +605,11 @@ def test_ceil(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr('ceil(a)'))

@datagen_overrides(seed=0, reason='https://github.com/NVIDIA/spark-rapids/issues/9846')
@pytest.mark.skipif(is_before_spark_330(), reason='scale parameter in Ceil function is not supported before Spark 3.3.0')
@pytest.mark.parametrize('data_gen', double_n_long_gens + _arith_decimal_gens_no_neg_scale, ids=idfn)
@pytest.mark.parametrize('data_gen', [long_gen] + _arith_decimal_gens_no_neg_scale, ids=idfn)
def test_ceil_scale_zero(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr('ceil(a, 0)'),
conf={'spark.rapids.sql.castFloatToDecimal.enabled':'true'})
lambda spark : unary_op_df(spark, data_gen).selectExpr('ceil(a, 0)'))

@pytest.mark.parametrize('data_gen', [_decimal_gen_36_neg5, _decimal_gen_38_neg10], ids=idfn)
def test_floor_ceil_overflow(data_gen):
Expand Down

0 comments on commit 0d433c8

Please sign in to comment.