Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support date_format via Gpu for non-UTC time zone [databricks] #9721

Merged
merged 11 commits into from
Dec 14, 2023
Merged
28 changes: 24 additions & 4 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,32 @@ def test_gettimestamp_ansi_exception():
supported_date_formats = ['yyyy-MM-dd', 'yyyy-MM', 'yyyy/MM/dd', 'yyyy/MM', 'dd/MM/yyyy',
'MM-dd', 'MM/dd', 'dd-MM', 'dd/MM']
@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn)
@pytest.mark.parametrize('data_gen', date_n_time_gens, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_date_format(data_gen, date_format):
@pytest.mark.parametrize('data_gen', [date_gen], ids=idfn)
@allow_non_gpu('ProjectExec')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to check whether it's supported timezone?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

date_format(date) will introduce cast(date as timestamp) which is not supported in non-UTC now.
After we support cast(date as timestamp), then we will update this case.

def test_date_format_for_date(data_gen, date_format):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)))

@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
@pytest.mark.skipif(not is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
def test_date_format_for_time(data_gen, date_format):
conf = {'spark.rapids.sql.nonUTC.enabled': True}
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)),
conf)

@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
@pytest.mark.skipif(is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
@allow_non_gpu('ProjectExec')
def test_date_format_for_time_fall_back(data_gen, date_format):
conf = {'spark.rapids.sql.nonUTC.enabled': True}
assert_gpu_fallback_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)),
'ProjectExec',
conf)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


@pytest.mark.parametrize('date_format', supported_date_formats + ['yyyyMMdd'], ids=idfn)
# from 0001-02-01 to 9999-12-30 to avoid 'year 0 is out of range'
@pytest.mark.parametrize('data_gen', [LongGen(min_val=int(datetime(1, 2, 1).timestamp()), max_val=int(datetime(9999, 12, 30).timestamp()))], ids=idfn)
Expand Down Expand Up @@ -576,7 +596,7 @@ def test_timestamp_seconds_long_overflow():
error_message='long overflow')

# For Decimal(20, 7) case, the data is both 'Overflow' and 'Rounding necessary', this case is to verify
# that 'Rounding necessary' check is before 'Overflow' check. So we should make sure that every decimal
# that 'Rounding necessary' check is before 'Overflow' check. So we should make sure that every decimal
# value in test data is 'Rounding necessary' by setting full_precision=True to avoid leading and trailing zeros.
# Otherwise, the test data will bypass the 'Rounding necessary' check and throw an 'Overflow' error.
@pytest.mark.parametrize('data_gen', [DecimalGen(7, 7, full_precision=True), DecimalGen(20, 7, full_precision=True)], ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1681,8 +1681,9 @@ object GpuOverrides extends Logging {
.withPsNote(TypeEnum.STRING, "A limited number of formats are supported"),
TypeSig.STRING)),
(a, conf, p, r) => new UnixTimeExprMeta[DateFormatClass](a, conf, p, r) {
override def isTimeZoneSupported = true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is due to the fact that DateFormatClass is not considered a TimeZoneAwareExpression, but requires support for non-UTC timezones. It's the last check for timezone requirement in Expressions. By default, this flag is false.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ttnghia FYI. Some original notes around this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems not a good practice to set some value as false by default then override true like this. Can we compute that value based on the input in the object constructor?

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: code style. not sure whether a line break need here.

GpuDateFormatClass(lhs, rhs, strfFormat)
GpuDateFormatClass(lhs, rhs, strfFormat, a.timeZoneId)
}
),
expr[ToUnixTimestamp](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,15 @@ case class GpuDateFormatClass(timestamp: Expression,
// we aren't using rhs as it was already converted in the GpuOverrides while creating the
// expressions map and passed down here as strfFormat
withResource(lhs.getBase.asTimestampMicroseconds()) { tsVector =>
tsVector.asStrings(strfFormat)
if (GpuOverrides.isUTCTimezone(zoneId)) {
// UTC time zone
tsVector.asStrings(strfFormat)
} else {
// Non-UTC TZ
withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(tsVector, zoneId.normalized())) {
shifted => shifted.asStrings(strfFormat)
}
}
}
}

Expand Down