Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support to_utc_timestamp [databricks] #10144

Merged
merged 6 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/additional-functionality/advanced_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.TimeAdd"></a>spark.rapids.sql.expression.TimeAdd| |Adds interval to timestamp|true|None|
<a name="sql.expression.ToDegrees"></a>spark.rapids.sql.expression.ToDegrees|`degrees`|Converts radians to degrees|true|None|
<a name="sql.expression.ToRadians"></a>spark.rapids.sql.expression.ToRadians|`radians`|Converts degrees to radians|true|None|
<a name="sql.expression.ToUTCTimestamp"></a>spark.rapids.sql.expression.ToUTCTimestamp|`to_utc_timestamp`|Render the input timestamp in UTC|true|None|
<a name="sql.expression.ToUnixTimestamp"></a>spark.rapids.sql.expression.ToUnixTimestamp|`to_unix_timestamp`|Returns the UNIX timestamp of the given time|true|None|
<a name="sql.expression.TransformKeys"></a>spark.rapids.sql.expression.TransformKeys|`transform_keys`|Transform keys in a map using a transform function|true|None|
<a name="sql.expression.TransformValues"></a>spark.rapids.sql.expression.TransformValues|`transform_values`|Transform values in a map using a transform function|true|None|
Expand Down
172 changes: 120 additions & 52 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -15441,6 +15441,74 @@ are limited.
<td> </td>
</tr>
<tr>
<td rowSpan="3">ToUTCTimestamp</td>
<td rowSpan="3">`to_utc_timestamp`</td>
<td rowSpan="3">Render the input timestamp in UTC</td>
<td rowSpan="3">None</td>
<td rowSpan="3">project</td>
<td>timestamp</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>timezone</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="3">ToUnixTimestamp</td>
<td rowSpan="3">`to_unix_timestamp`</td>
<td rowSpan="3">Returns the UNIX timestamp of the given time</td>
Expand Down Expand Up @@ -15645,6 +15713,32 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="4">UnaryMinus</td>
<td rowSpan="4">`negative`</td>
<td rowSpan="4">Negate a numeric value</td>
Expand Down Expand Up @@ -15735,32 +15829,6 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="4">UnaryPositive</td>
<td rowSpan="4">`positive`</td>
<td rowSpan="4">A numeric value with a + in front of it</td>
Expand Down Expand Up @@ -16018,6 +16086,32 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="2">Upper</td>
<td rowSpan="2">`upper`, `ucase`</td>
<td rowSpan="2">String uppercase operator</td>
Expand Down Expand Up @@ -16112,32 +16206,6 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="3">WindowExpression</td>
<td rowSpan="3"> </td>
<td rowSpan="3">Calculates a return value for every input row of a table based on a group (or "window") of rows</td>
Expand Down
53 changes: 32 additions & 21 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -286,41 +286,52 @@ def test_unsupported_fallback_to_unix_timestamp(data_gen):
spark, [("a", data_gen), ("b", string_gen)], length=10).selectExpr(
"to_unix_timestamp(a, b)"),
"ToUnixTimestamp")

supported_timezones = ["Asia/Shanghai", "UTC", "UTC+0", "UTC-0", "GMT", "GMT+0", "GMT-0", "EST", "MST", "VST"]
unsupported_timezones = ["PST", "NST", "AST", "America/Los_Angeles", "America/New_York", "America/Chicago"]

@pytest.mark.parametrize('time_zone', ["Asia/Shanghai", "UTC", "UTC+0", "UTC-0", "GMT", "GMT+0", "GMT-0"], ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
@tz_sensitive_test
Copy link
Collaborator Author

@thirtiseven thirtiseven Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my understanding it's not a tz_sensitive_test case because it will only run on gpu under utc timezone and fallback for all other timezones. We can add the timezones we want to test to the supported_timezones list.

@pytest.mark.parametrize('time_zone', supported_timezones, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_from_utc_timestamp(data_gen, time_zone):
def test_from_utc_timestamp(time_zone):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).select(f.from_utc_timestamp(f.col('a'), time_zone)))
lambda spark: unary_op_df(spark, timestamp_gen).select(f.from_utc_timestamp(f.col('a'), time_zone)))

@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('time_zone', ["PST", "NST", "AST", "America/Los_Angeles", "America/New_York", "America/Chicago"], ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
@tz_sensitive_test
def test_from_utc_timestamp_unsupported_timezone_fallback(data_gen, time_zone):
@pytest.mark.parametrize('time_zone', unsupported_timezones, ids=idfn)
def test_from_utc_timestamp_unsupported_timezone_fallback(time_zone):
assert_gpu_fallback_collect(
lambda spark: unary_op_df(spark, data_gen).select(f.from_utc_timestamp(f.col('a'), time_zone)),
lambda spark: unary_op_df(spark, timestamp_gen).select(f.from_utc_timestamp(f.col('a'), time_zone)),
'FromUTCTimestamp')

@pytest.mark.parametrize('time_zone', ["UTC", "Asia/Shanghai", "EST", "MST", "VST"], ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
@tz_sensitive_test
@allow_non_gpu('ProjectExec')
def test_unsupported_fallback_from_utc_timestamp():
time_zone_gen = StringGen(pattern="UTC")
assert_gpu_fallback_collect(
lambda spark: gen_df(spark, [("a", timestamp_gen), ("tzone", time_zone_gen)]).selectExpr(
"from_utc_timestamp(a, tzone)"),
'FromUTCTimestamp')

@allow_non_gpu(*non_utc_allow)
def test_from_utc_timestamp_supported_timezones(data_gen, time_zone):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case looks duplicated with test_from_utc_timestamp so I combined them.

@pytest.mark.parametrize('time_zone', supported_timezones, ids=idfn)
def test_to_utc_timestamp(time_zone):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).select(f.from_utc_timestamp(f.col('a'), time_zone)))
lambda spark: unary_op_df(spark, timestamp_gen).select(f.to_utc_timestamp(f.col('a'), time_zone)))

@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('time_zone', unsupported_timezones, ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
def test_unsupported_fallback_from_utc_timestamp(data_gen):
time_zone_gen = StringGen(pattern="UTC")
def test_to_utc_timestamp_unsupported_timezone_fallback(data_gen, time_zone):
assert_gpu_fallback_collect(
lambda spark: gen_df(spark, [("a", data_gen), ("tzone", time_zone_gen)]).selectExpr(
"from_utc_timestamp(a, tzone)"),
'FromUTCTimestamp')
lambda spark: unary_op_df(spark, data_gen).select(f.to_utc_timestamp(f.col('a'), time_zone)),
'ToUTCTimestamp')

@allow_non_gpu('ProjectExec')
def test_unsupported_fallback_to_utc_timestamp():
time_zone_gen = StringGen(pattern="UTC")
assert_gpu_fallback_collect(
lambda spark: gen_df(spark, [("a", timestamp_gen), ("tzone", time_zone_gen)]).selectExpr(
"to_utc_timestamp(a, tzone)"),
'ToUTCTimestamp')

@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('data_gen', [long_gen], ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1801,6 +1801,14 @@ object GpuOverrides extends Logging {
TypeSig.lit(TypeEnum.STRING))),
(a, conf, p, r) => new FromUTCTimestampExprMeta(a, conf, p, r)
),
expr[ToUTCTimestamp](
"Render the input timestamp in UTC",
ExprChecks.binaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP,
("timestamp", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP),
("timezone", TypeSig.lit(TypeEnum.STRING),
TypeSig.lit(TypeEnum.STRING))),
(a, conf, p, r) => new ToUTCTimestampExprMeta(a, conf, p, r)
),
expr[Pmod](
"Pmod",
// Decimal support disabled https://github.com/NVIDIA/spark-rapids/issues/7553
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,7 +27,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
import com.nvidia.spark.rapids.shims.ShimBinaryExpression

import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUnixTime, FromUTCTimestamp, ImplicitCastInputTypes, NullIntolerant, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUnixTime, FromUTCTimestamp, ImplicitCastInputTypes, NullIntolerant, TimeZoneAwareExpression, ToUTCTimestamp}
import org.apache.spark.sql.catalyst.util.DateTimeConstants
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -1132,6 +1132,65 @@ case class GpuFromUTCTimestamp(
}
}

class ToUTCTimestampExprMeta(
expr: ToUTCTimestamp,
override val conf: RapidsConf,
override val parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends BinaryExprMeta[ToUTCTimestamp](expr, conf, parent, rule) {

private[this] var timezoneId: ZoneId = null

override def tagExprForGpu(): Unit = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like ToUTCTimestampExprMeta has basically the same logic for whether it will run on the GPU as FromUTCTimestampExprMeta. I think these 2 classes should be refactored to share this logic

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

extractStringLit(expr.right) match {
case None =>
willNotWorkOnGpu("timezone input must be a literal string")
case Some(timezoneShortID) =>
if (timezoneShortID != null) {
timezoneId = GpuTimeZoneDB.getZoneId(timezoneShortID)
if (!GpuTimeZoneDB.isSupportedTimeZone(timezoneId)) {
willNotWorkOnGpu(s"Not supported timezone type $timezoneShortID.")
}
}
}
}

override def convertToGpu(timestamp: Expression, timezone: Expression): GpuExpression =
GpuToUTCTimestamp(timestamp, timezone, timezoneId)
}

case class GpuToUTCTimestamp(
timestamp: Expression, timezone: Expression, zoneId: ZoneId)
extends GpuBinaryExpressionArgsAnyScalar
with ImplicitCastInputTypes
with NullIntolerant {

override def left: Expression = timestamp
override def right: Expression = timezone
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType)
override def dataType: DataType = TimestampType

override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
if (rhs.getBase.isValid) {
if (GpuOverrides.isUTCTimezone(zoneId)) {
// For UTC timezone, just a no-op bypassing GPU computation.
lhs.getBase.incRefCount()
} else {
GpuTimeZoneDB.fromTimestampToUtcTimestamp(lhs.getBase, zoneId)
}
} else {
// All-null output column.
GpuColumnVector.columnVectorFromNull(lhs.getRowCount.toInt, dataType)
}
}

override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {
withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { lhsCol =>
doColumnar(lhsCol, rhs)
}
}
}

trait GpuDateMathBase extends GpuBinaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] =
Seq(DateType, TypeCollection(IntegerType, ShortType, ByteType))
Expand Down
Loading