From ec5239edd40988023ddd14eb22b4eaa1b07211bf Mon Sep 17 00:00:00 2001 From: Ahmed Hussein <50450311+amahussein@users.noreply.github.com> Date: Thu, 15 Feb 2024 10:21:17 -0600 Subject: [PATCH] Incorrect parsing of aggregates in DB queries (#790) * Incorrect parsing of aggregates in DB queries Fixes #786 DB uses `finalmerge_` as a prefix for final merge while Spark uses an empty prefix This PR is to replace the prefixes as follows: `finalmerge_`, `partial_`, `merge_` --------- Signed-off-by: Ahmed Hussein (amahussein) --- .../tool/planparser/SQLPlanParser.scala | 35 +++++++++++++++---- .../tool/qualification/QualOutputWriter.scala | 5 ++- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala index 683ed18c1..31fad2de5 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala @@ -306,8 +306,14 @@ object SQLPlanParser extends Logging { val windowFunctionPattern = """(\w+)\(""".r + val aggregatePrefixes = Set( + "finalmerge_", // DB specific prefix for final merge agg functions + "partial_", // used for partials + "merge_" // Used for partial merge + ) + val ignoreExpressions = Set("any", "cast", "ansi_cast", "decimal", "decimaltype", "every", - "some", "merge_max", "merge_min", "merge_sum", "merge_count", "merge_avg", "merge_first", + "some", "list", // some ops turn into literals and they should not cause any fallbacks "current_database", "current_user", "current_timestamp", @@ -625,11 +631,30 @@ object SQLPlanParser extends Logging { } private def getAllFunctionNames(regPattern: Regex, expr: String, - groupInd: Int = 1): Set[String] = { + groupInd: Int = 1, isAggr: Boolean = true): Set[String] = { // Returns all matches in an expression. This can be used when the SQL expression is not // tokenized. val newExpr = processSpecialFunctions(expr) - regPattern.findAllMatchIn(newExpr).map(_.group(groupInd)).toSet.filterNot(ignoreExpression(_)) + + // first get all the functionNames + val exprss = + regPattern.findAllMatchIn(newExpr).map(_.group(groupInd)).toSet + + // For aggregate expressions we want to process the results to remove the prefix + // DB: remove the "^partial_" and "^finalmerge_" prefixes + // TODO: + // for performance sake, we can turn off the aggregate processing by enabling it only + // when needed. However, for now, we always do this processing until we are confident we know + // the correct place to turn on/off that flag.we can use the argument isAgg only when needed + val results = if (isAggr) { + exprss.collect { + case func => + aggregatePrefixes.find(func.startsWith(_)).map(func.replaceFirst(_, "")).getOrElse(func) + } + } else { + exprss + } + results.filterNot(ignoreExpression(_)) } def parseProjectExpressions(exprStr: String): Array[String] = { @@ -665,10 +690,8 @@ object SQLPlanParser extends Logging { val group_value = m.group(group_ind) if (patternMap.getOrElse(group_value, false)) { val clauseExpr = m.group(group_ind + 1) - // Here "partial_" and "merge_" is removed and only function name is preserved. - val processedExpr = clauseExpr.replaceAll("partial_", "").replaceAll("merge_", "") // No need to split the expr any further because we are only interested in function names - val used_functions = getAllFunctionNames(functionPrefixPattern, processedExpr) + val used_functions = getAllFunctionNames(functionPrefixPattern, clauseExpr) parsedExpressions ++= used_functions } } diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualOutputWriter.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualOutputWriter.scala index 2089ba59c..3f755d65c 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualOutputWriter.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualOutputWriter.scala @@ -974,9 +974,8 @@ object QualOutputWriter { sumInfo: QualificationSummaryInfo): Set[ExecInfo] = { sumInfo.planInfo.map(_.execInfo).collect { case execInfos => - // No need to flatten the execs because by definition wholeCodeGen execs should not be part - // of that list - execInfos.filter(exec => exec.stages.isEmpty && !exec.isSupported) + val allExecs = flattenedExecs(execInfos) + allExecs.filter(exec => exec.stages.isEmpty && !exec.isSupported) }.flatten.toSet }