Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TimeAdd for non-UTC time zone #10068

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 62 additions & 27 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -26,39 +26,76 @@
non_utc_tz_allow = ['ProjectExec'] if not is_utc() else []
# Others work in all supported time zones
non_supported_tz_allow = ['ProjectExec'] if not is_supported_time_zone() else []
non_supported_tz_allow_filter = ['ProjectExec', 'FilterExec'] if not is_supported_time_zone() else []

# We only support literal intervals for TimeSub
vals = [(-584, 1563), (1943, 1101), (2693, 2167), (2729, 0), (44, 1534), (2635, 3319),
(1885, -2828), (0, 2463), (932, 2286), (0, 0)]
vals = [(0, 1)]
@pytest.mark.parametrize('data_gen', vals, ids=idfn)
@allow_non_gpu(*non_utc_allow)
@allow_non_gpu(*non_supported_tz_allow)
def test_timesub(data_gen):
days, seconds = data_gen
assert_gpu_and_cpu_are_equal_collect(
# We are starting at year 0015 to make sure we don't go before year 0001 while doing TimeSub
lambda spark: unary_op_df(spark, TimestampGen(start=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1)
lambda spark: unary_op_df(spark, TimestampGen())
.selectExpr("a - (interval {} days {} seconds)".format(days, seconds)))

@pytest.mark.parametrize('data_gen', vals, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_timeadd(data_gen):
@allow_non_gpu(*non_supported_tz_allow)
def test_timeadd_special(data_gen):
days, seconds = data_gen
assert_gpu_and_cpu_are_equal_collect(
# We are starting at year 0005 to make sure we don't go before year 0001
# and beyond year 10000 while doing TimeAdd
lambda spark: unary_op_df(spark, TimestampGen(start=datetime(5, 1, 1, tzinfo=timezone.utc), end=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1)
lambda spark: unary_op_df(spark, TimestampGen(start=datetime(1900, 12, 31, 15, 55, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc)), length=100)
.selectExpr("a + (interval {} days {} seconds)".format(days, seconds)))

def test_to_utc_timestamp_and_from_utc_timestamp():
revans2 marked this conversation as resolved.
Show resolved Hide resolved
aaa = with_cpu_session(lambda spark: unary_op_df(spark, TimestampGen(start=datetime(1900, 12, 31, 15, 55, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))).collect())
bbb = with_cpu_session(lambda spark: unary_op_df(spark, TimestampGen(start=datetime(1900, 12, 31, 15, 55, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))).selectExpr("from_utc_timestamp(to_utc_timestamp(a, 'Asia/Shanghai'), 'Asia/Shanghai')").collect())
assert aaa == bbb

@pytest.mark.parametrize('edge_value', [-pow(2, 63), pow(2, 63)], ids=idfn)
@allow_non_gpu(*non_supported_tz_allow)
def test_timeadd_long_overflow(edge_value):
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, TimestampGen())
.selectExpr("a + (interval {} microseconds)".format(edge_value)),
conf={},
error_message='long overflow')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@allow_non_gpu(*non_utc_allow)
def test_timeadd_daytime_column():
@allow_non_gpu(*non_supported_tz_allow)
def test_timeadd_daytime_column_normal():
gen_list = [
# timestamp column max year is 1000
revans2 marked this conversation as resolved.
Show resolved Hide resolved
('t', TimestampGen(end=datetime(1000, 1, 1, tzinfo=timezone.utc))),
('t', TimestampGen(start=datetime(1900, 12, 31, 15, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))),
# max days is 8000 year, so added result will not be out of range
revans2 marked this conversation as resolved.
Show resolved Hide resolved
('d', DayTimeIntervalGen(min_value=timedelta(days=0), max_value=timedelta(days=8000 * 365)))]
('d', DayTimeIntervalGen(min_value=timedelta(seconds=0), max_value=timedelta(seconds=0)))]
revans2 marked this conversation as resolved.
Show resolved Hide resolved
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, gen_list, length=2048).selectExpr("t", "d", "t + d"))

def test_timeadd_cpu_only():
revans2 marked this conversation as resolved.
Show resolved Hide resolved
gen_list = [
# timestamp column max year is 1000
revans2 marked this conversation as resolved.
Show resolved Hide resolved
('t', TimestampGen(start=datetime(1900, 12, 31, 15, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc))),
# max days is 8000 year, so added result will not be out of range
('d', DayTimeIntervalGen(min_value=timedelta(seconds=0), max_value=timedelta(seconds=0)))]
cpu_before = with_cpu_session(lambda spark: gen_df(spark, gen_list, length=2048).collect())
cpu_after = with_cpu_session(lambda spark: gen_df(spark, gen_list, length=2048).selectExpr("t + d").collect())
assert cpu_before == cpu_after

def test_to_utc_timestamp():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND"))
lambda spark: unary_op_df(spark, TimestampGen(start=datetime(1900, 12, 31, 15, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 16, tzinfo=timezone.utc)))
.selectExpr("to_utc_timestamp(a, 'Asia/Shanghai')"))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@allow_non_gpu(*non_supported_tz_allow)
def test_timeadd_daytime_column_long_overflow():
Copy link
Collaborator

@res-life res-life Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to ensure the random df will 100% overflow?
Maybe specify some constant variables to ensure overflow.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By not making it actually random.

DayTimeIntervalGen Has both a min_value and a max_value. You could set it up so all of the values generated would overflow. You might need to also remove the special cases and disable nulls to be 100% sure of it.

def __init__(self, min_value=MIN_DAY_TIME_INTERVAL, max_value=MAX_DAY_TIME_INTERVAL, start_field="day", end_field="second",

You could also use SetValuesGen with only values in it that would overflow.

class SetValuesGen(DataGen):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to SetValuesGen.

overflow_gen = SetValuesGen(DayTimeIntervalType(),
[timedelta(microseconds=-pow(2, 63)), timedelta(microseconds=(pow(2, 63) - 1))])
gen_list = [('t', TimestampGen()),('d', overflow_gen)]
assert_gpu_and_cpu_error(
lambda spark : gen_df(spark, gen_list).selectExpr("t + d").collect(),
conf={},
error_message='long overflow')

@pytest.mark.skipif(is_before_spark_350(), reason='DayTimeInterval overflow check for seconds is not supported before Spark 3.5.0')
def test_interval_seconds_overflow_exception():
Expand All @@ -68,7 +105,7 @@ def test_interval_seconds_overflow_exception():
error_message="IllegalArgumentException")

@pytest.mark.parametrize('data_gen', vals, ids=idfn)
@allow_non_gpu(*non_utc_allow)
@allow_non_gpu(*non_supported_tz_allow_filter)
def test_timeadd_from_subquery(data_gen):

def fun(spark):
Expand All @@ -80,7 +117,7 @@ def fun(spark):
assert_gpu_and_cpu_are_equal_collect(fun)

@pytest.mark.parametrize('data_gen', vals, ids=idfn)
@allow_non_gpu(*non_utc_allow)
@allow_non_gpu(*non_supported_tz_allow)
def test_timesub_from_subquery(data_gen):

def fun(spark):
Expand Down Expand Up @@ -135,19 +172,17 @@ def test_datediff(data_gen):
'datediff(a, date(null))',
'datediff(a, \'2016-03-02\')'))

hms_fallback = ['ProjectExec'] if not is_supported_time_zone() else []

@allow_non_gpu(*hms_fallback)
@allow_non_gpu(*non_supported_tz_allow)
def test_hour():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('hour(a)'))

@allow_non_gpu(*hms_fallback)
@allow_non_gpu(*non_supported_tz_allow)
def test_minute():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('minute(a)'))

@allow_non_gpu(*hms_fallback)
@allow_non_gpu(*non_supported_tz_allow)
def test_second():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('second(a)'))
Expand Down Expand Up @@ -288,7 +323,7 @@ def test_unsupported_fallback_to_unix_timestamp(data_gen):
"ToUnixTimestamp")

@pytest.mark.parametrize('time_zone', ["Asia/Shanghai", "Iran", "UTC", "UTC+0", "UTC-0", "GMT", "GMT+0", "GMT-0"], ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
@pytest.mark.parametrize('data_gen', [TimestampGen(start=datetime(1900, 12, 31, 7, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 23, tzinfo=timezone.utc))], ids=idfn)
revans2 marked this conversation as resolved.
Show resolved Hide resolved
@tz_sensitive_test
@allow_non_gpu(*non_utc_allow)
def test_from_utc_timestamp(data_gen, time_zone):
Expand All @@ -297,23 +332,23 @@ def test_from_utc_timestamp(data_gen, time_zone):

@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('time_zone', ["PST", "NST", "AST", "America/Los_Angeles", "America/New_York", "America/Chicago"], ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
@pytest.mark.parametrize('data_gen', [TimestampGen(start=datetime(1900, 12, 31, 7, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 23, tzinfo=timezone.utc))], ids=idfn)
revans2 marked this conversation as resolved.
Show resolved Hide resolved
@tz_sensitive_test
def test_from_utc_timestamp_unsupported_timezone_fallback(data_gen, time_zone):
assert_gpu_fallback_collect(
lambda spark: unary_op_df(spark, data_gen).select(f.from_utc_timestamp(f.col('a'), time_zone)),
'FromUTCTimestamp')

@pytest.mark.parametrize('time_zone', ["UTC", "Asia/Shanghai", "EST", "MST", "VST"], ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
@pytest.mark.parametrize('data_gen', [TimestampGen(start=datetime(1900, 12, 31, 7, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 23, tzinfo=timezone.utc))], ids=idfn)
revans2 marked this conversation as resolved.
Show resolved Hide resolved
@tz_sensitive_test
@allow_non_gpu(*non_utc_allow)
def test_from_utc_timestamp_supported_timezones(data_gen, time_zone):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).select(f.from_utc_timestamp(f.col('a'), time_zone)))

@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
@pytest.mark.parametrize('data_gen', [TimestampGen(start=datetime(1900, 12, 31, 7, tzinfo=timezone.utc), end=datetime(1900, 12, 31, 23, tzinfo=timezone.utc))], ids=idfn)
revans2 marked this conversation as resolved.
Show resolved Hide resolved
def test_unsupported_fallback_from_utc_timestamp(data_gen):
time_zone_gen = StringGen(pattern="UTC")
assert_gpu_fallback_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,8 @@ object GpuOverrides extends Logging {
.withPsNote(TypeEnum.CALENDAR, "month intervals are not supported"),
TypeSig.CALENDAR)),
(timeAdd, conf, p, r) => new BinaryExprMeta[TimeAdd](timeAdd, conf, p, r) {
override def isTimeZoneSupported = true

override def tagExprForGpu(): Unit = {
GpuOverrides.extractLit(timeAdd.interval).foreach { lit =>
val intvl = lit.value.asInstanceOf[CalendarInterval]
Expand All @@ -1655,7 +1657,7 @@ object GpuOverrides extends Logging {
}

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuTimeAdd(lhs, rhs)
GpuTimeAdd(lhs, rhs, timeAdd.timeZoneId)
}),
expr[DateAddInterval](
"Adds interval to date",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ object GpuScalar extends Logging {
*
* This class is introduced because many expressions require both the cudf Scalar and its
* corresponding Scala value to complete their computations. e.g. 'GpuStringSplit',
* 'GpuStringLocate', 'GpuDivide', 'GpuDateAddInterval', 'GpuTimeMath' ...
* 'GpuStringLocate', 'GpuDivide', 'GpuDateAddInterval', 'GpuTimeAdd' ...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Why name changed? It seems different over different Spark version. We can comment both in GpuTimeAdd/GpuTimeMath.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GpuTimeMath was an abstract class being implemented by GpuTimeAdd and GpuDateAddInterval. I removed it because the reusability of the two classes is actually poor.

* So only either a cudf Scalar or a Scala value can not support such cases, unless copying data
* between the host and the device each time being asked for.
*
Expand All @@ -493,7 +493,7 @@ object GpuScalar extends Logging {
* happens.
*
* Another reason why storing the Scala value in addition to the cudf Scalar is
* `GpuDateAddInterval` and 'GpuTimeMath' have different algorithms with the 3 members of
* `GpuDateAddInterval` and 'GpuTimeAdd' have different algorithms with the 3 members of
* a `CalendarInterval`, which can not be supported by a single cudf Scalar now.
*
* Do not create a GpuScalar from the constructor, instead call the factory APIs above.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,7 +37,8 @@ object AddOverflowChecks {
def basicOpOverflowCheck(
lhs: BinaryOperable,
rhs: BinaryOperable,
ret: ColumnVector): Unit = {
ret: ColumnVector,
msg: String = "One or more rows overflow for Add operation."): Unit = {
// Check overflow. It is true if the arguments have different signs and
// the sign of the result is different from the sign of x.
// Which is equal to "((x ^ r) & (y ^ r)) < 0" in the form of arithmetic.
Expand All @@ -54,9 +55,7 @@ object AddOverflowChecks {
withResource(signDiffCV) { signDiff =>
withResource(signDiff.any()) { any =>
if (any.isValid && any.getBoolean) {
throw RapidsErrorUtils.arithmeticOverflowError(
"One or more rows overflow for Add operation."
)
throw RapidsErrorUtils.arithmeticOverflowError(msg)
}
}
}
Expand Down Expand Up @@ -114,7 +113,8 @@ object SubtractOverflowChecks {
def basicOpOverflowCheck(
lhs: BinaryOperable,
rhs: BinaryOperable,
ret: ColumnVector): Unit = {
ret: ColumnVector,
msg: String = "One or more rows overflow for Add operation."): Unit = {
// Check overflow. It is true if the arguments have different signs and
// the sign of the result is different from the sign of x.
// Which is equal to "((x ^ y) & (x ^ r)) < 0" in the form of arithmetic.
Expand All @@ -131,8 +131,7 @@ object SubtractOverflowChecks {
withResource(signDiffCV) { signDiff =>
withResource(signDiff.any()) { any =>
if (any.isValid && any.getBoolean) {
throw RapidsErrorUtils.
arithmeticOverflowError("One or more rows overflow for Subtract operation.")
throw RapidsErrorUtils.arithmeticOverflowError(msg)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,15 @@ case class GpuYear(child: Expression) extends GpuDateUnaryExpression {
input.getBase.year()
}

abstract class GpuTimeMath(
start: Expression,
case class GpuDateAddInterval(start: Expression,
interval: Expression,
timeZoneId: Option[String] = None)
extends ShimBinaryExpression
with GpuExpression
with TimeZoneAwareExpression
with ExpectsInputTypes
with Serializable {
timeZoneId: Option[String] = None,
ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
extends ShimBinaryExpression
with GpuExpression
with TimeZoneAwareExpression
with ExpectsInputTypes
with Serializable {

def this(start: Expression, interval: Expression) = this(start, interval, None)

Expand All @@ -157,61 +157,16 @@ abstract class GpuTimeMath(

override def toString: String = s"$left - $right"
override def sql: String = s"${left.sql} - ${right.sql}"
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType)

override def dataType: DataType = TimestampType

override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess

val microSecondsInOneDay: Long = TimeUnit.DAYS.toMicros(1)

override def columnarEval(batch: ColumnarBatch): GpuColumnVector = {
withResourceIfAllowed(left.columnarEval(batch)) { lhs =>
withResourceIfAllowed(right.columnarEvalAny(batch)) { rhs =>
(lhs, rhs) match {
case (l, intvlS: GpuScalar)
if intvlS.dataType.isInstanceOf[CalendarIntervalType] =>
// Scalar does not support 'CalendarInterval' now, so use
// the Scala value instead.
// Skip the null check because it wll be detected by the following calls.
val intvl = intvlS.getValue.asInstanceOf[CalendarInterval]
if (intvl.months != 0) {
throw new UnsupportedOperationException("Months aren't supported at the moment")
}
val usToSub = intvl.days * microSecondsInOneDay + intvl.microseconds
if (usToSub != 0) {
withResource(Scalar.fromLong(usToSub)) { us_s =>
withResource(l.getBase.bitCastTo(DType.INT64)) { us =>
withResource(intervalMath(us_s, us)) { longResult =>
GpuColumnVector.from(longResult.castTo(DType.TIMESTAMP_MICROSECONDS), dataType)
}
}
}
} else {
l.incRefCount()
}
case _ =>
throw new UnsupportedOperationException("only column and interval arguments " +
s"are supported, got left: ${lhs.getClass} right: ${rhs.getClass}")
}
}
}
}

def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector
}

case class GpuDateAddInterval(start: Expression,
interval: Expression,
timeZoneId: Option[String] = None,
ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
extends GpuTimeMath(start, interval, timeZoneId) {

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
}

override def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = {
def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = {
us.add(us_s)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.rapids

import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}

object datetimeExpressionsUtils {
def timestampAddDuration(cv: ColumnView, duration: BinaryOperable): ColumnVector = {
// Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion,
// and currently BinaryOperable.implicitConversion return Long
// Directly specify the return type is TIMESTAMP_MICROSECONDS
val resWithOverflow = cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS)
closeOnExcept(resWithOverflow) { _ =>
withResource(resWithOverflow.castTo(DType.INT64)) { resWithOverflowLong =>
withResource(cv.bitCastTo(DType.INT64)) { cvLong =>
duration match {
case dur: Scalar =>
val durLong = Scalar.fromLong(dur.getLong)
withResource(durLong) { _ =>
AddOverflowChecks.basicOpOverflowCheck(
cvLong, durLong, resWithOverflowLong, "long overflow")
}
case dur: ColumnView =>
withResource(dur.bitCastTo(DType.INT64)) { durationLong =>
AddOverflowChecks.basicOpOverflowCheck(
cvLong, durationLong, resWithOverflowLong, "long overflow")
}
case _ =>
throw new UnsupportedOperationException("only scalar and column arguments " +
s"are supported, got ${duration.getClass}")
}
}
}
}
resWithOverflow
}
}
Loading