Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support format_number #9281

Merged
merged 22 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
62dc4a9
wip
thirtiseven Sep 8, 2023
dc4419c
Merge branch 'NVIDIA:branch-23.10' into format_number
thirtiseven Sep 8, 2023
59af888
wip
thirtiseven Sep 12, 2023
0ca68db
Merge branch 'NVIDIA:branch-23.10' into format_number
thirtiseven Sep 13, 2023
e273c6a
support format_number for integral and decimal type
thirtiseven Sep 21, 2023
2f664a6
support double/float normal cases
thirtiseven Sep 22, 2023
11290b8
Merge branch 'NVIDIA:branch-23.10' into format_number
thirtiseven Sep 22, 2023
40f48c2
support scientific notation double/float with positive exp
thirtiseven Sep 22, 2023
4e7af76
support scientific notation double/float with negative exp
thirtiseven Sep 25, 2023
e60dfb9
bug fixed and clean up
thirtiseven Sep 25, 2023
8d0d6a4
refactor and memory leak fix
thirtiseven Sep 26, 2023
2f14e18
Handle resource pair as a whole
thirtiseven Sep 26, 2023
845984e
fix more memory leak
thirtiseven Sep 27, 2023
68a3b2f
address some comments
thirtiseven Sep 27, 2023
2708cf7
add a config to control float/double enabling
thirtiseven Sep 27, 2023
9c4eff8
fixed a bug in neg exp get parts
thirtiseven Sep 27, 2023
28d06ac
fixed another bug and add float scala test
thirtiseven Sep 27, 2023
ed12c40
add some comments and use lstrip to remove neg sign
thirtiseven Sep 27, 2023
0889332
fix memory leaks
thirtiseven Sep 28, 2023
2fb9430
minor changes
thirtiseven Sep 28, 2023
f5d4000
fallback decimal with high scale
thirtiseven Sep 28, 2023
c3f1004
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides…
thirtiseven Sep 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/additional-functionality/advanced_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Expm1"></a>spark.rapids.sql.expression.Expm1|`expm1`|Euler's number e raised to a power minus 1|true|None|
<a name="sql.expression.Flatten"></a>spark.rapids.sql.expression.Flatten|`flatten`|Creates a single array from an array of arrays|true|None|
<a name="sql.expression.Floor"></a>spark.rapids.sql.expression.Floor|`floor`|Floor of a number|true|None|
<a name="sql.expression.FormatNumber"></a>spark.rapids.sql.expression.FormatNumber|`format_number`|Formats the number x like '#,###,###.##', rounded to d decimal places.|true|None|
<a name="sql.expression.FromUTCTimestamp"></a>spark.rapids.sql.expression.FromUTCTimestamp|`from_utc_timestamp`|Render the input UTC timestamp in the input timezone|true|None|
<a name="sql.expression.FromUnixTime"></a>spark.rapids.sql.expression.FromUnixTime|`from_unixtime`|Get the string from a unix timestamp|true|None|
<a name="sql.expression.GetArrayItem"></a>spark.rapids.sql.expression.GetArrayItem| |Gets the field at `ordinal` in the Array|true|None|
Expand Down
4 changes: 4 additions & 0 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,10 @@ The GPU will use different precision than Java's toString method when converting
types to strings. The GPU uses a lowercase `e` prefix for an exponent while Spark uses uppercase
`E`. As a result the computed string can differ from the default behavior in Spark.

The `format_number` function will retain 10 digits of precision for the GPU when the input is a floating
revans2 marked this conversation as resolved.
Show resolved Hide resolved
point number, but Spark will retain up to 17 digits of precision, i.e. `format_number(1234567890.1234567890, 5)`
will return `1,234,567,890.00000` on the GPU and `1,234,567,890.12346` on the CPU.

Starting from 22.06 this conf is enabled by default, to disable this operation on the GPU, set
[`spark.rapids.sql.castFloatToString.enabled`](configs.md#sql.castFloatToString.enabled) to `false`.

Expand Down
150 changes: 109 additions & 41 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -6461,23 +6461,23 @@ are limited.
<td> </td>
</tr>
<tr>
<td rowSpan="3">FromUTCTimestamp</td>
<td rowSpan="3">`from_utc_timestamp`</td>
<td rowSpan="3">Render the input UTC timestamp in the input timezone</td>
<td rowSpan="3">FormatNumber</td>
<td rowSpan="3">`format_number`</td>
<td rowSpan="3">Formats the number x like '#,###,###.##', rounded to d decimal places.</td>
<td rowSpan="3">None</td>
<td rowSpan="3">project</td>
<td>timestamp</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>x</td>
<td> </td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
Expand All @@ -6487,17 +6487,17 @@ are limited.
<td> </td>
</tr>
<tr>
<td>timezone</td>
<td> </td>
<td>d</td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>Literal value only</em></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>Only timezones equivalent to UTC are supported</em></td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
Expand All @@ -6517,8 +6517,8 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
Expand Down Expand Up @@ -6555,6 +6555,74 @@ are limited.
<th>UDT</th>
</tr>
<tr>
<td rowSpan="3">FromUTCTimestamp</td>
<td rowSpan="3">`from_utc_timestamp`</td>
<td rowSpan="3">Render the input UTC timestamp in the input timezone</td>
<td rowSpan="3">None</td>
<td rowSpan="3">project</td>
<td>timestamp</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>timezone</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>Only timezones equivalent to UTC are supported</em></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="3">FromUnixTime</td>
<td rowSpan="3">`from_unixtime`</td>
<td rowSpan="3">Get the string from a unix timestamp</td>
Expand Down Expand Up @@ -6874,6 +6942,32 @@ are limited.
<td><b>NS</b></td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="2">GetStructField</td>
<td rowSpan="2"> </td>
<td rowSpan="2">Gets the named field of the struct</td>
Expand Down Expand Up @@ -6921,32 +7015,6 @@ are limited.
<td><b>NS</b></td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="3">GetTimestamp</td>
<td rowSpan="3"> </td>
<td rowSpan="3">Gets timestamps from strings using given pattern.</td>
Expand Down
28 changes: 28 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,3 +797,31 @@ def test_conv_dec_to_from_hex(from_base, to_base, pattern):
lambda spark: unary_op_df(spark, gen).select('a', f.conv(f.col('a'), from_base, to_base)),
conf={'spark.rapids.sql.expression.Conv': True}
)

format_number_gens = integral_gens + [DecimalGen(precision=7, scale=7), DecimalGen(precision=18, scale=0),
DecimalGen(precision=18, scale=3), DecimalGen(precision=36, scale=5),
DecimalGen(precision=36, scale=-5), DecimalGen(precision=38, scale=10),
DecimalGen(precision=38, scale=-10)]

@pytest.mark.parametrize('data_gen', format_number_gens, ids=idfn)
def test_format_number_supported(data_gen):
gen = data_gen
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'format_number(a, -2)',
'format_number(a, 0)',
'format_number(a, 1)',
'format_number(a, 5)',
'format_number(a, 10)',
'format_number(a, 100)')
)

format_number_float_gens = [DoubleGen(min_exp=-300, max_exp=-32), DoubleGen(min_exp=-13, max_exp=15)]

@pytest.mark.parametrize('data_gen', format_number_float_gens, ids=idfn)
def test_format_number_float(data_gen):
gen = data_gen
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'format_number(a, 5)')
)
Original file line number Diff line number Diff line change
Expand Up @@ -3086,6 +3086,16 @@ object GpuOverrides extends Logging {
|For instance decimal strings not longer than 18 characters / hexadecimal strings
|not longer than 15 characters disregarding the sign cannot cause an overflow.
""".stripMargin.replaceAll("\n", " ")),
expr[FormatNumber](
"Formats the number x like '#,###,###.##', rounded to d decimal places.",
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.STRING,
("x", TypeSig.gpuNumeric, TypeSig.cpuNumeric),
firestarman marked this conversation as resolved.
Show resolved Hide resolved
("d", TypeSig.lit(TypeEnum.INT), TypeSig.INT+TypeSig.STRING)),
(in, conf, p, r) => new BinaryExprMeta[FormatNumber](in, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuFormatNumber(lhs, rhs)
}
),
expr[MapConcat](
"Returns the union of all the given maps",
ExprChecks.projectOnly(TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
Expand Down
Loading