Skip to content

Commit

Permalink
Address concurrent segment search concern
Browse files Browse the repository at this point in the history
To save the ranges per segment, now change to a map that save ranges for segments separately.

The increment document function "incrementBucketDocCount" should already be thread safe, as it's the same method used by normal aggregation execution path

Signed-off-by: bowenlan-amzn <[email protected]>
  • Loading branch information
bowenlan-amzn committed Aug 7, 2024
1 parent 778f1ce commit 234eb44
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ void setOptimizationContext(FilterRewriteOptimizationContext context) {

/**
* Prepares the optimization at shard level after checking aggregator is optimizable.
* <p>
* For example, figure out what are the ranges from the aggregation to do the optimization later
*/
protected abstract void prepare() throws IOException;

/**
* Prepares the optimization for a specific segment and ignore whatever built at shard level
* Prepares the optimization for a specific segment when the segment is functionally matching all docs
*
* @param leaf the leaf reader context for the segment
*/
Expand All @@ -69,8 +70,9 @@ void setOptimizationContext(FilterRewriteOptimizationContext context) {
/**
* Attempts to build aggregation results for a segment
*
* @param values the point values (index structure for numeric values) for a segment
* @param incrementDocCount a consumer to increment the document count for a range bucket. The First parameter is document count, the second is the key of the bucket
* @param values the point values (index structure for numeric values) for a segment
* @param incrementDocCount a consumer to increment the document count for a range bucket. The First parameter is document count, the second is the key of the bucket
* @param leafOrd
*/
protected abstract void tryOptimize(PointValues values, BiConsumer<Long, Long> incrementDocCount) throws IOException;
protected abstract void tryOptimize(PointValues values, BiConsumer<Long, Long> incrementDocCount, int leafOrd) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ protected void buildRanges(SearchContext context) throws IOException {
@Override
protected void prepareFromSegment(LeafReaderContext leaf) throws IOException {
long[] bounds = Helper.getSegmentBounds(leaf, fieldType.name());
filterRewriteOptimizationContext.setRangesFromSegment(buildRanges(bounds));
filterRewriteOptimizationContext.setRangesFromSegment(leaf.ord, buildRanges(bounds));
}

private Ranges buildRanges(long[] bounds) {
Expand Down Expand Up @@ -123,19 +123,20 @@ protected int getSize() {
}

@Override
protected final void tryOptimize(PointValues values, BiConsumer<Long, Long> incrementDocCount) throws IOException {
protected final void tryOptimize(PointValues values, BiConsumer<Long, Long> incrementDocCount, int leafOrd) throws IOException {
int size = getSize();

DateFieldMapper.DateFieldType fieldType = getFieldType();
Ranges ranges = filterRewriteOptimizationContext.getRanges(leafOrd);
BiConsumer<Integer, Integer> incrementFunc = (activeIndex, docCount) -> {
long rangeStart = LongPoint.decodeDimension(filterRewriteOptimizationContext.getRanges().lowers[activeIndex], 0);
long rangeStart = LongPoint.decodeDimension(ranges.lowers[activeIndex], 0);
rangeStart = fieldType.convertNanosToMillis(rangeStart);
long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart));
incrementDocCount.accept(ord, (long) docCount);
long bucketOrd = getBucketOrd(bucketOrdProducer().apply(rangeStart));
incrementDocCount.accept(bucketOrd, (long) docCount);
};

filterRewriteOptimizationContext.consumeDebugInfo(
multiRangesTraverse(values.getPointTree(), filterRewriteOptimizationContext.getRanges(), incrementFunc, size)
multiRangesTraverse(values.getPointTree(), filterRewriteOptimizationContext.getRanges(leafOrd), incrementFunc, size)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import org.opensearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.BiConsumer;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
Expand All @@ -36,16 +38,16 @@ public final class FilterRewriteOptimizationContext {
private final boolean canOptimize;
private boolean preparedAtShardLevel = false;

final AggregatorBridge aggregatorBridge;
private final AggregatorBridge aggregatorBridge;
int maxAggRewriteFilters;
String shardId;
private String shardId;

private Ranges ranges;
private Ranges rangesFromSegment;
private final Map<Integer, Ranges> rangesFromSegment = new HashMap<>(); // map of segment ordinal to its ranges

// debug info related fields
private int leaf;
private int inner;
private int leafNodeVisited;
private int innerNodeVisited;
private int segments;
private int optimizedSegments;

Expand All @@ -59,6 +61,10 @@ public FilterRewriteOptimizationContext(
this.canOptimize = this.canOptimize(parent, subAggLength, context);
}

/**
* common logic for checking whether the optimization can be applied and prepare at shard level
* if the aggregation has any special logic, it should be done using {@link AggregatorBridge}
*/
private boolean canOptimize(final Object parent, final int subAggLength, SearchContext context) throws IOException {
if (context.maxAggRewriteFilters() == 0) return false;

Expand All @@ -69,31 +75,32 @@ private boolean canOptimize(final Object parent, final int subAggLength, SearchC
aggregatorBridge.setOptimizationContext(this);
this.maxAggRewriteFilters = context.maxAggRewriteFilters();
this.shardId = context.indexShard().shardId().toString();
this.prepare();

assert ranges == null : "Ranges should only be built once at shard level, but they are already built";
aggregatorBridge.prepare();
if (ranges != null) {
preparedAtShardLevel = true;
}
}
logger.debug("Fast filter rewriteable: {} for shard {}", canOptimize, shardId);

return canOptimize;
}

private void prepare() throws IOException {
assert ranges == null : "Ranges should only be built once at shard level, but they are already built";
aggregatorBridge.prepare();
if (ranges != null) {
preparedAtShardLevel = true;
}
}

void setRanges(Ranges ranges) {
this.ranges = ranges;
}

void setRangesFromSegment(Ranges ranges) {
this.rangesFromSegment = ranges;
void setRangesFromSegment(int leafOrd, Ranges ranges) {
this.rangesFromSegment.put(leafOrd, ranges);
}

void clearRangesFromSegment(int leafOrd) {
this.rangesFromSegment.remove(leafOrd);
}

Ranges getRanges() {
if (rangesFromSegment != null) return rangesFromSegment;
Ranges getRanges(int leafOrd) {
if (!preparedAtShardLevel) return rangesFromSegment.get(leafOrd);
return ranges;
}

Expand Down Expand Up @@ -132,13 +139,13 @@ public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer<Lon
Ranges ranges = tryBuildRangesFromSegment(leafCtx, segmentMatchAll);
if (ranges == null) return false;

aggregatorBridge.tryOptimize(values, incrementDocCount);
aggregatorBridge.tryOptimize(values, incrementDocCount, leafCtx.ord);

optimizedSegments++;
logger.debug("Fast filter optimization applied to shard {} segment {}", shardId, leafCtx.ord);
logger.debug("crossed leaf nodes: {}, inner nodes: {}", leaf, inner);
logger.debug("Crossed leaf nodes: {}, inner nodes: {}", leafNodeVisited, innerNodeVisited);

rangesFromSegment = null;
clearRangesFromSegment(leafCtx.ord);
return true;
}

Expand All @@ -151,41 +158,40 @@ private Ranges tryBuildRangesFromSegment(LeafReaderContext leafCtx, boolean segm
return null;
}

if (ranges == null) { // not built at shard level but segment match all
if (!preparedAtShardLevel) { // not built at shard level but segment match all
logger.debug("Shard {} segment {} functionally match all documents. Build the fast filter", shardId, leafCtx.ord);
aggregatorBridge.prepareFromSegment(leafCtx);
return rangesFromSegment;
}
return ranges;
return getRanges(leafCtx.ord);
}

/**
* Contains debug info of BKD traversal to show in profile
*/
static class DebugInfo {
private int leaf = 0; // leaf node visited
private int inner = 0; // inner node visited
private int leafNodeVisited = 0; // leaf node visited
private int innerNodeVisited = 0; // inner node visited

void visitLeaf() {
leaf++;
leafNodeVisited++;
}

void visitInner() {
inner++;
innerNodeVisited++;
}
}

void consumeDebugInfo(DebugInfo debug) {
leaf += debug.leaf;
inner += debug.inner;
leafNodeVisited += debug.leafNodeVisited;
innerNodeVisited += debug.innerNodeVisited;
}

public void populateDebugInfo(BiConsumer<String, Object> add) {
if (optimizedSegments > 0) {
add.accept("optimized_segments", optimizedSegments);
add.accept("unoptimized_segments", segments - optimizedSegments);
add.accept("leaf_visited", leaf);
add.accept("inner_visited", inner);
add.accept("leaf_node_visited", leafNodeVisited);
add.accept("inner_node_visited", innerNodeVisited);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,16 @@ public void prepareFromSegment(LeafReaderContext leaf) {
}

@Override
protected final void tryOptimize(PointValues values, BiConsumer<Long, Long> incrementDocCount) throws IOException {
protected final void tryOptimize(PointValues values, BiConsumer<Long, Long> incrementDocCount, int leafOrd) throws IOException {
int size = Integer.MAX_VALUE;

BiConsumer<Integer, Integer> incrementFunc = (activeIndex, docCount) -> {
long ord = bucketOrdProducer().apply(activeIndex);
incrementDocCount.accept(ord, (long) docCount);
long bucketOrd = bucketOrdProducer().apply(activeIndex);
incrementDocCount.accept(bucketOrd, (long) docCount);
};

filterRewriteOptimizationContext.consumeDebugInfo(
multiRangesTraverse(values.getPointTree(), filterRewriteOptimizationContext.getRanges(), incrementFunc, size)
multiRangesTraverse(values.getPointTree(), filterRewriteOptimizationContext.getRanges(leafOrd), incrementFunc, size)
);
}

Expand Down

0 comments on commit 234eb44

Please sign in to comment.