Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TimeAdd for non-UTC time zone #10068

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,38 @@
non_utc_tz_allow = ['ProjectExec'] if not is_utc() else []
# Others work in all supported time zones
non_supported_tz_allow = ['ProjectExec'] if not is_supported_time_zone() else []
non_supported_tz_allow_filter = ['ProjectExec', 'FilterExec'] if not is_supported_time_zone() else []

# We only support literal intervals for TimeSub
vals = [(-584, 1563), (1943, 1101), (2693, 2167), (2729, 0), (44, 1534), (2635, 3319),
(1885, -2828), (0, 2463), (932, 2286), (0, 0)]
@pytest.mark.parametrize('data_gen', vals, ids=idfn)
@allow_non_gpu(*non_utc_allow)
@allow_non_gpu(*non_supported_tz_allow)
def test_timesub(data_gen):
days, seconds = data_gen
assert_gpu_and_cpu_are_equal_collect(
# We are starting at year 0015 to make sure we don't go before year 0001 while doing TimeSub
lambda spark: unary_op_df(spark, TimestampGen(start=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1)
lambda spark: unary_op_df(spark, TimestampGen())
.selectExpr("a - (interval {} days {} seconds)".format(days, seconds)))

@pytest.mark.parametrize('data_gen', vals, ids=idfn)
@allow_non_gpu(*non_utc_allow)
@allow_non_gpu(*non_supported_tz_allow)
def test_timeadd(data_gen):
days, seconds = data_gen
assert_gpu_and_cpu_are_equal_collect(
# We are starting at year 0005 to make sure we don't go before year 0001
# and beyond year 10000 while doing TimeAdd
lambda spark: unary_op_df(spark, TimestampGen(start=datetime(5, 1, 1, tzinfo=timezone.utc), end=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1)
lambda spark: unary_op_df(spark, TimestampGen())
.selectExpr("a + (interval {} days {} seconds)".format(days, seconds)))

@pytest.mark.parametrize('data_gen', [-pow(2, 63), pow(2, 63)], ids=idfn)
revans2 marked this conversation as resolved.
Show resolved Hide resolved
@allow_non_gpu(*non_supported_tz_allow)
def test_timeadd_long_overflow(data_gen):
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, TimestampGen())
.selectExpr("a + (interval {} microseconds)".format(data_gen)),
conf={},
error_message='long overflow')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@allow_non_gpu(*non_utc_allow)
@allow_non_gpu(*non_supported_tz_allow)
def test_timeadd_daytime_column():
gen_list = [
# timestamp column max year is 1000
revans2 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -60,6 +67,15 @@ def test_timeadd_daytime_column():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND"))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@allow_non_gpu(*non_supported_tz_allow)
def test_timeadd_daytime_column_long_overflow():
Copy link
Collaborator

@res-life res-life Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to ensure the random df will 100% overflow?
Maybe specify some constant variables to ensure overflow.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By not making it actually random.

DayTimeIntervalGen Has both a min_value and a max_value. You could set it up so all of the values generated would overflow. You might need to also remove the special cases and disable nulls to be 100% sure of it.

def __init__(self, min_value=MIN_DAY_TIME_INTERVAL, max_value=MAX_DAY_TIME_INTERVAL, start_field="day", end_field="second",

You could also use SetValuesGen with only values in it that would overflow.

class SetValuesGen(DataGen):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to SetValuesGen.

gen_list = [('t', TimestampGen()),('d', DayTimeIntervalGen())]
assert_gpu_and_cpu_error(
lambda spark : gen_df(spark, gen_list).selectExpr("t + d").collect(),
conf={},
error_message='long overflow')

@pytest.mark.skipif(is_before_spark_350(), reason='DayTimeInterval overflow check for seconds is not supported before Spark 3.5.0')
def test_interval_seconds_overflow_exception():
assert_gpu_and_cpu_error(
Expand All @@ -68,7 +84,7 @@ def test_interval_seconds_overflow_exception():
error_message="IllegalArgumentException")

@pytest.mark.parametrize('data_gen', vals, ids=idfn)
@allow_non_gpu(*non_utc_allow)
@allow_non_gpu(*non_supported_tz_allow_filter)
def test_timeadd_from_subquery(data_gen):

def fun(spark):
Expand All @@ -80,7 +96,7 @@ def fun(spark):
assert_gpu_and_cpu_are_equal_collect(fun)

@pytest.mark.parametrize('data_gen', vals, ids=idfn)
@allow_non_gpu(*non_utc_allow)
@allow_non_gpu(*non_supported_tz_allow)
def test_timesub_from_subquery(data_gen):

def fun(spark):
Expand Down Expand Up @@ -135,19 +151,17 @@ def test_datediff(data_gen):
'datediff(a, date(null))',
'datediff(a, \'2016-03-02\')'))

hms_fallback = ['ProjectExec'] if not is_supported_time_zone() else []

@allow_non_gpu(*hms_fallback)
@allow_non_gpu(*non_supported_tz_allow)
def test_hour():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('hour(a)'))

@allow_non_gpu(*hms_fallback)
@allow_non_gpu(*non_supported_tz_allow)
def test_minute():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('minute(a)'))

@allow_non_gpu(*hms_fallback)
@allow_non_gpu(*non_supported_tz_allow)
def test_second():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('second(a)'))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,8 @@ object GpuOverrides extends Logging {
.withPsNote(TypeEnum.CALENDAR, "month intervals are not supported"),
TypeSig.CALENDAR)),
(timeAdd, conf, p, r) => new BinaryExprMeta[TimeAdd](timeAdd, conf, p, r) {
override def isTimeZoneSupported = true

override def tagExprForGpu(): Unit = {
GpuOverrides.extractLit(timeAdd.interval).foreach { lit =>
val intvl = lit.value.asInstanceOf[CalendarInterval]
Expand All @@ -1655,7 +1657,7 @@ object GpuOverrides extends Logging {
}

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuTimeAdd(lhs, rhs)
GpuTimeAdd(lhs, rhs, timeAdd.timeZoneId)
}),
expr[DateAddInterval](
"Adds interval to date",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ object GpuScalar extends Logging {
*
* This class is introduced because many expressions require both the cudf Scalar and its
* corresponding Scala value to complete their computations. e.g. 'GpuStringSplit',
* 'GpuStringLocate', 'GpuDivide', 'GpuDateAddInterval', 'GpuTimeMath' ...
* 'GpuStringLocate', 'GpuDivide', 'GpuDateAddInterval', 'GpuTimeAdd' ...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Why name changed? It seems different over different Spark version. We can comment both in GpuTimeAdd/GpuTimeMath.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GpuTimeMath was an abstract class being implemented by GpuTimeAdd and GpuDateAddInterval. I removed it because the reusability of the two classes is actually poor.

* So only either a cudf Scalar or a Scala value can not support such cases, unless copying data
* between the host and the device each time being asked for.
*
Expand All @@ -493,7 +493,7 @@ object GpuScalar extends Logging {
* happens.
*
* Another reason why storing the Scala value in addition to the cudf Scalar is
* `GpuDateAddInterval` and 'GpuTimeMath' have different algorithms with the 3 members of
* `GpuDateAddInterval` and 'GpuTimeAdd' have different algorithms with the 3 members of
* a `CalendarInterval`, which can not be supported by a single cudf Scalar now.
*
* Do not create a GpuScalar from the constructor, instead call the factory APIs above.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,15 @@ case class GpuYear(child: Expression) extends GpuDateUnaryExpression {
input.getBase.year()
}

abstract class GpuTimeMath(
start: Expression,
case class GpuDateAddInterval(start: Expression,
interval: Expression,
timeZoneId: Option[String] = None)
extends ShimBinaryExpression
with GpuExpression
with TimeZoneAwareExpression
with ExpectsInputTypes
with Serializable {
timeZoneId: Option[String] = None,
ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
extends ShimBinaryExpression
with GpuExpression
with TimeZoneAwareExpression
with ExpectsInputTypes
with Serializable {

def this(start: Expression, interval: Expression) = this(start, interval, None)

Expand All @@ -157,61 +157,16 @@ abstract class GpuTimeMath(

override def toString: String = s"$left - $right"
override def sql: String = s"${left.sql} - ${right.sql}"
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType)

override def dataType: DataType = TimestampType

override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess

val microSecondsInOneDay: Long = TimeUnit.DAYS.toMicros(1)

override def columnarEval(batch: ColumnarBatch): GpuColumnVector = {
withResourceIfAllowed(left.columnarEval(batch)) { lhs =>
withResourceIfAllowed(right.columnarEvalAny(batch)) { rhs =>
(lhs, rhs) match {
case (l, intvlS: GpuScalar)
if intvlS.dataType.isInstanceOf[CalendarIntervalType] =>
// Scalar does not support 'CalendarInterval' now, so use
// the Scala value instead.
// Skip the null check because it wll be detected by the following calls.
val intvl = intvlS.getValue.asInstanceOf[CalendarInterval]
if (intvl.months != 0) {
throw new UnsupportedOperationException("Months aren't supported at the moment")
}
val usToSub = intvl.days * microSecondsInOneDay + intvl.microseconds
if (usToSub != 0) {
withResource(Scalar.fromLong(usToSub)) { us_s =>
withResource(l.getBase.bitCastTo(DType.INT64)) { us =>
withResource(intervalMath(us_s, us)) { longResult =>
GpuColumnVector.from(longResult.castTo(DType.TIMESTAMP_MICROSECONDS), dataType)
}
}
}
} else {
l.incRefCount()
}
case _ =>
throw new UnsupportedOperationException("only column and interval arguments " +
s"are supported, got left: ${lhs.getClass} right: ${rhs.getClass}")
}
}
}
}

def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector
}

case class GpuDateAddInterval(start: Expression,
interval: Expression,
timeZoneId: Option[String] = None,
ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
extends GpuTimeMath(start, interval, timeZoneId) {

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
}

override def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = {
def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = {
us.add(us_s)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,142 @@
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import ai.rapids.cudf.{ColumnVector, ColumnView, Scalar}
import java.util.concurrent.TimeUnit

import org.apache.spark.sql.catalyst.expressions.{Expression, TimeZoneAwareExpression}
import org.apache.spark.sql.rapids.GpuTimeMath
import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar}
import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource, withResourceIfAllowed}
import com.nvidia.spark.rapids.GpuOverrides
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
import com.nvidia.spark.rapids.shims.ShimBinaryExpression

case class GpuTimeAdd(start: Expression,
interval: Expression,
timeZoneId: Option[String] = None)
extends GpuTimeMath(start, interval, timeZoneId) {
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.CalendarInterval

case class GpuTimeAdd(
start: Expression,
interval: Expression,
timeZoneId: Option[String] = None)
extends ShimBinaryExpression
with GpuExpression
with TimeZoneAwareExpression
with ExpectsInputTypes
with Serializable {

def this(start: Expression, interval: Expression) = this(start, interval, None)

override def left: Expression = start
override def right: Expression = interval

override def toString: String = s"$left - $right"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is an add why do we show it as left - right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

override def sql: String = s"${left.sql} - ${right.sql}"
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType)

override def dataType: DataType = TimestampType

override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess

val microSecondsInOneDay: Long = TimeUnit.DAYS.toMicros(1)

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
}

override def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = {
us.add(us_s)
override def columnarEval(batch: ColumnarBatch): GpuColumnVector = {
withResourceIfAllowed(left.columnarEval(batch)) { lhs =>
revans2 marked this conversation as resolved.
Show resolved Hide resolved
withResourceIfAllowed(right.columnarEvalAny(batch)) { rhs =>
// lhs is start, rhs is interval
(lhs, rhs) match {
case (l, intvlS: GpuScalar) if intvlS.dataType.isInstanceOf[CalendarIntervalType] =>
// Scalar does not support 'CalendarInterval' now, so use
// the Scala value instead.
// Skip the null check because it wll be detected by the following calls.
val intvl = intvlS.getValue.asInstanceOf[CalendarInterval]
if (intvl.months != 0) {
throw new UnsupportedOperationException("Months aren't supported at the moment")
}
val usToSub = intvl.days * microSecondsInOneDay + intvl.microseconds
if (usToSub != 0) {
val res = if (GpuOverrides.isUTCTimezone(zoneId)) {
withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, usToSub)) { d =>
timestampAddDuration(l.getBase, d)
}
} else {
val utcRes = withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(l.getBase,
zoneId)) { utcTimestamp =>
withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, usToSub)) {
d => timestampAddDuration(utcTimestamp, d)
}
}
withResource(utcRes) { _ =>
GpuTimeZoneDB.fromTimestampToUtcTimestamp(utcRes, zoneId)
}
}
GpuColumnVector.from(res, dataType)
Copy link
Collaborator

@res-life res-life Dec 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be careful, seems res leaked.

} else {
l.incRefCount()
}
case _ =>
throw new UnsupportedOperationException("only column and interval arguments " +
s"are supported, got left: ${lhs.getClass} right: ${rhs.getClass}")
}
}
}
}

// A tricky way to check overflow. The result is overflow when positive + positive = negative
// or negative + negative = positive, so we can check the sign of the result is the same as
// the sign of the operands.
private def timestampAddDuration(cv: ColumnView, duration: BinaryOperable): ColumnVector = {
// Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion,
// and currently BinaryOperable.implicitConversion return Long
// Directly specify the return type is TIMESTAMP_MICROSECONDS
val resWithOverflow = cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS)
closeOnExcept(resWithOverflow) { _ =>
val isCvPos = withResource(
revans2 marked this conversation as resolved.
Show resolved Hide resolved
Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, 0)) { zero =>
cv.greaterOrEqualTo(zero)
}
val sameSignal = closeOnExcept(isCvPos) { isCvPos =>
val isDurationPos = duration match {
case durScalar: Scalar =>
val isPosBool = durScalar.isValid && durScalar.getLong >= 0
Scalar.fromBool(isPosBool)
case dur : AutoCloseable =>
withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, 0)) { zero =>
dur.greaterOrEqualTo(zero)
}
}
withResource(isDurationPos) { _ =>
isCvPos.equalTo(isDurationPos)
}
}
val isOverflow = withResource(sameSignal) { _ =>
val sameSignalWithRes = withResource(isCvPos) { _ =>
val isResNeg = withResource(
Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, 0)) { zero =>
resWithOverflow.lessThan(zero)
}
withResource(isResNeg) { _ =>
isCvPos.equalTo(isResNeg)
}
}
withResource(sameSignalWithRes) { _ =>
sameSignal.and(sameSignalWithRes)
}
}
val anyOverflow = withResource(isOverflow) { _ =>
isOverflow.any()
}
withResource(anyOverflow) { _ =>
if (anyOverflow.isValid && anyOverflow.getBoolean) {
throw new ArithmeticException("long overflow")
}
}
}
resWithOverflow
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
("interval", TypeSig.lit(TypeEnum.DAYTIME) + TypeSig.lit(TypeEnum.CALENDAR),
TypeSig.DAYTIME + TypeSig.CALENDAR)),
(timeAdd, conf, p, r) => new BinaryExprMeta[TimeAdd](timeAdd, conf, p, r) {

override def isTimeZoneSupported = true

override def tagExprForGpu(): Unit = {
GpuOverrides.extractLit(timeAdd.interval).foreach { lit =>
lit.dataType match {
Expand All @@ -224,7 +227,7 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
}

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuTimeAdd(lhs, rhs)
GpuTimeAdd(lhs, rhs, timeAdd.timeZoneId)
}),
GpuOverrides.expr[SpecifiedWindowFrame](
"Specification of the width of the group (or \"frame\") of input rows " +
Expand Down
Loading