Skip to content

Commit

Permalink
Adding logic for optimizing auto date histogram
Browse files Browse the repository at this point in the history
Signed-off-by: Ankit Jain <[email protected]>
  • Loading branch information
jainankitk committed Nov 18, 2023
1 parent 88b7cca commit 0c27add
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.CollectionUtil;
import org.apache.lucene.util.NumericUtils;
import org.opensearch.common.Rounding;
import org.opensearch.common.Rounding.Prepared;
import org.opensearch.common.lease.Releasables;
Expand Down Expand Up @@ -169,14 +173,14 @@ public final DeferringBucketCollector getDeferringCollector() {
return deferringCollector;
}

protected abstract LeafBucketCollector getLeafCollector(SortedNumericDocValues values, LeafBucketCollector sub) throws IOException;
protected abstract LeafBucketCollector getLeafCollector2(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException;

@Override
public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
if (valuesSource == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}
return getLeafCollector(valuesSource.longValues(ctx), sub);
return getLeafCollector2(ctx, sub);
}

protected final InternalAggregation[] buildAggregations(
Expand Down Expand Up @@ -263,6 +267,9 @@ private static class FromSingle extends AutoDateHistogramAggregator {
private long min = Long.MAX_VALUE;
private long max = Long.MIN_VALUE;

private Weight[] filters = null;
private final ValuesSource.Numeric valuesSource;

FromSingle(
String name,
AggregatorFactories factories,
Expand All @@ -288,13 +295,80 @@ private static class FromSingle extends AutoDateHistogramAggregator {

preparedRounding = prepareRounding(0);
bucketOrds = new LongKeyedBucketOrds.FromSingle(context.bigArrays());

this.valuesSource = valuesSourceConfig.hasValues() ? (ValuesSource.Numeric) valuesSourceConfig.getValuesSource() : null;
// Create the filters for fast aggregation only if the query is instance
// of point range query and there aren't any parent/sub aggregations
if (parent() == null && subAggregators.length == 0) {
final String fieldName = valuesSourceConfig.fieldContext().field();
final long[] bounds = FilterRewriteHelper.getAggregationBounds(context, fieldName);
if (bounds != null) {
final Rounding rounding = getMinimumRounding(bounds[0], bounds[1]);
filters = FilterRewriteHelper.createFilterForAggregations(context, rounding, preparedRounding, fieldName, bounds[0], bounds[1]);
}
}
}

private Rounding getMinimumRounding(final long low, final long high) {
// max - min / targetBuckets = bestDuration
// find the right innerInterval this bestDuration belongs to
// since we cannot exceed targetBuckets, bestDuration should go up,
// so the right innerInterval should be an upper bound
long bestDuration = (high - low) / targetBuckets;
while (roundingIdx < roundingInfos.length - 1) {
final RoundingInfo curRoundingInfo = roundingInfos[roundingIdx];
final int temp = curRoundingInfo.innerIntervals[curRoundingInfo.innerIntervals.length-1];
// If the interval duration is covered by the maximum inner interval,
// we can start with this outer interval for creating the buckets
if (bestDuration <= temp * curRoundingInfo.roughEstimateDurationMillis) {
break;
}
roundingIdx++;
}

preparedRounding = prepareRounding(roundingIdx);
return roundingInfos[roundingIdx].rounding;
}

boolean tryFastFilterAggregation(LeafReaderContext ctx, long owningBucketOrd) throws IOException {
final int[] counts = new int[filters.length];
int i;
for (i = 0; i < filters.length; i++) {
counts[i] = filters[i].count(ctx);
if (counts[i] == -1) {
// Cannot use the optimization if any of the counts
// is -1 indicating the segment might have deleted documents
return false;
}
}

for (i = 0; i < filters.length; i++) {
long bucketOrd = bucketOrds.add(
owningBucketOrd,
preparedRounding.round(NumericUtils.sortableBytesToLong(((PointRangeQuery) filters[i].getQuery()).getLowerPoint(), 0))
);
if (bucketOrd < 0) { // already seen
bucketOrd = -1 - bucketOrd;
}
incrementBucketDocCount(bucketOrd, counts[i]);
}
throw new CollectionTerminatedException();
}

@Override
protected LeafBucketCollector getLeafCollector(SortedNumericDocValues values, LeafBucketCollector sub) throws IOException {
protected LeafBucketCollector getLeafCollector2(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
SortedNumericDocValues values = valuesSource.longValues(ctx);

final boolean[] useOpt = new boolean[1];
useOpt[0] = filters != null;

return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
if (useOpt[0]) {
useOpt[0] = tryFastFilterAggregation(ctx, owningBucketOrd);
}

assert owningBucketOrd == 0;
if (false == values.advanceExact(doc)) {
return;
Expand Down Expand Up @@ -471,6 +545,8 @@ private static class FromMany extends AutoDateHistogramAggregator {
*/
private int rebucketCount = 0;

private final ValuesSource.Numeric valuesSource;

FromMany(
String name,
AggregatorFactories factories,
Expand Down Expand Up @@ -505,10 +581,12 @@ private static class FromMany extends AutoDateHistogramAggregator {
preparedRoundings[0] = roundingPreparer.apply(roundingInfos[0].rounding);
bucketOrds = new LongKeyedBucketOrds.FromMany(context.bigArrays());
liveBucketCountUnderestimate = context.bigArrays().newIntArray(1, true);
this.valuesSource = valuesSourceConfig.hasValues() ? (ValuesSource.Numeric) valuesSourceConfig.getValuesSource() : null;
}

@Override
protected LeafBucketCollector getLeafCollector(SortedNumericDocValues values, LeafBucketCollector sub) throws IOException {
protected LeafBucketCollector getLeafCollector2(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
SortedNumericDocValues values = valuesSource.longValues(ctx);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.CollectionUtil;
Expand Down Expand Up @@ -122,22 +120,16 @@ class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAg
// of point range query and there aren't any parent/sub aggregations
if (parent() == null && subAggregators.length == 0) {
final String fieldName = valuesSourceConfig.fieldContext().field();
final Query cq = FilterRewriteHelper.unwrapIntoConcreteQuery(context.query());
if (cq instanceof PointRangeQuery) {
final PointRangeQuery prq = (PointRangeQuery) cq;
// Ensure that the query and aggregation are on the same field
if (prq.getField().equals(fieldName)) {
filters = FilterRewriteHelper.createFilterForAggregations(
context,
rounding,
preparedRounding,
fieldName,
NumericUtils.sortableBytesToLong(prq.getLowerPoint(), 0),
NumericUtils.sortableBytesToLong(prq.getUpperPoint(), 0)
);
}
} else if (cq instanceof MatchAllDocsQuery) {
filters = FilterRewriteHelper.findBoundsAndCreateFilters(context, rounding, preparedRounding, fieldName);
final long[] bounds = FilterRewriteHelper.getAggregationBounds(context, fieldName);
if (bounds != null) {
filters = FilterRewriteHelper.createFilterForAggregations(
context,
rounding,
preparedRounding,
fieldName,
bounds[0],
bounds[1]
);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
Expand Down Expand Up @@ -63,14 +64,8 @@ public static Query unwrapIntoConcreteQuery(Query query) {

/**
* Finds the min and max bounds for segments within the passed search context
* and creates the weight filters using range queries within those bounds
*/
public static Weight[] findBoundsAndCreateFilters(
final SearchContext context,
final Rounding rounding,
final Rounding.Prepared preparedRounding,
final String fieldName
) throws IOException {
private static long[] getIndexBoundsFromLeaves(final SearchContext context, final String fieldName) throws IOException {
final List<LeafReaderContext> leaves = context.searcher().getIndexReader().leaves();
long min = Long.MAX_VALUE, max = Long.MIN_VALUE;
// Since the query does not specify bounds for aggregation, we can
Expand All @@ -79,8 +74,29 @@ public static Weight[] findBoundsAndCreateFilters(
min = Math.min(min, NumericUtils.sortableBytesToLong(leaf.reader().getPointValues(fieldName).getMinPackedValue(), 0));
max = Math.max(max, NumericUtils.sortableBytesToLong(leaf.reader().getPointValues(fieldName).getMaxPackedValue(), 0));
}

return createFilterForAggregations(context, rounding, preparedRounding, fieldName, min, max);

return new long[] { min, max };
}

public static long[] getAggregationBounds(final SearchContext context, final String fieldName) throws IOException {
final Query cq = unwrapIntoConcreteQuery(context.query());
final long[] indexBounds = getIndexBoundsFromLeaves(context, fieldName);
if (cq instanceof PointRangeQuery) {
final PointRangeQuery prq = (PointRangeQuery) cq;
// Ensure that the query and aggregation are on the same field
if (prq.getField().equals(fieldName)) {
return new long[] {
// Minimum bound for aggregation is the max between query and global
Math.max(NumericUtils.sortableBytesToLong(prq.getLowerPoint(), 0), indexBounds[0]),
// Maximum bound for aggregation is the min between query and global
Math.min(NumericUtils.sortableBytesToLong(prq.getUpperPoint(), 0), indexBounds[1])
};
}
} else if (cq instanceof MatchAllDocsQuery) {
return indexBounds;
}

return null;
}

/**
Expand Down Expand Up @@ -123,12 +139,16 @@ public static Weight[] createFilterForAggregations(

// Calculate the number of buckets using range and interval
long roundedLow = preparedRounding.round(low);
long prevRounded = roundedLow;
int bucketCount = 0;
while (roundedLow < high) {
bucketCount++;
// Below rounding is needed as the interval could return in
// non-rounded values for something like calendar month
roundedLow = preparedRounding.round(roundedLow + interval);
if (prevRounded == roundedLow)
break;
prevRounded = roundedLow;
}

Weight[] filters = null;
Expand All @@ -139,13 +159,13 @@ public static Weight[] createFilterForAggregations(
while (i < bucketCount) {
// Calculate the lower bucket bound
final byte[] lower = new byte[8];
NumericUtils.longToSortableBytes(Math.max(roundedLow, low), lower, 0);
NumericUtils.longToSortableBytes(i==0 ? low : roundedLow, lower, 0);
// Calculate the upper bucket bound
final byte[] upper = new byte[8];
roundedLow = preparedRounding.round(roundedLow + interval);
// Subtract -1 if the minimum is roundedLow as roundedLow itself
// is included in the next bucket
NumericUtils.longToSortableBytes(Math.min(roundedLow - 1, high), upper, 0);
NumericUtils.longToSortableBytes(i+1==bucketCount ? high : roundedLow - 1, upper, 0);
filters[i++] = context.searcher().createWeight(new PointRangeQuery(field, lower, upper, 1) {
@Override
protected String toString(int dimension, byte[] value) {
Expand Down

0 comments on commit 0c27add

Please sign in to comment.