Skip to content

Commit

Permalink
replace
Browse files Browse the repository at this point in the history
  • Loading branch information
xinrong-meng committed Nov 28, 2024
1 parent bb994d1 commit 69d580a
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,15 @@ def test_udf_in_filter_on_top_of_outer_join(self):
right = self.spark.createDataFrame([Row(a=1)])
df = left.join(right, on="a", how="left_outer")
df = df.withColumn("b", udf(lambda x: "x")(df.a))
self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b="x")])
self.assertDataFrameEqual(df.filter('b = "x"'), [Row(a=1, b="x")])

def test_udf_in_filter_on_top_of_join(self):
# regression test for SPARK-18589
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a, b: a == b, BooleanType())
df = left.crossJoin(right).filter(f("a", "b"))
self.assertEqual(df.collect(), [Row(a=1, b=1)])
self.assertDataFrameEqual(df, [Row(a=1, b=1)])

def test_udf_in_join_condition(self):
# regression test for SPARK-25314
Expand All @@ -243,7 +243,7 @@ def test_udf_in_join_condition(self):
df.collect()
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
df = left.join(right, f("a", "b"))
self.assertEqual(df.collect(), [Row(a=1, b=1)])
self.assertDataFrameEqual(df, [Row(a=1, b=1)])

def test_udf_in_left_outer_join_condition(self):
# regression test for SPARK-26147
Expand All @@ -256,7 +256,7 @@ def test_udf_in_left_outer_join_condition(self):
# The Python UDF only refer to attributes from one side, so it's evaluable.
df = left.join(right, f("a") == col("b").cast("string"), how="left_outer")
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
self.assertEqual(df.collect(), [Row(a=1, b=1)])
self.assertDataFrameEqual(df, [Row(a=1, b=1)])

def test_udf_and_common_filter_in_join_condition(self):
# regression test for SPARK-25314
Expand All @@ -266,7 +266,7 @@ def test_udf_and_common_filter_in_join_condition(self):
f = udf(lambda a, b: a == b, BooleanType())
df = left.join(right, [f("a", "b"), left.a1 == right.b1])
# do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
self.assertDataFrameEqual(df, [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])

def test_udf_not_supported_in_join_condition(self):
# regression test for SPARK-25314
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_udf_as_join_condition(self):
f = udf(lambda a: a, IntegerType())

df = left.join(right, [f("a") == f("b"), left.a1 == right.b1])
self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
self.assertDataFrameEqual(df, [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])

def test_udf_without_arguments(self):
self.spark.catalog.registerFunction("foo", lambda: "bar")
Expand Down Expand Up @@ -331,7 +331,7 @@ def test_udf_with_filter_function(self):

my_filter = udf(lambda a: a < 2, BooleanType())
sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
self.assertEqual(sel.collect(), [Row(key=1, value="1")])
self.assertDataFrameEqual(sel, [Row(key=1, value="1")])

def test_udf_with_variant_input(self):
df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as string)) v")
Expand Down Expand Up @@ -461,7 +461,7 @@ def test_udf_with_aggregate_function(self):

my_filter = udf(lambda a: a == 1, BooleanType())
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
self.assertEqual(sel.collect(), [Row(key=1)])
self.assertDataFrameEqual(sel, [Row(key=1)])

my_copy = udf(lambda x: x, IntegerType())
my_add = udf(lambda a, b: int(a + b), IntegerType())
Expand All @@ -471,7 +471,7 @@ def test_udf_with_aggregate_function(self):
.agg(sum(my_strlen(col("value"))).alias("s"))
.select(my_add(col("k"), col("s")).alias("t"))
)
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
self.assertDataFrameEqual(sel, [Row(t=4), Row(t=3)])

def test_udf_in_generate(self):
from pyspark.sql.functions import explode
Expand Down Expand Up @@ -505,7 +505,7 @@ def test_udf_with_order_by_and_limit(self):
my_copy = udf(lambda x: x, IntegerType())
df = self.spark.range(10).orderBy("id")
res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
self.assertEqual(res.collect(), [Row(id=0, copy=0)])
self.assertDataFrameEqual(res, [Row(id=0, copy=0)])

def test_udf_registration_returns_udf(self):
df = self.spark.range(10)
Expand Down Expand Up @@ -838,12 +838,12 @@ def test_datasource_with_udf(self):
for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn("c", c1)
expected = df.withColumn("c", lit(2))
self.assertEqual(expected.collect(), result.collect())
self.assertDataFrameEqual(expected, result)

for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn("c", c2)
expected = df.withColumn("c", col("i") + 1)
self.assertEqual(expected.collect(), result.collect())
self.assertDataFrameEqual(expected, result)

for df in [filesource_df, datasource_df, datasource_v2_df]:
for f in [f1, f2]:
Expand Down Expand Up @@ -902,7 +902,7 @@ def test_udf_in_subquery(self):
result = self.spark.sql(
"select i from values(0L) as data(i) where i in (select id from v)"
)
self.assertEqual(result.collect(), [Row(i=0)])
self.assertDataFrameEqual(result, [Row(i=0)])

def test_udf_globals_not_overwritten(self):
@udf("string")
Expand Down

0 comments on commit 69d580a

Please sign in to comment.