-
Notifications
You must be signed in to change notification settings - Fork 237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support TimeAdd for non-UTC time zone #10068
Changes from 4 commits
212a0b1
4f52067
c7dc304
89c9305
4c9485f
64a0232
f4e85d7
7f4237c
2de0e6e
3d5f2b9
4574ef1
2014495
c6102f3
66f3dd6
0801287
26152f5
cf10350
a4334d9
be5f813
06c6c47
ec7c4b0
511f8ee
a3ab495
0fb8ecb
a9deb35
1f3410f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -473,7 +473,7 @@ object GpuScalar extends Logging { | |
* | ||
* This class is introduced because many expressions require both the cudf Scalar and its | ||
* corresponding Scala value to complete their computations. e.g. 'GpuStringSplit', | ||
* 'GpuStringLocate', 'GpuDivide', 'GpuDateAddInterval', 'GpuTimeMath' ... | ||
* 'GpuStringLocate', 'GpuDivide', 'GpuDateAddInterval', 'GpuTimeAdd' ... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: Why name changed? It seems different over different Spark version. We can comment both in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
* So only either a cudf Scalar or a Scala value can not support such cases, unless copying data | ||
* between the host and the device each time being asked for. | ||
* | ||
|
@@ -493,7 +493,7 @@ object GpuScalar extends Logging { | |
* happens. | ||
* | ||
* Another reason why storing the Scala value in addition to the cudf Scalar is | ||
* `GpuDateAddInterval` and 'GpuTimeMath' have different algorithms with the 3 members of | ||
* `GpuDateAddInterval` and 'GpuTimeAdd' have different algorithms with the 3 members of | ||
* a `CalendarInterval`, which can not be supported by a single cudf Scalar now. | ||
* | ||
* Do not create a GpuScalar from the constructor, instead call the factory APIs above. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,21 +21,96 @@ | |
spark-rapids-shim-json-lines ***/ | ||
package org.apache.spark.sql.rapids.shims | ||
|
||
import ai.rapids.cudf.{ColumnVector, ColumnView, Scalar} | ||
import java.util.concurrent.TimeUnit | ||
|
||
import org.apache.spark.sql.catalyst.expressions.{Expression, TimeZoneAwareExpression} | ||
import org.apache.spark.sql.rapids.GpuTimeMath | ||
import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar} | ||
import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar} | ||
import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} | ||
import com.nvidia.spark.rapids.GpuOverrides | ||
import com.nvidia.spark.rapids.RapidsPluginImplicits._ | ||
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB | ||
import com.nvidia.spark.rapids.shims.ShimBinaryExpression | ||
|
||
case class GpuTimeAdd(start: Expression, | ||
interval: Expression, | ||
timeZoneId: Option[String] = None) | ||
extends GpuTimeMath(start, interval, timeZoneId) { | ||
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression} | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.sql.vectorized.ColumnarBatch | ||
import org.apache.spark.unsafe.types.CalendarInterval | ||
|
||
case class GpuTimeAdd( | ||
start: Expression, | ||
interval: Expression, | ||
timeZoneId: Option[String] = None) | ||
extends ShimBinaryExpression | ||
with GpuExpression | ||
with TimeZoneAwareExpression | ||
with ExpectsInputTypes | ||
with Serializable { | ||
|
||
def this(start: Expression, interval: Expression) = this(start, interval, None) | ||
|
||
override def left: Expression = start | ||
override def right: Expression = interval | ||
|
||
override def toString: String = s"$left - $right" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is an add why do we show it as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
override def sql: String = s"${left.sql} - ${right.sql}" | ||
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) | ||
|
||
override def dataType: DataType = TimestampType | ||
|
||
override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess | ||
|
||
val microSecondsInOneDay: Long = TimeUnit.DAYS.toMicros(1) | ||
|
||
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = { | ||
copy(timeZoneId = Option(timeZoneId)) | ||
} | ||
|
||
override def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = { | ||
us.add(us_s) | ||
override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { | ||
withResourceIfAllowed(left.columnarEval(batch)) { lhs => | ||
revans2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
withResourceIfAllowed(right.columnarEvalAny(batch)) { rhs => | ||
// lhs is start, rhs is interval | ||
(lhs, rhs) match { | ||
case (l, intvlS: GpuScalar) if intvlS.dataType.isInstanceOf[CalendarIntervalType] => | ||
// Scalar does not support 'CalendarInterval' now, so use | ||
// the Scala value instead. | ||
// Skip the null check because it wll be detected by the following calls. | ||
val intvl = intvlS.getValue.asInstanceOf[CalendarInterval] | ||
if (intvl.months != 0) { | ||
throw new UnsupportedOperationException("Months aren't supported at the moment") | ||
} | ||
val usToSub = intvl.days * microSecondsInOneDay + intvl.microseconds | ||
if (usToSub != 0) { | ||
val res = if (GpuOverrides.isUTCTimezone(zoneId)) { | ||
withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, usToSub)) { d => | ||
timestampAddDuration(l.getBase, d) | ||
} | ||
} else { | ||
val utcRes = withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(l.getBase, | ||
zoneId)) { utcTimestamp => | ||
withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, usToSub)) { | ||
d => timestampAddDuration(utcTimestamp, d) | ||
} | ||
} | ||
withResource(utcRes) { _ => | ||
GpuTimeZoneDB.fromTimestampToUtcTimestamp(utcRes, zoneId) | ||
} | ||
} | ||
GpuColumnVector.from(res, dataType) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} else { | ||
l.incRefCount() | ||
} | ||
case _ => | ||
throw new UnsupportedOperationException("only column and interval arguments " + | ||
s"are supported, got left: ${lhs.getClass} right: ${rhs.getClass}") | ||
} | ||
} | ||
} | ||
} | ||
|
||
private def timestampAddDuration(cv: ColumnView, duration: BinaryOperable): ColumnVector = { | ||
// Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion, | ||
// and currently BinaryOperable.implicitConversion return Long | ||
// Directly specify the return type is TIMESTAMP_MICROSECONDS | ||
cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need this configuration any more.