From 58e5281e81d8f6c284994521d6496baa01f7cece Mon Sep 17 00:00:00 2001 From: bowenlan-amzn Date: Wed, 5 Jun 2024 15:00:58 -0700 Subject: [PATCH] refactor Signed-off-by: bowenlan-amzn --- .../bucket/FastFilterRewriteHelper.java | 69 ++++------- .../bucket/range/RangeAggregator.java | 1 - .../bucket/range/RangeAggregatorTests.java | 115 ++++++++++++++++-- 3 files changed, 132 insertions(+), 53 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/FastFilterRewriteHelper.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/FastFilterRewriteHelper.java index 6f858a0ff9ec9..c2ad56f46d0c3 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/FastFilterRewriteHelper.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/FastFilterRewriteHelper.java @@ -195,7 +195,6 @@ public static class FastFilterContext { private AggregationType aggregationType; private final SearchContext context; - private String fieldName; private MappedFieldType fieldType; private Ranges ranges; @@ -228,7 +227,6 @@ public boolean isRewriteable(final Object parent, final int subAggLength) { public void buildRanges(MappedFieldType fieldType) throws IOException { assert ranges == null : "Ranges should only be built once at shard level, but they are already built"; - this.fieldName = fieldType.name(); this.fieldType = fieldType; this.ranges = this.aggregationType.buildRanges(context, fieldType); if (ranges != null) { @@ -249,6 +247,9 @@ private Ranges buildRanges(LeafReaderContext leaf) throws IOException { * Try to populate the bucket doc counts for aggregation *

* Usage: invoked at segment level — in getLeafCollector of aggregator + * + * @param bucketOrd bucket ordinal producer + * @param incrementDocCount consume the doc_count results for certain ordinal */ public boolean tryFastFilterAggregation( final LeafReaderContext ctx, @@ -262,7 +263,7 @@ public boolean tryFastFilterAggregation( if (ctx.reader().hasDeletions()) return false; - PointValues values = ctx.reader().getPointValues(this.fieldName); + PointValues values = ctx.reader().getPointValues(this.fieldType.name()); if (values == null) return false; // only proceed if every document corresponds to exactly one point if (values.getDocCount() != values.size()) return false; @@ -458,13 +459,11 @@ public DebugInfo tryFastFilterAggregation( */ public static class RangeAggregationType implements AggregationType { - private final ValuesSource.Numeric source; private final ValuesSourceConfig config; private final Range[] ranges; private FieldTypeEnum fieldTypeEnum; public RangeAggregationType(ValuesSourceConfig config, Range[] ranges) { - this.source = (ValuesSource.Numeric) config.getValuesSource(); this.config = config; this.ranges = ranges; } @@ -482,7 +481,7 @@ public boolean isRewriteable(Object parent, int subAggLength) { return false; } - if (source instanceof ValuesSource.Numeric.FieldData) { + if (config.getValuesSource() instanceof ValuesSource.Numeric.FieldData) { // ranges are already sorted by from and then to // we want ranges not overlapping with each other double prevTo = ranges[0].getTo(); @@ -499,7 +498,7 @@ public boolean isRewriteable(Object parent, int subAggLength) { } @Override - public Ranges buildRanges(SearchContext ctx, MappedFieldType fieldType) throws IOException { + public Ranges buildRanges(SearchContext ctx, MappedFieldType fieldType) { int byteLen = this.fieldTypeEnum.getByteLen(); String pointType = this.fieldTypeEnum.getPointType(); @@ -604,26 +603,8 @@ static FieldTypeEnum fromTypeName(String typeName) { } } - public static BigInteger convertDoubleToBigInteger(double value) { - // we use big integer to represent unsigned long - BigInteger maxUnsignedLong = BigInteger.valueOf(2).pow(64).subtract(BigInteger.ONE); - - if (Double.isNaN(value)) { - return BigInteger.ZERO; - } else if (Double.isInfinite(value)) { - if (value > 0) { - return maxUnsignedLong; - } else { - return BigInteger.ZERO; - } - } else { - BigDecimal bigDecimal = BigDecimal.valueOf(value); - return bigDecimal.toBigInteger(); - } - } - @Override - public Ranges buildRanges(LeafReaderContext leaf, SearchContext ctx, MappedFieldType fieldType) throws IOException { + public Ranges buildRanges(LeafReaderContext leaf, SearchContext ctx, MappedFieldType fieldType) { throw new UnsupportedOperationException("Range aggregation should not build ranges at segment level"); } @@ -645,6 +626,24 @@ public DebugInfo tryFastFilterAggregation( } } + public static BigInteger convertDoubleToBigInteger(double value) { + // we use big integer to represent unsigned long + BigInteger maxUnsignedLong = BigInteger.valueOf(2).pow(64).subtract(BigInteger.ONE); + + if (Double.isNaN(value)) { + return BigInteger.ZERO; + } else if (Double.isInfinite(value)) { + if (value > 0) { + return maxUnsignedLong; + } else { + return BigInteger.ZERO; + } + } else { + BigDecimal bigDecimal = BigDecimal.valueOf(value); + return bigDecimal.toBigInteger(); + } + } + public static boolean isCompositeAggRewriteable(CompositeValuesSourceConfig[] sourceConfigs) { return sourceConfigs.length == 1 && sourceConfigs[0].valuesSource() instanceof RoundingValuesSource; } @@ -776,13 +775,6 @@ public int firstRangeIndex(byte[] globalMin, byte[] globalMax) { int i = 0; while (compareByteValue(uppers[i], globalMin) <= 0) { i++; - // special case - // lower and upper may be same for the last range - // if (i == size - 1) { - // if (compareByteValue(lowers[i], globalMin) >= 0) { - // return i; - // } - // } if (i >= size) { return -1; } @@ -957,27 +949,18 @@ private boolean withinLowerBound(byte[] value) { } private boolean withinUpperBound(byte[] value) { - // special case - // lower and upper may be same for the last range - // if (activeIndex == ranges.size - 1) { - // return Ranges.compareByteValue(value, activeRange[1]) <= 0; - // } return Ranges.withinUpperBound(value, activeRange[1]); } private boolean withinRange(byte[] value) { return withinLowerBound(value) && withinUpperBound(value); } - - private boolean cellCross(byte[] min, byte[] max) { - return Ranges.compareByteValue(activeRange[0], min) > 0 || withinUpperBound(max); - } } /** * Contains debug info of BKD traversal to show in profile */ - public static class DebugInfo { + private static class DebugInfo { private int leaf = 0; // leaf node visited private int inner = 0; // inner node visited diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java index 9140174c74323..6673ec645bf6c 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java @@ -296,7 +296,6 @@ public ScoreMode scoreMode() { @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { - boolean optimized = fastFilterContext.tryFastFilterAggregation( ctx, this::incrementBucketDocCount, diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/range/RangeAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/range/RangeAggregatorTests.java index fe2feec1d3597..50b26fae3cac9 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/bucket/range/RangeAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/range/RangeAggregatorTests.java @@ -50,22 +50,32 @@ import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.BytesRef; import org.opensearch.common.CheckedConsumer; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.indices.breaker.NoneCircuitBreakerService; import org.opensearch.index.mapper.DateFieldMapper; import org.opensearch.index.mapper.KeywordFieldMapper; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregatorTestCase; import org.opensearch.search.aggregations.CardinalityUpperBound; +import org.opensearch.search.aggregations.InternalAggregation; +import org.opensearch.search.aggregations.MultiBucketConsumerService; +import org.opensearch.search.aggregations.bucket.FastFilterRewriteHelper; +import org.opensearch.search.aggregations.pipeline.PipelineAggregator; import org.opensearch.search.aggregations.support.AggregationInspectionHelper; import java.io.IOException; +import java.math.BigInteger; import java.time.ZoneOffset; import java.time.ZonedDateTime; import java.util.Collections; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import static java.util.Collections.singleton; +import static org.opensearch.test.InternalAggregationTestCase.DEFAULT_MAX_BUCKETS; import static org.hamcrest.Matchers.equalTo; public class RangeAggregatorTests extends AggregatorTestCase { @@ -74,6 +84,10 @@ public class RangeAggregatorTests extends AggregatorTestCase { private static final String DATE_FIELD_NAME = "date"; private static final String DOUBLE_FIELD_NAME = "double"; + private static final String FLOAT_FIELD_NAME = "float"; + private static final String HALF_FLOAT_FIELD_NAME = "half_float"; + private static final String UNSIGNED_LONG_FIELD_NAME = "unsigned_long"; + private static final String SCALED_FLOAT_FIELD_NAME = "scaled_float"; public void testNoMatchingField() throws IOException { testCase(new MatchAllDocsQuery(), iw -> { @@ -313,15 +327,38 @@ public void testSubAggCollectsFromManyBucketsIfManyRanges() throws IOException { }); } - public void testDoubleType() throws IOException { + public void testOverlappingRanges() throws IOException { RangeAggregationBuilder aggregationBuilder = new RangeAggregationBuilder("range").field(DOUBLE_FIELD_NAME) .addRange(1, 2) - .addRange(2, 3); + .addRange(1, 1.5) + .addRange(0, 0.5); testRewriteOptimizationCase(aggregationBuilder, DoublePoint.newRangeQuery(DOUBLE_FIELD_NAME, 0, 5), indexWriter -> { indexWriter.addDocument(singleton(new DoubleField(DOUBLE_FIELD_NAME, 0.1, Field.Store.NO))); indexWriter.addDocument(singleton(new DoubleField(DOUBLE_FIELD_NAME, 1.1, Field.Store.NO))); indexWriter.addDocument(singleton(new DoubleField(DOUBLE_FIELD_NAME, 2.1, Field.Store.NO))); + }, range -> { + List ranges = range.getBuckets(); + assertEquals(3, ranges.size()); + assertEquals("0.0-0.5", ranges.get(0).getKeyAsString()); + assertEquals(1, ranges.get(0).getDocCount()); + assertEquals("1.0-1.5", ranges.get(1).getKeyAsString()); + assertEquals(1, ranges.get(1).getDocCount()); + assertEquals("1.0-2.0", ranges.get(2).getKeyAsString()); + assertEquals(1, ranges.get(2).getDocCount()); + assertTrue(AggregationInspectionHelper.hasValue(range)); + }, new NumberFieldMapper.NumberFieldType(DOUBLE_FIELD_NAME, NumberFieldMapper.NumberType.DOUBLE), false); + } + + public void testDoubleType() throws IOException { + RangeAggregationBuilder aggregationBuilder = new RangeAggregationBuilder("range").field(DOUBLE_FIELD_NAME) + .addRange(1, 2) + .addRange(2, 3); + + testRewriteOptimizationCase(aggregationBuilder, new MatchAllDocsQuery(), indexWriter -> { + indexWriter.addDocument(NumberFieldMapper.NumberType.DOUBLE.createFields(DOUBLE_FIELD_NAME, 0.1, true, true, false)); + indexWriter.addDocument(NumberFieldMapper.NumberType.DOUBLE.createFields(DOUBLE_FIELD_NAME, 1.1, true, true, false)); + indexWriter.addDocument(NumberFieldMapper.NumberType.DOUBLE.createFields(DOUBLE_FIELD_NAME, 2.1, true, true, false)); }, range -> { List ranges = range.getBuckets(); assertEquals(2, ranges.size()); @@ -330,7 +367,26 @@ public void testDoubleType() throws IOException { assertEquals("2.0-3.0", ranges.get(1).getKeyAsString()); assertEquals(1, ranges.get(1).getDocCount()); assertTrue(AggregationInspectionHelper.hasValue(range)); - }, new NumberFieldMapper.NumberFieldType(DOUBLE_FIELD_NAME, NumberFieldMapper.NumberType.DOUBLE)); + }, new NumberFieldMapper.NumberFieldType(DOUBLE_FIELD_NAME, NumberFieldMapper.NumberType.DOUBLE), true); + } + + public void testConvertDoubleToBigInteger() { + double value = Double.NaN; + BigInteger result = FastFilterRewriteHelper.convertDoubleToBigInteger(value); + assertEquals(BigInteger.ZERO, result); + + value = Double.POSITIVE_INFINITY; + result = FastFilterRewriteHelper.convertDoubleToBigInteger(value); + BigInteger maxUnsignedLong = BigInteger.valueOf(2).pow(64).subtract(BigInteger.ONE); + assertEquals(maxUnsignedLong, result); + + value = Double.NEGATIVE_INFINITY; + result = FastFilterRewriteHelper.convertDoubleToBigInteger(value); + assertEquals(BigInteger.ZERO, result); + + value = 123.456; + result = FastFilterRewriteHelper.convertDoubleToBigInteger(value); + assertEquals("123", result.toString()); } private void testCase( @@ -391,7 +447,8 @@ private void testRewriteOptimizationCase( Query query, CheckedConsumer buildIndex, Consumer> verify, - MappedFieldType fieldType + MappedFieldType fieldType, + boolean optimized ) throws IOException { try (Directory directory = newDirectory()) { try (IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()))) { @@ -401,14 +458,54 @@ private void testRewriteOptimizationCase( try (IndexReader indexReader = DirectoryReader.open(directory)) { IndexSearcher indexSearcher = newSearcher(indexReader, true, true); - InternalRange agg = searchAndReduce( - indexSearcher, - query, - aggregationBuilder, - fieldType + CountingAggregator aggregator = createCountingAggregator(query, aggregationBuilder, indexSearcher, fieldType); + aggregator.preCollection(); + indexSearcher.search(query, aggregator); + aggregator.postCollection(); + + MultiBucketConsumerService.MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumerService.MultiBucketConsumer( + Integer.MAX_VALUE, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) ); + InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forFinalReduction( + aggregator.context().bigArrays(), + getMockScriptService(), + reduceBucketConsumer, + PipelineAggregator.PipelineTree.EMPTY + ); + InternalRange topLevel = (InternalRange) aggregator.buildTopLevel(); + InternalRange agg = (InternalRange) topLevel.reduce(Collections.singletonList(topLevel), context); + doAssertReducedMultiBucketConsumer(agg, reduceBucketConsumer); + verify.accept(agg); + + if (optimized) { + assertEquals(0, aggregator.getCollectCount().get()); + } else { + assertTrue(aggregator.getCollectCount().get() > 0); + } } } } + + protected CountingAggregator createCountingAggregator( + Query query, + AggregationBuilder builder, + IndexSearcher searcher, + MappedFieldType... fieldTypes + ) throws IOException { + return new CountingAggregator( + new AtomicInteger(), + createAggregator( + query, + builder, + searcher, + new MultiBucketConsumerService.MultiBucketConsumer( + DEFAULT_MAX_BUCKETS, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ), + fieldTypes + ) + ); + } }