diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 12417787..7601f286 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -533,7 +533,7 @@ class DataFrame private[snowpark] ( * @param alias The alias name of the dataframe * @return a [[DataFrame]] */ - def alias(alias: String): DataFrame = withPlan(DataframeAlias(alias, plan)) + def alias(alias: String): DataFrame = withPlan(DataframeAlias(alias, plan, output)) /** * Returns a new DataFrame with the specified Column expressions as output (similar to SELECT in diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala index 826b4703..6192e49c 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala @@ -71,7 +71,8 @@ private[snowpark] class ExpressionAnalyzer( val normalizedColName = quoteName(aliasColName) val col = aliasOutput.filter(attr => attr.name.equals(normalizedColName)) if (col.length == 1) { - col.head.withName(normalizedColName) + // analyze new attributes at the same time + analyze(col.head.withName(normalizedColName)) } else { throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(aliasColName, aliasOutput.map(_.name)) } diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala index 961cea4f..51fef0f4 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala @@ -195,12 +195,18 @@ private[snowpark] case class Sort(order: Seq[SortOrder], child: LogicalPlan) ext Sort(order, _) } -private[snowpark] case class DataframeAlias(alias: String, child: LogicalPlan) extends UnaryNode { +// requires childOutput when creating, +// since child's SnowflakePlan can be empty +private[snowpark] case class DataframeAlias( + alias: String, + child: LogicalPlan, + childOutput: Seq[Attribute]) + extends UnaryNode { override lazy val dfAliasMap: Map[String, Seq[Attribute]] = - Utils.addToDataframeAliasMap(Map(alias -> child.getSnowflakePlan.get.output), child) + Utils.addToDataframeAliasMap(Map(alias -> childOutput), child) override protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan = - DataframeAlias(alias, _) + DataframeAlias(alias, _, childOutput) override protected def updateChild: LogicalPlan => LogicalPlan = createFromAnalyzedChild diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala index 7c2add81..9539809e 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala @@ -157,7 +157,7 @@ private object SqlGenerator extends Logging { .transformations(transformations) .options(options) .createSnowflakePlan() - case DataframeAlias(_, child) => resolveChild(child) + case DataframeAlias(_, child, _) => resolveChild(child) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala index 044bc270..5ca4ca9e 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala @@ -94,13 +94,6 @@ class DataFrameAliasSuite extends TestData with BeforeAndAfterEach with EagerSes .join(df2, df1.col("id") === df2.col("id")) .select(df2.col("B.num")), Seq(Row(7), Row(8), Row(9))) - - // The following use case is out of the scope of supporting alias - // We still follow the old ambiguity resolving policy and require DF to be used - assertThrows[SnowparkClientException]( - df1 - .join(df2, df1.col("id") === df2.col("id")) - .select($"A.num")) } test("Test for alias conflict") { @@ -113,4 +106,35 @@ class DataFrameAliasSuite extends TestData with BeforeAndAfterEach with EagerSes .join(df2, df1.col("id") === df2.col("id")) .select(df1.col("A.num"))) } + + test("snow-1335123") { + val df1 = Seq((1, 2, 3, 4), (11, 12, 13, 14), (21, 12, 23, 24), (11, 32, 33, 34)).toDF( + "col_a", + "col_b", + "col_c", + "col_d") + + val df2 = Seq((1, 2, 5, 6), (11, 12, 15, 16), (41, 12, 25, 26), (11, 42, 35, 36)).toDF( + "col_a", + "col_b", + "col_e", + "col_f") + + val df3 = df1 + .alias("a") + .join( + df2.alias("b"), + col("a.col_a") === col("b.col_a") + && col("a.col_b") === col("b.col_b"), + "left") + .select("a.col_a", "a.col_b", "col_c", "col_d", "col_e", "col_f") + + checkAnswer( + df3, + Seq( + Row(1, 2, 3, 4, 5, 6), + Row(11, 12, 13, 14, 15, 16), + Row(11, 32, 33, 34, null, null), + Row(21, 12, 23, 24, null, null))) + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala index e07b94a9..45aec853 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala @@ -13,7 +13,9 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { val tableName = randomName() val viewName = randomName() val samplingDeviation = 0.4 + import session.implicits._ + override def afterEach(): Unit = { dropTable(tableName) dropView(viewName) @@ -104,15 +106,15 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { assert( getShowString(df, 2) == """------------------------------- - ||"A" |"B" | - |------------------------------- - ||line1 |NULL | - ||line2 | | - ||single line |NotNull | - || |one more line | - || |last line | - |------------------------------- - |""".stripMargin) + ||"A" |"B" | + |------------------------------- + ||line1 |NULL | + ||line2 | | + ||single line |NotNull | + || |one more line | + || |last line | + |------------------------------- + |""".stripMargin) } test("show") { @@ -404,11 +406,11 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { assert( getSchemaString(nullData3.na.replace("flo", Map(Double.NaN -> null)).schema) == """root - | |--FLO: Double (nullable = true) - | |--INT: Long (nullable = true) - | |--BOO: Boolean (nullable = true) - | |--STR: String (nullable = true) - |""".stripMargin) + | |--FLO: Double (nullable = true) + | |--INT: Long (nullable = true) + | |--BOO: Boolean (nullable = true) + | |--STR: String (nullable = true) + |""".stripMargin) } @@ -535,69 +537,86 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { assert( getShowString(monthlySales.stat.sampleBy(col("empid"), Map(1 -> 0.0, 2 -> 1.0)), 10) == """-------------------------------- - ||"EMPID" |"AMOUNT" |"MONTH" | - |-------------------------------- - ||2 |4500 |JAN | - ||2 |35000 |JAN | - ||2 |200 |FEB | - ||2 |90500 |FEB | - ||2 |2500 |MAR | - ||2 |9500 |MAR | - ||2 |800 |APR | - ||2 |4500 |APR | - |-------------------------------- - |""".stripMargin) + ||"EMPID" |"AMOUNT" |"MONTH" | + |-------------------------------- + ||2 |4500 |JAN | + ||2 |35000 |JAN | + ||2 |200 |FEB | + ||2 |90500 |FEB | + ||2 |2500 |MAR | + ||2 |9500 |MAR | + ||2 |800 |APR | + ||2 |4500 |APR | + |-------------------------------- + |""".stripMargin) assert( getShowString(monthlySales.stat.sampleBy(col("month"), Map("JAN" -> 1.0)), 10) == """-------------------------------- - ||"EMPID" |"AMOUNT" |"MONTH" | - |-------------------------------- - ||1 |10000 |JAN | - ||1 |400 |JAN | - ||2 |4500 |JAN | - ||2 |35000 |JAN | - |-------------------------------- - |""".stripMargin) + ||"EMPID" |"AMOUNT" |"MONTH" | + |-------------------------------- + ||1 |10000 |JAN | + ||1 |400 |JAN | + ||2 |4500 |JAN | + ||2 |35000 |JAN | + |-------------------------------- + |""".stripMargin) assert( getShowString(monthlySales.stat.sampleBy(col("month"), Map()), 10) == """-------------------------------- - ||"EMPID" |"AMOUNT" |"MONTH" | - |-------------------------------- - |-------------------------------- - |""".stripMargin) + ||"EMPID" |"AMOUNT" |"MONTH" | + |-------------------------------- + |-------------------------------- + |""".stripMargin) } // On GitHub Action this test time out. But locally it passed. ignore("df.stat.pivot max column test") { def randomString(n: Int): String = Random.alphanumeric.filter(_.isLetter).take(n).mkString + // Local execution time: 1000 -> 25s, 3000 -> 2.5 min, 5000 -> 10 min. - val df1 = Seq.fill(1000) { (randomString(230), randomString(230)) }.toDF("a", "b") + val df1 = Seq + .fill(1000) { + (randomString(230), randomString(230)) + } + .toDF("a", "b") getShowString(df1.stat.crosstab("a", "b"), 1) - val df2 = Seq.fill(1001) { (randomString(230), randomString(230)) }.toDF("a", "b") + val df2 = Seq + .fill(1001) { + (randomString(230), randomString(230)) + } + .toDF("a", "b") assertThrows[SnowparkClientException](getShowString(df2.stat.crosstab("a", "b"), 1)) - val df3 = Seq.fill(1000) { (1, 1) }.toDF("a", "b") + val df3 = Seq + .fill(1000) { + (1, 1) + } + .toDF("a", "b") assert( getShowString(df3.stat.crosstab("a", "b"), 10) == """----------------------------------- - ||"A" |"CAST(1 AS NUMBER(38,0))" | - |----------------------------------- - ||1 |1000 | - |----------------------------------- - |""".stripMargin) + ||"A" |"CAST(1 AS NUMBER(38,0))" | + |----------------------------------- + ||1 |1000 | + |----------------------------------- + |""".stripMargin) - val df4 = Seq.fill(1001) { (1, 1) }.toDF("a", "b") + val df4 = Seq + .fill(1001) { + (1, 1) + } + .toDF("a", "b") assert( getShowString(df4.stat.crosstab("a", "b"), 10) == """----------------------------------- - ||"A" |"CAST(1 AS NUMBER(38,0))" | - |----------------------------------- - ||1 |1001 | - |----------------------------------- - |""".stripMargin) + ||"A" |"CAST(1 AS NUMBER(38,0))" | + |----------------------------------- + ||1 |1001 | + |----------------------------------- + |""".stripMargin) } test("select *") { @@ -793,13 +812,23 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { val values = Seq((1, null), (2, "NotNull"), (3, null)) // toDF(String, String*) with invalid args count - assertThrows[IllegalArgumentException]({ values.toDF("a") }) - assertThrows[IllegalArgumentException]({ values.toDF("a", "b", "c") }) + assertThrows[IllegalArgumentException]({ + values.toDF("a") + }) + assertThrows[IllegalArgumentException]({ + values.toDF("a", "b", "c") + }) // toDF(Seq[String]) with invalid args count - assertThrows[IllegalArgumentException]({ values.toDF(Seq.empty) }) - assertThrows[IllegalArgumentException]({ values.toDF(Seq("a")) }) - assertThrows[IllegalArgumentException]({ values.toDF(Seq("a", "b", "c")) }) + assertThrows[IllegalArgumentException]({ + values.toDF(Seq.empty) + }) + assertThrows[IllegalArgumentException]({ + values.toDF(Seq("a")) + }) + assertThrows[IllegalArgumentException]({ + values.toDF(Seq("a", "b", "c")) + }) } test("test sort()") { @@ -836,7 +865,9 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { } // Negative test: sort() needs at least one sort expression. - assertThrows[SnowparkClientException]({ df.sort(Seq.empty) }) + assertThrows[SnowparkClientException]({ + df.sort(Seq.empty) + }) } test("test select()") { @@ -874,14 +905,26 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { val df = Seq((1, "a", 10), (2, "b", 20), (3, "c", 30)).toDF("a", "b", "c") // Select with empty Seq - assertThrows[IllegalArgumentException]({ df.select(Seq.empty[String]) }) - assertThrows[IllegalArgumentException]({ df.select(Seq.empty[Column]) }) + assertThrows[IllegalArgumentException]({ + df.select(Seq.empty[String]) + }) + assertThrows[IllegalArgumentException]({ + df.select(Seq.empty[Column]) + }) // select column which doesn't exist. - assertThrows[SnowflakeSQLException]({ df.select("not_exist_column").collect() }) - assertThrows[SnowflakeSQLException]({ df.select(Seq("not_exist_column")).collect() }) - assertThrows[SnowflakeSQLException]({ df.select(col("not_exist_column")).collect() }) - assertThrows[SnowflakeSQLException]({ df.select(Seq(col("not_exist_column"))).collect() }) + assertThrows[SnowflakeSQLException]({ + df.select("not_exist_column").collect() + }) + assertThrows[SnowflakeSQLException]({ + df.select(Seq("not_exist_column")).collect() + }) + assertThrows[SnowflakeSQLException]({ + df.select(col("not_exist_column")).collect() + }) + assertThrows[SnowflakeSQLException]({ + df.select(Seq(col("not_exist_column"))).collect() + }) } test("drop() and dropColumns()") { @@ -916,10 +959,18 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { checkAnswer(df.drop(Seq(col("c"), col("b"))), expectedResult) // drop all columns (negative test) - assertThrows[SnowparkClientException]({ df.drop("a", "b", "c").collect() }) - assertThrows[SnowparkClientException]({ df.drop(Seq("a", "b", "c")).collect() }) - assertThrows[SnowparkClientException]({ df.drop(col("a"), col("b"), col("c")).collect() }) - assertThrows[SnowparkClientException]({ df.drop(Seq(col("a"), col("b"), col("c"))).collect() }) + assertThrows[SnowparkClientException]({ + df.drop("a", "b", "c").collect() + }) + assertThrows[SnowparkClientException]({ + df.drop(Seq("a", "b", "c")).collect() + }) + assertThrows[SnowparkClientException]({ + df.drop(col("a"), col("b"), col("c")).collect() + }) + assertThrows[SnowparkClientException]({ + df.drop(Seq(col("a"), col("b"), col("c"))).collect() + }) } test("DataFrame.agg()") { @@ -1313,12 +1364,12 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { "{\n \"'\": 1\n}", "1", Geography.fromGeoJSON("""{ - | "coordinates": [ - | 30, - | 10 - | ], - | "type": "Point" - |}""".stripMargin), + | "coordinates": [ + | 30, + | 10 + | ], + | "type": "Point" + |}""".stripMargin), Geometry.fromGeoJSON("""{ | "coordinates": [ | 2.000000000000000e+01, @@ -1431,12 +1482,12 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Row( "1", Geography.fromGeoJSON("""{ - | "coordinates": [ - | 10, - | 10 - | ], - | "type": "Point" - |}""".stripMargin), + | "coordinates": [ + | 10, + | 10 + | ], + | "type": "Point" + |}""".stripMargin), Geometry.fromGeoJSON("""{ | "coordinates": [ | 2.000000000000000e+01, @@ -1451,11 +1502,13 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { test("create nullable dataFrame with schema inference") { val df = Seq((1, Some(1), None), (2, Some(3), Some(true))) .toDF("a", "b", "c") - assert(getSchemaString(df.schema) == """root - | |--A: Long (nullable = false) - | |--B: Long (nullable = false) - | |--C: Boolean (nullable = true) - |""".stripMargin) + assert( + getSchemaString(df.schema) == + """root + | |--A: Long (nullable = false) + | |--B: Long (nullable = false) + | |--C: Boolean (nullable = true) + |""".stripMargin) checkAnswer(df, Seq(Row(1, 1, null), Row(2, 3, true)), sort = false) }