Skip to content

Commit

Permalink
Put back in full decimal support for format_number (#9351)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Oct 5, 2023
1 parent 7e899f4 commit f901176
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 23 deletions.
16 changes: 4 additions & 12 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,8 @@ def test_conv_dec_to_from_hex(from_base, to_base, pattern):
DecimalGen(precision=36, scale=-5), DecimalGen(precision=38, scale=10),
DecimalGen(precision=38, scale=-10),
DecimalGen(precision=38, scale=30, special_cases=[Decimal('0.000125')]),
DecimalGen(precision=38, scale=32, special_cases=[Decimal('0.000125')])]
DecimalGen(precision=38, scale=32, special_cases=[Decimal('0.000125')]),
DecimalGen(precision=38, scale=37, special_cases=[Decimal('0.000125')])]

@pytest.mark.parametrize('data_gen', format_number_gens, ids=idfn)
def test_format_number_supported(data_gen):
Expand All @@ -815,6 +816,8 @@ def test_format_number_supported(data_gen):
'format_number(a, 1)',
'format_number(a, 5)',
'format_number(a, 10)',
'format_number(a, 13)',
'format_number(a, 30)',
'format_number(a, 100)')
)

Expand All @@ -840,14 +843,3 @@ def test_format_number_float_fallback(data_gen):
'format_number(a, 5)'),
'FormatNumber'
)

# fallback due to https://github.com/NVIDIA/spark-rapids/issues/9309
@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('data_gen', [float_gen, double_gen], ids=idfn)
def test_format_number_decimal_big_scale_fallback(data_gen):
data_gen = DecimalGen(precision=38, scale=37)
assert_gpu_fallback_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'format_number(a, 5)'),
'FormatNumber'
)
Original file line number Diff line number Diff line change
Expand Up @@ -3093,20 +3093,11 @@ object GpuOverrides extends Logging {
(in, conf, p, r) => new BinaryExprMeta[FormatNumber](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
in.children.head.dataType match {
case _: FloatType | DoubleType => {
if (!conf.isFloatFormatNumberEnabled) {
willNotWorkOnGpu("format_number with floating point types on the GPU returns " +
case FloatType | DoubleType if !conf.isFloatFormatNumberEnabled =>
willNotWorkOnGpu("format_number with floating point types on the GPU returns " +
"results that have a different precision than the default results of Spark. " +
"To enable this operation on the GPU, set" +
s" ${RapidsConf.ENABLE_FLOAT_FORMAT_NUMBER} to true.")
}
}
case dt: DecimalType => {
if (dt.scale > 32) {
willNotWorkOnGpu("format_number will generate results mismatched from Spark " +
"when the scale is larger than 32.")
}
}
case _ =>
}
}
Expand Down

0 comments on commit f901176

Please sign in to comment.