Skip to content

Commit

Permalink
Adding logic for recursively unwrapping the query
Browse files Browse the repository at this point in the history
  • Loading branch information
jainankitk committed Nov 17, 2023
1 parent 021ed69 commit 6723da7
Showing 1 changed file with 37 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.CollectionUtil;
import org.apache.lucene.util.NumericUtils;
import org.opensearch.common.Nullable;
import org.opensearch.common.Rounding;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.lucene.search.function.FunctionScoreQuery;
import org.opensearch.index.query.DateRangeIncludingNowQuery;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.AggregatorFactories;
Expand All @@ -63,10 +66,12 @@
import java.time.ZoneId;
import java.time.format.TextStyle;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Function;

/**
* An aggregator for date values. Every date is rounded down using a configured
Expand All @@ -89,6 +94,8 @@ class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAg
private final BucketOrder order;
private final boolean keyed;

private final Map<Class, Function<Query, Query>> queryWrappers;

private final long minDocCount;
private final LongBounds extendedBounds;
private final LongBounds hardBounds;
Expand Down Expand Up @@ -125,6 +132,11 @@ class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAg
// TODO: Stop using null here
this.valuesSource = valuesSourceConfig.hasValues() ? (ValuesSource.Numeric) valuesSourceConfig.getValuesSource() : null;
this.formatter = valuesSourceConfig.format();
this.queryWrappers = new HashMap<>();
queryWrappers.put(ConstantScoreQuery.class, q -> ((ConstantScoreQuery) q).getQuery());
queryWrappers.put(FunctionScoreQuery.class, q -> ((FunctionScoreQuery) q).getSubQuery());
queryWrappers.put(DateRangeIncludingNowQuery.class, q -> ((DateRangeIncludingNowQuery) q).getQuery());
queryWrappers.put(IndexOrDocValuesQuery.class, q -> ((IndexOrDocValuesQuery) q).getIndexQuery());

bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), cardinality);

Expand All @@ -133,24 +145,18 @@ class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAg
// of point range query and there aren't any parent/sub aggregations
if (parent() == null && subAggregators.length == 0) {
final String fieldName = valuesSourceConfig.fieldContext().field();
if (context.query() instanceof IndexOrDocValuesQuery) {
final IndexOrDocValuesQuery q = (IndexOrDocValuesQuery) context.query();
if (q.getIndexQuery() instanceof PointRangeQuery) {
final PointRangeQuery prq = (PointRangeQuery) q.getIndexQuery();
// Ensure that the query and aggregation are on the same field
if (prq.getField().equals(fieldName)) {
createFilterForAggregations(fieldName,
NumericUtils.sortableBytesToLong(prq.getLowerPoint(), 0),
NumericUtils.sortableBytesToLong(prq.getUpperPoint(), 0));
}
}
} else if (context.query() instanceof ConstantScoreQuery) {
final ConstantScoreQuery csq = (ConstantScoreQuery) context.query();
// Ensure that the constant score query is instance of match all query
if (csq.getQuery() instanceof MatchAllDocsQuery) {
findBoundsAndCreateFilters(fieldName, context);
final Query cq = unwrapIntoConcreteQuery(context.query());
if (cq instanceof PointRangeQuery) {
final PointRangeQuery prq = (PointRangeQuery) cq;
// Ensure that the query and aggregation are on the same field
if (prq.getField().equals(fieldName)) {
createFilterForAggregations(
fieldName,
NumericUtils.sortableBytesToLong(prq.getLowerPoint(), 0),
NumericUtils.sortableBytesToLong(prq.getUpperPoint(), 0)
);
}
} else if (context.query() instanceof MatchAllDocsQuery) {
} else if (cq instanceof MatchAllDocsQuery) {
findBoundsAndCreateFilters(fieldName, context);
}
}
Expand Down Expand Up @@ -306,16 +312,26 @@ private boolean tryFastFilterAggregation(LeafReaderContext ctx, long owningBucke
throw new CollectionTerminatedException();
}

/**
* Recursively unwraps query into the concrete form
* for applying the optimization
*/
private Query unwrapIntoConcreteQuery(Query query) {
while (queryWrappers.containsKey(query.getClass())) {
query = queryWrappers.get(query.getClass()).apply(query);
}

return query;
}

private void findBoundsAndCreateFilters(final String fieldName, final SearchContext context) throws IOException {
final List<LeafReaderContext> leaves = context.searcher().getIndexReader().leaves();
long min = Long.MAX_VALUE, max = Long.MIN_VALUE;
// Since the query does not specify bounds for aggregation, we can
// build the global min/max from local min/max within each segment
for (LeafReaderContext leaf : leaves) {
min = Math.min(min, NumericUtils.sortableBytesToLong(
leaf.reader().getPointValues(fieldName).getMinPackedValue(), 0));
max = Math.max(max, NumericUtils.sortableBytesToLong(
leaf.reader().getPointValues(fieldName).getMaxPackedValue(), 0));
min = Math.min(min, NumericUtils.sortableBytesToLong(leaf.reader().getPointValues(fieldName).getMinPackedValue(), 0));
max = Math.max(max, NumericUtils.sortableBytesToLong(leaf.reader().getPointValues(fieldName).getMaxPackedValue(), 0));
}
System.out.println("Auto min and max for aggregation are : " + min + " : " + max);
createFilterForAggregations(fieldName, min, max);
Expand Down

0 comments on commit 6723da7

Please sign in to comment.