Skip to content

Commit

Permalink
Implement Math.round using floor on GPU
Browse files Browse the repository at this point in the history
Fixes #10017

Spark and libcudf round are more like rould half away from zero

Reimplementing Math.round from Orc TimestampFromDouble conversion
using floor

Signed-off-by: Gera Shegalov <[email protected]>
  • Loading branch information
gerashegalov committed Dec 23, 2023
1 parent bb235c9 commit f91b7e9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
2 changes: 0 additions & 2 deletions integration_tests/src/main/python/orc_cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ def test_casting_from_float_and_double(spark_tmp_path, to_type):
lambda spark: spark.read.schema(schema_str).orc(orc_path)
)


@datagen_overrides(seed=0, reason='https://github.com/NVIDIA/spark-rapids/issues/10017')
@pytest.mark.parametrize('data_gen', [DoubleGen(max_exp=32, special_cases=None),
DoubleGen(max_exp=32, special_cases=[8.88e9, 9.99e10, 1.314e11])])
@allow_non_gpu(*non_utc_allow_orc_scan)
Expand Down
28 changes: 25 additions & 3 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -363,17 +363,39 @@ object GpuOrcScan {
// val doubleMillis = doubleValue * 1000,
// val milliseconds = Math.round(doubleMillis)
// if (noOverflow) { milliseconds } else { null }

// java.lang.Math.round is a true half up, meaning rounding towards positive infinity
// even for negative numbers
// assert(Math.round(-1.5) = -1
// assert(Math.round(1.5) = 2
//
// libcudf, Spark implement it half up in a half away from zero fashion
// >> sql("SELECT ROUND(-1.5D, 0), ROUND(-0.5D, 0), ROUND(0.5D, 0)").show(truncate=False)
// +--------------+--------------+-------------+
// |round(-1.5, 0)|round(-0.5, 0)|round(0.5, 0)|
// +--------------+--------------+-------------+
// |-2.0 |-1.0 |1.0 |
// +--------------+--------------+-------------+
//
// Math.round half up can be implemented in terms of floor
// Math.round(x) = n iff x is in [n-0.5, n+0.5) iff x+0.5 is in [n,n+1) iff floor(x+0.5) = n
//
val milliseconds = withResource(Scalar.fromDouble(DateTimeConstants.MILLIS_PER_SECOND)) {
thousand =>
// ORC assumes value is in seconds
withResource(col.mul(thousand, DType.FLOAT64)) { doubleMillis =>
withResource(doubleMillis.round()) { millis =>
withResource(getOverflowFlags(doubleMillis, millis)) { overflowFlags =>
millis.copyWithBooleanColumnAsValidity(overflowFlags)
withResource(Scalar.fromDouble(0.5)) { half =>
withResource(doubleMillis.add(half)) { doubleMillisPlusHalf =>
withResource(doubleMillisPlusHalf.floor()) { millis =>
withResource(getOverflowFlags(doubleMillis, millis)) { overflowFlags =>
millis.copyWithBooleanColumnAsValidity(overflowFlags)
}
}
}
}
}
}

// Cast milli-seconds to micro-seconds
// We need to pay attention that when convert (milliSeconds * 1000) to INT64, there may be
// INT64-overflow.
Expand Down

0 comments on commit f91b7e9

Please sign in to comment.