diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 7666d9e8..f7d6909c 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -561,18 +561,35 @@ class DataFrame private[snowpark] ( "Provide at least one column expression for select(). " + s"This DataFrame has column names (${output.length}): " + s"${output.map(_.name).mkString(", ")}\n") - - val resultDF = withPlan { Project(columns.map(_.named), plan) } - // do not rename back if this project contains internal alias. - // because no named duplicated if just renamed. - val hasInternalAlias: Boolean = columns.map(_.expr).exists { - case Alias(_, _, true) => true - case _ => false - } - if (hasInternalAlias) { - resultDF - } else { - renameBackIfDeduped(resultDF) + // todo: error message + val tf = columns.filter(_.expr.isInstanceOf[TableFunctionExpression]) + tf.size match { + case 0 => // no table function + val resultDF = withPlan { + Project(columns.map(_.named), plan) + } + // do not rename back if this project contains internal alias. + // because no named duplicated if just renamed. + val hasInternalAlias: Boolean = columns.map(_.expr).exists { + case Alias(_, _, true) => true + case _ => false + } + if (hasInternalAlias) { + resultDF + } else { + renameBackIfDeduped(resultDF) + } + case 1 => // 1 table function + val base = this.join(tf.head) + val baseColumns = base.schema.map(field => base(field.name)) + val inputDFColumnSize = this.schema.size + val tfColumns = baseColumns.splitAt(inputDFColumnSize)._2 + val (beforeTf, afterTf) = columns.span(_ != tf.head) + val resultColumns = beforeTf ++ tfColumns ++ afterTf.tail + base.select(resultColumns) + case _ => + // more than 1 TF + throw ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT() } } @@ -1788,9 +1805,8 @@ class DataFrame private[snowpark] ( * object or an object that you create from the [[TableFunction]] class. * @param args A list of arguments to pass to the specified table function. */ - def join(func: TableFunction, args: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin(this.plan, func.call(args: _*), None) - } + def join(func: TableFunction, args: Seq[Column]): DataFrame = + joinTableFunction(func.call(args: _*), None) /** * Joins the current DataFrame with the output of the specified user-defined table @@ -1822,12 +1838,10 @@ class DataFrame private[snowpark] ( func: TableFunction, args: Seq[Column], partitionBy: Seq[Column], - orderBy: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin( - this.plan, + orderBy: Seq[Column]): DataFrame = + joinTableFunction( func.call(args: _*), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) - } /** * Joins the current DataFrame with the output of the specified table function `func` that takes @@ -1859,9 +1873,8 @@ class DataFrame private[snowpark] ( * Some functions, like `flatten`, have named parameters. * Use this map to specify the parameter names and their corresponding values. */ - def join(func: TableFunction, args: Map[String, Column]): DataFrame = withPlan { - TableFunctionJoin(this.plan, func.call(args), None) - } + def join(func: TableFunction, args: Map[String, Column]): DataFrame = + joinTableFunction(func.call(args), None) /** * Joins the current DataFrame with the output of the specified user-defined table function @@ -1900,12 +1913,10 @@ class DataFrame private[snowpark] ( func: TableFunction, args: Map[String, Column], partitionBy: Seq[Column], - orderBy: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin( - this.plan, + orderBy: Seq[Column]): DataFrame = + joinTableFunction( func.call(args), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) - } /** * Joins the current DataFrame with the output of the specified table function `func`. @@ -1929,9 +1940,8 @@ class DataFrame private[snowpark] ( * @param func [[TableFunction]] object, which can be one of the values in the [[tableFunctions]] * object or an object that you create from the [[TableFunction.apply()]]. */ - def join(func: Column): DataFrame = withPlan { - TableFunctionJoin(this.plan, getTableFunctionExpression(func), None) - } + def join(func: Column): DataFrame = + joinTableFunction(getTableFunctionExpression(func), None) /** * Joins the current DataFrame with the output of the specified user-defined table function @@ -1951,11 +1961,32 @@ class DataFrame private[snowpark] ( * @param partitionBy A list of columns partitioned by. * @param orderBy A list of columns ordered by. */ - def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = withPlan { - TableFunctionJoin( - this.plan, + def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = + joinTableFunction( getTableFunctionExpression(func), Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition)) + + private def joinTableFunction( + func: TableFunctionExpression, + partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = { + val originalResult = withPlan { + TableFunctionJoin(this.plan, func, partitionByOrderBy) + } + val resultSchema = originalResult.schema + val columnNames = resultSchema.map(_.name) + // duplicated names + val dup = columnNames.diff(columnNames.distinct).distinct.map(quoteName) + // guarantee no duplicated names in the result + if (dup.nonEmpty) { + val dfPrefix = DataFrame.generatePrefix('o') + val renamedDf = + this.select(this.output.map(_.name).map(aliasIfNeeded(this, _, dfPrefix, dup.toSet))) + withPlan { + TableFunctionJoin(renamedDf.plan, func, partitionByOrderBy) + } + } else { + originalResult + } } /** diff --git a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala index 3dbd5b89..505b2b6d 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala @@ -68,6 +68,7 @@ private[snowpark] object ErrorMessage { "0128" -> "DataFrameWriter doesn't support to set option '%s' as '%s' in '%s' mode when writing to a %s.", "0129" -> "DataFrameWriter doesn't support mode '%s' when writing to a %s.", "0130" -> "Unsupported join operations, Dataframes can join with other Dataframes or TableFunctions only", + "0131" -> "At most one table function can be called inside select() function", // Begin to define UDF related messages "0200" -> "Incorrect number of arguments passed to the UDF: Expected: %d, Found: %d", "0201" -> "Attempted to call an unregistered UDF. You must register the UDF before calling it.", @@ -244,6 +245,9 @@ private[snowpark] object ErrorMessage { def DF_JOIN_WITH_WRONG_ARGUMENT(): SnowparkClientException = createException("0130") + def DF_MORE_THAN_ONE_TF_IN_SELECT(): SnowparkClientException = + createException("0131") + /* * 2NN: UDF error code */ diff --git a/src/test/java/com/snowflake/snowpark_test/JavaUDTFSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaUDTFSuite.java index 4c2d3aae..226842e0 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaUDTFSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaUDTFSuite.java @@ -46,7 +46,7 @@ public void basicTypes() { "create or replace temp table " + tableName + "(i1 smallint, i2 int, l1 bigint, f1 float, d1 double, " - + "decimal number(38, 18), b boolean, s string, bi binary)"; + + "de number(38, 18), b boolean, s string, bi binary)"; runQuery(crt); String insert = "insert into " @@ -68,7 +68,7 @@ public void basicTypes() { col("l1"), col("f1"), col("d1"), - col("decimal"), + col("de"), col("b"), col("s"), col("bi")); @@ -82,7 +82,7 @@ public void basicTypes() { .append("|--L1: Long (nullable = true)") .append("|--F1: Double (nullable = true)") .append("|--D1: Double (nullable = true)") - .append("|--DECIMAL: Decimal(38, 18) (nullable = true)") + .append("|--DE: Decimal(38, 18) (nullable = true)") .append("|--B: Boolean (nullable = true)") .append("|--S: String (nullable = true)") .append("|--BI: Binary (nullable = true)") diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index d5ede212..072c73e8 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -296,6 +296,14 @@ class ErrorMessageSuite extends FunSuite { " or TableFunctions only")) } + test("DF_MORE_THAN_ONE_TF_IN_SELECT") { + val ex = ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT() + assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0131"))) + assert( + ex.message.startsWith("Error Code: 0131, Error message: " + + "At most one table function can be called inside select() function")) + } + test("UDF_INCORRECT_ARGS_NUMBER") { val ex = ErrorMessage.UDF_INCORRECT_ARGS_NUMBER(1, 2) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0200"))) diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index 46f0028f..fe330255 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -337,4 +337,53 @@ class TableFunctionSuite extends TestData { .select("value"), Seq(Row("77"), Row("88"))) } + + test("table function in select") { + val df = Seq((1, "1,2"), (2, "3,4")).toDF("idx", "data") + // only tf + val result1 = df.select(tableFunctions.split_to_table(df("data"), ",")) + assert(result1.schema.map(_.name) == Seq("SEQ", "INDEX", "VALUE")) + checkAnswer(result1, Seq(Row(1, 1, "1"), Row(1, 2, "2"), Row(2, 1, "3"), Row(2, 2, "4"))) + + // columns + tf + val result2 = df.select(df("idx"), tableFunctions.split_to_table(df("data"), ",")) + assert(result2.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE")) + checkAnswer( + result2, + Seq(Row(1, 1, 1, "1"), Row(1, 1, 2, "2"), Row(2, 2, 1, "3"), Row(2, 2, 2, "4"))) + + // columns + tf + columns + val result3 = df.select(df("idx"), tableFunctions.split_to_table(df("data"), ","), df("idx")) + assert(result3.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE", "IDX")) + checkAnswer( + result3, + Seq(Row(1, 1, 1, "1", 1), Row(1, 1, 2, "2", 1), Row(2, 2, 1, "3", 2), Row(2, 2, 2, "4", 2))) + + // tf + other express + val result4 = df.select(tableFunctions.split_to_table(df("data"), ","), df("idx") + 100) + checkAnswer( + result4, + Seq(Row(1, 1, "1", 101), Row(1, 2, "2", 101), Row(2, 1, "3", 102), Row(2, 2, "4", 102))) + } + + test("table function join with duplicated column name") { + val df = Seq((1, "1,2"), (2, "3,4")).toDF("idx", "value") + val result = df.join(tableFunctions.split_to_table(df("value"), lit(","))) + // only one VALUE in the result + checkAnswer(result.select("value"), Seq(Row("1"), Row("2"), Row("3"), Row("4"))) + checkAnswer(result.select(result("value")), Seq(Row("1"), Row("2"), Row("3"), Row("4"))) + checkAnswer(result.select(df("value")), Seq(Row("1,2"), Row("1,2"), Row("3,4"), Row("3,4"))) + } + + test("table function select with duplicated column name") { + val df = Seq((1, "1,2"), (2, "3,4")).toDF("idx", "value") + val result1 = df.select(tableFunctions.split_to_table(df("value"), lit(","))) + checkAnswer(result1, Seq(Row(1, 1, "1"), Row(1, 2, "2"), Row(2, 1, "3"), Row(2, 2, "4"))) + val result = df.select(df("value"), tableFunctions.split_to_table(df("value"), lit(","))) + // only one VALUE in the result + checkAnswer(result.select("value"), Seq(Row("1"), Row("2"), Row("3"), Row("4"))) + checkAnswer(result.select(result("value")), Seq(Row("1"), Row("2"), Row("3"), Row("4"))) + checkAnswer(result.select(df("value")), Seq(Row("1,2"), Row("1,2"), Row("3,4"), Row("3,4"))) + } + }