Skip to content

Commit

Permalink
SNOW-966360 Support Explode Function (#69)
Browse files Browse the repository at this point in the history
* explode array

* explode map type

* add error message

* add error

* add doc

* explode with session table function

* java api

* add java doc

* add comments
  • Loading branch information
sfc-gh-bli authored Nov 30, 2023
1 parent 8a8406c commit e056bd3
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 24 deletions.
29 changes: 29 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/TableFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,33 @@ public static Column flatten(
public static Column flatten(Column input) {
return new Column(com.snowflake.snowpark.tableFunctions.flatten(input.toScalaColumn()));
}

/**
* Flattens a given array or map type column into individual rows. The output column(s) in case of
* array input column is `VALUE`, and are `KEY` and `VALUE` in case of amp input column.
*
* <p>Example
*
* <pre>{@code
* DataFrame df =
* getSession()
* .createDataFrame(
* new Row[] {Row.create("{\"a\":1, \"b\":2}")},
* StructType.create(new StructField("col", DataTypes.StringType)));
* DataFrame df1 =
* df.select(
* Functions.parse_json(df.col("col"))
* .cast(DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType))
* .as("col"));
* df1.select(TableFunctions.explode(df1.col("col"))).show()
* }</pre>
*
* @since 1.10.0
* @param input The expression that will be unseated into rows. The expression must be either
* MapType or ArrayType data.
* @return The result Column reference
*/
public static Column explode(Column input) {
return new Column(com.snowflake.snowpark.tableFunctions.explode(input.toScalaColumn()));
}
}
60 changes: 44 additions & 16 deletions src/main/scala/com/snowflake/snowpark/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package com.snowflake.snowpark

import scala.reflect.ClassTag
import scala.util.Random
import com.snowflake.snowpark.internal.analyzer.{TableFunction => TF}
import com.snowflake.snowpark.internal.ErrorMessage
import com.snowflake.snowpark.internal.{Logging, Utils}
import com.snowflake.snowpark.internal.analyzer._
import com.snowflake.snowpark.types._
import com.github.vertical_blank.sqlformatter.SqlFormatter
import com.snowflake.snowpark.functions.lit
import com.snowflake.snowpark.internal.Utils.{
TempObjectType,
getTableFunctionExpression,
Expand Down Expand Up @@ -1969,23 +1971,49 @@ class DataFrame private[snowpark] (
private def joinTableFunction(
func: TableFunctionExpression,
partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = {
val originalResult = withPlan {
TableFunctionJoin(this.plan, func, partitionByOrderBy)
func match {
// explode is a client side function
case TF(funcName, args) if funcName.toLowerCase().trim.equals("explode") =>
// explode has only one argument
joinWithExplode(args.head, partitionByOrderBy)
case _ =>
val originalResult = withPlan {
TableFunctionJoin(this.plan, func, partitionByOrderBy)
}
val resultSchema = originalResult.schema
val columnNames = resultSchema.map(_.name)
// duplicated names
val dup = columnNames.diff(columnNames.distinct).distinct.map(quoteName)
// guarantee no duplicated names in the result
if (dup.nonEmpty) {
val dfPrefix = DataFrame.generatePrefix('o')
val renamedDf =
this.select(this.output.map(_.name).map(aliasIfNeeded(this, _, dfPrefix, dup.toSet)))
withPlan {
TableFunctionJoin(renamedDf.plan, func, partitionByOrderBy)
}
} else {
originalResult
}
}
val resultSchema = originalResult.schema
val columnNames = resultSchema.map(_.name)
// duplicated names
val dup = columnNames.diff(columnNames.distinct).distinct.map(quoteName)
// guarantee no duplicated names in the result
if (dup.nonEmpty) {
val dfPrefix = DataFrame.generatePrefix('o')
val renamedDf =
this.select(this.output.map(_.name).map(aliasIfNeeded(this, _, dfPrefix, dup.toSet)))
withPlan {
TableFunctionJoin(renamedDf.plan, func, partitionByOrderBy)
}
} else {
originalResult
}

private def joinWithExplode(
expr: Expression,
partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = {
val columns: Seq[Column] = this.output.map(attr => col(attr.name))
// check the column type of input column
this.select(Column(expr)).schema.head.dataType match {
case _: ArrayType =>
joinTableFunction(
tableFunctions.flatten.call(Map("input" -> Column(expr), "mode" -> lit("array"))),
partitionByOrderBy).select(columns :+ Column("VALUE"))
case _: MapType =>
joinTableFunction(
tableFunctions.flatten.call(Map("input" -> Column(expr), "mode" -> lit("object"))),
partitionByOrderBy).select(columns ++ Seq(Column("KEY"), Column("VALUE")))
case otherType =>
throw ErrorMessage.MISC_INVALID_EXPLODE_ARGUMENT_TYPE(otherType.typeName)
}
}

Expand Down
17 changes: 16 additions & 1 deletion src/main/scala/com/snowflake/snowpark/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,10 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
// Use df.join to apply function result if args contains a DF column
val sourceDFs = args.flatMap(_.expr.sourceDFs)
if (sourceDFs.isEmpty) {
DataFrame(this, TableFunctionRelation(func.call(args: _*)))
// explode function requires a special handling since it is a client side function.
if (func.funcName.trim.toLowerCase() == "explode") {
callExplode(args.head)
} else DataFrame(this, TableFunctionRelation(func.call(args: _*)))
} else if (sourceDFs.toSet.size > 1) {
throw UDF_CANNOT_ACCEPT_MANY_DF_COLS()
} else {
Expand Down Expand Up @@ -580,6 +583,18 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
}
}

// process explode function with literal values
private def callExplode(input: Column): DataFrame = {
import this.implicits._
// to reuse the DataFrame.join function, the input column has to be converted to
// a DataFrame column. The best the solution is to create an empty dataframe and
// then append this column via withColumn function. However, Snowpark doesn't support
// empty DataFrame, therefore creating a dummy dataframe instead.
val dummyDF = Seq(1).toDF("a")
val sourceDF = dummyDF.withColumn("b", input)
sourceDF.select(tableFunctions.explode(sourceDF("b")))
}

/**
* Creates a new DataFrame from the given table function.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,11 @@ private[snowpark] object ErrorMessage {
"0420" -> "Invalid RSA private key. The error is: %s",
"0421" -> "Invalid stage location: %s. Reason: %s.",
"0422" -> "Internal error: Server fetching is disabled for the parameter %s and there is no default value for it.",
"0423" -> "Invalid input argument, Session.tableFunction only supports table function arguments")
"0423" -> "Invalid input argument, Session.tableFunction only supports table function arguments",
"0424" ->
"""Invalid input argument type, the input argument type of Explode function should be either Map or Array types.
|The input argument type: %s
|""".stripMargin)
// scalastyle:on

/*
Expand Down Expand Up @@ -393,6 +397,9 @@ private[snowpark] object ErrorMessage {
def MISC_INVALID_TABLE_FUNCTION_INPUT(): SnowparkClientException =
createException("0423")

def MISC_INVALID_EXPLODE_ARGUMENT_TYPE(argumentType: String): SnowparkClientException =
createException("0424", argumentType)

/**
* Create Snowpark client Exception.
*
Expand Down
25 changes: 25 additions & 0 deletions src/main/scala/com/snowflake/snowpark/tableFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,29 @@ object tableFunctions {
"outer" -> lit(outer),
"recursive" -> lit(recursive),
"mode" -> lit(mode)))

/**
* Flattens a given array or map type column into individual rows.
* The output column(s) in case of array input column is `VALUE`,
* and are `KEY` and `VALUE` in case of amp input column.
*
* Example
* {{{
* import com.snowflake.snowpark.functions._
*
* val df = Seq("""{"a":1, "b": 2}""").toDF("a")
* val df1 = df.select(
* parse_json(df("a"))
* .cast(types.MapType(types.StringType, types.IntegerType))
* .as("a"))
* df1.select(lit(1), tableFunctions.explode(df1("a")), df1("a")("a")).show()
* }}}
*
* @since 1.10.0
* @param input The expression that will be unseated into rows.
* The expression must be either MapType or ArrayType data.
* @return The result Column reference
*/
def explode(input: Column): Column = TableFunction("explode").apply(input)

}
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,10 @@ public void flatten() {
}

@Test
public void getOrCreate()
{
public void getOrCreate() {
String expectedSessionInfo = getSession().getSessionInfo();
String actualSessionInfo = Session.builder().getOrCreate().getSessionInfo();
assert(actualSessionInfo.equals(expectedSessionInfo));
assert (actualSessionInfo.equals(expectedSessionInfo));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,63 @@ public void argumentInFlatten() {
.select("value"),
new Row[] {Row.create("1"), Row.create("2")});
}

@Test
public void explodeWithDataFrame() {
// select
DataFrame df =
getSession()
.createDataFrame(
new Row[] {Row.create("{\"a\":1, \"b\":2}")},
StructType.create(new StructField("col", DataTypes.StringType)));
DataFrame df1 =
df.select(
Functions.parse_json(df.col("col"))
.cast(DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType))
.as("col"));
checkAnswer(
df1.select(TableFunctions.explode(df1.col("col"))),
new Row[] {Row.create("a", "1"), Row.create("b", "2")});
// join
df =
getSession()
.createDataFrame(
new Row[] {Row.create("[1, 2]")},
StructType.create(new StructField("col", DataTypes.StringType)));
df1 =
df.select(
Functions.parse_json(df.col("col"))
.cast(DataTypes.createArrayType(DataTypes.IntegerType))
.as("col"));
checkAnswer(
df1.join(TableFunctions.explode(df1.col("col"))).select("VALUE"),
new Row[] {Row.create("1"), Row.create("2")});
}

@Test
public void explodeWithSession() {
DataFrame df =
getSession()
.createDataFrame(
new Row[] {Row.create("{\"a\":1, \"b\":2}")},
StructType.create(new StructField("col", DataTypes.StringType)));
DataFrame df1 =
df.select(
Functions.parse_json(df.col("col"))
.cast(DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType))
.as("col"));
checkAnswer(
getSession().tableFunction(TableFunctions.explode(df1.col("col"))).select("KEY", "VALUE"),
new Row[] {Row.create("a", "1"), Row.create("b", "2")});

checkAnswer(
getSession()
.tableFunction(
TableFunctions.explode(
Functions.parse_json(Functions.lit("{\"a\":1, \"b\":2}"))
.cast(
DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType))))
.select("KEY", "VALUE"),
new Row[] {Row.create("a", "1"), Row.create("b", "2")});
}
}
11 changes: 11 additions & 0 deletions src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -830,4 +830,15 @@ class ErrorMessageSuite extends FunSuite {
ex.message.startsWith("Error Code: 0423, Error message: Invalid input argument, " +
"Session.tableFunction only supports table function arguments"))
}

test("MISC_INVALID_EXPLODE_ARGUMENT_TYPE") {
val ex = ErrorMessage.MISC_INVALID_EXPLODE_ARGUMENT_TYPE(types.IntegerType.typeName)
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0424")))
assert(
ex.message.startsWith(
"Error Code: 0424, Error message: " +
"Invalid input argument type, the input argument type of " +
"Explode function should be either Map or Array types.\n" +
"The input argument type: Integer"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ package com.snowflake.snowpark_test
import com.snowflake.snowpark.functions._
import com.snowflake.snowpark.{Row, _}

import scala.collection.Seq
import scala.collection.immutable.Map

class TableFunctionSuite extends TestData {
import session.implicits._

Expand Down Expand Up @@ -386,4 +383,64 @@ class TableFunctionSuite extends TestData {
checkAnswer(result.select(df("value")), Seq(Row("1,2"), Row("1,2"), Row("3,4"), Row("3,4")))
}

test("explode with array column") {
val df = Seq("[1, 2]").toDF("a")
val df1 = df.select(parse_json(df("a")).cast(types.ArrayType(types.IntegerType)).as("a"))
checkAnswer(
df1.select(lit(1), tableFunctions.explode(df1("a")), df1("a")(1)),
Seq(Row(1, "1", "2"), Row(1, "2", "2")))
}

test("explode with map column") {
val df = Seq("""{"a":1, "b": 2}""").toDF("a")
val df1 = df.select(
parse_json(df("a"))
.cast(types.MapType(types.StringType, types.IntegerType))
.as("a"))
checkAnswer(
df1.select(lit(1), tableFunctions.explode(df1("a")), df1("a")("a")),
Seq(Row(1, "a", "1", "1"), Row(1, "b", "2", "1")))
}

test("explode with other column") {
val df = Seq("""{"a":1, "b": 2}""").toDF("a")
val df1 = df.select(
parse_json(df("a"))
.as("a"))
val error = intercept[SnowparkClientException] {
df1.select(tableFunctions.explode(df1("a"))).show()
}
assert(
error.message.contains(
"the input argument type of Explode function should be either Map or Array types"))
assert(error.message.contains("The input argument type: Variant"))
}

test("explode with DataFrame.join") {
val df = Seq("[1, 2]").toDF("a")
val df1 = df.select(parse_json(df("a")).cast(types.ArrayType(types.IntegerType)).as("a"))
checkAnswer(
df1.join(tableFunctions.explode(df1("a"))).select("VALUE"),
Seq(Row("1"), Row("2")))
}

test("explode with session.tableFunction") {
// with dataframe column
val df = Seq("""{"a":1, "b": 2}""").toDF("a")
val df1 = df.select(
parse_json(df("a"))
.cast(types.MapType(types.StringType, types.IntegerType))
.as("a"))
checkAnswer(
session.tableFunction(tableFunctions.explode(df1("a"))),
Seq(Row("a", "1"), Row("b", "2")))

// with literal value
checkAnswer(
session.tableFunction(
tableFunctions
.explode(parse_json(lit("[1, 2]")).cast(types.ArrayType(types.IntegerType)))),
Seq(Row("1"), Row("2")))
}

}

0 comments on commit e056bd3

Please sign in to comment.