diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java index ff3b5de45a534..26727ec41c502 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java @@ -38,6 +38,8 @@ import org.apache.lucene.index.Term; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.DisiPriorityQueue; import org.apache.lucene.search.DisiWrapper; import org.apache.lucene.search.DocIdSetIterator; @@ -45,6 +47,7 @@ import org.apache.lucene.search.Scorer; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.RamUsageEstimator; @@ -159,7 +162,14 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException { if (parent == null && subAggregators.length == 0 && valuesSourceConfig.missing() == null && valuesSourceConfig.script() == null) { Terms terms = ctx.reader().terms(valuesSourceConfig.fieldContext().field()); if (terms != null) { + Weight weight = context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.TOP_DOCS, 1f); + Bits liveDocs = ctx.reader().getLiveDocs(); + BulkScorer scorer = weight.bulkScorer(ctx); collector = new PruningCollector(collector, terms.iterator(), ctx, context, valuesSourceConfig.fieldContext().field()); + scorer.score(collector, liveDocs); + collector.postCollect(); + Releasables.close(collector); + throw new CollectionTerminatedException(); } } @@ -240,7 +250,7 @@ private abstract static class Collector extends LeafBucketCollector implements R private static class PruningCollector extends Collector { private final Collector delegate; - private final DisiPriorityQueue pq; + private final DisiPriorityQueue queue; PruningCollector(Collector delegate, TermsEnum terms, LeafReaderContext ctx, SearchContext context, String field) throws IOException { @@ -255,9 +265,9 @@ private static class PruningCollector extends Collector { postingMap.put(term, scorer); } - this.pq = new DisiPriorityQueue(postingMap.size()); + this.queue = new DisiPriorityQueue(postingMap.size()); for (Map.Entry entry : postingMap.entrySet()) { - pq.add(new DisiWrapper(entry.getValue())); + queue.add(new DisiWrapper(entry.getValue())); } } @@ -269,11 +279,27 @@ public void close() { @Override public void collect(int doc, long owningBucketOrd) throws IOException { delegate.collect(doc, owningBucketOrd); + + DisiWrapper top = queue.top(); + int curTopDoc = top.doc; + if (curTopDoc == doc) { + do { + queue.pop(); + top = queue.updateTop(); + } while (queue.size() > 1 && top.doc == curTopDoc); + } + // after pruning, queue top doc will exceed current doc + // but in the scoring process, we have a contract that every iteration or nextDoc + // iterator: 0 - 1 - 2 - 3 - ... + // every doc, collect and next + // with competitiveIterator coming in, we assert it's not null and before or match current iterator + // if it's before, we advance it, if then matched, good, we collect and go to next + // if it's going over current doc, we don't collect and go next, instead we advance iterator! } @Override public DocIdSetIterator competitiveIterator() { - return new DisjunctionDISIWithPruning(pq); + return new SlowDocIdPropagatorDISI(new DisjunctionDISI(queue), -1); } @Override @@ -282,11 +308,11 @@ public void postCollect() throws IOException { } } - private static class DisjunctionDISIWithPruning extends DocIdSetIterator { + private static class DisjunctionDISI extends DocIdSetIterator { final DisiPriorityQueue queue; - public DisjunctionDISIWithPruning(DisiPriorityQueue queue) { + public DisjunctionDISI(DisiPriorityQueue queue) { this.queue = queue; } @@ -310,16 +336,6 @@ public int nextDoc() throws IOException { @Override public int advance(int target) throws IOException { DisiWrapper top = queue.top(); - // don't do the pruning if this iterator hasn't been used yet - if (top.doc != -1) { - int curTopDoc = top.doc; - do { - top.doc = top.approximation.advance(Integer.MAX_VALUE); // prune - top = queue.updateTop(); - } while (top.doc == curTopDoc); // there may be multiple subScorers on current doc - } - - if (top.doc >= target) return top.doc; do { top.doc = top.approximation.advance(target); top = queue.updateTop(); @@ -335,6 +351,45 @@ public long cost() { } } + private static class SlowDocIdPropagatorDISI extends DocIdSetIterator { + DocIdSetIterator disi; + + Integer curDocId; + + SlowDocIdPropagatorDISI(DocIdSetIterator disi, Integer curDocId) { + this.disi = disi; + this.curDocId = curDocId; + } + + @Override + public int docID() { + assert curDocId <= disi.docID(); + return curDocId; + } + + @Override + public int nextDoc() throws IOException { + return advance(curDocId + 1); + } + + @Override + public int advance(int i) throws IOException { + if (i <= disi.docID()) { + // since we are slow propagating docIDs, it may happen the disi is already advanced to a higher docID than i + // in such scenarios we can simply return the docID where disi is pointing to and update the curDocId + curDocId = disi.docID(); + return disi.docID(); + } + curDocId = disi.advance(i); + return curDocId; + } + + @Override + public long cost() { + return disi.cost(); + } + } + /** * Empty Collector for the Cardinality agg * diff --git a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java index d21e7f6ed8550..6db5449883106 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java @@ -39,12 +39,19 @@ import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.document.SortedSetDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.Term; import org.apache.lucene.search.DocValuesFieldExistsQuery; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; +import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.BytesRef; import org.opensearch.common.CheckedConsumer; import org.opensearch.common.geo.GeoPoint; @@ -104,7 +111,7 @@ public void testDynamicPruningOrdinalCollector() throws IOException { MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName); final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName); - testAggregation(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> { + testCase2(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> { iw.addDocument( asList( new KeywordField(fieldName, "1", Field.Store.NO), @@ -271,4 +278,25 @@ private void testAggregation( ) throws IOException { testCase(aggregationBuilder, query, buildIndex, verify, fieldType); } + + protected void testCase2( + AggregationBuilder aggregationBuilder, + Query query, + CheckedConsumer buildIndex, + Consumer verify, + MappedFieldType... fieldTypes + ) throws IOException { + try (Directory directory = newDirectory()) { + IndexWriter indexWriter =new IndexWriter(directory, new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec())); + buildIndex.accept(indexWriter); + indexWriter.close(); + + try (DirectoryReader unwrapped = DirectoryReader.open(directory); IndexReader indexReader = wrapDirectoryReader(unwrapped)) { + IndexSearcher indexSearcher = newIndexSearcher(indexReader); + + InternalCardinality agg = searchAndReduce(indexSearcher, query, aggregationBuilder, fieldTypes); + verify.accept(agg); + } + } + } }