From 8fc9beebf00e9d55ff83303684e582480470f469 Mon Sep 17 00:00:00 2001 From: bowenlan-amzn <bowenlan23@gmail.com> Date: Fri, 24 May 2024 10:53:25 -0700 Subject: [PATCH] utilize competitive iterator api to perform pruning Signed-off-by: bowenlan-amzn <bowenlan23@gmail.com> --- .../metrics/CardinalityAggregator.java | 116 +++++++++++++++++- .../DynamicPruningCollectorWrapper.java | 3 +- .../metrics/CardinalityAggregatorTests.java | 14 ++- 3 files changed, 125 insertions(+), 8 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java index dd2e7458d81d2..7138342ec2ca4 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java @@ -33,9 +33,19 @@ package org.opensearch.search.aggregations.metrics; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.DisiPriorityQueue; +import org.apache.lucene.search.DisiWrapper; +import org.apache.lucene.search.DisjunctionDISIApproximation; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.Weight; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.RamUsageEstimator; @@ -59,6 +69,7 @@ import org.opensearch.search.internal.SearchContext; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.function.BiConsumer; @@ -137,8 +148,10 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException { // only use ordinals if they don't increase memory usage by more than 25% if (ordinalsMemoryUsage < countsMemoryUsage / 4) { ordinalsCollectorsUsed++; - return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()), - context, ctx, fieldContext, source); + // return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()), + // context, ctx, fieldContext, source); + return new CompetitiveCollector(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()), + source, ctx, context, fieldContext); } ordinalsCollectorsOverheadTooHigh++; } @@ -217,6 +230,105 @@ abstract static class Collector extends LeafBucketCollector implements Releasabl } + private static class CompetitiveCollector extends Collector { + + private final Collector delegate; + private final DisiPriorityQueue pq; + + CompetitiveCollector(Collector delegate, ValuesSource.Bytes.WithOrdinals source, LeafReaderContext ctx, + SearchContext context, FieldContext fieldContext) throws IOException { + this.delegate = delegate; + + final SortedSetDocValues ordinalValues = source.ordinalsValues(ctx); + TermsEnum terms = ordinalValues.termsEnum(); + Map<BytesRef, Scorer> postingMap = new HashMap<>(); + while (terms.next() != null) { + BytesRef term = terms.term(); + + TermQuery termQuery = new TermQuery(new Term(fieldContext.field(), term)); + Weight subWeight = context.searcher().createWeight(termQuery, ScoreMode.COMPLETE_NO_SCORES, 1f); + Scorer scorer = subWeight.scorer(ctx); + + postingMap.put(term, scorer); + } + this.pq = new DisiPriorityQueue(postingMap.size()); + for (Map.Entry<BytesRef, Scorer> entry : postingMap.entrySet()) { + pq.add(new DisiWrapper(entry.getValue())); + } + } + + @Override + public void close() { + delegate.close(); + } + + @Override + public void collect(int doc, long owningBucketOrd) throws IOException { + delegate.collect(doc, owningBucketOrd); + } + + @Override + public DocIdSetIterator competitiveIterator() throws IOException { + return new DisjunctionDISIWithPruning(pq); + } + + @Override + public void postCollect() throws IOException { + delegate.postCollect(); + } + } + + private static class DisjunctionDISIWithPruning extends DocIdSetIterator { + + final DisiPriorityQueue queue; + + public DisjunctionDISIWithPruning(DisiPriorityQueue queue) { + this.queue = queue; + } + + @Override + public int docID() { + return queue.top().doc; + } + + @Override + public int nextDoc() throws IOException { + // don't expect this to be called + throw new UnsupportedOperationException(); + } + + @Override + public int advance(int target) throws IOException { + // more than advance to the next doc >= target + // we also do the pruning of current doc here + + DisiWrapper top = queue.top(); + + // after collecting the doc, before advancing to target + // we can safely remove all the iterators that having this doc + if (top.doc != -1) { + int curTopDoc = top.doc; + do { + top.doc = top.approximation.advance(Integer.MAX_VALUE); + top = queue.updateTop(); + } while (top.doc == curTopDoc); + } + + if (top.doc >= target) return top.doc; + do { + top.doc = top.approximation.advance(target); + top = queue.updateTop(); + } while (top.doc < target); + return top.doc; + } + + @Override + public long cost() { + // don't expect this to be called + throw new UnsupportedOperationException(); + } + } + /** * Empty Collector for the Cardinality agg * diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java index f4c3d59a3833f..5b18168c9d874 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java @@ -10,6 +10,7 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermsEnum; @@ -52,7 +53,7 @@ class DynamicPruningCollectorWrapper extends CardinalityAggregator.Collector { // this logic should be pluggable depending on the type of leaf bucket collector by CardinalityAggregator TermsEnum terms = ordinalValues.termsEnum(); Weight weight = context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE_NO_SCORES, 1f); - Map<Long, Boolean> found = new HashMap<>(); + Map<Long, Boolean> found = new HashMap<>(); // ord : found or not List<Scorer> subScorers = new ArrayList<>(); while (terms.next() != null && !found.containsKey(terms.ord())) { // TODO can we get rid of terms previously encountered in other segments? diff --git a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java index a9966c9e70e76..73e5093618bb7 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java @@ -104,7 +104,9 @@ public void testDynamicPruningOrdinalCollector() throws IOException { MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName); final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName); - testAggregation(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> { + testAggregation(aggregationBuilder, + new TermQuery(new Term(filterFieldName, "foo")), + iw -> { iw.addDocument(asList( new KeywordField(fieldName, "1", Field.Store.NO), new KeywordField(fieldName, "2", Field.Store.NO), @@ -142,10 +144,12 @@ public void testDynamicPruningOrdinalCollector() throws IOException { new KeywordField(filterFieldName, "bar", Field.Store.NO), new SortedSetDocValuesField(fieldName, new BytesRef("5")) )); - }, card -> { - assertEquals(3.0, card.getValue(), 0); - assertTrue(AggregationInspectionHelper.hasValue(card)); - }, fieldType); + }, + card -> { + assertEquals(3.0, card.getValue(), 0); + assertTrue(AggregationInspectionHelper.hasValue(card)); + }, + fieldType); } public void testNoMatchingField() throws IOException {