Skip to content

Commit

Permalink
add alias
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-zli committed Jan 2, 2024
1 parent 71f54ae commit 76c26ed
Show file tree
Hide file tree
Showing 15 changed files with 145 additions and 22 deletions.
1 change: 1 addition & 0 deletions src/main/scala/com/snowflake/snowpark/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ private[snowpark] object Column {
def apply(name: String): Column =
new Column(name match {
case "*" => Star(Seq.empty)
case c if c.contains(".") => DfAliasAttribute(name)
case _ => UnresolvedAttribute(quoteName(name))
})

Expand Down
36 changes: 32 additions & 4 deletions src/main/scala/com/snowflake/snowpark/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,23 @@ class DataFrame private[snowpark] (
case _ => Column(resolve(colName))
}

/**
* Returns the current DataFrame aliased as the input alias name.
*
* For example:
*
* {{{
* val df2 = df.alias("A")
* df2.select(df2.col("A.num"))
* }}}
*
* @group basic
* @since 1.10.0
* @param alias The alias name of the dataframe
* @return a [[DataFrame]]
*/
def alias(alias: String): DataFrame = withPlan(DataframeAlias(alias, plan))

/**
* Returns a new DataFrame with the specified Column expressions as output (similar to SELECT in
* SQL). Only the Columns specified as arguments will be present in the resulting DataFrame.
Expand Down Expand Up @@ -2791,7 +2808,8 @@ class DataFrame private[snowpark] (

// utils
private[snowpark] def resolve(colName: String): NamedExpression = {
val normalizedColName = quoteName(colName)
val (aliasColName, aliasOutput) = resolveAlias(colName, output)
val normalizedColName = quoteName(aliasColName)
def isDuplicatedName: Boolean = {
if (session.conn.hideInternalAlias) {
this.plan.internalRenamedColumns.values.exists(_ == normalizedColName)
Expand All @@ -2800,13 +2818,23 @@ class DataFrame private[snowpark] (
}
}
val col =
output.filter(attr => attr.name.equals(normalizedColName))
aliasOutput.filter(attr => attr.name.equals(normalizedColName))
if (col.length == 1) {
col.head.withName(normalizedColName).withSourceDF(this)
} else if (isDuplicatedName) {
throw ErrorMessage.PLAN_JDBC_REPORT_JOIN_AMBIGUOUS(colName, colName)
throw ErrorMessage.PLAN_JDBC_REPORT_JOIN_AMBIGUOUS(aliasColName, aliasColName)
} else {
throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(aliasColName, aliasOutput.map(_.name))
}
}

// Handle dataframe alias by redirecting output and column name resolution
private def resolveAlias(colName: String, output: Seq[Attribute]): (String, Seq[Attribute]) = {
val colNameSplit = colName.split("\\.", 2)
if (colNameSplit.length > 1 && plan.dfAliasMap.contains(colNameSplit(0))) {
(colNameSplit(1), plan.dfAliasMap(colNameSplit(0)))
} else {
throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(colName, output.map(_.name))
(colName, output)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ private[snowpark] object ErrorMessage {
"0129" -> "DataFrameWriter doesn't support mode '%s' when writing to a %s.",
"0130" -> "Unsupported join operations, Dataframes can join with other Dataframes or TableFunctions only",
"0131" -> "At most one table function can be called inside select() function",
"0132" -> "Duplicated dataframe alias defined: %s",
// Begin to define UDF related messages
"0200" -> "Incorrect number of arguments passed to the UDF: Expected: %d, Found: %d",
"0201" -> "Attempted to call an unregistered UDF. You must register the UDF before calling it.",
Expand Down Expand Up @@ -252,6 +253,9 @@ private[snowpark] object ErrorMessage {
def DF_MORE_THAN_ONE_TF_IN_SELECT(): SnowparkClientException =
createException("0131")

def DF_ALIAS_DUPLICATES(duplicatedAlias: scala.collection.Set[String]): SnowparkClientException =
createException("0132", duplicatedAlias.mkString(", "))

/*
* 2NN: UDF error code
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,19 @@ private[snowpark] case class UnresolvedAttribute(override val name: String)
this
}

private[snowpark] case class DfAliasAttribute(override val name: String)
extends Expression with NamedExpression {
override def sql: String = ""

override def children: Seq[Expression] = Seq.empty

// can't analyze
override lazy val dependentColumnNames: Option[Set[String]] = None

override protected def createAnalyzedExpression(analyzedChildren: Seq[Expression]): Expression =
this
}

private[snowpark] case class ListAgg(col: Expression, delimiter: String, isDistinct: Boolean)
extends Expression {
override def children: Seq[Expression] = Seq(col)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,39 @@
package com.snowflake.snowpark.internal.analyzer

import com.snowflake.snowpark.internal.ErrorMessage

import scala.collection.mutable.{Map => MMap}

private[snowpark] object ExpressionAnalyzer {
def apply(aliasMap: Map[ExprId, String]): ExpressionAnalyzer =
new ExpressionAnalyzer(aliasMap)
def apply(aliasMap: Map[ExprId, String],
dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer =
new ExpressionAnalyzer(aliasMap, dfAliasMap)

def apply(): ExpressionAnalyzer =
new ExpressionAnalyzer(Map.empty)
def apply(dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer =
new ExpressionAnalyzer(Map.empty, dfAliasMap)

// create new analyzer by combining two alias maps
def apply(map1: Map[ExprId, String], map2: Map[ExprId, String]): ExpressionAnalyzer = {
def apply(map1: Map[ExprId, String], map2: Map[ExprId, String],
dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer = {
val common = map1.keySet & map2.keySet
val result = (map1 ++ map2).filter {
// remove common column, let (df1.join(df2))
// .join(df2.join(df3)).select(df2) report error
case (id, _) => !common.contains(id)
}
new ExpressionAnalyzer(result)
new ExpressionAnalyzer(result, dfAliasMap)
}

def apply(maps: Seq[Map[ExprId, String]]): ExpressionAnalyzer = {
maps.foldLeft(ExpressionAnalyzer()) {
case (expAnalyzer, map) => ExpressionAnalyzer(expAnalyzer.getAliasMap, map)
def apply(maps: Seq[Map[ExprId, String]],
dfAliasMap: MMap[String, Seq[Attribute]]): ExpressionAnalyzer = {
maps.foldLeft(ExpressionAnalyzer(dfAliasMap)) {
case (expAnalyzer, map) => ExpressionAnalyzer(expAnalyzer.getAliasMap, map, dfAliasMap)
}
}
}

private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String]) {
private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String],
dfAliasMap: MMap[String, Seq[Attribute]]) {
private val generatedAliasMap: MMap[ExprId, String] = MMap.empty

def analyze(ex: Expression): Expression = ex match {
Expand All @@ -52,6 +58,25 @@ private[snowpark] class ExpressionAnalyzer(aliasMap: Map[ExprId, String]) {
// removed useless alias
case Alias(child: NamedExpression, name, _) if quoteName(child.name) == quoteName(name) =>
child
case DfAliasAttribute(name) =>
val colNameSplit = name.split("\\.", 2)
if (colNameSplit.length > 1 && dfAliasMap.contains(colNameSplit(0))) {
val aliasOutput = dfAliasMap(colNameSplit(0))
val aliasColName = colNameSplit(1)
val normalizedColName = quoteName(aliasColName)
val col = aliasOutput.filter(attr => attr.name.equals(normalizedColName))
if (col.length == 1) {
col.head.withName(normalizedColName)
} else {
throw ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME(aliasColName, aliasOutput.map(_.name))
}
} else {
// if didn't find alias in the map
name match {
case "*" => Star(Seq.empty)
case _ => UnresolvedAttribute(quoteName(name))
}
}
case _ => ex
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ private[snowpark] trait MultiChildrenNode extends LogicalPlan {

protected def updateChildren(newChildren: Seq[LogicalPlan]): MultiChildrenNode

children.foreach(child => addToDataframeAliasMap(child.dfAliasMap))
override protected def analyze: LogicalPlan =
createFromAnalyzedChildren(children.map(_.analyzed))

protected def createFromAnalyzedChildren: Seq[LogicalPlan] => MultiChildrenNode

override protected def analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(children.map(_.aliasMap))
ExpressionAnalyzer(children.map(_.aliasMap), dfAliasMap)

lazy override val internalRenamedColumns: Map[String, String] =
children.map(_.internalRenamedColumns).reduce(_ ++ _)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ case class SnowflakeCreateTable(tableName: String, mode: SaveMode, query: Option
SnowflakeCreateTable(tableName, mode, query.map(_.analyzed))

override protected val analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(query.map(_.aliasMap).getOrElse(Map.empty))
ExpressionAnalyzer(query.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap)

override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = {
val newQuery = query.map(func)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class SnowflakePlan(
sourcePlan.map(_.analyzed).getOrElse(this)

override protected def analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(sourcePlan.map(_.aliasMap).getOrElse(Map.empty))
ExpressionAnalyzer(sourcePlan.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap)

override def getSnowflakePlan: Option[SnowflakePlan] = Some(this)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.snowflake.snowpark.internal.analyzer

import com.snowflake.snowpark.internal.ErrorMessage
import com.snowflake.snowpark.Row
import scala.collection.mutable.{Map => MMap}

private[snowpark] trait LogicalPlan {
def children: Seq[LogicalPlan] = Seq.empty
Expand All @@ -18,6 +19,23 @@ private[snowpark] trait LogicalPlan {
(analyzedPlan, analyzer.getAliasMap)
}

var dfAliasMap: MMap[String, Seq[Attribute]] = MMap.empty

// map from df alias string to snowflakePlan.output
// add to map when DataframeAlias node is createdFromChild
// merge map when analyze is called on leafNode, unaryNode, multiChildrenNode
// report conflict if there is merge collision
// New expression dataframeAttribute when input has .
// Expression analizer -> see dataframeAttribute -> split and search map
// if map does not contain the key, then treat as normal column name
// else search for Attribute with the name in the attribute list
protected def addToDataframeAliasMap(map: MMap[String, Seq[Attribute]]): Unit = {
val duplicatedAlias = dfAliasMap.keySet.intersect(map.keySet)
if (duplicatedAlias.nonEmpty) {
throw ErrorMessage.DF_ALIAS_DUPLICATES(duplicatedAlias)
}
dfAliasMap ++= map
}
protected def analyze: LogicalPlan
protected def analyzer: ExpressionAnalyzer

Expand Down Expand Up @@ -69,7 +87,7 @@ private[snowpark] trait LogicalPlan {

private[snowpark] trait LeafNode extends LogicalPlan {
// create ExpressionAnalyzer with empty alias map
override protected val analyzer: ExpressionAnalyzer = ExpressionAnalyzer()
override protected val analyzer: ExpressionAnalyzer = ExpressionAnalyzer(dfAliasMap)

// leaf node doesn't have child
override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = this
Expand Down Expand Up @@ -138,8 +156,9 @@ private[snowpark] trait UnaryNode extends LogicalPlan {
lazy protected val analyzedChild: LogicalPlan = child.analyzed
// create expression analyzer from child's alias map
lazy override protected val analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(child.aliasMap)
ExpressionAnalyzer(child.aliasMap, dfAliasMap)

addToDataframeAliasMap(child.dfAliasMap)
override protected def analyze: LogicalPlan = createFromAnalyzedChild(analyzedChild)

protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan
Expand Down Expand Up @@ -192,6 +211,17 @@ private[snowpark] case class Sort(order: Seq[SortOrder], child: LogicalPlan) ext
Sort(order, _)
}

private[snowpark] case class DataframeAlias(alias: String, child: LogicalPlan)
extends UnaryNode {
dfAliasMap += (alias -> child.getSnowflakePlan.get.output)
override protected def createFromAnalyzedChild: LogicalPlan => LogicalPlan = child => {
DataframeAlias(alias, child)
}

override protected def updateChild: LogicalPlan => LogicalPlan =
createFromAnalyzedChild
}

private[snowpark] case class Aggregate(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ private object SqlGenerator extends Logging {
.transformations(transformations)
.options(options)
.createSnowflakePlan()
case DataframeAlias(_, child) => resolveChild(child)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ case class TableDelete(
TableDelete(tableName, condition.map(_.analyze(analyzer.analyze)), sourceData.map(_.analyzed))

override protected def analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty))
ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap)

override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = {
val newSource = sourceData.map(func)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ case class TableUpdate(
}, condition.map(_.analyze(analyzer.analyze)), sourceData.map(_.analyzed))

override protected def analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty))
ExpressionAnalyzer(sourceData.map(_.aliasMap).getOrElse(Map.empty), dfAliasMap)

override def updateChildren(func: LogicalPlan => LogicalPlan): LogicalPlan = {
val newSource = sourceData.map(func)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ private[snowpark] abstract class BinaryNode extends LogicalPlan {
lazy protected val analyzedRight: LogicalPlan = right.analyzed

lazy override protected val analyzer: ExpressionAnalyzer =
ExpressionAnalyzer(left.aliasMap, right.aliasMap)
ExpressionAnalyzer(left.aliasMap, right.aliasMap, dfAliasMap)

addToDataframeAliasMap(left.dfAliasMap)
addToDataframeAliasMap(right.dfAliasMap)
override def analyze: LogicalPlan =
createFromAnalyzedChildren(analyzedLeft, analyzedRight)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ private[snowpark] case class Alias(child: Expression, name: String, isInternal:
override protected val createAnalyzedUnary: Expression => Expression = Alias(_, name)
}

private[snowpark] case class DfAlias(child: Expression, name: String)
extends UnaryExpression
with NamedExpression {
override def sqlOperator: String = ""
override def operatorFirst: Boolean = false
override def toString: String = ""

override protected val createAnalyzedUnary: Expression => Expression = DfAlias(_, name)
}

private[snowpark] case class UnresolvedAlias(
child: Expression,
aliasFunc: Option[Expression => String] = None)
Expand Down
8 changes: 8 additions & 0 deletions src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,14 @@ class ErrorMessageSuite extends FunSuite {
"At most one table function can be called inside select() function"))
}

test("DF_ALIAS_DUPLICATES") {
val ex = ErrorMessage.DF_ALIAS_DUPLICATES(Set("a", "b"))
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0132")))
assert(
ex.message.startsWith("Error Code: 0132, Error message: " +
"Duplicated dataframe alias defined: a, b"))
}

test("UDF_INCORRECT_ARGS_NUMBER") {
val ex = ErrorMessage.UDF_INCORRECT_ARGS_NUMBER(1, 2)
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0200")))
Expand Down

0 comments on commit 76c26ed

Please sign in to comment.