Skip to content

Commit

Permalink
Adding support for Rare & Top PPL
Browse files Browse the repository at this point in the history
top [N] <field-list> [by-clause]

N: number of results to return. Default: 10
field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------

rare <field-list> [by-clause]

field-list: mandatory. comma-delimited list of field names.
by-clause: optional. one or more fields to group the results by.
-------------------------------------------------------------------------------------------
commands:
 - opensearch-project#461
 - opensearch-project#536
Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Aug 14, 2024
1 parent 4296927 commit 4dde4ac
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import org.opensearch.sql.ast.expression.UnresolvedExpression;

import java.util.Collections;
import java.util.List;

/** Logical plan node of Rare (Aggregation) command, the interface for building aggregation actions in queries. */
public class RareAggregation extends Aggregation {
/** Aggregation Constructor without span and argument. */
public RareAggregation(
List<UnresolvedExpression> aggExprList,
List<UnresolvedExpression> sortExprList,
List<UnresolvedExpression> groupExprList) {
super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import org.opensearch.sql.ast.expression.UnresolvedExpression;

import java.util.Collections;
import java.util.List;

/** Logical plan node of Top (Aggregation) command, the interface for building aggregation actions in queries. */
public class TopAggregation extends Aggregation {
/** Aggregation Constructor without span and argument. */
public TopAggregation(
List<UnresolvedExpression> aggExprList,
List<UnresolvedExpression> sortExprList,
List<UnresolvedExpression> groupExprList) {
super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
import org.apache.spark.sql.catalyst.expressions.Ascending$;
import org.apache.spark.sql.catalyst.expressions.Descending$;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.expressions.SortDirection;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
import org.apache.spark.sql.catalyst.plans.logical.Deduplicate;
Expand Down Expand Up @@ -57,9 +60,11 @@
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Project;
import org.opensearch.sql.ast.tree.RareAggregation;
import org.opensearch.sql.ast.tree.RareTopN;
import org.opensearch.sql.ast.tree.Relation;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.TopAggregation;
import org.opensearch.sql.ppl.utils.AggregatorTranslator;
import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator;
import org.opensearch.sql.ppl.utils.ComparatorTransformer;
Expand Down Expand Up @@ -174,12 +179,27 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex
node.getChild().get(0).accept(this, context);
List<Expression> aggsExpList = visitExpressionList(node.getAggExprList(), context);
List<Expression> groupExpList = visitExpressionList(node.getGroupExprList(), context);

List<Expression> sortExpList = visitExpressionList(node.getSortExprList(), context);
if (!groupExpList.isEmpty()) {
//add group by fields to context
context.getGroupingParseExpressions().addAll(groupExpList);
}

// set sort direction according to command type
List<SortDirection> sortDirections = new ArrayList<>();
if (node instanceof RareAggregation) {
sortDirections.add(Ascending$.MODULE$);
} else if(node instanceof TopAggregation) {
sortDirections.add(Descending$.MODULE$);
}

if (!sortExpList.isEmpty()) {
visitExpressionList(node.getSortExprList(), context);
Seq<SortOrder> sortElements = context.retainAllNamedParseExpressions(exp ->
new SortOrder((NamedExpression) exp, sortDirections.get(0) , sortDirections.get(0).defaultNullOrdering(), seq(new ArrayList<Expression>())));
context.apply(p -> (LogicalPlan) new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, p));
}

UnresolvedExpression span = node.getSpan();
if (!Objects.isNull(span)) {
span.accept(this, context);
Expand All @@ -188,7 +208,7 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex
}
// build the aggregation logical step
return extractedAggregation(context);
}
}

private static LogicalPlan extractedAggregation(CatalystPlanContext context) {
Seq<Expression> groupingExpression = context.retainAllGroupingNamedParseExpressions(p -> p);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.flint.spark.ppl.OpenSearchPPLParserBaseVisitor;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.FieldsMapping;
Expand All @@ -36,11 +37,13 @@
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Parse;
import org.opensearch.sql.ast.tree.Project;
import org.opensearch.sql.ast.tree.RareAggregation;
import org.opensearch.sql.ast.tree.RareTopN;
import org.opensearch.sql.ast.tree.Relation;
import org.opensearch.sql.ast.tree.Rename;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.TopAggregation;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.ppl.utils.ArgumentFactory;

Expand Down Expand Up @@ -275,7 +278,8 @@ public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandCo
public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) {
ImmutableList.Builder<UnresolvedExpression> aggListBuilder = new ImmutableList.Builder<>();
ctx.fieldList().fieldExpression().forEach(field -> {
UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field));
UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field),
Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER))));
String name = field.qualifiedName().getText();
Alias alias = new Alias(name, aggExpression);
aggListBuilder.add(alias);
Expand All @@ -295,23 +299,24 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx)
.collect(Collectors.toList()))
.orElse(emptyList());

Aggregation aggregation =
new Aggregation(


TopAggregation aggregation =
new TopAggregation(
aggListBuilder.build(),
emptyList(),
groupList,
null,
ArgumentFactory.getArgumentList(ctx));
groupList);
return aggregation;

}

/** Rare command. */
@Override
public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) {
ImmutableList.Builder<UnresolvedExpression> aggListBuilder = new ImmutableList.Builder<>();
ImmutableList.Builder<UnresolvedExpression> sortListBuilder = new ImmutableList.Builder<>();
ctx.fieldList().fieldExpression().forEach(field -> {
UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field));
UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field),
Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER))));
String name = field.qualifiedName().getText();
Alias alias = new Alias(name, aggExpression);
aggListBuilder.add(alias);
Expand All @@ -330,15 +335,17 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct
internalVisitExpression(groupCtx)))
.collect(Collectors.toList()))
.orElse(emptyList());

Aggregation aggregation =
new Aggregation(
//build the sort fields
ctx.fieldList().fieldExpression().forEach(field -> {
sortListBuilder.add(internalVisitExpression(field));
});
RareAggregation aggregation =
new RareAggregation(
aggListBuilder.build(),
emptyList(),
groupList,
null,
ArgumentFactory.getArgumentList(ctx));
sortListBuilder.build(),
groupList);
return aggregation;

}

/** From clause. */
Expand Down

0 comments on commit 4dde4ac

Please sign in to comment.