Skip to content

Commit

Permalink
Applying the optimizations for match all query as well
Browse files Browse the repository at this point in the history
Signed-off-by: Ankit Jain <[email protected]>
  • Loading branch information
jainankitk committed Nov 6, 2023
1 parent d1b5380 commit e57424e
Showing 1 changed file with 32 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
Expand All @@ -61,6 +63,7 @@
import java.time.ZoneId;
import java.time.format.TextStyle;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.BiConsumer;
Expand Down Expand Up @@ -135,12 +138,28 @@ class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAg
if (q.getIndexQuery() instanceof PointRangeQuery) {
final PointRangeQuery prq = (PointRangeQuery) q.getIndexQuery();
// Ensure that the query and aggregation are on the same field
if (valuesSource != null && prq.getField().equals(fieldName)) {
long low = NumericUtils.sortableBytesToLong(prq.getLowerPoint(), 0);
long high = NumericUtils.sortableBytesToLong(prq.getUpperPoint(), 0);
createFilterForAggregations(fieldName, low, high);
if (prq.getField().equals(fieldName)) {
createFilterForAggregations(fieldName,
NumericUtils.sortableBytesToLong(prq.getLowerPoint(), 0),
NumericUtils.sortableBytesToLong(prq.getUpperPoint(), 0));
}
}
} else if (context.query() instanceof ConstantScoreQuery) {
final ConstantScoreQuery csq = (ConstantScoreQuery) context.query();
// Ensure that the constant score query is instance of match all query
if (csq.getQuery() instanceof MatchAllDocsQuery) {
final List<LeafReaderContext> leaves = aggregationContext.searcher().getIndexReader().leaves();
long min = Long.MAX_VALUE, max = Long.MIN_VALUE;
// Since the query does not specify bounds for aggregation, we can
// build the global min/max from local min/max within each segment
for (LeafReaderContext leaf : leaves) {
min = Math.min(min, NumericUtils.sortableBytesToLong(
leaf.reader().getPointValues(fieldName).getMinPackedValue(), 0));
max = Math.max(max, NumericUtils.sortableBytesToLong(
leaf.reader().getPointValues(fieldName).getMaxPackedValue(), 0));
}
createFilterForAggregations(fieldName, min, max);
}
}
}
}
Expand Down Expand Up @@ -299,8 +318,8 @@ private boolean isUTCTimeZone(final ZoneId zoneId) {
return "Z".equals(zoneId.getDisplayName(TextStyle.FULL, Locale.ENGLISH));
}

private void createFilterForAggregations(String field, long low, long high) throws IOException {
long interval = Long.MAX_VALUE;
private void createFilterForAggregations(final String field, final long low, final long high) throws IOException {
long interval;
if (rounding instanceof Rounding.TimeUnitRounding) {
interval = (((Rounding.TimeUnitRounding) rounding).unit).extraLocalOffsetLookup();
if (!isUTCTimeZone(((Rounding.TimeUnitRounding) rounding).timeZone)) {
Expand All @@ -323,21 +342,21 @@ private void createFilterForAggregations(String field, long low, long high) thro
int bucketCount = 0;
while (roundedLow < high) {
bucketCount++;
roundedLow += interval;
// Below rounding is needed as the interval could return in
// non-rounded values for something like calendar month
roundedLow = preparedRounding.round(roundedLow);
roundedLow = preparedRounding.round(roundedLow + interval);
}

if (bucketCount > 0 && bucketCount <= MAX_NUM_FILTER_BUCKETS) {
int i = 0;
filters = new Weight[bucketCount];
roundedLow = preparedRounding.round(low);
while (i < bucketCount) {
byte[] lower = new byte[8];
NumericUtils.longToSortableBytes(low, lower, 0);
byte[] upper = new byte[8];
// Calculate the upper bucket
roundedLow = preparedRounding.round(low);
// Calculate the lower bucket bound
final byte[] lower = new byte[8];
NumericUtils.longToSortableBytes(Math.max(roundedLow, low), lower, 0);
// Calculate the upper bucket bound
final byte[] upper = new byte[8];
roundedLow = preparedRounding.round(roundedLow + interval);
// Subtract -1 if the minimum is roundedLow as roundedLow itself
// is included in the next bucket
Expand Down

0 comments on commit e57424e

Please sign in to comment.