From 69d580acd4909a3cc1d10655a140e9fb3d3ed3fb Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Thu, 28 Nov 2024 15:39:48 +0800 Subject: [PATCH] replace --- python/pyspark/sql/tests/test_udf.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 78aa2546128a1..a045004cbe444 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -220,7 +220,7 @@ def test_udf_in_filter_on_top_of_outer_join(self): right = self.spark.createDataFrame([Row(a=1)]) df = left.join(right, on="a", how="left_outer") df = df.withColumn("b", udf(lambda x: "x")(df.a)) - self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b="x")]) + self.assertDataFrameEqual(df.filter('b = "x"'), [Row(a=1, b="x")]) def test_udf_in_filter_on_top_of_join(self): # regression test for SPARK-18589 @@ -228,7 +228,7 @@ def test_udf_in_filter_on_top_of_join(self): right = self.spark.createDataFrame([Row(b=1)]) f = udf(lambda a, b: a == b, BooleanType()) df = left.crossJoin(right).filter(f("a", "b")) - self.assertEqual(df.collect(), [Row(a=1, b=1)]) + self.assertDataFrameEqual(df, [Row(a=1, b=1)]) def test_udf_in_join_condition(self): # regression test for SPARK-25314 @@ -243,7 +243,7 @@ def test_udf_in_join_condition(self): df.collect() with self.sql_conf({"spark.sql.crossJoin.enabled": True}): df = left.join(right, f("a", "b")) - self.assertEqual(df.collect(), [Row(a=1, b=1)]) + self.assertDataFrameEqual(df, [Row(a=1, b=1)]) def test_udf_in_left_outer_join_condition(self): # regression test for SPARK-26147 @@ -256,7 +256,7 @@ def test_udf_in_left_outer_join_condition(self): # The Python UDF only refer to attributes from one side, so it's evaluable. df = left.join(right, f("a") == col("b").cast("string"), how="left_outer") with self.sql_conf({"spark.sql.crossJoin.enabled": True}): - self.assertEqual(df.collect(), [Row(a=1, b=1)]) + self.assertDataFrameEqual(df, [Row(a=1, b=1)]) def test_udf_and_common_filter_in_join_condition(self): # regression test for SPARK-25314 @@ -266,7 +266,7 @@ def test_udf_and_common_filter_in_join_condition(self): f = udf(lambda a, b: a == b, BooleanType()) df = left.join(right, [f("a", "b"), left.a1 == right.b1]) # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition. - self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)]) + self.assertDataFrameEqual(df, [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)]) def test_udf_not_supported_in_join_condition(self): # regression test for SPARK-25314 @@ -294,7 +294,7 @@ def test_udf_as_join_condition(self): f = udf(lambda a: a, IntegerType()) df = left.join(right, [f("a") == f("b"), left.a1 == right.b1]) - self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)]) + self.assertDataFrameEqual(df, [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)]) def test_udf_without_arguments(self): self.spark.catalog.registerFunction("foo", lambda: "bar") @@ -331,7 +331,7 @@ def test_udf_with_filter_function(self): my_filter = udf(lambda a: a < 2, BooleanType()) sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2")) - self.assertEqual(sel.collect(), [Row(key=1, value="1")]) + self.assertDataFrameEqual(sel, [Row(key=1, value="1")]) def test_udf_with_variant_input(self): df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as string)) v") @@ -461,7 +461,7 @@ def test_udf_with_aggregate_function(self): my_filter = udf(lambda a: a == 1, BooleanType()) sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) - self.assertEqual(sel.collect(), [Row(key=1)]) + self.assertDataFrameEqual(sel, [Row(key=1)]) my_copy = udf(lambda x: x, IntegerType()) my_add = udf(lambda a, b: int(a + b), IntegerType()) @@ -471,7 +471,7 @@ def test_udf_with_aggregate_function(self): .agg(sum(my_strlen(col("value"))).alias("s")) .select(my_add(col("k"), col("s")).alias("t")) ) - self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) + self.assertDataFrameEqual(sel, [Row(t=4), Row(t=3)]) def test_udf_in_generate(self): from pyspark.sql.functions import explode @@ -505,7 +505,7 @@ def test_udf_with_order_by_and_limit(self): my_copy = udf(lambda x: x, IntegerType()) df = self.spark.range(10).orderBy("id") res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1) - self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + self.assertDataFrameEqual(res, [Row(id=0, copy=0)]) def test_udf_registration_returns_udf(self): df = self.spark.range(10) @@ -838,12 +838,12 @@ def test_datasource_with_udf(self): for df in [filesource_df, datasource_df, datasource_v2_df]: result = df.withColumn("c", c1) expected = df.withColumn("c", lit(2)) - self.assertEqual(expected.collect(), result.collect()) + self.assertDataFrameEqual(expected, result) for df in [filesource_df, datasource_df, datasource_v2_df]: result = df.withColumn("c", c2) expected = df.withColumn("c", col("i") + 1) - self.assertEqual(expected.collect(), result.collect()) + self.assertDataFrameEqual(expected, result) for df in [filesource_df, datasource_df, datasource_v2_df]: for f in [f1, f2]: @@ -902,7 +902,7 @@ def test_udf_in_subquery(self): result = self.spark.sql( "select i from values(0L) as data(i) where i in (select id from v)" ) - self.assertEqual(result.collect(), [Row(i=0)]) + self.assertDataFrameEqual(result, [Row(i=0)]) def test_udf_globals_not_overwritten(self): @udf("string")