From 630c5934720d5e1b1f4a9d7b8923455bfc2fafa1 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Wed, 4 Dec 2024 15:35:35 -0800 Subject: [PATCH 01/18] Implement GpuOverride Signed-off-by: Nghia Truong --- .../nvidia/spark/rapids/GpuOverrides.scala | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 45905f0b9e0..fa3ceb10f25 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1821,6 +1821,53 @@ object GpuOverrides extends Logging { ParamCheck("round", TypeSig.lit(TypeEnum.BOOLEAN), TypeSig.BOOLEAN))), (a, conf, p, r) => new MonthsBetweenExprMeta(a, conf, p, r) ), + expr[TruncDate]( + "Truncate the date to the unit specified by the given string format", + ExprChecks.binaryProject(TypeSig.DATE, TypeSig.DATE, + ("date", TypeSig.DATE, TypeSig.DATE), + ("format", TypeSig.lit(TypeEnum.STRING) + .withPsNote(TypeEnum.STRING, "\"QUARTER\" and \"WEEK\" are not supported"), + TypeSig.STRING)), + (a, conf, p, r) => new BinaryExprMeta[TruncDate](a, conf, p, r) { + override def tagExprForGpu(): Unit = { + def isSupported(format: String): Boolean = { + format.toUpperCase match { + case "YEAR" | "YYYY" | "YY" | "MONTH" | "MM" | "MON" => true + case _ => false + } + } + extractStringLit(a.format) match { + case Some(format) if isSupported(format) => + case _ => + willNotWorkOnGpu("Truncation format is not supported") + } + } + override def convertToGpu(child: Expression): GpuExpression = GpuTruncDate(child) + }), + expr[TruncTimestamp]( + "Truncate the timestamp to the unit specified by the given string format", + ExprChecks.binaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, + ("date", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), + ("format", TypeSig.lit(TypeEnum.STRING) + .withPsNote(TypeEnum.STRING, "\"QUARTER\" and \"WEEK\" are not supported"), + TypeSig.STRING)), + (a, conf, p, r) => new BinaryExprMeta[TruncDate](a, conf, p, r) { + override def tagExprForGpu(): Unit = { + def isSupported(format: String): Boolean = { + format.toUpperCase match { + case "YEAR" | "YYYY" | "YY" | "MONTH" | "MM" | "MON" | "DAY" | "DD" | + "HOUR" | "MINUTE" | "SECOND" | "MILLISECOND" | "MICROSECOND" => true + case _ => false + } + } + extractStringLit(a.format) match { + case Some(format) if isSupported(format) => + case _ => + willNotWorkOnGpu("Truncation format is not supported") + } + } + override def convertToGpu(child: Expression): GpuExpression = GpuTruncTimestamp(child) + }), expr[Pmod]( "Pmod", // Decimal support disabled https://github.com/NVIDIA/spark-rapids/issues/7553 From 075fe1e1a73fbcd70a5be587969408ec1a74b3e8 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Wed, 4 Dec 2024 16:23:56 -0800 Subject: [PATCH 02/18] Implement `GpuTruncDateTime` Signed-off-by: Nghia Truong --- .../nvidia/spark/rapids/GpuOverrides.scala | 8 ++- .../sql/rapids/datetimeExpressions.scala | 69 ++++++++++++++++--- 2 files changed, 65 insertions(+), 12 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index fa3ceb10f25..1bf03b8a3b1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1842,7 +1842,8 @@ object GpuOverrides extends Logging { willNotWorkOnGpu("Truncation format is not supported") } } - override def convertToGpu(child: Expression): GpuExpression = GpuTruncDate(child) + override def convertToGpu(date: Expression, format: Expression): GpuExpression = + GpuTruncDate(date, format) }), expr[TruncTimestamp]( "Truncate the timestamp to the unit specified by the given string format", @@ -1851,7 +1852,7 @@ object GpuOverrides extends Logging { ("format", TypeSig.lit(TypeEnum.STRING) .withPsNote(TypeEnum.STRING, "\"QUARTER\" and \"WEEK\" are not supported"), TypeSig.STRING)), - (a, conf, p, r) => new BinaryExprMeta[TruncDate](a, conf, p, r) { + (a, conf, p, r) => new BinaryExprMeta[TruncTimestamp](a, conf, p, r) { override def tagExprForGpu(): Unit = { def isSupported(format: String): Boolean = { format.toUpperCase match { @@ -1866,7 +1867,8 @@ object GpuOverrides extends Logging { willNotWorkOnGpu("Truncation format is not supported") } } - override def convertToGpu(child: Expression): GpuExpression = GpuTruncTimestamp(child) + override def convertToGpu(format: Expression, timestamp: Expression): GpuExpression = + GpuTruncTimestamp(format, timestamp) }), expr[Pmod]( "Pmod", diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 0f382a7b6e6..a7bbf2cf700 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -27,7 +27,7 @@ import com.nvidia.spark.rapids.Arm._ import com.nvidia.spark.rapids.ExprMeta import com.nvidia.spark.rapids.GpuOverrides.{extractStringLit, getTimeParserPolicy} import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.jni.GpuTimeZoneDB +import com.nvidia.spark.rapids.jni.{DateTimeUtils, GpuTimeZoneDB} import com.nvidia.spark.rapids.shims.{NullIntolerantShim, ShimBinaryExpression, ShimExpression} import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUnixTime, FromUTCTimestamp, ImplicitCastInputTypes, MonthsBetween, TimeZoneAwareExpression, ToUTCTimestamp} @@ -417,12 +417,11 @@ abstract class UnixTimeExprMeta[A <: BinaryExpression with TimeZoneAwareExpressi } trait GpuNumberToTimestampUnaryExpression extends GpuUnaryExpression { - override def dataType: DataType = TimestampType override def outputTypeOverride: DType = DType.TIMESTAMP_MICROSECONDS /** - * Test whether if input * multiplier will cause Long-overflow. In Math.multiplyExact, + * Test whether if input * multiplier will cause Long-overflow. In Math.multiplyExact, * if there is an integer-overflow, then it will throw an ArithmeticException "long overflow" */ def checkLongMultiplicationOverflow(input: ColumnVector, multiplier: Long): Unit = { @@ -439,7 +438,7 @@ trait GpuNumberToTimestampUnaryExpression extends GpuUnaryExpression { } protected val convertTo : GpuColumnVector => ColumnVector - + override def doColumnar(input: GpuColumnVector): ColumnVector = { convertTo(input) } @@ -543,14 +542,14 @@ case class GpuSecondsToTimestamp(child: Expression) extends GpuNumberToTimestamp longs.asTimestampSeconds() } case _ => - throw new UnsupportedOperationException(s"Unsupport type ${child.dataType} " + + throw new UnsupportedOperationException(s"Unsupport type ${child.dataType} " + s"for SecondsToTimestamp ") } } case class GpuMillisToTimestamp(child: Expression) extends GpuNumberToTimestampUnaryExpression { protected lazy val convertTo: GpuColumnVector => ColumnVector = child.dataType match { - case LongType => + case LongType => (input: GpuColumnVector) => { checkLongMultiplicationOverflow(input.getBase, DateTimeConstants.MICROS_PER_MILLIS) input.getBase.asTimestampMilliseconds() @@ -563,7 +562,7 @@ case class GpuMillisToTimestamp(child: Expression) extends GpuNumberToTimestampU } } case _ => - throw new UnsupportedOperationException(s"Unsupport type ${child.dataType} " + + throw new UnsupportedOperationException(s"Unsupport type ${child.dataType} " + s"for MillisToTimestamp ") } } @@ -581,7 +580,7 @@ case class GpuMicrosToTimestamp(child: Expression) extends GpuNumberToTimestampU } } case _ => - throw new UnsupportedOperationException(s"Unsupport type ${child.dataType} " + + throw new UnsupportedOperationException(s"Unsupport type ${child.dataType} " + s"for MicrosToTimestamp ") } } @@ -1108,7 +1107,7 @@ abstract class ConvertUTCTimestampExprMetaBase[INPUT <: BinaryExpression]( rule: DataFromReplacementRule) extends BinaryExprMeta[INPUT](expr, conf, parent, rule) { - protected[this] var timezoneId: ZoneId = null + protected[this] var timezoneId: ZoneId = null override def tagExprForGpu(): Unit = { extractStringLit(expr.right) match { @@ -1525,3 +1524,55 @@ case class GpuLastDay(startDate: Expression) override protected def doColumnar(input: GpuColumnVector): ColumnVector = input.getBase.lastDayOfMonth() } + +abstract class GpuTruncDateTime(datetime: Expression, format: Expression) + extends GpuBinaryExpression with ImplicitCastInputTypes { + override def left: Expression = datetime + + override def right: Expression = format + + override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = { + // We always store date/time to the left and format to the right expression. + DateTimeUtils.truncate(lhs.getBase, rhs.getBase) + } + + override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = { + withResource(GpuColumnVector.from(lhs, rhs.getRowCount.toInt, lhs.dataType)) { left => + doColumnar(left, rhs) + } + } + + override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { + withResource(GpuColumnVector.from(rhs, lhs.getRowCount.toInt, rhs.dataType)) { right => + doColumnar(lhs, right) + } + } + + override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { + withResource(GpuColumnVector.from(lhs, numRows, lhs.dataType)) { left => + withResource(GpuColumnVector.from(rhs, numRows, rhs.dataType)) { right => + doColumnar(left, right) + } + } + } +} + +case class GpuTruncDate(date: Expression, format: Expression) + extends GpuTruncDateTime(date, format) { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + + override def dataType: DataType = DateType + + override def prettyName: String = "trunc" +} + +case class GpuTruncTimestamp(format: Expression, timestamp: Expression) + extends GpuTruncDateTime(timestamp, format) { + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType) + + override def dataType: DataType = TimestampType + + override def prettyName: String = "date_trunc" +} From f512b559f76da505944c6a1f1f1da956101c2bce Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Thu, 5 Dec 2024 09:53:49 -0800 Subject: [PATCH 03/18] Do not fallback Signed-off-by: Nghia Truong --- .../nvidia/spark/rapids/GpuOverrides.scala | 27 ------------------- .../sql/rapids/datetimeExpressions.scala | 4 +-- 2 files changed, 2 insertions(+), 29 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 1bf03b8a3b1..93f516e7071 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1829,19 +1829,6 @@ object GpuOverrides extends Logging { .withPsNote(TypeEnum.STRING, "\"QUARTER\" and \"WEEK\" are not supported"), TypeSig.STRING)), (a, conf, p, r) => new BinaryExprMeta[TruncDate](a, conf, p, r) { - override def tagExprForGpu(): Unit = { - def isSupported(format: String): Boolean = { - format.toUpperCase match { - case "YEAR" | "YYYY" | "YY" | "MONTH" | "MM" | "MON" => true - case _ => false - } - } - extractStringLit(a.format) match { - case Some(format) if isSupported(format) => - case _ => - willNotWorkOnGpu("Truncation format is not supported") - } - } override def convertToGpu(date: Expression, format: Expression): GpuExpression = GpuTruncDate(date, format) }), @@ -1853,20 +1840,6 @@ object GpuOverrides extends Logging { .withPsNote(TypeEnum.STRING, "\"QUARTER\" and \"WEEK\" are not supported"), TypeSig.STRING)), (a, conf, p, r) => new BinaryExprMeta[TruncTimestamp](a, conf, p, r) { - override def tagExprForGpu(): Unit = { - def isSupported(format: String): Boolean = { - format.toUpperCase match { - case "YEAR" | "YYYY" | "YY" | "MONTH" | "MM" | "MON" | "DAY" | "DD" | - "HOUR" | "MINUTE" | "SECOND" | "MILLISECOND" | "MICROSECOND" => true - case _ => false - } - } - extractStringLit(a.format) match { - case Some(format) if isSupported(format) => - case _ => - willNotWorkOnGpu("Truncation format is not supported") - } - } override def convertToGpu(format: Expression, timestamp: Expression): GpuExpression = GpuTruncTimestamp(format, timestamp) }), diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index a7bbf2cf700..e6a4047fc28 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -1527,12 +1527,12 @@ case class GpuLastDay(startDate: Expression) abstract class GpuTruncDateTime(datetime: Expression, format: Expression) extends GpuBinaryExpression with ImplicitCastInputTypes { - override def left: Expression = datetime + // We always store date/time to the left and format to the right expressions. + override def left: Expression = datetime override def right: Expression = format override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = { - // We always store date/time to the left and format to the right expression. DateTimeUtils.truncate(lhs.getBase, rhs.getBase) } From 863e67409a1674d849d4128e567f3083550c40f2 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Fri, 6 Dec 2024 11:41:20 -0800 Subject: [PATCH 04/18] All variant input will be converted to just one Signed-off-by: Nghia Truong --- .../spark/sql/rapids/datetimeExpressions.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index e6a4047fc28..d6972a931b7 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -1537,24 +1537,28 @@ abstract class GpuTruncDateTime(datetime: Expression, format: Expression) } override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = { - withResource(GpuColumnVector.from(lhs, rhs.getRowCount.toInt, lhs.dataType)) { left => + withResource(fromScalar(lhs)) { left => doColumnar(left, rhs) } } override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { - withResource(GpuColumnVector.from(rhs, lhs.getRowCount.toInt, rhs.dataType)) { right => + withResource(fromScalar(rhs)) { right => doColumnar(lhs, right) } } override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { - withResource(GpuColumnVector.from(lhs, numRows, lhs.dataType)) { left => - withResource(GpuColumnVector.from(rhs, numRows, rhs.dataType)) { right => + withResource(fromScalar(lhs, numRows)) { left => + withResource(fromScalar(rhs, numRows)) { right => doColumnar(left, right) } } } + + private def fromScalar(input: GpuScalar, numRows: Int = 1) : GpuColumnVector = { + GpuColumnVector.from(input, numRows, input.dataType) + } } case class GpuTruncDate(date: Expression, format: Expression) From 59f1051f6e57469a1c0386b9985c896a89d18fc7 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Fri, 6 Dec 2024 11:41:28 -0800 Subject: [PATCH 05/18] Add generated docs Signed-off-by: Nghia Truong --- .../advanced_configs.md | 2 + docs/supported_ops.md | 632 +++++++++++------- tools/generated_files/320/operatorsScore.csv | 2 + tools/generated_files/320/supportedExprs.csv | 6 + tools/generated_files/operatorsScore.csv | 2 + tools/generated_files/supportedExprs.csv | 6 + 6 files changed, 422 insertions(+), 228 deletions(-) diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md index a4427d9495a..5519e56b419 100644 --- a/docs/additional-functionality/advanced_configs.md +++ b/docs/additional-functionality/advanced_configs.md @@ -386,6 +386,8 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.ToUnixTimestamp|`to_unix_timestamp`|Returns the UNIX timestamp of the given time|true|None| spark.rapids.sql.expression.TransformKeys|`transform_keys`|Transform keys in a map using a transform function|true|None| spark.rapids.sql.expression.TransformValues|`transform_values`|Transform values in a map using a transform function|true|None| +spark.rapids.sql.expression.TruncDate|`trunc`|Truncate the date to the unit specified by the given string format|true|None| +spark.rapids.sql.expression.TruncTimestamp|`date_trunc`|Truncate the timestamp to the unit specified by the given string format|true|None| spark.rapids.sql.expression.UnaryMinus|`negative`|Negate a numeric value|true|None| spark.rapids.sql.expression.UnaryPositive|`positive`|A numeric value with a + in front of it|true|None| spark.rapids.sql.expression.UnboundedFollowing$| |Special boundary for a window frame, indicating all rows preceding the current row|true|None| diff --git a/docs/supported_ops.md b/docs/supported_ops.md index acf7133af40..1072c0d8062 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -17702,6 +17702,154 @@ are limited. YEARMONTH +TruncDate +`trunc` +Truncate the date to the unit specified by the given string format +None +project +date + + + + + + + +S + + + + + + + + + + + + + + +format + + + + + + + + + +PS
"QUARTER" and "WEEK" are not supported;
Literal value only
+ + + + + + + + + + + + +result + + + + + + + +S + + + + + + + + + + + + + + +TruncTimestamp +`date_trunc` +Truncate the timestamp to the unit specified by the given string format +None +project +date + + + + + + + + +PS
UTC is only supported TZ for TIMESTAMP
+ + + + + + + + + + + + + +format + + + + + + + + + +PS
"QUARTER" and "WEEK" are not supported;
Literal value only
+ + + + + + + + + + + + +result + + + + + + + + +PS
UTC is only supported TZ for TIMESTAMP
+ + + + + + + + + + + + + UnaryMinus `negative` Negate a numeric value @@ -17926,6 +18074,34 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + UnboundedPreceding$ Special boundary for a window frame, indicating all rows preceding the current row @@ -18079,34 +18255,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH - - Upper `ucase`, `upper` String uppercase operator @@ -18357,6 +18505,34 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + XxHash64 `xxhash64` xxhash64 hash operator @@ -18673,34 +18849,6 @@ are limited. S -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH - - ApproximatePercentile `approx_percentile`, `percentile_approx` Approximate percentile @@ -18891,6 +19039,34 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + Average `avg`, `mean` Average aggregate operator @@ -19181,34 +19357,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH - - CollectSet `collect_set` Collect a set of unique elements, not supported in reduction @@ -19354,6 +19502,34 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + Count `count` Count aggregate operator @@ -19638,38 +19814,10 @@ are limited. NS PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH
PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH
-PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH
-NS -NS -NS - - -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH
+NS +NS +NS Last @@ -19817,6 +19965,34 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + Max `max` Max aggregate operator @@ -20106,34 +20282,6 @@ are limited. NS -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH - - Min `min` Min aggregate operator @@ -20279,6 +20427,34 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + MinBy `min_by` MinBy aggregate operator. It may produce different results than CPU when multiple rows in a group have same minimum value in the ordering column and different associated values in the value column. @@ -20613,34 +20789,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH - - PivotFirst PivotFirst operator @@ -20785,6 +20933,34 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + StddevPop `stddev_pop` Aggregation computing population standard deviation @@ -21075,34 +21251,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH - - Sum `sum` Sum aggregate operator @@ -21248,6 +21396,34 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + VariancePop `var_pop` Aggregation computing population variance @@ -21538,34 +21714,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH - - NormalizeNaNAndZero Normalize NaN and zero @@ -21645,6 +21793,34 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + HiveGenericUDF Hive Generic UDF, the UDF can choose to implement a RAPIDS accelerated interface to get better performance diff --git a/tools/generated_files/320/operatorsScore.csv b/tools/generated_files/320/operatorsScore.csv index 19c999aa796..d8c4ca63adc 100644 --- a/tools/generated_files/320/operatorsScore.csv +++ b/tools/generated_files/320/operatorsScore.csv @@ -265,6 +265,8 @@ ToUTCTimestamp,4 ToUnixTimestamp,4 TransformKeys,4 TransformValues,4 +TruncDate,4 +TruncTimestamp,4 UnaryMinus,4 UnaryPositive,4 UnboundedFollowing$,4 diff --git a/tools/generated_files/320/supportedExprs.csv b/tools/generated_files/320/supportedExprs.csv index e4a4db760b0..29b5b8d9150 100644 --- a/tools/generated_files/320/supportedExprs.csv +++ b/tools/generated_files/320/supportedExprs.csv @@ -606,6 +606,12 @@ TransformKeys,S,`transform_keys`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA, TransformValues,S,`transform_values`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA TransformValues,S,`transform_values`,None,project,function,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS,NS,NS TransformValues,S,`transform_values`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA UnaryMinus,S,`negative`,None,project,input,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,NS,NS UnaryMinus,S,`negative`,None,project,result,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,NS,NS UnaryMinus,S,`negative`,None,AST,input,NA,NS,NS,S,S,S,S,NA,NA,NA,NS,NA,NA,NS,NA,NA,NA,NA,NS,NS diff --git a/tools/generated_files/operatorsScore.csv b/tools/generated_files/operatorsScore.csv index 19c999aa796..d8c4ca63adc 100644 --- a/tools/generated_files/operatorsScore.csv +++ b/tools/generated_files/operatorsScore.csv @@ -265,6 +265,8 @@ ToUTCTimestamp,4 ToUnixTimestamp,4 TransformKeys,4 TransformValues,4 +TruncDate,4 +TruncTimestamp,4 UnaryMinus,4 UnaryPositive,4 UnboundedFollowing$,4 diff --git a/tools/generated_files/supportedExprs.csv b/tools/generated_files/supportedExprs.csv index e4a4db760b0..29b5b8d9150 100644 --- a/tools/generated_files/supportedExprs.csv +++ b/tools/generated_files/supportedExprs.csv @@ -606,6 +606,12 @@ TransformKeys,S,`transform_keys`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA, TransformValues,S,`transform_values`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA TransformValues,S,`transform_values`,None,project,function,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS,NS,NS TransformValues,S,`transform_values`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA UnaryMinus,S,`negative`,None,project,input,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,NS,NS UnaryMinus,S,`negative`,None,project,result,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,NS,NS UnaryMinus,S,`negative`,None,AST,input,NA,NS,NS,S,S,S,S,NA,NA,NA,NS,NA,NA,NS,NA,NA,NA,NA,NS,NS From e3e4c9cb919bd0e1ff1e53a90ac0e0cfeadee40e Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Fri, 6 Dec 2024 13:22:20 -0800 Subject: [PATCH 06/18] Add tests Signed-off-by: Nghia Truong --- .../src/main/python/date_time_test.py | 72 +++++++++++++++++-- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 1a7024dac85..6c38ad8fa37 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -321,7 +321,7 @@ def test_unsupported_fallback_to_unix_timestamp(data_gen): spark, [("a", data_gen), ("b", string_gen)], length=10).selectExpr( "to_unix_timestamp(a, b)"), "ToUnixTimestamp") - + supported_timezones = ["Asia/Shanghai", "UTC", "UTC+0", "UTC-0", "GMT", "GMT+0", "GMT-0", "EST", "MST", "VST"] unsupported_timezones = ["PST", "NST", "AST", "America/Los_Angeles", "America/New_York", "America/Chicago"] @@ -681,7 +681,7 @@ def test_unsupported_fallback_to_date(): conf) -# (-62135510400, 253402214400) is the range of seconds that can be represented by timestamp_seconds +# (-62135510400, 253402214400) is the range of seconds that can be represented by timestamp_seconds # considering the influence of time zone. ts_float_gen = SetValuesGen(FloatType(), [0.0, -0.0, 1.0, -1.0, 1.234567, -1.234567, 16777215.0, float('inf'), float('-inf'), float('nan')]) seconds_gens = [LongGen(min_val=-62135510400, max_val=253402214400), IntegerGen(), ShortGen(), ByteGen(), @@ -710,7 +710,7 @@ def test_timestamp_seconds_rounding_necessary(data_gen): lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_seconds(a)").collect(), conf={}, error_message='Rounding necessary') - + @pytest.mark.parametrize('data_gen', [DecimalGen(19, 6), DecimalGen(20, 6)], ids=idfn) @allow_non_gpu(*non_utc_allow) def test_timestamp_seconds_decimal_overflow(data_gen): @@ -725,7 +725,7 @@ def test_timestamp_seconds_decimal_overflow(data_gen): def test_timestamp_millis(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_millis(a)")) - + @allow_non_gpu(*non_utc_allow) def test_timestamp_millis_long_overflow(): assert_gpu_and_cpu_error( @@ -751,3 +751,67 @@ def test_date_to_timestamp(parser_policy): conf = { "spark.sql.legacy.timeParserPolicy": parser_policy, "spark.rapids.sql.incompatibleDateFormats.enabled": True}) + +# Generate format strings, which are case insensitive and have some garbage rows. +trunc_date_format_gen = StringGen('(?i:YEAR|YYYY|YY|QUARTER|MONTH|MM|MON|WEEK)') \ + .with_special_pattern('invalid', weight=50) +trunc_timestamp_format_gen = StringGen('(?i:YEAR|YYYY|YY|QUARTER|MONTH|MM|MON|WEEK|DAY|DD|HOUR|MINUTE|SECOND|MILLISECOND|MICROSECOND)') \ + .with_special_pattern('invalid', weight=50) + +@pytest.mark.parametrize('data_gen', [date_gen], ids=idfn) +@pytest.mark.parametrize('format_gen', [trunc_date_format_gen], ids=idfn) +def test_trunc_date_full_input(data_gen, format_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : two_col_df(spark, data_gen, format_gen).selectExpr('trunc(a, b)')) + +@pytest.mark.parametrize('format_gen', [trunc_timestamp_format_gen], ids=idfn) +@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn) +def test_trunc_timestamp_full_input(format_gen, data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : two_col_df(spark, format_gen, data_gen).selectExpr('date_trunc(a, b)')) + +@pytest.mark.parametrize('format_gen', [trunc_date_format_gen], ids=idfn) +def test_trunc_date_single_value(format_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, format_gen).selectExpr('trunc("1980-05-18", a)')) + +@pytest.mark.parametrize('format_gen', [trunc_timestamp_format_gen], ids=idfn) +def test_trunc_timestamp_single_value(format_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, format_gen).selectExpr( + 'date_trunc(a, "1980-05-18T09:32:05.359")')) + +@pytest.mark.parametrize('data_gen', [date_gen], ids=idfn) +def test_trunc_date_single_format(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'trunc(a, "YEAR")', + 'trunc(a, "YYYY")', + 'trunc(a, "YY")', + 'trunc(a, "QUARTER")', + 'trunc(a, "MONTH")', + 'trunc(a, "MM")', + 'trunc(a, "MON")', + 'trunc(a, "WEEK")', + 'trunc(a, "invalid")')) + +@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn) +def test_trunc_date_single_format(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'date_trunc("YEAR", a)', + 'date_trunc("YYYY", a)', + 'date_trunc("YY", a)', + 'date_trunc("QUARTER", a)', + 'date_trunc("MONTH", a)', + 'date_trunc("MM", a)', + 'date_trunc("MON", a)', + 'date_trunc("WEEK", a)', + 'date_trunc("DAY", a)', + 'date_trunc("DD", a)', + 'date_trunc("HOUR", a)', + 'date_trunc("MINUTE", a)', + 'date_trunc("SECOND", a)', + 'date_trunc("MILLISECOND", a)', + 'date_trunc("MICROSECOND", a)', + 'date_trunc("invalid", a)')) From 0b6339e6cd2f744a2b9144d03cbeb20dedd4ecba Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Fri, 6 Dec 2024 13:52:59 -0800 Subject: [PATCH 07/18] Fix parameter types Signed-off-by: Nghia Truong --- .../scala/com/nvidia/spark/rapids/GpuOverrides.scala | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 93f516e7071..a0eeb3f9056 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1825,9 +1825,7 @@ object GpuOverrides extends Logging { "Truncate the date to the unit specified by the given string format", ExprChecks.binaryProject(TypeSig.DATE, TypeSig.DATE, ("date", TypeSig.DATE, TypeSig.DATE), - ("format", TypeSig.lit(TypeEnum.STRING) - .withPsNote(TypeEnum.STRING, "\"QUARTER\" and \"WEEK\" are not supported"), - TypeSig.STRING)), + ("format", TypeSig.STRING, TypeSig.STRING)), (a, conf, p, r) => new BinaryExprMeta[TruncDate](a, conf, p, r) { override def convertToGpu(date: Expression, format: Expression): GpuExpression = GpuTruncDate(date, format) @@ -1835,10 +1833,8 @@ object GpuOverrides extends Logging { expr[TruncTimestamp]( "Truncate the timestamp to the unit specified by the given string format", ExprChecks.binaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, - ("date", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), - ("format", TypeSig.lit(TypeEnum.STRING) - .withPsNote(TypeEnum.STRING, "\"QUARTER\" and \"WEEK\" are not supported"), - TypeSig.STRING)), + ("format", TypeSig.STRING, TypeSig.STRING), + ("date", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP)), (a, conf, p, r) => new BinaryExprMeta[TruncTimestamp](a, conf, p, r) { override def convertToGpu(format: Expression, timestamp: Expression): GpuExpression = GpuTruncTimestamp(format, timestamp) From 11b302c9270bd6dfcd67557fc6bac6eeca39d1fc Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Fri, 6 Dec 2024 14:25:32 -0800 Subject: [PATCH 08/18] Fix test Signed-off-by: Nghia Truong --- integration_tests/src/main/python/date_time_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 6c38ad8fa37..98d32b5993b 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -796,7 +796,7 @@ def test_trunc_date_single_format(data_gen): 'trunc(a, "invalid")')) @pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn) -def test_trunc_date_single_format(data_gen): +def test_trunc_timestamp_single_format(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).selectExpr( 'date_trunc("YEAR", a)', From 02a0a35fb9c802aee28f6e4edf6fd3d2e3867c26 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Fri, 6 Dec 2024 14:25:51 -0800 Subject: [PATCH 09/18] Change abstract class to trait Signed-off-by: Nghia Truong --- .../sql/rapids/datetimeExpressions.scala | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index d6972a931b7..4e66615366f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -1525,13 +1525,7 @@ case class GpuLastDay(startDate: Expression) input.getBase.lastDayOfMonth() } -abstract class GpuTruncDateTime(datetime: Expression, format: Expression) - extends GpuBinaryExpression with ImplicitCastInputTypes { - - // We always store date/time to the left and format to the right expressions. - override def left: Expression = datetime - override def right: Expression = format - +trait GpuTruncDateTime extends GpuBinaryExpression with ImplicitCastInputTypes { override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = { DateTimeUtils.truncate(lhs.getBase, rhs.getBase) } @@ -1561,8 +1555,11 @@ abstract class GpuTruncDateTime(datetime: Expression, format: Expression) } } -case class GpuTruncDate(date: Expression, format: Expression) - extends GpuTruncDateTime(date, format) { +case class GpuTruncDate(date: Expression, format: Expression) extends GpuTruncDateTime { + // We always store date/time to the left and format to the right expressions. + // This is to make sure `doColumnar` will call `DateTimeUtils.truncate` with the correct order. + override def left: Expression = date + override def right: Expression = format override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) @@ -1571,8 +1568,11 @@ case class GpuTruncDate(date: Expression, format: Expression) override def prettyName: String = "trunc" } -case class GpuTruncTimestamp(format: Expression, timestamp: Expression) - extends GpuTruncDateTime(timestamp, format) { +case class GpuTruncTimestamp(format: Expression, timestamp: Expression) extends GpuTruncDateTime { + // We always store date/time to the left and format to the right expressions. + // This is to make sure `doColumnar` will call `DateTimeUtils.truncate` with the correct order. + override def left: Expression = timestamp + override def right: Expression = format override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType) From 776b04f06257284bcadbd840ffd7e90b625c9660 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Fri, 6 Dec 2024 21:58:33 -0800 Subject: [PATCH 10/18] Fix expression order and implement `TimeZoneAwareExpression` Signed-off-by: Nghia Truong --- .../nvidia/spark/rapids/GpuOverrides.scala | 2 +- .../sql/rapids/datetimeExpressions.scala | 38 +++++++++++-------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index a0eeb3f9056..b9a46f7772f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1837,7 +1837,7 @@ object GpuOverrides extends Logging { ("date", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP)), (a, conf, p, r) => new BinaryExprMeta[TruncTimestamp](a, conf, p, r) { override def convertToGpu(format: Expression, timestamp: Expression): GpuExpression = - GpuTruncTimestamp(format, timestamp) + GpuTruncTimestamp(format, timestamp, a.timeZoneId) }), expr[Pmod]( "Pmod", diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 4e66615366f..53b72cc29c0 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -1526,38 +1526,32 @@ case class GpuLastDay(startDate: Expression) } trait GpuTruncDateTime extends GpuBinaryExpression with ImplicitCastInputTypes { - override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = { - DateTimeUtils.truncate(lhs.getBase, rhs.getBase) - } - override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = { - withResource(fromScalar(lhs)) { left => + withResource(scalarToColumn(lhs)) { left => doColumnar(left, rhs) } } override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { - withResource(fromScalar(rhs)) { right => + withResource(scalarToColumn(rhs)) { right => doColumnar(lhs, right) } } override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { - withResource(fromScalar(lhs, numRows)) { left => - withResource(fromScalar(rhs, numRows)) { right => + withResource(scalarToColumn(lhs, numRows)) { left => + withResource(scalarToColumn(rhs, numRows)) { right => doColumnar(left, right) } } } - private def fromScalar(input: GpuScalar, numRows: Int = 1) : GpuColumnVector = { + private def scalarToColumn(input: GpuScalar, numRows: Int = 1) : GpuColumnVector = { GpuColumnVector.from(input, numRows, input.dataType) } } case class GpuTruncDate(date: Expression, format: Expression) extends GpuTruncDateTime { - // We always store date/time to the left and format to the right expressions. - // This is to make sure `doColumnar` will call `DateTimeUtils.truncate` with the correct order. override def left: Expression = date override def right: Expression = format @@ -1566,17 +1560,29 @@ case class GpuTruncDate(date: Expression, format: Expression) extends GpuTruncDa override def dataType: DataType = DateType override def prettyName: String = "trunc" + + override def doColumnar(dateCol: GpuColumnVector, formatCol: GpuColumnVector): ColumnVector = { + DateTimeUtils.truncate(dateCol.getBase, formatCol.getBase) + } } -case class GpuTruncTimestamp(format: Expression, timestamp: Expression) extends GpuTruncDateTime { - // We always store date/time to the left and format to the right expressions. - // This is to make sure `doColumnar` will call `DateTimeUtils.truncate` with the correct order. - override def left: Expression = timestamp - override def right: Expression = format +case class GpuTruncTimestamp(format: Expression, timestamp: Expression, + timeZoneId: Option[String] = None) + extends GpuTruncDateTime with TimeZoneAwareExpression { + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = { + copy(timeZoneId = Option(timeZoneId)) + } + + override def left: Expression = format + override def right: Expression = timestamp override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType) override def dataType: DataType = TimestampType override def prettyName: String = "date_trunc" + + override def doColumnar(formatCol: GpuColumnVector, tsCol: GpuColumnVector): ColumnVector = { + DateTimeUtils.truncate(tsCol.getBase, formatCol.getBase) + } } From 19d69be01b0aa8ff569d38b44ccca25f72c0bdec Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Fri, 6 Dec 2024 23:35:44 -0800 Subject: [PATCH 11/18] Update generated docs Signed-off-by: Nghia Truong --- docs/supported_ops.md | 10 +++++----- tools/generated_files/320/supportedExprs.csv | 4 ++-- tools/generated_files/supportedExprs.csv | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 1072c0d8062..1be5008888b 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -17740,7 +17740,7 @@ are limited. -PS
"QUARTER" and "WEEK" are not supported;
Literal value only
+S @@ -17781,7 +17781,7 @@ are limited. Truncate the timestamp to the unit specified by the given string format None project -date +format @@ -17790,8 +17790,8 @@ are limited. -PS
UTC is only supported TZ for TIMESTAMP
+S @@ -17804,7 +17804,7 @@ are limited. -format +date @@ -17813,8 +17813,8 @@ are limited. +PS
UTC is only supported TZ for TIMESTAMP
-PS
"QUARTER" and "WEEK" are not supported;
Literal value only
diff --git a/tools/generated_files/320/supportedExprs.csv b/tools/generated_files/320/supportedExprs.csv index 29b5b8d9150..80fc939ee68 100644 --- a/tools/generated_files/320/supportedExprs.csv +++ b/tools/generated_files/320/supportedExprs.csv @@ -607,10 +607,10 @@ TransformValues,S,`transform_values`,None,project,argument,NA,NA,NA,NA,NA,NA,NA, TransformValues,S,`transform_values`,None,project,function,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS,NS,NS TransformValues,S,`transform_values`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA TruncDate,S,`trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA -TruncDate,S,`trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA TruncDate,S,`trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA TruncTimestamp,S,`date_trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA -TruncTimestamp,S,`date_trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA TruncTimestamp,S,`date_trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA UnaryMinus,S,`negative`,None,project,input,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,NS,NS UnaryMinus,S,`negative`,None,project,result,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,NS,NS diff --git a/tools/generated_files/supportedExprs.csv b/tools/generated_files/supportedExprs.csv index 29b5b8d9150..80fc939ee68 100644 --- a/tools/generated_files/supportedExprs.csv +++ b/tools/generated_files/supportedExprs.csv @@ -607,10 +607,10 @@ TransformValues,S,`transform_values`,None,project,argument,NA,NA,NA,NA,NA,NA,NA, TransformValues,S,`transform_values`,None,project,function,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS,NS,NS TransformValues,S,`transform_values`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA TruncDate,S,`trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA -TruncDate,S,`trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA TruncDate,S,`trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA TruncTimestamp,S,`date_trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA -TruncTimestamp,S,`date_trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA TruncTimestamp,S,`date_trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA UnaryMinus,S,`negative`,None,project,input,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,NS,NS UnaryMinus,S,`negative`,None,project,result,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,NS,NS From 9e7fb91c7cfe3132c2589caa041b903202a20267 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Fri, 6 Dec 2024 23:46:55 -0800 Subject: [PATCH 12/18] Update generated docs Signed-off-by: Nghia Truong --- tools/generated_files/342/operatorsScore.csv | 2 ++ tools/generated_files/342/supportedExprs.csv | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/tools/generated_files/342/operatorsScore.csv b/tools/generated_files/342/operatorsScore.csv index b1e9198e58b..16ac93a02ba 100644 --- a/tools/generated_files/342/operatorsScore.csv +++ b/tools/generated_files/342/operatorsScore.csv @@ -277,6 +277,8 @@ ToUTCTimestamp,4 ToUnixTimestamp,4 TransformKeys,4 TransformValues,4 +TruncDate,4 +TruncTimestamp,4 UnaryMinus,4 UnaryPositive,4 UnboundedFollowing$,4 diff --git a/tools/generated_files/342/supportedExprs.csv b/tools/generated_files/342/supportedExprs.csv index 01a48b40249..f2d5a4f0736 100644 --- a/tools/generated_files/342/supportedExprs.csv +++ b/tools/generated_files/342/supportedExprs.csv @@ -629,6 +629,12 @@ TransformKeys,S,`transform_keys`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA, TransformValues,S,`transform_values`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA TransformValues,S,`transform_values`,None,project,function,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS,NS,NS TransformValues,S,`transform_values`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA UnaryMinus,S,`negative`,None,project,input,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,S,S UnaryMinus,S,`negative`,None,project,result,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,S,S UnaryMinus,S,`negative`,None,AST,input,NA,NS,NS,S,S,S,S,NA,NA,NA,NS,NA,NA,NS,NA,NA,NA,NA,NS,NS From 96804163ceaff784473e534efa41e6c357c6c432 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Tue, 10 Dec 2024 10:14:02 -0800 Subject: [PATCH 13/18] Add generated docs Signed-off-by: Nghia Truong --- tools/generated_files/340/operatorsScore.csv | 2 ++ tools/generated_files/340/supportedExprs.csv | 6 ++++++ tools/generated_files/400/operatorsScore.csv | 2 ++ tools/generated_files/400/supportedExprs.csv | 6 ++++++ 4 files changed, 16 insertions(+) diff --git a/tools/generated_files/340/operatorsScore.csv b/tools/generated_files/340/operatorsScore.csv index b1e9198e58b..16ac93a02ba 100644 --- a/tools/generated_files/340/operatorsScore.csv +++ b/tools/generated_files/340/operatorsScore.csv @@ -277,6 +277,8 @@ ToUTCTimestamp,4 ToUnixTimestamp,4 TransformKeys,4 TransformValues,4 +TruncDate,4 +TruncTimestamp,4 UnaryMinus,4 UnaryPositive,4 UnboundedFollowing$,4 diff --git a/tools/generated_files/340/supportedExprs.csv b/tools/generated_files/340/supportedExprs.csv index 01a48b40249..f2d5a4f0736 100644 --- a/tools/generated_files/340/supportedExprs.csv +++ b/tools/generated_files/340/supportedExprs.csv @@ -629,6 +629,12 @@ TransformKeys,S,`transform_keys`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA, TransformValues,S,`transform_values`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA TransformValues,S,`transform_values`,None,project,function,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS,NS,NS TransformValues,S,`transform_values`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA UnaryMinus,S,`negative`,None,project,input,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,S,S UnaryMinus,S,`negative`,None,project,result,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,S,S UnaryMinus,S,`negative`,None,AST,input,NA,NS,NS,S,S,S,S,NA,NA,NA,NS,NA,NA,NS,NA,NA,NA,NA,NS,NS diff --git a/tools/generated_files/400/operatorsScore.csv b/tools/generated_files/400/operatorsScore.csv index 53791a06705..0a099fc2233 100644 --- a/tools/generated_files/400/operatorsScore.csv +++ b/tools/generated_files/400/operatorsScore.csv @@ -278,6 +278,8 @@ ToUTCTimestamp,4 ToUnixTimestamp,4 TransformKeys,4 TransformValues,4 +TruncDate,4 +TruncTimestamp,4 UnaryMinus,4 UnaryPositive,4 UnboundedFollowing$,4 diff --git a/tools/generated_files/400/supportedExprs.csv b/tools/generated_files/400/supportedExprs.csv index 4cfa1020889..92cb5327d10 100644 --- a/tools/generated_files/400/supportedExprs.csv +++ b/tools/generated_files/400/supportedExprs.csv @@ -635,6 +635,12 @@ TransformKeys,S,`transform_keys`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA, TransformValues,S,`transform_values`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA TransformValues,S,`transform_values`,None,project,function,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS,NS,NS TransformValues,S,`transform_values`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA UnaryMinus,S,`negative`,None,project,input,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,S,S UnaryMinus,S,`negative`,None,project,result,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,S,S UnaryMinus,S,`negative`,None,AST,input,NA,NS,NS,S,S,S,S,NA,NA,NA,NS,NA,NA,NS,NA,NA,NA,NA,NS,NS From 35fea2be95324dbc45b303393b25472dd0ea504b Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Tue, 10 Dec 2024 11:01:55 -0800 Subject: [PATCH 14/18] Add generated docs Signed-off-by: Nghia Truong --- tools/generated_files/321/operatorsScore.csv | 2 ++ tools/generated_files/321/supportedExprs.csv | 6 ++++++ tools/generated_files/330/operatorsScore.csv | 2 ++ tools/generated_files/330/supportedExprs.csv | 6 ++++++ tools/generated_files/331/operatorsScore.csv | 2 ++ tools/generated_files/331/supportedExprs.csv | 6 ++++++ 6 files changed, 24 insertions(+) diff --git a/tools/generated_files/321/operatorsScore.csv b/tools/generated_files/321/operatorsScore.csv index 19c999aa796..d8c4ca63adc 100644 --- a/tools/generated_files/321/operatorsScore.csv +++ b/tools/generated_files/321/operatorsScore.csv @@ -265,6 +265,8 @@ ToUTCTimestamp,4 ToUnixTimestamp,4 TransformKeys,4 TransformValues,4 +TruncDate,4 +TruncTimestamp,4 UnaryMinus,4 UnaryPositive,4 UnboundedFollowing$,4 diff --git a/tools/generated_files/321/supportedExprs.csv b/tools/generated_files/321/supportedExprs.csv index e4a4db760b0..80fc939ee68 100644 --- a/tools/generated_files/321/supportedExprs.csv +++ b/tools/generated_files/321/supportedExprs.csv @@ -606,6 +606,12 @@ TransformKeys,S,`transform_keys`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA, TransformValues,S,`transform_values`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA TransformValues,S,`transform_values`,None,project,function,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS,NS,NS TransformValues,S,`transform_values`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA UnaryMinus,S,`negative`,None,project,input,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,NS,NS UnaryMinus,S,`negative`,None,project,result,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,NS,NS UnaryMinus,S,`negative`,None,AST,input,NA,NS,NS,S,S,S,S,NA,NA,NA,NS,NA,NA,NS,NA,NA,NA,NA,NS,NS diff --git a/tools/generated_files/330/operatorsScore.csv b/tools/generated_files/330/operatorsScore.csv index e5978fb9f1a..e86e30e606c 100644 --- a/tools/generated_files/330/operatorsScore.csv +++ b/tools/generated_files/330/operatorsScore.csv @@ -275,6 +275,8 @@ ToUTCTimestamp,4 ToUnixTimestamp,4 TransformKeys,4 TransformValues,4 +TruncDate,4 +TruncTimestamp,4 UnaryMinus,4 UnaryPositive,4 UnboundedFollowing$,4 diff --git a/tools/generated_files/330/supportedExprs.csv b/tools/generated_files/330/supportedExprs.csv index 0073281cb32..38a6042dc5c 100644 --- a/tools/generated_files/330/supportedExprs.csv +++ b/tools/generated_files/330/supportedExprs.csv @@ -627,6 +627,12 @@ TransformKeys,S,`transform_keys`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA, TransformValues,S,`transform_values`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA TransformValues,S,`transform_values`,None,project,function,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS,NS,NS TransformValues,S,`transform_values`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA UnaryMinus,S,`negative`,None,project,input,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,S,S UnaryMinus,S,`negative`,None,project,result,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,S,S UnaryMinus,S,`negative`,None,AST,input,NA,NS,NS,S,S,S,S,NA,NA,NA,NS,NA,NA,NS,NA,NA,NA,NA,NS,NS diff --git a/tools/generated_files/331/operatorsScore.csv b/tools/generated_files/331/operatorsScore.csv index b988344e702..229201ba885 100644 --- a/tools/generated_files/331/operatorsScore.csv +++ b/tools/generated_files/331/operatorsScore.csv @@ -276,6 +276,8 @@ ToUTCTimestamp,4 ToUnixTimestamp,4 TransformKeys,4 TransformValues,4 +TruncDate,4 +TruncTimestamp,4 UnaryMinus,4 UnaryPositive,4 UnboundedFollowing$,4 diff --git a/tools/generated_files/331/supportedExprs.csv b/tools/generated_files/331/supportedExprs.csv index f62af4c9513..12a4d1e0cf4 100644 --- a/tools/generated_files/331/supportedExprs.csv +++ b/tools/generated_files/331/supportedExprs.csv @@ -629,6 +629,12 @@ TransformKeys,S,`transform_keys`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA, TransformValues,S,`transform_values`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA TransformValues,S,`transform_values`,None,project,function,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS,NS,NS TransformValues,S,`transform_values`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncDate,S,`trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,format,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,date,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +TruncTimestamp,S,`date_trunc`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA UnaryMinus,S,`negative`,None,project,input,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,S,S UnaryMinus,S,`negative`,None,project,result,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NS,NA,NA,NA,NA,S,S UnaryMinus,S,`negative`,None,AST,input,NA,NS,NS,S,S,S,S,NA,NA,NA,NS,NA,NA,NS,NA,NA,NA,NA,NS,NS From a5c5b7fb67366bdfba7b50b5cf681b01027a6960 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Wed, 11 Dec 2024 19:36:37 -0800 Subject: [PATCH 15/18] Allow non-utc timezone for timestamp tests Signed-off-by: Nghia Truong --- integration_tests/src/main/python/date_time_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 98d32b5993b..d29e6a5b5fa 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -764,6 +764,7 @@ def test_trunc_date_full_input(data_gen, format_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark : two_col_df(spark, data_gen, format_gen).selectExpr('trunc(a, b)')) +@allow_non_gpu(*non_utc_tz_allow) @pytest.mark.parametrize('format_gen', [trunc_timestamp_format_gen], ids=idfn) @pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn) def test_trunc_timestamp_full_input(format_gen, data_gen): @@ -775,6 +776,7 @@ def test_trunc_date_single_value(format_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, format_gen).selectExpr('trunc("1980-05-18", a)')) +@allow_non_gpu(*non_utc_tz_allow) @pytest.mark.parametrize('format_gen', [trunc_timestamp_format_gen], ids=idfn) def test_trunc_timestamp_single_value(format_gen): assert_gpu_and_cpu_are_equal_collect( @@ -795,6 +797,7 @@ def test_trunc_date_single_format(data_gen): 'trunc(a, "WEEK")', 'trunc(a, "invalid")')) +@allow_non_gpu(*non_utc_tz_allow) @pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn) def test_trunc_timestamp_single_format(data_gen): assert_gpu_and_cpu_are_equal_collect( From 1f470821b0e36ec85760a3beab5a6577baa7a2c4 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Wed, 11 Dec 2024 20:29:23 -0800 Subject: [PATCH 16/18] Adopt to JNI changes Signed-off-by: Nghia Truong --- .../sql/rapids/datetimeExpressions.scala | 59 ++++++++++--------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index d72aa5e5053..16dc0069e6a 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -1527,33 +1527,9 @@ case class GpuLastDay(startDate: Expression) input.getBase.lastDayOfMonth() } -trait GpuTruncDateTime extends GpuBinaryExpression with ImplicitCastInputTypes { - override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = { - withResource(scalarToColumn(lhs)) { left => - doColumnar(left, rhs) - } - } - - override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { - withResource(scalarToColumn(rhs)) { right => - doColumnar(lhs, right) - } - } - - override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { - withResource(scalarToColumn(lhs, numRows)) { left => - withResource(scalarToColumn(rhs, numRows)) { right => - doColumnar(left, right) - } - } - } - - private def scalarToColumn(input: GpuScalar, numRows: Int = 1) : GpuColumnVector = { - GpuColumnVector.from(input, numRows, input.dataType) - } -} +case class GpuTruncDate(date: Expression, format: Expression) + extends GpuBinaryExpression with ImplicitCastInputTypes { -case class GpuTruncDate(date: Expression, format: Expression) extends GpuTruncDateTime { override def left: Expression = date override def right: Expression = format @@ -1566,11 +1542,26 @@ case class GpuTruncDate(date: Expression, format: Expression) extends GpuTruncDa override def doColumnar(dateCol: GpuColumnVector, formatCol: GpuColumnVector): ColumnVector = { DateTimeUtils.truncate(dateCol.getBase, formatCol.getBase) } + + override def doColumnar(dateVal: GpuScalar, formatCol: GpuColumnVector): ColumnVector = { + DateTimeUtils.truncate(dateVal.getBase, formatCol.getBase) + } + + override def doColumnar(dateCol: GpuColumnVector, formatVal: GpuScalar): ColumnVector = { + DateTimeUtils.truncate(dateCol.getBase, formatVal.getBase) + } + + override def doColumnar(numRows: Int, dateVal: GpuScalar, formatVal: GpuScalar): ColumnVector = { + withResource(DateTimeUtils.truncate(dateVal.getBase, formatVal.getBase)) { truncated => + ColumnVector.fromScalar(truncated, numRows) + } + } } case class GpuTruncTimestamp(format: Expression, timestamp: Expression, timeZoneId: Option[String] = None) - extends GpuTruncDateTime with TimeZoneAwareExpression { + extends GpuBinaryExpression with ImplicitCastInputTypes with TimeZoneAwareExpression { + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = { copy(timeZoneId = Option(timeZoneId)) } @@ -1587,4 +1578,18 @@ case class GpuTruncTimestamp(format: Expression, timestamp: Expression, override def doColumnar(formatCol: GpuColumnVector, tsCol: GpuColumnVector): ColumnVector = { DateTimeUtils.truncate(tsCol.getBase, formatCol.getBase) } + + override def doColumnar(formatVal: GpuScalar, tsCol: GpuColumnVector): ColumnVector = { + DateTimeUtils.truncate(tsCol.getBase, formatVal.getBase) + } + + override def doColumnar(formatCol: GpuColumnVector, tsVal: GpuScalar): ColumnVector = { + DateTimeUtils.truncate(tsVal.getBase, formatCol.getBase) + } + + override def doColumnar(numRows: Int, formatVal: GpuScalar, tsVal: GpuScalar): ColumnVector = { + withResource(DateTimeUtils.truncate(tsVal.getBase, formatVal.getBase)) { truncated => + ColumnVector.fromScalar(truncated, numRows) + } + } } From e56824eecb1aec5c47703c1dd2da08665a137cce Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Thu, 12 Dec 2024 10:02:33 -0800 Subject: [PATCH 17/18] Rewrite all classes Signed-off-by: Nghia Truong --- .../nvidia/spark/rapids/GpuOverrides.scala | 12 +- .../sql/rapids/datetimeExpressions.scala | 127 ++++++++++++++---- 2 files changed, 103 insertions(+), 36 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index d24528731b3..0c7aa046a7c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1826,19 +1826,15 @@ object GpuOverrides extends Logging { ExprChecks.binaryProject(TypeSig.DATE, TypeSig.DATE, ("date", TypeSig.DATE, TypeSig.DATE), ("format", TypeSig.STRING, TypeSig.STRING)), - (a, conf, p, r) => new BinaryExprMeta[TruncDate](a, conf, p, r) { - override def convertToGpu(date: Expression, format: Expression): GpuExpression = - GpuTruncDate(date, format) - }), + (a, conf, p, r) => new TruncDateExprMeta(a, conf, p, r) + ), expr[TruncTimestamp]( "Truncate the timestamp to the unit specified by the given string format", ExprChecks.binaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, ("format", TypeSig.STRING, TypeSig.STRING), ("date", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP)), - (a, conf, p, r) => new BinaryExprMeta[TruncTimestamp](a, conf, p, r) { - override def convertToGpu(format: Expression, timestamp: Expression): GpuExpression = - GpuTruncTimestamp(format, timestamp, a.timeZoneId) - }), + (a, conf, p, r) => new TruncTimestampExprMeta(a, conf, p, r) + ), expr[Pmod]( "Pmod", // Decimal support disabled https://github.com/NVIDIA/spark-rapids/issues/7553 diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 16dc0069e6a..14ec6bbd090 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -30,7 +30,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.jni.{DateTimeUtils, GpuTimeZoneDB} import com.nvidia.spark.rapids.shims.{NullIntolerantShim, ShimBinaryExpression, ShimExpression} -import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUnixTime, FromUTCTimestamp, ImplicitCastInputTypes, MonthsBetween, TimeZoneAwareExpression, ToUTCTimestamp} +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUnixTime, FromUTCTimestamp, ImplicitCastInputTypes, MonthsBetween, TimeZoneAwareExpression, ToUTCTimestamp, TruncDate, TruncTimestamp} import org.apache.spark.sql.catalyst.util.DateTimeConstants import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -1527,11 +1527,54 @@ case class GpuLastDay(startDate: Expression) input.getBase.lastDayOfMonth() } -case class GpuTruncDate(date: Expression, format: Expression) - extends GpuBinaryExpression with ImplicitCastInputTypes { +abstract class GpuTruncDateTime(fmtStr: Option[String]) extends GpuBinaryExpression + with ImplicitCastInputTypes with Serializable { + override def nullable: Boolean = true + + protected def truncate(datetimeCol: GpuColumnVector, fmtCol: GpuColumnVector): ColumnVector = { + DateTimeUtils.truncate(datetimeCol.getBase, fmtCol.getBase) + } + + protected def truncate(datetimeVal: GpuScalar, formatCol: GpuColumnVector): ColumnVector = { + withResource(ColumnVector.fromScalar(datetimeVal.getBase, 1)) { datetimeCol => + DateTimeUtils.truncate(datetimeCol, formatCol.getBase) + } + } + + protected def truncate(datetimeCol: GpuColumnVector, fmtVal: GpuScalar): ColumnVector = { + // fmtVal is unused, as it was extracted to `fmtStr` before. + fmtStr match { + case Some(fmt) => DateTimeUtils.truncate(datetimeCol.getBase, fmt) + case None => throw new IllegalArgumentException("Invalid format string.") + } + } + + protected def truncate(numRows: Int, datetimeVal: GpuScalar, fmtVal: GpuScalar): ColumnVector = { + // fmtVal is unused, as it was extracted to `fmtStr` before. + fmtStr match { + case Some(fmt) => + withResource(ColumnVector.fromScalar(datetimeVal.getBase, 1)) { datetimeCol => + val truncated = DateTimeUtils.truncate(datetimeCol, fmt) + if (numRows == 1) { + truncated + } else { + withResource(truncated) { _ => + withResource(truncated.getScalarElement(0)) { truncatedScalar => + ColumnVector.fromScalar(truncatedScalar, numRows) + } + } + } + } + case None => throw new IllegalArgumentException("Invalid format string.") + } + } +} +case class GpuTruncDate(date: Expression, fmt: Expression, fmtStr: Option[String]) + extends GpuTruncDateTime(fmtStr) { override def left: Expression = date - override def right: Expression = format + + override def right: Expression = fmt override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) @@ -1539,34 +1582,33 @@ case class GpuTruncDate(date: Expression, format: Expression) override def prettyName: String = "trunc" - override def doColumnar(dateCol: GpuColumnVector, formatCol: GpuColumnVector): ColumnVector = { - DateTimeUtils.truncate(dateCol.getBase, formatCol.getBase) + override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = { + truncate(lhs, rhs) } - override def doColumnar(dateVal: GpuScalar, formatCol: GpuColumnVector): ColumnVector = { - DateTimeUtils.truncate(dateVal.getBase, formatCol.getBase) + override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = { + truncate(lhs, rhs) } - override def doColumnar(dateCol: GpuColumnVector, formatVal: GpuScalar): ColumnVector = { - DateTimeUtils.truncate(dateCol.getBase, formatVal.getBase) + override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { + truncate(lhs, rhs) } - override def doColumnar(numRows: Int, dateVal: GpuScalar, formatVal: GpuScalar): ColumnVector = { - withResource(DateTimeUtils.truncate(dateVal.getBase, formatVal.getBase)) { truncated => - ColumnVector.fromScalar(truncated, numRows) - } + override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { + truncate(numRows, lhs, rhs) } } -case class GpuTruncTimestamp(format: Expression, timestamp: Expression, - timeZoneId: Option[String] = None) - extends GpuBinaryExpression with ImplicitCastInputTypes with TimeZoneAwareExpression { +case class GpuTruncTimestamp(fmt: Expression, timestamp: Expression, timeZoneId: Option[String], + fmtStr: Option[String]) + extends GpuTruncDateTime(fmtStr) with TimeZoneAwareExpression { override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = { copy(timeZoneId = Option(timeZoneId)) } - override def left: Expression = format + override def left: Expression = fmt + override def right: Expression = timestamp override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType) @@ -1575,21 +1617,50 @@ case class GpuTruncTimestamp(format: Expression, timestamp: Expression, override def prettyName: String = "date_trunc" - override def doColumnar(formatCol: GpuColumnVector, tsCol: GpuColumnVector): ColumnVector = { - DateTimeUtils.truncate(tsCol.getBase, formatCol.getBase) + // Since the input order of this class is opposite compared to the `GpuTruncDate` class, + // we need to switch `lhs` and `rhs` in the `doColumnar` methods below. + + override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = { + truncate(rhs, lhs) } - override def doColumnar(formatVal: GpuScalar, tsCol: GpuColumnVector): ColumnVector = { - DateTimeUtils.truncate(tsCol.getBase, formatVal.getBase) + override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = { + truncate(rhs, lhs) } - override def doColumnar(formatCol: GpuColumnVector, tsVal: GpuScalar): ColumnVector = { - DateTimeUtils.truncate(tsVal.getBase, formatCol.getBase) + override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { + truncate(rhs, lhs) } - override def doColumnar(numRows: Int, formatVal: GpuScalar, tsVal: GpuScalar): ColumnVector = { - withResource(DateTimeUtils.truncate(tsVal.getBase, formatVal.getBase)) { truncated => - ColumnVector.fromScalar(truncated, numRows) - } + override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { + truncate(numRows, rhs, lhs) + } +} + +class TruncDateExprMeta(expr: TruncDate, + override val conf: RapidsConf, + override val parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends BinaryExprMeta[TruncDate](expr, conf, parent, rule) { + + // Store the format string as we need to process it on the CPU later on. + private val fmtStr = extractStringLit(expr.format) + + override def convertToGpu(date: Expression, format: Expression): GpuExpression = { + GpuTruncDate(date, format, fmtStr) + } +} + +class TruncTimestampExprMeta(expr: TruncTimestamp, + override val conf: RapidsConf, + override val parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends BinaryExprMeta[TruncTimestamp](expr, conf, parent, rule) { + + // Store the format string as we need to process it on the CPU later on. + private val fmtStr = extractStringLit(expr.format) + + override def convertToGpu(format: Expression, timestamp: Expression): GpuExpression = { + GpuTruncTimestamp(format, timestamp, expr.timeZoneId, fmtStr) } } From 2201fde199fa3b4af1ed2c97056a81ed209a6e10 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Thu, 12 Dec 2024 13:27:12 -0800 Subject: [PATCH 18/18] Rename variable Signed-off-by: Nghia Truong --- .../org/apache/spark/sql/rapids/datetimeExpressions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 14ec6bbd090..ad14ab400dd 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -1535,9 +1535,9 @@ abstract class GpuTruncDateTime(fmtStr: Option[String]) extends GpuBinaryExpress DateTimeUtils.truncate(datetimeCol.getBase, fmtCol.getBase) } - protected def truncate(datetimeVal: GpuScalar, formatCol: GpuColumnVector): ColumnVector = { + protected def truncate(datetimeVal: GpuScalar, fmtCol: GpuColumnVector): ColumnVector = { withResource(ColumnVector.fromScalar(datetimeVal.getBase, 1)) { datetimeCol => - DateTimeUtils.truncate(datetimeCol, formatCol.getBase) + DateTimeUtils.truncate(datetimeCol, fmtCol.getBase) } }