Skip to content

Commit

Permalink
Support date_format via GPU for non-UTC time zone
Browse files Browse the repository at this point in the history
Signed-off-by: Chong Gao <[email protected]>
  • Loading branch information
Chong Gao committed Dec 5, 2023
1 parent 877130a commit e3941c2
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
13 changes: 11 additions & 2 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,22 @@ def test_gettimestamp_ansi_exception():

supported_date_formats = ['yyyy-MM-dd', 'yyyy-MM', 'yyyy/MM/dd', 'yyyy/MM', 'dd/MM/yyyy',
'MM-dd', 'MM/dd', 'dd-MM', 'dd/MM']
@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn)
@pytest.mark.parametrize('date_format', [date_gen], ids=idfn)
@pytest.mark.parametrize('data_gen', date_n_time_gens, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_date_format(data_gen, date_format):
def test_date_format_for_date(data_gen, date_format):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)))

@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
@pytest.mark.xfail(is_dst_time_zone(), reason="only support non-DST time zone, refer to https://github.com/NVIDIA/spark-rapids/issues/6839")
def test_date_format_for_time(data_gen, date_format):
conf = {'spark.rapids.sql.nonUTC.enabled': True}
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)),
conf)

@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn)
# from 0001-02-01 to 9999-12-30
@pytest.mark.parametrize('data_gen', [LongGen(min_val=int(datetime(1, 2, 1).timestamp()), max_val=int(datetime(9999, 12, 30).timestamp()))], ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1681,8 +1681,9 @@ object GpuOverrides extends Logging {
.withPsNote(TypeEnum.STRING, "A limited number of formats are supported"),
TypeSig.STRING)),
(a, conf, p, r) => new UnixTimeExprMeta[DateFormatClass](a, conf, p, r) {
override def isTimeZoneSupported = true
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuDateFormatClass(lhs, rhs, strfFormat)
GpuDateFormatClass(lhs, rhs, strfFormat, a.timeZoneId)
}
),
expr[ToUnixTimestamp](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,15 @@ case class GpuDateFormatClass(timestamp: Expression,
// we aren't using rhs as it was already converted in the GpuOverrides while creating the
// expressions map and passed down here as strfFormat
withResource(lhs.getBase.asTimestampMicroseconds()) { tsVector =>
tsVector.asStrings(strfFormat)
if (GpuOverrides.isUTCTimezone(zoneId)) {
// UTC time zone
tsVector.asStrings(strfFormat)
} else {
// Non-UTC TZ
withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(tsVector, zoneId.normalized())) {
shifted => shifted.asStrings(strfFormat)
}
}
}
}

Expand Down

0 comments on commit e3941c2

Please sign in to comment.