diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index ee33762e..49ee2bfd 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -7,7 +7,11 @@ import com.snowflake.snowpark.internal.{Logging, Utils} import com.snowflake.snowpark.internal.analyzer._ import com.snowflake.snowpark.types._ import com.github.vertical_blank.sqlformatter.SqlFormatter -import com.snowflake.snowpark.internal.Utils.{TempObjectType, getTableFunctionExpression, randomNameForTempObject} +import com.snowflake.snowpark.internal.Utils.{ + TempObjectType, + getTableFunctionExpression, + randomNameForTempObject +} import javax.xml.bind.DatatypeConverter import scala.collection.JavaConverters._ @@ -1909,7 +1913,9 @@ class DataFrame private[snowpark] ( // todo: add test with UDTF def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin(this.plan, getTableFunctionExpression(func), + TableFunctionJoin( + this.plan, + getTableFunctionExpression(func), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) } diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala index c5a75cbb..dc0489b9 100644 --- a/src/main/scala/com/snowflake/snowpark/Session.scala +++ b/src/main/scala/com/snowflake/snowpark/Session.scala @@ -10,9 +10,17 @@ import com.snowflake.snowpark.internal._ import com.snowflake.snowpark.internal.analyzer.{TableFunction => TFunction} import com.snowflake.snowpark.types._ import com.snowflake.snowpark.functions._ -import com.snowflake.snowpark.internal.ErrorMessage.{UDF_CANNOT_ACCEPT_MANY_DF_COLS, UDF_UNEXPECTED_COLUMN_ORDER} +import com.snowflake.snowpark.internal.ErrorMessage.{ + UDF_CANNOT_ACCEPT_MANY_DF_COLS, + UDF_UNEXPECTED_COLUMN_ORDER +} import com.snowflake.snowpark.internal.ParameterUtils.ClosureCleanerMode -import com.snowflake.snowpark.internal.Utils.{TempObjectNamePattern, TempObjectType, getTableFunctionExpression, randomNameForTempObject} +import com.snowflake.snowpark.internal.Utils.{ + TempObjectNamePattern, + TempObjectType, + getTableFunctionExpression, + randomNameForTempObject +} import net.snowflake.client.jdbc.{SnowflakeConnectionV1, SnowflakeDriver, SnowflakeSQLException} import scala.concurrent.{ExecutionContext, Future} diff --git a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala index 31ebbff5..5f765e64 100644 --- a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala +++ b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala @@ -111,8 +111,12 @@ object tableFunctions { def flatten(input: Column): Column = Column(flatten.apply(input)) - def flatten(input: Column, - path: String, outer: Boolean, recursive: Boolean, mode: String): Column = + def flatten( + input: Column, + path: String, + outer: Boolean, + recursive: Boolean, + mode: String): Column = Column( flatten.apply( Map( diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index 30124185..d5ede212 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -819,8 +819,7 @@ class ErrorMessageSuite extends FunSuite { val ex = ErrorMessage.MISC_INVALID_TABLE_FUNCTION_INPUT() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0423"))) assert( - ex.message.startsWith( - "Error Code: 0423, Error message: Invalid input argument, " + - "Session.tableFunction only supports table function arguments")) + ex.message.startsWith("Error Code: 0423, Error message: Invalid input argument, " + + "Session.tableFunction only supports table function arguments")) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index 2de216ec..a28e4913 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -204,30 +204,78 @@ class TableFunctionSuite extends TestData { test("Argument in table function: flatten2") { val df1 = Seq("{\"a\":1, \"b\":[77, 88]}").toDF("col") checkAnswer( - df1.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "b", outer = true, recursive = true, mode = "both")).select("value"), + df1 + .join( + tableFunctions.flatten( + input = parse_json(df1("col")), + path = "b", + outer = true, + recursive = true, + mode = "both")) + .select("value"), Seq(Row("77"), Row("88"))) val df2 = Seq("[]").toDF("col") - checkAnswer(df2.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "", outer = true, recursive = true, mode = "both")).select("value"), + checkAnswer( + df2 + .join( + tableFunctions.flatten( + input = parse_json(df1("col")), + path = "", + outer = true, + recursive = true, + mode = "both")) + .select("value"), Seq(Row(null))) - assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "", outer = true, recursive = true, mode = "both")).count() == 4) - assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "", outer = true, recursive = false, mode = "both")).count() == 2) - assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "", outer = true, recursive = true, mode = "array")).count() == 1) - assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")), - path = "", outer = true, recursive = true, mode = "object")).count() == 2) + assert( + df1 + .join( + tableFunctions.flatten( + input = parse_json(df1("col")), + path = "", + outer = true, + recursive = true, + mode = "both")) + .count() == 4) + assert( + df1 + .join( + tableFunctions.flatten( + input = parse_json(df1("col")), + path = "", + outer = true, + recursive = false, + mode = "both")) + .count() == 2) + assert( + df1 + .join( + tableFunctions.flatten( + input = parse_json(df1("col")), + path = "", + outer = true, + recursive = true, + mode = "array")) + .count() == 1) + assert( + df1 + .join( + tableFunctions.flatten( + input = parse_json(df1("col")), + path = "", + outer = true, + recursive = true, + mode = "object")) + .count() == 2) } test("Argument in table function: flatten - session") { val df = Seq( (1, Array(1, 2, 3), Map("a" -> "b", "c" -> "d")), (2, Array(11, 22, 33), Map("a1" -> "b1", "c1" -> "d1"))).toDF("idx", "arr", "map") - checkAnswer(session.tableFunction(tableFunctions.flatten(df("arr"))).select("value"), + checkAnswer( + session.tableFunction(tableFunctions.flatten(df("arr"))).select("value"), Seq(Row("1"), Row("2"), Row("3"), Row("11"), Row("22"), Row("33"))) // error if it is not a table function val error1 = intercept[SnowparkClientException] { @@ -237,4 +285,19 @@ class TableFunctionSuite extends TestData { error1.message.contains("Invalid input argument, " + "Session.tableFunction only supports table function arguments")) } + + test("Argument in table function: flatten - session 2") { + val df1 = Seq("{\"a\":1, \"b\":[77, 88]}").toDF("col") + checkAnswer( + session + .tableFunction( + tableFunctions.flatten( + input = parse_json(df1("col")), + path = "b", + outer = true, + recursive = true, + mode = "both")) + .select("value"), + Seq(Row("77"), Row("88"))) + } }