Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1468615 Fix Result Order Issue in the UDTF Suite #112

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 16 additions & 25 deletions src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ class UDTFSuite extends TestData {
Seq(
Row("w1 w2", "g1", "w2", 1),
Row("w1 w2", "g1", "w1", 1),
Row("w1 w1 w1", "g2", "w1", 3)),
false)
Row("w1 w1 w1", "g2", "w1", 3)))
} finally {
runQuery(s"drop function if exists $funcName(STRING)", session)
}
Expand Down Expand Up @@ -145,8 +144,7 @@ class UDTFSuite extends TestData {
|""".stripMargin)
checkAnswer(
df1,
Seq(Row("w3", 6), Row("w2", 4), Row("w1", 2), Row("w3", 6), Row("w2", 4), Row("w1", 2)),
false)
Seq(Row("w3", 6), Row("w2", 4), Row("w1", 2), Row("w3", 6), Row("w2", 4), Row("w1", 2)))

// Call the UDTF with funcName and named parameters, result should be the same
val df2 = session
Expand All @@ -162,8 +160,7 @@ class UDTFSuite extends TestData {
|""".stripMargin)
checkAnswer(
df2,
Seq(Row(6, "w3"), Row(4, "w2"), Row(2, "w1"), Row(6, "w3"), Row(4, "w2"), Row(2, "w1")),
false)
Seq(Row(6, "w3"), Row(4, "w2"), Row(2, "w1"), Row(6, "w3"), Row(4, "w2"), Row(2, "w1")))

// scalastyle:off
// Use UDTF with table join
Expand Down Expand Up @@ -199,8 +196,7 @@ class UDTFSuite extends TestData {
Row(null, null, "g2", 1),
Row(null, null, "w2", 1),
Row(null, null, "g1", 1),
Row(null, null, "w1", 4)),
false)
Row(null, null, "w1", 4)))

// Use UDTF with table function + over partition
val df4 = session.sql(
Expand All @@ -217,8 +213,7 @@ class UDTFSuite extends TestData {
Row("w1 w1 w1", "g2", "g2", 1),
Row("w1 w1 w1", "g2", "w1", 3),
Row(null, "g2", "g2", 1),
Row(null, "g2", "w1", 3)),
false)
Row(null, "g2", "w1", 3)))
} finally {
runQuery(s"drop function if exists $funcName(VARCHAR,VARCHAR)", session)
}
Expand All @@ -245,15 +240,15 @@ class UDTFSuite extends TestData {
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df1, Seq(Row(10), Row(11), Row(12), Row(13), Row(14)), false)
checkAnswer(df1, Seq(Row(10), Row(11), Row(12), Row(13), Row(14)))

val df2 = session.tableFunction(TableFunction(funcName), lit(20), lit(5))
assert(
getSchemaString(df2.schema) ==
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df2, Seq(Row(20), Row(21), Row(22), Row(23), Row(24)), false)
checkAnswer(df2, Seq(Row(20), Row(21), Row(22), Row(23), Row(24)))

val df3 = session
.tableFunction(tableFunction, Map("arg1" -> lit(30), "arg2" -> lit(5)))
Expand All @@ -262,7 +257,7 @@ class UDTFSuite extends TestData {
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df3, Seq(Row(30), Row(31), Row(32), Row(33), Row(34)), false)
checkAnswer(df3, Seq(Row(30), Row(31), Row(32), Row(33), Row(34)))
} finally {
runQuery(s"drop function if exists $funcName(NUMBER, NUMBER)", session)
}
Expand Down Expand Up @@ -377,15 +372,15 @@ class UDTFSuite extends TestData {
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df1, Seq(Row(10), Row(11), Row(20), Row(21), Row(22), Row(23)), false)
checkAnswer(df1, Seq(Row(10), Row(11), Row(20), Row(21), Row(22), Row(23)))

val df2 = session.tableFunction(TableFunction(funcName), sourceDF("b"), sourceDF("c"))
assert(
getSchemaString(df2.schema) ==
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df2, Seq(Row(100), Row(101), Row(200), Row(201), Row(202), Row(203)), false)
checkAnswer(df2, Seq(Row(100), Row(101), Row(200), Row(201), Row(202), Row(203)))

// Check table function with df column arguments as Map
val sourceDF2 = Seq(30).toDF("a")
Expand All @@ -396,7 +391,7 @@ class UDTFSuite extends TestData {
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df3, Seq(Row(30), Row(31), Row(32), Row(33), Row(34)), false)
checkAnswer(df3, Seq(Row(30), Row(31), Row(32), Row(33), Row(34)))

// Check table function with nested functions on df column
val df4 = session.tableFunction(tableFunction, abs(ceil(sourceDF("a"))), lit(2))
Expand All @@ -405,7 +400,7 @@ class UDTFSuite extends TestData {
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df4, Seq(Row(10), Row(11), Row(20), Row(21)), false)
checkAnswer(df4, Seq(Row(10), Row(11), Row(20), Row(21)))

// Check result df column filtering with duplicate column names
val sourceDF3 = Seq(30).toDF("C1")
Expand All @@ -419,7 +414,7 @@ class UDTFSuite extends TestData {
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df, Seq(Row(30), Row(31), Row(32), Row(33), Row(34)), false)
checkAnswer(df, Seq(Row(30), Row(31), Row(32), Row(33), Row(34)))
}
} finally {
runQuery(s"drop function if exists $funcName(NUMBER, NUMBER)", session)
Expand Down Expand Up @@ -1866,25 +1861,21 @@ class UDTFSuite extends TestData {
val tableFunction3 = session.udtf.registerTemporary(new ReturnManyColumns(3))
val df3 = session.tableFunction(tableFunction3, lit(10))
assert(df3.schema.length == 3)
checkAnswer(df3, Seq(Row(11, 12, 13), Row(1, 2, 3)), false)
checkAnswer(df3, Seq(Row(11, 12, 13), Row(1, 2, 3)))

// Test UDTF return 100 columns
val tableFunction100 = session.udtf.registerTemporary(new ReturnManyColumns(100))
val df100 = session.tableFunction(tableFunction100, lit(20))
assert(df100.schema.length == 100)
checkAnswer(
df100,
Seq(Row.fromArray((21 to 120).toArray), Row.fromArray((1 to 100).toArray)),
false)
checkAnswer(df100, Seq(Row.fromArray((21 to 120).toArray), Row.fromArray((1 to 100).toArray)))

// Test UDTF return 200 columns
val tableFunction200 = session.udtf.registerTemporary(new ReturnManyColumns(200))
val df200 = session.tableFunction(tableFunction200, lit(100))
assert(df200.schema.length == 200)
checkAnswer(
df200,
Seq(Row.fromArray((101 to 300).toArray), Row.fromArray((1 to 200).toArray)),
false)
Seq(Row.fromArray((101 to 300).toArray), Row.fromArray((1 to 200).toArray)))
}

test("test output type: basic types") {
Expand Down
Loading