-
Notifications
You must be signed in to change notification settings - Fork 237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Move timezone check to each operator [databricks] #9482
Changes from 2 commits
d8e77b2
3f781a4
d5a6d7a
b3fa3ee
c31b2e3
a7c8996
2878c5c
705f8b5
882b751
aec893c
7f81644
bcc1f5b
505b72e
07942ea
3033bc3
a852455
0358cd4
f6ccadd
21d5a69
e2aa9da
9eab476
e231a80
71928a0
ca23932
ee60bea
d403c59
dd5ad0b
058e13e
fc3a678
938c649
befa39d
cf2c621
c298d5f
09e772c
f43a8f9
5882cc3
7a53dc2
7bd9ef8
9817c4e
f8505b7
fa1c84d
fbbbd5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,10 +11,10 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import pytest | ||
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_error | ||
|
||
from data_gen import * | ||
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_fallback_collect, assert_gpu_and_cpu_error, assert_spark_exception, with_gpu_session | ||
from datetime import date, datetime, timezone | ||
from marks import ignore_order, incompat, allow_non_gpu | ||
from pyspark.sql.types import * | ||
|
@@ -558,3 +558,87 @@ def test_timestamp_millis_long_overflow(): | |
def test_timestamp_micros(data_gen): | ||
assert_gpu_and_cpu_are_equal_collect( | ||
lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_micros(a)")) | ||
|
||
|
||
# used by timezone test cases | ||
def get_timezone_df(spark): | ||
schema = StructType([ | ||
StructField("ts_str_col", StringType()), | ||
StructField("long_col", LongType()), | ||
StructField("ts_col", TimestampType()), | ||
StructField("date_col", DateType()), | ||
StructField("date_str_col", StringType()), | ||
]) | ||
data = [ | ||
('1970-01-01 00:00:00', 0, datetime(1970, 1, 1), date(1970, 1, 1), '1970-01-01'), | ||
('1970-01-01 00:00:00', 0, datetime(1970, 1, 1), date(1970, 1, 1), '1970-01-01'), | ||
] | ||
return spark.createDataFrame(SparkContext.getOrCreate().parallelize(data),schema) | ||
|
||
# used by timezone test cases, specify all the sqls that will be impacted by non-utc timezone | ||
time_zone_sql_conf_pairs = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: There're some functions related to timezone (not supported yet), mentioned in Spark built-in function website. We can add some comments mentioning here.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For current_timezone, it just returns the session timezone, we can ignore it for this PR. For MakeTimestamp and ConvertTimezone, it's recorded in this follow on issue: #9570 |
||
("select minute(ts_col) from tab", {}), | ||
("select second(ts_col) from tab", {}), | ||
("select hour(ts_col) from tab", {}), | ||
("select date_col + (interval 10 days 3 seconds) from tab", {}), | ||
("select date_format(ts_col, 'yyyy-MM-dd HH:mm:ss') from tab", {}), | ||
("select unix_timestamp(ts_col) from tab", {"spark.rapids.sql.improvedTimeOps.enabled": "true"}), | ||
("select to_unix_timestamp(ts_str_col) from tab", {"spark.rapids.sql.improvedTimeOps.enabled": "false"}), | ||
("select to_unix_timestamp(ts_col) from tab", {"spark.rapids.sql.improvedTimeOps.enabled": "true"}), | ||
("select to_date(date_str_col, 'yyyy-MM-dd') from tab", {}), # test GpuGetTimestamp | ||
("select to_date(date_str_col) from tab", {}), | ||
("select from_unixtime(long_col, 'yyyy-MM-dd HH:mm:ss') from tab", {}), | ||
("select cast(ts_col as string) from tab", {}), # cast | ||
("select cast(ts_col as date) from tab", {}), # cast | ||
("select cast(date_col as TIMESTAMP) from tab", {}), # cast | ||
("select to_timestamp(ts_str_col) from tab", {"spark.rapids.sql.improvedTimeOps.enabled": "false"}), | ||
("select to_timestamp(ts_str_col) from tab", {"spark.rapids.sql.improvedTimeOps.enabled": "true"}), | ||
] | ||
|
||
|
||
@allow_non_gpu("ProjectExec") | ||
@pytest.mark.parametrize('sql, extra_conf', time_zone_sql_conf_pairs) | ||
def test_timezone_for_operators_with_non_utc(sql, extra_conf): | ||
# timezone is non-utc, should fallback to CPU | ||
timezone_conf = {"spark.sql.session.timeZone": "+08:00", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we make the time zone string a param to the test? Just because I would like to test a few more time zones than just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
"spark.rapids.sql.hasExtendedYearValues": "false", | ||
"spark.rapids.sql.castStringToTimestamp.enabled": "true"} | ||
all_conf = copy_and_update(timezone_conf, extra_conf) | ||
def gen_sql_df(spark): | ||
df = get_timezone_df(spark) | ||
df.createOrReplaceTempView("tab") | ||
return spark.sql(sql) | ||
assert_gpu_fallback_collect(gen_sql_df, "ProjectExec", all_conf) | ||
|
||
|
||
@pytest.mark.parametrize('sql, conf', time_zone_sql_conf_pairs) | ||
def test_timezone_for_operators_with_utc(sql, conf): | ||
# timezone is utc, should be supported by GPU | ||
timezone_conf = {"spark.sql.session.timeZone": "UTC", | ||
"spark.rapids.sql.hasExtendedYearValues": "false", | ||
"spark.rapids.sql.castStringToTimestamp.enabled": "true",} | ||
conf = copy_and_update(timezone_conf, conf) | ||
def gen_sql_df(spark): | ||
df = get_timezone_df(spark) | ||
df.createOrReplaceTempView("tab") | ||
return spark.sql(sql) | ||
assert_gpu_and_cpu_are_equal_collect(gen_sql_df, conf) | ||
|
||
|
||
@allow_non_gpu("ProjectExec") | ||
def test_timezone_for_operator_from_utc_timestamp_with_non_utc(): | ||
# timezone is non-utc, should fallback to CPU | ||
def gen_sql_df(spark): | ||
df = get_timezone_df(spark) | ||
df.createOrReplaceTempView("tab") | ||
return spark.sql("select from_utc_timestamp(ts_col, '+08:00') from tab") | ||
assert_gpu_fallback_collect(gen_sql_df, "ProjectExec") | ||
|
||
|
||
def test_timezone_for_operator_from_utc_timestamp_with_utc(): | ||
# timezone is utc, should be supported by GPU | ||
def gen_sql_df(spark): | ||
df = get_timezone_df(spark) | ||
df.createOrReplaceTempView("tab") | ||
return spark.sql("select from_utc_timestamp(ts_col, '+00:00') from tab").collect() | ||
with_gpu_session(gen_sql_df) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1045,7 +1045,7 @@ | |
<arg>-Yno-adapted-args</arg> | ||
<arg>-Ywarn-unused:imports,locals,patvars,privates</arg> | ||
<arg>-Xlint:missing-interpolator</arg> | ||
<arg>-Xfatal-warnings</arg> | ||
<!-- <arg>-Xfatal-warnings</arg> --> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Revert this back when we try to commit it. |
||
</args> | ||
<jvmArgs> | ||
<jvmArg>-Xms1024m</jvmArg> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -669,9 +669,7 @@ object GpuOverrides extends Logging { | |
case FloatType => true | ||
case DoubleType => true | ||
case DateType => true | ||
case TimestampType => | ||
TypeChecks.areTimestampsSupported(ZoneId.systemDefault()) && | ||
TypeChecks.areTimestampsSupported(SQLConf.get.sessionLocalTimeZone) | ||
case TimestampType => true | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to consider the timezone check for scan and writer parts? AFAIK, when scanning data from Parquet, If applies, we should add some python tests as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check is used by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we will need to check these. For me, anything that does not have a test that shows it works fully in at least one other time zone must fall back to the CPU if it sees a timestamp that is not UTC. Parquet for example has the rebase mode for older timestamps that requires knowing the timezone to do properly. |
||
case StringType => true | ||
case dt: DecimalType if allowDecimal => dt.precision <= DType.DECIMAL64_MAX_PRECISION | ||
case NullType => allowNull | ||
|
@@ -1655,6 +1653,9 @@ object GpuOverrides extends Logging { | |
willNotWorkOnGpu("interval months isn't supported") | ||
} | ||
} | ||
|
||
// need timezone support, here check timezone | ||
checkTimeZoneId(dateAddInterval.zoneId) | ||
} | ||
|
||
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = | ||
|
@@ -1668,6 +1669,12 @@ object GpuOverrides extends Logging { | |
.withPsNote(TypeEnum.STRING, "A limited number of formats are supported"), | ||
TypeSig.STRING)), | ||
(a, conf, p, r) => new UnixTimeExprMeta[DateFormatClass](a, conf, p, r) { | ||
|
||
override def tagExprForGpu(): Unit = { | ||
// need timezone support, here check timezone | ||
checkTimeZoneId(a.zoneId) | ||
} | ||
|
||
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = | ||
GpuDateFormatClass(lhs, rhs, strfFormat) | ||
} | ||
|
@@ -1682,6 +1689,12 @@ object GpuOverrides extends Logging { | |
.withPsNote(TypeEnum.STRING, "A limited number of formats are supported"), | ||
TypeSig.STRING)), | ||
(a, conf, p, r) => new UnixTimeExprMeta[ToUnixTimestamp](a, conf, p, r) { | ||
|
||
override def tagExprForGpu(): Unit = { | ||
// need timezone support, here check timezone | ||
checkTimeZoneId(a.zoneId) | ||
} | ||
|
||
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { | ||
if (conf.isImprovedTimestampOpsEnabled) { | ||
// passing the already converted strf string for a little optimization | ||
|
@@ -1701,6 +1714,12 @@ object GpuOverrides extends Logging { | |
.withPsNote(TypeEnum.STRING, "A limited number of formats are supported"), | ||
TypeSig.STRING)), | ||
(a, conf, p, r) => new UnixTimeExprMeta[UnixTimestamp](a, conf, p, r) { | ||
|
||
override def tagExprForGpu(): Unit = { | ||
// need timezone support, here check timezone | ||
checkTimeZoneId(a.zoneId) | ||
} | ||
|
||
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { | ||
if (conf.isImprovedTimestampOpsEnabled) { | ||
// passing the already converted strf string for a little optimization | ||
|
@@ -1715,6 +1734,11 @@ object GpuOverrides extends Logging { | |
ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, | ||
TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), | ||
(hour, conf, p, r) => new UnaryExprMeta[Hour](hour, conf, p, r) { | ||
|
||
override def tagExprForGpu(): Unit = { | ||
// need timezone support, here check timezone | ||
checkTimeZoneId(hour.zoneId) | ||
} | ||
|
||
override def convertToGpu(expr: Expression): GpuExpression = GpuHour(expr) | ||
}), | ||
|
@@ -1724,6 +1748,11 @@ object GpuOverrides extends Logging { | |
TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), | ||
(minute, conf, p, r) => new UnaryExprMeta[Minute](minute, conf, p, r) { | ||
|
||
override def tagExprForGpu(): Unit = { | ||
// need timezone support, here check timezone | ||
checkTimeZoneId(minute.zoneId) | ||
} | ||
|
||
override def convertToGpu(expr: Expression): GpuExpression = | ||
GpuMinute(expr) | ||
}), | ||
|
@@ -1733,6 +1762,11 @@ object GpuOverrides extends Logging { | |
TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), | ||
(second, conf, p, r) => new UnaryExprMeta[Second](second, conf, p, r) { | ||
|
||
override def tagExprForGpu(): Unit = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we try and have a TimeZoneAwareExprMeta, or something similar that makes it super simple to do this? We might even be able to back it into ExprMeta itself, just by checking if the class that this wraps is also There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm guessing the best approach is to put it directly in ExprMeta since otherwise we would have to mixin the TimeZoneAwareExprMeta for the different functions. I'm guessing that functions requiring timezone will span the gamut of Unary/Binary/Ternary/Quaternary/Agg/etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe wrap the check in a method and override it whenever a function starts supporting alternate timezones. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like GpuCast will be a first exception to this idea: #6835 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm testing all the existing test cases with adding non-UTC time zone config to identify all the failed cases:
Then I'll update the failed cases. |
||
// need timezone support, here check timezone | ||
checkTimeZoneId(second.zoneId) | ||
} | ||
|
||
override def convertToGpu(expr: Expression): GpuExpression = | ||
GpuSecond(expr) | ||
}), | ||
|
@@ -1767,6 +1801,12 @@ object GpuOverrides extends Logging { | |
.withPsNote(TypeEnum.STRING, "Only a limited number of formats are supported"), | ||
TypeSig.STRING)), | ||
(a, conf, p, r) => new UnixTimeExprMeta[FromUnixTime](a, conf, p, r) { | ||
|
||
override def tagExprForGpu(): Unit = { | ||
// need timezone support, here check timezone | ||
checkTimeZoneId(a.zoneId) | ||
} | ||
|
||
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = | ||
// passing the already converted strf string for a little optimization | ||
GpuFromUnixTime(lhs, rhs, strfFormat) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -374,13 +374,12 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { | |
case Some(value) => ZoneId.of(value) | ||
case None => throw new RuntimeException(s"Driver time zone cannot be determined.") | ||
} | ||
if (TypeChecks.areTimestampsSupported(driverTimezone)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. may off-topic. Considering the configuration There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Here But for our
I think yes, becasue we want to avoid the issue |
||
val executorTimezone = ZoneId.systemDefault() | ||
if (executorTimezone.normalized() != driverTimezone.normalized()) { | ||
throw new RuntimeException(s" Driver and executor timezone mismatch. " + | ||
s"Driver timezone is $driverTimezone and executor timezone is " + | ||
s"$executorTimezone. Set executor timezone to $driverTimezone.") | ||
} | ||
|
||
val executorTimezone = ZoneId.systemDefault() | ||
if (executorTimezone.normalized() != driverTimezone.normalized()) { | ||
throw new RuntimeException(s" Driver and executor timezone mismatch. " + | ||
s"Driver timezone is $driverTimezone and executor timezone is " + | ||
s"$executorTimezone. Set executor timezone to $driverTimezone.") | ||
} | ||
|
||
GpuCoreDumpHandler.executorInit(conf, pluginContext) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -363,8 +363,7 @@ final class TypeSig private( | |
case FloatType => check.contains(TypeEnum.FLOAT) | ||
case DoubleType => check.contains(TypeEnum.DOUBLE) | ||
case DateType => check.contains(TypeEnum.DATE) | ||
case TimestampType if check.contains(TypeEnum.TIMESTAMP) => | ||
TypeChecks.areTimestampsSupported() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Originally invoked by shuffle meta, FileFormatChecks, tag AST and other.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Took a quick look at cudf. For AST, I noticed timezone info is not respected yet. |
||
case TimestampType => check.contains(TypeEnum.TIMESTAMP) | ||
case StringType => check.contains(TypeEnum.STRING) | ||
case dt: DecimalType => | ||
check.contains(TypeEnum.DECIMAL) && | ||
|
@@ -840,7 +839,7 @@ object TypeChecks { | |
areTimestampsSupported(ZoneId.systemDefault()) && | ||
areTimestampsSupported(SQLConf.get.sessionLocalTimeZone) | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: extra space. |
||
def isTimezoneSensitiveType(dataType: DataType): Boolean = { | ||
dataType == TimestampType | ||
} | ||
|
@@ -1502,7 +1501,20 @@ class CastChecks extends ExprChecks { | |
|
||
def gpuCanCast(from: DataType, to: DataType): Boolean = { | ||
val (checks, _) = getChecksAndSigs(from) | ||
checks.isSupportedByPlugin(to) | ||
checks.isSupportedByPlugin(to) && gpuCanCastConsiderTimezone(from, to) | ||
} | ||
|
||
def gpuCanCastConsiderTimezone(from: DataType, to: DataType) = { | ||
// need timezone support, here check timezone | ||
(from, to) match { | ||
case (_:StringType, _:TimestampType) => TypeChecks.areTimestampsSupported() | ||
case (_:TimestampType, _:StringType) => TypeChecks.areTimestampsSupported() | ||
case (_:StringType, _:DateType) => TypeChecks.areTimestampsSupported() | ||
case (_:DateType, _:StringType) => TypeChecks.areTimestampsSupported() | ||
case (_:TimestampType, _:DateType) => TypeChecks.areTimestampsSupported() | ||
case (_:DateType, _:TimestampType) => TypeChecks.areTimestampsSupported() | ||
case _ => true | ||
} | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
/* | ||
* Copyright (c) 2021-2022, NVIDIA CORPORATION. | ||
* Copyright (c) 2021-2023, NVIDIA CORPORATION. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not touched? |
||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
|
@@ -38,6 +38,12 @@ object TimeStamp { | |
.withPsNote(TypeEnum.STRING, "A limited number of formats are supported"), | ||
TypeSig.STRING)), | ||
(a, conf, p, r) => new UnixTimeExprMeta[GetTimestamp](a, conf, p, r) { | ||
|
||
override def tagExprForGpu(): Unit = { | ||
// need timezone support, here check timezone | ||
checkTimeZoneId(a.zoneId) | ||
} | ||
|
||
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { | ||
GpuGetTimestamp(lhs, rhs, sparkFormat, strfFormat) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: one line break after license header.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done