Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50436][PYTHON][TESTS] Use assertDataFrameEqual in pyspark.sql.tests.test_udf #49001

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,15 @@ def test_udf_in_filter_on_top_of_outer_join(self):
right = self.spark.createDataFrame([Row(a=1)])
df = left.join(right, on="a", how="left_outer")
df = df.withColumn("b", udf(lambda x: "x")(df.a))
self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b="x")])
self.assertDataFrameEqual(df.filter('b = "x"'), [Row(a=1, b="x")])

def test_udf_in_filter_on_top_of_join(self):
# regression test for SPARK-18589
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a, b: a == b, BooleanType())
df = left.crossJoin(right).filter(f("a", "b"))
self.assertEqual(df.collect(), [Row(a=1, b=1)])
self.assertDataFrameEqual(df, [Row(a=1, b=1)])

def test_udf_in_join_condition(self):
# regression test for SPARK-25314
Expand All @@ -243,7 +243,7 @@ def test_udf_in_join_condition(self):
df.collect()
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
df = left.join(right, f("a", "b"))
self.assertEqual(df.collect(), [Row(a=1, b=1)])
self.assertDataFrameEqual(df, [Row(a=1, b=1)])

def test_udf_in_left_outer_join_condition(self):
# regression test for SPARK-26147
Expand All @@ -256,7 +256,7 @@ def test_udf_in_left_outer_join_condition(self):
# The Python UDF only refer to attributes from one side, so it's evaluable.
df = left.join(right, f("a") == col("b").cast("string"), how="left_outer")
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
self.assertEqual(df.collect(), [Row(a=1, b=1)])
self.assertDataFrameEqual(df, [Row(a=1, b=1)])

def test_udf_and_common_filter_in_join_condition(self):
# regression test for SPARK-25314
Expand All @@ -266,7 +266,7 @@ def test_udf_and_common_filter_in_join_condition(self):
f = udf(lambda a, b: a == b, BooleanType())
df = left.join(right, [f("a", "b"), left.a1 == right.b1])
# do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
self.assertDataFrameEqual(df, [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])

def test_udf_not_supported_in_join_condition(self):
# regression test for SPARK-25314
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_udf_as_join_condition(self):
f = udf(lambda a: a, IntegerType())

df = left.join(right, [f("a") == f("b"), left.a1 == right.b1])
self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
self.assertDataFrameEqual(df, [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])

def test_udf_without_arguments(self):
self.spark.catalog.registerFunction("foo", lambda: "bar")
Expand Down Expand Up @@ -331,7 +331,7 @@ def test_udf_with_filter_function(self):

my_filter = udf(lambda a: a < 2, BooleanType())
sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
self.assertEqual(sel.collect(), [Row(key=1, value="1")])
self.assertDataFrameEqual(sel, [Row(key=1, value="1")])

def test_udf_with_variant_input(self):
df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as string)) v")
Expand Down Expand Up @@ -461,7 +461,7 @@ def test_udf_with_aggregate_function(self):

my_filter = udf(lambda a: a == 1, BooleanType())
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
self.assertEqual(sel.collect(), [Row(key=1)])
self.assertDataFrameEqual(sel, [Row(key=1)])

my_copy = udf(lambda x: x, IntegerType())
my_add = udf(lambda a, b: int(a + b), IntegerType())
Expand All @@ -471,7 +471,7 @@ def test_udf_with_aggregate_function(self):
.agg(sum(my_strlen(col("value"))).alias("s"))
.select(my_add(col("k"), col("s")).alias("t"))
)
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
self.assertDataFrameEqual(sel, [Row(t=4), Row(t=3)])

def test_udf_in_generate(self):
from pyspark.sql.functions import explode
Expand Down Expand Up @@ -505,7 +505,7 @@ def test_udf_with_order_by_and_limit(self):
my_copy = udf(lambda x: x, IntegerType())
df = self.spark.range(10).orderBy("id")
res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
self.assertEqual(res.collect(), [Row(id=0, copy=0)])
self.assertDataFrameEqual(res, [Row(id=0, copy=0)])

def test_udf_registration_returns_udf(self):
df = self.spark.range(10)
Expand Down Expand Up @@ -838,12 +838,12 @@ def test_datasource_with_udf(self):
for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn("c", c1)
expected = df.withColumn("c", lit(2))
self.assertEqual(expected.collect(), result.collect())
self.assertDataFrameEqual(expected, result)

for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn("c", c2)
expected = df.withColumn("c", col("i") + 1)
self.assertEqual(expected.collect(), result.collect())
self.assertDataFrameEqual(expected, result)

for df in [filesource_df, datasource_df, datasource_v2_df]:
for f in [f1, f2]:
Expand Down Expand Up @@ -902,7 +902,7 @@ def test_udf_in_subquery(self):
result = self.spark.sql(
"select i from values(0L) as data(i) where i in (select id from v)"
)
self.assertEqual(result.collect(), [Row(i=0)])
self.assertDataFrameEqual(result, [Row(i=0)])

def test_udf_globals_not_overwritten(self):
@udf("string")
Expand Down