diff --git a/integration_tests/src/main/python/conftest.py b/integration_tests/src/main/python/conftest.py index cf21c405899..a9b2f6146ec 100644 --- a/integration_tests/src/main/python/conftest.py +++ b/integration_tests/src/main/python/conftest.py @@ -154,7 +154,14 @@ def pytest_runtest_setup(item): _inject_oom = item.get_closest_marker('inject_oom') datagen_overrides = item.get_closest_marker('datagen_overrides') if datagen_overrides: - _test_datagen_random_seed = datagen_overrides.kwargs.get('seed', _test_datagen_random_seed) + try: + seed = datagen_overrides.kwargs["seed"] + except KeyError: + raise Exception("datagen_overrides requires an override seed value") + + override_seed = datagen_overrides.kwargs.get('condition', True) + if override_seed: + _test_datagen_random_seed = seed order = item.get_closest_marker('ignore_order') if order: diff --git a/integration_tests/src/main/python/map_test.py b/integration_tests/src/main/python/map_test.py index b35789b62f5..d0e064535d5 100644 --- a/integration_tests/src/main/python/map_test.py +++ b/integration_tests/src/main/python/map_test.py @@ -189,7 +189,9 @@ def query_map_scalar(spark): @allow_non_gpu('WindowLocalExec') -@datagen_overrides(seed=0, reason='https://github.com/NVIDIA/spark-rapids/issues/9683') +@datagen_overrides(seed=0, condition=is_before_spark_314() + or (not is_before_spark_320() and is_before_spark_323()) + or (not is_before_spark_330() and is_before_spark_331()), reason="https://issues.apache.org/jira/browse/SPARK-40089") @pytest.mark.parametrize('data_gen', supported_key_map_gens, ids=idfn) @allow_non_gpu(*non_utc_allow) def test_map_scalars_supported_key_types(data_gen): diff --git a/integration_tests/src/main/python/spark_session.py b/integration_tests/src/main/python/spark_session.py index 50eaa7c49a9..606f9a31dc4 100644 --- a/integration_tests/src/main/python/spark_session.py +++ b/integration_tests/src/main/python/spark_session.py @@ -128,12 +128,18 @@ def is_before_spark_312(): def is_before_spark_313(): return spark_version() < "3.1.3" +def is_before_spark_314(): + return spark_version() < "3.1.4" + def is_before_spark_320(): return spark_version() < "3.2.0" def is_before_spark_322(): return spark_version() < "3.2.2" +def is_before_spark_323(): + return spark_version() < "3.2.3" + def is_before_spark_330(): return spark_version() < "3.3.0"