From f9011769bc31653f5ce7c4b1485a241f0e77c393 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Thu, 5 Oct 2023 13:16:37 -0500 Subject: [PATCH] Put back in full decimal support for format_number (#9351) Signed-off-by: Robert (Bobby) Evans --- integration_tests/src/main/python/string_test.py | 16 ++++------------ .../com/nvidia/spark/rapids/GpuOverrides.scala | 13 ++----------- 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 316a427db94..23800edfbca 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -803,7 +803,8 @@ def test_conv_dec_to_from_hex(from_base, to_base, pattern): DecimalGen(precision=36, scale=-5), DecimalGen(precision=38, scale=10), DecimalGen(precision=38, scale=-10), DecimalGen(precision=38, scale=30, special_cases=[Decimal('0.000125')]), - DecimalGen(precision=38, scale=32, special_cases=[Decimal('0.000125')])] + DecimalGen(precision=38, scale=32, special_cases=[Decimal('0.000125')]), + DecimalGen(precision=38, scale=37, special_cases=[Decimal('0.000125')])] @pytest.mark.parametrize('data_gen', format_number_gens, ids=idfn) def test_format_number_supported(data_gen): @@ -815,6 +816,8 @@ def test_format_number_supported(data_gen): 'format_number(a, 1)', 'format_number(a, 5)', 'format_number(a, 10)', + 'format_number(a, 13)', + 'format_number(a, 30)', 'format_number(a, 100)') ) @@ -840,14 +843,3 @@ def test_format_number_float_fallback(data_gen): 'format_number(a, 5)'), 'FormatNumber' ) - -# fallback due to https://github.com/NVIDIA/spark-rapids/issues/9309 -@allow_non_gpu('ProjectExec') -@pytest.mark.parametrize('data_gen', [float_gen, double_gen], ids=idfn) -def test_format_number_decimal_big_scale_fallback(data_gen): - data_gen = DecimalGen(precision=38, scale=37) - assert_gpu_fallback_collect( - lambda spark: unary_op_df(spark, data_gen).selectExpr( - 'format_number(a, 5)'), - 'FormatNumber' - ) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index a512a837656..91f873e3f28 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3093,20 +3093,11 @@ object GpuOverrides extends Logging { (in, conf, p, r) => new BinaryExprMeta[FormatNumber](in, conf, p, r) { override def tagExprForGpu(): Unit = { in.children.head.dataType match { - case _: FloatType | DoubleType => { - if (!conf.isFloatFormatNumberEnabled) { - willNotWorkOnGpu("format_number with floating point types on the GPU returns " + + case FloatType | DoubleType if !conf.isFloatFormatNumberEnabled => + willNotWorkOnGpu("format_number with floating point types on the GPU returns " + "results that have a different precision than the default results of Spark. " + "To enable this operation on the GPU, set" + s" ${RapidsConf.ENABLE_FLOAT_FORMAT_NUMBER} to true.") - } - } - case dt: DecimalType => { - if (dt.scale > 32) { - willNotWorkOnGpu("format_number will generate results mismatched from Spark " + - "when the scale is larger than 32.") - } - } case _ => } }