Skip to content

Commit

Permalink
Fix cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Chong Gao committed Nov 29, 2023
1 parent a743a7a commit bb1cbca
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 9 deletions.
3 changes: 3 additions & 0 deletions integration_tests/src/main/python/aqe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from marks import ignore_order, allow_non_gpu
from spark_session import with_cpu_session, is_databricks113_or_later

# allow non gpu when time zone is non-UTC because of https://github.com/NVIDIA/spark-rapids/issues/9653'
non_utc_allow=['HashAggregateExec', 'ProjectExec', 'FilterExec', 'FileSourceScanExec', 'BatchScanExec', 'CollectLimitExec', 'DeserializeToObjectExec', 'DataWritingCommandExec', 'WriteFilesExec', 'ShuffleExchangeExec'] if is_not_utc() else []

_adaptive_conf = { "spark.sql.adaptive.enabled": "true" }

def create_skew_df(spark, length):
Expand Down
7 changes: 5 additions & 2 deletions integration_tests/src/main/python/csv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def test_csv_fallback(spark_tmp_path, read_func, disable_conf, spark_tmp_table_f
'CORRECTED',
'EXCEPTION'
])
@allow_non_gpu(*non_utc_allow)
def test_date_formats_round_trip(spark_tmp_path, date_format, v1_enabled_list, ansi_enabled, time_parser_policy):
gen = StructGen([('a', DateGen())], nullable=False)
data_path = spark_tmp_path + '/CSV_DATA'
Expand Down Expand Up @@ -365,13 +366,15 @@ def test_date_formats_round_trip(spark_tmp_path, date_format, v1_enabled_list, a
.csv(data_path),
conf=updated_conf)


non_utc_allow_for_test_read_valid_and_invalid_dates=['FileSourceScanExec', 'BatchScanExec'] if is_not_utc() else []
@pytest.mark.parametrize('filename', ["date.csv"])
@pytest.mark.parametrize('v1_enabled_list', ["", "csv"])
@pytest.mark.parametrize('ansi_enabled', ["true", "false"])
@pytest.mark.parametrize('time_parser_policy', [
pytest.param('LEGACY', marks=pytest.mark.allow_non_gpu('BatchScanExec,FileSourceScanExec')),
'CORRECTED',
'EXCEPTION'
pytest.param('CORRECTED', marks=pytest.mark.allow_non_gpu(*non_utc_allow_for_test_read_valid_and_invalid_dates)),
pytest.param('EXCEPTION', marks=pytest.mark.allow_non_gpu(*non_utc_allow_for_test_read_valid_and_invalid_dates))
])
def test_read_valid_and_invalid_dates(std_input_path, filename, v1_enabled_list, ansi_enabled, time_parser_policy):
data_path = std_input_path + '/' + filename
Expand Down
6 changes: 2 additions & 4 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,12 +810,11 @@ def spark_fn(spark_session):
@allow_non_gpu('ObjectHashAggregateExec', 'SortAggregateExec',
'ShuffleExchangeExec', 'HashPartitioning', 'SortExec',
'SortArray', 'Alias', 'Literal', 'Count', 'CollectList', 'CollectSet',
'AggregateExpression', 'ProjectExec')
'AggregateExpression', 'ProjectExec', *non_utc_allow)
@pytest.mark.parametrize('data_gen', _full_gen_data_for_collect_op, ids=idfn)
@pytest.mark.parametrize('replace_mode', _replace_modes_non_distinct, ids=idfn)
@pytest.mark.parametrize('aqe_enabled', ['false', 'true'], ids=idfn)
@pytest.mark.parametrize('use_obj_hash_agg', ['false', 'true'], ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_hash_groupby_collect_partial_replace_fallback(data_gen,
replace_mode,
aqe_enabled,
Expand Down Expand Up @@ -857,13 +856,12 @@ def test_hash_groupby_collect_partial_replace_fallback(data_gen,
@allow_non_gpu('ObjectHashAggregateExec', 'SortAggregateExec',
'ShuffleExchangeExec', 'HashPartitioning', 'SortExec',
'SortArray', 'Alias', 'Literal', 'Count', 'CollectList', 'CollectSet',
'AggregateExpression', 'ProjectExec')
'AggregateExpression', 'ProjectExec', *non_utc_allow)
@pytest.mark.parametrize('data_gen', _full_gen_data_for_collect_op, ids=idfn)
@pytest.mark.parametrize('replace_mode', _replace_modes_single_distinct, ids=idfn)
@pytest.mark.parametrize('aqe_enabled', ['false', 'true'], ids=idfn)
@pytest.mark.parametrize('use_obj_hash_agg', ['false', 'true'], ids=idfn)
@pytest.mark.xfail(condition=is_databricks104_or_later(), reason='https://github.com/NVIDIA/spark-rapids/issues/4963')
@allow_non_gpu(*non_utc_allow)
def test_hash_groupby_collect_partial_replace_with_distinct_fallback(data_gen,
replace_mode,
aqe_enabled,
Expand Down
7 changes: 4 additions & 3 deletions integration_tests/src/main/python/hive_delimited_text_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def read_impl(spark):
return read_impl


non_utc_allow_for_test_basic_hive_text_read=['HiveTableScanExec'] if is_not_utc() else []
@pytest.mark.skipif(is_spark_cdh(),
reason="Hive text reads are disabled on CDH, as per "
"https://github.com/NVIDIA/spark-rapids/pull/7628")
Expand Down Expand Up @@ -187,7 +188,7 @@ def read_impl(spark):
('hive-delim-text/carriage-return', StructType([StructField("str", StringType())]), {}),
('hive-delim-text/carriage-return-err', StructType([StructField("str", StringType())]), {}),
], ids=idfn)
@allow_non_gpu(*non_utc_allow)
@allow_non_gpu(*non_utc_allow_for_test_basic_hive_text_read)
def test_basic_hive_text_read(std_input_path, name, schema, spark_tmp_table_factory, options):
assert_gpu_and_cpu_are_equal_collect(read_hive_text_sql(std_input_path + '/' + name,
schema, spark_tmp_table_factory, options),
Expand Down Expand Up @@ -240,7 +241,7 @@ def read_hive_text_table(spark, text_table_name, fields="my_field"):
"https://github.com/NVIDIA/spark-rapids/pull/7628")
@approximate_float
@pytest.mark.parametrize('data_gen', hive_text_supported_gens, ids=idfn)
@allow_non_gpu(*non_utc_allow)
@allow_non_gpu(*non_utc_allow_for_test_basic_hive_text_read)
def test_hive_text_round_trip(spark_tmp_path, data_gen, spark_tmp_table_factory):
gen = StructGen([('my_field', data_gen)], nullable=False)
data_path = spark_tmp_path + '/hive_text_table'
Expand Down Expand Up @@ -527,7 +528,7 @@ def create_table_with_compressed_files(spark):
('hive-delim-text/carriage-return', StructType([StructField("str", StringType())]), {}),
('hive-delim-text/carriage-return-err', StructType([StructField("str", StringType())]), {}),
], ids=idfn)
@allow_non_gpu(*non_utc_allow)
@allow_non_gpu(*non_utc_allow_for_test_basic_hive_text_read)
def test_basic_hive_text_write(std_input_path, input_dir, schema, spark_tmp_table_factory, mode, options):
# Configure table options, including schema.
if options is None:
Expand Down
2 changes: 2 additions & 0 deletions integration_tests/src/main/python/json_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,12 @@ def test_json_input_meta(spark_tmp_path, v1_enabled_list):
'input_file_block_length()'),
conf=updated_conf)

allow_non_gpu_for_json_scan = ['FileSourceScanExec', 'BatchScanExec'] if is_not_utc() else []
json_supported_date_formats = ['yyyy-MM-dd', 'yyyy/MM/dd', 'yyyy-MM', 'yyyy/MM',
'MM-yyyy', 'MM/yyyy', 'MM-dd-yyyy', 'MM/dd/yyyy', 'dd-MM-yyyy', 'dd/MM/yyyy']
@pytest.mark.parametrize('date_format', json_supported_date_formats, ids=idfn)
@pytest.mark.parametrize('v1_enabled_list', ["", "json"])
@allow_non_gpu(*allow_non_gpu_for_json_scan)
def test_json_date_formats_round_trip(spark_tmp_path, date_format, v1_enabled_list):
gen = StructGen([('a', DateGen())], nullable=False)
data_path = spark_tmp_path + '/JSON_DATA'
Expand Down

0 comments on commit bb1cbca

Please sign in to comment.