diff --git a/src/main/java/com/snowflake/snowpark_java/DataFrame.java b/src/main/java/com/snowflake/snowpark_java/DataFrame.java index 2e6d23ec..56e416e0 100644 --- a/src/main/java/com/snowflake/snowpark_java/DataFrame.java +++ b/src/main/java/com/snowflake/snowpark_java/DataFrame.java @@ -802,6 +802,25 @@ public Column col(String colName) { return new Column(df.col(colName)); } + /** + * Returns the current DataFrame aliased as the input alias name. + * + * For example: + * + * {{{ + * val df2 = df.alias("A") + * df2.select(df2.col("A.num")) + * }}} + * + * @group basic + * @since 1.10.0 + * @param alias The alias name of the dataframe + * @return a [[DataFrame]] + */ + public DataFrame alias(String alias) { + return new DataFrame(this.df.alias(alias)); + } + /** * Executes the query representing this DataFrame and returns the result as an array of Row * objects. diff --git a/src/main/scala/com/snowflake/snowpark/Column.scala b/src/main/scala/com/snowflake/snowpark/Column.scala index 1e37ed9d..56996aa9 100644 --- a/src/main/scala/com/snowflake/snowpark/Column.scala +++ b/src/main/scala/com/snowflake/snowpark/Column.scala @@ -732,6 +732,7 @@ private[snowpark] object Column { def apply(name: String): Column = new Column(name match { case "*" => Star(Seq.empty) + case c if c.contains(".") => UnresolvedDFAliasAttribute(name) case _ => UnresolvedAttribute(quoteName(name)) }) diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 54c43c49..12417787 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -518,6 +518,23 @@ class DataFrame private[snowpark] ( case _ => Column(resolve(colName)) } + /** + * Returns the current DataFrame aliased as the input alias name. + * + * For example: + * + * {{{ + * val df2 = df.alias("A") + * df2.select(df2.col("A.num")) + * }}} + * + * @group basic + * @since 1.10.0 + * @param alias The alias name of the dataframe + * @return a [[DataFrame]] + */ + def alias(alias: String): DataFrame = withPlan(DataframeAlias(alias, plan)) + /** * Returns a new DataFrame with the specified Column expressions as output (similar to SELECT in * SQL). Only the Columns specified as arguments will be present in the resulting DataFrame. @@ -2791,7 +2808,8 @@ class DataFrame private[snowpark] ( // utils private[snowpark] def resolve(colName: String): NamedExpression = { - val normalizedColName = quoteName(colName) + val (aliasColName, aliasOutput) = resolveAlias(colName, output) + val normalizedColName = quoteName(aliasColName) def isDuplicatedName: Boolean = { if (session.conn.hideInternalAlias) { this.plan.internalRenamedColumns.values.exists(_ == normalizedColName) @@ -2800,13 +2818,23 @@ class DataFrame private[snowpark] ( } } val col = - output.filter(attr => attr.name.equals(normalizedColName)) + aliasOutput.filter(attr => attr.name.equals(normalizedColName)) if (col.length == 1) { col.head.withName(normalizedColName).withSourceDF(this) } else if (isDuplicatedName) { - throw ErrorMessage.PLAN_JDBC_REPORT_JOIN_AMBIGUOUS(colName, colName) + throw ErrorMessage.PLAN_JDBC_REPORT_JOIN_AMBIGUOUS(aliasColName, aliasColName) + } else { + throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(aliasColName, aliasOutput.map(_.name)) + } + } + + // Handle dataframe alias by redirecting output and column name resolution + private def resolveAlias(colName: String, output: Seq[Attribute]): (String, Seq[Attribute]) = { + val colNameSplit = colName.split("\\.", 2) + if (colNameSplit.length > 1 && plan.dfAliasMap.contains(colNameSplit(0))) { + (colNameSplit(1), plan.dfAliasMap(colNameSplit(0))) } else { - throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(colName, output.map(_.name)) + (colName, output) } } diff --git a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala index 21db9dea..fbb7966c 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala @@ -69,6 +69,7 @@ private[snowpark] object ErrorMessage { "0129" -> "DataFrameWriter doesn't support mode '%s' when writing to a %s.", "0130" -> "Unsupported join operations, Dataframes can join with other Dataframes or TableFunctions only", "0131" -> "At most one table function can be called inside select() function", + "0132" -> "Duplicated dataframe alias defined: %s", // Begin to define UDF related messages "0200" -> "Incorrect number of arguments passed to the UDF: Expected: %d, Found: %d", "0201" -> "Attempted to call an unregistered UDF. You must register the UDF before calling it.", @@ -252,6 +253,9 @@ private[snowpark] object ErrorMessage { def DF_MORE_THAN_ONE_TF_IN_SELECT(): SnowparkClientException = createException("0131") + def DF_ALIAS_DUPLICATES(duplicatedAlias: scala.collection.Set[String]): SnowparkClientException = + createException("0132", duplicatedAlias.mkString(", ")) + /* * 2NN: UDF error code */ diff --git a/src/main/scala/com/snowflake/snowpark/internal/Utils.scala b/src/main/scala/com/snowflake/snowpark/internal/Utils.scala index 76f06f73..92c8173d 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/Utils.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/Utils.scala @@ -1,7 +1,7 @@ package com.snowflake.snowpark.internal import com.snowflake.snowpark.Column -import com.snowflake.snowpark.internal.analyzer.{Attribute, TableFunctionExpression, singleQuote} +import com.snowflake.snowpark.internal.analyzer.{Attribute, LogicalPlan, TableFunctionExpression, singleQuote} import java.io.{File, FileInputStream} import java.lang.invoke.SerializedLambda @@ -99,6 +99,20 @@ object Utils extends Logging { lastInternalLine + "\n" + stackTrace.take(stackDepth).mkString("\n") } + def addToDataframeAliasMap(result: Map[String, Seq[Attribute]], child: LogicalPlan) + : Map[String, Seq[Attribute]] = { + if (child != null) { + val map = child.dfAliasMap + val duplicatedAlias = result.keySet.intersect(map.keySet) + if (duplicatedAlias.nonEmpty) { + throw ErrorMessage.DF_ALIAS_DUPLICATES(duplicatedAlias) + } + result ++ map + } else { + result + } + } + def logTime[T](f: => T, funcDescription: String): T = { logInfo(funcDescription) val start = System.currentTimeMillis() diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala index 0ea99977..adbddd38 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/Expression.scala @@ -390,6 +390,19 @@ private[snowpark] case class UnresolvedAttribute(override val name: String) this } +private[snowpark] case class UnresolvedDFAliasAttribute(override val name: String) + extends Expression with NamedExpression { + override def sql: String = "" + + override def children: Seq[Expression] = Seq.empty + + // can't analyze + override lazy val dependentColumnNames: Option[Set[String]] = None + + override protected def createAnalyzedExpression(analyzedChildren: Seq[Expression]): Expression = + this +} + private[snowpark] case class ListAgg(col: Expression, delimiter: String, isDistinct: Boolean) extends Expression { override def children: Seq[Expression] = Seq(col) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala index bd412766..76bed5ef 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/ExpressionAnalyzer.scala @@ -1,33 +1,39 @@ package com.snowflake.snowpark.internal.analyzer +import com.snowflake.snowpark.internal.ErrorMessage + import scala.collection.mutable.{Map => MMap} private[snowpark] object ExpressionAnalyzer { - def apply(aliasMap: Map[ExprId, String]): ExpressionAnalyzer = - new ExpressionAnalyzer(aliasMap) + def apply(aliasMap: Map[ExprId, String], + dfAliasMap: Map[String, Seq[Attribute]]): ExpressionAnalyzer = + new ExpressionAnalyzer(aliasMap, dfAliasMap) def apply(): ExpressionAnalyzer = - new ExpressionAnalyzer(Map.empty) + new ExpressionAnalyzer(Map.empty, Map.empty) // create new analyzer by combining two alias maps - def apply(map1: Map[ExprId, String], map2: Map[ExprId, String]): ExpressionAnalyzer = { + def apply(map1: Map[ExprId, String], map2: Map[ExprId, String], + dfAliasMap: Map[String, Seq[Attribute]]): ExpressionAnalyzer = { val common = map1.keySet & map2.keySet val result = (map1 ++ map2).filter { // remove common column, let (df1.join(df2)) // .join(df2.join(df3)).select(df2) report error case (id, _) => !common.contains(id) } - new ExpressionAnalyzer(result) + new ExpressionAnalyzer(result, dfAliasMap) } - def apply(maps: Seq[Map[ExprId, String]]): ExpressionAnalyzer = { + def apply(maps: Seq[Map[ExprId, String]], + dfAliasMap: Map[String, Seq[Attribute]]): ExpressionAnalyzer = { maps.foldLeft(ExpressionAnalyzer()) { - case (expAnalyzer, map) => ExpressionAnalyzer(expAnalyzer.getAliasMap, map) + case (expAnalyzer, map) => ExpressionAnalyzer(expAnalyzer.getAliasMap, map, dfAliasMap) } } } -private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String]) { +private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String], + dfAliasMap: Map[String, Seq[Attribute]]) { private val generatedAliasMap: MMap[ExprId, String] = MMap.empty def analyze(ex: Expression): Expression = ex match { @@ -52,6 +58,25 @@ private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String]) { // removed useless alias case Alias(child: NamedExpression, name, _) if quoteName(child.name) == quoteName(name) => child + case UnresolvedDFAliasAttribute(name) => + val colNameSplit = name.split("\\.", 2) + if (colNameSplit.length > 1 && dfAliasMap.contains(colNameSplit(0))) { + val aliasOutput = dfAliasMap(colNameSplit(0)) + val aliasColName = colNameSplit(1) + val normalizedColName = quoteName(aliasColName) + val col = aliasOutput.filter(attr => attr.name.equals(normalizedColName)) + if (col.length == 1) { + col.head.withName(normalizedColName) + } else { + throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(aliasColName, aliasOutput.map(_.name)) + } + } else { + // if didn't find alias in the map + name match { + case "*" => Star(Seq.empty) + case _ => UnresolvedAttribute(quoteName(name)) + } + } case _ => ex } diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala index 7239e1dd..b5594184 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/MultiChildrenNode.scala @@ -1,5 +1,7 @@ package com.snowflake.snowpark.internal.analyzer +import com.snowflake.snowpark.internal.Utils + private[snowpark] trait MultiChildrenNode extends LogicalPlan { override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = { val newChildren: Seq[LogicalPlan] = children.map(func) @@ -11,13 +13,19 @@ private[snowpark] trait MultiChildrenNode extends LogicalPlan { protected def updateChildren(newChildren: Seq[LogicalPlan]): MultiChildrenNode + + override lazy val dfAliasMap: Map[String, Seq[Attribute]] = + children.foldLeft(Map.empty[String, Seq[Attribute]]) { + case (map, child) => Utils.addToDataframeAliasMap(map, child) + } + override protected def analyze: LogicalPlan = createFromAnalyzedChildren(children.map(_.analyzed)) protected def createFromAnalyzedChildren: Seq[LogicalPlan] => MultiChildrenNode override protected def analyzer: ExpressionAnalyzer = - ExpressionAnalyzer(children.map(_.aliasMap)) + ExpressionAnalyzer(children.map(_.aliasMap), dfAliasMap) lazy override val internalRenamedColumns: Map[String, String] = children.map(_.internalRenamedColumns).reduce(_ ++ _) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakeCreateTable.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakeCreateTable.scala index 8b30eb9c..ff44f927 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakeCreateTable.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakeCreateTable.scala @@ -10,7 +10,7 @@ case class SnowflakeCreateTable(tableName: String, mode: SaveMode, query: Option SnowflakeCreateTable(tableName, mode, query.map(_.analyzed)) override protected val analyzer: ExpressionAnalyzer = - ExpressionAnalyzer(query.map(_.aliasMap).getOrElse(Map.empty)) + ExpressionAnalyzer(query.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap) override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = { val newQuery = query.map(func) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala index f90d1dd3..a3218758 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala @@ -106,7 +106,7 @@ class SnowflakePlan( sourcePlan.map(_.analyzed).getOrElse(this) override protected def analyzer: ExpressionAnalyzer = - ExpressionAnalyzer(sourcePlan.map(_.aliasMap).getOrElse(Map.empty)) + ExpressionAnalyzer(sourcePlan.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap) override def getSnowflakePlan: Option[SnowflakePlan] = Some(this) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala index 12b54e41..3cebd228 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlanNode.scala @@ -1,6 +1,6 @@ package com.snowflake.snowpark.internal.analyzer -import com.snowflake.snowpark.internal.ErrorMessage +import com.snowflake.snowpark.internal.{ErrorMessage, Utils} import com.snowflake.snowpark.Row private[snowpark] trait LogicalPlan { @@ -18,6 +18,8 @@ private[snowpark] trait LogicalPlan { (analyzedPlan, analyzer.getAliasMap) } + lazy val dfAliasMap: Map[String, Seq[Attribute]] = Map.empty + protected def analyze: LogicalPlan protected def analyzer: ExpressionAnalyzer @@ -138,8 +140,9 @@ private[snowpark] trait UnaryNode extends LogicalPlan { lazy protected val analyzedChild: LogicalPlan = child.analyzed // create expression analyzer from child's alias map lazy override protected val analyzer: ExpressionAnalyzer = - ExpressionAnalyzer(child.aliasMap) + ExpressionAnalyzer(child.aliasMap, dfAliasMap) + override lazy val dfAliasMap: Map[String, Seq[Attribute]] = child.dfAliasMap override protected def analyze: LogicalPlan = createFromAnalyzedChild(analyzedChild) protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan @@ -192,6 +195,18 @@ private[snowpark] case class Sort(order: Seq[SortOrder], child: LogicalPlan) ext Sort(order, _) } +private[snowpark] case class DataframeAlias(alias: String, child: LogicalPlan) + extends UnaryNode { + + override lazy val dfAliasMap: Map[String, Seq[Attribute]] = + Utils.addToDataframeAliasMap(Map(alias -> child.getSnowflakePlan.get.output), child) + override protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan = + DataframeAlias(alias, _) + + override protected def updateChild: LogicalPlan => LogicalPlan = + createFromAnalyzedChild +} + private[snowpark] case class Aggregate( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala index a7a5f655..7c2add81 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala @@ -157,6 +157,7 @@ private object SqlGenerator extends Logging { .transformations(transformations) .options(options) .createSnowflakePlan() + case DataframeAlias(_, child) => resolveChild(child) } } diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableDelete.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableDelete.scala index f50900d8..9c4922dd 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableDelete.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableDelete.scala @@ -14,7 +14,7 @@ case class TableDelete( TableDelete(tableName, condition.map(_.analyze(analyzer.analyze)), sourceData.map(_.analyzed)) override protected def analyzer: ExpressionAnalyzer = - ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty)) + ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap) override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = { val newSource = sourceData.map(func) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableUpdate.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableUpdate.scala index 07faa247..ddee926c 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableUpdate.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/TableUpdate.scala @@ -17,7 +17,7 @@ case class TableUpdate( }, condition.map(_.analyze(analyzer.analyze)), sourceData.map(_.analyzed)) override protected def analyzer: ExpressionAnalyzer = - ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty)) + ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap) override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = { val newSource = sourceData.map(func) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala index e0b9b4f4..98ff26d5 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/binaryPlanNodes.scala @@ -1,8 +1,7 @@ package com.snowflake.snowpark.internal.analyzer import java.util.Locale - -import com.snowflake.snowpark.internal.ErrorMessage +import com.snowflake.snowpark.internal.{ErrorMessage, Utils} private[snowpark] abstract class BinaryNode extends LogicalPlan { def left: LogicalPlan @@ -14,7 +13,11 @@ private[snowpark] abstract class BinaryNode extends LogicalPlan { lazy protected val analyzedRight: LogicalPlan = right.analyzed lazy override protected val analyzer: ExpressionAnalyzer = - ExpressionAnalyzer(left.aliasMap, right.aliasMap) + ExpressionAnalyzer(left.aliasMap, right.aliasMap, dfAliasMap) + + + override lazy val dfAliasMap: Map[String, Seq[Attribute]] = + Utils.addToDataframeAliasMap(Utils.addToDataframeAliasMap(Map.empty, left), right) override def analyze: LogicalPlan = createFromAnalyzedChildren(analyzedLeft, analyzedRight) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/unaryExpressions.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/unaryExpressions.scala index 572ea890..5db1bfef 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/unaryExpressions.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/unaryExpressions.scala @@ -75,6 +75,16 @@ private[snowpark] case class Alias(child: Expression, name: String, isInternal: override protected val createAnalyzedUnary: Expression => Expression = Alias(_, name) } +private[snowpark] case class DfAlias(child: Expression, name: String) + extends UnaryExpression + with NamedExpression { + override def sqlOperator: String = "" + override def operatorFirst: Boolean = false + override def toString: String = "" + + override protected val createAnalyzedUnary: Expression => Expression = DfAlias(_, name) +} + private[snowpark] case class UnresolvedAlias( child: Expression, aliasFunc: Option[Expression => String] = None) diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index bfeddd72..2a3df7f7 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -304,6 +304,14 @@ class ErrorMessageSuite extends FunSuite { "At most one table function can be called inside select() function")) } + test("DF_ALIAS_DUPLICATES") { + val ex = ErrorMessage.DF_ALIAS_DUPLICATES(Set("a", "b")) + assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0132"))) + assert( + ex.message.startsWith("Error Code: 0132, Error message: " + + "Duplicated dataframe alias defined: a, b")) + } + test("UDF_INCORRECT_ARGS_NUMBER") { val ex = ErrorMessage.UDF_INCORRECT_ARGS_NUMBER(1, 2) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0200"))) diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala new file mode 100644 index 00000000..5deca2c9 --- /dev/null +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala @@ -0,0 +1,96 @@ +package com.snowflake.snowpark_test + +import com.snowflake.snowpark._ +import com.snowflake.snowpark.functions._ +import com.snowflake.snowpark.internal.analyzer._ +import com.snowflake.snowpark.types._ +import net.snowflake.client.jdbc.SnowflakeSQLException +import org.scalatest.BeforeAndAfterEach +import java.sql.{Date, Time, Timestamp} +import scala.util.Random + +class DataFrameAliasSuite extends TestData with BeforeAndAfterEach with EagerSession { + val tableName1: String = randomName() + val tableName2: String = randomName() + import session.implicits._ + + override def afterEach(): Unit = { + dropTable(tableName1) + dropTable(tableName2) + super.afterEach() + } + + test("Test for alias with df.col, col and $") { + createTable(tableName1, "num int") + runQuery(s"insert into $tableName1 values(1),(2),(3)", session) + val df = session.table(tableName1).alias("A") + checkAnswer(df.select(df.col("A.num")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df.select(col("A.num")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df.select($"A.num"), Seq(Row(1), Row(2), Row(3))) + + val df1 = df.alias("B") + checkAnswer(df1.select(df1.col("A.num")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df1.select(col("A.num")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df1.select($"A.num"), Seq(Row(1), Row(2), Row(3))) + + checkAnswer(df1.select(df1.col("B.num")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df1.select(col("B.num")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df1.select($"B.num"), Seq(Row(1), Row(2), Row(3))) + } + + test("Test for alias with dot in column name") { + createTable(tableName1, "\"num.col\" int") + runQuery(s"insert into $tableName1 values(1),(2),(3)", session) + val df = session.table(tableName1).alias("A") + checkAnswer(df.select(df.col("A.num.col")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df.select(col("A.num.col")), Seq(Row(1), Row(2), Row(3))) + checkAnswer(df.select($"A.num.col"), Seq(Row(1), Row(2), Row(3))) + } + + test("Test for alias with join") { + createTable(tableName1, "id1 int, num1 int") + createTable(tableName2, "id2 int, num2 int") + runQuery(s"insert into $tableName1 values(1, 4),(2, 5),(3, 6)", session) + runQuery(s"insert into $tableName2 values(1, 7),(2, 8),(3, 9)", session) + val df1 = session.table(tableName1).alias("A") + val df2 = session.table(tableName2).alias("B") + checkAnswer(df1.join(df2, $"id1" === $"id2") + .select(df1.col("A.num1")), Seq(Row(4), Row(5), Row(6))) + checkAnswer(df1.join(df2, $"id1" === $"id2") + .select(df2.col("B.num2")), Seq(Row(7), Row(8), Row(9))) + + checkAnswer(df1.join(df2, $"id1" === $"id2") + .select($"A.num1"), Seq(Row(4), Row(5), Row(6))) + checkAnswer(df1.join(df2, $"id1" === $"id2") + .select($"B.num2"), Seq(Row(7), Row(8), Row(9))) + } + + test("Test for alias with join with column renaming") { + createTable(tableName1, "id int, num int") + createTable(tableName2, "id int, num int") + runQuery(s"insert into $tableName1 values(1, 4),(2, 5),(3, 6)", session) + runQuery(s"insert into $tableName2 values(1, 7),(2, 8),(3, 9)", session) + val df1 = session.table(tableName1).alias("A") + val df2 = session.table(tableName2).alias("B") + checkAnswer(df1.join(df2, df1.col("id") === df2.col("id")) + .select(df1.col("A.num")), Seq(Row(4), Row(5), Row(6))) + checkAnswer(df1.join(df2, df1.col("id") === df2.col("id")) + .select(df2.col("B.num")), Seq(Row(7), Row(8), Row(9))) + + // The following use case is out of the scope of supporting alias + // We still follow the old ambiguity resolving policy and require DF to be used + assertThrows[SnowparkClientException]( + df1.join(df2, df1.col("id") === df2.col("id")) + .select($"A.num")) + } + + test("Test for alias conflict") { + createTable(tableName1, "id int, num int") + createTable(tableName2, "id int, num int") + val df1 = session.table(tableName1).alias("A") + val df2 = session.table(tableName2).alias("A") + assertThrows[SnowparkClientException]( + df1.join(df2, df1.col("id") === df2.col("id")) + .select(df1.col("A.num"))) + } +}