diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md
index 0c4defb32ac..ac073b374fe 100644
--- a/docs/additional-functionality/advanced_configs.md
+++ b/docs/additional-functionality/advanced_configs.md
@@ -110,7 +110,7 @@ Name | Description | Default Value | Applicable at
spark.rapids.sql.format.parquet.reader.type|Sets the Parquet reader type. We support different types that are optimized for different environments. The original Spark style reader can be selected by setting this to PERFILE which individually reads and copies files to the GPU. Loading many small files individually has high overhead, and using either COALESCING or MULTITHREADED is recommended instead. The COALESCING reader is good when using a local file system where the executors are on the same nodes or close to the nodes the data is being read on. This reader coalesces all the files assigned to a task into a single host buffer before sending it down to the GPU. It copies blocks from a single file into a host buffer in separate threads in parallel, see spark.rapids.sql.multiThreadedRead.numThreads. MULTITHREADED is good for cloud environments where you are reading from a blobstore that is totally separate and likely has a higher I/O read cost. Many times the cloud environments also get better throughput when you have multiple readers in parallel. This reader uses multiple threads to read each file in parallel and each file is sent to the GPU separately. This allows the CPU to keep reading while GPU is also doing work. See spark.rapids.sql.multiThreadedRead.numThreads and spark.rapids.sql.format.parquet.multiThreadedRead.maxNumFilesParallel to control the number of threads and amount of memory used. By default this is set to AUTO so we select the reader we think is best. This will either be the COALESCING or the MULTITHREADED based on whether we think the file is in the cloud. See spark.rapids.cloudSchemes.|AUTO|Runtime
spark.rapids.sql.format.parquet.write.enabled|When set to false disables parquet output acceleration|true|Runtime
spark.rapids.sql.format.parquet.writer.int96.enabled|When set to false, disables accelerated parquet write if the spark.sql.parquet.outputTimestampType is set to INT96|true|Runtime
-spark.rapids.sql.formatNumberFloat.enabled|format_number with floating point types on the GPU returns results that have a different precision than the default results of Spark.|false|Runtime
+spark.rapids.sql.formatNumberFloat.enabled|format_number with floating point types on the GPU returns results that have a different precision than the default results of Spark.|true|Runtime
spark.rapids.sql.hasExtendedYearValues|Spark 3.2.0+ extended parsing of years in dates and timestamps to support the full range of possible values. Prior to this it was limited to a positive 4 digit year. The Accelerator does not support the extended range yet. This config indicates if your data includes this extended range or not, or if you don't care about getting the correct values on values with the extended range.|true|Runtime
spark.rapids.sql.hashOptimizeSort.enabled|Whether sorts should be inserted after some hashed operations to improve output ordering. This can improve output file sizes when saving to columnar formats.|false|Runtime
spark.rapids.sql.improvedFloatOps.enabled|For some floating point operations spark uses one way to compute the value and the underlying cudf implementation can use an improved algorithm. In some cases this can result in cudf producing an answer when spark overflows.|true|Runtime
diff --git a/docs/compatibility.md b/docs/compatibility.md
index 2a950f9069e..11792b8a2f3 100644
--- a/docs/compatibility.md
+++ b/docs/compatibility.md
@@ -713,13 +713,13 @@ to `false`.
The Rapids Accelerator for Apache Spark uses uses a method based on [ryu](https://github.com/ulfjack/ryu) when converting floating point data type to string. As a result the computed string can differ from the output of Spark in some cases: sometimes the output is shorter (which is arguably more accurate) and sometimes the output may differ in the precise digits output.
-The `format_number` function will retain 10 digits of precision for the GPU when the input is a floating
-point number, but Spark will retain up to 17 digits of precision, i.e. `format_number(1234567890.1234567890, 5)`
-will return `1,234,567,890.00000` on the GPU and `1,234,567,890.12346` on the CPU. To enable this on the GPU, set [`spark.rapids.sql.formatNumberFloat.enabled`](additional-functionality/advanced_configs.md#sql.formatNumberFloat.enabled) to `true`.
-
This configuration is enabled by default. To disable this operation on the GPU set
[`spark.rapids.sql.castFloatToString.enabled`](additional-functionality/advanced_configs.md#sql.castFloatToString.enabled) to `false`.
+The `format_number` function also uses [ryu](https://github.com/ulfjack/ryu) as the solution when formatting floating-point data types to
+strings, so results may differ from Spark in the same way. To disable this on the GPU, set
+[`spark.rapids.sql.formatNumberFloat.enabled`](additional-functionality/advanced_configs.md#sql.formatNumberFloat.enabled) to `false`.
+
### String to Float
Casting from string to floating-point types on the GPU returns incorrect results when the string
diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py
index 1e871e85bd5..da4777e803f 100644
--- a/integration_tests/src/main/python/string_test.py
+++ b/integration_tests/src/main/python/string_test.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2023, NVIDIA CORPORATION.
+# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -822,25 +822,35 @@ def test_format_number_supported(data_gen):
'format_number(a, 100)')
)
-float_format_number_conf = {'spark.rapids.sql.formatNumberFloat.enabled': 'true'}
-format_number_float_gens = [DoubleGen(min_exp=-300, max_exp=15)]
+format_float_special_vals = [float('nan'), float('inf'), float('-inf'), 0.0, -0.0,
+ 1.1234543, 0.0000152, 0.0000252, 0.999999, 999990.0,
+ 0.001234, 0.00000078, 7654321.1234567]
-@pytest.mark.parametrize('data_gen', format_number_float_gens, ids=idfn)
-def test_format_number_float_limited(data_gen):
+@pytest.mark.parametrize('data_gen', [SetValuesGen(FloatType(), format_float_special_vals),
+ SetValuesGen(DoubleType(), format_float_special_vals)], ids=idfn)
+def test_format_number_float_special(data_gen):
gen = data_gen
- assert_gpu_and_cpu_are_equal_collect(
- lambda spark: unary_op_df(spark, gen).selectExpr(
- 'format_number(a, 5)'),
- conf = float_format_number_conf
- )
-
-# format_number for float/double is disabled by default due to compatibility issue
-# GPU will generate result with less precision than CPU
-@allow_non_gpu('ProjectExec')
-@pytest.mark.parametrize('data_gen', [float_gen, double_gen], ids=idfn)
-def test_format_number_float_fallback(data_gen):
- assert_gpu_fallback_collect(
- lambda spark: unary_op_df(spark, data_gen).selectExpr(
- 'format_number(a, 5)'),
- 'FormatNumber'
- )
+ cpu_results = with_cpu_session(lambda spark: unary_op_df(spark, gen).selectExpr(
+ 'format_number(a, 5)').collect())
+ gpu_results = with_gpu_session(lambda spark: unary_op_df(spark, gen).selectExpr(
+ 'format_number(a, 5)').collect())
+ for cpu, gpu in zip(cpu_results, gpu_results):
+ assert cpu[0] == gpu[0]
+
+def test_format_number_double_value():
+ data_gen = DoubleGen(nullable=False, no_nans=True)
+ cpu_results = list(map(lambda x: float(x[0].replace(",", "")), with_cpu_session(
+ lambda spark: unary_op_df(spark, data_gen).selectExpr('format_number(a, 5)').collect())))
+ gpu_results = list(map(lambda x: float(x[0].replace(",", "")), with_gpu_session(
+ lambda spark: unary_op_df(spark, data_gen).selectExpr('format_number(a, 5)').collect())))
+ for cpu, gpu in zip(cpu_results, gpu_results):
+ assert math.isclose(cpu, gpu, abs_tol=1.1e-5)
+
+def test_format_number_float_value():
+ data_gen = FloatGen(nullable=False, no_nans=True)
+ cpu_results = list(map(lambda x: float(x[0].replace(",", "")), with_cpu_session(
+ lambda spark: unary_op_df(spark, data_gen).selectExpr('format_number(a, 5)').collect())))
+ gpu_results = list(map(lambda x: float(x[0].replace(",", "")), with_gpu_session(
+ lambda spark: unary_op_df(spark, data_gen).selectExpr('format_number(a, 5)').collect())))
+ for cpu, gpu in zip(cpu_results, gpu_results):
+ assert math.isclose(cpu, gpu, rel_tol=1e-7) or math.isclose(cpu, gpu, abs_tol=1.1e-5)
\ No newline at end of file
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
index 41b3b920ad1..c260e517e9e 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
@@ -761,7 +761,7 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern")
.doc("format_number with floating point types on the GPU returns results that have " +
"a different precision than the default results of Spark.")
.booleanConf
- .createWithDefault(false)
+ .createWithDefault(true)
val ENABLE_CAST_FLOAT_TO_INTEGRAL_TYPES =
conf("spark.rapids.sql.castFloatToIntegralTypes.enabled")
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala
index febbf75ba58..d10aa40ce30 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.rapids
import java.nio.charset.Charset
-import java.text.DecimalFormatSymbols
-import java.util.{Locale, Optional}
+import java.util.Optional
import scala.collection.mutable.ArrayBuffer
@@ -2095,366 +2094,6 @@ case class GpuFormatNumber(x: Expression, d: Expression)
}
}
- private def getZeroCv(size: Int): ColumnVector = {
- withResource(Scalar.fromString("0")) { zero =>
- ColumnVector.fromScalar(zero, size)
- }
- }
-
- private def handleDoublePosExp(cv: ColumnVector, intPart: ColumnVector, decPart: ColumnVector,
- exp: ColumnVector, d: Int): (ColumnVector, ColumnVector) = {
- // handle cases when exp is positive
- // append "0" * zerosNum after end of strings, zerosNum = exp - decLen + d
- val expSubDecLen = withResource(decPart.getCharLengths) { decLen =>
- exp.sub(decLen)
- }
- val zerosNum = withResource(expSubDecLen) { _ =>
- withResource(Scalar.fromInt(d)) { dScalar =>
- expSubDecLen.add(dScalar)
- }
- }
- val zeroCv = withResource(Scalar.fromString("0")) { zero =>
- ColumnVector.fromScalar(zero, cv.getRowCount.toInt)
- }
- val zeros = withResource(zerosNum) { _ =>
- withResource(zeroCv) { _ =>
- zeroCv.repeatStrings(zerosNum)
- }
- }
-
- val intAndDecParts = withResource(zeros) { _ =>
- ColumnVector.stringConcatenate(Array(intPart, decPart, zeros))
- }
- // split intAndDecParts to intPart and decPart with substrings, start = len(intAndDecParts) - d
- closeOnExcept(ArrayBuffer.empty[ColumnVector]) { resourceArray =>
- val (intPartPosExp, decPartPosExpTemp) = withResource(intAndDecParts) { _ =>
- val (start, end) = withResource(intAndDecParts.getCharLengths) { partsLength =>
- (withResource(Scalar.fromInt(d)) { d =>
- partsLength.sub(d)
- }, partsLength.incRefCount())
- }
- withResource(start) { _ =>
- withResource(end) { _ =>
- val zeroIntCv = withResource(Scalar.fromInt(0)) { zero =>
- ColumnVector.fromScalar(zero, cv.getRowCount.toInt)
- }
- val intPart = withResource(zeroIntCv) { _ =>
- intAndDecParts.substring(zeroIntCv, start)
- }
- val decPart = closeOnExcept(intPart) { _ =>
- intAndDecParts.substring(start, end)
- }
- (intPart, decPart)
- }
- }
- }
- resourceArray += intPartPosExp
- // if decLen - exp > d, convert to float/double, round, convert back to string
- // decLen's max value is 9, abs(expPart)'s min value is 7, so it is possible only when d < 2
- // because d is small, we can use double to do the rounding
- val decPartPosExp = if (0 < d && d < 2) {
- val pointCv = closeOnExcept(decPartPosExpTemp) { _ =>
- withResource(Scalar.fromString(".")) { point =>
- ColumnVector.fromScalar(point, cv.getRowCount.toInt)
- }
- }
- val withPoint = withResource(decPartPosExpTemp) { _ =>
- withResource(pointCv) { pointCv =>
- ColumnVector.stringConcatenate(Array(pointCv, decPartPosExpTemp))
- }
- }
- val decimalTypeRounding = DType.create(DType.DTypeEnum.DECIMAL128, -9)
- val withPointDecimal = withResource(withPoint) { _ =>
- withResource(withPoint.castTo(decimalTypeRounding)) { decimal =>
- decimal.round(d, RoundMode.HALF_EVEN)
- }
- }
- val roundedString = withResource(withPointDecimal) { _ =>
- withPointDecimal.castTo(DType.STRING)
- }
- withResource(roundedString) { _ =>
- withResource(roundedString.stringSplit(".", 2)) { splited =>
- splited.getColumn(1).incRefCount()
- }
- }
- } else {
- decPartPosExpTemp
- }
- (intPartPosExp, decPartPosExp)
- }
- }
-
- private def handleDoubleNegExp(cv: ColumnVector, intPart: ColumnVector, decPart: ColumnVector,
- exp: ColumnVector, d: Int): (ColumnVector, ColumnVector) = {
- // handle cases when exp is negative
- // "0." + (- exp - 1) * "0" + intPart + decPart
- // if -1 - d <= exp and decLen - exp > d, need to rounding
- val cond1 = withResource(Scalar.fromInt(-1 - d)) { negOneSubD =>
- exp.greaterOrEqualTo(negOneSubD)
- }
- val cond2 = closeOnExcept(cond1) { _ =>
- val decLenSubExp = withResource(decPart.getCharLengths) { decLen =>
- decLen.sub(exp)
- }
- withResource(decLenSubExp) { _ =>
- withResource(Scalar.fromInt(d)) { d =>
- decLenSubExp.greaterThan(d)
- }
- }
- }
- val needRounding = withResource(cond1) { _ =>
- withResource(cond2) { _ =>
- cond1.and(cond2)
- }
- }
- val anyNeedRounding = withResource(needRounding) { _ =>
- withResource(needRounding.any()) { any =>
- any.isValid && any.getBoolean
- }
- }
- anyNeedRounding match {
- case false =>
- // a shortcut when no need to rounding
- // "0." + (- exp - 1) * "0" + intPart + decPart
- withResource(getZeroCv(cv.getRowCount.toInt)) { zeroCv =>
- val expSubOne = withResource(Scalar.fromInt(-1)) { negOne =>
- negOne.sub(exp)
- }
- val addingZeros = withResource(expSubOne) { _ =>
- zeroCv.repeatStrings(expSubOne)
- }
- val decPartNegExp = withResource(addingZeros) { _ =>
- ColumnVector.stringConcatenate(Array(addingZeros, intPart, decPart))
- }
- val decPartNegSubstr = withResource(decPartNegExp) { _ =>
- decPartNegExp.substring(0, d)
- }
- (zeroCv.incRefCount(), decPartNegSubstr)
- }
- case true =>
- // if -exp <= d + 1 && -exp + decLen + 1 > d, need to rounding
- // dec will be round to (d + exp + 1) digits
- val dExpOne = withResource(Scalar.fromInt(d + 1)) { dExpOne =>
- exp.add(dExpOne)
- }
- // To do a dataframe operation, add some zeros before
- // (intPat + decPart) and round them to 10
- // zerosNumRounding = (10 - (d + exp + 1)) . max(0)
- val tenSubDExpOne = withResource(dExpOne) { _ =>
- withResource(Scalar.fromInt(10)) { ten =>
- ten.sub(dExpOne)
- }
- }
- val zerosNumRounding = withResource(tenSubDExpOne) { _ =>
- withResource(Scalar.fromInt(0)) { zero =>
- withResource(tenSubDExpOne.lessThan(zero)) { lessThanZero =>
- lessThanZero.ifElse(zero, tenSubDExpOne)
- }
- }
- }
- val leadingZeros = withResource(zerosNumRounding) { _ =>
- withResource(getZeroCv(cv.getRowCount.toInt)) { zeroCv =>
- zeroCv.repeatStrings(zerosNumRounding)
- }
- }
- val numberToRoundStr = withResource(leadingZeros) { _ =>
- val zeroPointCv = withResource(Scalar.fromString("0.")) { point =>
- ColumnVector.fromScalar(point, cv.getRowCount.toInt)
- }
- withResource(zeroPointCv) { _ =>
- ColumnVector.stringConcatenate(Array(zeroPointCv, leadingZeros, intPart, decPart))
- }
- }
- // use a decimal type to round, set scale to -20 to keep all digits
- val decimalTypeRounding = DType.create(DType.DTypeEnum.DECIMAL128, -20)
- val numberToRound = withResource(numberToRoundStr) { _ =>
- numberToRoundStr.castTo(decimalTypeRounding)
- }
- // rounding 10 digits
- val rounded = withResource(numberToRound) { _ =>
- numberToRound.round(10, RoundMode.HALF_EVEN)
- }
- val roundedStr = withResource(rounded) { _ =>
- rounded.castTo(DType.STRING)
- }
- // substr 2 to remove "0."
- val roundedDecPart = withResource(roundedStr) { _ =>
- roundedStr.substring(2)
- }
- val decPartStriped = withResource(roundedDecPart) { _ =>
- withResource(Scalar.fromString("0")) { zero =>
- roundedDecPart.lstrip(zero)
- }
- }
- val decPartNegExp = withResource(decPartStriped) { _ =>
- decPartStriped.pad(d, PadSide.LEFT, "0")
- }
- closeOnExcept(decPartNegExp) { _ =>
- (getZeroCv(cv.getRowCount.toInt), decPartNegExp)
- }
- }
- }
-
- private def normalDoubleSplit(cv: ColumnVector, d: Int): (ColumnVector, ColumnVector) = {
- val roundingScale = d.min(10) // cuDF will keep at most 9 digits after decimal point
- val roundedStr = withResource(cv.round(roundingScale, RoundMode.HALF_EVEN)) { rounded =>
- rounded.castTo(DType.STRING)
- }
- val (intPart, decPart) = withResource(roundedStr) { _ =>
- withResource(roundedStr.stringSplit(".", 2)) { intAndDec =>
- (intAndDec.getColumn(0).incRefCount(), intAndDec.getColumn(1).incRefCount())
- }
- }
- val intPartNoNeg = closeOnExcept(decPart) { _ =>
- withResource(intPart) { _ =>
- removeNegSign(intPart)
- }
- }
- val decPartPad = closeOnExcept(intPartNoNeg) { _ =>
- withResource(decPart) { _ =>
- decPart.pad(d, PadSide.RIGHT, "0")
- }
- }
- // a workaround for cuDF float to string, e.g. 12.3 => "12.30000019" instead of "12.3"
- val decPartSubstr = closeOnExcept(intPartNoNeg) { _ =>
- withResource(decPartPad) { _ =>
- decPartPad.substring(0, d)
- }
- }
- (intPartNoNeg, decPartSubstr)
- }
-
- private def expDoubleSplit(cv: ColumnVector, d: Int): (ColumnVector, ColumnVector) = {
- // handle special case: 1.234567e+7 or 1.234567e-6
- // get three parts first:
- val replaceDelimToE = withResource(Scalar.fromString("e")) { e =>
- withResource(Scalar.fromString(".")) { p =>
- cv.stringReplace(e, p)
- }
- }
- // get three parts: 1.234567e+7 -> 1, 234567, +7
- val (intPartSign, decPart, expPart) = withResource(replaceDelimToE) { _ =>
- withResource(replaceDelimToE.stringSplit(".", 3)) { intDecExp =>
- (intDecExp.getColumn(0).incRefCount(),
- intDecExp.getColumn(1).incRefCount(),
- intDecExp.getColumn(2).incRefCount())
- }
- }
- // sign will be handled later, use string-based solution instead abs to avoid overfolw
- val intPart = closeOnExcept(decPart) { _ =>
- closeOnExcept(expPart) { _ =>
- withResource(intPartSign) { _ =>
- removeNegSign(intPartSign)
- }
- }
- }
- val exp = closeOnExcept(decPart) { _ =>
- closeOnExcept(intPart) { _ =>
- withResource(expPart) { _ =>
- expPart.castTo(DType.INT32)
- }
- }
- }
- // handle positive and negative exp separately
- val (intPartPosExp, decPartPosExp) = closeOnExcept(intPart) { _ =>
- closeOnExcept(decPart) { _ =>
- closeOnExcept(exp) { _ =>
- handleDoublePosExp(cv, intPart, decPart, exp, d)
- }
- }
- }
- withResource(ArrayBuffer.empty[ColumnVector]) { resourceArray =>
- val (intPartNegExp, decPartNegExp) = withResource(intPart) { _ =>
- withResource(decPart) { _ =>
- closeOnExcept(exp) { _ =>
- handleDoubleNegExp(cv, intPart, decPart, exp, d)
- }
- }
- }
- resourceArray += intPartNegExp
- resourceArray += decPartNegExp
- val expPos = withResource(exp) { _ =>
- withResource(Scalar.fromInt(0)) { zero =>
- exp.greaterOrEqualTo(zero)
- }
- }
- // combine results
- withResource(expPos) { _ =>
- val intPartExp = withResource(intPartPosExp) { _ =>
- expPos.ifElse(intPartPosExp, intPartNegExp)
- }
- val decPartExp = closeOnExcept(intPartExp) { _ =>
- withResource(decPartPosExp) { _ =>
- expPos.ifElse(decPartPosExp, decPartNegExp)
- }
- }
- (intPartExp, decPartExp)
- }
- }
- }
-
- private def getPartsFromDouble(cv: ColumnVector, d: Int): (ColumnVector, ColumnVector) = {
- // handle normal case: 1234.567
- closeOnExcept(ArrayBuffer.empty[ColumnVector]) { resourceArray =>
- val (normalInt, normalDec) = normalDoubleSplit(cv, d)
- resourceArray += normalInt
- resourceArray += normalDec
- // first check special case
- val cvStr = withResource(cv.castTo(DType.STRING)) { cvStr =>
- cvStr.incRefCount()
- }
- val containsE = closeOnExcept(cvStr) { _ =>
- withResource(Scalar.fromString("e")) { e =>
- cvStr.stringContains(e)
- }
- }
- withResource(containsE) { _ =>
- // if no special case, return normal case directly
- val anyExp = closeOnExcept(cvStr) { _ =>
- withResource(containsE.any()) { any =>
- any.isValid && any.getBoolean
- }
- }
- anyExp match {
- case false => {
- cvStr.safeClose()
- (normalInt, normalDec)
- }
- case true => {
- val noEReplaced = withResource(cvStr) { _ =>
- // replace normal case with 0e0 to avoid error
- withResource(Scalar.fromString("0.0e0")) { default =>
- containsE.ifElse(cvStr, default)
- }
- }
- // handle scientific notation case:
- val (expInt, expDec) = withResource(noEReplaced) { _ =>
- expDoubleSplit(noEReplaced, d)
- }
- // combine results
- // remove normalInt from resourceArray
- resourceArray.remove(0)
- val intPart = closeOnExcept(expDec) { _ =>
- withResource(expInt) { _ =>
- withResource(normalInt) { _ =>
- containsE.ifElse(expInt, normalInt)
- }
- }
- }
- resourceArray.clear()
- resourceArray += intPart
- val decPart = withResource(expDec) { _ =>
- withResource(normalDec) { _ =>
- containsE.ifElse(expDec, normalDec)
- }
- }
- (intPart, decPart)
- }
- }
- }
- }
- }
-
private def getPartsFromDecimal(cv: ColumnVector, d: Int, scale: Int):
(ColumnVector, ColumnVector) = {
// prevent d too large to fit in decimalType
@@ -2523,9 +2162,6 @@ case class GpuFormatNumber(x: Expression, d: Expression)
private def getParts(cv: ColumnVector, d: Int): (ColumnVector, ColumnVector) = {
// get int part and dec part from a column vector, int part will be set to positive
x.dataType match {
- case FloatType | DoubleType => {
- getPartsFromDouble(cv, d)
- }
case DecimalType.Fixed(_, scale) => {
getPartsFromDecimal(cv, d, scale)
}
@@ -2588,69 +2224,8 @@ case class GpuFormatNumber(x: Expression, d: Expression)
}
}
- private def handleInfAndNan(cv: ColumnVector, res: ColumnVector): ColumnVector = {
- // replace inf and nan with infSymbol and nanSymbol in res according to cv
- val symbols = DecimalFormatSymbols.getInstance(Locale.US)
- val nanSymbol = symbols.getNaN
- val infSymbol = symbols.getInfinity
- val negInfSymbol = "-" + infSymbol
- val handleNan = withResource(cv.isNan()) { isNan =>
- withResource(Scalar.fromString(nanSymbol)) { nan =>
- isNan.ifElse(nan, res)
- }
- }
- val isInf = closeOnExcept(handleNan) { _ =>
- x.dataType match {
- case DoubleType => {
- withResource(Scalar.fromDouble(Double.PositiveInfinity)) { inf =>
- cv.equalTo(inf)
- }
- }
- case FloatType => {
- withResource(Scalar.fromFloat(Float.PositiveInfinity)) { inf =>
- cv.equalTo(inf)
- }
- }
- }
- }
- val handleInf = withResource(isInf) { _ =>
- withResource(handleNan) { _ =>
- withResource(Scalar.fromString(infSymbol)) { inf =>
- isInf.ifElse(inf, handleNan)
- }
- }
- }
- val isNegInf = closeOnExcept(handleInf) { _ =>
- x.dataType match {
- case DoubleType => {
- withResource(Scalar.fromDouble(Double.NegativeInfinity)) { negInf =>
- cv.equalTo(negInf)
- }
- }
- case FloatType => {
- withResource(Scalar.fromFloat(Float.NegativeInfinity)) { negInf =>
- cv.equalTo(negInf)
- }
- }
- }
- }
- val handleNegInf = withResource(isNegInf) { _ =>
- withResource(Scalar.fromString(negInfSymbol)) { negInf =>
- withResource(handleInf) { _ =>
- isNegInf.ifElse(negInf, handleInf)
- }
- }
- }
- handleNegInf
- }
-
- override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
- // get int d from rhs
- if (!rhs.isValid || rhs.getValue.asInstanceOf[Int] < 0) {
- return GpuColumnVector.columnVectorFromNull(lhs.getRowCount.toInt, StringType)
- }
- val d = rhs.getValue.asInstanceOf[Int]
- val (integerPart, decimalPart) = getParts(lhs.getBase, d)
+ private def formatNumberNonKernel(cv: ColumnVector, d: Int): ColumnVector = {
+ val (integerPart, decimalPart) = getParts(cv, d)
// reverse integer part for adding commas
val resWithDecimalPart = withResource(decimalPart) { _ =>
val reversedIntegerPart = withResource(integerPart) { intPart =>
@@ -2682,13 +2257,13 @@ case class GpuFormatNumber(x: Expression, d: Expression)
}
// add negative sign back
val negCv = withResource(Scalar.fromString("-")) { negativeSign =>
- ColumnVector.fromScalar(negativeSign, lhs.getRowCount.toInt)
+ ColumnVector.fromScalar(negativeSign, cv.getRowCount.toInt)
}
val formated = withResource(resWithDecimalPart) { _ =>
val resWithNeg = withResource(negCv) { _ =>
ColumnVector.stringConcatenate(Array(negCv, resWithDecimalPart))
}
- withResource(negativeCheck(lhs.getBase)) { isNegative =>
+ withResource(negativeCheck(cv)) { isNegative =>
withResource(resWithNeg) { _ =>
isNegative.ifElse(resWithNeg, resWithDecimalPart)
}
@@ -2696,12 +2271,12 @@ case class GpuFormatNumber(x: Expression, d: Expression)
}
// handle null case
val anyNull = closeOnExcept(formated) { _ =>
- lhs.getBase.getNullCount > 0
+ cv.getNullCount > 0
}
val formatedWithNull = anyNull match {
case true => {
withResource(formated) { _ =>
- withResource(lhs.getBase.isNull) { isNull =>
+ withResource(cv.isNull) { isNull =>
withResource(Scalar.fromNull(DType.STRING)) { nullScalar =>
isNull.ifElse(nullScalar, formated)
}
@@ -2710,14 +2285,22 @@ case class GpuFormatNumber(x: Expression, d: Expression)
}
case false => formated
}
- // handle inf and nan
+ formatedWithNull
+ }
+
+ override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
+ // get int d from rhs
+ if (!rhs.isValid || rhs.getValue.asInstanceOf[Int] < 0) {
+ return GpuColumnVector.columnVectorFromNull(lhs.getRowCount.toInt, StringType)
+ }
+ val d = rhs.getValue.asInstanceOf[Int]
x.dataType match {
case FloatType | DoubleType => {
- withResource(formatedWithNull) { _ =>
- handleInfAndNan(lhs.getBase, formatedWithNull)
- }
+ CastStrings.fromFloatWithFormat(lhs.getBase, d)
+ }
+ case _ => {
+ formatNumberNonKernel(lhs.getBase, d)
}
- case _ => formatedWithNull
}
}
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala
index 4f64839acfd..3c3933946c5 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala
@@ -220,35 +220,6 @@ class RegExpUtilsSuite extends AnyFunSuite {
}
}
-class FormatNumberSuite extends SparkQueryCompareTestSuite {
- def testFormatNumberDf(session: SparkSession): DataFrame = {
- import session.sqlContext.implicits._
- Seq[java.lang.Float](
- -0.0f,
- 0.0f,
- Float.PositiveInfinity,
- Float.NegativeInfinity,
- Float.NaN,
- 1.0f,
- 1.2345f,
- 123456789.0f,
- 123456789.123456789f,
- 0.00123456789f,
- 0.0000000123456789f,
- 1.0000000123456789f
- ).toDF("doubles")
- }
-
- testSparkResultsAreEqual("Test format_number float",
- testFormatNumberDf,
- conf = new SparkConf().set("spark.rapids.sql.formatNumberFloat.enabled", "true")) {
- frame => frame.selectExpr("format_number(doubles, -1)",
- "format_number(doubles, 0)",
- "format_number(doubles, 1)",
- "format_number(doubles, 5)")
- }
-}
-
/*
* This isn't actually a test. It's just useful to help visualize what's going on when there are
* differences present.