Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
refactor the data provider and try optimize logic

Signed-off-by: bowenlan-amzn <[email protected]>
  • Loading branch information
bowenlan-amzn committed Jun 19, 2024
1 parent 7d9d57e commit 1a067ba
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,14 @@
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.LongUnaryOperator;
import java.util.stream.Collectors;

import static org.opensearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING;

/**
* Main aggregator that aggregates docs from mulitple aggregations
* Main aggregator that aggregates docs from multiple aggregations
*
* @opensearch.internal
*/
Expand All @@ -120,7 +121,6 @@ public final class CompositeAggregator extends BucketsAggregator {

private final OptimizationContext optimizationContext;
private LongKeyedBucketOrds bucketOrds = null;
private Rounding.Prepared preparedRounding = null;

CompositeAggregator(
String name,
Expand Down Expand Up @@ -168,22 +168,19 @@ public final class CompositeAggregator extends BucketsAggregator {

optimizationContext = new OptimizationContext(context, new CompositeAggAggregatorDataProvider());
if (optimizationContext.canOptimize(parent, subAggregators.length)) {
// bucketOrds is used for saving date histogram results
bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), CardinalityUpperBound.ONE);
preparedRounding = ((CompositeAggAggregatorDataProvider) optimizationContext.getAggregationType()).getRoundingPrepared();
optimizationContext.buildRanges(sourceConfigs[0].fieldType());
}
}

/**
* Currently the filter rewrite is only supported for date histograms
*/
public class CompositeAggAggregatorDataProvider extends AbstractDateHistogramAggAggregatorDataProvider {
private final class CompositeAggAggregatorDataProvider extends AbstractDateHistogramAggAggregatorDataProvider {
private RoundingValuesSource valuesSource;
private long afterKey = -1L;

@Override
public boolean canOptimize() {
protected boolean canOptimize() {
if (sourceConfigs.length != 1 || !(sourceConfigs[0].valuesSource() instanceof RoundingValuesSource)) return false;
if (canOptimize(sourceConfigs[0].missingBucket(), sourceConfigs[0].hasScript(), sourceConfigs[0].fieldType())) {
this.valuesSource = (RoundingValuesSource) sourceConfigs[0].valuesSource();
Expand All @@ -193,16 +190,20 @@ public boolean canOptimize() {
throw new IllegalArgumentException("now() is not supported in [after] key");
});
}

// bucketOrds is used for saving the date histogram results got from the optimization path
bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), CardinalityUpperBound.ONE);

return true;
}
return false;
}

public Rounding getRounding(final long low, final long high) {
protected Rounding getRounding(final long low, final long high) {
return valuesSource.getRounding();
}

public Rounding.Prepared getRoundingPrepared() {
protected Rounding.Prepared getRoundingPrepared() {
return valuesSource.getPreparedRounding();
}

Expand All @@ -217,9 +218,14 @@ protected long[] processAfterKey(long[] bounds, long interval) {
}

@Override
public int getSize() {
protected int getSize() {
return size;
}

@Override
protected Function<Object, Long> bucketOrdProducer() {
return (key) -> bucketOrds.add(0, getRoundingPrepared().round((long) key));
}
}

@Override
Expand Down Expand Up @@ -371,7 +377,7 @@ private boolean isMaybeMultivalued(LeafReaderContext context, SortField sortFiel
return v2 != null && DocValues.unwrapSingleton(v2) == null;

default:
// we have no clue whether the field is multi-valued or not so we assume it is.
// we have no clue whether the field is multivalued or not so we assume it is.
return true;
}
}
Expand Down Expand Up @@ -554,11 +560,7 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
boolean optimized = optimizationContext.tryFastFilterAggregation(
ctx,
this::incrementBucketDocCount,
(key) -> bucketOrds.add(0, preparedRounding.round((long) key))
);
boolean optimized = optimizationContext.tryFastFilterAggregation(ctx, this::incrementBucketDocCount);
if (optimized) throw new CollectionTerminatedException();

finishLeaf();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ private AutoDateHistogramAggregator(
}
}

private class AutoHistogramAggAggregatorDataProvider extends AbstractDateHistogramAggAggregatorDataProvider {
private final class AutoHistogramAggAggregatorDataProvider extends AbstractDateHistogramAggAggregatorDataProvider {

@Override
public boolean canOptimize() {
protected boolean canOptimize() {
return canOptimize(valuesSourceConfig);
}

Expand Down Expand Up @@ -201,6 +201,11 @@ protected Rounding getRounding(final long low, final long high) {
protected Prepared getRoundingPrepared() {
return preparedRounding;
}

@Override
protected Function<Object, Long> bucketOrdProducer() {
return (key) -> getBucketOrds().add(0, preparedRounding.round((long) key));
}
}

protected abstract LongKeyedBucketOrds getBucketOrds();
Expand Down Expand Up @@ -232,11 +237,7 @@ public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBuc
return LeafBucketCollector.NO_OP_COLLECTOR;
}

boolean optimized = optimizationContext.tryFastFilterAggregation(
ctx,
this::incrementBucketDocCount,
(key) -> getBucketOrds().add(0, preparedRounding.round((long) key))
);
boolean optimized = optimizationContext.tryFastFilterAggregation(ctx, this::incrementBucketDocCount);
if (optimized) throw new CollectionTerminatedException();

final SortedNumericDocValues values = valuesSource.longValues(ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Function;

/**
* An aggregator for date values. Every date is rounded down using a configured
Expand Down Expand Up @@ -125,10 +126,9 @@ class DateHistogramAggregator extends BucketsAggregator implements SizedBucketAg
}
}

private class DateHistogramAggAggregatorDataProvider extends AbstractDateHistogramAggAggregatorDataProvider {

private final class DateHistogramAggAggregatorDataProvider extends AbstractDateHistogramAggAggregatorDataProvider {
@Override
public boolean canOptimize() {
protected boolean canOptimize() {
return canOptimize(valuesSourceConfig);
}

Expand All @@ -146,6 +146,11 @@ protected Rounding.Prepared getRoundingPrepared() {
protected long[] processHardBounds(long[] bounds) {
return super.processHardBounds(bounds, hardBounds);
}

@Override
protected Function<Object, Long> bucketOrdProducer() {
return (key) -> bucketOrds.add(0, preparedRounding.round((long) key));
}
}

@Override
Expand All @@ -162,11 +167,7 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol
return LeafBucketCollector.NO_OP_COLLECTOR;
}

boolean optimized = optimizationContext.tryFastFilterAggregation(
ctx,
this::incrementBucketDocCount,
(key) -> bucketOrds.add(0, preparedRounding.round((long) key))
);
boolean optimized = optimizationContext.tryFastFilterAggregation(ctx, this::incrementBucketDocCount);
if (optimized) throw new CollectionTerminatedException();

SortedNumericDocValues values = valuesSource.longValues(ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.fielddata.SortedNumericDoubleValues;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.AggregatorFactories;
Expand All @@ -68,6 +67,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Function;

import static org.opensearch.core.xcontent.ConstructingObjectParser.optionalConstructorArg;

Expand Down Expand Up @@ -289,15 +289,20 @@ public RangeAggregator(
}
}

class RangeAggregatorDataProvider extends AbstractRangeAggregatorDataProvider {
private final class RangeAggregatorDataProvider extends AbstractRangeAggregatorDataProvider {
@Override
public boolean canOptimize() {
protected boolean canOptimize() {
return canOptimize(valuesSourceConfig, ranges);
}

@Override
public OptimizationContext.Ranges buildRanges(SearchContext ctx, MappedFieldType fieldType) {
return buildRanges(fieldType, ranges);
protected void buildRanges(SearchContext ctx) {
buildRanges(ranges);
}

@Override
protected Function<Object, Long> bucketOrdProducer() {
return (activeIndex) -> subBucketOrdinal(0, (int) activeIndex);
}
}

Expand All @@ -311,11 +316,7 @@ public ScoreMode scoreMode() {

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
boolean optimized = optimizationContext.tryFastFilterAggregation(
ctx,
this::incrementBucketDocCount,
(activeIndex) -> subBucketOrdinal(0, (int) activeIndex)
);
boolean optimized = optimizationContext.tryFastFilterAggregation(ctx, this::incrementBucketDocCount);
if (optimized) throw new CollectionTerminatedException();

final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,18 @@
import java.io.IOException;
import java.util.OptionalLong;
import java.util.function.BiConsumer;
import java.util.function.Function;

import static org.opensearch.search.optimization.ranges.OptimizationContext.multiRangesTraverse;

/**
* For date histogram aggregation
*/
public abstract class AbstractDateHistogramAggAggregatorDataProvider implements AggregatorDataProvider {
public abstract class AbstractDateHistogramAggAggregatorDataProvider extends AggregatorDataProvider {
private MappedFieldType fieldType;

public boolean canOptimize(boolean missing, boolean hasScript, MappedFieldType fieldType) {
protected boolean canOptimize(boolean missing, boolean hasScript, MappedFieldType fieldType) {
if (!missing && !hasScript) {
if (fieldType != null && fieldType instanceof DateFieldMapper.DateFieldType) {
if (fieldType instanceof DateFieldMapper.DateFieldType) {
if (fieldType.isSearchable()) {
this.fieldType = fieldType;
return true;
Expand All @@ -43,7 +42,7 @@ public boolean canOptimize(boolean missing, boolean hasScript, MappedFieldType f
return false;
}

public boolean canOptimize(ValuesSourceConfig config) {
protected boolean canOptimize(ValuesSourceConfig config) {
if (config.script() == null && config.missing() == null) {
MappedFieldType fieldType = config.fieldType();
if (fieldType instanceof DateFieldMapper.DateFieldType) {
Expand All @@ -57,21 +56,18 @@ public boolean canOptimize(ValuesSourceConfig config) {
}

@Override
public OptimizationContext.Ranges buildRanges(SearchContext context, MappedFieldType fieldType) throws IOException {
protected void buildRanges(SearchContext context) throws IOException {
long[] bounds = Helper.getDateHistoAggBounds(context, fieldType.name());
// logger.debug("Bounds are {} for shard {}", bounds, context.indexShard().shardId());
return buildRanges(context, bounds);
this.optimizationContext.setRanges(buildRanges(context, bounds));
}

@Override
public OptimizationContext.Ranges buildRanges(LeafReaderContext leaf, SearchContext context, MappedFieldType fieldType)
throws IOException {
protected void buildRanges(LeafReaderContext leaf, SearchContext context) throws IOException {
long[] bounds = Helper.getSegmentBounds(leaf, fieldType.name());
// logger.debug("Bounds are {} for shard {} segment {}", bounds, context.indexShard().shardId(), leaf.ord);
return buildRanges(context, bounds);
this.optimizationContext.setRanges(buildRanges(context, bounds));
}

private OptimizationContext.Ranges buildRanges(SearchContext context, long[] bounds) throws IOException {
private OptimizationContext.Ranges buildRanges(SearchContext context, long[] bounds) {
bounds = processHardBounds(bounds);
if (bounds == null) {
return null;
Expand All @@ -89,12 +85,12 @@ private OptimizationContext.Ranges buildRanges(SearchContext context, long[] bou
bounds = processAfterKey(bounds, interval);

return Helper.createRangesFromAgg(
context,
(DateFieldMapper.DateFieldType) fieldType,
interval,
getRoundingPrepared(),
bounds[0],
bounds[1]
bounds[1],
context.maxAggRewriteFilters()
);
}

Expand Down Expand Up @@ -128,33 +124,29 @@ protected long[] processHardBounds(long[] bounds, LongBounds hardBounds) {
return bounds;
}

public DateFieldMapper.DateFieldType getFieldType() {
private DateFieldMapper.DateFieldType getFieldType() {
assert fieldType instanceof DateFieldMapper.DateFieldType;
return (DateFieldMapper.DateFieldType) fieldType;
}

public int getSize() {
protected int getSize() {
return Integer.MAX_VALUE;
}

@Override
public OptimizationContext.DebugInfo tryFastFilterAggregation(
PointValues values,
OptimizationContext.Ranges ranges,
BiConsumer<Long, Long> incrementDocCount,
Function<Object, Long> bucketOrd
) throws IOException {
protected final void tryFastFilterAggregation(PointValues values, BiConsumer<Long, Long> incrementDocCount) throws IOException {
int size = getSize();
OptimizationContext.Ranges ranges = this.optimizationContext.getRanges();

DateFieldMapper.DateFieldType fieldType = getFieldType();
BiConsumer<Integer, Integer> incrementFunc = (activeIndex, docCount) -> {
long rangeStart = LongPoint.decodeDimension(ranges.lowers[activeIndex], 0);
rangeStart = fieldType.convertNanosToMillis(rangeStart);
long ord = getBucketOrd(bucketOrd.apply(rangeStart));
long ord = getBucketOrd(bucketOrdProducer().apply(rangeStart));
incrementDocCount.accept(ord, (long) docCount);
};

return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size);
this.optimizationContext.consumeDebugInfo(multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size));
}

private static long getBucketOrd(long bucketOrd) {
Expand Down
Loading

0 comments on commit 1a067ba

Please sign in to comment.