diff --git a/src/main/java/com/snowflake/snowpark_java/TableFunctions.java b/src/main/java/com/snowflake/snowpark_java/TableFunctions.java index 022ec6df..524c89c0 100644 --- a/src/main/java/com/snowflake/snowpark_java/TableFunctions.java +++ b/src/main/java/com/snowflake/snowpark_java/TableFunctions.java @@ -150,4 +150,33 @@ public static Column flatten( public static Column flatten(Column input) { return new Column(com.snowflake.snowpark.tableFunctions.flatten(input.toScalaColumn())); } + + /** + * Flattens a given array or map type column into individual rows. The output column(s) in case of + * array input column is `VALUE`, and are `KEY` and `VALUE` in case of amp input column. + * + *

Example + * + *

{@code
+   * DataFrame df =
+   *   getSession()
+   *     .createDataFrame(
+   *       new Row[] {Row.create("{\"a\":1, \"b\":2}")},
+   *       StructType.create(new StructField("col", DataTypes.StringType)));
+   * DataFrame df1 =
+   *   df.select(
+   *     Functions.parse_json(df.col("col"))
+   *       .cast(DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType))
+   *       .as("col"));
+   * df1.select(TableFunctions.explode(df1.col("col"))).show()
+   * }
+ * + * @since 1.10.0 + * @param input The expression that will be unseated into rows. The expression must be either + * MapType or ArrayType data. + * @return The result Column reference + */ + public static Column explode(Column input) { + return new Column(com.snowflake.snowpark.tableFunctions.explode(input.toScalaColumn())); + } } diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index f7d6909c..54c43c49 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -2,11 +2,13 @@ package com.snowflake.snowpark import scala.reflect.ClassTag import scala.util.Random +import com.snowflake.snowpark.internal.analyzer.{TableFunction => TF} import com.snowflake.snowpark.internal.ErrorMessage import com.snowflake.snowpark.internal.{Logging, Utils} import com.snowflake.snowpark.internal.analyzer._ import com.snowflake.snowpark.types._ import com.github.vertical_blank.sqlformatter.SqlFormatter +import com.snowflake.snowpark.functions.lit import com.snowflake.snowpark.internal.Utils.{ TempObjectType, getTableFunctionExpression, @@ -1969,23 +1971,49 @@ class DataFrame private[snowpark] ( private def joinTableFunction( func: TableFunctionExpression, partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = { - val originalResult = withPlan { - TableFunctionJoin(this.plan, func, partitionByOrderBy) + func match { + // explode is a client side function + case TF(funcName, args) if funcName.toLowerCase().trim.equals("explode") => + // explode has only one argument + joinWithExplode(args.head, partitionByOrderBy) + case _ => + val originalResult = withPlan { + TableFunctionJoin(this.plan, func, partitionByOrderBy) + } + val resultSchema = originalResult.schema + val columnNames = resultSchema.map(_.name) + // duplicated names + val dup = columnNames.diff(columnNames.distinct).distinct.map(quoteName) + // guarantee no duplicated names in the result + if (dup.nonEmpty) { + val dfPrefix = DataFrame.generatePrefix('o') + val renamedDf = + this.select(this.output.map(_.name).map(aliasIfNeeded(this, _, dfPrefix, dup.toSet))) + withPlan { + TableFunctionJoin(renamedDf.plan, func, partitionByOrderBy) + } + } else { + originalResult + } } - val resultSchema = originalResult.schema - val columnNames = resultSchema.map(_.name) - // duplicated names - val dup = columnNames.diff(columnNames.distinct).distinct.map(quoteName) - // guarantee no duplicated names in the result - if (dup.nonEmpty) { - val dfPrefix = DataFrame.generatePrefix('o') - val renamedDf = - this.select(this.output.map(_.name).map(aliasIfNeeded(this, _, dfPrefix, dup.toSet))) - withPlan { - TableFunctionJoin(renamedDf.plan, func, partitionByOrderBy) - } - } else { - originalResult + } + + private def joinWithExplode( + expr: Expression, + partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = { + val columns: Seq[Column] = this.output.map(attr => col(attr.name)) + // check the column type of input column + this.select(Column(expr)).schema.head.dataType match { + case _: ArrayType => + joinTableFunction( + tableFunctions.flatten.call(Map("input" -> Column(expr), "mode" -> lit("array"))), + partitionByOrderBy).select(columns :+ Column("VALUE")) + case _: MapType => + joinTableFunction( + tableFunctions.flatten.call(Map("input" -> Column(expr), "mode" -> lit("object"))), + partitionByOrderBy).select(columns ++ Seq(Column("KEY"), Column("VALUE"))) + case otherType => + throw ErrorMessage.MISC_INVALID_EXPLODE_ARGUMENT_TYPE(otherType.typeName) } } diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala index aa8058d0..950e90b0 100644 --- a/src/main/scala/com/snowflake/snowpark/Session.scala +++ b/src/main/scala/com/snowflake/snowpark/Session.scala @@ -529,7 +529,10 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log // Use df.join to apply function result if args contains a DF column val sourceDFs = args.flatMap(_.expr.sourceDFs) if (sourceDFs.isEmpty) { - DataFrame(this, TableFunctionRelation(func.call(args: _*))) + // explode function requires a special handling since it is a client side function. + if (func.funcName.trim.toLowerCase() == "explode") { + callExplode(args.head) + } else DataFrame(this, TableFunctionRelation(func.call(args: _*))) } else if (sourceDFs.toSet.size > 1) { throw UDF_CANNOT_ACCEPT_MANY_DF_COLS() } else { @@ -580,6 +583,18 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log } } + // process explode function with literal values + private def callExplode(input: Column): DataFrame = { + import this.implicits._ + // to reuse the DataFrame.join function, the input column has to be converted to + // a DataFrame column. The best the solution is to create an empty dataframe and + // then append this column via withColumn function. However, Snowpark doesn't support + // empty DataFrame, therefore creating a dummy dataframe instead. + val dummyDF = Seq(1).toDF("a") + val sourceDF = dummyDF.withColumn("b", input) + sourceDF.select(tableFunctions.explode(sourceDF("b"))) + } + /** * Creates a new DataFrame from the given table function. * diff --git a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala index 505b2b6d..21db9dea 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala @@ -153,7 +153,11 @@ private[snowpark] object ErrorMessage { "0420" -> "Invalid RSA private key. The error is: %s", "0421" -> "Invalid stage location: %s. Reason: %s.", "0422" -> "Internal error: Server fetching is disabled for the parameter %s and there is no default value for it.", - "0423" -> "Invalid input argument, Session.tableFunction only supports table function arguments") + "0423" -> "Invalid input argument, Session.tableFunction only supports table function arguments", + "0424" -> + """Invalid input argument type, the input argument type of Explode function should be either Map or Array types. + |The input argument type: %s + |""".stripMargin) // scalastyle:on /* @@ -393,6 +397,9 @@ private[snowpark] object ErrorMessage { def MISC_INVALID_TABLE_FUNCTION_INPUT(): SnowparkClientException = createException("0423") + def MISC_INVALID_EXPLODE_ARGUMENT_TYPE(argumentType: String): SnowparkClientException = + createException("0424", argumentType) + /** * Create Snowpark client Exception. * diff --git a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala index 212a000c..91f40c13 100644 --- a/src/main/scala/com/snowflake/snowpark/tableFunctions.scala +++ b/src/main/scala/com/snowflake/snowpark/tableFunctions.scala @@ -197,4 +197,29 @@ object tableFunctions { "outer" -> lit(outer), "recursive" -> lit(recursive), "mode" -> lit(mode))) + + /** + * Flattens a given array or map type column into individual rows. + * The output column(s) in case of array input column is `VALUE`, + * and are `KEY` and `VALUE` in case of amp input column. + * + * Example + * {{{ + * import com.snowflake.snowpark.functions._ + * + * val df = Seq("""{"a":1, "b": 2}""").toDF("a") + * val df1 = df.select( + * parse_json(df("a")) + * .cast(types.MapType(types.StringType, types.IntegerType)) + * .as("a")) + * df1.select(lit(1), tableFunctions.explode(df1("a")), df1("a")("a")).show() + * }}} + * + * @since 1.10.0 + * @param input The expression that will be unseated into rows. + * The expression must be either MapType or ArrayType data. + * @return The result Column reference + */ + def explode(input: Column): Column = TableFunction("explode").apply(input) + } diff --git a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java index 75bda890..ee00443c 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java @@ -82,11 +82,10 @@ public void flatten() { } @Test - public void getOrCreate() - { + public void getOrCreate() { String expectedSessionInfo = getSession().getSessionInfo(); String actualSessionInfo = Session.builder().getOrCreate().getSessionInfo(); - assert(actualSessionInfo.equals(expectedSessionInfo)); + assert (actualSessionInfo.equals(expectedSessionInfo)); } @Test diff --git a/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java index 12ca8391..849b41b5 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaTableFunctionSuite.java @@ -121,4 +121,63 @@ public void argumentInFlatten() { .select("value"), new Row[] {Row.create("1"), Row.create("2")}); } + + @Test + public void explodeWithDataFrame() { + // select + DataFrame df = + getSession() + .createDataFrame( + new Row[] {Row.create("{\"a\":1, \"b\":2}")}, + StructType.create(new StructField("col", DataTypes.StringType))); + DataFrame df1 = + df.select( + Functions.parse_json(df.col("col")) + .cast(DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType)) + .as("col")); + checkAnswer( + df1.select(TableFunctions.explode(df1.col("col"))), + new Row[] {Row.create("a", "1"), Row.create("b", "2")}); + // join + df = + getSession() + .createDataFrame( + new Row[] {Row.create("[1, 2]")}, + StructType.create(new StructField("col", DataTypes.StringType))); + df1 = + df.select( + Functions.parse_json(df.col("col")) + .cast(DataTypes.createArrayType(DataTypes.IntegerType)) + .as("col")); + checkAnswer( + df1.join(TableFunctions.explode(df1.col("col"))).select("VALUE"), + new Row[] {Row.create("1"), Row.create("2")}); + } + + @Test + public void explodeWithSession() { + DataFrame df = + getSession() + .createDataFrame( + new Row[] {Row.create("{\"a\":1, \"b\":2}")}, + StructType.create(new StructField("col", DataTypes.StringType))); + DataFrame df1 = + df.select( + Functions.parse_json(df.col("col")) + .cast(DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType)) + .as("col")); + checkAnswer( + getSession().tableFunction(TableFunctions.explode(df1.col("col"))).select("KEY", "VALUE"), + new Row[] {Row.create("a", "1"), Row.create("b", "2")}); + + checkAnswer( + getSession() + .tableFunction( + TableFunctions.explode( + Functions.parse_json(Functions.lit("{\"a\":1, \"b\":2}")) + .cast( + DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType)))) + .select("KEY", "VALUE"), + new Row[] {Row.create("a", "1"), Row.create("b", "2")}); + } } diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index 072c73e8..bfeddd72 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -830,4 +830,15 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith("Error Code: 0423, Error message: Invalid input argument, " + "Session.tableFunction only supports table function arguments")) } + + test("MISC_INVALID_EXPLODE_ARGUMENT_TYPE") { + val ex = ErrorMessage.MISC_INVALID_EXPLODE_ARGUMENT_TYPE(types.IntegerType.typeName) + assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0424"))) + assert( + ex.message.startsWith( + "Error Code: 0424, Error message: " + + "Invalid input argument type, the input argument type of " + + "Explode function should be either Map or Array types.\n" + + "The input argument type: Integer")) + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index fe330255..58392126 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -3,9 +3,6 @@ package com.snowflake.snowpark_test import com.snowflake.snowpark.functions._ import com.snowflake.snowpark.{Row, _} -import scala.collection.Seq -import scala.collection.immutable.Map - class TableFunctionSuite extends TestData { import session.implicits._ @@ -386,4 +383,64 @@ class TableFunctionSuite extends TestData { checkAnswer(result.select(df("value")), Seq(Row("1,2"), Row("1,2"), Row("3,4"), Row("3,4"))) } + test("explode with array column") { + val df = Seq("[1, 2]").toDF("a") + val df1 = df.select(parse_json(df("a")).cast(types.ArrayType(types.IntegerType)).as("a")) + checkAnswer( + df1.select(lit(1), tableFunctions.explode(df1("a")), df1("a")(1)), + Seq(Row(1, "1", "2"), Row(1, "2", "2"))) + } + + test("explode with map column") { + val df = Seq("""{"a":1, "b": 2}""").toDF("a") + val df1 = df.select( + parse_json(df("a")) + .cast(types.MapType(types.StringType, types.IntegerType)) + .as("a")) + checkAnswer( + df1.select(lit(1), tableFunctions.explode(df1("a")), df1("a")("a")), + Seq(Row(1, "a", "1", "1"), Row(1, "b", "2", "1"))) + } + + test("explode with other column") { + val df = Seq("""{"a":1, "b": 2}""").toDF("a") + val df1 = df.select( + parse_json(df("a")) + .as("a")) + val error = intercept[SnowparkClientException] { + df1.select(tableFunctions.explode(df1("a"))).show() + } + assert( + error.message.contains( + "the input argument type of Explode function should be either Map or Array types")) + assert(error.message.contains("The input argument type: Variant")) + } + + test("explode with DataFrame.join") { + val df = Seq("[1, 2]").toDF("a") + val df1 = df.select(parse_json(df("a")).cast(types.ArrayType(types.IntegerType)).as("a")) + checkAnswer( + df1.join(tableFunctions.explode(df1("a"))).select("VALUE"), + Seq(Row("1"), Row("2"))) + } + + test("explode with session.tableFunction") { + // with dataframe column + val df = Seq("""{"a":1, "b": 2}""").toDF("a") + val df1 = df.select( + parse_json(df("a")) + .cast(types.MapType(types.StringType, types.IntegerType)) + .as("a")) + checkAnswer( + session.tableFunction(tableFunctions.explode(df1("a"))), + Seq(Row("a", "1"), Row("b", "2"))) + + // with literal value + checkAnswer( + session.tableFunction( + tableFunctions + .explode(parse_json(lit("[1, 2]")).cast(types.ArrayType(types.IntegerType)))), + Seq(Row("1"), Row("2"))) + } + }