Skip to content

Commit

Permalink
Restructuring the code for making it more reusable and unit testable
Browse files Browse the repository at this point in the history
Signed-off-by: Ankit Jain <[email protected]>
  • Loading branch information
jainankitk committed Nov 17, 2023
1 parent f12ed77 commit d9bbd1f
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
Expand All @@ -46,8 +44,6 @@
import org.opensearch.common.Nullable;
import org.opensearch.common.Rounding;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.lucene.search.function.FunctionScoreQuery;
import org.opensearch.index.query.DateRangeIncludingNowQuery;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.AggregatorFactories;
Expand All @@ -63,15 +59,9 @@
import org.opensearch.search.internal.SearchContext;

import java.io.IOException;
import java.time.ZoneId;
import java.time.format.TextStyle;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Function;

/**
* An aggregator for date values. Every date is rounded down using a configured
Expand All @@ -82,8 +72,6 @@
* @opensearch.internal
*/
class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAggregator {

private static final int MAX_NUM_FILTER_BUCKETS = 1024;
private final ValuesSource.Numeric valuesSource;
private final DocValueFormat formatter;
private final Rounding rounding;
Expand All @@ -93,15 +81,10 @@ class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAg
private final Rounding.Prepared preparedRounding;
private final BucketOrder order;
private final boolean keyed;

private final Map<Class, Function<Query, Query>> queryWrappers;

private final long minDocCount;
private final LongBounds extendedBounds;
private final LongBounds hardBounds;

private Weight[] filters = null;

private final LongKeyedBucketOrds bucketOrds;

DateHistogramAggregator(
Expand Down Expand Up @@ -132,31 +115,29 @@ class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAg
// TODO: Stop using null here
this.valuesSource = valuesSourceConfig.hasValues() ? (ValuesSource.Numeric) valuesSourceConfig.getValuesSource() : null;
this.formatter = valuesSourceConfig.format();
this.queryWrappers = new HashMap<>();
queryWrappers.put(ConstantScoreQuery.class, q -> ((ConstantScoreQuery) q).getQuery());
queryWrappers.put(FunctionScoreQuery.class, q -> ((FunctionScoreQuery) q).getSubQuery());
queryWrappers.put(DateRangeIncludingNowQuery.class, q -> ((DateRangeIncludingNowQuery) q).getQuery());
queryWrappers.put(IndexOrDocValuesQuery.class, q -> ((IndexOrDocValuesQuery) q).getIndexQuery());

bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), cardinality);

// Create the filters for fast aggregation only if the query is instance
// of point range query and there aren't any parent/sub aggregations
if (parent() == null && subAggregators.length == 0) {
final String fieldName = valuesSourceConfig.fieldContext().field();
final Query cq = unwrapIntoConcreteQuery(context.query());
final Query cq = FilterRewriteHelper.unwrapIntoConcreteQuery(context.query());
if (cq instanceof PointRangeQuery) {
final PointRangeQuery prq = (PointRangeQuery) cq;
// Ensure that the query and aggregation are on the same field
if (prq.getField().equals(fieldName)) {
createFilterForAggregations(
FilterRewriteHelper.createFilterForAggregations(
aggregationContext,
rounding,
preparedRounding,
fieldName,
NumericUtils.sortableBytesToLong(prq.getLowerPoint(), 0),
NumericUtils.sortableBytesToLong(prq.getUpperPoint(), 0)
);
}
} else if (cq instanceof MatchAllDocsQuery) {
findBoundsAndCreateFilters(fieldName, context);
FilterRewriteHelper.findBoundsAndCreateFilters(context, rounding, preparedRounding, fieldName);
}
}
}
Expand Down Expand Up @@ -308,85 +289,4 @@ private boolean tryFastFilterAggregation(LeafReaderContext ctx, long owningBucke
}
throw new CollectionTerminatedException();
}

/**
* Recursively unwraps query into the concrete form
* for applying the optimization
*/
private Query unwrapIntoConcreteQuery(Query query) {
while (queryWrappers.containsKey(query.getClass())) {
query = queryWrappers.get(query.getClass()).apply(query);
}

return query;
}

private void findBoundsAndCreateFilters(final String fieldName, final SearchContext context) throws IOException {
final List<LeafReaderContext> leaves = context.searcher().getIndexReader().leaves();
long min = Long.MAX_VALUE, max = Long.MIN_VALUE;
// Since the query does not specify bounds for aggregation, we can
// build the global min/max from local min/max within each segment
for (LeafReaderContext leaf : leaves) {
min = Math.min(min, NumericUtils.sortableBytesToLong(leaf.reader().getPointValues(fieldName).getMinPackedValue(), 0));
max = Math.max(max, NumericUtils.sortableBytesToLong(leaf.reader().getPointValues(fieldName).getMaxPackedValue(), 0));
}
createFilterForAggregations(fieldName, min, max);
}

private boolean isUTCTimeZone(final ZoneId zoneId) {
return "Z".equals(zoneId.getDisplayName(TextStyle.FULL, Locale.ENGLISH));
}

private void createFilterForAggregations(final String field, final long low, final long high) throws IOException {
long interval;
if (rounding instanceof Rounding.TimeUnitRounding) {
interval = (((Rounding.TimeUnitRounding) rounding).unit).extraLocalOffsetLookup();
if (!isUTCTimeZone(((Rounding.TimeUnitRounding) rounding).timeZone)) {
// Fast filter aggregation cannot be used if it needs time zone rounding
return;
}
} else if (rounding instanceof Rounding.TimeIntervalRounding) {
interval = ((Rounding.TimeIntervalRounding) rounding).interval;
if (!isUTCTimeZone(((Rounding.TimeIntervalRounding) rounding).timeZone)) {
// Fast filter aggregation cannot be used if it needs time zone rounding
return;
}
} else {
// Unexpected scenario, exit and fall back to original
return;
}

// Calculate the number of buckets using range and interval
long roundedLow = preparedRounding.round(low);
int bucketCount = 0;
while (roundedLow < high) {
bucketCount++;
// Below rounding is needed as the interval could return in
// non-rounded values for something like calendar month
roundedLow = preparedRounding.round(roundedLow + interval);
}

if (bucketCount > 0 && bucketCount <= MAX_NUM_FILTER_BUCKETS) {
int i = 0;
filters = new Weight[bucketCount];
roundedLow = preparedRounding.round(low);
while (i < bucketCount) {
// Calculate the lower bucket bound
final byte[] lower = new byte[8];
NumericUtils.longToSortableBytes(Math.max(roundedLow, low), lower, 0);
// Calculate the upper bucket bound
final byte[] upper = new byte[8];
roundedLow = preparedRounding.round(roundedLow + interval);
// Subtract -1 if the minimum is roundedLow as roundedLow itself
// is included in the next bucket
NumericUtils.longToSortableBytes(Math.min(roundedLow - 1, high), upper, 0);
filters[i++] = context.searcher().createWeight(new PointRangeQuery(field, lower, upper, 1) {
@Override
protected String toString(int dimension, byte[] value) {
return null;
}
}, ScoreMode.COMPLETE_NO_SCORES, 1);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.aggregations.bucket.histogram;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.NumericUtils;
import org.opensearch.common.Rounding;
import org.opensearch.common.lucene.search.function.FunctionScoreQuery;
import org.opensearch.index.query.DateRangeIncludingNowQuery;
import org.opensearch.search.internal.SearchContext;

import java.io.IOException;
import java.time.ZoneId;
import java.time.format.TextStyle;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;

public class FilterRewriteHelper {
private static final int MAX_NUM_FILTER_BUCKETS = 1024;
private static final Map<Class, Function<Query, Query>> queryWrappers;

// Initialize the wrappers map for unwrapping the query
static {
queryWrappers = new HashMap<>();
queryWrappers.put(ConstantScoreQuery.class, q -> ((ConstantScoreQuery) q).getQuery());
queryWrappers.put(FunctionScoreQuery.class, q -> ((FunctionScoreQuery) q).getSubQuery());
queryWrappers.put(DateRangeIncludingNowQuery.class, q -> ((DateRangeIncludingNowQuery) q).getQuery());
queryWrappers.put(IndexOrDocValuesQuery.class, q -> ((IndexOrDocValuesQuery) q).getIndexQuery());
}

/**
* Recursively unwraps query into the concrete form
* for applying the optimization
*/
public static Query unwrapIntoConcreteQuery(Query query) {
while (queryWrappers.containsKey(query.getClass())) {
query = queryWrappers.get(query.getClass()).apply(query);
}

return query;
}

/**
* Finds the min and max bounds for segments within the passed search context
* and creates the weight filters using range queries within those bounds
*/
public static void findBoundsAndCreateFilters(
final SearchContext context,
final Rounding rounding,
final Rounding.Prepared preparedRounding,
final String fieldName
) throws IOException {
final List<LeafReaderContext> leaves = context.searcher().getIndexReader().leaves();
long min = Long.MAX_VALUE, max = Long.MIN_VALUE;
// Since the query does not specify bounds for aggregation, we can
// build the global min/max from local min/max within each segment
for (LeafReaderContext leaf : leaves) {
min = Math.min(min, NumericUtils.sortableBytesToLong(leaf.reader().getPointValues(fieldName).getMinPackedValue(), 0));
max = Math.max(max, NumericUtils.sortableBytesToLong(leaf.reader().getPointValues(fieldName).getMaxPackedValue(), 0));
}
createFilterForAggregations(context, rounding, preparedRounding, fieldName, min, max);
}

/**
* Helper function for checking if the time zone requested for date histogram
* aggregation is utc or not
*/
private static boolean isUTCTimeZone(final ZoneId zoneId) {
return "Z".equals(zoneId.getDisplayName(TextStyle.FULL, Locale.ENGLISH));
}

/**
* Creates the range query filters for aggregations using the interval, min/max
* bounds and the rounding values
*/
public static Weight[] createFilterForAggregations(
final SearchContext context,
final Rounding rounding,
final Rounding.Prepared preparedRounding,
final String field,
final long low,
final long high
) throws IOException {
long interval;
if (rounding instanceof Rounding.TimeUnitRounding) {
interval = (((Rounding.TimeUnitRounding) rounding).unit).extraLocalOffsetLookup();
if (!isUTCTimeZone(((Rounding.TimeUnitRounding) rounding).timeZone)) {
// Fast filter aggregation cannot be used if it needs time zone rounding
return null;
}
} else if (rounding instanceof Rounding.TimeIntervalRounding) {
interval = ((Rounding.TimeIntervalRounding) rounding).interval;
if (!isUTCTimeZone(((Rounding.TimeIntervalRounding) rounding).timeZone)) {
// Fast filter aggregation cannot be used if it needs time zone rounding
return null;
}
} else {
// Unexpected scenario, exit and fall back to original
return null;
}

// Calculate the number of buckets using range and interval
long roundedLow = preparedRounding.round(low);
int bucketCount = 0;
while (roundedLow < high) {
bucketCount++;
// Below rounding is needed as the interval could return in
// non-rounded values for something like calendar month
roundedLow = preparedRounding.round(roundedLow + interval);
}

Weight[] filters = null;
if (bucketCount > 0 && bucketCount <= MAX_NUM_FILTER_BUCKETS) {
int i = 0;
filters = new Weight[bucketCount];
roundedLow = preparedRounding.round(low);
while (i < bucketCount) {
// Calculate the lower bucket bound
final byte[] lower = new byte[8];
NumericUtils.longToSortableBytes(Math.max(roundedLow, low), lower, 0);
// Calculate the upper bucket bound
final byte[] upper = new byte[8];
roundedLow = preparedRounding.round(roundedLow + interval);
// Subtract -1 if the minimum is roundedLow as roundedLow itself
// is included in the next bucket
NumericUtils.longToSortableBytes(Math.min(roundedLow - 1, high), upper, 0);
filters[i++] = context.searcher().createWeight(new PointRangeQuery(field, lower, upper, 1) {
@Override
protected String toString(int dimension, byte[] value) {
return null;
}
}, ScoreMode.COMPLETE_NO_SCORES, 1);
}
}

return filters;
}
}

0 comments on commit d9bbd1f

Please sign in to comment.