Skip to content

Commit

Permalink
Support CPU path for from_utc_timestamp function with timezone (#9689)
Browse files Browse the repository at this point in the history
* Support CPU path for from_utc_timestamp function

Signed-off-by: Ferdinand Xu <[email protected]>

* Address comments

* Address comments

* Refactor

* Add a test configuration avoiding exposing CPU backend visible to user

* Address comments and fix failed UTs

* Address comments

* Fix

---------

Signed-off-by: Ferdinand Xu <[email protected]>
  • Loading branch information
winningsix authored Nov 21, 2023
1 parent c6b0a50 commit 2667941
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 29 deletions.
10 changes: 9 additions & 1 deletion integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,22 @@ def test_from_utc_timestamp(data_gen, time_zone):
lambda spark: unary_op_df(spark, data_gen).select(f.from_utc_timestamp(f.col('a'), time_zone)))

@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('time_zone', ["PST", "MST", "EST", "VST", "NST", "AST"], ids=idfn)
@pytest.mark.parametrize('time_zone', ["PST", "NST", "AST"], ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
def test_from_utc_timestamp_unsupported_timezone_fallback(data_gen, time_zone):
assert_gpu_fallback_collect(
lambda spark: unary_op_df(spark, data_gen).select(f.from_utc_timestamp(f.col('a'), time_zone)),
'FromUTCTimestamp')


@pytest.mark.parametrize('time_zone', ["UTC", "Asia/Shanghai", "EST", "MST", "VST"], ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
def test_from_utc_timestamp_supported_timezones(data_gen, time_zone):
# Remove spark.rapids.test.CPU.timezone configuration when GPU kernel is ready to really test on GPU
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).select(f.from_utc_timestamp(f.col('a'), time_zone)), conf = {"spark.rapids.test.CPU.timezone": "true"})


@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
def test_unsupported_fallback_from_utc_timestamp(data_gen):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2056,6 +2056,12 @@ object RapidsConf {
.booleanConf
.createOptional

val TEST_USE_TIMEZONE_CPU_BACKEND = conf("spark.rapids.test.CPU.timezone")
.doc("Only for tests: verify for timezone related functions")
.internal()
.booleanConf
.createOptional

private def printSectionHeader(category: String): Unit =
println(s"\n### $category")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,36 @@
* limitations under the License.
*/

package com.nvidia.spark.rapids.timezone
package org.apache.spark.sql.rapids

import java.time.ZoneId

import ai.rapids.cudf.{ColumnVector, DType, HostColumnVector}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.GpuOverrides

import org.apache.spark.sql.catalyst.util.DateTimeUtils

object TimeZoneDB {
def isUTCTimezone(timezoneId: ZoneId): Boolean = {
timezoneId.normalized() == GpuOverrides.UTC_TIMEZONE_ID
}

// Copied from Spark. Used to format time zone ID string with (+|-)h:mm and (+|-)hh:m
def getZoneId(timezoneId: String): ZoneId = {
val formattedZoneId = timezoneId
// To support the (+|-)h:mm format because it was supported before Spark 3.0.
.replaceFirst("(\\+|\\-)(\\d):", "$10$2:")
// To support the (+|-)hh:m format because it was supported before Spark 3.0.
.replaceFirst("(\\+|\\-)(\\d\\d):(\\d)$", "$1$2:0$3")
DateTimeUtils.getZoneId(formattedZoneId)
}

// Support fixed offset or no transition rule case
def isSupportedTimezone(timezoneId: String): Boolean = {
val rules = getZoneId(timezoneId).getRules
rules.isFixedOffset || rules.getTransitionRules.isEmpty
}

def cacheDatabase(): Unit = {}

Expand All @@ -42,10 +62,14 @@ object TimeZoneDB {
withResource(HostColumnVector.builder(DType.TIMESTAMP_MICROSECONDS, rowCount)) { builder =>
var currRow = 0
while (currRow < rowCount) {
val origin = input.getLong(currRow)
// Spark implementation
val dist = DateTimeUtils.toUTCTime(origin, zoneStr)
builder.append(dist)
if (input.isNull(currRow)) {
builder.appendNull()
} else {
val origin = input.getLong(currRow)
// Spark implementation
val dist = DateTimeUtils.toUTCTime(origin, zoneStr)
builder.append(dist)
}
currRow += 1
}
withResource(builder.build()) { b =>
Expand All @@ -72,10 +96,14 @@ object TimeZoneDB {
withResource(HostColumnVector.builder(DType.TIMESTAMP_MICROSECONDS, rowCount)) { builder =>
var currRow = 0
while (currRow < rowCount) {
val origin = input.getLong(currRow)
// Spark implementation
val dist = DateTimeUtils.fromUTCTime(origin, zoneStr)
builder.append(dist)
if(input.isNull(currRow)) {
builder.appendNull()
} else {
val origin = input.getLong(currRow)
// Spark implementation
val dist = DateTimeUtils.fromUTCTime(origin, zoneStr)
builder.append(dist)
}
currRow += 1
}
withResource(builder.build()) { b =>
Expand All @@ -97,10 +125,14 @@ object TimeZoneDB {
withResource(HostColumnVector.builder(DType.TIMESTAMP_DAYS, rowCount)) { builder =>
var currRow = 0
while (currRow < rowCount) {
val origin = input.getLong(currRow)
// Spark implementation
val dist = DateTimeUtils.microsToDays(origin, currentTimeZone)
builder.append(dist)
if (input.isNull(currRow)) {
builder.appendNull()
} else {
val origin = input.getLong(currRow)
// Spark implementation
val dist = DateTimeUtils.microsToDays(origin, currentTimeZone)
builder.append(dist)
}
currRow += 1
}
withResource(builder.build()) { b =>
Expand All @@ -124,10 +156,14 @@ object TimeZoneDB {
withResource(HostColumnVector.builder(DType.INT64, rowCount)) { builder =>
var currRow = 0
while (currRow < rowCount) {
val origin = input.getInt(currRow)
// Spark implementation
val dist = DateTimeUtils.daysToMicros(origin, desiredTimeZone)
builder.append(dist)
if (input.isNull(currRow)) {
builder.appendNull()
} else {
val origin = input.getInt(currRow)
// Spark implementation
val dist = DateTimeUtils.daysToMicros(origin, desiredTimeZone)
builder.append(dist)
}
currRow += 1
}
withResource(builder.build()) { b =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import ai.rapids.cudf.{BinaryOp, CaptureGroups, ColumnVector, ColumnView, DType,
import com.nvidia.spark.rapids.{BinaryExprMeta, BoolUtils, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuBinaryExpressionArgsAnyScalar, GpuCast, GpuColumnVector, GpuExpression, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta}
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.GpuOverrides.{extractStringLit, getTimeParserPolicy}
import com.nvidia.spark.rapids.RapidsConf.TEST_USE_TIMEZONE_CPU_BACKEND
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.ShimBinaryExpression

Expand Down Expand Up @@ -1044,29 +1045,38 @@ class FromUTCTimestampExprMeta(
rule: DataFromReplacementRule)
extends BinaryExprMeta[FromUTCTimestamp](expr, conf, parent, rule) {

private[this] var timezoneId: ZoneId = null
private[this] val isOnCPU: Boolean = conf.get(TEST_USE_TIMEZONE_CPU_BACKEND).getOrElse(false)

override def tagExprForGpu(): Unit = {
extractStringLit(expr.right) match {
case None =>
willNotWorkOnGpu("timezone input must be a literal string")
case Some(timezoneShortID) =>
if (timezoneShortID != null) {
val utc = ZoneId.of("UTC").normalized
// This is copied from Spark, to convert `(+|-)h:mm` into `(+|-)0h:mm`.
val timezone = ZoneId.of(timezoneShortID.replaceFirst("(\\+|\\-)(\\d):", "$10$2:"),
ZoneId.SHORT_IDS).normalized

if (timezone != utc) {
willNotWorkOnGpu("only timezones equivalent to UTC are supported")
timezoneId = TimeZoneDB.getZoneId(timezoneShortID)
// Always pass for UTC timezone since it's no-op.
if (!TimeZoneDB.isUTCTimezone(timezoneId)) {
// Check CPU path, mostly for test purpose
if (isOnCPU) {
if(!TimeZoneDB.isSupportedTimezone(timezoneShortID)) {
willNotWorkOnGpu(s"Not supported timezone type $timezoneShortID.")
}
} else {
// TODO: remove this once GPU backend was supported.
willNotWorkOnGpu(s"Not supported timezone type $timezoneShortID.")
}
}
}
}
}

override def convertToGpu(timestamp: Expression, timezone: Expression): GpuExpression =
GpuFromUTCTimestamp(timestamp, timezone)
GpuFromUTCTimestamp(timestamp, timezone, timezoneId, isOnCPU)
}

case class GpuFromUTCTimestamp(timestamp: Expression, timezone: Expression)
case class GpuFromUTCTimestamp(
timestamp: Expression, timezone: Expression, zoneId: ZoneId, isOnCPU: Boolean)
extends GpuBinaryExpressionArgsAnyScalar
with ImplicitCastInputTypes
with NullIntolerant {
Expand All @@ -1078,8 +1088,18 @@ case class GpuFromUTCTimestamp(timestamp: Expression, timezone: Expression)

override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
if (rhs.getBase.isValid) {
// Just a no-op.
lhs.getBase.incRefCount()
if (TimeZoneDB.isUTCTimezone(zoneId)) {
// For UTC timezone, just a no-op bypassing GPU computation.
lhs.getBase.incRefCount()
} else {
if (isOnCPU){
TimeZoneDB.fromUtcTimestampToTimestamp(lhs.getBase, zoneId)
} else {
// TODO: remove this until GPU backend supported.
throw new UnsupportedOperationException(
s"Not supported timezone type ${zoneId.normalized()}")
}
}
} else {
// All-null output column.
GpuColumnVector.columnVectorFromNull(lhs.getRowCount.toInt, dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.nvidia.spark.rapids.SparkQueryCompareTestSuite

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.rapids.TimeZoneDB
import org.apache.spark.sql.types._

class TimeZoneSuite extends SparkQueryCompareTestSuite {
Expand Down

0 comments on commit 2667941

Please sign in to comment.