From 8fc9beebf00e9d55ff83303684e582480470f469 Mon Sep 17 00:00:00 2001
From: bowenlan-amzn <bowenlan23@gmail.com>
Date: Fri, 24 May 2024 10:53:25 -0700
Subject: [PATCH] utilize competitive iterator api to perform pruning

Signed-off-by: bowenlan-amzn <bowenlan23@gmail.com>
---
 .../metrics/CardinalityAggregator.java        | 116 +++++++++++++++++-
 .../DynamicPruningCollectorWrapper.java       |   3 +-
 .../metrics/CardinalityAggregatorTests.java   |  14 ++-
 3 files changed, 125 insertions(+), 8 deletions(-)

diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java
index dd2e7458d81d2..7138342ec2ca4 100644
--- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java
+++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java
@@ -33,9 +33,19 @@
 package org.opensearch.search.aggregations.metrics;
 
 import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.PostingsEnum;
 import org.apache.lucene.index.SortedNumericDocValues;
 import org.apache.lucene.index.SortedSetDocValues;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.search.DisiPriorityQueue;
+import org.apache.lucene.search.DisiWrapper;
+import org.apache.lucene.search.DisjunctionDISIApproximation;
+import org.apache.lucene.search.DocIdSetIterator;
 import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.search.Weight;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.FixedBitSet;
 import org.apache.lucene.util.RamUsageEstimator;
@@ -59,6 +69,7 @@
 import org.opensearch.search.internal.SearchContext;
 
 import java.io.IOException;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.function.BiConsumer;
 
@@ -137,8 +148,10 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException {
             // only use ordinals if they don't increase memory usage by more than 25%
             if (ordinalsMemoryUsage < countsMemoryUsage / 4) {
                 ordinalsCollectorsUsed++;
-                return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()),
-                    context, ctx, fieldContext, source);
+                // return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()),
+                //     context, ctx, fieldContext, source);
+                return new CompetitiveCollector(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()),
+                    source, ctx, context, fieldContext);
             }
             ordinalsCollectorsOverheadTooHigh++;
         }
@@ -217,6 +230,105 @@ abstract static class Collector extends LeafBucketCollector implements Releasabl
 
     }
 
+    private static class CompetitiveCollector extends Collector {
+
+        private final Collector delegate;
+        private final DisiPriorityQueue pq;
+
+        CompetitiveCollector(Collector delegate, ValuesSource.Bytes.WithOrdinals source, LeafReaderContext ctx,
+                             SearchContext context, FieldContext fieldContext) throws IOException {
+            this.delegate = delegate;
+
+            final SortedSetDocValues ordinalValues = source.ordinalsValues(ctx);
+            TermsEnum terms = ordinalValues.termsEnum();
+            Map<BytesRef, Scorer> postingMap = new HashMap<>();
+            while (terms.next() != null) {
+                BytesRef term = terms.term();
+
+                TermQuery termQuery = new TermQuery(new Term(fieldContext.field(), term));
+                Weight subWeight = context.searcher().createWeight(termQuery, ScoreMode.COMPLETE_NO_SCORES, 1f);
+                Scorer scorer = subWeight.scorer(ctx);
+
+                postingMap.put(term, scorer);
+            }
+            this.pq = new DisiPriorityQueue(postingMap.size());
+            for (Map.Entry<BytesRef, Scorer> entry : postingMap.entrySet()) {
+                pq.add(new DisiWrapper(entry.getValue()));
+            }
+        }
+
+        @Override
+        public void close() {
+            delegate.close();
+        }
+
+        @Override
+        public void collect(int doc, long owningBucketOrd) throws IOException {
+            delegate.collect(doc, owningBucketOrd);
+        }
+
+        @Override
+        public DocIdSetIterator competitiveIterator() throws IOException {
+            return new DisjunctionDISIWithPruning(pq);
+        }
+
+        @Override
+        public void postCollect() throws IOException {
+            delegate.postCollect();
+        }
+    }
+
+    private static class DisjunctionDISIWithPruning extends DocIdSetIterator {
+
+        final DisiPriorityQueue queue;
+
+        public DisjunctionDISIWithPruning(DisiPriorityQueue queue) {
+            this.queue = queue;
+        }
+
+        @Override
+        public int docID() {
+            return queue.top().doc;
+        }
+
+        @Override
+        public int nextDoc() throws IOException {
+            // don't expect this to be called
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public int advance(int target) throws IOException {
+            // more than advance to the next doc >= target
+            // we also do the pruning of current doc here
+
+            DisiWrapper top = queue.top();
+
+            // after collecting the doc, before advancing to target
+            // we can safely remove all the iterators that having this doc
+            if (top.doc != -1) {
+                int curTopDoc = top.doc;
+                do {
+                    top.doc = top.approximation.advance(Integer.MAX_VALUE);
+                    top = queue.updateTop();
+                } while (top.doc == curTopDoc);
+            }
+
+            if (top.doc >= target) return top.doc;
+            do {
+                top.doc = top.approximation.advance(target);
+                top = queue.updateTop();
+            } while (top.doc < target);
+            return top.doc;
+        }
+
+        @Override
+        public long cost() {
+            // don't expect this to be called
+            throw new UnsupportedOperationException();
+        }
+    }
+
     /**
      * Empty Collector for the Cardinality agg
      *
diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java
index f4c3d59a3833f..5b18168c9d874 100644
--- a/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java
+++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java
@@ -10,6 +10,7 @@
 
 import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.PostingsEnum;
 import org.apache.lucene.index.SortedSetDocValues;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.index.TermsEnum;
@@ -52,7 +53,7 @@ class DynamicPruningCollectorWrapper extends CardinalityAggregator.Collector {
             // this logic should be pluggable depending on the type of leaf bucket collector by CardinalityAggregator
             TermsEnum terms = ordinalValues.termsEnum();
             Weight weight = context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE_NO_SCORES, 1f);
-            Map<Long, Boolean> found = new HashMap<>();
+            Map<Long, Boolean> found = new HashMap<>(); // ord : found or not
             List<Scorer> subScorers = new ArrayList<>();
             while (terms.next() != null && !found.containsKey(terms.ord())) {
                 // TODO can we get rid of terms previously encountered in other segments?
diff --git a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java
index a9966c9e70e76..73e5093618bb7 100644
--- a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java
+++ b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java
@@ -104,7 +104,9 @@ public void testDynamicPruningOrdinalCollector() throws IOException {
 
         MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName);
         final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName);
-        testAggregation(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> {
+        testAggregation(aggregationBuilder,
+            new TermQuery(new Term(filterFieldName, "foo")),
+            iw -> {
             iw.addDocument(asList(
                 new KeywordField(fieldName, "1", Field.Store.NO),
                 new KeywordField(fieldName, "2", Field.Store.NO),
@@ -142,10 +144,12 @@ public void testDynamicPruningOrdinalCollector() throws IOException {
                 new KeywordField(filterFieldName, "bar", Field.Store.NO),
                 new SortedSetDocValuesField(fieldName, new BytesRef("5"))
             ));
-            }, card -> {
-            assertEquals(3.0, card.getValue(), 0);
-            assertTrue(AggregationInspectionHelper.hasValue(card));
-        }, fieldType);
+            },
+            card -> {
+                assertEquals(3.0, card.getValue(), 0);
+                assertTrue(AggregationInspectionHelper.hasValue(card));
+            },
+            fieldType);
     }
 
     public void testNoMatchingField() throws IOException {