diff --git a/server/src/internalClusterTest/java/org/opensearch/search/aggregations/metrics/CardinalityIT.java b/server/src/internalClusterTest/java/org/opensearch/search/aggregations/metrics/CardinalityIT.java index db4ee3571d141..b2ed689622e7d 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/aggregations/metrics/CardinalityIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/aggregations/metrics/CardinalityIT.java @@ -34,6 +34,7 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsResponse; import org.opensearch.action.index.IndexRequestBuilder; import org.opensearch.action.search.SearchResponse; import org.opensearch.common.settings.Settings; @@ -59,6 +60,7 @@ import static java.util.Collections.emptyMap; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; import static org.opensearch.index.query.QueryBuilders.matchAllQuery; +import static org.opensearch.search.SearchService.CARDINALITY_AGGREGATION_PRUNING_THRESHOLD; import static org.opensearch.search.SearchService.CLUSTER_CONCURRENT_SEGMENT_SEARCH_SETTING; import static org.opensearch.search.aggregations.AggregationBuilders.cardinality; import static org.opensearch.search.aggregations.AggregationBuilders.global; @@ -255,6 +257,36 @@ public void testSingleValuedString() throws Exception { assertCount(count, numDocs); } + public void testDisableDynamicPruning() throws Exception { + SearchResponse response = client().prepareSearch("idx") + .addAggregation(cardinality("cardinality").precisionThreshold(precisionThreshold).field("str_value")) + .get(); + assertSearchResponse(response); + + Cardinality count1 = response.getAggregations().get("cardinality"); + + final ClusterUpdateSettingsResponse updateSettingResponse = client().admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings(Settings.builder().put(CARDINALITY_AGGREGATION_PRUNING_THRESHOLD.getKey(), 0)) + .get(); + assertEquals(updateSettingResponse.getTransientSettings().get(CARDINALITY_AGGREGATION_PRUNING_THRESHOLD.getKey()), "0"); + + response = client().prepareSearch("idx") + .addAggregation(cardinality("cardinality").precisionThreshold(precisionThreshold).field("str_value")) + .get(); + assertSearchResponse(response); + Cardinality count2 = response.getAggregations().get("cardinality"); + + assertEquals(count1, count2); + + client().admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings(Settings.builder().putNull(CARDINALITY_AGGREGATION_PRUNING_THRESHOLD.getKey())) + .get(); + } + public void testSingleValuedNumeric() throws Exception { SearchResponse response = client().prepareSearch("idx") .addAggregation(cardinality("cardinality").precisionThreshold(precisionThreshold).field(singleNumericField())) diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index e4cd3c729389b..7ea04acf00415 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -540,6 +540,7 @@ public void apply(Settings value, Settings current, Settings previous) { SearchService.MAX_OPEN_PIT_CONTEXT, SearchService.MAX_PIT_KEEPALIVE_SETTING, SearchService.MAX_AGGREGATION_REWRITE_FILTERS, + SearchService.CARDINALITY_AGGREGATION_PRUNING_THRESHOLD, CreatePitController.PIT_INIT_KEEP_ALIVE, Node.WRITE_PORTS_FILE_SETTING, Node.NODE_NAME_SETTING, diff --git a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java index cd8714f6b556a..abb968c2de245 100644 --- a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java @@ -106,6 +106,7 @@ import java.util.function.Function; import java.util.function.LongSupplier; +import static org.opensearch.search.SearchService.CARDINALITY_AGGREGATION_PRUNING_THRESHOLD; import static org.opensearch.search.SearchService.CLUSTER_CONCURRENT_SEGMENT_SEARCH_SETTING; import static org.opensearch.search.SearchService.MAX_AGGREGATION_REWRITE_FILTERS; @@ -189,6 +190,7 @@ final class DefaultSearchContext extends SearchContext { private final boolean concurrentSearchSettingsEnabled; private final SetOnce requestShouldUseConcurrentSearch = new SetOnce<>(); private final int maxAggRewriteFilters; + private final int cardinalityAggregationPruningThreshold; DefaultSearchContext( ReaderContext readerContext, @@ -244,6 +246,7 @@ final class DefaultSearchContext extends SearchContext { this.requestToAggReduceContextBuilder = requestToAggReduceContextBuilder; this.maxAggRewriteFilters = evaluateFilterRewriteSetting(); + this.cardinalityAggregationPruningThreshold = evaluateCardinalityAggregationPruningThreshold(); } @Override @@ -1010,4 +1013,16 @@ private int evaluateFilterRewriteSetting() { } return 0; } + + @Override + public int cardinalityAggregationPruningThreshold() { + return cardinalityAggregationPruningThreshold; + } + + private int evaluateCardinalityAggregationPruningThreshold() { + if (clusterService != null) { + return clusterService.getClusterSettings().get(CARDINALITY_AGGREGATION_PRUNING_THRESHOLD); + } + return 0; + } } diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index 45f111d889522..135af91912e5d 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -288,6 +288,15 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv Property.NodeScope ); + // value 0 can disable dynamic pruning optimization in cardinality aggregation + public static final Setting CARDINALITY_AGGREGATION_PRUNING_THRESHOLD = Setting.intSetting( + "search.dynamic_pruning.cardinality_aggregation.max_allowed_cardinality", + 100, + 0, + Property.Dynamic, + Property.NodeScope + ); + public static final int DEFAULT_SIZE = 10; public static final int DEFAULT_FROM = 0; diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java index 99c4eaac4b777..0f3d975960364 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java @@ -32,13 +32,28 @@ package org.opensearch.search.aggregations.metrics; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.DisiPriorityQueue; +import org.apache.lucene.search.DisiWrapper; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.RamUsageEstimator; +import org.opensearch.OpenSearchStatusException; import org.opensearch.common.Nullable; import org.opensearch.common.hash.MurmurHash3; import org.opensearch.common.lease.Releasable; @@ -48,6 +63,7 @@ import org.opensearch.common.util.BitMixer; import org.opensearch.common.util.LongArray; import org.opensearch.common.util.ObjectArray; +import org.opensearch.core.rest.RestStatus; import org.opensearch.index.fielddata.SortedBinaryDocValues; import org.opensearch.index.fielddata.SortedNumericDoubleValues; import org.opensearch.search.aggregations.Aggregator; @@ -58,9 +74,12 @@ import org.opensearch.search.internal.SearchContext; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.function.BiConsumer; +import static org.opensearch.search.SearchService.CARDINALITY_AGGREGATION_PRUNING_THRESHOLD; + /** * An aggregator that computes approximate counts of unique values. * @@ -68,9 +87,13 @@ */ public class CardinalityAggregator extends NumericMetricsAggregator.SingleValue { + private static final Logger logger = LogManager.getLogger(CardinalityAggregator.class); + private final int precision; private final ValuesSource valuesSource; + private final ValuesSourceConfig valuesSourceConfig; + // Expensive to initialize, so we only initialize it when we have an actual value source @Nullable private HyperLogLogPlusPlus counts; @@ -82,6 +105,7 @@ public class CardinalityAggregator extends NumericMetricsAggregator.SingleValue private int ordinalsCollectorsUsed; private int ordinalsCollectorsOverheadTooHigh; private int stringHashingCollectorsUsed; + private int dynamicPrunedSegments; public CardinalityAggregator( String name, @@ -96,6 +120,7 @@ public CardinalityAggregator( this.valuesSource = valuesSourceConfig.hasValues() ? valuesSourceConfig.getValuesSource() : null; this.precision = precision; this.counts = valuesSource == null ? null : new HyperLogLogPlusPlus(precision, context.bigArrays(), 1); + this.valuesSourceConfig = valuesSourceConfig; } @Override @@ -118,6 +143,7 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException { return new DirectCollector(counts, hashValues); } + Collector collector = null; if (valuesSource instanceof ValuesSource.Bytes.WithOrdinals) { ValuesSource.Bytes.WithOrdinals source = (ValuesSource.Bytes.WithOrdinals) valuesSource; final SortedSetDocValues ordinalValues = source.ordinalsValues(ctx); @@ -125,20 +151,109 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException { if (maxOrd == 0) { emptyCollectorsUsed++; return new EmptyCollector(); + } else { + final long ordinalsMemoryUsage = OrdinalsCollector.memoryOverhead(maxOrd); + final long countsMemoryUsage = HyperLogLogPlusPlus.memoryUsage(precision); + // only use ordinals if they don't increase memory usage by more than 25% + if (ordinalsMemoryUsage < countsMemoryUsage / 4) { + ordinalsCollectorsUsed++; + collector = new OrdinalsCollector(counts, ordinalValues, context.bigArrays()); + } else { + ordinalsCollectorsOverheadTooHigh++; + } + } + } + + if (collector == null) { // not able to build an OrdinalsCollector + stringHashingCollectorsUsed++; + collector = new DirectCollector(counts, MurmurHash3Values.hash(valuesSource.bytesValues(ctx))); + } + + if (canPrune(parent, subAggregators, valuesSourceConfig)) { + Terms terms = ctx.reader().terms(valuesSourceConfig.fieldContext().field()); + if (terms == null) return collector; + if (exceedMaxThreshold(terms)) { + return collector; + } + + Collector pruningCollector = tryWrapWithPruningCollector(collector, terms, ctx); + if (pruningCollector == null) { + return collector; } - final long ordinalsMemoryUsage = OrdinalsCollector.memoryOverhead(maxOrd); - final long countsMemoryUsage = HyperLogLogPlusPlus.memoryUsage(precision); - // only use ordinals if they don't increase memory usage by more than 25% - if (ordinalsMemoryUsage < countsMemoryUsage / 4) { - ordinalsCollectorsUsed++; - return new OrdinalsCollector(counts, ordinalValues, context.bigArrays()); + if (!tryScoreWithPruningCollector(ctx, pruningCollector)) { + return collector; + } + logger.debug("Dynamic pruned segment {} of shard {}", ctx.ord, context.indexShard().shardId()); + dynamicPrunedSegments++; + + return getNoOpCollector(); + } + + return collector; + } + + private boolean canPrune(Aggregator parent, Aggregator[] subAggregators, ValuesSourceConfig valuesSourceConfig) { + return parent == null && subAggregators.length == 0 && valuesSourceConfig.missing() == null && valuesSourceConfig.script() == null; + } + + private boolean exceedMaxThreshold(Terms terms) throws IOException { + if (terms.size() > context.cardinalityAggregationPruningThreshold()) { + logger.debug( + "Cannot prune because terms size {} is greater than the threshold {}", + terms.size(), + context.cardinalityAggregationPruningThreshold() + ); + return true; + } + return false; + } + + private Collector tryWrapWithPruningCollector(Collector collector, Terms terms, LeafReaderContext ctx) { + try { + return new PruningCollector(collector, terms.iterator(), ctx, context, valuesSourceConfig.fieldContext().field()); + } catch (Exception e) { + logger.warn("Failed to build collector for dynamic pruning.", e); + return null; + } + } + + private boolean tryScoreWithPruningCollector(LeafReaderContext ctx, Collector pruningCollector) throws IOException { + try { + Weight weight = context.query().rewrite(context.searcher()).createWeight(context.searcher(), ScoreMode.TOP_DOCS, 1f); + BulkScorer scorer = weight.bulkScorer(ctx); + if (scorer == null) { + return false; } - ordinalsCollectorsOverheadTooHigh++; + Bits liveDocs = ctx.reader().getLiveDocs(); + scorer.score(pruningCollector, liveDocs); + pruningCollector.postCollect(); + Releasables.close(pruningCollector); + } catch (Exception e) { + throw new OpenSearchStatusException( + "Failed when performing dynamic pruning in cardinality aggregation. You can set cluster setting [" + + CARDINALITY_AGGREGATION_PRUNING_THRESHOLD.getKey() + + "] to 0 to disable.", + RestStatus.INTERNAL_SERVER_ERROR, + e + ); } + return true; + } - stringHashingCollectorsUsed++; - return new DirectCollector(counts, MurmurHash3Values.hash(valuesSource.bytesValues(ctx))); + private Collector getNoOpCollector() { + return new Collector() { + @Override + public void close() {} + + @Override + public void postCollect() throws IOException {} + + @Override + public void collect(int doc, long owningBucketOrd) throws IOException { + throw new CollectionTerminatedException(); + } + }; } @Override @@ -175,7 +290,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) { if (counts == null || owningBucketOrdinal >= counts.maxOrd() || counts.cardinality(owningBucketOrdinal) == 0) { return buildEmptyAggregation(); } - // We need to build a copy because the returned Aggregation needs remain usable after + // We need to build a copy because the returned Aggregation needs to remain usable after // this Aggregator (and its HLL++ counters) is released. AbstractHyperLogLogPlusPlus copy = counts.clone(owningBucketOrdinal, BigArrays.NON_RECYCLING_INSTANCE); return new InternalCardinality(name, copy, metadata()); @@ -199,6 +314,7 @@ public void collectDebugInfo(BiConsumer add) { add.accept("ordinals_collectors_used", ordinalsCollectorsUsed); add.accept("ordinals_collectors_overhead_too_high", ordinalsCollectorsOverheadTooHigh); add.accept("string_hashing_collectors_used", stringHashingCollectorsUsed); + add.accept("dynamic_pruned_segments", dynamicPrunedSegments); } /** @@ -212,6 +328,130 @@ private abstract static class Collector extends LeafBucketCollector implements R } + /** + * This collector enhance the delegate collector with pruning ability on term field + * The iterators of term field values are wrapped into a priority queue, and able to + * pop/prune the values after being collected + */ + private static class PruningCollector extends Collector { + + private final Collector delegate; + private final DisiPriorityQueue queue; + private final DocIdSetIterator competitiveIterator; + + PruningCollector(Collector delegate, TermsEnum terms, LeafReaderContext ctx, SearchContext context, String field) + throws IOException { + this.delegate = delegate; + + Map postingMap = new HashMap<>(); + while (terms.next() != null) { + BytesRef term = terms.term(); + TermQuery termQuery = new TermQuery(new Term(field, term)); + Weight subWeight = termQuery.createWeight(context.searcher(), ScoreMode.COMPLETE_NO_SCORES, 1f); + Scorer scorer = subWeight.scorer(ctx); + postingMap.put(term, scorer); + } + + this.queue = new DisiPriorityQueue(postingMap.size()); + for (Scorer scorer : postingMap.values()) { + queue.add(new DisiWrapper(scorer)); + } + + competitiveIterator = new DisjunctionDISI(queue); + } + + @Override + public void close() { + delegate.close(); + } + + @Override + public void collect(int doc, long owningBucketOrd) throws IOException { + delegate.collect(doc, owningBucketOrd); + prune(doc); + } + + /** + * Note: the queue may be empty or the queue top may be null after pruning + */ + private void prune(int doc) { + DisiWrapper top = queue.top(); + int curTopDoc = top.doc; + if (curTopDoc == doc) { + do { + queue.pop(); + top = queue.updateTop(); + } while (queue.size() > 1 && top.doc == curTopDoc); + } + } + + @Override + public DocIdSetIterator competitiveIterator() { + return competitiveIterator; + } + + @Override + public void postCollect() throws IOException { + delegate.postCollect(); + } + } + + /** + * This DISI is a disjunction of all terms in a segment + * And it will be the competitive iterator of the leaf pruning collector + * After pruning done after collect, queue top doc may exceed the next doc of (lead) iterator + * To still providing a docID slower than the lead iterator for the next iteration + * We keep track of a slowDocId that will be updated later during advance + */ + private static class DisjunctionDISI extends DocIdSetIterator { + private final DisiPriorityQueue queue; + private int slowDocId = -1; + + public DisjunctionDISI(DisiPriorityQueue queue) { + this.queue = queue; + } + + @Override + public int docID() { + return slowDocId; + } + + @Override + public int advance(int target) throws IOException { + DisiWrapper top = queue.top(); + if (top == null) { + return slowDocId = NO_MORE_DOCS; + } + + // This would be the outcome of last pruning + // this DISI's docID is already making to the target + if (top.doc >= target) { + slowDocId = top.doc; + return top.doc; + } + + do { + top.doc = top.approximation.advance(target); + top = queue.updateTop(); + } while (top.doc < target); + slowDocId = queue.size() == 0 ? NO_MORE_DOCS : queue.top().doc; + + return slowDocId; + } + + @Override + public int nextDoc() { + // don't expect this to be called based on its usage in DefaultBulkScorer + throw new UnsupportedOperationException(); + } + + @Override + public long cost() { + // don't expect this to be called based on its usage in DefaultBulkScorer + throw new UnsupportedOperationException(); + } + } + /** * Empty Collector for the Cardinality agg * diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalCardinality.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalCardinality.java index 7e9511ffdd379..9f9ad63220fea 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalCardinality.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalCardinality.java @@ -117,7 +117,6 @@ public InternalAggregation reduce(List aggregations, Reduce return aggregations.get(0); } else { return new InternalCardinality(name, reduced, getMetadata()); - } } diff --git a/server/src/main/java/org/opensearch/search/internal/SearchContext.java b/server/src/main/java/org/opensearch/search/internal/SearchContext.java index 0c8240d3a8322..bc4b7058651dd 100644 --- a/server/src/main/java/org/opensearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/SearchContext.java @@ -522,4 +522,8 @@ public String toString() { public int maxAggRewriteFilters() { return 0; } + + public int cardinalityAggregationPruningThreshold() { + return 0; + } } diff --git a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java index cdd17e2fa7dd6..b5dd27e37c332 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java @@ -33,30 +33,56 @@ package org.opensearch.search.aggregations.metrics; import org.apache.lucene.document.BinaryDocValuesField; +import org.apache.lucene.document.Field; import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.KeywordField; import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.document.SortedSetDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.index.Term; import org.apache.lucene.search.DocValuesFieldExistsQuery; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.BytesRef; import org.opensearch.common.CheckedConsumer; import org.opensearch.common.geo.GeoPoint; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.indices.breaker.NoneCircuitBreakerService; +import org.opensearch.index.mapper.KeywordFieldMapper; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.mapper.RangeFieldMapper; import org.opensearch.index.mapper.RangeType; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregatorTestCase; +import org.opensearch.search.aggregations.InternalAggregation; +import org.opensearch.search.aggregations.MultiBucketConsumerService; +import org.opensearch.search.aggregations.pipeline.PipelineAggregator; import org.opensearch.search.aggregations.support.AggregationInspectionHelper; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; +import static java.util.Arrays.asList; import static java.util.Collections.singleton; +import static org.opensearch.test.InternalAggregationTestCase.DEFAULT_MAX_BUCKETS; +import static org.mockito.Mockito.when; public class CardinalityAggregatorTests extends AggregatorTestCase { @@ -199,4 +225,276 @@ private void testAggregation( ) throws IOException { testCase(aggregationBuilder, query, buildIndex, verify, fieldType); } + + public void testDynamicPruningDisabledWhenExceedingThreshold() throws IOException { + final String fieldName = "testField"; + final String filterFieldName = "filterField"; + + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName); + final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName); + + int randomCardinality = randomIntBetween(20, 100); + AtomicInteger counter = new AtomicInteger(); + + testDynamicPruning(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> { + for (int i = 0; i < randomCardinality; i++) { + String filterValue = "foo"; + if (randomBoolean()) { + filterValue = "bar"; + counter.getAndIncrement(); + } + iw.addDocument( + asList( + new KeywordField(filterFieldName, filterValue, Field.Store.NO), + new KeywordField(fieldName, String.valueOf(i), Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef(String.valueOf(i))) + ) + ); + } + }, + card -> { assertEquals(randomCardinality - counter.get(), card.getValue(), 0); }, + fieldType, + 10, + (collectCount) -> assertEquals(randomCardinality - counter.get(), (int) collectCount) + ); + } + + public void testDynamicPruningFixedValues() throws IOException { + final String fieldName = "testField"; + final String filterFieldName = "filterField"; + + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName); + final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName); + testDynamicPruning(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> { + iw.addDocument( + asList( + new KeywordField(fieldName, "1", Field.Store.NO), + new KeywordField(fieldName, "2", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("1")), + new SortedSetDocValuesField(fieldName, new BytesRef("2")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "2", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("2")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "1", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("1")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "2", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("2")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "3", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("3")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "4", Field.Store.NO), + new KeywordField(filterFieldName, "bar", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("4")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "5", Field.Store.NO), + new KeywordField(filterFieldName, "bar", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("5")) + ) + ); + }, card -> { + assertEquals(3.0, card.getValue(), 0); + assertTrue(AggregationInspectionHelper.hasValue(card)); + }, fieldType, 100, (collectCount) -> assertEquals(0, (int) collectCount)); + } + + public void testDynamicPruningRandomValues() throws IOException { + final String fieldName = "testField"; + final String filterFieldName = "filterField"; + + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName); + final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName); + + int randomCardinality = randomIntBetween(1, 100); + AtomicInteger counter = new AtomicInteger(); + + testDynamicPruning(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> { + for (int i = 0; i < randomCardinality; i++) { + String filterValue = "foo"; + if (randomBoolean()) { + filterValue = "bar"; + counter.getAndIncrement(); + } + iw.addDocument( + asList( + new KeywordField(filterFieldName, filterValue, Field.Store.NO), + new KeywordField(fieldName, String.valueOf(i), Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef(String.valueOf(i))) + ) + ); + } + }, card -> { + logger.info("expected {}, cardinality: {}", randomCardinality - counter.get(), card.getValue()); + assertEquals(randomCardinality - counter.get(), card.getValue(), 0); + }, fieldType, 100, (collectCount) -> assertEquals(0, (int) collectCount)); + } + + public void testDynamicPruningRandomDelete() throws IOException { + final String fieldName = "testField"; + + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName); + final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName); + + int randomCardinality = randomIntBetween(1, 100); + AtomicInteger counter = new AtomicInteger(); + + testDynamicPruning(aggregationBuilder, new MatchAllDocsQuery(), iw -> { + for (int i = 0; i < randomCardinality; i++) { + iw.addDocument( + asList( + new KeywordField(fieldName, String.valueOf(i), Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef(String.valueOf(i))) + ) + ); + if (randomBoolean()) { + iw.deleteDocuments(new Term(fieldName, String.valueOf(i))); + counter.getAndIncrement(); + } + } + }, + card -> { assertEquals(randomCardinality - counter.get(), card.getValue(), 0); }, + fieldType, + 100, + (collectCount) -> assertEquals(0, (int) collectCount) + ); + } + + public void testDynamicPruningFieldMissingInSegment() throws IOException { + final String fieldName = "testField"; + final String fieldName2 = "testField2"; + + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName); + final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName); + + int randomNumSegments = randomIntBetween(1, 50); + logger.info("Indexing [{}] segments", randomNumSegments); + + testDynamicPruning(aggregationBuilder, new MatchAllDocsQuery(), iw -> { + for (int i = 0; i < randomNumSegments; i++) { + iw.addDocument( + asList( + new KeywordField(fieldName, String.valueOf(i), Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef(String.valueOf(i))) + ) + ); + iw.commit(); + } + iw.addDocument(List.of(new KeywordField(fieldName2, "100", Field.Store.NO))); + iw.addDocument(List.of(new KeywordField(fieldName2, "101", Field.Store.NO))); + iw.addDocument(List.of(new KeywordField(fieldName2, "102", Field.Store.NO))); + iw.commit(); + }, + card -> { assertEquals(randomNumSegments, card.getValue(), 0); }, + fieldType, + 100, + (collectCount) -> assertEquals(3, (int) collectCount) + ); + } + + private void testDynamicPruning( + AggregationBuilder aggregationBuilder, + Query query, + CheckedConsumer buildIndex, + Consumer verify, + MappedFieldType fieldType, + int pruningThreshold, + Consumer verifyCollectCount + ) throws IOException { + try (Directory directory = newDirectory()) { + try ( + IndexWriter indexWriter = new IndexWriter( + directory, + new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()).setMergePolicy(NoMergePolicy.INSTANCE) + ) + ) { + // disable merge so segment number is same as commit times + buildIndex.accept(indexWriter); + } + + try (IndexReader indexReader = DirectoryReader.open(directory)) { + IndexSearcher indexSearcher = newSearcher(indexReader, true, true); + + CountingAggregator aggregator = createCountingAggregator( + query, + aggregationBuilder, + indexSearcher, + fieldType, + pruningThreshold + ); + aggregator.preCollection(); + indexSearcher.search(query, aggregator); + aggregator.postCollection(); + + MultiBucketConsumerService.MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumerService.MultiBucketConsumer( + Integer.MAX_VALUE, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ); + InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forFinalReduction( + aggregator.context().bigArrays(), + getMockScriptService(), + reduceBucketConsumer, + PipelineAggregator.PipelineTree.EMPTY + ); + InternalCardinality topLevel = (InternalCardinality) aggregator.buildTopLevel(); + InternalCardinality card = (InternalCardinality) topLevel.reduce(Collections.singletonList(topLevel), context); + doAssertReducedMultiBucketConsumer(card, reduceBucketConsumer); + + verify.accept(card); + + logger.info("aggregator collect count {}", aggregator.getCollectCount().get()); + verifyCollectCount.accept(aggregator.getCollectCount().get()); + } + } + } + + protected CountingAggregator createCountingAggregator( + Query query, + AggregationBuilder builder, + IndexSearcher searcher, + MappedFieldType fieldType, + int pruningThreshold + ) throws IOException { + return new CountingAggregator( + new AtomicInteger(), + createAggregatorWithCustomizableSearchContext( + query, + builder, + searcher, + createIndexSettings(), + new MultiBucketConsumerService.MultiBucketConsumer( + DEFAULT_MAX_BUCKETS, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ), + (searchContext) -> { + when(searchContext.cardinalityAggregationPruningThreshold()).thenReturn(pruningThreshold); + }, + fieldType + ) + ); + } } diff --git a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java index 02e5d22e147d5..50b27ec000615 100644 --- a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java @@ -304,6 +304,20 @@ protected A createAggregator( return createAggregator(aggregationBuilder, searchContext); } + protected A createAggregatorWithCustomizableSearchContext( + Query query, + AggregationBuilder aggregationBuilder, + IndexSearcher indexSearcher, + IndexSettings indexSettings, + MultiBucketConsumer bucketConsumer, + Consumer customizeSearchContext, + MappedFieldType... fieldTypes + ) throws IOException { + SearchContext searchContext = createSearchContext(indexSearcher, indexSettings, query, bucketConsumer, fieldTypes); + customizeSearchContext.accept(searchContext); + return createAggregator(aggregationBuilder, searchContext); + } + protected A createAggregator(AggregationBuilder aggregationBuilder, SearchContext searchContext) throws IOException { @SuppressWarnings("unchecked")