Skip to content

Commit

Permalink
Adds support for leveraging StarTree index in conjunction with filter…
Browse files Browse the repository at this point in the history
…ed aggregations (apache#11886)
  • Loading branch information
egalpin authored Nov 18, 2023
1 parent 6aecd41 commit f7f8260
Show file tree
Hide file tree
Showing 13 changed files with 335 additions and 192 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.pinot.core.query.aggregation.AggregationExecutor;
import org.apache.pinot.core.query.aggregation.DefaultAggregationExecutor;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils.AggregationInfo;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.core.startree.executor.StarTreeAggregationExecutor;

Expand All @@ -42,18 +43,17 @@ public class AggregationOperator extends BaseOperator<AggregationResultsBlock> {
private final QueryContext _queryContext;
private final AggregationFunction[] _aggregationFunctions;
private final BaseProjectOperator<?> _projectOperator;
private final long _numTotalDocs;
private final boolean _useStarTree;
private final long _numTotalDocs;

private int _numDocsScanned = 0;

public AggregationOperator(QueryContext queryContext, BaseProjectOperator<?> projectOperator, long numTotalDocs,
boolean useStarTree) {
public AggregationOperator(QueryContext queryContext, AggregationInfo aggregationInfo, long numTotalDocs) {
_queryContext = queryContext;
_aggregationFunctions = queryContext.getAggregationFunctions();
_projectOperator = projectOperator;
_projectOperator = aggregationInfo.getProjectOperator();
_useStarTree = aggregationInfo.isUseStarTree();
_numTotalDocs = numTotalDocs;
_useStarTree = useStarTree;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.IdentityHashMap;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.core.common.Operator;
import org.apache.pinot.core.operator.BaseOperator;
import org.apache.pinot.core.operator.BaseProjectOperator;
Expand All @@ -32,7 +31,9 @@
import org.apache.pinot.core.query.aggregation.AggregationExecutor;
import org.apache.pinot.core.query.aggregation.DefaultAggregationExecutor;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils.AggregationInfo;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.core.startree.executor.StarTreeAggregationExecutor;


/**
Expand All @@ -47,18 +48,18 @@ public class FilteredAggregationOperator extends BaseOperator<AggregationResults

private final QueryContext _queryContext;
private final AggregationFunction[] _aggregationFunctions;
private final List<Pair<AggregationFunction[], BaseProjectOperator<?>>> _projectOperators;
private final List<AggregationInfo> _aggregationInfos;
private final long _numTotalDocs;

private long _numDocsScanned;
private long _numEntriesScannedInFilter;
private long _numEntriesScannedPostFilter;

public FilteredAggregationOperator(QueryContext queryContext,
List<Pair<AggregationFunction[], BaseProjectOperator<?>>> projectOperators, long numTotalDocs) {
public FilteredAggregationOperator(QueryContext queryContext, List<AggregationInfo> aggregationInfos,
long numTotalDocs) {
_queryContext = queryContext;
_aggregationFunctions = queryContext.getAggregationFunctions();
_projectOperators = projectOperators;
_aggregationInfos = aggregationInfos;
_numTotalDocs = numTotalDocs;
}

Expand All @@ -71,10 +72,16 @@ protected AggregationResultsBlock getNextBlock() {
resultIndexMap.put(_aggregationFunctions[i], i);
}

for (Pair<AggregationFunction[], BaseProjectOperator<?>> pair : _projectOperators) {
AggregationFunction[] aggregationFunctions = pair.getLeft();
AggregationExecutor aggregationExecutor = new DefaultAggregationExecutor(aggregationFunctions);
BaseProjectOperator<?> projectOperator = pair.getRight();
for (AggregationInfo aggregationInfo : _aggregationInfos) {
AggregationFunction[] aggregationFunctions = aggregationInfo.getFunctions();
BaseProjectOperator<?> projectOperator = aggregationInfo.getProjectOperator();
AggregationExecutor aggregationExecutor;
if (aggregationInfo.isUseStarTree()) {
aggregationExecutor = new StarTreeAggregationExecutor(aggregationFunctions);
} else {
aggregationExecutor = new DefaultAggregationExecutor(aggregationFunctions);
}

ValueBlock valueBlock;
int numDocsScanned = 0;
while ((valueBlock = projectOperator.nextBlock()) != null) {
Expand All @@ -95,7 +102,7 @@ protected AggregationResultsBlock getNextBlock() {

@Override
public List<Operator> getChildOperators() {
return _projectOperators.stream().map(Pair::getRight).collect(Collectors.toList());
return _aggregationInfos.stream().map(AggregationInfo::getProjectOperator).collect(Collectors.toList());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils.AggregationInfo;
import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.core.startree.executor.StarTreeGroupByExecutor;
import org.apache.pinot.core.util.GroupByUtils;
import org.apache.pinot.spi.trace.Tracing;

Expand All @@ -56,22 +58,21 @@ public class FilteredGroupByOperator extends BaseOperator<GroupByResultsBlock> {
private final QueryContext _queryContext;
private final AggregationFunction[] _aggregationFunctions;
private final ExpressionContext[] _groupByExpressions;
private final List<Pair<AggregationFunction[], BaseProjectOperator<?>>> _projectOperators;
private final List<AggregationInfo> _aggregationInfos;
private final long _numTotalDocs;
private final DataSchema _dataSchema;

private long _numDocsScanned;
private long _numEntriesScannedInFilter;
private long _numEntriesScannedPostFilter;

public FilteredGroupByOperator(QueryContext queryContext,
List<Pair<AggregationFunction[], BaseProjectOperator<?>>> projectOperators, long numTotalDocs) {
public FilteredGroupByOperator(QueryContext queryContext, List<AggregationInfo> aggregationInfos, long numTotalDocs) {
assert queryContext.getAggregationFunctions() != null && queryContext.getFilteredAggregationFunctions() != null
&& queryContext.getGroupByExpressions() != null;
_queryContext = queryContext;
_aggregationFunctions = queryContext.getAggregationFunctions();
_groupByExpressions = queryContext.getGroupByExpressions().toArray(new ExpressionContext[0]);
_projectOperators = projectOperators;
_aggregationInfos = aggregationInfos;
_numTotalDocs = numTotalDocs;

// NOTE: The indexedTable expects that the data schema will have group by columns before aggregation columns
Expand All @@ -82,7 +83,7 @@ public FilteredGroupByOperator(QueryContext queryContext,
DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[numColumns];

// Extract column names and data types for group-by columns
BaseProjectOperator<?> projectOperator = projectOperators.get(0).getRight();
BaseProjectOperator<?> projectOperator = aggregationInfos.get(0).getProjectOperator();
for (int i = 0; i < numGroupByExpressions; i++) {
ExpressionContext groupByExpression = _groupByExpressions[i];
columnNames[i] = groupByExpression.toString();
Expand All @@ -105,9 +106,7 @@ public FilteredGroupByOperator(QueryContext queryContext,

@Override
protected GroupByResultsBlock getNextBlock() {
// TODO(egalpin): Support Startree query resolution when possible, even with FILTER expressions
int numAggregations = _aggregationFunctions.length;

GroupByResultHolder[] groupByResultHolders = new GroupByResultHolder[numAggregations];
IdentityHashMap<AggregationFunction, Integer> resultHolderIndexMap =
new IdentityHashMap<>(_aggregationFunctions.length);
Expand All @@ -116,9 +115,9 @@ protected GroupByResultsBlock getNextBlock() {
}

GroupKeyGenerator groupKeyGenerator = null;
for (Pair<AggregationFunction[], BaseProjectOperator<?>> pair : _projectOperators) {
AggregationFunction[] aggregationFunctions = pair.getLeft();
BaseProjectOperator<?> projectOperator = pair.getRight();
for (AggregationInfo aggregationInfo : _aggregationInfos) {
AggregationFunction[] aggregationFunctions = aggregationInfo.getFunctions();
BaseProjectOperator<?> projectOperator = aggregationInfo.getProjectOperator();

// Perform aggregation group-by on all the blocks
DefaultGroupByExecutor groupByExecutor;
Expand All @@ -130,13 +129,24 @@ protected GroupByResultsBlock getNextBlock() {
// the GroupByExecutor to have sole ownership of the GroupKeyGenerator. Therefore, we allow constructing a
// GroupByExecutor with a pre-existing GroupKeyGenerator so that the GroupKeyGenerator can be shared across
// loop iterations i.e. across all aggs.
groupByExecutor =
new DefaultGroupByExecutor(_queryContext, aggregationFunctions, _groupByExpressions, projectOperator);
if (aggregationInfo.isUseStarTree()) {
groupByExecutor =
new StarTreeGroupByExecutor(_queryContext, aggregationFunctions, _groupByExpressions, projectOperator);
} else {
groupByExecutor =
new DefaultGroupByExecutor(_queryContext, aggregationFunctions, _groupByExpressions, projectOperator);
}
groupKeyGenerator = groupByExecutor.getGroupKeyGenerator();
} else {
groupByExecutor =
new DefaultGroupByExecutor(_queryContext, aggregationFunctions, _groupByExpressions, projectOperator,
groupKeyGenerator);
if (aggregationInfo.isUseStarTree()) {
groupByExecutor =
new StarTreeGroupByExecutor(_queryContext, aggregationFunctions, _groupByExpressions, projectOperator,
groupKeyGenerator);
} else {
groupByExecutor =
new DefaultGroupByExecutor(_queryContext, aggregationFunctions, _groupByExpressions, projectOperator,
groupKeyGenerator);
}
}

int numDocsScanned = 0;
Expand Down Expand Up @@ -191,7 +201,7 @@ protected GroupByResultsBlock getNextBlock() {

@Override
public List<Operator> getChildOperators() {
return _projectOperators.stream().map(Pair::getRight).collect(Collectors.toList());
return _aggregationInfos.stream().map(AggregationInfo::getProjectOperator).collect(Collectors.toList());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.pinot.core.operator.blocks.ValueBlock;
import org.apache.pinot.core.operator.blocks.results.GroupByResultsBlock;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils.AggregationInfo;
import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
import org.apache.pinot.core.query.aggregation.groupby.GroupByExecutor;
import org.apache.pinot.core.query.request.context.QueryContext;
Expand All @@ -51,31 +52,31 @@ public class GroupByOperator extends BaseOperator<GroupByResultsBlock> {
private final AggregationFunction[] _aggregationFunctions;
private final ExpressionContext[] _groupByExpressions;
private final BaseProjectOperator<?> _projectOperator;
private final long _numTotalDocs;
private final boolean _useStarTree;
private final long _numTotalDocs;
private final DataSchema _dataSchema;

private int _numDocsScanned = 0;

public GroupByOperator(QueryContext queryContext, ExpressionContext[] groupByExpressions,
BaseProjectOperator<?> projectOperator, long numTotalDocs, boolean useStarTree) {
public GroupByOperator(QueryContext queryContext, AggregationInfo aggregationInfo, long numTotalDocs) {
assert queryContext.getAggregationFunctions() != null && queryContext.getGroupByExpressions() != null;
_queryContext = queryContext;
_aggregationFunctions = queryContext.getAggregationFunctions();
_groupByExpressions = groupByExpressions;
_projectOperator = projectOperator;
_groupByExpressions = queryContext.getGroupByExpressions().toArray(new ExpressionContext[0]);
_projectOperator = aggregationInfo.getProjectOperator();
_useStarTree = aggregationInfo.isUseStarTree();
_numTotalDocs = numTotalDocs;
_useStarTree = useStarTree;

// NOTE: The indexedTable expects that the the data schema will have group by columns before aggregation columns
int numGroupByExpressions = groupByExpressions.length;
// NOTE: The indexedTable expects that the data schema will have group by columns before aggregation columns
int numGroupByExpressions = _groupByExpressions.length;
int numAggregationFunctions = _aggregationFunctions.length;
int numColumns = numGroupByExpressions + numAggregationFunctions;
String[] columnNames = new String[numColumns];
DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[numColumns];

// Extract column names and data types for group-by columns
for (int i = 0; i < numGroupByExpressions; i++) {
ExpressionContext groupByExpression = groupByExpressions[i];
ExpressionContext groupByExpression = _groupByExpressions[i];
columnNames[i] = groupByExpression.toString();
columnDataTypes[i] = DataSchema.ColumnDataType.fromDataTypeSV(
_projectOperator.getResultColumnContext(groupByExpression).getDataType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,8 @@

import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.core.common.Operator;
import org.apache.pinot.core.operator.BaseProjectOperator;
import org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock;
import org.apache.pinot.core.operator.filter.BaseFilterOperator;
import org.apache.pinot.core.operator.query.AggregationOperator;
Expand All @@ -34,15 +30,11 @@
import org.apache.pinot.core.operator.query.NonScanBasedAggregationOperator;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils.AggregationInfo;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.core.startree.CompositePredicateEvaluator;
import org.apache.pinot.core.startree.StarTreeUtils;
import org.apache.pinot.core.startree.plan.StarTreeProjectPlanNode;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.segment.spi.IndexSegment;
import org.apache.pinot.segment.spi.datasource.DataSource;
import org.apache.pinot.segment.spi.index.startree.AggregationFunctionColumnPair;
import org.apache.pinot.segment.spi.index.startree.StarTreeV2;

import static org.apache.pinot.segment.spi.AggregationFunctionType.*;

Expand Down Expand Up @@ -81,9 +73,8 @@ public Operator<AggregationResultsBlock> run() {
* Build the operator to be used for filtered aggregations
*/
private FilteredAggregationOperator buildFilteredAggOperator() {
List<Pair<AggregationFunction[], BaseProjectOperator<?>>> projectOperators =
AggregationFunctionUtils.buildFilteredAggregateProjectOperators(_indexSegment, _queryContext);
return new FilteredAggregationOperator(_queryContext, projectOperators,
return new FilteredAggregationOperator(_queryContext,
AggregationFunctionUtils.buildFilteredAggregationInfos(_indexSegment, _queryContext),
_indexSegment.getSegmentMetadata().getTotalDocs());
}

Expand All @@ -93,11 +84,10 @@ private FilteredAggregationOperator buildFilteredAggOperator() {
* aggregates code will be invoked
*/
public Operator<AggregationResultsBlock> buildNonFilteredAggOperator() {
assert _queryContext.getAggregationFunctions() != null;

int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
AggregationFunction[] aggregationFunctions = _queryContext.getAggregationFunctions();
assert aggregationFunctions != null;

int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
FilterPlanNode filterPlanNode = new FilterPlanNode(_indexSegment, _queryContext);
BaseFilterOperator filterOperator = filterPlanNode.run();

Expand All @@ -117,38 +107,12 @@ public Operator<AggregationResultsBlock> buildNonFilteredAggOperator() {
}
return new NonScanBasedAggregationOperator(_queryContext, dataSources, numTotalDocs);
}

// Use star-tree to solve the query if possible
List<StarTreeV2> starTrees = _indexSegment.getStarTrees();
if (!filterOperator.isResultEmpty() && starTrees != null && !_queryContext.isSkipStarTree()) {
AggregationFunctionColumnPair[] aggregationFunctionColumnPairs =
StarTreeUtils.extractAggregationFunctionPairs(aggregationFunctions);
if (aggregationFunctionColumnPairs != null) {
Map<String, List<CompositePredicateEvaluator>> predicateEvaluatorsMap =
StarTreeUtils.extractPredicateEvaluatorsMap(_indexSegment, _queryContext.getFilter(),
filterPlanNode.getPredicateEvaluators());
if (predicateEvaluatorsMap != null) {
for (StarTreeV2 starTreeV2 : starTrees) {
if (StarTreeUtils.isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs, null,
predicateEvaluatorsMap.keySet())) {
BaseProjectOperator<?> projectOperator =
new StarTreeProjectPlanNode(_queryContext, starTreeV2, aggregationFunctionColumnPairs, null,
predicateEvaluatorsMap).run();
return new AggregationOperator(_queryContext, projectOperator, numTotalDocs, true);
}
}
}
}
}
}

// TODO: Do not create ProjectOperator when filter result is empty
Set<ExpressionContext> expressionsToTransform =
AggregationFunctionUtils.collectExpressionsToTransform(aggregationFunctions, null);
BaseProjectOperator<?> projectOperator =
new ProjectPlanNode(_indexSegment, _queryContext, expressionsToTransform, DocIdSetPlanNode.MAX_DOC_PER_CALL,
filterOperator).run();
return new AggregationOperator(_queryContext, projectOperator, numTotalDocs, false);
AggregationInfo aggregationInfo =
AggregationFunctionUtils.buildAggregationInfo(_indexSegment, _queryContext, aggregationFunctions,
_queryContext.getFilter(), filterOperator, filterPlanNode.getPredicateEvaluators());
return new AggregationOperator(_queryContext, aggregationInfo, numTotalDocs);
}

/**
Expand Down
Loading

0 comments on commit f7f8260

Please sign in to comment.