Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Move timezone check to each operator [databricks] #9482

Closed
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
d8e77b2
Add test cases for timezone awarded operators
Oct 19, 2023
3f781a4
Move timezone check to each operator
Oct 19, 2023
d5a6d7a
Merge branch 23.12
Oct 27, 2023
b3fa3ee
Update
Oct 27, 2023
c31b2e3
debug
Oct 27, 2023
a7c8996
debug
Oct 27, 2023
2878c5c
Add timezone test mark
Oct 27, 2023
705f8b5
Minor update
Nov 1, 2023
882b751
Fix failed cmp case on Spark311; Restore a python import; minor changes
Nov 1, 2023
aec893c
Fix failure on Databricks
Nov 2, 2023
7f81644
Update test cases for Databricks
Nov 2, 2023
bcc1f5b
Update test cases for Databricks
Nov 2, 2023
505b72e
Fix delta lake test cases.
Nov 3, 2023
07942ea
Fix delta lake test cases.
Nov 3, 2023
3033bc3
Remove the skip logic when time zone is not UTC
Nov 7, 2023
a852455
Add time zone config to set non-UTC
Nov 7, 2023
0358cd4
Add fallback case for cast_test.py
Nov 7, 2023
f6ccadd
Add fallback case for cast_test.py
Nov 7, 2023
21d5a69
Add fallback case for cast_test.py
Nov 8, 2023
e2aa9da
Add fallback case for cast_test.py
Nov 8, 2023
9eab476
Update split_list
Nov 8, 2023
e231a80
Add fallback case for cast_test.py
Nov 8, 2023
71928a0
Add fallback case for cast_test.py
Nov 8, 2023
ca23932
Add fallback cases for cmp_test.py
Nov 9, 2023
ee60bea
Add fallback tests for json_test.py
firestarman Nov 9, 2023
d403c59
add non_utc fallback for parquet_write qa_select and window_function …
thirtiseven Nov 9, 2023
dd5ad0b
Add fallback tests for conditionals_test.py
winningsix Nov 9, 2023
058e13e
Add fallback cases for collection_ops_test.py
Nov 9, 2023
fc3a678
add fallback tests for date_time_test
thirtiseven Nov 9, 2023
938c649
clean up spark_session.py
thirtiseven Nov 9, 2023
befa39d
Add fallback tests for explain_test and csv_test
winningsix Nov 9, 2023
cf2c621
Update test case
Nov 9, 2023
c298d5f
update test case
Nov 9, 2023
09e772c
Add default value
Nov 10, 2023
f43a8f9
Remove useless is_tz_utc
Nov 10, 2023
5882cc3
Fix fallback cases
Nov 10, 2023
7a53dc2
Add bottom check for time zone; Fix ORC check
Nov 13, 2023
7bd9ef8
By default, ExecCheck do not check UTC time zone
Nov 13, 2023
9817c4e
For common expr like AttributeReference, just skip the UTC checking
Nov 13, 2023
f8505b7
For common expr like AttributeReference, just skip the UTC checking
Nov 13, 2023
fa1c84d
For common expr like AttributeReference, just skip the UTC checking
Nov 13, 2023
fbbbd5b
Update test cases
Nov 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion integration_tests/src/main/python/aqe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pyspark.sql.types import *
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_cpu_and_gpu_are_equal_collect_with_capture
from data_gen import *
from marks import ignore_order, allow_non_gpu
from marks import ignore_order, allow_non_gpu, disable_timezone_test
from spark_session import with_cpu_session, is_databricks113_or_later

_adaptive_conf = { "spark.sql.adaptive.enabled": "true" }
Expand Down Expand Up @@ -195,6 +195,7 @@ def do_it(spark):
@ignore_order(local=True)
@allow_non_gpu('BroadcastNestedLoopJoinExec', 'Cast', 'DateSub', *db_113_cpu_bnlj_join_allow)
@pytest.mark.parametrize('join', joins, ids=idfn)
@disable_timezone_test
def test_aqe_join_reused_exchange_inequality_condition(spark_tmp_path, join):
data_path = spark_tmp_path + '/PARQUET_DATA'
def prep(spark):
Expand Down
13 changes: 10 additions & 3 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from data_gen import *
from spark_session import is_before_spark_320, is_before_spark_330, is_spark_340_or_later, \
is_databricks113_or_later
from marks import allow_non_gpu, approximate_float
from marks import allow_non_gpu, approximate_float, disable_timezone_test
from pyspark.sql.types import *
from spark_init_internal import spark_version
from datetime import date, datetime
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_cast_empty_string_to_int():
def test_cast_nested(data_gen, to_type):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(f.col('a').cast(to_type)))

@disable_timezone_test
def test_cast_string_date_valid_format():
# In Spark 3.2.0+ the valid format changed, and we cannot support all of the format.
# This provides values that are valid in all of those formats.
Expand Down Expand Up @@ -99,6 +99,7 @@ def test_cast_string_date_invalid_ansi_before_320():

# test Spark versions >= 320 and databricks, ANSI mode, valid values
@pytest.mark.skipif(is_before_spark_320(), reason="Spark versions(< 320) not support Ansi mode when casting string to date")
@disable_timezone_test
def test_cast_string_date_valid_ansi():
data_rows = [(v,) for v in valid_values_string_to_date]
assert_gpu_and_cpu_are_equal_collect(
Expand All @@ -109,6 +110,7 @@ def test_cast_string_date_valid_ansi():
# test Spark versions >= 320, ANSI mode
@pytest.mark.skipif(is_before_spark_320(), reason="ansi cast(string as date) throws exception only in 3.2.0+")
@pytest.mark.parametrize('invalid', invalid_values_string_to_date)
@disable_timezone_test
def test_cast_string_date_invalid_ansi(invalid):
assert_gpu_and_cpu_error(
lambda spark: spark.createDataFrame([(invalid,)], "a string").select(f.col('a').cast(DateType())).collect(),
Expand Down Expand Up @@ -139,6 +141,7 @@ def test_try_cast_fallback_340(invalid):
conf={'spark.rapids.sql.hasExtendedYearValues': False,
'spark.sql.ansi.enabled': True})

@disable_timezone_test
# test all Spark versions, non ANSI mode, invalid value will be converted to NULL
def test_cast_string_date_non_ansi():
data_rows = [(v,) for v in values_string_to_data]
Expand All @@ -150,6 +153,7 @@ def test_cast_string_date_non_ansi():
StringGen('[0-9]{1,4}-[0-3][0-9]-[0-5][0-9][ |T][0-3][0-9]:[0-6][0-9]:[0-6][0-9]'),
StringGen('[0-9]{1,4}-[0-3][0-9]-[0-5][0-9][ |T][0-3][0-9]:[0-6][0-9]:[0-6][0-9].[0-9]{0,6}Z?')],
ids=idfn)
@disable_timezone_test
def test_cast_string_ts_valid_format(data_gen):
# In Spark 3.2.0+ the valid format changed, and we cannot support all of the format.
# This provides values that are valid in all of those formats.
Expand Down Expand Up @@ -297,6 +301,7 @@ def _assert_cast_to_string_equal (data_gen, conf):

@pytest.mark.parametrize('data_gen', all_array_gens_for_cast_to_string, ids=idfn)
@pytest.mark.parametrize('legacy', ['true', 'false'])
@disable_timezone_test
def test_cast_array_to_string(data_gen, legacy):
_assert_cast_to_string_equal(
data_gen,
Expand All @@ -316,6 +321,7 @@ def test_cast_array_with_unmatched_element_to_string(data_gen, legacy):

@pytest.mark.parametrize('data_gen', basic_map_gens_for_cast_to_string, ids=idfn)
@pytest.mark.parametrize('legacy', ['true', 'false'])
@disable_timezone_test
def test_cast_map_to_string(data_gen, legacy):
_assert_cast_to_string_equal(
data_gen,
Expand All @@ -335,6 +341,7 @@ def test_cast_map_with_unmatched_element_to_string(data_gen, legacy):

@pytest.mark.parametrize('data_gen', [StructGen([[str(i), gen] for i, gen in enumerate(basic_array_struct_gens_for_cast_to_string)] + [["map", MapGen(ByteGen(nullable=False), null_gen)]])], ids=idfn)
@pytest.mark.parametrize('legacy', ['true', 'false'])
@disable_timezone_test
def test_cast_struct_to_string(data_gen, legacy):
_assert_cast_to_string_equal(
data_gen,
Expand Down Expand Up @@ -505,7 +512,7 @@ def test_cast_timestamp_to_numeric_non_ansi():
lambda spark: unary_op_df(spark, timestamp_gen)
.selectExpr("cast(a as byte)", "cast(a as short)", "cast(a as int)", "cast(a as long)",
"cast(a as float)", "cast(a as double)"))

@disable_timezone_test
def test_cast_timestamp_to_string():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, timestamp_gen)
Expand Down
2 changes: 2 additions & 0 deletions integration_tests/src/main/python/cmp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from asserts import assert_gpu_and_cpu_are_equal_collect
from data_gen import *
from marks import disable_timezone_test
from spark_session import with_cpu_session, is_before_spark_330
from pyspark.sql.types import *
import pyspark.sql.functions as f
Expand Down Expand Up @@ -291,6 +292,7 @@ def test_filter_with_project(data_gen):
# and some constants that then make it so all we need is the number of rows
# of input.
@pytest.mark.parametrize('op', ['>', '<'])
@disable_timezone_test
def test_empty_filter(op, spark_tmp_path):

def do_it(spark):
Expand Down
6 changes: 6 additions & 0 deletions integration_tests/src/main/python/collection_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error
from data_gen import *
from marks import disable_timezone_test
from pyspark.sql.types import *
from string_test import mk_str_gen
import pyspark.sql.functions as f
Expand Down Expand Up @@ -248,6 +249,7 @@ def test_sort_array_normalize_nans():
sequence_normal_no_step_integral_gens = [(gens[0], gens[1]) for
gens in sequence_normal_integral_gens]

@disable_timezone_test
@pytest.mark.parametrize('start_gen,stop_gen', sequence_normal_no_step_integral_gens, ids=idfn)
def test_sequence_without_step(start_gen, stop_gen):
assert_gpu_and_cpu_are_equal_collect(
Expand All @@ -257,6 +259,7 @@ def test_sequence_without_step(start_gen, stop_gen):
"sequence(20, b)"))

@pytest.mark.parametrize('start_gen,stop_gen,step_gen', sequence_normal_integral_gens, ids=idfn)
@disable_timezone_test
def test_sequence_with_step(start_gen, stop_gen, step_gen):
# Get a step scalar from the 'step_gen' which follows the rules.
step_gen.start(random.Random(0))
Expand Down Expand Up @@ -299,6 +302,7 @@ def test_sequence_with_step(start_gen, stop_gen, step_gen):
IntegerGen(min_val=0, max_val=0, special_cases=[]))
]

@disable_timezone_test
@pytest.mark.parametrize('start_gen,stop_gen,step_gen', sequence_illegal_boundaries_integral_gens, ids=idfn)
def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen):
assert_gpu_and_cpu_error(
Expand All @@ -314,6 +318,7 @@ def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen):
]

@pytest.mark.parametrize('stop_gen', sequence_too_long_length_gens, ids=idfn)
@disable_timezone_test
def test_sequence_too_long_sequence(stop_gen):
assert_gpu_and_cpu_error(
# To avoid OOM, reduce the row number to 1, it is enough to verify this case.
Expand Down Expand Up @@ -355,6 +360,7 @@ def get_sequence_data(gen, len):
mixed_schema)

# test for 3 cases mixed in a single dataset
@disable_timezone_test
def test_sequence_with_step_mixed_cases():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: get_sequence_cases_mixed_df(spark)
Expand Down
3 changes: 3 additions & 0 deletions integration_tests/src/main/python/conditionals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from asserts import assert_gpu_and_cpu_are_equal_collect
from data_gen import *
from marks import disable_timezone_test
from spark_session import is_before_spark_320, is_jvm_charset_utf8
from pyspark.sql.types import *
import pyspark.sql.functions as f
Expand Down Expand Up @@ -230,6 +231,7 @@ def test_conditional_with_side_effects_case_when(data_gen):
conf = test_conf)

@pytest.mark.parametrize('data_gen', [mk_str_gen('[a-z]{0,3}')], ids=idfn)
@disable_timezone_test
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a follow on issue to enable/test sequence for all time zones? We don't support timestamps for sequence currently and there are a lot of tests that are failing/skipped for no good reason.

def test_conditional_with_side_effects_sequence(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
Expand All @@ -240,6 +242,7 @@ def test_conditional_with_side_effects_sequence(data_gen):

@pytest.mark.skipif(is_before_spark_320(), reason='Earlier versions of Spark cannot cast sequence to string')
@pytest.mark.parametrize('data_gen', [mk_str_gen('[a-z]{0,3}')], ids=idfn)
@disable_timezone_test
def test_conditional_with_side_effects_sequence_cast(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
Expand Down
11 changes: 11 additions & 0 deletions integration_tests/src/main/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def pytest_runtest_setup(item):
if not item.config.getoption('pyarrow_test'):
pytest.skip('tests for pyarrow not configured to run')

if item.get_closest_marker('disable_timezone_test'):
pytest.skip('Skip because this case is not ready for non UTC time zone')

def pytest_configure(config):
global _runtime_env
_runtime_env = config.getoption('runtime_env')
Expand Down Expand Up @@ -415,3 +418,11 @@ def enable_fuzz_test(request):
if not enable_fuzz_test:
# fuzz tests are not required for any test runs
pytest.skip("fuzz_test not configured to run")

# Whether add a non UTC timezone test for all the existing test cases
# By default, test non UTC timezone
_enable_timezone_test = True

def disable_timezone_test():
global _enable_timezone_test
return _enable_timezone_test is False
4 changes: 4 additions & 0 deletions integration_tests/src/main/python/csv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def read_impl(spark):
@pytest.mark.parametrize('read_func', [read_csv_df, read_csv_sql])
@pytest.mark.parametrize('v1_enabled_list', ["", "csv"])
@pytest.mark.parametrize('ansi_enabled', ["true", "false"])
@disable_timezone_test
def test_basic_csv_read(std_input_path, name, schema, options, read_func, v1_enabled_list, ansi_enabled, spark_tmp_table_factory):
updated_conf=copy_and_update(_enable_all_types_conf, {
'spark.sql.sources.useV1SourceList': v1_enabled_list,
Expand Down Expand Up @@ -289,6 +290,7 @@ def test_csv_read_small_floats(std_input_path, name, schema, options, read_func,
@approximate_float
@pytest.mark.parametrize('data_gen', csv_supported_gens, ids=idfn)
@pytest.mark.parametrize('v1_enabled_list', ["", "csv"])
@disable_timezone_test
def test_round_trip(spark_tmp_path, data_gen, v1_enabled_list):
gen = StructGen([('a', data_gen)], nullable=False)
data_path = spark_tmp_path + '/CSV_DATA'
Expand Down Expand Up @@ -405,6 +407,7 @@ def test_read_valid_and_invalid_dates(std_input_path, filename, v1_enabled_list,
@pytest.mark.parametrize('ts_part', csv_supported_ts_parts)
@pytest.mark.parametrize('date_format', csv_supported_date_formats)
@pytest.mark.parametrize('v1_enabled_list', ["", "csv"])
@disable_timezone_test
def test_ts_formats_round_trip(spark_tmp_path, date_format, ts_part, v1_enabled_list):
full_format = date_format + ts_part
data_gen = TimestampGen()
Expand Down Expand Up @@ -619,6 +622,7 @@ def do_read(spark):

@allow_non_gpu('FileSourceScanExec', 'CollectLimitExec', 'DeserializeToObjectExec')
@pytest.mark.skipif(is_before_spark_340(), reason='`preferDate` is only supported in Spark 340+')
@disable_timezone_test
def test_csv_prefer_date_with_infer_schema(spark_tmp_path):
# start date ""0001-01-02" required due to: https://github.com/NVIDIA/spark-rapids/issues/5606
data_gens = [byte_gen, short_gen, int_gen, long_gen, boolean_gen, timestamp_gen, DateGen(start=date(1, 1, 2))]
Expand Down
11 changes: 11 additions & 0 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,3 +1173,14 @@ def get_25_partitions_df(spark):
StructField("c3", IntegerType())])
data = [[i, j, k] for i in range(0, 5) for j in range(0, 5) for k in range(0, 100)]
return spark.createDataFrame(data, schema)

# If timezone is non-UTC and rebase mode is LEGACY, writing to Parquet will fail because of GPU
# currently does not support. On Databricks, the default datetime rebase mode is LEGACY,
# it's different from regular Spark. Some of the cases will fall if timezone is non-UTC on DB.
# The following configs is for DB and ensure the rebase mode is not LEGACY on DB.
writer_confs_for_DB = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a huge change for our testing. I think what we really want for now is to have all of the file read/write operators need to fallback to the CPU if there is a timestamp at all involved. It does not matter if we have LEGACY or not. Until the code has been checked and updated, if needed, to work for other time zones we need to fall back to the CPU and skip the tests like we are doing in other places.

I am not saying that the changes you have done to the plugin are wrong. I am just saying that I don't think this is the right way to fix the tests, and I don't have time right now to fully review that all of the tests you updated didn't lose some kind of coverage that we need/want when this happened.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't want this in here. I am fine if we xfail the DB write tests and point to why when the timezone is not UTC. But I don't want to change what we are testing unless we go through each test and verify that we are not losing coverage. This PR is already big enough. If you want to do this change it should be split off into another PR.

'spark.sql.parquet.datetimeRebaseModeInWrite': 'CORRECTED',
'spark.sql.parquet.datetimeRebaseModeInRead': 'CORRECTED',
'spark.sql.parquet.int96RebaseModeInWrite' : 'CORRECTED',
'spark.sql.parquet.int96RebaseModeInRead' : 'CORRECTED'
}
Loading