From 212a0b1ad0aae9378352a92bdb951ecd7564d66f Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 14 Dec 2023 13:29:57 +0800 Subject: [PATCH 01/21] wip --- .../src/main/python/date_time_test.py | 23 +++++++++++-------- .../nvidia/spark/rapids/GpuOverrides.scala | 4 +++- .../rapids/shims/Spark320PlusShims.scala | 5 +++- .../rapids/shims/DayTimeIntervalShims.scala | 5 +++- 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 8b0ff3e5c68..c65593dd204 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -31,7 +31,7 @@ vals = [(-584, 1563), (1943, 1101), (2693, 2167), (2729, 0), (44, 1534), (2635, 3319), (1885, -2828), (0, 2463), (932, 2286), (0, 0)] @pytest.mark.parametrize('data_gen', vals, ids=idfn) -@allow_non_gpu(*non_utc_allow) +# @allow_non_gpu(*non_utc_allow) def test_timesub(data_gen): days, seconds = data_gen assert_gpu_and_cpu_are_equal_collect( @@ -40,25 +40,28 @@ def test_timesub(data_gen): .selectExpr("a - (interval {} days {} seconds)".format(days, seconds))) @pytest.mark.parametrize('data_gen', vals, ids=idfn) -@allow_non_gpu(*non_utc_allow) +# @allow_non_gpu(*non_utc_allow) def test_timeadd(data_gen): days, seconds = data_gen assert_gpu_and_cpu_are_equal_collect( # We are starting at year 0005 to make sure we don't go before year 0001 # and beyond year 10000 while doing TimeAdd lambda spark: unary_op_df(spark, TimestampGen(start=datetime(5, 1, 1, tzinfo=timezone.utc), end=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1) - .selectExpr("a + (interval {} days {} seconds)".format(days, seconds))) + .selectExpr("a + (interval {} days {} seconds)".format(days, seconds)), + conf = {'spark.rapids.sql.nonUTC.enabled': True}) @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') -@allow_non_gpu(*non_utc_allow) +# @allow_non_gpu(*non_utc_allow) def test_timeadd_daytime_column(): gen_list = [ # timestamp column max year is 1000 ('t', TimestampGen(end=datetime(1000, 1, 1, tzinfo=timezone.utc))), # max days is 8000 year, so added result will not be out of range - ('d', DayTimeIntervalGen(min_value=timedelta(days=0), max_value=timedelta(days=8000 * 365)))] + ('d', DayTimeIntervalGen(min_value=timedelta(days=1000 * 365), max_value=timedelta(days=1250 * 365)))] assert_gpu_and_cpu_are_equal_collect( - lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND")) + lambda spark: gen_df(spark, gen_list, length=200).selectExpr("t + d", "t", "d"), + # lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND"), + conf = {'spark.rapids.sql.nonUTC.enabled': True}) @pytest.mark.skipif(is_before_spark_350(), reason='DayTimeInterval overflow check for seconds is not supported before Spark 3.5.0') def test_interval_seconds_overflow_exception(): @@ -68,7 +71,7 @@ def test_interval_seconds_overflow_exception(): error_message="IllegalArgumentException") @pytest.mark.parametrize('data_gen', vals, ids=idfn) -@allow_non_gpu(*non_utc_allow) +# @allow_non_gpu(*non_utc_allow) def test_timeadd_from_subquery(data_gen): def fun(spark): @@ -77,10 +80,10 @@ def fun(spark): spark.sql("select a, ((select max(a) from testTime) + interval 1 day) as datePlus from testTime").createOrReplaceTempView("testTime2") return spark.sql("select * from testTime2 where datePlus > current_timestamp") - assert_gpu_and_cpu_are_equal_collect(fun) + assert_gpu_and_cpu_are_equal_collect(fun, conf = {'spark.rapids.sql.nonUTC.enabled': True}) @pytest.mark.parametrize('data_gen', vals, ids=idfn) -@allow_non_gpu(*non_utc_allow) +# @allow_non_gpu(*non_utc_allow) def test_timesub_from_subquery(data_gen): def fun(spark): @@ -89,7 +92,7 @@ def fun(spark): spark.sql("select a, ((select min(a) from testTime) - interval 1 day) as dateMinus from testTime").createOrReplaceTempView("testTime2") return spark.sql("select * from testTime2 where dateMinus < current_timestamp") - assert_gpu_and_cpu_are_equal_collect(fun) + assert_gpu_and_cpu_are_equal_collect(fun, conf = {'spark.rapids.sql.nonUTC.enabled': True}) # Should specify `spark.sql.legacy.interval.enabled` to test `DateAddInterval` after Spark 3.2.0, # refer to https://issues.apache.org/jira/browse/SPARK-34896 diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 379134f68cd..81bdd65b144 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1645,6 +1645,8 @@ object GpuOverrides extends Logging { .withPsNote(TypeEnum.CALENDAR, "month intervals are not supported"), TypeSig.CALENDAR)), (timeAdd, conf, p, r) => new BinaryExprMeta[TimeAdd](timeAdd, conf, p, r) { + override def isTimeZoneSupported = true + override def tagExprForGpu(): Unit = { GpuOverrides.extractLit(timeAdd.interval).foreach { lit => val intvl = lit.value.asInstanceOf[CalendarInterval] @@ -1655,7 +1657,7 @@ object GpuOverrides extends Logging { } override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - GpuTimeAdd(lhs, rhs) + GpuTimeAdd(lhs, rhs, timeAdd.timeZoneId) }), expr[DateAddInterval]( "Adds interval to date", diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index 78b495f0fcf..0beef391ad0 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -208,6 +208,9 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { ("interval", TypeSig.lit(TypeEnum.DAYTIME) + TypeSig.lit(TypeEnum.CALENDAR), TypeSig.DAYTIME + TypeSig.CALENDAR)), (timeAdd, conf, p, r) => new BinaryExprMeta[TimeAdd](timeAdd, conf, p, r) { + + override def isTimeZoneSupported = true + override def tagExprForGpu(): Unit = { GpuOverrides.extractLit(timeAdd.interval).foreach { lit => lit.dataType match { @@ -222,7 +225,7 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { } override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - GpuTimeAdd(lhs, rhs) + GpuTimeAdd(lhs, rhs, timeAdd.timeZoneId) }), GpuOverrides.expr[SpecifiedWindowFrame]( "Specification of the width of the group (or \"frame\") of input rows " + diff --git a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/DayTimeIntervalShims.scala b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/DayTimeIntervalShims.scala index 369deda4e5e..fc3fd0d65a9 100644 --- a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/DayTimeIntervalShims.scala +++ b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/DayTimeIntervalShims.scala @@ -50,6 +50,9 @@ object DayTimeIntervalShims { .withPsNote(TypeEnum.CALENDAR, "month intervals are not supported"), TypeSig.DAYTIME + TypeSig.CALENDAR)), (timeAdd, conf, p, r) => new BinaryExprMeta[TimeAdd](timeAdd, conf, p, r) { + + override def isTimeZoneSupported = true + override def tagExprForGpu(): Unit = { GpuOverrides.extractLit(timeAdd.interval).foreach { lit => lit.dataType match { @@ -64,7 +67,7 @@ object DayTimeIntervalShims { } override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - GpuTimeAdd(lhs, rhs) + GpuTimeAdd(lhs, rhs, timeAdd.timeZoneId) }), GpuOverrides.expr[Abs]( "Absolute value", From 4f520677561a1ab9e689d3019939344934a48f99 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 18 Dec 2023 16:31:39 +0800 Subject: [PATCH 02/21] Suport TimeAdd for non-UTC time zone Signed-off-by: Haoyang Li --- .../src/main/python/date_time_test.py | 23 +++-- .../sql/rapids/datetimeExpressions.scala | 63 ++----------- .../rapids/shims/datetimeExpressions.scala | 93 +++++++++++++++++-- .../rapids/shims/datetimeExpressions.scala | 39 +++++++- 4 files changed, 139 insertions(+), 79 deletions(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index c65593dd204..4c3c8b5e188 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -26,12 +26,13 @@ non_utc_tz_allow = ['ProjectExec'] if not is_utc() else [] # Others work in all supported time zones non_supported_tz_allow = ['ProjectExec'] if not is_supported_time_zone() else [] +non_supported_tz_allow_filter = ['ProjectExec', 'FilterExec'] if not is_supported_time_zone() else [] # We only support literal intervals for TimeSub vals = [(-584, 1563), (1943, 1101), (2693, 2167), (2729, 0), (44, 1534), (2635, 3319), (1885, -2828), (0, 2463), (932, 2286), (0, 0)] @pytest.mark.parametrize('data_gen', vals, ids=idfn) -# @allow_non_gpu(*non_utc_allow) +@allow_non_gpu(*non_supported_tz_allow) def test_timesub(data_gen): days, seconds = data_gen assert_gpu_and_cpu_are_equal_collect( @@ -40,7 +41,7 @@ def test_timesub(data_gen): .selectExpr("a - (interval {} days {} seconds)".format(days, seconds))) @pytest.mark.parametrize('data_gen', vals, ids=idfn) -# @allow_non_gpu(*non_utc_allow) +@allow_non_gpu(*non_supported_tz_allow) def test_timeadd(data_gen): days, seconds = data_gen assert_gpu_and_cpu_are_equal_collect( @@ -51,15 +52,15 @@ def test_timeadd(data_gen): conf = {'spark.rapids.sql.nonUTC.enabled': True}) @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') -# @allow_non_gpu(*non_utc_allow) +@allow_non_gpu(*non_supported_tz_allow) def test_timeadd_daytime_column(): gen_list = [ # timestamp column max year is 1000 ('t', TimestampGen(end=datetime(1000, 1, 1, tzinfo=timezone.utc))), # max days is 8000 year, so added result will not be out of range - ('d', DayTimeIntervalGen(min_value=timedelta(days=1000 * 365), max_value=timedelta(days=1250 * 365)))] + ('d', DayTimeIntervalGen(min_value=timedelta(days=1000 * 365), max_value=timedelta(days=1005 * 365)))] assert_gpu_and_cpu_are_equal_collect( - lambda spark: gen_df(spark, gen_list, length=200).selectExpr("t + d", "t", "d"), + lambda spark: gen_df(spark, gen_list, length=2048).selectExpr("t + d", "t", "d"), # lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND"), conf = {'spark.rapids.sql.nonUTC.enabled': True}) @@ -71,7 +72,7 @@ def test_interval_seconds_overflow_exception(): error_message="IllegalArgumentException") @pytest.mark.parametrize('data_gen', vals, ids=idfn) -# @allow_non_gpu(*non_utc_allow) +@allow_non_gpu(*non_supported_tz_allow_filter) def test_timeadd_from_subquery(data_gen): def fun(spark): @@ -83,7 +84,7 @@ def fun(spark): assert_gpu_and_cpu_are_equal_collect(fun, conf = {'spark.rapids.sql.nonUTC.enabled': True}) @pytest.mark.parametrize('data_gen', vals, ids=idfn) -# @allow_non_gpu(*non_utc_allow) +@allow_non_gpu(*non_supported_tz_allow) def test_timesub_from_subquery(data_gen): def fun(spark): @@ -138,21 +139,19 @@ def test_datediff(data_gen): 'datediff(a, date(null))', 'datediff(a, \'2016-03-02\')')) -hms_fallback = ['ProjectExec'] if not is_supported_time_zone() else [] - -@allow_non_gpu(*hms_fallback) +@allow_non_gpu(*non_supported_tz_allow) def test_hour(): assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('hour(a)'), conf = {'spark.rapids.sql.nonUTC.enabled': True}) -@allow_non_gpu(*hms_fallback) +@allow_non_gpu(*non_supported_tz_allow) def test_minute(): assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('minute(a)'), conf = {'spark.rapids.sql.nonUTC.enabled': True}) -@allow_non_gpu(*hms_fallback) +@allow_non_gpu(*non_supported_tz_allow) def test_second(): assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('second(a)'), diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 0cf4213bd4e..a66ac53f9d4 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -140,15 +140,15 @@ case class GpuYear(child: Expression) extends GpuDateUnaryExpression { input.getBase.year() } -abstract class GpuTimeMath( - start: Expression, +case class GpuDateAddInterval(start: Expression, interval: Expression, - timeZoneId: Option[String] = None) - extends ShimBinaryExpression - with GpuExpression - with TimeZoneAwareExpression - with ExpectsInputTypes - with Serializable { + timeZoneId: Option[String] = None, + ansiEnabled: Boolean = SQLConf.get.ansiEnabled) + extends ShimBinaryExpression + with GpuExpression + with TimeZoneAwareExpression + with ExpectsInputTypes + with Serializable { def this(start: Expression, interval: Expression) = this(start, interval, None) @@ -157,61 +157,16 @@ abstract class GpuTimeMath( override def toString: String = s"$left - $right" override def sql: String = s"${left.sql} - ${right.sql}" - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) - - override def dataType: DataType = TimestampType override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess val microSecondsInOneDay: Long = TimeUnit.DAYS.toMicros(1) - override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { - withResourceIfAllowed(left.columnarEval(batch)) { lhs => - withResourceIfAllowed(right.columnarEvalAny(batch)) { rhs => - (lhs, rhs) match { - case (l, intvlS: GpuScalar) - if intvlS.dataType.isInstanceOf[CalendarIntervalType] => - // Scalar does not support 'CalendarInterval' now, so use - // the Scala value instead. - // Skip the null check because it wll be detected by the following calls. - val intvl = intvlS.getValue.asInstanceOf[CalendarInterval] - if (intvl.months != 0) { - throw new UnsupportedOperationException("Months aren't supported at the moment") - } - val usToSub = intvl.days * microSecondsInOneDay + intvl.microseconds - if (usToSub != 0) { - withResource(Scalar.fromLong(usToSub)) { us_s => - withResource(l.getBase.bitCastTo(DType.INT64)) { us => - withResource(intervalMath(us_s, us)) { longResult => - GpuColumnVector.from(longResult.castTo(DType.TIMESTAMP_MICROSECONDS), dataType) - } - } - } - } else { - l.incRefCount() - } - case _ => - throw new UnsupportedOperationException("only column and interval arguments " + - s"are supported, got left: ${lhs.getClass} right: ${rhs.getClass}") - } - } - } - } - - def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector -} - -case class GpuDateAddInterval(start: Expression, - interval: Expression, - timeZoneId: Option[String] = None, - ansiEnabled: Boolean = SQLConf.get.ansiEnabled) - extends GpuTimeMath(start, interval, timeZoneId) { - override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = { copy(timeZoneId = Option(timeZoneId)) } - override def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = { + def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = { us.add(us_s) } diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index 8a37bd63ca5..4d7db37b17e 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -21,21 +21,96 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims -import ai.rapids.cudf.{ColumnVector, ColumnView, Scalar} +import java.util.concurrent.TimeUnit -import org.apache.spark.sql.catalyst.expressions.{Expression, TimeZoneAwareExpression} -import org.apache.spark.sql.rapids.GpuTimeMath +import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar} +import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar} +import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} +import com.nvidia.spark.rapids.GpuOverrides +import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.jni.GpuTimeZoneDB +import com.nvidia.spark.rapids.shims.ShimBinaryExpression -case class GpuTimeAdd(start: Expression, - interval: Expression, - timeZoneId: Option[String] = None) - extends GpuTimeMath(start, interval, timeZoneId) { +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.unsafe.types.CalendarInterval + +case class GpuTimeAdd( + start: Expression, + interval: Expression, + timeZoneId: Option[String] = None) + extends ShimBinaryExpression + with GpuExpression + with TimeZoneAwareExpression + with ExpectsInputTypes + with Serializable { + + def this(start: Expression, interval: Expression) = this(start, interval, None) + + override def left: Expression = start + override def right: Expression = interval + + override def toString: String = s"$left - $right" + override def sql: String = s"${left.sql} - ${right.sql}" + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) + + override def dataType: DataType = TimestampType + + override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess + + val microSecondsInOneDay: Long = TimeUnit.DAYS.toMicros(1) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = { copy(timeZoneId = Option(timeZoneId)) } - override def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = { - us.add(us_s) + override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { + withResourceIfAllowed(left.columnarEval(batch)) { lhs => + withResourceIfAllowed(right.columnarEvalAny(batch)) { rhs => + // lhs is start, rhs is interval + (lhs, rhs) match { + case (l, intvlS: GpuScalar) if intvlS.dataType.isInstanceOf[CalendarIntervalType] => + // Scalar does not support 'CalendarInterval' now, so use + // the Scala value instead. + // Skip the null check because it wll be detected by the following calls. + val intvl = intvlS.getValue.asInstanceOf[CalendarInterval] + if (intvl.months != 0) { + throw new UnsupportedOperationException("Months aren't supported at the moment") + } + val usToSub = intvl.days * microSecondsInOneDay + intvl.microseconds + if (usToSub != 0) { + val res = if (GpuOverrides.isUTCTimezone(zoneId)) { + withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, usToSub)) { d => + timestampAddDuration(l.getBase, d) + } + } else { + val utcRes = withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(l.getBase, + zoneId)) { utcTimestamp => + withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, usToSub)) { + d => timestampAddDuration(utcTimestamp, d) + } + } + withResource(utcRes) { _ => + GpuTimeZoneDB.fromTimestampToUtcTimestamp(utcRes, zoneId) + } + } + GpuColumnVector.from(res, dataType) + } else { + l.incRefCount() + } + case _ => + throw new UnsupportedOperationException("only column and interval arguments " + + s"are supported, got left: ${lhs.getClass} right: ${rhs.getClass}") + } + } + } + } + + private def timestampAddDuration(cv: ColumnView, duration: BinaryOperable): ColumnVector = { + // Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion, + // and currently BinaryOperable.implicitConversion return Long + // Directly specify the return type is TIMESTAMP_MICROSECONDS + cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS) } } diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index 0bb4f596927..a96ada87044 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -37,12 +37,15 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims +import java.time.ZoneId import java.util.concurrent.TimeUnit import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar} import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar} import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} +import com.nvidia.spark.rapids.GpuOverrides import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.jni.GpuTimeZoneDB import com.nvidia.spark.rapids.shims.ShimBinaryExpression import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression} @@ -104,9 +107,23 @@ case class GpuTimeAdd(start: Expression, // add interval if (interval != 0) { - withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, interval)) { d => - GpuColumnVector.from(timestampAddDuration(l.getBase, d), dataType) + val zoneID = ZoneId.of(timeZoneId.getOrElse("UTC")) + val resCv = if (GpuOverrides.isUTCTimezone(zoneId)) { + withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, interval)) { d => + timestampAddDuration(l.getBase, d) + } + } else { + val utcRes = withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(l.getBase, + zoneID)) { utcTimestamp => + withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, interval)) { + d => timestampAddDuration(utcTimestamp, d) + } + } + withResource(utcRes) { _ => + GpuTimeZoneDB.fromTimestampToUtcTimestamp(utcRes, zoneID) + } } + GpuColumnVector.from(resCv, dataType) } else { l.incRefCount() } @@ -115,9 +132,23 @@ case class GpuTimeAdd(start: Expression, case (_: TimestampType, _: DayTimeIntervalType) => // DayTimeIntervalType is stored as long // bitCastTo is similar to reinterpret_cast, it's fast, the time can be ignored. - withResource(r.getBase.bitCastTo(DType.DURATION_MICROSECONDS)) { duration => - GpuColumnVector.from(timestampAddDuration(l.getBase, duration), dataType) + val zoneID = ZoneId.of(timeZoneId.getOrElse("UTC")) + val resCv = if (GpuOverrides.isUTCTimezone(zoneId)) { + withResource(r.getBase.bitCastTo(DType.DURATION_MICROSECONDS)) { duration => + timestampAddDuration(l.getBase, duration) + } + } else { + val utcRes = withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(l.getBase, + zoneID)) { utcTimestamp => + withResource(r.getBase.bitCastTo(DType.DURATION_MICROSECONDS)) { duration => + timestampAddDuration(utcTimestamp, duration) + } + } + withResource(utcRes) { utc => + GpuTimeZoneDB.fromTimestampToUtcTimestamp(utc, zoneID) + } } + GpuColumnVector.from(resCv, dataType) case _ => throw new UnsupportedOperationException( "GpuTimeAdd takes column and interval as an argument only") From c7dc3046beb0d091a7597c8e84d83e846dddd5b2 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 18 Dec 2023 16:43:55 +0800 Subject: [PATCH 03/21] clean up tests Signed-off-by: Haoyang Li --- integration_tests/src/main/python/date_time_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 4c3c8b5e188..1fc13357406 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -81,7 +81,7 @@ def fun(spark): spark.sql("select a, ((select max(a) from testTime) + interval 1 day) as datePlus from testTime").createOrReplaceTempView("testTime2") return spark.sql("select * from testTime2 where datePlus > current_timestamp") - assert_gpu_and_cpu_are_equal_collect(fun, conf = {'spark.rapids.sql.nonUTC.enabled': True}) + assert_gpu_and_cpu_are_equal_collect(fun) @pytest.mark.parametrize('data_gen', vals, ids=idfn) @allow_non_gpu(*non_supported_tz_allow) @@ -93,7 +93,7 @@ def fun(spark): spark.sql("select a, ((select min(a) from testTime) - interval 1 day) as dateMinus from testTime").createOrReplaceTempView("testTime2") return spark.sql("select * from testTime2 where dateMinus < current_timestamp") - assert_gpu_and_cpu_are_equal_collect(fun, conf = {'spark.rapids.sql.nonUTC.enabled': True}) + assert_gpu_and_cpu_are_equal_collect(fun) # Should specify `spark.sql.legacy.interval.enabled` to test `DateAddInterval` after Spark 3.2.0, # refer to https://issues.apache.org/jira/browse/SPARK-34896 From 89c93050250219c310d15bee84568ddd28a210a1 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 18 Dec 2023 16:47:37 +0800 Subject: [PATCH 04/21] clean up Signed-off-by: Haoyang Li --- .../src/main/scala/com/nvidia/spark/rapids/literals.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala index 1bcdc58c612..f4f18491745 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala @@ -473,7 +473,7 @@ object GpuScalar extends Logging { * * This class is introduced because many expressions require both the cudf Scalar and its * corresponding Scala value to complete their computations. e.g. 'GpuStringSplit', - * 'GpuStringLocate', 'GpuDivide', 'GpuDateAddInterval', 'GpuTimeMath' ... + * 'GpuStringLocate', 'GpuDivide', 'GpuDateAddInterval', 'GpuTimeAdd' ... * So only either a cudf Scalar or a Scala value can not support such cases, unless copying data * between the host and the device each time being asked for. * @@ -493,7 +493,7 @@ object GpuScalar extends Logging { * happens. * * Another reason why storing the Scala value in addition to the cudf Scalar is - * `GpuDateAddInterval` and 'GpuTimeMath' have different algorithms with the 3 members of + * `GpuDateAddInterval` and 'GpuTimeAdd' have different algorithms with the 3 members of * a `CalendarInterval`, which can not be supported by a single cudf Scalar now. * * Do not create a GpuScalar from the constructor, instead call the factory APIs above. From 4c9485f91a5d8ed3229d195a8f626f0aa3df68ac Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 20 Dec 2023 09:44:15 +0800 Subject: [PATCH 05/21] remove config Signed-off-by: Haoyang Li --- integration_tests/src/main/python/date_time_test.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 1fc13357406..e1228e243f2 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -48,8 +48,7 @@ def test_timeadd(data_gen): # We are starting at year 0005 to make sure we don't go before year 0001 # and beyond year 10000 while doing TimeAdd lambda spark: unary_op_df(spark, TimestampGen(start=datetime(5, 1, 1, tzinfo=timezone.utc), end=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1) - .selectExpr("a + (interval {} days {} seconds)".format(days, seconds)), - conf = {'spark.rapids.sql.nonUTC.enabled': True}) + .selectExpr("a + (interval {} days {} seconds)".format(days, seconds))) @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @allow_non_gpu(*non_supported_tz_allow) @@ -60,9 +59,7 @@ def test_timeadd_daytime_column(): # max days is 8000 year, so added result will not be out of range ('d', DayTimeIntervalGen(min_value=timedelta(days=1000 * 365), max_value=timedelta(days=1005 * 365)))] assert_gpu_and_cpu_are_equal_collect( - lambda spark: gen_df(spark, gen_list, length=2048).selectExpr("t + d", "t", "d"), - # lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND"), - conf = {'spark.rapids.sql.nonUTC.enabled': True}) + lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND")) @pytest.mark.skipif(is_before_spark_350(), reason='DayTimeInterval overflow check for seconds is not supported before Spark 3.5.0') def test_interval_seconds_overflow_exception(): From f4e85d78969af8ea5536f6cbd186b53a52f5eb26 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 20 Dec 2023 09:54:50 +0800 Subject: [PATCH 06/21] clean up Signed-off-by: Haoyang Li --- integration_tests/src/main/python/date_time_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 142f509a903..0d3b6a5cb45 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -57,7 +57,7 @@ def test_timeadd_daytime_column(): # timestamp column max year is 1000 ('t', TimestampGen(end=datetime(1000, 1, 1, tzinfo=timezone.utc))), # max days is 8000 year, so added result will not be out of range - ('d', DayTimeIntervalGen(min_value=timedelta(days=1000 * 365), max_value=timedelta(days=1005 * 365)))] + ('d', DayTimeIntervalGen(min_value=timedelta(days=0), max_value=timedelta(days=8000 * 365)))] assert_gpu_and_cpu_are_equal_collect( lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND")) From 2de0e6eb03606ebbaf88acf53d81fa06417cd2d0 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 28 Dec 2023 18:53:41 +0800 Subject: [PATCH 07/21] Add long overflow check for 320+ Signed-off-by: Haoyang Li --- .../src/main/python/date_time_test.py | 9 ++++ .../rapids/shims/datetimeExpressions.scala | 50 ++++++++++++++++++- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index fa2180fb204..27b6f597827 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -61,6 +61,15 @@ def test_timeadd_daytime_column(): assert_gpu_and_cpu_are_equal_collect( lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND")) +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@allow_non_gpu(*non_supported_tz_allow) +def test_timeadd_daytime_column_long_overflow(): + gen_list = [('t', TimestampGen()),('d', DayTimeIntervalGen())] + assert_gpu_and_cpu_error( + lambda spark : gen_df(spark, gen_list).selectExpr("t + d").collect(), + conf={}, + error_message='long overflow') + @pytest.mark.skipif(is_before_spark_350(), reason='DayTimeInterval overflow check for seconds is not supported before Spark 3.5.0') def test_interval_seconds_overflow_exception(): assert_gpu_and_cpu_error( diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index 488e5fc1f23..e3411f9bc9f 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -44,7 +44,7 @@ import java.util.concurrent.TimeUnit import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar} import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar} -import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource, withResourceIfAllowed} import com.nvidia.spark.rapids.GpuOverrides import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.jni.GpuTimeZoneDB @@ -164,10 +164,56 @@ case class GpuTimeAdd(start: Expression, } } + // A tricky way to check overflow. The result is overflow when positive + positive = negative + // or negative + negative = positive, so we can check the sign of the result is the same as + // the sign of the operands. private def timestampAddDuration(cv: ColumnView, duration: BinaryOperable): ColumnVector = { // Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion, // and currently BinaryOperable.implicitConversion return Long // Directly specify the return type is TIMESTAMP_MICROSECONDS - cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS) + val resWithOverflow = cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS) + closeOnExcept(resWithOverflow) { _ => + val isCvPos = withResource( + Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, 0)) { zero => + cv.greaterOrEqualTo(zero) + } + val sameSignal = closeOnExcept(isCvPos) { isCvPos => + val isDurationPos = duration match { + case durScalar: Scalar => + val isPosBool = durScalar.isValid && durScalar.getLong >= 0 + Scalar.fromBool(isPosBool) + case dur : AutoCloseable => + withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, 0)) { zero => + dur.greaterOrEqualTo(zero) + } + } + withResource(isDurationPos) { _ => + isCvPos.equalTo(isDurationPos) + } + } + val isOverflow = withResource(sameSignal) { _ => + val sameSignalWithRes = withResource(isCvPos) { _ => + val isResNeg = withResource( + Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, 0)) { zero => + resWithOverflow.lessThan(zero) + } + withResource(isResNeg) { _ => + isCvPos.equalTo(isResNeg) + } + } + withResource(sameSignalWithRes) { _ => + sameSignal.and(sameSignalWithRes) + } + } + val anyOverflow = withResource(isOverflow) { _ => + isOverflow.any() + } + withResource(anyOverflow) { _ => + if (anyOverflow.isValid && anyOverflow.getBoolean) { + throw new ArithmeticException("long overflow") + } + } + } + resWithOverflow } } From 3d5f2b9632e4833cdf3d49845d5a38abd9fc47d9 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 29 Dec 2023 10:50:03 +0800 Subject: [PATCH 08/21] Add long overflow check for lower versions too Signed-off-by: Haoyang Li --- .../src/main/python/date_time_test.py | 16 ++++-- .../rapids/shims/datetimeExpressions.scala | 50 ++++++++++++++++++- 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 27b6f597827..cebfcf75ad6 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -36,8 +36,7 @@ def test_timesub(data_gen): days, seconds = data_gen assert_gpu_and_cpu_are_equal_collect( - # We are starting at year 0015 to make sure we don't go before year 0001 while doing TimeSub - lambda spark: unary_op_df(spark, TimestampGen(start=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1) + lambda spark: unary_op_df(spark, TimestampGen()) .selectExpr("a - (interval {} days {} seconds)".format(days, seconds))) @pytest.mark.parametrize('data_gen', vals, ids=idfn) @@ -45,11 +44,18 @@ def test_timesub(data_gen): def test_timeadd(data_gen): days, seconds = data_gen assert_gpu_and_cpu_are_equal_collect( - # We are starting at year 0005 to make sure we don't go before year 0001 - # and beyond year 10000 while doing TimeAdd - lambda spark: unary_op_df(spark, TimestampGen(start=datetime(5, 1, 1, tzinfo=timezone.utc), end=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1) + lambda spark: unary_op_df(spark, TimestampGen()) .selectExpr("a + (interval {} days {} seconds)".format(days, seconds))) +@pytest.mark.parametrize('data_gen', [-pow(2, 63), pow(2, 63)], ids=idfn) +@allow_non_gpu(*non_supported_tz_allow) +def test_timeadd_long_overflow(data_gen): + assert_gpu_and_cpu_error( + lambda spark: unary_op_df(spark, TimestampGen()) + .selectExpr("a + (interval {} microseconds)".format(data_gen)), + conf={}, + error_message='long overflow') + @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @allow_non_gpu(*non_supported_tz_allow) def test_timeadd_daytime_column(): diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index 4d7db37b17e..276f391bc4b 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -25,7 +25,7 @@ import java.util.concurrent.TimeUnit import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar} import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar} -import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource, withResourceIfAllowed} import com.nvidia.spark.rapids.GpuOverrides import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.jni.GpuTimeZoneDB @@ -107,10 +107,56 @@ case class GpuTimeAdd( } } + // A tricky way to check overflow. The result is overflow when positive + positive = negative + // or negative + negative = positive, so we can check the sign of the result is the same as + // the sign of the operands. private def timestampAddDuration(cv: ColumnView, duration: BinaryOperable): ColumnVector = { // Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion, // and currently BinaryOperable.implicitConversion return Long // Directly specify the return type is TIMESTAMP_MICROSECONDS - cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS) + val resWithOverflow = cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS) + closeOnExcept(resWithOverflow) { _ => + val isCvPos = withResource( + Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, 0)) { zero => + cv.greaterOrEqualTo(zero) + } + val sameSignal = closeOnExcept(isCvPos) { isCvPos => + val isDurationPos = duration match { + case durScalar: Scalar => + val isPosBool = durScalar.isValid && durScalar.getLong >= 0 + Scalar.fromBool(isPosBool) + case dur : AutoCloseable => + withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, 0)) { zero => + dur.greaterOrEqualTo(zero) + } + } + withResource(isDurationPos) { _ => + isCvPos.equalTo(isDurationPos) + } + } + val isOverflow = withResource(sameSignal) { _ => + val sameSignalWithRes = withResource(isCvPos) { _ => + val isResNeg = withResource( + Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, 0)) { zero => + resWithOverflow.lessThan(zero) + } + withResource(isResNeg) { _ => + isCvPos.equalTo(isResNeg) + } + } + withResource(sameSignalWithRes) { _ => + sameSignal.and(sameSignalWithRes) + } + } + val anyOverflow = withResource(isOverflow) { _ => + isOverflow.any() + } + withResource(anyOverflow) { _ => + if (anyOverflow.isValid && anyOverflow.getBoolean) { + throw new ArithmeticException("long overflow") + } + } + } + resWithOverflow } } From 20144958dabccbfa7224dd878362b07c39592046 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 29 Dec 2023 16:40:00 +0800 Subject: [PATCH 09/21] Add perf test Signed-off-by: Haoyang Li --- .../rapids/timezone/TimeZonePerfSuite.scala | 35 ++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala index 49ee394e904..99245829b75 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala @@ -132,7 +132,7 @@ class TimeZonePerfSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAl println(s"test,type,zone,used MS") for (zoneStr <- zones) { // run 6 rounds, but ignore the first round. - for (i <- 1 to 6) { + val elapses = (1 to 6).map { i => // run on Cpu val startOnCpu = System.nanoTime() withCpuSparkSession( @@ -153,8 +153,16 @@ class TimeZonePerfSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAl val elapseOnGpuMS = (endOnGpu - startOnGpu) / 1000000L if (i != 1) { println(s"$testName,Gpu,$zoneStr,$elapseOnGpuMS") + (elapseOnCpuMS, elapseOnGpuMS) + } else { + (0L, 0L) // skip the first round } } + val meanCpu = elapses.map(_._1).sum / 5.0 + val meanGpu = elapses.map(_._2).sum / 5.0 + val speedup = meanCpu.toDouble / meanGpu.toDouble + println(f"$testName, $zoneStr: mean cpu time: $meanCpu%.2f ms, " + + f"mean gpu time: $meanGpu%.2f ms, speedup: $speedup%.2f x") } } @@ -173,4 +181,29 @@ class TimeZonePerfSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAl runAndRecordTime("from_utc_timestamp", perfTest) } + + test("test timeadd") { + assume(enablePerfTest) + + // cache time zone DB in advance + GpuTimeZoneDB.cacheDatabase() + Thread.sleep(5L) + + def perfTest(spark: SparkSession, zone: String): DataFrame = { + spark.read.parquet(path).selectExpr( + "count(c_ts - (interval -584 days 1563 seconds))", + "count(c_ts - (interval 1943 days 1101 seconds))", + "count(c_ts - (interval 2693 days 2167 seconds))", + "count(c_ts - (interval 2729 days 0 seconds))", + "count(c_ts - (interval 44 days 1534 seconds))", + "count(c_ts - (interval 2635 days 3319 seconds))", + "count(c_ts - (interval 1885 days -2828 seconds))", + "count(c_ts - (interval 0 days 2463 seconds))", + "count(c_ts - (interval 932 days 2286 seconds))", + "count(c_ts - (interval 0 days 0 seconds))" + ) + } + + runAndRecordTime("time_add", perfTest) + } } From c6102f3d92f2c99d772445b9152a77613818a954 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 3 Jan 2024 17:33:36 +0800 Subject: [PATCH 10/21] wip Signed-off-by: Haoyang Li --- .../src/main/python/date_time_test.py | 55 ++++++++++----- .../apache/spark/sql/rapids/arithmetic.scala | 15 ++-- .../sql/rapids/datetimeExpressionsUtils.scala | 52 ++++++++++++++ .../rapids/shims/datetimeExpressions.scala | 66 ++---------------- .../rapids/shims/datetimeExpressions.scala | 68 +++---------------- 5 files changed, 112 insertions(+), 144 deletions(-) create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index cebfcf75ad6..b612851483d 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,8 +29,7 @@ non_supported_tz_allow_filter = ['ProjectExec', 'FilterExec'] if not is_supported_time_zone() else [] # We only support literal intervals for TimeSub -vals = [(-584, 1563), (1943, 1101), (2693, 2167), (2729, 0), (44, 1534), (2635, 3319), - (1885, -2828), (0, 2463), (932, 2286), (0, 0)] +vals = [(0, 1)] @pytest.mark.parametrize('data_gen', vals, ids=idfn) @allow_non_gpu(*non_supported_tz_allow) def test_timesub(data_gen): @@ -41,36 +40,58 @@ def test_timesub(data_gen): @pytest.mark.parametrize('data_gen', vals, ids=idfn) @allow_non_gpu(*non_supported_tz_allow) -def test_timeadd(data_gen): +def test_timeadd_special(data_gen): days, seconds = data_gen assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, TimestampGen()) + lambda spark: unary_op_df(spark, TimestampGen(start=datetime(1900, 12, 31, 15, 55, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc)), length=100) .selectExpr("a + (interval {} days {} seconds)".format(days, seconds))) -@pytest.mark.parametrize('data_gen', [-pow(2, 63), pow(2, 63)], ids=idfn) +def test_to_utc_timestamp_and_from_utc_timestamp(): + aaa = with_cpu_session(lambda spark: unary_op_df(spark, TimestampGen(start=datetime(1900, 12, 31, 15, 55, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))).collect()) + bbb = with_cpu_session(lambda spark: unary_op_df(spark, TimestampGen(start=datetime(1900, 12, 31, 15, 55, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))).selectExpr("from_utc_timestamp(to_utc_timestamp(a, 'Asia/Shanghai'), 'Asia/Shanghai')").collect()) + assert aaa == bbb + +@pytest.mark.parametrize('edge_value', [-pow(2, 63), pow(2, 63)], ids=idfn) @allow_non_gpu(*non_supported_tz_allow) -def test_timeadd_long_overflow(data_gen): +def test_timeadd_long_overflow(edge_value): assert_gpu_and_cpu_error( lambda spark: unary_op_df(spark, TimestampGen()) - .selectExpr("a + (interval {} microseconds)".format(data_gen)), + .selectExpr("a + (interval {} microseconds)".format(edge_value)), conf={}, error_message='long overflow') @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @allow_non_gpu(*non_supported_tz_allow) -def test_timeadd_daytime_column(): +def test_timeadd_daytime_column_normal(): + gen_list = [ + # timestamp column max year is 1000 + ('t', TimestampGen(start=datetime(1900, 12, 31, 15, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))), + # max days is 8000 year, so added result will not be out of range + ('d', DayTimeIntervalGen(min_value=timedelta(seconds=0), max_value=timedelta(seconds=0)))] + assert_gpu_and_cpu_are_equal_collect( + lambda spark: gen_df(spark, gen_list, length=2048).selectExpr("t", "d", "t + d")) + +def test_timeadd_cpu_only(): gen_list = [ # timestamp column max year is 1000 - ('t', TimestampGen(end=datetime(1000, 1, 1, tzinfo=timezone.utc))), + ('t', TimestampGen(start=datetime(1900, 12, 31, 15, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))), # max days is 8000 year, so added result will not be out of range - ('d', DayTimeIntervalGen(min_value=timedelta(days=0), max_value=timedelta(days=8000 * 365)))] + ('d', DayTimeIntervalGen(min_value=timedelta(seconds=0), max_value=timedelta(seconds=0)))] + cpu_before = with_cpu_session(lambda spark: gen_df(spark, gen_list, length=2048).collect()) + cpu_after = with_cpu_session(lambda spark: gen_df(spark, gen_list, length=2048).selectExpr("t + d").collect()) + assert cpu_before == cpu_after + +def test_to_utc_timestamp(): assert_gpu_and_cpu_are_equal_collect( - lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND")) + lambda spark: unary_op_df(spark, TimestampGen(start=datetime(1900, 12, 31, 15, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))) + .selectExpr("to_utc_timestamp(a, 'Asia/Shanghai')")) @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @allow_non_gpu(*non_supported_tz_allow) def test_timeadd_daytime_column_long_overflow(): - gen_list = [('t', TimestampGen()),('d', DayTimeIntervalGen())] + overflow_gen = SetValuesGen(DayTimeIntervalType(), + [timedelta(microseconds=-pow(2, 63)), timedelta(microseconds=(pow(2, 63) - 1))]) + gen_list = [('t', TimestampGen()),('d', overflow_gen)] assert_gpu_and_cpu_error( lambda spark : gen_df(spark, gen_list).selectExpr("t + d").collect(), conf={}, @@ -302,7 +323,7 @@ def test_unsupported_fallback_to_unix_timestamp(data_gen): "ToUnixTimestamp") @pytest.mark.parametrize('time_zone', ["Asia/Shanghai", "Iran", "UTC", "UTC+0", "UTC-0", "GMT", "GMT+0", "GMT-0"], ids=idfn) -@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn) +@pytest.mark.parametrize('data_gen', [TimestampGen(start=datetime(1900, 12, 31, 7, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 23, tzinfo=timezone.utc))], ids=idfn) @tz_sensitive_test @allow_non_gpu(*non_utc_allow) def test_from_utc_timestamp(data_gen, time_zone): @@ -311,7 +332,7 @@ def test_from_utc_timestamp(data_gen, time_zone): @allow_non_gpu('ProjectExec') @pytest.mark.parametrize('time_zone', ["PST", "NST", "AST", "America/Los_Angeles", "America/New_York", "America/Chicago"], ids=idfn) -@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn) +@pytest.mark.parametrize('data_gen', [TimestampGen(start=datetime(1900, 12, 31, 7, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 23, tzinfo=timezone.utc))], ids=idfn) @tz_sensitive_test def test_from_utc_timestamp_unsupported_timezone_fallback(data_gen, time_zone): assert_gpu_fallback_collect( @@ -319,7 +340,7 @@ def test_from_utc_timestamp_unsupported_timezone_fallback(data_gen, time_zone): 'FromUTCTimestamp') @pytest.mark.parametrize('time_zone', ["UTC", "Asia/Shanghai", "EST", "MST", "VST"], ids=idfn) -@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn) +@pytest.mark.parametrize('data_gen', [TimestampGen(start=datetime(1900, 12, 31, 7, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 23, tzinfo=timezone.utc))], ids=idfn) @tz_sensitive_test @allow_non_gpu(*non_utc_allow) def test_from_utc_timestamp_supported_timezones(data_gen, time_zone): @@ -327,7 +348,7 @@ def test_from_utc_timestamp_supported_timezones(data_gen, time_zone): lambda spark: unary_op_df(spark, data_gen).select(f.from_utc_timestamp(f.col('a'), time_zone))) @allow_non_gpu('ProjectExec') -@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn) +@pytest.mark.parametrize('data_gen', [TimestampGen(start=datetime(1900, 12, 31, 7, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 23, tzinfo=timezone.utc))], ids=idfn) def test_unsupported_fallback_from_utc_timestamp(data_gen): time_zone_gen = StringGen(pattern="UTC") assert_gpu_fallback_collect( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala index 3e96fa7d419..2954046aa55 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,7 +37,8 @@ object AddOverflowChecks { def basicOpOverflowCheck( lhs: BinaryOperable, rhs: BinaryOperable, - ret: ColumnVector): Unit = { + ret: ColumnVector, + msg: String = "One or more rows overflow for Add operation."): Unit = { // Check overflow. It is true if the arguments have different signs and // the sign of the result is different from the sign of x. // Which is equal to "((x ^ r) & (y ^ r)) < 0" in the form of arithmetic. @@ -54,9 +55,7 @@ object AddOverflowChecks { withResource(signDiffCV) { signDiff => withResource(signDiff.any()) { any => if (any.isValid && any.getBoolean) { - throw RapidsErrorUtils.arithmeticOverflowError( - "One or more rows overflow for Add operation." - ) + throw RapidsErrorUtils.arithmeticOverflowError(msg) } } } @@ -114,7 +113,8 @@ object SubtractOverflowChecks { def basicOpOverflowCheck( lhs: BinaryOperable, rhs: BinaryOperable, - ret: ColumnVector): Unit = { + ret: ColumnVector, + msg: String = "One or more rows overflow for Add operation."): Unit = { // Check overflow. It is true if the arguments have different signs and // the sign of the result is different from the sign of x. // Which is equal to "((x ^ y) & (x ^ r)) < 0" in the form of arithmetic. @@ -131,8 +131,7 @@ object SubtractOverflowChecks { withResource(signDiffCV) { signDiff => withResource(signDiff.any()) { any => if (any.isValid && any.getBoolean) { - throw RapidsErrorUtils. - arithmeticOverflowError("One or more rows overflow for Subtract operation.") + throw RapidsErrorUtils.arithmeticOverflowError(msg) } } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala new file mode 100644 index 00000000000..5d8de3f7d70 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar} +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} + +object datetimeExpressionsUtils { + def timestampAddDuration(cv: ColumnView, duration: BinaryOperable): ColumnVector = { + // Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion, + // and currently BinaryOperable.implicitConversion return Long + // Directly specify the return type is TIMESTAMP_MICROSECONDS + val resWithOverflow = cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS) + closeOnExcept(resWithOverflow) { _ => + withResource(resWithOverflow.castTo(DType.INT64)) { resWithOverflowLong => + withResource(cv.bitCastTo(DType.INT64)) { cvLong => + duration match { + case dur: Scalar => + val durLong = Scalar.fromLong(dur.getLong) + withResource(durLong) { _ => + AddOverflowChecks.basicOpOverflowCheck( + cvLong, durLong, resWithOverflowLong, "long overflow") + } + case dur: ColumnView => + withResource(dur.bitCastTo(DType.INT64)) { durationLong => + AddOverflowChecks.basicOpOverflowCheck( + cvLong, durationLong, resWithOverflowLong, "long overflow") + } + case _ => + throw new UnsupportedOperationException("only scalar and column arguments " + + s"are supported, got ${duration.getClass}") + } + } + } + } + resWithOverflow + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index 276f391bc4b..94c4f9ea525 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,15 +23,16 @@ package org.apache.spark.sql.rapids.shims import java.util.concurrent.TimeUnit -import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar} +import ai.rapids.cudf.{DType, Scalar} import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource, withResourceIfAllowed} +import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} import com.nvidia.spark.rapids.GpuOverrides import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.jni.GpuTimeZoneDB import com.nvidia.spark.rapids.shims.ShimBinaryExpression import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression} +import org.apache.spark.sql.rapids.datetimeExpressionsUtils import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.CalendarInterval @@ -66,7 +67,7 @@ case class GpuTimeAdd( } override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { - withResourceIfAllowed(left.columnarEval(batch)) { lhs => + withResource(left.columnarEval(batch)) { lhs => withResourceIfAllowed(right.columnarEvalAny(batch)) { rhs => // lhs is start, rhs is interval (lhs, rhs) match { @@ -82,13 +83,13 @@ case class GpuTimeAdd( if (usToSub != 0) { val res = if (GpuOverrides.isUTCTimezone(zoneId)) { withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, usToSub)) { d => - timestampAddDuration(l.getBase, d) + datetimeExpressionsUtils.timestampAddDuration(l.getBase, d) } } else { val utcRes = withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(l.getBase, zoneId)) { utcTimestamp => withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, usToSub)) { - d => timestampAddDuration(utcTimestamp, d) + d => datetimeExpressionsUtils.timestampAddDuration(utcTimestamp, d) } } withResource(utcRes) { _ => @@ -106,57 +107,4 @@ case class GpuTimeAdd( } } } - - // A tricky way to check overflow. The result is overflow when positive + positive = negative - // or negative + negative = positive, so we can check the sign of the result is the same as - // the sign of the operands. - private def timestampAddDuration(cv: ColumnView, duration: BinaryOperable): ColumnVector = { - // Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion, - // and currently BinaryOperable.implicitConversion return Long - // Directly specify the return type is TIMESTAMP_MICROSECONDS - val resWithOverflow = cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS) - closeOnExcept(resWithOverflow) { _ => - val isCvPos = withResource( - Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, 0)) { zero => - cv.greaterOrEqualTo(zero) - } - val sameSignal = closeOnExcept(isCvPos) { isCvPos => - val isDurationPos = duration match { - case durScalar: Scalar => - val isPosBool = durScalar.isValid && durScalar.getLong >= 0 - Scalar.fromBool(isPosBool) - case dur : AutoCloseable => - withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, 0)) { zero => - dur.greaterOrEqualTo(zero) - } - } - withResource(isDurationPos) { _ => - isCvPos.equalTo(isDurationPos) - } - } - val isOverflow = withResource(sameSignal) { _ => - val sameSignalWithRes = withResource(isCvPos) { _ => - val isResNeg = withResource( - Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, 0)) { zero => - resWithOverflow.lessThan(zero) - } - withResource(isResNeg) { _ => - isCvPos.equalTo(isResNeg) - } - } - withResource(sameSignalWithRes) { _ => - sameSignal.and(sameSignalWithRes) - } - } - val anyOverflow = withResource(isOverflow) { _ => - isOverflow.any() - } - withResource(anyOverflow) { _ => - if (anyOverflow.isValid && anyOverflow.getBoolean) { - throw new ArithmeticException("long overflow") - } - } - } - resWithOverflow - } } diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index e3411f9bc9f..d0651de26a0 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,15 +42,16 @@ package org.apache.spark.sql.rapids.shims import java.time.ZoneId import java.util.concurrent.TimeUnit -import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar} +import ai.rapids.cudf.{DType, Scalar} import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource, withResourceIfAllowed} +import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} import com.nvidia.spark.rapids.GpuOverrides import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.jni.GpuTimeZoneDB import com.nvidia.spark.rapids.shims.ShimBinaryExpression import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression} +import org.apache.spark.sql.rapids.datetimeExpressionsUtils import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.CalendarInterval @@ -112,13 +113,13 @@ case class GpuTimeAdd(start: Expression, val zoneID = ZoneId.of(timeZoneId.getOrElse("UTC")) val resCv = if (GpuOverrides.isUTCTimezone(zoneId)) { withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, interval)) { d => - timestampAddDuration(l.getBase, d) + datetimeExpressionsUtils.timestampAddDuration(l.getBase, d) } } else { val utcRes = withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(l.getBase, zoneID)) { utcTimestamp => withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, interval)) { - d => timestampAddDuration(utcTimestamp, d) + d => datetimeExpressionsUtils.timestampAddDuration(utcTimestamp, d) } } withResource(utcRes) { _ => @@ -137,13 +138,13 @@ case class GpuTimeAdd(start: Expression, val zoneID = ZoneId.of(timeZoneId.getOrElse("UTC")) val resCv = if (GpuOverrides.isUTCTimezone(zoneId)) { withResource(r.getBase.bitCastTo(DType.DURATION_MICROSECONDS)) { duration => - timestampAddDuration(l.getBase, duration) + datetimeExpressionsUtils.timestampAddDuration(l.getBase, duration) } } else { val utcRes = withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(l.getBase, zoneID)) { utcTimestamp => withResource(r.getBase.bitCastTo(DType.DURATION_MICROSECONDS)) { duration => - timestampAddDuration(utcTimestamp, duration) + datetimeExpressionsUtils.timestampAddDuration(utcTimestamp, duration) } } withResource(utcRes) { utc => @@ -163,57 +164,4 @@ case class GpuTimeAdd(start: Expression, } } } - - // A tricky way to check overflow. The result is overflow when positive + positive = negative - // or negative + negative = positive, so we can check the sign of the result is the same as - // the sign of the operands. - private def timestampAddDuration(cv: ColumnView, duration: BinaryOperable): ColumnVector = { - // Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion, - // and currently BinaryOperable.implicitConversion return Long - // Directly specify the return type is TIMESTAMP_MICROSECONDS - val resWithOverflow = cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS) - closeOnExcept(resWithOverflow) { _ => - val isCvPos = withResource( - Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, 0)) { zero => - cv.greaterOrEqualTo(zero) - } - val sameSignal = closeOnExcept(isCvPos) { isCvPos => - val isDurationPos = duration match { - case durScalar: Scalar => - val isPosBool = durScalar.isValid && durScalar.getLong >= 0 - Scalar.fromBool(isPosBool) - case dur : AutoCloseable => - withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, 0)) { zero => - dur.greaterOrEqualTo(zero) - } - } - withResource(isDurationPos) { _ => - isCvPos.equalTo(isDurationPos) - } - } - val isOverflow = withResource(sameSignal) { _ => - val sameSignalWithRes = withResource(isCvPos) { _ => - val isResNeg = withResource( - Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, 0)) { zero => - resWithOverflow.lessThan(zero) - } - withResource(isResNeg) { _ => - isCvPos.equalTo(isResNeg) - } - } - withResource(sameSignalWithRes) { _ => - sameSignal.and(sameSignalWithRes) - } - } - val anyOverflow = withResource(isOverflow) { _ => - isOverflow.any() - } - withResource(anyOverflow) { _ => - if (anyOverflow.isValid && anyOverflow.getBoolean) { - throw new ArithmeticException("long overflow") - } - } - } - resWithOverflow - } } From 66f3dd678aaecf0c0c3176839d8b0c1cdc07ea60 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 3 Jan 2024 22:29:35 +0800 Subject: [PATCH 11/21] revert test code Signed-off-by: Haoyang Li --- .../src/main/python/date_time_test.py | 51 ++++++------------- 1 file changed, 16 insertions(+), 35 deletions(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index b612851483d..dfbc185001b 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -29,7 +29,8 @@ non_supported_tz_allow_filter = ['ProjectExec', 'FilterExec'] if not is_supported_time_zone() else [] # We only support literal intervals for TimeSub -vals = [(0, 1)] +vals = [(-584, 1563), (1943, 1101), (2693, 2167), (2729, 0), (44, 1534), (2635, 3319), + (1885, -2828), (0, 2463), (932, 2286), (0, 0)] @pytest.mark.parametrize('data_gen', vals, ids=idfn) @allow_non_gpu(*non_supported_tz_allow) def test_timesub(data_gen): @@ -40,51 +41,31 @@ def test_timesub(data_gen): @pytest.mark.parametrize('data_gen', vals, ids=idfn) @allow_non_gpu(*non_supported_tz_allow) -def test_timeadd_special(data_gen): +def test_timeadd(data_gen): days, seconds = data_gen assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, TimestampGen(start=datetime(1900, 12, 31, 15, 55, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc)), length=100) + lambda spark: unary_op_df(spark, TimestampGen()) .selectExpr("a + (interval {} days {} seconds)".format(days, seconds))) -def test_to_utc_timestamp_and_from_utc_timestamp(): - aaa = with_cpu_session(lambda spark: unary_op_df(spark, TimestampGen(start=datetime(1900, 12, 31, 15, 55, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))).collect()) - bbb = with_cpu_session(lambda spark: unary_op_df(spark, TimestampGen(start=datetime(1900, 12, 31, 15, 55, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))).selectExpr("from_utc_timestamp(to_utc_timestamp(a, 'Asia/Shanghai'), 'Asia/Shanghai')").collect()) - assert aaa == bbb - -@pytest.mark.parametrize('edge_value', [-pow(2, 63), pow(2, 63)], ids=idfn) +@pytest.mark.parametrize('edge_vals', [-pow(2, 63), pow(2, 63)], ids=idfn) @allow_non_gpu(*non_supported_tz_allow) -def test_timeadd_long_overflow(edge_value): +def test_timeadd_long_overflow(edge_vals): assert_gpu_and_cpu_error( lambda spark: unary_op_df(spark, TimestampGen()) - .selectExpr("a + (interval {} microseconds)".format(edge_value)), + .selectExpr("a + (interval {} microseconds)".format(edge_vals)), conf={}, error_message='long overflow') @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @allow_non_gpu(*non_supported_tz_allow) -def test_timeadd_daytime_column_normal(): - gen_list = [ - # timestamp column max year is 1000 - ('t', TimestampGen(start=datetime(1900, 12, 31, 15, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))), - # max days is 8000 year, so added result will not be out of range - ('d', DayTimeIntervalGen(min_value=timedelta(seconds=0), max_value=timedelta(seconds=0)))] - assert_gpu_and_cpu_are_equal_collect( - lambda spark: gen_df(spark, gen_list, length=2048).selectExpr("t", "d", "t + d")) - -def test_timeadd_cpu_only(): +def test_timeadd_daytime_column(): gen_list = [ # timestamp column max year is 1000 - ('t', TimestampGen(start=datetime(1900, 12, 31, 15, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))), + ('t', TimestampGen(end=datetime(1000, 1, 1, tzinfo=timezone.utc))), # max days is 8000 year, so added result will not be out of range - ('d', DayTimeIntervalGen(min_value=timedelta(seconds=0), max_value=timedelta(seconds=0)))] - cpu_before = with_cpu_session(lambda spark: gen_df(spark, gen_list, length=2048).collect()) - cpu_after = with_cpu_session(lambda spark: gen_df(spark, gen_list, length=2048).selectExpr("t + d").collect()) - assert cpu_before == cpu_after - -def test_to_utc_timestamp(): + ('d', DayTimeIntervalGen(min_value=timedelta(days=0), max_value=timedelta(days=8000 * 365)))] assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, TimestampGen(start=datetime(1900, 12, 31, 15, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))) - .selectExpr("to_utc_timestamp(a, 'Asia/Shanghai')")) + lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND")) @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @allow_non_gpu(*non_supported_tz_allow) @@ -323,7 +304,7 @@ def test_unsupported_fallback_to_unix_timestamp(data_gen): "ToUnixTimestamp") @pytest.mark.parametrize('time_zone', ["Asia/Shanghai", "Iran", "UTC", "UTC+0", "UTC-0", "GMT", "GMT+0", "GMT-0"], ids=idfn) -@pytest.mark.parametrize('data_gen', [TimestampGen(start=datetime(1900, 12, 31, 7, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 23, tzinfo=timezone.utc))], ids=idfn) +@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn) @tz_sensitive_test @allow_non_gpu(*non_utc_allow) def test_from_utc_timestamp(data_gen, time_zone): @@ -332,7 +313,7 @@ def test_from_utc_timestamp(data_gen, time_zone): @allow_non_gpu('ProjectExec') @pytest.mark.parametrize('time_zone', ["PST", "NST", "AST", "America/Los_Angeles", "America/New_York", "America/Chicago"], ids=idfn) -@pytest.mark.parametrize('data_gen', [TimestampGen(start=datetime(1900, 12, 31, 7, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 23, tzinfo=timezone.utc))], ids=idfn) +@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn) @tz_sensitive_test def test_from_utc_timestamp_unsupported_timezone_fallback(data_gen, time_zone): assert_gpu_fallback_collect( @@ -340,7 +321,7 @@ def test_from_utc_timestamp_unsupported_timezone_fallback(data_gen, time_zone): 'FromUTCTimestamp') @pytest.mark.parametrize('time_zone', ["UTC", "Asia/Shanghai", "EST", "MST", "VST"], ids=idfn) -@pytest.mark.parametrize('data_gen', [TimestampGen(start=datetime(1900, 12, 31, 7, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 23, tzinfo=timezone.utc))], ids=idfn) +@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn) @tz_sensitive_test @allow_non_gpu(*non_utc_allow) def test_from_utc_timestamp_supported_timezones(data_gen, time_zone): @@ -348,7 +329,7 @@ def test_from_utc_timestamp_supported_timezones(data_gen, time_zone): lambda spark: unary_op_df(spark, data_gen).select(f.from_utc_timestamp(f.col('a'), time_zone))) @allow_non_gpu('ProjectExec') -@pytest.mark.parametrize('data_gen', [TimestampGen(start=datetime(1900, 12, 31, 7, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 23, tzinfo=timezone.utc))], ids=idfn) +@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn) def test_unsupported_fallback_from_utc_timestamp(data_gen): time_zone_gen = StringGen(pattern="UTC") assert_gpu_fallback_collect( @@ -668,4 +649,4 @@ def test_timestamp_millis_long_overflow(): @allow_non_gpu(*non_utc_allow) def test_timestamp_micros(data_gen): assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_micros(a)")) + lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_micros(a)")) \ No newline at end of file From 0801287ea08664ead6e6d058f61ffa0213a5fecf Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 4 Jan 2024 09:56:35 +0800 Subject: [PATCH 12/21] clean up Signed-off-by: Haoyang Li --- integration_tests/src/main/python/date_time_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index dfbc185001b..72611818223 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -649,4 +649,4 @@ def test_timestamp_millis_long_overflow(): @allow_non_gpu(*non_utc_allow) def test_timestamp_micros(data_gen): assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_micros(a)")) \ No newline at end of file + lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_micros(a)")) From 26152f55df12ca39fa1cdf172a025b6ef6e638e6 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 10 Jan 2024 11:10:03 +0800 Subject: [PATCH 13/21] wip Signed-off-by: Haoyang Li --- .../src/main/python/date_time_test.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 72611818223..cef7d0e26a4 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -55,6 +55,25 @@ def test_timeadd_long_overflow(edge_vals): .selectExpr("a + (interval {} microseconds)".format(edge_vals)), conf={}, error_message='long overflow') + +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@allow_non_gpu(*non_supported_tz_allow) +def test_timeadd_daytime_column_normal(): + gen_list = [ + # timestamp column max year is 1000 + ('t', TimestampGen(start=datetime(1900, 12, 31, 15, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))), + # max days is 8000 year, so added result will not be out of range + ('d', DayTimeIntervalGen(min_value=timedelta(seconds=0), max_value=timedelta(seconds=0)))] + assert_gpu_and_cpu_are_equal_collect( + lambda spark: gen_df(spark, gen_list, length=2048).selectExpr("t", "d", "t + d")) + +@pytest.mark.parametrize('data_gen', [(0, 1)], ids=idfn) +@allow_non_gpu(*non_supported_tz_allow) +def test_timeadd_special(data_gen): + days, seconds = data_gen + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, TimestampGen(start=datetime(1900, 12, 31, 15, 55, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc)), length=100) + .selectExpr("a + (interval {} days {} seconds)".format(days, seconds))) @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @allow_non_gpu(*non_supported_tz_allow) From a4334d931863535b90b3ef01f21880191cf6aba2 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 12 Jan 2024 13:45:48 +0800 Subject: [PATCH 14/21] wip Signed-off-by: Haoyang Li --- .../src/main/python/date_time_test.py | 21 +++++++--- .../rapids/shims/datetimeExpressions.scala | 38 ++++++++++--------- .../rapids/timezone/TimeZonePerfSuite.scala | 11 +----- 3 files changed, 38 insertions(+), 32 deletions(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 605760e1783..4e9ceb088ec 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -39,13 +39,13 @@ def test_timesub(data_gen): lambda spark: unary_op_df(spark, TimestampGen()) .selectExpr("a - (interval {} days {} seconds)".format(days, seconds))) -@pytest.mark.parametrize('data_gen', vals, ids=idfn) +@pytest.mark.parametrize('data_gen', [(2002335, 66506, 226873)], ids=idfn) @allow_non_gpu(*non_supported_tz_allow) -def test_timeadd(data_gen): - days, seconds = data_gen +def test_timeadd_debug(data_gen): + days, seconds, mircos = data_gen assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, TimestampGen()) - .selectExpr("a + (interval {} days {} seconds)".format(days, seconds))) + lambda spark: unary_op_df(spark, TimestampGen(), length=500) + .selectExpr("a", "a + (interval {} days {} seconds {} microseconds)".format(days, seconds, mircos))) @pytest.mark.parametrize('edge_vals', [-pow(2, 63), pow(2, 63)], ids=idfn) @allow_non_gpu(*non_supported_tz_allow) @@ -86,6 +86,17 @@ def test_timeadd_daytime_column(): assert_gpu_and_cpu_are_equal_collect( lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND")) +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@allow_non_gpu(*non_supported_tz_allow) +def test_timeadd_daytime_column_debug(): + gen_list = [ + # timestamp column max year is 1000 + ('t', TimestampGen(end=datetime(1000, 1, 1, tzinfo=timezone.utc))), + # max days is 8000 year, so added result will not be out of range + ('d', DayTimeIntervalGen(min_value=timedelta(days=0), max_value=timedelta(days=8000 * 365)))] + assert_gpu_and_cpu_are_equal_collect( + lambda spark: gen_df(spark, gen_list, length=10).selectExpr("t", "d", "t + d")) + @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @allow_non_gpu(*non_supported_tz_allow) def test_timeadd_daytime_column_long_overflow(): diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index e7ee20c2574..8552ba5c750 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -117,15 +117,16 @@ case class GpuTimeAdd(start: Expression, datetimeExpressionsUtils.timestampAddDuration(l.getBase, d) } } else { - val utcRes = withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(l.getBase, - zoneID)) { utcTimestamp => - withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, interval)) { - d => datetimeExpressionsUtils.timestampAddDuration(utcTimestamp, d) - } - } - withResource(utcRes) { _ => - GpuTimeZoneDB.fromTimestampToUtcTimestamp(utcRes, zoneID) - } + // val utcRes = withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(l.getBase, + // zoneID)) { utcTimestamp => + // withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, interval)) { + // d => datetimeExpressionsUtils.timestampAddDuration(utcTimestamp, d) + // } + // } + // withResource(utcRes) { _ => + // GpuTimeZoneDB.fromTimestampToUtcTimestamp(utcRes, zoneID) + // } + GpuTimeZoneDB.timeAdd(l.getBase, interval, zoneID) } GpuColumnVector.from(resCv, dataType) } else { @@ -142,14 +143,17 @@ case class GpuTimeAdd(start: Expression, datetimeExpressionsUtils.timestampAddDuration(l.getBase, duration) } } else { - val utcRes = withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(l.getBase, - zoneID)) { utcTimestamp => - withResource(r.getBase.bitCastTo(DType.DURATION_MICROSECONDS)) { duration => - datetimeExpressionsUtils.timestampAddDuration(utcTimestamp, duration) - } - } - withResource(utcRes) { utc => - GpuTimeZoneDB.fromTimestampToUtcTimestamp(utc, zoneID) + // val utcRes = withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(l.getBase, + // zoneID)) { utcTimestamp => + // withResource(r.getBase.bitCastTo(DType.DURATION_MICROSECONDS)) { duration => + // datetimeExpressionsUtils.timestampAddDuration(utcTimestamp, duration) + // } + // } + // withResource(utcRes) { utc => + // GpuTimeZoneDB.fromTimestampToUtcTimestamp(utc, zoneID) + // } + withResource(r.getBase.bitCastTo(DType.INT64)) { duration => + GpuTimeZoneDB.timeAdd(l.getBase, duration, zoneID) } } GpuColumnVector.from(resCv, dataType) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala index 544684afc14..4ce76a22b10 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala @@ -132,7 +132,6 @@ class TimeZonePerfSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAl println(s"test,type,zone,used MS") for (zoneStr <- zones) { // run 6 rounds, but ignore the first round. - val elapses = (1 to 6).map { i => val elapses = (1 to 6).map { i => // run on Cpu val startOnCpu = System.nanoTime() @@ -155,9 +154,6 @@ class TimeZonePerfSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAl if (i != 1) { println(s"$testName,Gpu,$zoneStr,$elapseOnGpuMS") (elapseOnCpuMS, elapseOnGpuMS) - } else { - (0L, 0L) // skip the first round - (elapseOnCpuMS, elapseOnGpuMS) } else { (0L, 0L) // skip the first round } @@ -167,11 +163,6 @@ class TimeZonePerfSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAl val speedup = meanCpu.toDouble / meanGpu.toDouble println(f"$testName, $zoneStr: mean cpu time: $meanCpu%.2f ms, " + f"mean gpu time: $meanGpu%.2f ms, speedup: $speedup%.2f x") - val meanCpu = elapses.map(_._1).sum / 5.0 - val meanGpu = elapses.map(_._2).sum / 5.0 - val speedup = meanCpu.toDouble / meanGpu.toDouble - println(f"$testName, $zoneStr: mean cpu time: $meanCpu%.2f ms, " + - f"mean gpu time: $meanGpu%.2f ms, speedup: $speedup%.2f x") } } @@ -231,4 +222,4 @@ class TimeZonePerfSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAl runAndRecordTime("to_utc_timestamp", perfTest) } -} +} \ No newline at end of file From be5f813439075aca506b2bfa3867768e5c0ec97e Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 16 Jan 2024 16:14:38 +0800 Subject: [PATCH 15/21] Use timeadd kernel Signed-off-by: Haoyang Li --- .../src/main/python/date_time_test.py | 40 +++-------------- .../sql/rapids/datetimeExpressionsUtils.scala | 23 +++++++++- .../rapids/shims/datetimeExpressions.scala | 17 +------ .../rapids/shims/datetimeExpressions.scala | 45 ++++--------------- 4 files changed, 35 insertions(+), 90 deletions(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 4e9ceb088ec..c237e85bb17 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -39,13 +39,13 @@ def test_timesub(data_gen): lambda spark: unary_op_df(spark, TimestampGen()) .selectExpr("a - (interval {} days {} seconds)".format(days, seconds))) -@pytest.mark.parametrize('data_gen', [(2002335, 66506, 226873)], ids=idfn) +@pytest.mark.parametrize('data_gen', vals, ids=idfn) @allow_non_gpu(*non_supported_tz_allow) -def test_timeadd_debug(data_gen): - days, seconds, mircos = data_gen +def test_timeadd(data_gen): + days, seconds = data_gen assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, TimestampGen(), length=500) - .selectExpr("a", "a + (interval {} days {} seconds {} microseconds)".format(days, seconds, mircos))) + .selectExpr("a + (interval {} days {} seconds)".format(days, seconds))) @pytest.mark.parametrize('edge_vals', [-pow(2, 63), pow(2, 63)], ids=idfn) @allow_non_gpu(*non_supported_tz_allow) @@ -55,25 +55,6 @@ def test_timeadd_long_overflow(edge_vals): .selectExpr("a + (interval {} microseconds)".format(edge_vals)), conf={}, error_message='long overflow') - -@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') -@allow_non_gpu(*non_supported_tz_allow) -def test_timeadd_daytime_column_normal(): - gen_list = [ - # timestamp column max year is 1000 - ('t', TimestampGen(start=datetime(1900, 12, 31, 15, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))), - # max days is 8000 year, so added result will not be out of range - ('d', DayTimeIntervalGen(min_value=timedelta(seconds=0), max_value=timedelta(seconds=0)))] - assert_gpu_and_cpu_are_equal_collect( - lambda spark: gen_df(spark, gen_list, length=2048).selectExpr("t", "d", "t + d")) - -@pytest.mark.parametrize('data_gen', [(0, 1)], ids=idfn) -@allow_non_gpu(*non_supported_tz_allow) -def test_timeadd_special(data_gen): - days, seconds = data_gen - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, TimestampGen(start=datetime(1900, 12, 31, 15, 55, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc)), length=100) - .selectExpr("a + (interval {} days {} seconds)".format(days, seconds))) @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @allow_non_gpu(*non_supported_tz_allow) @@ -82,21 +63,10 @@ def test_timeadd_daytime_column(): # timestamp column max year is 1000 ('t', TimestampGen(end=datetime(1000, 1, 1, tzinfo=timezone.utc))), # max days is 8000 year, so added result will not be out of range - ('d', DayTimeIntervalGen(min_value=timedelta(days=0), max_value=timedelta(days=8000 * 365)))] + ('d', DayTimeIntervalGen(min_value=timedelta(days=1000 * 365), max_value=timedelta(days=8000 * 365)))] assert_gpu_and_cpu_are_equal_collect( lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND")) -@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') -@allow_non_gpu(*non_supported_tz_allow) -def test_timeadd_daytime_column_debug(): - gen_list = [ - # timestamp column max year is 1000 - ('t', TimestampGen(end=datetime(1000, 1, 1, tzinfo=timezone.utc))), - # max days is 8000 year, so added result will not be out of range - ('d', DayTimeIntervalGen(min_value=timedelta(days=0), max_value=timedelta(days=8000 * 365)))] - assert_gpu_and_cpu_are_equal_collect( - lambda spark: gen_df(spark, gen_list, length=10).selectExpr("t", "d", "t + d")) - @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @allow_non_gpu(*non_supported_tz_allow) def test_timeadd_daytime_column_long_overflow(): diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala index 5d8de3f7d70..27598f22984 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala @@ -16,15 +16,34 @@ package org.apache.spark.sql.rapids +import java.time.ZoneId + import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.GpuOverrides.isUTCTimezone +import com.nvidia.spark.rapids.jni.GpuTimeZoneDB object datetimeExpressionsUtils { - def timestampAddDuration(cv: ColumnView, duration: BinaryOperable): ColumnVector = { + def timestampAddDuration(cv: ColumnVector, duration: BinaryOperable, + zoneId: ZoneId): ColumnVector = { // Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion, // and currently BinaryOperable.implicitConversion return Long // Directly specify the return type is TIMESTAMP_MICROSECONDS - val resWithOverflow = cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS) + val resWithOverflow = if (isUTCTimezone(zoneId)) { + duration match { + case durS: Scalar => cv.binaryOp(BinaryOp.ADD, durS, DType.TIMESTAMP_MICROSECONDS) + case durC: ColumnView => { + withResource(durC.bitCastTo(DType.DURATION_MICROSECONDS)) { durMirco => + cv.binaryOp(BinaryOp.ADD, durMirco, DType.TIMESTAMP_MICROSECONDS) + } + } + } + } else { + duration match { + case durS: Scalar => GpuTimeZoneDB.timeAdd(cv, durS, zoneId) + case durC: ColumnView => GpuTimeZoneDB.timeAdd(cv, durC, zoneId) + } + } closeOnExcept(resWithOverflow) { _ => withResource(resWithOverflow.castTo(DType.INT64)) { resWithOverflowLong => withResource(cv.bitCastTo(DType.INT64)) { cvLong => diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index 94c4f9ea525..6f289785b1e 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -28,7 +28,6 @@ import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar} import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} import com.nvidia.spark.rapids.GpuOverrides import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.jni.GpuTimeZoneDB import com.nvidia.spark.rapids.shims.ShimBinaryExpression import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression} @@ -81,21 +80,7 @@ case class GpuTimeAdd( } val usToSub = intvl.days * microSecondsInOneDay + intvl.microseconds if (usToSub != 0) { - val res = if (GpuOverrides.isUTCTimezone(zoneId)) { - withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, usToSub)) { d => - datetimeExpressionsUtils.timestampAddDuration(l.getBase, d) - } - } else { - val utcRes = withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(l.getBase, - zoneId)) { utcTimestamp => - withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, usToSub)) { - d => datetimeExpressionsUtils.timestampAddDuration(utcTimestamp, d) - } - } - withResource(utcRes) { _ => - GpuTimeZoneDB.fromTimestampToUtcTimestamp(utcRes, zoneId) - } - } + val res = datetimeExpressionsUtils.timestampAddDuration(l.getBase, d, zoneId) GpuColumnVector.from(res, dataType) } else { l.incRefCount() diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index 8552ba5c750..be51290d85b 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -46,9 +46,7 @@ import java.util.concurrent.TimeUnit import ai.rapids.cudf.{DType, Scalar} import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar} import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} -import com.nvidia.spark.rapids.GpuOverrides import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.jni.GpuTimeZoneDB import com.nvidia.spark.rapids.shims.ShimBinaryExpression import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression} @@ -111,22 +109,11 @@ case class GpuTimeAdd(start: Expression, // add interval if (interval != 0) { - val zoneID = ZoneId.of(timeZoneId.getOrElse("UTC")) - val resCv = if (GpuOverrides.isUTCTimezone(zoneId)) { - withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, interval)) { d => - datetimeExpressionsUtils.timestampAddDuration(l.getBase, d) - } - } else { - // val utcRes = withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(l.getBase, - // zoneID)) { utcTimestamp => - // withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, interval)) { - // d => datetimeExpressionsUtils.timestampAddDuration(utcTimestamp, d) - // } - // } - // withResource(utcRes) { _ => - // GpuTimeZoneDB.fromTimestampToUtcTimestamp(utcRes, zoneID) - // } - GpuTimeZoneDB.timeAdd(l.getBase, interval, zoneID) + val zoneId = ZoneId.of(timeZoneId.getOrElse("UTC")) + val resCv = withResource(Scalar.durationFromLong( + DType.DURATION_MICROSECONDS, interval)) { duration => + datetimeExpressionsUtils.timestampAddDuration( + l.getBase, duration, zoneId) } GpuColumnVector.from(resCv, dataType) } else { @@ -137,25 +124,9 @@ case class GpuTimeAdd(start: Expression, case (_: TimestampType, _: DayTimeIntervalType) => // DayTimeIntervalType is stored as long // bitCastTo is similar to reinterpret_cast, it's fast, the time can be ignored. - val zoneID = ZoneId.of(timeZoneId.getOrElse("UTC")) - val resCv = if (GpuOverrides.isUTCTimezone(zoneId)) { - withResource(r.getBase.bitCastTo(DType.DURATION_MICROSECONDS)) { duration => - datetimeExpressionsUtils.timestampAddDuration(l.getBase, duration) - } - } else { - // val utcRes = withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(l.getBase, - // zoneID)) { utcTimestamp => - // withResource(r.getBase.bitCastTo(DType.DURATION_MICROSECONDS)) { duration => - // datetimeExpressionsUtils.timestampAddDuration(utcTimestamp, duration) - // } - // } - // withResource(utcRes) { utc => - // GpuTimeZoneDB.fromTimestampToUtcTimestamp(utc, zoneID) - // } - withResource(r.getBase.bitCastTo(DType.INT64)) { duration => - GpuTimeZoneDB.timeAdd(l.getBase, duration, zoneID) - } - } + val zoneId = ZoneId.of(timeZoneId.getOrElse("UTC")) + val resCv = datetimeExpressionsUtils.timestampAddDuration( + l.getBase, r.getBase, zoneId) GpuColumnVector.from(resCv, dataType) case _ => throw new UnsupportedOperationException( From 06c6c475fd5c7d7b2ee9ea46a015256e30c65a3e Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 16 Jan 2024 16:29:42 +0800 Subject: [PATCH 16/21] clean up Signed-off-by: Haoyang Li --- integration_tests/src/main/python/date_time_test.py | 4 ++-- .../spark/sql/rapids/datetimeExpressionsUtils.scala | 8 ++++---- .../nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index c237e85bb17..64ed1fc0bd4 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -44,7 +44,7 @@ def test_timesub(data_gen): def test_timeadd(data_gen): days, seconds = data_gen assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, TimestampGen(), length=500) + lambda spark: unary_op_df(spark, TimestampGen()) .selectExpr("a + (interval {} days {} seconds)".format(days, seconds))) @pytest.mark.parametrize('edge_vals', [-pow(2, 63), pow(2, 63)], ids=idfn) @@ -63,7 +63,7 @@ def test_timeadd_daytime_column(): # timestamp column max year is 1000 ('t', TimestampGen(end=datetime(1000, 1, 1, tzinfo=timezone.utc))), # max days is 8000 year, so added result will not be out of range - ('d', DayTimeIntervalGen(min_value=timedelta(days=1000 * 365), max_value=timedelta(days=8000 * 365)))] + ('d', DayTimeIntervalGen(min_value=timedelta(days=-1000 * 365), max_value=timedelta(days=8000 * 365)))] assert_gpu_and_cpu_are_equal_collect( lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND")) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala index 27598f22984..cbef3afce9d 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala @@ -26,10 +26,10 @@ import com.nvidia.spark.rapids.jni.GpuTimeZoneDB object datetimeExpressionsUtils { def timestampAddDuration(cv: ColumnVector, duration: BinaryOperable, zoneId: ZoneId): ColumnVector = { - // Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion, - // and currently BinaryOperable.implicitConversion return Long - // Directly specify the return type is TIMESTAMP_MICROSECONDS val resWithOverflow = if (isUTCTimezone(zoneId)) { + // Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion, + // and currently BinaryOperable.implicitConversion return Long + // Directly specify the return type is TIMESTAMP_MICROSECONDS duration match { case durS: Scalar => cv.binaryOp(BinaryOp.ADD, durS, DType.TIMESTAMP_MICROSECONDS) case durC: ColumnView => { @@ -68,4 +68,4 @@ object datetimeExpressionsUtils { } resWithOverflow } -} \ No newline at end of file +} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala index 4ce76a22b10..dab577534ca 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/timezone/TimeZonePerfSuite.scala @@ -222,4 +222,4 @@ class TimeZonePerfSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAl runAndRecordTime("to_utc_timestamp", perfTest) } -} \ No newline at end of file +} From ec7c4b07c37d92e524db6334c89f791a742bb326 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 18 Jan 2024 22:23:12 +0800 Subject: [PATCH 17/21] fix 311 build and address comments Signed-off-by: Haoyang Li --- .../additional-functionality/advanced_configs.md | 1 + .../spark/sql/rapids/datetimeExpressions.scala | 4 ++-- .../sql/rapids/datetimeExpressionsUtils.scala | 8 +++++++- .../sql/rapids/shims/datetimeExpressions.scala | 16 +++++++++------- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md index 4ac0a8d3bee..f577daaf10f 100644 --- a/docs/additional-functionality/advanced_configs.md +++ b/docs/additional-functionality/advanced_configs.md @@ -33,6 +33,7 @@ Name | Description | Default Value | Applicable at spark.rapids.filecache.blockPathRegexp|A regular expression to decide which paths will not be cached when the file cache is enabled. If a path is blocked by this regexp but is allowed by spark.rapids.filecache.allowPathRegexp, then the path is blocked.|None|Startup spark.rapids.filecache.checkStale|Controls whether the cached is checked for being out of date with respect to the input file. When enabled, the data that has been cached locally for a file will be invalidated if the file is updated after being cached. This feature is only necessary if an input file for a Spark application can be changed during the lifetime of the application. If an individual input file will not be overwritten during the Spark application then performance may be improved by setting this to false.|true|Startup spark.rapids.filecache.maxBytes|Controls the maximum amount of data that will be cached locally. If left unspecified, it will use half of the available disk space detected on startup for the configured Spark local disks.|None|Startup +spark.rapids.filecache.useChecksums|Whether to write out and verify checksums for the cached local files.|false|Startup spark.rapids.gpu.resourceName|The name of the Spark resource that represents a GPU that you want the plugin to use if using custom resources with Spark.|gpu|Startup spark.rapids.memory.gpu.allocFraction|The fraction of available (free) GPU memory that should be allocated for pooled memory. This must be less than or equal to the maximum limit configured via spark.rapids.memory.gpu.maxAllocFraction, and greater than or equal to the minimum limit configured via spark.rapids.memory.gpu.minAllocFraction.|1.0|Startup spark.rapids.memory.gpu.debug|Provides a log of GPU memory allocations and frees. If set to STDOUT or STDERR the logging will go there. Setting it to NONE disables logging. All other values are reserved for possible future expansion and in the mean time will disable logging.|NONE|Startup diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 40f64c99c32..5e1cf2ed560 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -155,8 +155,8 @@ case class GpuDateAddInterval(start: Expression, override def left: Expression = start override def right: Expression = interval - override def toString: String = s"$left - $right" - override def sql: String = s"${left.sql} - ${right.sql}" + override def toString: String = s"$left + $right" + override def sql: String = s"${left.sql} + ${right.sql}" override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala index cbef3afce9d..8e573b3e6d0 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala @@ -26,12 +26,18 @@ import com.nvidia.spark.rapids.jni.GpuTimeZoneDB object datetimeExpressionsUtils { def timestampAddDuration(cv: ColumnVector, duration: BinaryOperable, zoneId: ZoneId): ColumnVector = { + assert(cv.getType == DType.TIMESTAMP_MICROSECONDS, + "cv should be TIMESTAMP_MICROSECONDS type but got " + cv.getType) + assert(duration.getType == DType.DURATION_MICROSECONDS, + "duration should be DURATION_MICROSECONDS type but got " + duration.getType) val resWithOverflow = if (isUTCTimezone(zoneId)) { // Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion, // and currently BinaryOperable.implicitConversion return Long // Directly specify the return type is TIMESTAMP_MICROSECONDS duration match { - case durS: Scalar => cv.binaryOp(BinaryOp.ADD, durS, DType.TIMESTAMP_MICROSECONDS) + case durS: Scalar => { + cv.binaryOp(BinaryOp.ADD, durS, DType.TIMESTAMP_MICROSECONDS) + } case durC: ColumnView => { withResource(durC.bitCastTo(DType.DURATION_MICROSECONDS)) { durMirco => cv.binaryOp(BinaryOp.ADD, durMirco, DType.TIMESTAMP_MICROSECONDS) diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index 6f289785b1e..414407297e1 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -26,7 +26,6 @@ import java.util.concurrent.TimeUnit import ai.rapids.cudf.{DType, Scalar} import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar} import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} -import com.nvidia.spark.rapids.GpuOverrides import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.shims.ShimBinaryExpression @@ -51,8 +50,8 @@ case class GpuTimeAdd( override def left: Expression = start override def right: Expression = interval - override def toString: String = s"$left - $right" - override def sql: String = s"${left.sql} - ${right.sql}" + override def toString: String = s"$left + $right" + override def sql: String = s"${left.sql} + ${right.sql}" override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) override def dataType: DataType = TimestampType @@ -78,10 +77,13 @@ case class GpuTimeAdd( if (intvl.months != 0) { throw new UnsupportedOperationException("Months aren't supported at the moment") } - val usToSub = intvl.days * microSecondsInOneDay + intvl.microseconds - if (usToSub != 0) { - val res = datetimeExpressionsUtils.timestampAddDuration(l.getBase, d, zoneId) - GpuColumnVector.from(res, dataType) + val interval = intvl.days * microSecondsInOneDay + intvl.microseconds + if (interval != 0) { + val resCv = withResource(Scalar.durationFromLong( + DType.DURATION_MICROSECONDS, interval)) { duration => + datetimeExpressionsUtils.timestampAddDuration(l.getBase, duration, zoneId) + } + GpuColumnVector.from(resCv, dataType) } else { l.incRefCount() } From 511f8ee0e11548531ea083f7313b57833b130ff5 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 23 Jan 2024 10:00:21 +0800 Subject: [PATCH 18/21] wip Signed-off-by: Haoyang Li --- integration_tests/src/main/python/date_time_test.py | 11 +++++++++++ .../spark/sql/rapids/datetimeExpressionsUtils.scala | 4 +--- .../spark/sql/rapids/shims/datetimeExpressions.scala | 5 +++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 64ed1fc0bd4..ceae5157158 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -67,6 +67,17 @@ def test_timeadd_daytime_column(): assert_gpu_and_cpu_are_equal_collect( lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND")) +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +@allow_non_gpu(*non_supported_tz_allow) +def test_timeadd_daytime_column_debug(): + gen_list = [ + # timestamp column max year is 1000 + ('t', TimestampGen(end=datetime(1000, 1, 1, tzinfo=timezone.utc))), + # max days is 8000 year, so added result will not be out of range + ('d', DayTimeIntervalGen(min_value=timedelta(days=-1000 * 365), max_value=timedelta(days=8000 * 365)))] + assert_gpu_and_cpu_are_equal_collect( + lambda spark: gen_df(spark, gen_list, length=2000000).selectExpr("t", "d", "t + d")) + @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @allow_non_gpu(*non_supported_tz_allow) def test_timeadd_daytime_column_long_overflow(): diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala index 8e573b3e6d0..6d85a3f1bd1 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala @@ -39,9 +39,7 @@ object datetimeExpressionsUtils { cv.binaryOp(BinaryOp.ADD, durS, DType.TIMESTAMP_MICROSECONDS) } case durC: ColumnView => { - withResource(durC.bitCastTo(DType.DURATION_MICROSECONDS)) { durMirco => - cv.binaryOp(BinaryOp.ADD, durMirco, DType.TIMESTAMP_MICROSECONDS) - } + cv.binaryOp(BinaryOp.ADD, durC, DType.TIMESTAMP_MICROSECONDS) } } } else { diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index be51290d85b..d897f2af7db 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -125,8 +125,9 @@ case class GpuTimeAdd(start: Expression, // DayTimeIntervalType is stored as long // bitCastTo is similar to reinterpret_cast, it's fast, the time can be ignored. val zoneId = ZoneId.of(timeZoneId.getOrElse("UTC")) - val resCv = datetimeExpressionsUtils.timestampAddDuration( - l.getBase, r.getBase, zoneId) + val resCv = withResource(r.getBase.bitCastTo(DType.DURATION_MICROSECONDS)) { dur => + datetimeExpressionsUtils.timestampAddDuration(l.getBase, dur, zoneId) + } GpuColumnVector.from(resCv, dataType) case _ => throw new UnsupportedOperationException( From a3ab495dd5d1b2be1633f9999d7f7e2eed954f50 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 25 Jan 2024 00:24:53 +0800 Subject: [PATCH 19/21] match Calendar types Signed-off-by: Haoyang Li --- .../src/main/python/date_time_test.py | 10 +-- .../sql/rapids/datetimeExpressionsUtils.scala | 78 ++++++++++++++----- .../rapids/shims/datetimeExpressions.scala | 18 +---- .../rapids/shims/datetimeExpressions.scala | 34 ++++---- 4 files changed, 81 insertions(+), 59 deletions(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index ceae5157158..671eb50e2a9 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -30,7 +30,7 @@ # We only support literal intervals for TimeSub vals = [(-584, 1563), (1943, 1101), (2693, 2167), (2729, 0), (44, 1534), (2635, 3319), - (1885, -2828), (0, 2463), (932, 2286), (0, 0)] + (1885, -2828), (0, 2463), (932, 2286), (0, 0), (0, 86400), (1, 86401), (1, 8640000)] @pytest.mark.parametrize('data_gen', vals, ids=idfn) @allow_non_gpu(*non_supported_tz_allow) def test_timesub(data_gen): @@ -44,7 +44,7 @@ def test_timesub(data_gen): def test_timeadd(data_gen): days, seconds = data_gen assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, TimestampGen()) + lambda spark: unary_op_df(spark, TimestampGen(), length=200000) .selectExpr("a + (interval {} days {} seconds)".format(days, seconds))) @pytest.mark.parametrize('edge_vals', [-pow(2, 63), pow(2, 63)], ids=idfn) @@ -58,14 +58,14 @@ def test_timeadd_long_overflow(edge_vals): @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @allow_non_gpu(*non_supported_tz_allow) -def test_timeadd_daytime_column(): +def test_timeadd_daytime_column_invalid(): gen_list = [ # timestamp column max year is 1000 - ('t', TimestampGen(end=datetime(1000, 1, 1, tzinfo=timezone.utc))), + ('t', TimestampGen(end=datetime(2000, 1, 1, tzinfo=timezone.utc))), # max days is 8000 year, so added result will not be out of range ('d', DayTimeIntervalGen(min_value=timedelta(days=-1000 * 365), max_value=timedelta(days=8000 * 365)))] assert_gpu_and_cpu_are_equal_collect( - lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND")) + lambda spark: gen_df(spark, gen_list).selectExpr("t + INTERVAL 0 DAYS 86400000001 MICROSECONDS")) @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @allow_non_gpu(*non_supported_tz_allow) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala index 6d85a3f1bd1..4f55fff1539 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala @@ -17,14 +17,18 @@ package org.apache.spark.sql.rapids import java.time.ZoneId +import java.util.concurrent.TimeUnit import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.GpuOverrides.isUTCTimezone import com.nvidia.spark.rapids.jni.GpuTimeZoneDB object datetimeExpressionsUtils { - def timestampAddDuration(cv: ColumnVector, duration: BinaryOperable, + + val microSecondsInOneDay: Long = TimeUnit.DAYS.toMicros(1) + + def timestampAddDurationUs(cv: ColumnVector, duration: BinaryOperable, zoneId: ZoneId): ColumnVector = { assert(cv.getType == DType.TIMESTAMP_MICROSECONDS, "cv should be TIMESTAMP_MICROSECONDS type but got " + cv.getType) @@ -48,28 +52,60 @@ object datetimeExpressionsUtils { case durC: ColumnView => GpuTimeZoneDB.timeAdd(cv, durC, zoneId) } } - closeOnExcept(resWithOverflow) { _ => - withResource(resWithOverflow.castTo(DType.INT64)) { resWithOverflowLong => - withResource(cv.bitCastTo(DType.INT64)) { cvLong => - duration match { - case dur: Scalar => - val durLong = Scalar.fromLong(dur.getLong) - withResource(durLong) { _ => - AddOverflowChecks.basicOpOverflowCheck( - cvLong, durLong, resWithOverflowLong, "long overflow") - } - case dur: ColumnView => - withResource(dur.bitCastTo(DType.INT64)) { durationLong => - AddOverflowChecks.basicOpOverflowCheck( - cvLong, durationLong, resWithOverflowLong, "long overflow") - } - case _ => - throw new UnsupportedOperationException("only scalar and column arguments " + - s"are supported, got ${duration.getClass}") + timeAddOverflowCheck(cv, duration, resWithOverflow) + resWithOverflow + } + + def timestampAddDurationCalendar(cv: ColumnVector, days: Long, + microseconds: Long, zoneId: ZoneId): ColumnVector = { + val interval = days * microSecondsInOneDay + microseconds + if (interval == 0) { + return cv.incRefCount() + } + val resWithOverflow = if (isUTCTimezone(zoneId)) { + cv.binaryOp(BinaryOp.ADD, Scalar.durationFromLong(DType.DURATION_MICROSECONDS, + interval), DType.TIMESTAMP_MICROSECONDS) + } else { + val daysScalar = Scalar.durationFromLong(DType.DURATION_MICROSECONDS, + days * microSecondsInOneDay) + val resDays = withResource(daysScalar) { _ => + GpuTimeZoneDB.timeAdd(cv, daysScalar, zoneId) + } + withResource(resDays) { _ => + resDays.binaryOp(BinaryOp.ADD, Scalar.durationFromLong(DType.DURATION_MICROSECONDS, + microseconds), DType.TIMESTAMP_MICROSECONDS) + } + } + withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, + interval)) { duration => + timeAddOverflowCheck(cv, duration, resWithOverflow) + } + resWithOverflow + } + + def timeAddOverflowCheck( + cv: ColumnVector, + duration: BinaryOperable, + resWithOverflow: ColumnVector): Unit = { + withResource(resWithOverflow.castTo(DType.INT64)) { resWithOverflowLong => + withResource(cv.bitCastTo(DType.INT64)) { cvLong => + duration match { + case dur: Scalar => + val durLong = Scalar.fromLong(dur.getLong) + withResource(durLong) { _ => + AddOverflowChecks.basicOpOverflowCheck( + cvLong, durLong, resWithOverflowLong, "long overflow") } + case dur: ColumnView => + withResource(dur.bitCastTo(DType.INT64)) { durationLong => + AddOverflowChecks.basicOpOverflowCheck( + cvLong, durationLong, resWithOverflowLong, "long overflow") + } + case _ => + throw new UnsupportedOperationException("only scalar and column arguments " + + s"are supported, got ${duration.getClass}") } } } - resWithOverflow } } diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index 414407297e1..405ada9b61e 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -21,9 +21,6 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims -import java.util.concurrent.TimeUnit - -import ai.rapids.cudf.{DType, Scalar} import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar} import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -58,8 +55,6 @@ case class GpuTimeAdd( override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess - val microSecondsInOneDay: Long = TimeUnit.DAYS.toMicros(1) - override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = { copy(timeZoneId = Option(timeZoneId)) } @@ -77,16 +72,9 @@ case class GpuTimeAdd( if (intvl.months != 0) { throw new UnsupportedOperationException("Months aren't supported at the moment") } - val interval = intvl.days * microSecondsInOneDay + intvl.microseconds - if (interval != 0) { - val resCv = withResource(Scalar.durationFromLong( - DType.DURATION_MICROSECONDS, interval)) { duration => - datetimeExpressionsUtils.timestampAddDuration(l.getBase, duration, zoneId) - } - GpuColumnVector.from(resCv, dataType) - } else { - l.incRefCount() - } + val resCv = datetimeExpressionsUtils.timestampAddDurationCalendar(l.getBase, + intvl.days, intvl.microseconds, zoneId) + GpuColumnVector.from(resCv, dataType) case _ => throw new UnsupportedOperationException("only column and interval arguments " + s"are supported, got left: ${lhs.getClass} right: ${rhs.getClass}") diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index d897f2af7db..aae3aad6def 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -89,8 +89,7 @@ case class GpuTimeAdd(start: Expression, // lhs is start, rhs is interval (lhs, rhs) match { case (l, intervalS: GpuScalar) => - // get long type interval - val interval = intervalS.dataType match { + intervalS.dataType match { case CalendarIntervalType => // Scalar does not support 'CalendarInterval' now, so use // the Scala value instead. @@ -99,26 +98,25 @@ case class GpuTimeAdd(start: Expression, if (calendarI.months != 0) { throw new UnsupportedOperationException("Months aren't supported at the moment") } - calendarI.days * microSecondsInOneDay + calendarI.microseconds + timestampAddDurationCalendar(l, calendarI.days, calendarI.microseconds, timeZone) case _: DayTimeIntervalType => - intervalS.getValue.asInstanceOf[Long] + val interval = intervalS.getValue.asInstanceOf[Long] + // add interval + if (interval != 0) { + val zoneId = ZoneId.of(timeZoneId.getOrElse("UTC")) + val resCv = withResource(Scalar.durationFromLong( + DType.DURATION_MICROSECONDS, interval)) { duration => + datetimeExpressionsUtils.timestampAddDurationUs( + l.getBase, duration, zoneId) + } + GpuColumnVector.from(resCv, dataType) + } else { + l.incRefCount() + } case _ => throw new UnsupportedOperationException( "GpuTimeAdd unsupported data type: " + intervalS.dataType) } - - // add interval - if (interval != 0) { - val zoneId = ZoneId.of(timeZoneId.getOrElse("UTC")) - val resCv = withResource(Scalar.durationFromLong( - DType.DURATION_MICROSECONDS, interval)) { duration => - datetimeExpressionsUtils.timestampAddDuration( - l.getBase, duration, zoneId) - } - GpuColumnVector.from(resCv, dataType) - } else { - l.incRefCount() - } case (l, r: GpuColumnVector) => (l.dataType(), r.dataType) match { case (_: TimestampType, _: DayTimeIntervalType) => @@ -126,7 +124,7 @@ case class GpuTimeAdd(start: Expression, // bitCastTo is similar to reinterpret_cast, it's fast, the time can be ignored. val zoneId = ZoneId.of(timeZoneId.getOrElse("UTC")) val resCv = withResource(r.getBase.bitCastTo(DType.DURATION_MICROSECONDS)) { dur => - datetimeExpressionsUtils.timestampAddDuration(l.getBase, dur, zoneId) + datetimeExpressionsUtils.timestampAddDurationUs(l.getBase, dur, zoneId) } GpuColumnVector.from(resCv, dataType) case _ => From 0fb8ecb609b1a91a4de47557e8f30b48e60eb55f Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 25 Jan 2024 01:00:12 +0800 Subject: [PATCH 20/21] fix build Signed-off-by: Haoyang Li --- .../apache/spark/sql/rapids/datetimeExpressionsUtils.scala | 6 ++++++ .../apache/spark/sql/rapids/shims/datetimeExpressions.scala | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala index 4f55fff1539..a77e907b46e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala @@ -58,6 +58,8 @@ object datetimeExpressionsUtils { def timestampAddDurationCalendar(cv: ColumnVector, days: Long, microseconds: Long, zoneId: ZoneId): ColumnVector = { + assert(cv.getType == DType.TIMESTAMP_MICROSECONDS, + "cv should be TIMESTAMP_MICROSECONDS type but got " + cv.getType) val interval = days * microSecondsInOneDay + microseconds if (interval == 0) { return cv.incRefCount() @@ -66,6 +68,9 @@ object datetimeExpressionsUtils { cv.binaryOp(BinaryOp.ADD, Scalar.durationFromLong(DType.DURATION_MICROSECONDS, interval), DType.TIMESTAMP_MICROSECONDS) } else { + // For CalendarInterval, microseconds could be larger than 1 day or negative, + // and microseconds in TimeAdd is not affected by timezone, so we need to + // calculate days and microseconds separately. val daysScalar = Scalar.durationFromLong(DType.DURATION_MICROSECONDS, days * microSecondsInOneDay) val resDays = withResource(daysScalar) { _ => @@ -76,6 +81,7 @@ object datetimeExpressionsUtils { microseconds), DType.TIMESTAMP_MICROSECONDS) } } + // The sign of duration will be unchanged considering the impact of timezone withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, interval)) { duration => timeAddOverflowCheck(cv, duration, resWithOverflow) diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index aae3aad6def..b43293413af 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -98,7 +98,9 @@ case class GpuTimeAdd(start: Expression, if (calendarI.months != 0) { throw new UnsupportedOperationException("Months aren't supported at the moment") } - timestampAddDurationCalendar(l, calendarI.days, calendarI.microseconds, timeZone) + val resCv = datetimeExpressionsUtils.timestampAddDurationCalendar(l.getBase, + calendarI.days, calendarI.microseconds, zoneId) + GpuColumnVector.from(resCv, dataType) case _: DayTimeIntervalType => val interval = intervalS.getValue.asInstanceOf[Long] // add interval From a9deb35ab76c1fa10d92f6dab875e3c907177b9e Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 25 Jan 2024 13:39:10 +0800 Subject: [PATCH 21/21] clean up Signed-off-by: Haoyang Li --- .../src/main/python/date_time_test.py | 17 +++----------- .../sql/rapids/datetimeExpressionsUtils.scala | 22 +++++++++++-------- .../rapids/shims/datetimeExpressions.scala | 3 +-- 3 files changed, 17 insertions(+), 25 deletions(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 671eb50e2a9..638913185d3 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -44,7 +44,7 @@ def test_timesub(data_gen): def test_timeadd(data_gen): days, seconds = data_gen assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, TimestampGen(), length=200000) + lambda spark: unary_op_df(spark, TimestampGen()) .selectExpr("a + (interval {} days {} seconds)".format(days, seconds))) @pytest.mark.parametrize('edge_vals', [-pow(2, 63), pow(2, 63)], ids=idfn) @@ -58,25 +58,14 @@ def test_timeadd_long_overflow(edge_vals): @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @allow_non_gpu(*non_supported_tz_allow) -def test_timeadd_daytime_column_invalid(): +def test_timeadd_daytime_column(): gen_list = [ # timestamp column max year is 1000 ('t', TimestampGen(end=datetime(2000, 1, 1, tzinfo=timezone.utc))), # max days is 8000 year, so added result will not be out of range ('d', DayTimeIntervalGen(min_value=timedelta(days=-1000 * 365), max_value=timedelta(days=8000 * 365)))] assert_gpu_and_cpu_are_equal_collect( - lambda spark: gen_df(spark, gen_list).selectExpr("t + INTERVAL 0 DAYS 86400000001 MICROSECONDS")) - -@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') -@allow_non_gpu(*non_supported_tz_allow) -def test_timeadd_daytime_column_debug(): - gen_list = [ - # timestamp column max year is 1000 - ('t', TimestampGen(end=datetime(1000, 1, 1, tzinfo=timezone.utc))), - # max days is 8000 year, so added result will not be out of range - ('d', DayTimeIntervalGen(min_value=timedelta(days=-1000 * 365), max_value=timedelta(days=8000 * 365)))] - assert_gpu_and_cpu_are_equal_collect( - lambda spark: gen_df(spark, gen_list, length=2000000).selectExpr("t", "d", "t + d")) + lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND")) @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') @allow_non_gpu(*non_supported_tz_allow) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala index a77e907b46e..e34e1cf346c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressionsUtils.scala @@ -20,7 +20,7 @@ import java.time.ZoneId import java.util.concurrent.TimeUnit import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar} -import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.GpuOverrides.isUTCTimezone import com.nvidia.spark.rapids.jni.GpuTimeZoneDB @@ -52,7 +52,9 @@ object datetimeExpressionsUtils { case durC: ColumnView => GpuTimeZoneDB.timeAdd(cv, durC, zoneId) } } - timeAddOverflowCheck(cv, duration, resWithOverflow) + closeOnExcept(resWithOverflow) { _ => + timeAddOverflowCheck(cv, duration, resWithOverflow) + } resWithOverflow } @@ -82,9 +84,11 @@ object datetimeExpressionsUtils { } } // The sign of duration will be unchanged considering the impact of timezone - withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, - interval)) { duration => - timeAddOverflowCheck(cv, duration, resWithOverflow) + closeOnExcept(resWithOverflow) { _ => + withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, + interval)) { duration => + timeAddOverflowCheck(cv, duration, resWithOverflow) + } } resWithOverflow } @@ -99,13 +103,13 @@ object datetimeExpressionsUtils { case dur: Scalar => val durLong = Scalar.fromLong(dur.getLong) withResource(durLong) { _ => - AddOverflowChecks.basicOpOverflowCheck( - cvLong, durLong, resWithOverflowLong, "long overflow") + AddOverflowChecks.basicOpOverflowCheck( + cvLong, durLong, resWithOverflowLong, "long overflow") } case dur: ColumnView => withResource(dur.bitCastTo(DType.INT64)) { durationLong => - AddOverflowChecks.basicOpOverflowCheck( - cvLong, durationLong, resWithOverflowLong, "long overflow") + AddOverflowChecks.basicOpOverflowCheck( + cvLong, durationLong, resWithOverflowLong, "long overflow") } case _ => throw new UnsupportedOperationException("only scalar and column arguments " + diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index b43293413af..1e78a638a07 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -108,8 +108,7 @@ case class GpuTimeAdd(start: Expression, val zoneId = ZoneId.of(timeZoneId.getOrElse("UTC")) val resCv = withResource(Scalar.durationFromLong( DType.DURATION_MICROSECONDS, interval)) { duration => - datetimeExpressionsUtils.timestampAddDurationUs( - l.getBase, duration, zoneId) + datetimeExpressionsUtils.timestampAddDurationUs(l.getBase, duration, zoneId) } GpuColumnVector.from(resCv, dataType) } else {