From b309f70a3d51a001ba5bd859dce89e693a6191c2 Mon Sep 17 00:00:00 2001 From: Liangcai Li Date: Wed, 13 Dec 2023 09:08:57 +0800 Subject: [PATCH] GpuFromUnixTime supports more formats by post process (#10023) Support format 'yyyyMMdd' for GpuFromUnixTime Signed-off-by: Firestarman --- .../src/main/python/date_time_test.py | 2 +- .../com/nvidia/spark/rapids/DateUtils.scala | 14 ++-- .../nvidia/spark/rapids/GpuOverrides.scala | 7 +- .../sql/rapids/datetimeExpressions.scala | 71 ++++++++++++++++++- 4 files changed, 78 insertions(+), 16 deletions(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 1d4ce5e65d8..4434655b8e3 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -455,7 +455,7 @@ def test_date_format(data_gen, date_format): assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format))) -@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn) +@pytest.mark.parametrize('date_format', supported_date_formats + ['yyyyMMdd'], ids=idfn) # from 0001-02-01 to 9999-12-30 to avoid 'year 0 is out of range' @pytest.mark.parametrize('data_gen', [LongGen(min_val=int(datetime(1, 2, 1).timestamp()), max_val=int(datetime(9999, 12, 30).timestamp()))], ids=idfn) @pytest.mark.skipif(not is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported") diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala index ccb2e91f57a..771f8ecc695 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -217,15 +217,17 @@ object DateUtils { def tagAndGetCudfFormat( meta: RapidsMeta[_, _, _], sparkFormat: String, - parseString: Boolean): String = { + parseString: Boolean, + inputFormat: Option[String] = None): String = { + val formatToConvert = inputFormat.getOrElse(sparkFormat) var strfFormat: String = null if (GpuOverrides.getTimeParserPolicy == LegacyTimeParserPolicy) { try { // try and convert the format to cuDF format - this will throw an exception if // the format contains unsupported characters or words - strfFormat = toStrf(sparkFormat, parseString) + strfFormat = toStrf(formatToConvert, parseString) // format parsed ok but we have no 100% compatible formats in LEGACY mode - if (GpuToTimestamp.LEGACY_COMPATIBLE_FORMATS.contains(sparkFormat)) { + if (GpuToTimestamp.LEGACY_COMPATIBLE_FORMATS.contains(formatToConvert)) { // LEGACY support has a number of issues that mean we cannot guarantee // compatibility with CPU // - we can only support 4 digit years but Spark supports a wider range @@ -249,9 +251,9 @@ object DateUtils { try { // try and convert the format to cuDF format - this will throw an exception if // the format contains unsupported characters or words - strfFormat = toStrf(sparkFormat, parseString) + strfFormat = toStrf(formatToConvert, parseString) // format parsed ok, so it is either compatible (tested/certified) or incompatible - if (!GpuToTimestamp.CORRECTED_COMPATIBLE_FORMATS.contains(sparkFormat) && + if (!GpuToTimestamp.CORRECTED_COMPATIBLE_FORMATS.contains(formatToConvert) && !meta.conf.incompatDateFormats) { meta.willNotWorkOnGpu(s"CORRECTED format '$sparkFormat' on the GPU is not guaranteed " + s"to produce the same results as Spark on CPU. Set " + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index c63e615e98b..13369a25a15 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1769,12 +1769,7 @@ object GpuOverrides extends Logging { ("format", TypeSig.lit(TypeEnum.STRING) .withPsNote(TypeEnum.STRING, "Only a limited number of formats are supported"), TypeSig.STRING)), - (a, conf, p, r) => new UnixTimeExprMeta[FromUnixTime](a, conf, p, r) { - override def isTimeZoneSupported = true - override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - // passing the already converted strf string for a little optimization - GpuFromUnixTime(lhs, rhs, strfFormat, a.timeZoneId) - }), + (a, conf, p, r) => new FromUnixTimeMeta(a ,conf ,p ,r)), expr[FromUTCTimestamp]( "Render the input UTC timestamp in the input timezone", ExprChecks.binaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 7e4de33bb70..238a65a3a65 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -27,7 +27,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.jni.GpuTimeZoneDB import com.nvidia.spark.rapids.shims.ShimBinaryExpression -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUTCTimestamp, ImplicitCastInputTypes, NullIntolerant, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUnixTime, FromUTCTimestamp, ImplicitCastInputTypes, NullIntolerant, TimeZoneAwareExpression} import org.apache.spark.sql.catalyst.util.DateTimeConstants import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -914,19 +914,83 @@ case class GpuGetTimestamp( override def right: Expression = format } +class FromUnixTimeMeta(a: FromUnixTime, + override val conf: RapidsConf, + val p: Option[RapidsMeta[_, _, _]], + r: DataFromReplacementRule) extends UnixTimeExprMeta[FromUnixTime](a, conf, p, r) { + + private type FmtConverter = ColumnView => ColumnVector + + private var colConverter: Option[FmtConverter] = None + + /** + * More supported formats by post conversions. The idea is + * 1) Map the unsupported target format to a supported format as + * the intermediate format, + * 2) Call into cuDF with this intermediate format, + * 3) Run a post conversion to get the right output for the target format. + * + * NOTE: Need to remove the entry if the key format is supported by cuDF. + */ + private val FORMATS_BY_CONVERSION: Map[String, (String, FmtConverter)] = Map( + // spark format -> (intermediate format, converter) + "yyyyMMdd" -> (("yyyy-MM-dd", + col => { + withResource(Scalar.fromString("-")) { dashStr => + withResource(Scalar.fromString("")) { emptyStr => + col.stringReplace(dashStr, emptyStr) + } + } + } + )) + ) + + override def tagExprForGpu(): Unit = { + extractStringLit(a.right) match { + case Some(rightLit) => + sparkFormat = rightLit + var inputFormat: Option[String] = None + FORMATS_BY_CONVERSION.get(sparkFormat).foreach { case (tempFormat, converter) => + colConverter = Some(converter) + inputFormat = Some(tempFormat) + } + strfFormat = DateUtils.tagAndGetCudfFormat(this, sparkFormat, + a.left.dataType == DataTypes.StringType, inputFormat) + case None => + willNotWorkOnGpu("format has to be a string literal") + } + } + + override def isTimeZoneSupported = true + override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { + // passing the already converted strf string for a little optimization + GpuFromUnixTime(lhs, rhs, strfFormat, colConverter, a.timeZoneId) + } +} + case class GpuFromUnixTime( sec: Expression, format: Expression, strfFormat: String, - timeZoneId: Option[String] = None) + colConverter: Option[ColumnView => ColumnVector], + timeZoneId: Option[String]) extends GpuBinaryExpressionArgsAnyScalar with TimeZoneAwareExpression with ImplicitCastInputTypes { + // To avoid duplicated "if...else" for each input batch + private val convertFunc: ColumnVector => ColumnVector = { + if (colConverter.isDefined) { + col => withResource(col)(colConverter.get.apply) + } else { + identity[ColumnVector] + } + } + override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { // we aren't using rhs as it was already converted in the GpuOverrides while creating the // expressions map and passed down here as strfFormat - withResource(lhs.getBase.asTimestampSeconds) { secondCV => + val ret = withResource(lhs.getBase.asTimestampSeconds) { secondCV => if (GpuOverrides.isUTCTimezone(zoneId)) { // UTC time zone secondCV.asStrings(strfFormat) @@ -937,6 +1001,7 @@ case class GpuFromUnixTime( } } } + convertFunc(ret) } override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {