Skip to content

Commit

Permalink
GpuFromUnixTime supports more formats by post process (#10023)
Browse files Browse the repository at this point in the history
Support format 'yyyyMMdd' for GpuFromUnixTime

Signed-off-by: Firestarman <[email protected]>
  • Loading branch information
firestarman authored Dec 13, 2023
1 parent 2b43851 commit b309f70
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 16 deletions.
2 changes: 1 addition & 1 deletion integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def test_date_format(data_gen, date_format):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)))

@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn)
@pytest.mark.parametrize('date_format', supported_date_formats + ['yyyyMMdd'], ids=idfn)
# from 0001-02-01 to 9999-12-30 to avoid 'year 0 is out of range'
@pytest.mark.parametrize('data_gen', [LongGen(min_val=int(datetime(1, 2, 1).timestamp()), max_val=int(datetime(9999, 12, 30).timestamp()))], ids=idfn)
@pytest.mark.skipif(not is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -217,15 +217,17 @@ object DateUtils {
def tagAndGetCudfFormat(
meta: RapidsMeta[_, _, _],
sparkFormat: String,
parseString: Boolean): String = {
parseString: Boolean,
inputFormat: Option[String] = None): String = {
val formatToConvert = inputFormat.getOrElse(sparkFormat)
var strfFormat: String = null
if (GpuOverrides.getTimeParserPolicy == LegacyTimeParserPolicy) {
try {
// try and convert the format to cuDF format - this will throw an exception if
// the format contains unsupported characters or words
strfFormat = toStrf(sparkFormat, parseString)
strfFormat = toStrf(formatToConvert, parseString)
// format parsed ok but we have no 100% compatible formats in LEGACY mode
if (GpuToTimestamp.LEGACY_COMPATIBLE_FORMATS.contains(sparkFormat)) {
if (GpuToTimestamp.LEGACY_COMPATIBLE_FORMATS.contains(formatToConvert)) {
// LEGACY support has a number of issues that mean we cannot guarantee
// compatibility with CPU
// - we can only support 4 digit years but Spark supports a wider range
Expand All @@ -249,9 +251,9 @@ object DateUtils {
try {
// try and convert the format to cuDF format - this will throw an exception if
// the format contains unsupported characters or words
strfFormat = toStrf(sparkFormat, parseString)
strfFormat = toStrf(formatToConvert, parseString)
// format parsed ok, so it is either compatible (tested/certified) or incompatible
if (!GpuToTimestamp.CORRECTED_COMPATIBLE_FORMATS.contains(sparkFormat) &&
if (!GpuToTimestamp.CORRECTED_COMPATIBLE_FORMATS.contains(formatToConvert) &&
!meta.conf.incompatDateFormats) {
meta.willNotWorkOnGpu(s"CORRECTED format '$sparkFormat' on the GPU is not guaranteed " +
s"to produce the same results as Spark on CPU. Set " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1769,12 +1769,7 @@ object GpuOverrides extends Logging {
("format", TypeSig.lit(TypeEnum.STRING)
.withPsNote(TypeEnum.STRING, "Only a limited number of formats are supported"),
TypeSig.STRING)),
(a, conf, p, r) => new UnixTimeExprMeta[FromUnixTime](a, conf, p, r) {
override def isTimeZoneSupported = true
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
// passing the already converted strf string for a little optimization
GpuFromUnixTime(lhs, rhs, strfFormat, a.timeZoneId)
}),
(a, conf, p, r) => new FromUnixTimeMeta(a ,conf ,p ,r)),
expr[FromUTCTimestamp](
"Render the input UTC timestamp in the input timezone",
ExprChecks.binaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
import com.nvidia.spark.rapids.shims.ShimBinaryExpression

import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUTCTimestamp, ImplicitCastInputTypes, NullIntolerant, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUnixTime, FromUTCTimestamp, ImplicitCastInputTypes, NullIntolerant, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.util.DateTimeConstants
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -914,19 +914,83 @@ case class GpuGetTimestamp(
override def right: Expression = format
}

class FromUnixTimeMeta(a: FromUnixTime,
override val conf: RapidsConf,
val p: Option[RapidsMeta[_, _, _]],
r: DataFromReplacementRule) extends UnixTimeExprMeta[FromUnixTime](a, conf, p, r) {

private type FmtConverter = ColumnView => ColumnVector

private var colConverter: Option[FmtConverter] = None

/**
* More supported formats by post conversions. The idea is
* 1) Map the unsupported target format to a supported format as
* the intermediate format,
* 2) Call into cuDF with this intermediate format,
* 3) Run a post conversion to get the right output for the target format.
*
* NOTE: Need to remove the entry if the key format is supported by cuDF.
*/
private val FORMATS_BY_CONVERSION: Map[String, (String, FmtConverter)] = Map(
// spark format -> (intermediate format, converter)
"yyyyMMdd" -> (("yyyy-MM-dd",
col => {
withResource(Scalar.fromString("-")) { dashStr =>
withResource(Scalar.fromString("")) { emptyStr =>
col.stringReplace(dashStr, emptyStr)
}
}
}
))
)

override def tagExprForGpu(): Unit = {
extractStringLit(a.right) match {
case Some(rightLit) =>
sparkFormat = rightLit
var inputFormat: Option[String] = None
FORMATS_BY_CONVERSION.get(sparkFormat).foreach { case (tempFormat, converter) =>
colConverter = Some(converter)
inputFormat = Some(tempFormat)
}
strfFormat = DateUtils.tagAndGetCudfFormat(this, sparkFormat,
a.left.dataType == DataTypes.StringType, inputFormat)
case None =>
willNotWorkOnGpu("format has to be a string literal")
}
}

override def isTimeZoneSupported = true
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
// passing the already converted strf string for a little optimization
GpuFromUnixTime(lhs, rhs, strfFormat, colConverter, a.timeZoneId)
}
}

case class GpuFromUnixTime(
sec: Expression,
format: Expression,
strfFormat: String,
timeZoneId: Option[String] = None)
colConverter: Option[ColumnView => ColumnVector],
timeZoneId: Option[String])
extends GpuBinaryExpressionArgsAnyScalar
with TimeZoneAwareExpression
with ImplicitCastInputTypes {

// To avoid duplicated "if...else" for each input batch
private val convertFunc: ColumnVector => ColumnVector = {
if (colConverter.isDefined) {
col => withResource(col)(colConverter.get.apply)
} else {
identity[ColumnVector]
}
}

override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
// we aren't using rhs as it was already converted in the GpuOverrides while creating the
// expressions map and passed down here as strfFormat
withResource(lhs.getBase.asTimestampSeconds) { secondCV =>
val ret = withResource(lhs.getBase.asTimestampSeconds) { secondCV =>
if (GpuOverrides.isUTCTimezone(zoneId)) {
// UTC time zone
secondCV.asStrings(strfFormat)
Expand All @@ -937,6 +1001,7 @@ case class GpuFromUnixTime(
}
}
}
convertFunc(ret)
}

override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {
Expand Down

0 comments on commit b309f70

Please sign in to comment.