diff --git a/integration_tests/src/main/python/orc_cast_test.py b/integration_tests/src/main/python/orc_cast_test.py index 48dbad54e51..c10d0b01570 100644 --- a/integration_tests/src/main/python/orc_cast_test.py +++ b/integration_tests/src/main/python/orc_cast_test.py @@ -102,8 +102,6 @@ def test_casting_from_float_and_double(spark_tmp_path, to_type): lambda spark: spark.read.schema(schema_str).orc(orc_path) ) - -@datagen_overrides(seed=0, reason='https://github.com/NVIDIA/spark-rapids/issues/10017') @pytest.mark.parametrize('data_gen', [DoubleGen(max_exp=32, special_cases=None), DoubleGen(max_exp=32, special_cases=[8.88e9, 9.99e10, 1.314e11])]) @allow_non_gpu(*non_utc_allow_orc_scan) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala index cabd9258d11..f638a7ba26a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala @@ -363,17 +363,39 @@ object GpuOrcScan { // val doubleMillis = doubleValue * 1000, // val milliseconds = Math.round(doubleMillis) // if (noOverflow) { milliseconds } else { null } + + // java.lang.Math.round is a true half up, meaning rounding towards positive infinity + // even for negative numbers + // assert(Math.round(-1.5) = -1 + // assert(Math.round(1.5) = 2 + // + // libcudf, Spark implement it half up in a half away from zero fashion + // >> sql("SELECT ROUND(-1.5D, 0), ROUND(-0.5D, 0), ROUND(0.5D, 0)").show(truncate=False) + // +--------------+--------------+-------------+ + // |round(-1.5, 0)|round(-0.5, 0)|round(0.5, 0)| + // +--------------+--------------+-------------+ + // |-2.0 |-1.0 |1.0 | + // +--------------+--------------+-------------+ + // + // Math.round half up can be implemented in terms of floor + // Math.round(x) = n iff x is in [n-0.5, n+0.5) iff x+0.5 is in [n,n+1) iff floor(x+0.5) = n + // val milliseconds = withResource(Scalar.fromDouble(DateTimeConstants.MILLIS_PER_SECOND)) { thousand => // ORC assumes value is in seconds withResource(col.mul(thousand, DType.FLOAT64)) { doubleMillis => - withResource(doubleMillis.round()) { millis => - withResource(getOverflowFlags(doubleMillis, millis)) { overflowFlags => - millis.copyWithBooleanColumnAsValidity(overflowFlags) + withResource(Scalar.fromDouble(0.5)) { half => + withResource(doubleMillis.add(half)) { doubleMillisPlusHalf => + withResource(doubleMillisPlusHalf.floor()) { millis => + withResource(getOverflowFlags(doubleMillis, millis)) { overflowFlags => + millis.copyWithBooleanColumnAsValidity(overflowFlags) + } + } } } } } + // Cast milli-seconds to micro-seconds // We need to pay attention that when convert (milliSeconds * 1000) to INT64, there may be // INT64-overflow.