Skip to content

Commit

Permalink
Fixed array_tests for Spark 4.0.0 [databricks] (NVIDIA#11048)
Browse files Browse the repository at this point in the history
* Fixed array_tests

* Signing off

Signed-off-by: Raza Jafri <[email protected]>

* Disable ANSI for failing tests

---------

Signed-off-by: Raza Jafri <[email protected]>
  • Loading branch information
razajafri authored Jun 28, 2024
1 parent 7dc52bc commit dd62000
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_gpu_fallback_collect
from data_gen import *
from conftest import is_databricks_runtime
from marks import incompat, allow_non_gpu
from spark_session import is_before_spark_313, is_before_spark_330, is_databricks113_or_later, is_spark_330_or_later, is_databricks104_or_later, is_spark_33X, is_spark_340_or_later, is_spark_330, is_spark_330cdh
from marks import incompat, allow_non_gpu, disable_ansi_mode
from spark_session import *
from pyspark.sql.types import *
from pyspark.sql.types import IntegralType
from pyspark.sql.functions import array_contains, col, element_at, lit, array
Expand Down Expand Up @@ -103,11 +103,13 @@

@pytest.mark.parametrize('data_gen', array_item_test_gens, ids=idfn)
@pytest.mark.parametrize('index_gen', array_index_gens, ids=idfn)
@disable_ansi_mode
def test_array_item(data_gen, index_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, data_gen, index_gen).selectExpr('a[b]'))

@pytest.mark.parametrize('data_gen', array_item_test_gens, ids=idfn)
@disable_ansi_mode
def test_array_item_lit_ordinal(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
Expand Down Expand Up @@ -145,8 +147,10 @@ def test_array_item_with_strict_index(strict_index_enabled, index):

# No need to test this for multiple data types for array. Only one is enough, but with two kinds of invalid index.
@pytest.mark.parametrize('index', [-2, 100, array_neg_index_gen, array_out_index_gen], ids=idfn)
@disable_ansi_mode
def test_array_item_ansi_fail_invalid_index(index):
message = "SparkArrayIndexOutOfBoundsException" if (is_databricks104_or_later() or is_spark_330_or_later()) else "java.lang.ArrayIndexOutOfBoundsException"
message = "SparkArrayIndexOutOfBoundsException" if (is_databricks104_or_later() or is_spark_330_or_later() and is_before_spark_400()) else \
"ArrayIndexOutOfBoundsException"
if isinstance(index, int):
test_func = lambda spark: unary_op_df(spark, ArrayGen(int_gen)).select(col('a')[index]).collect()
else:
Expand All @@ -171,6 +175,7 @@ def test_array_item_ansi_not_fail_all_null_data():
decimal_gen_32bit, decimal_gen_64bit, decimal_gen_128bit, binary_gen,
StructGen([['child0', StructGen([['child01', IntegerGen()]])], ['child1', string_gen], ['child2', float_gen]], nullable=False),
StructGen([['child0', byte_gen], ['child1', string_gen], ['child2', float_gen]], nullable=False)], ids=idfn)
@disable_ansi_mode
def test_make_array(data_gen):
(s1, s2) = with_cpu_session(
lambda spark: gen_scalars_for_sql(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen)))
Expand Down Expand Up @@ -212,6 +217,7 @@ def test_orderby_array_of_structs(data_gen):
@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen,
float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen], ids=idfn)
@disable_ansi_mode
def test_array_contains(data_gen):
arr_gen = ArrayGen(data_gen)
literal = with_cpu_session(lambda spark: gen_scalar(data_gen, force_no_nulls=True))
Expand Down Expand Up @@ -239,6 +245,7 @@ def test_array_contains_for_nans(data_gen):


@pytest.mark.parametrize('data_gen', array_item_test_gens, ids=idfn)
@disable_ansi_mode
def test_array_element_at(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, data_gen, array_no_zero_index_gen).selectExpr(
Expand All @@ -252,8 +259,9 @@ def test_array_element_at(data_gen):

# No need tests for multiple data types for list data. Only one is enough.
@pytest.mark.parametrize('index', [100, array_out_index_gen], ids=idfn)
@disable_ansi_mode
def test_array_element_at_ansi_fail_invalid_index(index):
message = "ArrayIndexOutOfBoundsException" if is_before_spark_330() else "SparkArrayIndexOutOfBoundsException"
message = "ArrayIndexOutOfBoundsException" if is_before_spark_330() or not is_before_spark_400() else "SparkArrayIndexOutOfBoundsException"
if isinstance(index, int):
test_func = lambda spark: unary_op_df(spark, ArrayGen(int_gen)).select(
element_at(col('a'), index)).collect()
Expand Down Expand Up @@ -282,9 +290,10 @@ def test_array_element_at_ansi_not_fail_all_null_data():

@pytest.mark.parametrize('index', [0, array_zero_index_gen], ids=idfn)
@pytest.mark.parametrize('ansi_enabled', [False, True], ids=idfn)
@disable_ansi_mode
def test_array_element_at_zero_index_fail(index, ansi_enabled):
if is_spark_340_or_later():
message = "org.apache.spark.SparkRuntimeException: [INVALID_INDEX_OF_ZERO] The index 0 is invalid"
message = "SparkRuntimeException: [INVALID_INDEX_OF_ZERO] The index 0 is invalid"
elif is_databricks113_or_later():
message = "org.apache.spark.SparkRuntimeException: [ELEMENT_AT_BY_INDEX_ZERO] The index 0 is invalid"
else:
Expand All @@ -303,6 +312,7 @@ def test_array_element_at_zero_index_fail(index, ansi_enabled):


@pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn)
@disable_ansi_mode
def test_array_transform(data_gen):
def do_it(spark):
columns = ['a', 'b',
Expand Down

0 comments on commit dd62000

Please sign in to comment.