Skip to content

Commit

Permalink
now fix to do all the scoring within Cardinality
Browse files Browse the repository at this point in the history
Signed-off-by: bowenlan-amzn <[email protected]>
  • Loading branch information
bowenlan-amzn committed Jun 6, 2024
1 parent f35ac3a commit 268e1cd
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DisiPriorityQueue;
import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.RamUsageEstimator;
Expand Down Expand Up @@ -159,7 +162,14 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException {
if (parent == null && subAggregators.length == 0 && valuesSourceConfig.missing() == null && valuesSourceConfig.script() == null) {
Terms terms = ctx.reader().terms(valuesSourceConfig.fieldContext().field());
if (terms != null) {
Weight weight = context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.TOP_DOCS, 1f);
Bits liveDocs = ctx.reader().getLiveDocs();
BulkScorer scorer = weight.bulkScorer(ctx);
collector = new PruningCollector(collector, terms.iterator(), ctx, context, valuesSourceConfig.fieldContext().field());
scorer.score(collector, liveDocs);
collector.postCollect();
Releasables.close(collector);
throw new CollectionTerminatedException();
}
}

Expand Down Expand Up @@ -240,7 +250,7 @@ private abstract static class Collector extends LeafBucketCollector implements R
private static class PruningCollector extends Collector {

private final Collector delegate;
private final DisiPriorityQueue pq;
private final DisiPriorityQueue queue;

PruningCollector(Collector delegate, TermsEnum terms, LeafReaderContext ctx, SearchContext context, String field)
throws IOException {
Expand All @@ -255,9 +265,9 @@ private static class PruningCollector extends Collector {
postingMap.put(term, scorer);
}

this.pq = new DisiPriorityQueue(postingMap.size());
this.queue = new DisiPriorityQueue(postingMap.size());
for (Map.Entry<BytesRef, Scorer> entry : postingMap.entrySet()) {
pq.add(new DisiWrapper(entry.getValue()));
queue.add(new DisiWrapper(entry.getValue()));
}
}

Expand All @@ -269,11 +279,27 @@ public void close() {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
delegate.collect(doc, owningBucketOrd);

DisiWrapper top = queue.top();
int curTopDoc = top.doc;
if (curTopDoc == doc) {
do {
queue.pop();
top = queue.updateTop();
} while (queue.size() > 1 && top.doc == curTopDoc);
}
// after pruning, queue top doc will exceed current doc
// but in the scoring process, we have a contract that every iteration or nextDoc
// iterator: 0 - 1 - 2 - 3 - ...
// every doc, collect and next
// with competitiveIterator coming in, we assert it's not null and before or match current iterator
// if it's before, we advance it, if then matched, good, we collect and go to next
// if it's going over current doc, we don't collect and go next, instead we advance iterator!
}

@Override
public DocIdSetIterator competitiveIterator() {
return new DisjunctionDISIWithPruning(pq);
return new SlowDocIdPropagatorDISI(new DisjunctionDISI(queue), -1);
}

@Override
Expand All @@ -282,11 +308,11 @@ public void postCollect() throws IOException {
}
}

private static class DisjunctionDISIWithPruning extends DocIdSetIterator {
private static class DisjunctionDISI extends DocIdSetIterator {

final DisiPriorityQueue queue;

public DisjunctionDISIWithPruning(DisiPriorityQueue queue) {
public DisjunctionDISI(DisiPriorityQueue queue) {
this.queue = queue;
}

Expand All @@ -310,16 +336,6 @@ public int nextDoc() throws IOException {
@Override
public int advance(int target) throws IOException {
DisiWrapper top = queue.top();
// don't do the pruning if this iterator hasn't been used yet
if (top.doc != -1) {
int curTopDoc = top.doc;
do {
top.doc = top.approximation.advance(Integer.MAX_VALUE); // prune
top = queue.updateTop();
} while (top.doc == curTopDoc); // there may be multiple subScorers on current doc
}

if (top.doc >= target) return top.doc;
do {
top.doc = top.approximation.advance(target);
top = queue.updateTop();
Expand All @@ -335,6 +351,45 @@ public long cost() {
}
}

private static class SlowDocIdPropagatorDISI extends DocIdSetIterator {
DocIdSetIterator disi;

Integer curDocId;

SlowDocIdPropagatorDISI(DocIdSetIterator disi, Integer curDocId) {
this.disi = disi;
this.curDocId = curDocId;
}

@Override
public int docID() {
assert curDocId <= disi.docID();
return curDocId;
}

@Override
public int nextDoc() throws IOException {
return advance(curDocId + 1);
}

@Override
public int advance(int i) throws IOException {
if (i <= disi.docID()) {
// since we are slow propagating docIDs, it may happen the disi is already advanced to a higher docID than i
// in such scenarios we can simply return the docID where disi is pointing to and update the curDocId
curDocId = disi.docID();
return disi.docID();
}
curDocId = disi.advance(i);
return curDocId;
}

@Override
public long cost() {
return disi.cost();
}
}

/**
* Empty Collector for the Cardinality agg
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,19 @@
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BytesRef;
import org.opensearch.common.CheckedConsumer;
import org.opensearch.common.geo.GeoPoint;
Expand Down Expand Up @@ -104,7 +111,7 @@ public void testDynamicPruningOrdinalCollector() throws IOException {

MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName);
final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName);
testAggregation(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> {
testCase2(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> {
iw.addDocument(
asList(
new KeywordField(fieldName, "1", Field.Store.NO),
Expand Down Expand Up @@ -271,4 +278,25 @@ private void testAggregation(
) throws IOException {
testCase(aggregationBuilder, query, buildIndex, verify, fieldType);
}

protected void testCase2(
AggregationBuilder aggregationBuilder,
Query query,
CheckedConsumer<IndexWriter, IOException> buildIndex,
Consumer<InternalCardinality> verify,
MappedFieldType... fieldTypes
) throws IOException {
try (Directory directory = newDirectory()) {
IndexWriter indexWriter =new IndexWriter(directory, new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()));
buildIndex.accept(indexWriter);
indexWriter.close();

try (DirectoryReader unwrapped = DirectoryReader.open(directory); IndexReader indexReader = wrapDirectoryReader(unwrapped)) {
IndexSearcher indexSearcher = newIndexSearcher(indexReader);

InternalCardinality agg = searchAndReduce(indexSearcher, query, aggregationBuilder, fieldTypes);
verify.accept(agg);
}
}
}
}

0 comments on commit 268e1cd

Please sign in to comment.