Skip to content

Commit

Permalink
Add date_format test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Chong Gao committed Dec 8, 2023
1 parent 7e61915 commit 96924eb
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
2 changes: 0 additions & 2 deletions integration_tests/src/main/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import random
import warnings

from time_zone_utils import *

# TODO redo _spark stuff using fixtures
#
# Don't import pyspark / _spark directly in conftest globally
Expand Down
20 changes: 20 additions & 0 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,26 @@ def test_date_format_for_date(data_gen, date_format):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)))

@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
@pytest.mark.skipif(not is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
def test_date_format_for_time(data_gen, date_format):
conf = {'spark.rapids.sql.nonUTC.enabled': True}
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)),
conf)

@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
@pytest.mark.skipif(is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
@allow_non_gpu('ProjectExec')
def test_date_format_for_time_fall_back(data_gen, date_format):
conf = {'spark.rapids.sql.nonUTC.enabled': True}
assert_gpu_fallback_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)),
'ProjectExec',
conf)

@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn)
# from 0001-02-01 to 9999-12-30 to avoid 'year 0 is out of range'
@pytest.mark.parametrize('data_gen', [LongGen(min_val=int(datetime(1, 2, 1).timestamp()), max_val=int(datetime(9999, 12, 30).timestamp()))], ids=idfn)
Expand Down

0 comments on commit 96924eb

Please sign in to comment.