Skip to content

Commit

Permalink
[SPARK-34003][SQL][FOLLOWUP] Avoid pushing modified Char/Varchar sort…
Browse files Browse the repository at this point in the history
… attributes into aggregate for existing ones

### What changes were proposed in this pull request?

In apache@0f8e5dd, we partially fix the rule conflicts between `PaddingAndLengthCheckForCharVarchar` and `ResolveAggregateFunctions`, as error still exists in

sql like ```SELECT substr(v, 1, 2), sum(i) FROM t GROUP BY v ORDER BY substr(v, 1, 2)```

```sql
[info]   Failed to analyze query: org.apache.spark.sql.AnalysisException: expression 'spark_catalog.default.t.`v`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;
[info]   Project [substr(v, 1, 2)Kyligence#100, sum(i)#101L]
[info]   +- Sort [aggOrder#102 ASC NULLS FIRST], true
[info]      +- !Aggregate [v#106], [substr(v#106, 1, 2) AS substr(v, 1, 2)Kyligence#100, sum(cast(i#98 as bigint)) AS sum(i)#101L, substr(v#103, 1, 2) AS aggOrder#102
[info]         +- SubqueryAlias spark_catalog.default.t
[info]            +- Project [if ((length(v#97) <= 3)) v#97 else if ((length(rtrim(v#97, None)) > 3)) cast(raise_error(concat(input string of length , cast(length(v#97) as string),  exceeds varchar type length limitation: 3)) as string) else rpad(rtrim(v#97, None), 3,  ) AS v#106, i#98]
[info]               +- Relation[v#97,i#98] parquet
[info]
[info]   Project [substr(v, 1, 2)Kyligence#100, sum(i)#101L]
[info]   +- Sort [aggOrder#102 ASC NULLS FIRST], true
[info]      +- !Aggregate [v#106], [substr(v#106, 1, 2) AS substr(v, 1, 2)Kyligence#100, sum(cast(i#98 as bigint)) AS sum(i)#101L, substr(v#103, 1, 2) AS aggOrder#102
[info]         +- SubqueryAlias spark_catalog.default.t
[info]            +- Project [if ((length(v#97) <= 3)) v#97 else if ((length(rtrim(v#97, None)) > 3)) cast(raise_error(concat(input string of length , cast(length(v#97) as string),  exceeds varchar type length limitation: 3)) as string) else rpad(rtrim(v#97, None), 3,  ) AS v#106, i#98]
[info]               +- Relation[v#97,i#98] parquet

```
We need to look recursively into children to find char/varchars.

In this PR,  we try to resolve the full attributes including the original `Aggregate` expressions and the candidates in `SortOrder` together, then use the new re-resolved `Aggregate` expressions to determine which candidate in the `SortOrder` shall be pushed. This can avoid mismatch for the same attributes w/o this change, as the expressions returned by `executeSameContext` will change when `PaddingAndLengthCheckForCharVarchar` takes effects. W/ this change, the expressions can be matched correctly.

For those unmatched, w need to look recursively into children to find char/varchars instead of the expression itself only.

### Why are the changes needed?

bugfix

### Does this PR introduce _any_ user-facing change?

no
### How was this patch tested?

add new tests

Closes apache#31129 from yaooqinn/SPARK-34003-F.

Authored-by: Kent Yao <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
yaooqinn authored and cloud-fan committed Jan 12, 2021
1 parent 02a17e9 commit 99f8489
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2381,13 +2381,16 @@ class Analyzer(override val catalogManager: CatalogManager)
val unresolvedSortOrders = sortOrder.filter { s =>
!s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s)
}
val aliasedOrdering =
unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())

val aggregateWithExtraOrdering = aggregate.copy(
aggregateExpressions = aggregate.aggregateExpressions ++ aliasedOrdering)

val resolvedAggregate: Aggregate =
executeSameContext(aggregatedOrdering).asInstanceOf[Aggregate]
val resolvedAliasedOrdering: Seq[Alias] =
resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]]
executeSameContext(aggregateWithExtraOrdering).asInstanceOf[Aggregate]

val (reResolvedAggExprs, resolvedAliasedOrdering) =
resolvedAggregate.aggregateExpressions.splitAt(aggregate.aggregateExpressions.length)

// If we pass the analysis check, then the ordering expressions should only reference to
// aggregate expressions or grouping expressions, and it's safe to push them down to
Expand All @@ -2401,24 +2404,25 @@ class Analyzer(override val catalogManager: CatalogManager)
// expression instead.
val needsPushDown = ArrayBuffer.empty[NamedExpression]
val orderToAlias = unresolvedSortOrders.zip(aliasedOrdering)
val evaluatedOrderings = resolvedAliasedOrdering.zip(orderToAlias).map {
case (evaluated, (order, aliasOrder)) =>
val index = originalAggExprs.indexWhere {
case Alias(child, _) => child semanticEquals evaluated.child
case other => other semanticEquals evaluated.child
}
val evaluatedOrderings =
resolvedAliasedOrdering.asInstanceOf[Seq[Alias]].zip(orderToAlias).map {
case (evaluated, (order, aliasOrder)) =>
val index = reResolvedAggExprs.indexWhere {
case Alias(child, _) => child semanticEquals evaluated.child
case other => other semanticEquals evaluated.child
}

if (index == -1) {
if (CharVarcharUtils.getRawType(evaluated.metadata).nonEmpty) {
needsPushDown += aliasOrder
order.copy(child = aliasOrder)
if (index == -1) {
if (hasCharVarchar(evaluated)) {
needsPushDown += aliasOrder
order.copy(child = aliasOrder)
} else {
needsPushDown += evaluated
order.copy(child = evaluated.toAttribute)
}
} else {
needsPushDown += evaluated
order.copy(child = evaluated.toAttribute)
order.copy(child = originalAggExprs(index).toAttribute)
}
} else {
order.copy(child = originalAggExprs(index).toAttribute)
}
}

val sortOrdersMap = unresolvedSortOrders
Expand All @@ -2443,6 +2447,13 @@ class Analyzer(override val catalogManager: CatalogManager)
}
}

def hasCharVarchar(expr: Alias): Boolean = {
expr.find {
case ne: NamedExpression => CharVarcharUtils.getRawType(ne.metadata).nonEmpty
case _ => false
}.nonEmpty
}

def containsAggregate(condition: Expression): Boolean = {
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ Input [8]: [ws_sold_date_sk#1, ws_item_sk#2, ws_web_page_sk#3, ws_order_number#4

(10) Exchange
Input [6]: [ws_sold_date_sk#1, ws_item_sk#2, ws_order_number#4, ws_quantity#5, ws_sales_price#6, ws_net_profit#7]
Arguments: hashpartitioning(cast(ws_item_sk#2 as bigint), cast(ws_order_number#4 as bigint), 5), true, [id=#10]
Arguments: hashpartitioning(cast(ws_item_sk#2 as bigint), cast(ws_order_number#4 as bigint), 5), ENSURE_REQUIREMENTS, [id=#10]

(11) Sort [codegen id : 3]
Input [6]: [ws_sold_date_sk#1, ws_item_sk#2, ws_order_number#4, ws_quantity#5, ws_sales_price#6, ws_net_profit#7]
Expand All @@ -123,7 +123,7 @@ Condition : (((((isnotnull(wr_item_sk#11) AND isnotnull(wr_order_number#16)) AND

(15) Exchange
Input [8]: [wr_item_sk#11, wr_refunded_cdemo_sk#12, wr_refunded_addr_sk#13, wr_returning_cdemo_sk#14, wr_reason_sk#15, wr_order_number#16, wr_fee#17, wr_refunded_cash#18]
Arguments: hashpartitioning(wr_item_sk#11, wr_order_number#16, 5), true, [id=#19]
Arguments: hashpartitioning(wr_item_sk#11, wr_order_number#16, 5), ENSURE_REQUIREMENTS, [id=#19]

(16) Sort [codegen id : 5]
Input [8]: [wr_item_sk#11, wr_refunded_cdemo_sk#12, wr_refunded_addr_sk#13, wr_returning_cdemo_sk#14, wr_reason_sk#15, wr_order_number#16, wr_fee#17, wr_refunded_cash#18]
Expand Down Expand Up @@ -229,7 +229,7 @@ Input [11]: [ws_quantity#5, ws_sales_price#6, ws_net_profit#7, wr_refunded_cdemo

(39) Exchange
Input [7]: [ws_quantity#5, ws_sales_price#6, wr_refunded_cdemo_sk#12, wr_returning_cdemo_sk#14, wr_fee#17, wr_refunded_cash#18, r_reason_desc#24]
Arguments: hashpartitioning(wr_refunded_cdemo_sk#12, wr_returning_cdemo_sk#14, 5), true, [id=#30]
Arguments: hashpartitioning(wr_refunded_cdemo_sk#12, wr_returning_cdemo_sk#14, 5), ENSURE_REQUIREMENTS, [id=#30]

(40) Sort [codegen id : 10]
Input [7]: [ws_quantity#5, ws_sales_price#6, wr_refunded_cdemo_sk#12, wr_returning_cdemo_sk#14, wr_fee#17, wr_refunded_cash#18, r_reason_desc#24]
Expand Down Expand Up @@ -278,7 +278,7 @@ Input [6]: [cd_demo_sk#31, cd_marital_status#32, cd_education_status#33, cd_demo

(50) Exchange
Input [4]: [cd_demo_sk#31, cd_marital_status#32, cd_education_status#33, cd_demo_sk#35]
Arguments: hashpartitioning(cast(cd_demo_sk#31 as bigint), cast(cd_demo_sk#35 as bigint), 5), true, [id=#38]
Arguments: hashpartitioning(cast(cd_demo_sk#31 as bigint), cast(cd_demo_sk#35 as bigint), 5), ENSURE_REQUIREMENTS, [id=#38]

(51) Sort [codegen id : 13]
Input [4]: [cd_demo_sk#31, cd_marital_status#32, cd_education_status#33, cd_demo_sk#35]
Expand All @@ -302,16 +302,16 @@ Results [7]: [r_reason_desc#24, sum#45, count#46, sum#47, count#48, sum#49, coun

(55) Exchange
Input [7]: [r_reason_desc#24, sum#45, count#46, sum#47, count#48, sum#49, count#50]
Arguments: hashpartitioning(r_reason_desc#24, 5), true, [id=#51]
Arguments: hashpartitioning(r_reason_desc#24, 5), ENSURE_REQUIREMENTS, [id=#51]

(56) HashAggregate [codegen id : 15]
Input [7]: [r_reason_desc#24, sum#45, count#46, sum#47, count#48, sum#49, count#50]
Keys [1]: [r_reason_desc#24]
Functions [3]: [avg(cast(ws_quantity#5 as bigint)), avg(UnscaledValue(wr_refunded_cash#18)), avg(UnscaledValue(wr_fee#17))]
Aggregate Attributes [3]: [avg(cast(ws_quantity#5 as bigint))#52, avg(UnscaledValue(wr_refunded_cash#18))#53, avg(UnscaledValue(wr_fee#17))#54]
Results [5]: [substr(r_reason_desc#24, 1, 20) AS substr(r_reason_desc, 1, 20)#55, avg(cast(ws_quantity#5 as bigint))#52 AS avg(ws_quantity)#56, cast((avg(UnscaledValue(wr_refunded_cash#18))#53 / 100.0) as decimal(11,6)) AS avg(wr_refunded_cash)#57, cast((avg(UnscaledValue(wr_fee#17))#54 / 100.0) as decimal(11,6)) AS avg(wr_fee)#58, avg(cast(ws_quantity#5 as bigint))#52 AS aggOrder#59]
Results [4]: [substr(r_reason_desc#24, 1, 20) AS substr(r_reason_desc, 1, 20)#55, avg(cast(ws_quantity#5 as bigint))#52 AS avg(ws_quantity)#56, cast((avg(UnscaledValue(wr_refunded_cash#18))#53 / 100.0) as decimal(11,6)) AS avg(wr_refunded_cash)#57, cast((avg(UnscaledValue(wr_fee#17))#54 / 100.0) as decimal(11,6)) AS avg(wr_fee)#58]

(57) TakeOrderedAndProject
Input [5]: [substr(r_reason_desc, 1, 20)#55, avg(ws_quantity)#56, avg(wr_refunded_cash)#57, avg(wr_fee)#58, aggOrder#59]
Arguments: 100, [substr(r_reason_desc, 1, 20)#55 ASC NULLS FIRST, aggOrder#59 ASC NULLS FIRST, avg(wr_refunded_cash)#57 ASC NULLS FIRST, avg(wr_fee)#58 ASC NULLS FIRST], [substr(r_reason_desc, 1, 20)#55, avg(ws_quantity)#56, avg(wr_refunded_cash)#57, avg(wr_fee)#58]
Input [4]: [substr(r_reason_desc, 1, 20)#55, avg(ws_quantity)#56, avg(wr_refunded_cash)#57, avg(wr_fee)#58]
Arguments: 100, [substr(r_reason_desc, 1, 20)#55 ASC NULLS FIRST, avg(ws_quantity)#56 ASC NULLS FIRST, avg(wr_refunded_cash)#57 ASC NULLS FIRST, avg(wr_fee)#58 ASC NULLS FIRST], [substr(r_reason_desc, 1, 20)#55, avg(ws_quantity)#56, avg(wr_refunded_cash)#57, avg(wr_fee)#58]

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
TakeOrderedAndProject [substr(r_reason_desc, 1, 20),aggOrder,avg(wr_refunded_cash),avg(wr_fee),avg(ws_quantity)]
TakeOrderedAndProject [substr(r_reason_desc, 1, 20),avg(ws_quantity),avg(wr_refunded_cash),avg(wr_fee)]
WholeStageCodegen (15)
HashAggregate [r_reason_desc,sum,count,sum,count,sum,count] [avg(cast(ws_quantity as bigint)),avg(UnscaledValue(wr_refunded_cash)),avg(UnscaledValue(wr_fee)),substr(r_reason_desc, 1, 20),avg(ws_quantity),avg(wr_refunded_cash),avg(wr_fee),aggOrder,sum,count,sum,count,sum,count]
HashAggregate [r_reason_desc,sum,count,sum,count,sum,count] [avg(cast(ws_quantity as bigint)),avg(UnscaledValue(wr_refunded_cash)),avg(UnscaledValue(wr_fee)),substr(r_reason_desc, 1, 20),avg(ws_quantity),avg(wr_refunded_cash),avg(wr_fee),sum,count,sum,count,sum,count]
InputAdapter
Exchange [r_reason_desc] #1
WholeStageCodegen (14)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,16 +272,16 @@ Results [7]: [r_reason_desc#35, sum#43, count#44, sum#45, count#46, sum#47, coun

(49) Exchange
Input [7]: [r_reason_desc#35, sum#43, count#44, sum#45, count#46, sum#47, count#48]
Arguments: hashpartitioning(r_reason_desc#35, 5), true, [id=#49]
Arguments: hashpartitioning(r_reason_desc#35, 5), ENSURE_REQUIREMENTS, [id=#49]

(50) HashAggregate [codegen id : 9]
Input [7]: [r_reason_desc#35, sum#43, count#44, sum#45, count#46, sum#47, count#48]
Keys [1]: [r_reason_desc#35]
Functions [3]: [avg(cast(ws_quantity#5 as bigint)), avg(UnscaledValue(wr_refunded_cash#15)), avg(UnscaledValue(wr_fee#14))]
Aggregate Attributes [3]: [avg(cast(ws_quantity#5 as bigint))#50, avg(UnscaledValue(wr_refunded_cash#15))#51, avg(UnscaledValue(wr_fee#14))#52]
Results [5]: [substr(r_reason_desc#35, 1, 20) AS substr(r_reason_desc, 1, 20)#53, avg(cast(ws_quantity#5 as bigint))#50 AS avg(ws_quantity)#54, cast((avg(UnscaledValue(wr_refunded_cash#15))#51 / 100.0) as decimal(11,6)) AS avg(wr_refunded_cash)#55, cast((avg(UnscaledValue(wr_fee#14))#52 / 100.0) as decimal(11,6)) AS avg(wr_fee)#56, avg(cast(ws_quantity#5 as bigint))#50 AS aggOrder#57]
Results [4]: [substr(r_reason_desc#35, 1, 20) AS substr(r_reason_desc, 1, 20)#53, avg(cast(ws_quantity#5 as bigint))#50 AS avg(ws_quantity)#54, cast((avg(UnscaledValue(wr_refunded_cash#15))#51 / 100.0) as decimal(11,6)) AS avg(wr_refunded_cash)#55, cast((avg(UnscaledValue(wr_fee#14))#52 / 100.0) as decimal(11,6)) AS avg(wr_fee)#56]

(51) TakeOrderedAndProject
Input [5]: [substr(r_reason_desc, 1, 20)#53, avg(ws_quantity)#54, avg(wr_refunded_cash)#55, avg(wr_fee)#56, aggOrder#57]
Arguments: 100, [substr(r_reason_desc, 1, 20)#53 ASC NULLS FIRST, aggOrder#57 ASC NULLS FIRST, avg(wr_refunded_cash)#55 ASC NULLS FIRST, avg(wr_fee)#56 ASC NULLS FIRST], [substr(r_reason_desc, 1, 20)#53, avg(ws_quantity)#54, avg(wr_refunded_cash)#55, avg(wr_fee)#56]
Input [4]: [substr(r_reason_desc, 1, 20)#53, avg(ws_quantity)#54, avg(wr_refunded_cash)#55, avg(wr_fee)#56]
Arguments: 100, [substr(r_reason_desc, 1, 20)#53 ASC NULLS FIRST, avg(ws_quantity)#54 ASC NULLS FIRST, avg(wr_refunded_cash)#55 ASC NULLS FIRST, avg(wr_fee)#56 ASC NULLS FIRST], [substr(r_reason_desc, 1, 20)#53, avg(ws_quantity)#54, avg(wr_refunded_cash)#55, avg(wr_fee)#56]

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
TakeOrderedAndProject [substr(r_reason_desc, 1, 20),aggOrder,avg(wr_refunded_cash),avg(wr_fee),avg(ws_quantity)]
TakeOrderedAndProject [substr(r_reason_desc, 1, 20),avg(ws_quantity),avg(wr_refunded_cash),avg(wr_fee)]
WholeStageCodegen (9)
HashAggregate [r_reason_desc,sum,count,sum,count,sum,count] [avg(cast(ws_quantity as bigint)),avg(UnscaledValue(wr_refunded_cash)),avg(UnscaledValue(wr_fee)),substr(r_reason_desc, 1, 20),avg(ws_quantity),avg(wr_refunded_cash),avg(wr_fee),aggOrder,sum,count,sum,count,sum,count]
HashAggregate [r_reason_desc,sum,count,sum,count,sum,count] [avg(cast(ws_quantity as bigint)),avg(UnscaledValue(wr_refunded_cash)),avg(UnscaledValue(wr_fee)),substr(r_reason_desc, 1, 20),avg(ws_quantity),avg(wr_refunded_cash),avg(wr_fee),sum,count,sum,count,sum,count]
InputAdapter
Exchange [r_reason_desc] #1
WholeStageCodegen (8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,17 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
checkAnswer(sql("SELECT v, sum(i) FROM t GROUP BY v ORDER BY v"), Row("c", 1))
}
}

test("SPARK-34003: fix char/varchar fails w/ order by functions") {
withTable("t") {
sql(s"CREATE TABLE t(v VARCHAR(3), i INT) USING $format")
sql("INSERT INTO t VALUES ('c', 1)")
checkAnswer(sql("SELECT substr(v, 1, 2), sum(i) FROM t GROUP BY v ORDER BY substr(v, 1, 2)"),
Row("c", 1))
checkAnswer(sql("SELECT sum(i) FROM t GROUP BY v ORDER BY substr(v, 1, 2)"),
Row(1))
}
}
}

// Some basic char/varchar tests which doesn't rely on table implementation.
Expand Down

0 comments on commit 99f8489

Please sign in to comment.