From 4dde4ace06a3b0c1692904c4ce0ce921948f82f8 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Wed, 14 Aug 2024 15:00:02 -0700 Subject: [PATCH] Adding support for Rare & Top PPL top [N] [by-clause] N: number of results to return. Default: 10 field-list: mandatory. comma-delimited list of field names. by-clause: optional. one or more fields to group the results by. ------------------------------------------------------------------------------------------- rare [by-clause] field-list: mandatory. comma-delimited list of field names. by-clause: optional. one or more fields to group the results by. ------------------------------------------------------------------------------------------- commands: - https://github.com/opensearch-project/opensearch-spark/issues/461 - https://github.com/opensearch-project/opensearch-spark/issues/536 Signed-off-by: YANGDB --- .../sql/ast/tree/RareAggregation.java | 23 ++++++++++++ .../sql/ast/tree/TopAggregation.java | 23 ++++++++++++ .../sql/ppl/CatalystQueryPlanVisitor.java | 24 +++++++++++- .../opensearch/sql/ppl/parser/AstBuilder.java | 37 +++++++++++-------- 4 files changed, 90 insertions(+), 17 deletions(-) create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java new file mode 100644 index 000000000..55b2e4c43 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/RareAggregation.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.Collections; +import java.util.List; + +/** Logical plan node of Rare (Aggregation) command, the interface for building aggregation actions in queries. */ +public class RareAggregation extends Aggregation { + /** Aggregation Constructor without span and argument. */ + public RareAggregation( + List aggExprList, + List sortExprList, + List groupExprList) { + super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList()); + } + +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java new file mode 100644 index 000000000..1aaa69dde --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/TopAggregation.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.Collections; +import java.util.List; + +/** Logical plan node of Top (Aggregation) command, the interface for building aggregation actions in queries. */ +public class TopAggregation extends Aggregation { + /** Aggregation Constructor without span and argument. */ + public TopAggregation( + List aggExprList, + List sortExprList, + List groupExprList) { + super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList()); + } + +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 812cbea82..6ac9d3a34 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -9,9 +9,12 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.expressions.Ascending$; +import org.apache.spark.sql.catalyst.expressions.Descending$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; import org.apache.spark.sql.catalyst.plans.logical.Deduplicate; @@ -57,9 +60,11 @@ import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.RareAggregation; import org.opensearch.sql.ast.tree.RareTopN; import org.opensearch.sql.ast.tree.Relation; import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.TopAggregation; import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; @@ -174,12 +179,27 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex node.getChild().get(0).accept(this, context); List aggsExpList = visitExpressionList(node.getAggExprList(), context); List groupExpList = visitExpressionList(node.getGroupExprList(), context); - + List sortExpList = visitExpressionList(node.getSortExprList(), context); if (!groupExpList.isEmpty()) { //add group by fields to context context.getGroupingParseExpressions().addAll(groupExpList); } + // set sort direction according to command type + List sortDirections = new ArrayList<>(); + if (node instanceof RareAggregation) { + sortDirections.add(Ascending$.MODULE$); + } else if(node instanceof TopAggregation) { + sortDirections.add(Descending$.MODULE$); + } + + if (!sortExpList.isEmpty()) { + visitExpressionList(node.getSortExprList(), context); + Seq sortElements = context.retainAllNamedParseExpressions(exp -> + new SortOrder((NamedExpression) exp, sortDirections.get(0) , sortDirections.get(0).defaultNullOrdering(), seq(new ArrayList()))); + context.apply(p -> (LogicalPlan) new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, p)); + } + UnresolvedExpression span = node.getSpan(); if (!Objects.isNull(span)) { span.accept(this, context); @@ -188,7 +208,7 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex } // build the aggregation logical step return extractedAggregation(context); - } +} private static LogicalPlan extractedAggregation(CatalystPlanContext context) { Seq groupingExpression = context.retainAllGroupingNamedParseExpressions(p -> p); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 0313cb930..ca988a7d8 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -14,6 +14,7 @@ import org.opensearch.flint.spark.ppl.OpenSearchPPLParserBaseVisitor; import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.FieldsMapping; @@ -36,11 +37,13 @@ import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.RareAggregation; import org.opensearch.sql.ast.tree.RareTopN; import org.opensearch.sql.ast.tree.Relation; import org.opensearch.sql.ast.tree.Rename; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.TableFunction; +import org.opensearch.sql.ast.tree.TopAggregation; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ppl.utils.ArgumentFactory; @@ -275,7 +278,8 @@ public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandCo public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); ctx.fieldList().fieldExpression().forEach(field -> { - UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field)); + UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), + Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); Alias alias = new Alias(name, aggExpression); aggListBuilder.add(alias); @@ -295,23 +299,24 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) .collect(Collectors.toList())) .orElse(emptyList()); - Aggregation aggregation = - new Aggregation( + + + TopAggregation aggregation = + new TopAggregation( aggListBuilder.build(), emptyList(), - groupList, - null, - ArgumentFactory.getArgumentList(ctx)); + groupList); return aggregation; - } /** Rare command. */ @Override public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); + ImmutableList.Builder sortListBuilder = new ImmutableList.Builder<>(); ctx.fieldList().fieldExpression().forEach(field -> { - UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field)); + UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), + Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); Alias alias = new Alias(name, aggExpression); aggListBuilder.add(alias); @@ -330,15 +335,17 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct internalVisitExpression(groupCtx))) .collect(Collectors.toList())) .orElse(emptyList()); - - Aggregation aggregation = - new Aggregation( + //build the sort fields + ctx.fieldList().fieldExpression().forEach(field -> { + sortListBuilder.add(internalVisitExpression(field)); + }); + RareAggregation aggregation = + new RareAggregation( aggListBuilder.build(), - emptyList(), - groupList, - null, - ArgumentFactory.getArgumentList(ctx)); + sortListBuilder.build(), + groupList); return aggregation; + } /** From clause. */