Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds in lazy execution for Lucene kNN queries
Browse files Browse the repository at this point in the history
Signed-off-by: Kunal Kotwani <kkotwani@amazon.com>
kotwanikunal committed Dec 4, 2024
1 parent 9276c77 commit 38dc781
Showing 4 changed files with 166 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
### Infrastructure
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;

@@ -106,9 +107,9 @@ public static Query create(CreateQueryRequest createQueryRequest) {
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
switch (vectorDataType) {
case BYTE:
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter);
return new LuceneEngineKnnVectorQuery(getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter));
case FLOAT:
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter);
return new LuceneEngineKnnVectorQuery(getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter));
default:
throw new IllegalArgumentException(
String.format(
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.lucene;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;

import java.io.IOException;

/**
* LuceneEngineKnnVectorQuery is a wrapper around a vector queries for the Lucene engine.
* This enables us to defer rewrites until weight creation to optimize repeated execution
* of Lucene based k-NN queries.
*/
@AllArgsConstructor
@Log4j2
public class LuceneEngineKnnVectorQuery extends Query {
private final Query luceneQuery;

/*
Prevents repeated rewrites of the query for the Lucene engine.
*/
@Override
public Query rewrite(IndexSearcher indexSearcher) {
return this;
}

/*
Rewrites the query just before weight creation.
*/
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
Query rewrittenQuery = luceneQuery.rewrite(searcher);
return rewrittenQuery.createWeight(searcher, scoreMode, boost);
}

@Override
public String toString(String s) {
return luceneQuery.toString();
}

@Override
public void visit(QueryVisitor queryVisitor) {
queryVisitor.visitLeaf(this);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
LuceneEngineKnnVectorQuery otherQuery = (LuceneEngineKnnVectorQuery) o;
return luceneQuery.equals(otherQuery.luceneQuery);
}

@Override
public int hashCode() {
return luceneQuery.hashCode();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.lucene;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Spy;
import org.opensearch.test.OpenSearchTestCase;

import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.MockitoAnnotations.openMocks;

public class LuceneEngineKnnVectorQueryTests extends OpenSearchTestCase {

@Mock
IndexSearcher indexSearcher;

@Mock
Query luceneQuery;

@Mock
Weight weight;

@Mock
QueryVisitor queryVisitor;

@Spy
@InjectMocks
LuceneEngineKnnVectorQuery objectUnderTest;

@Override
public void setUp() throws Exception {
super.setUp();
openMocks(this);
when(luceneQuery.rewrite(any(IndexSearcher.class))).thenReturn(luceneQuery);
when(luceneQuery.createWeight(any(IndexSearcher.class), any(ScoreMode.class), anyFloat())).thenReturn(weight);
}

public void testRewrite() {
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
verifyNoInteractions(luceneQuery);
verify(objectUnderTest, times(3)).rewrite(indexSearcher);
}

public void testCreateWeight() throws Exception {
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
verifyNoInteractions(luceneQuery);
Weight actualWeight = objectUnderTest.createWeight(indexSearcher, ScoreMode.TOP_DOCS, 1.0f);
verify(luceneQuery, times(1)).rewrite(indexSearcher);
verify(objectUnderTest, times(3)).rewrite(indexSearcher);
assertEquals(weight, actualWeight);
}

public void testVisit() {
objectUnderTest.visit(queryVisitor);
verify(queryVisitor).visitLeaf(objectUnderTest);
}

public void testEquals() {
LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
LuceneEngineKnnVectorQuery otherQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
assertEquals(mainQuery, otherQuery);
assertEquals(mainQuery, mainQuery);
assertNotEquals(mainQuery, null);
assertNotEquals(mainQuery, new Object());
LuceneEngineKnnVectorQuery otherQuery2 = new LuceneEngineKnnVectorQuery(null);
assertNotEquals(mainQuery, otherQuery2);
}

public void testHashCode() {
LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
assertEquals(mainQuery.hashCode(), luceneQuery.hashCode());
}

public void testToString() {
LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
assertEquals(mainQuery.toString(), luceneQuery.toString());
}
}

0 comments on commit 38dc781

Please sign in to comment.