diff --git a/integration_tests/src/main/python/cast_test.py b/integration_tests/src/main/python/cast_test.py index f9c54de3400..f8ca81a91a2 100644 --- a/integration_tests/src/main/python/cast_test.py +++ b/integration_tests/src/main/python/cast_test.py @@ -540,6 +540,11 @@ def test_cast_timestamp_to_string(): lambda spark: unary_op_df(spark, timestamp_gen) .selectExpr("cast(a as string)")) +def test_cast_timestamp_to_date(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, timestamp_gen) + .selectExpr("cast(a as date)")) + @pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') def test_cast_day_time_interval_to_string(): _assert_cast_to_string_equal(DayTimeIntervalGen(start_field='day', end_field='day', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {}) @@ -692,9 +697,9 @@ def test_cast_int_to_string_not_UTC(): lambda spark: unary_op_df(spark, int_gen, 100).selectExpr("a", "CAST(a AS STRING) as str"), {"spark.sql.session.timeZone": "+08"}) -not_utc_fallback_test_params = [(timestamp_gen, 'STRING'), (timestamp_gen, 'DATE'), +not_utc_fallback_test_params = [(timestamp_gen, 'STRING'), # python does not like year 0, and with time zones the default start date can become year 0 :( - (DateGen(start=date(1, 1, 3)), 'TIMESTAMP'), + (DateGen(start=date(1, 1, 1)), 'TIMESTAMP'), (SetValuesGen(StringType(), ['2023-03-20 10:38:50', '2023-03-20 10:39:02']), 'TIMESTAMP')] @allow_non_gpu('ProjectExec') diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index dd2819a7832..383a24018af 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -24,7 +24,7 @@ from spark_session import is_before_spark_340, with_cpu_session import sre_yield import struct -from conftest import skip_unless_precommit_tests,get_datagen_seed, is_not_utc +from conftest import skip_unless_precommit_tests, get_datagen_seed, is_not_utc import time import os from functools import lru_cache diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index ca62a5a75e0..8eab7f1e231 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -14,7 +14,7 @@ import pytest from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_error -from conftest import is_utc, is_supported_time_zone +from conftest import is_utc, is_supported_time_zone, get_test_tz from data_gen import * from datetime import date, datetime, timezone from marks import ignore_order, incompat, allow_non_gpu, datagen_overrides, tz_sensitive_test @@ -379,6 +379,7 @@ def fun(spark): assert_gpu_and_cpu_are_equal_collect(fun, conf=copy_and_update(parser_policy_dic, ansi_enabled_conf)) + @pytest.mark.parametrize('ansi_enabled', [True, False], ids=['ANSI_ON', 'ANSI_OFF']) @pytest.mark.parametrize('data_gen', date_n_time_gens, ids=idfn) @tz_sensitive_test @@ -427,22 +428,42 @@ def test_string_unix_timestamp_ansi_exception(): error_message="Exception", conf=ansi_enabled_conf) -@pytest.mark.parametrize('data_gen', [StringGen('[0-9]{4}-0[1-9]-[0-2][1-8]')], ids=idfn) -@pytest.mark.parametrize('ansi_enabled', [True, False], ids=['ANSI_ON', 'ANSI_OFF']) -@allow_non_gpu(*non_utc_allow) -def test_gettimestamp(data_gen, ansi_enabled): +@tz_sensitive_test +@pytest.mark.skipif(not is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported") +@pytest.mark.parametrize('parser_policy', ["CORRECTED", "EXCEPTION"], ids=idfn) +def test_to_timestamp(parser_policy): + gen = StringGen("[0-9]{3}[1-9]-(0[1-9]|1[0-2])-(0[1-9]|[1-2][0-9]) ([0-1][0-9]|2[0-3]):[0-5][0-9]:[0-5][0-9]") + if get_test_tz() == "Asia/Shanghai": + # ensure some times around transition are tested + gen = gen.with_special_case("1991-04-14 02:00:00")\ + .with_special_case("1991-04-14 02:30:00")\ + .with_special_case("1991-04-14 03:00:00")\ + .with_special_case("1991-09-15 02:00:00")\ + .with_special_case("1991-09-15 02:30:00")\ + .with_special_case("1991-09-15 03:00:00") + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, gen) + .select(f.col("a"), f.to_timestamp(f.col("a"), "yyyy-MM-dd HH:mm:ss")), + { "spark.sql.legacy.timeParserPolicy": parser_policy}) + + +@tz_sensitive_test +@pytest.mark.skipif(not is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported") +@pytest.mark.parametrize("ansi_enabled", [True, False], ids=['ANSI_ON', 'ANSI_OFF']) +def test_to_date(ansi_enabled): assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen).select(f.to_date(f.col("a"), "yyyy-MM-dd")), + lambda spark : unary_op_df(spark, date_gen) + .select(f.to_date(f.col("a").cast('string'), "yyyy-MM-dd")), {'spark.sql.ansi.enabled': ansi_enabled}) - +@tz_sensitive_test +@pytest.mark.skipif(not is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported") @pytest.mark.parametrize('data_gen', [StringGen('0[1-9][0-9]{4}')], ids=idfn) -@allow_non_gpu(*non_utc_allow) -def test_gettimestamp_format_MMyyyy(data_gen): +def test_to_date_format_MMyyyy(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, data_gen).select(f.to_date(f.col("a"), "MMyyyy"))) -def test_gettimestamp_ansi_exception(): +def test_to_date_ansi_exception(): assert_gpu_and_cpu_error( lambda spark : invalid_date_string_df(spark).select(f.to_date(f.col("a"), "yyyy-MM-dd")).collect(), error_message="Exception", diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py index 39c7018a2e7..6f422c8afbf 100644 --- a/integration_tests/src/main/python/window_function_test.py +++ b/integration_tests/src/main/python/window_function_test.py @@ -1686,7 +1686,6 @@ def test_window_first_last_nth_ignore_nulls(data_gen): @ignore_order(local=True) -@allow_non_gpu(*non_utc_allow) def test_to_date_with_window_functions(): """ This test ensures that date expressions participating alongside window aggregations diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 6c0bdbfe41e..9bf9144db0e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -26,13 +26,14 @@ import ai.rapids.cudf.{BinaryOp, CaptureGroups, ColumnVector, ColumnView, Decima import ai.rapids.cudf import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.jni.CastStrings +import com.nvidia.spark.rapids.jni.{CastStrings, GpuTimeZoneDB} import com.nvidia.spark.rapids.shims.{AnsiUtil, GpuCastShims, GpuIntervalUtils, GpuTypeShims, SparkShimImpl, YearParseUtil} import org.apache.commons.text.StringEscapeUtils import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, NullIntolerant, TimeZoneAwareExpression, UnaryExpression} import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.GpuToTimestamp.replaceSpecialDates import org.apache.spark.sql.rapids.shims.RapidsErrorUtils @@ -86,6 +87,13 @@ abstract class CastExprMetaBase[INPUT <: UnaryExpression with TimeZoneAwareExpre val fromType: DataType = cast.child.dataType val toType: DataType = cast.dataType + override def isTimeZoneSupported: Boolean = { + (fromType, toType) match { + case (TimestampType, DateType) => true // this is for to_date(...) + case _ => false + } + } + override def tagExprForGpu(): Unit = { recursiveTagExprForGpuCheck() } @@ -209,13 +217,16 @@ object CastOptions { * @param ansiMode Whether the cast should be ANSI compliant * @param stringToDateAnsiMode Whether to cast String to Date using ANSI compliance * @param castToJsonString Whether to use JSON format when casting to String + * @param ignoreNullFieldsInStructs Whether to omit null values when converting to JSON + * @param timeZoneId If cast is timezone aware, the timezone needed */ class CastOptions( legacyCastComplexTypesToString: Boolean, ansiMode: Boolean, stringToDateAnsiMode: Boolean, val castToJsonString: Boolean = false, - val ignoreNullFieldsInStructs: Boolean = true) extends Serializable { + val ignoreNullFieldsInStructs: Boolean = true, + val timeZoneId: Option[String] = Option.empty[String]) extends Serializable { /** * Retuns the left bracket to use when surrounding brackets when converting @@ -614,6 +625,12 @@ object GpuCast { case (_: IntegerType | ShortType | ByteType, ym: DataType) if GpuTypeShims.isSupportedYearMonthType(ym) => GpuIntervalUtils.intToYearMonthInterval(input, ym) + case (TimestampType, DateType) if options.timeZoneId.isDefined => + val zoneId = DateTimeUtils.getZoneId(options.timeZoneId.get) + withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(input.asInstanceOf[ColumnVector], + zoneId.normalized())) { + shifted => shifted.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType)) + } case _ => input.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType)) } @@ -1807,7 +1824,8 @@ case class GpuCast( import GpuCast._ private val options: CastOptions = - new CastOptions(legacyCastComplexTypesToString, ansiMode, stringToDateAnsiModeEnabled) + new CastOptions(legacyCastComplexTypesToString, ansiMode, stringToDateAnsiModeEnabled, + timeZoneId = timeZoneId) // when ansi mode is enabled, some cast expressions can throw exceptions on invalid inputs override def hasSideEffects: Boolean = super.hasSideEffects || { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index bc73338ec87..f37b59f709a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -1126,8 +1126,9 @@ abstract class BaseExprMeta[INPUT <: Expression]( if (!isTimeZoneSupported) return checkUTCTimezone(this) // Level 3 check - if (!GpuTimeZoneDB.isSupportedTimeZone(getZoneId())) { - willNotWorkOnGpu(TimeZoneDB.timezoneNotSupportedStr(this.wrapped.getClass.toString)) + val zoneId = getZoneId() + if (!GpuTimeZoneDB.isSupportedTimeZone(zoneId)) { + willNotWorkOnGpu(TimeZoneDB.timezoneNotSupportedStr(zoneId.toString)) } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/expressions/rapids/Timestamp.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/expressions/rapids/Timestamp.scala index b441627c928..f6be1cca6bc 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/expressions/rapids/Timestamp.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/expressions/rapids/Timestamp.scala @@ -38,8 +38,9 @@ object TimeStamp { .withPsNote(TypeEnum.STRING, "A limited number of formats are supported"), TypeSig.STRING)), (a, conf, p, r) => new UnixTimeExprMeta[GetTimestamp](a, conf, p, r) { + override def isTimeZoneSupported = true override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { - GpuGetTimestamp(lhs, rhs, sparkFormat, strfFormat) + GpuGetTimestamp(lhs, rhs, sparkFormat, strfFormat, a.timeZoneId) } }) ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 3a338c91a09..09d4a977084 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -856,7 +856,7 @@ abstract class GpuToTimestamp val tmp = lhs.dataType match { case _: StringType => // rhs is ignored we already parsed the format - if (getTimeParserPolicy == LegacyTimeParserPolicy) { + val res = if (getTimeParserPolicy == LegacyTimeParserPolicy) { parseStringAsTimestampWithLegacyParserPolicy( lhs, sparkFormat, @@ -871,6 +871,11 @@ abstract class GpuToTimestamp DType.TIMESTAMP_MICROSECONDS, failOnError) } + if (GpuOverrides.isUTCTimezone(zoneId)) { + res + } else { + GpuTimeZoneDB.fromTimestampToUtcTimestamp(res, zoneId) + } case _: DateType => timeZoneId match { case Some(_) =>