Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-966360 Support Explode Function #69

Merged
merged 10 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/TableFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,33 @@ public static Column flatten(
public static Column flatten(Column input) {
return new Column(com.snowflake.snowpark.tableFunctions.flatten(input.toScalaColumn()));
}

/**
* Flattens a given array or map type column into individual rows. The output column(s) in case of
* array input column is `VALUE`, and are `KEY` and `VALUE` in case of amp input column.
*
* <p>Example
*
* <pre>{@code
* DataFrame df =
* getSession()
* .createDataFrame(
* new Row[] {Row.create("{\"a\":1, \"b\":2}")},
* StructType.create(new StructField("col", DataTypes.StringType)));
* DataFrame df1 =
* df.select(
* Functions.parse_json(df.col("col"))
* .cast(DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType))
* .as("col"));
* df1.select(TableFunctions.explode(df1.col("col"))).show()
* }</pre>
*
* @since 1.10.0
* @param input The expression that will be unseated into rows. The expression must be either
* MapType or ArrayType data.
* @return The result Column reference
*/
public static Column explode(Column input) {
return new Column(com.snowflake.snowpark.tableFunctions.explode(input.toScalaColumn()));
}
}
60 changes: 44 additions & 16 deletions src/main/scala/com/snowflake/snowpark/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package com.snowflake.snowpark

import scala.reflect.ClassTag
import scala.util.Random
import com.snowflake.snowpark.internal.analyzer.{TableFunction => TF}
import com.snowflake.snowpark.internal.ErrorMessage
import com.snowflake.snowpark.internal.{Logging, Utils}
import com.snowflake.snowpark.internal.analyzer._
import com.snowflake.snowpark.types._
import com.github.vertical_blank.sqlformatter.SqlFormatter
import com.snowflake.snowpark.functions.lit
import com.snowflake.snowpark.internal.Utils.{
TempObjectType,
getTableFunctionExpression,
Expand Down Expand Up @@ -1969,23 +1971,49 @@ class DataFrame private[snowpark] (
private def joinTableFunction(
func: TableFunctionExpression,
partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = {
val originalResult = withPlan {
TableFunctionJoin(this.plan, func, partitionByOrderBy)
func match {
// explode is a client side function
case TF(funcName, args) if funcName.toLowerCase().trim.equals("explode") =>
// explode has only one argument
joinWithExplode(args.head, partitionByOrderBy)
case _ =>
val originalResult = withPlan {
TableFunctionJoin(this.plan, func, partitionByOrderBy)
}
val resultSchema = originalResult.schema
val columnNames = resultSchema.map(_.name)
// duplicated names
val dup = columnNames.diff(columnNames.distinct).distinct.map(quoteName)
// guarantee no duplicated names in the result
if (dup.nonEmpty) {
val dfPrefix = DataFrame.generatePrefix('o')
val renamedDf =
this.select(this.output.map(_.name).map(aliasIfNeeded(this, _, dfPrefix, dup.toSet)))
withPlan {
TableFunctionJoin(renamedDf.plan, func, partitionByOrderBy)
}
} else {
originalResult
}
}
val resultSchema = originalResult.schema
val columnNames = resultSchema.map(_.name)
// duplicated names
val dup = columnNames.diff(columnNames.distinct).distinct.map(quoteName)
// guarantee no duplicated names in the result
if (dup.nonEmpty) {
val dfPrefix = DataFrame.generatePrefix('o')
val renamedDf =
this.select(this.output.map(_.name).map(aliasIfNeeded(this, _, dfPrefix, dup.toSet)))
withPlan {
TableFunctionJoin(renamedDf.plan, func, partitionByOrderBy)
}
} else {
originalResult
}

private def joinWithExplode(
expr: Expression,
partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = {
val columns: Seq[Column] = this.output.map(attr => col(attr.name))
// check the column type of input column
this.select(Column(expr)).schema.head.dataType match {
case _: ArrayType =>
joinTableFunction(
tableFunctions.flatten.call(Map("input" -> Column(expr), "mode" -> lit("array"))),
partitionByOrderBy).select(columns :+ Column("VALUE"))
case _: MapType =>
joinTableFunction(
tableFunctions.flatten.call(Map("input" -> Column(expr), "mode" -> lit("object"))),
partitionByOrderBy).select(columns ++ Seq(Column("KEY"), Column("VALUE")))
case otherType =>
throw ErrorMessage.MISC_INVALID_EXPLODE_ARGUMENT_TYPE(otherType.typeName)
}
}

Expand Down
17 changes: 16 additions & 1 deletion src/main/scala/com/snowflake/snowpark/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,10 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
// Use df.join to apply function result if args contains a DF column
val sourceDFs = args.flatMap(_.expr.sourceDFs)
if (sourceDFs.isEmpty) {
DataFrame(this, TableFunctionRelation(func.call(args: _*)))
// explode function requires a special handling since it is a client side function.
if (func.funcName.trim.toLowerCase() == "explode") {
callExplode(args.head)
} else DataFrame(this, TableFunctionRelation(func.call(args: _*)))
} else if (sourceDFs.toSet.size > 1) {
throw UDF_CANNOT_ACCEPT_MANY_DF_COLS()
} else {
Expand Down Expand Up @@ -580,6 +583,18 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
}
}

// process explode function with literal values
private def callExplode(input: Column): DataFrame = {
import this.implicits._
// to reuse the DataFrame.join function, the input column has to be converted to
// a DataFrame column. The best the solution is to create an empty dataframe and
// then append this column via withColumn function. However, Snowpark doesn't support
// empty DataFrame, therefore creating a dummy dataframe instead.
val dummyDF = Seq(1).toDF("a")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand here why we need to create a DF with dummy value and column names. This is not even in test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to reuse the DataFrame.join function since it supports explode function.
So I want to create an empty DataFrame, and then add input in this DataFrame via withColumn.
However, Snowpark doesn't support empty DataFrame, therefore I created a dummy DataFrame and remove the dummy column later.

val sourceDF = dummyDF.withColumn("b", input)
sourceDF.select(tableFunctions.explode(sourceDF("b")))
}

/**
* Creates a new DataFrame from the given table function.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,11 @@ private[snowpark] object ErrorMessage {
"0420" -> "Invalid RSA private key. The error is: %s",
"0421" -> "Invalid stage location: %s. Reason: %s.",
"0422" -> "Internal error: Server fetching is disabled for the parameter %s and there is no default value for it.",
"0423" -> "Invalid input argument, Session.tableFunction only supports table function arguments")
"0423" -> "Invalid input argument, Session.tableFunction only supports table function arguments",
"0424" ->
"""Invalid input argument type, the input argument type of Explode function should be either Map or Array types.
|The input argument type: %s
|""".stripMargin)
// scalastyle:on

/*
Expand Down Expand Up @@ -393,6 +397,9 @@ private[snowpark] object ErrorMessage {
def MISC_INVALID_TABLE_FUNCTION_INPUT(): SnowparkClientException =
createException("0423")

def MISC_INVALID_EXPLODE_ARGUMENT_TYPE(argumentType: String): SnowparkClientException =
createException("0424", argumentType)

/**
* Create Snowpark client Exception.
*
Expand Down
25 changes: 25 additions & 0 deletions src/main/scala/com/snowflake/snowpark/tableFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,29 @@ object tableFunctions {
"outer" -> lit(outer),
"recursive" -> lit(recursive),
"mode" -> lit(mode)))

/**
* Flattens a given array or map type column into individual rows.
* The output column(s) in case of array input column is `VALUE`,
* and are `KEY` and `VALUE` in case of amp input column.
*
* Example
* {{{
* import com.snowflake.snowpark.functions._
*
* val df = Seq("""{"a":1, "b": 2}""").toDF("a")
* val df1 = df.select(
* parse_json(df("a"))
* .cast(types.MapType(types.StringType, types.IntegerType))
* .as("a"))
* df1.select(lit(1), tableFunctions.explode(df1("a")), df1("a")("a")).show()
* }}}
*
* @since 1.10.0
* @param input The expression that will be unseated into rows.
* The expression must be either MapType or ArrayType data.
* @return The result Column reference
*/
def explode(input: Column): Column = TableFunction("explode").apply(input)

}
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,10 @@ public void flatten() {
}

@Test
public void getOrCreate()
{
public void getOrCreate() {
String expectedSessionInfo = getSession().getSessionInfo();
String actualSessionInfo = Session.builder().getOrCreate().getSessionInfo();
assert(actualSessionInfo.equals(expectedSessionInfo));
assert (actualSessionInfo.equals(expectedSessionInfo));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,63 @@ public void argumentInFlatten() {
.select("value"),
new Row[] {Row.create("1"), Row.create("2")});
}

@Test
public void explodeWithDataFrame() {
// select
DataFrame df =
getSession()
.createDataFrame(
new Row[] {Row.create("{\"a\":1, \"b\":2}")},
StructType.create(new StructField("col", DataTypes.StringType)));
DataFrame df1 =
df.select(
Functions.parse_json(df.col("col"))
.cast(DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType))
.as("col"));
checkAnswer(
df1.select(TableFunctions.explode(df1.col("col"))),
new Row[] {Row.create("a", "1"), Row.create("b", "2")});
// join
df =
getSession()
.createDataFrame(
new Row[] {Row.create("[1, 2]")},
StructType.create(new StructField("col", DataTypes.StringType)));
df1 =
df.select(
Functions.parse_json(df.col("col"))
.cast(DataTypes.createArrayType(DataTypes.IntegerType))
.as("col"));
checkAnswer(
df1.join(TableFunctions.explode(df1.col("col"))).select("VALUE"),
new Row[] {Row.create("1"), Row.create("2")});
}

@Test
public void explodeWithSession() {
DataFrame df =
getSession()
.createDataFrame(
new Row[] {Row.create("{\"a\":1, \"b\":2}")},
StructType.create(new StructField("col", DataTypes.StringType)));
DataFrame df1 =
df.select(
Functions.parse_json(df.col("col"))
.cast(DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType))
.as("col"));
checkAnswer(
getSession().tableFunction(TableFunctions.explode(df1.col("col"))).select("KEY", "VALUE"),
new Row[] {Row.create("a", "1"), Row.create("b", "2")});

checkAnswer(
getSession()
.tableFunction(
TableFunctions.explode(
Functions.parse_json(Functions.lit("{\"a\":1, \"b\":2}"))
.cast(
DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType))))
.select("KEY", "VALUE"),
new Row[] {Row.create("a", "1"), Row.create("b", "2")});
}
}
11 changes: 11 additions & 0 deletions src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -830,4 +830,15 @@ class ErrorMessageSuite extends FunSuite {
ex.message.startsWith("Error Code: 0423, Error message: Invalid input argument, " +
"Session.tableFunction only supports table function arguments"))
}

test("MISC_INVALID_EXPLODE_ARGUMENT_TYPE") {
val ex = ErrorMessage.MISC_INVALID_EXPLODE_ARGUMENT_TYPE(types.IntegerType.typeName)
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0424")))
assert(
ex.message.startsWith(
"Error Code: 0424, Error message: " +
"Invalid input argument type, the input argument type of " +
"Explode function should be either Map or Array types.\n" +
"The input argument type: Integer"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ package com.snowflake.snowpark_test
import com.snowflake.snowpark.functions._
import com.snowflake.snowpark.{Row, _}

import scala.collection.Seq
import scala.collection.immutable.Map

class TableFunctionSuite extends TestData {
import session.implicits._

Expand Down Expand Up @@ -386,4 +383,64 @@ class TableFunctionSuite extends TestData {
checkAnswer(result.select(df("value")), Seq(Row("1,2"), Row("1,2"), Row("3,4"), Row("3,4")))
}

test("explode with array column") {
val df = Seq("[1, 2]").toDF("a")
val df1 = df.select(parse_json(df("a")).cast(types.ArrayType(types.IntegerType)).as("a"))
checkAnswer(
df1.select(lit(1), tableFunctions.explode(df1("a")), df1("a")(1)),
Seq(Row(1, "1", "2"), Row(1, "2", "2")))
}

test("explode with map column") {
val df = Seq("""{"a":1, "b": 2}""").toDF("a")
val df1 = df.select(
parse_json(df("a"))
.cast(types.MapType(types.StringType, types.IntegerType))
.as("a"))
checkAnswer(
df1.select(lit(1), tableFunctions.explode(df1("a")), df1("a")("a")),
Seq(Row(1, "a", "1", "1"), Row(1, "b", "2", "1")))
}

test("explode with other column") {
val df = Seq("""{"a":1, "b": 2}""").toDF("a")
val df1 = df.select(
parse_json(df("a"))
.as("a"))
val error = intercept[SnowparkClientException] {
df1.select(tableFunctions.explode(df1("a"))).show()
}
assert(
error.message.contains(
"the input argument type of Explode function should be either Map or Array types"))
assert(error.message.contains("The input argument type: Variant"))
}

test("explode with DataFrame.join") {
val df = Seq("[1, 2]").toDF("a")
val df1 = df.select(parse_json(df("a")).cast(types.ArrayType(types.IntegerType)).as("a"))
checkAnswer(
df1.join(tableFunctions.explode(df1("a"))).select("VALUE"),
Seq(Row("1"), Row("2")))
}

test("explode with session.tableFunction") {
// with dataframe column
val df = Seq("""{"a":1, "b": 2}""").toDF("a")
val df1 = df.select(
parse_json(df("a"))
.cast(types.MapType(types.StringType, types.IntegerType))
.as("a"))
checkAnswer(
session.tableFunction(tableFunctions.explode(df1("a"))),
Seq(Row("a", "1"), Row("b", "2")))

// with literal value
checkAnswer(
session.tableFunction(
tableFunctions
.explode(parse_json(lit("[1, 2]")).cast(types.ArrayType(types.IntegerType)))),
Seq(Row("1"), Row("2")))
}

}
Loading