diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index d6972a931b7..4e66615366f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -1525,13 +1525,7 @@ case class GpuLastDay(startDate: Expression) input.getBase.lastDayOfMonth() } -abstract class GpuTruncDateTime(datetime: Expression, format: Expression) - extends GpuBinaryExpression with ImplicitCastInputTypes { - - // We always store date/time to the left and format to the right expressions. - override def left: Expression = datetime - override def right: Expression = format - +trait GpuTruncDateTime extends GpuBinaryExpression with ImplicitCastInputTypes { override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = { DateTimeUtils.truncate(lhs.getBase, rhs.getBase) } @@ -1561,8 +1555,11 @@ abstract class GpuTruncDateTime(datetime: Expression, format: Expression) } } -case class GpuTruncDate(date: Expression, format: Expression) - extends GpuTruncDateTime(date, format) { +case class GpuTruncDate(date: Expression, format: Expression) extends GpuTruncDateTime { + // We always store date/time to the left and format to the right expressions. + // This is to make sure `doColumnar` will call `DateTimeUtils.truncate` with the correct order. + override def left: Expression = date + override def right: Expression = format override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) @@ -1571,8 +1568,11 @@ case class GpuTruncDate(date: Expression, format: Expression) override def prettyName: String = "trunc" } -case class GpuTruncTimestamp(format: Expression, timestamp: Expression) - extends GpuTruncDateTime(timestamp, format) { +case class GpuTruncTimestamp(format: Expression, timestamp: Expression) extends GpuTruncDateTime { + // We always store date/time to the left and format to the right expressions. + // This is to make sure `doColumnar` will call `DateTimeUtils.truncate` with the correct order. + override def left: Expression = timestamp + override def right: Expression = format override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType)