Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
extract segment match all logic

Signed-off-by: bowenlan-amzn <[email protected]>
  • Loading branch information
bowenlan-amzn committed Jun 20, 2024
1 parent e8e9ad3 commit 7c491b9
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ protected int getSize() {
protected Function<Object, Long> bucketOrdProducer() {
return (key) -> bucketOrds.add(0, getRoundingPrepared().round((long) key));
}

@Override
protected boolean segmentMatchAll(LeafReaderContext leaf) throws IOException {
return segmentMatchAll(context, leaf);
}
}

@Override
Expand Down Expand Up @@ -565,7 +570,7 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
boolean optimized = optimizationContext.tryFastFilterAggregation(ctx, this::incrementBucketDocCount, context);
boolean optimized = optimizationContext.tryFastFilterAggregation(ctx, this::incrementBucketDocCount);
if (optimized) throw new CollectionTerminatedException();

finishLeaf();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ protected Prepared getRoundingPrepared() {
protected Function<Object, Long> bucketOrdProducer() {
return (key) -> getBucketOrds().add(0, preparedRounding.round((long) key));
}

@Override
protected boolean segmentMatchAll(LeafReaderContext leaf) throws IOException {
return segmentMatchAll(context, leaf);
}
}

protected abstract LongKeyedBucketOrds getBucketOrds();
Expand Down Expand Up @@ -241,7 +246,7 @@ public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBuc
return LeafBucketCollector.NO_OP_COLLECTOR;
}

boolean optimized = optimizationContext.tryFastFilterAggregation(ctx, this::incrementBucketDocCount, context);
boolean optimized = optimizationContext.tryFastFilterAggregation(ctx, this::incrementBucketDocCount);
if (optimized) throw new CollectionTerminatedException();

final SortedNumericDocValues values = valuesSource.longValues(ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ protected long[] processHardBounds(long[] bounds) {
protected Function<Object, Long> bucketOrdProducer() {
return (key) -> bucketOrds.add(0, preparedRounding.round((long) key));
}

@Override
protected boolean segmentMatchAll(LeafReaderContext leaf) throws IOException {
return segmentMatchAll(context, leaf);
}
}

@Override
Expand All @@ -171,7 +176,7 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol
return LeafBucketCollector.NO_OP_COLLECTOR;
}

boolean optimized = optimizationContext.tryFastFilterAggregation(ctx, this::incrementBucketDocCount, context);
boolean optimized = optimizationContext.tryFastFilterAggregation(ctx, this::incrementBucketDocCount);
if (optimized) throw new CollectionTerminatedException();

SortedNumericDocValues values = valuesSource.longValues(ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ public ScoreMode scoreMode() {

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
boolean optimized = optimizationContext.tryFastFilterAggregation(ctx, this::incrementBucketDocCount, context);
boolean optimized = optimizationContext.tryFastFilterAggregation(ctx, this::incrementBucketDocCount);
if (optimized) throw new CollectionTerminatedException();

final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.opensearch.common.Rounding;
import org.opensearch.index.mapper.DateFieldMapper;
import org.opensearch.index.mapper.MappedFieldType;
Expand Down Expand Up @@ -154,4 +156,9 @@ private static long getBucketOrd(long bucketOrd) {

return bucketOrd;
}

protected boolean segmentMatchAll(SearchContext ctx, LeafReaderContext leafCtx) throws IOException {
Weight weight = ctx.query().rewrite(ctx.searcher()).createWeight(ctx.searcher(), ScoreMode.COMPLETE_NO_SCORES, 1f);
return weight != null && weight.count(leafCtx) == leafCtx.reader().numDocs();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,8 @@ void setOptimizationContext(OptimizationContext optimizationContext) {
abstract void tryFastFilterAggregation(PointValues values, BiConsumer<Long, Long> incrementDocCount, Ranges ranges) throws IOException;

protected abstract Function<Object, Long> bucketOrdProducer();

protected boolean segmentMatchAll(LeafReaderContext leaf) throws IOException {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.ArrayUtil;
import org.opensearch.common.CheckedRunnable;
import org.opensearch.index.mapper.DocCountFieldMapper;
Expand Down Expand Up @@ -65,13 +63,13 @@ public boolean canOptimize(final Object parent, final int subAggLength, SearchCo
if (parent != null || subAggLength != 0) return false;

boolean rewriteable = aggregatorBridge.canOptimize();
logger.debug("Fast filter rewriteable: {} for shard {}", rewriteable, context.indexShard().shardId());
this.rewriteable = rewriteable;
if (rewriteable) {
aggregatorBridge.setOptimizationContext(this);
this.maxAggRewriteFilters = context.maxAggRewriteFilters();
this.shardId = context.indexShard().shardId().toString();
}
logger.debug("Fast filter rewriteable: {} for shard {}", rewriteable, shardId);
return rewriteable;
}

Expand All @@ -98,11 +96,8 @@ void setRanges(Ranges ranges) {
*
* @param incrementDocCount consume the doc_count results for certain ordinal
*/
public boolean tryFastFilterAggregation(
final LeafReaderContext leafCtx,
final BiConsumer<Long, Long> incrementDocCount,
SearchContext context
) throws IOException {
public boolean tryFastFilterAggregation(final LeafReaderContext leafCtx, final BiConsumer<Long, Long> incrementDocCount)
throws IOException {
segments++;
if (!rewriteable) {
return false;
Expand All @@ -125,20 +120,8 @@ public boolean tryFastFilterAggregation(
return false;
}

// even if no ranges built at shard level, we can still perform the optimization
// when functionally match-all at segment level
if (!rangesBuiltAtShardLevel && !segmentMatchAll(context, leafCtx)) {
return false;
}

Ranges ranges = this.ranges;
if (ranges == null) { // not built at shard level but segment match all
logger.debug("Shard {} segment {} functionally match all documents. Build the fast filter", shardId, leafCtx.ord);
ranges = buildRanges(leafCtx);
if (ranges == null) {
return false;
}
}
Ranges ranges = tryGetRangesFromSegment(leafCtx);
if (ranges == null) return false;

aggregatorBridge.tryFastFilterAggregation(values, incrementDocCount, ranges);

Expand All @@ -148,9 +131,21 @@ public boolean tryFastFilterAggregation(
return true;
}

public static boolean segmentMatchAll(SearchContext ctx, LeafReaderContext leafCtx) throws IOException {
Weight weight = ctx.query().rewrite(ctx.searcher()).createWeight(ctx.searcher(), ScoreMode.COMPLETE_NO_SCORES, 1f);
return weight != null && weight.count(leafCtx) == leafCtx.reader().numDocs();
/**
* Even when ranges cannot be built at shard level, we can still build ranges
* at segment level when it's functionally match-all at segment level
*/
private Ranges tryGetRangesFromSegment(LeafReaderContext leafCtx) throws IOException {
if (!rangesBuiltAtShardLevel && !aggregatorBridge.segmentMatchAll(leafCtx)) {
return null;
}

Ranges ranges = this.ranges;
if (ranges == null) { // not built at shard level but segment match all
logger.debug("Shard {} segment {} functionally match all documents. Build the fast filter", shardId, leafCtx.ord);
ranges = buildRanges(leafCtx);
}
return ranges;
}

/**
Expand Down

0 comments on commit 7c491b9

Please sign in to comment.