Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds in lazy execution for Lucene kNN queries #2305

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
- Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.common.QueryUtils;
import org.opensearch.knn.index.query.lucenelib.NestedKnnVectorQueryFactory;
import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;

Expand Down Expand Up @@ -128,9 +129,13 @@ public static Query create(CreateQueryRequest createQueryRequest) {
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
switch (vectorDataType) {
case BYTE:
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter, expandNested);
return new LuceneEngineKnnVectorQuery(
getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter, expandNested)
);
case FLOAT:
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter, expandNested);
return new LuceneEngineKnnVectorQuery(
getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter, expandNested)
);
default:
throw new IllegalArgumentException(
String.format(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.lucene;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;

import java.io.IOException;

/**
* LuceneEngineKnnVectorQuery is a wrapper around a vector queries for the Lucene engine.
* This enables us to defer rewrites until weight creation to optimize repeated execution
* of Lucene based k-NN queries.
*/
@AllArgsConstructor
@Log4j2
public class LuceneEngineKnnVectorQuery extends Query {
@Getter
private final Query luceneQuery;

/*
Prevents repeated rewrites of the query for the Lucene engine.
*/
@Override
public Query rewrite(IndexSearcher indexSearcher) {
return this;
}

/*
Rewrites the query just before weight creation.
*/
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
Query rewrittenQuery = luceneQuery.rewrite(searcher);
return rewrittenQuery.createWeight(searcher, scoreMode, boost);
}

@Override
public String toString(String s) {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
return luceneQuery.toString();
}

@Override
public void visit(QueryVisitor queryVisitor) {
queryVisitor.visitLeaf(this);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
LuceneEngineKnnVectorQuery otherQuery = (LuceneEngineKnnVectorQuery) o;
return luceneQuery.equals(otherQuery.luceneQuery);
}

@Override
public int hashCode() {
return luceneQuery.hashCode();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import com.google.common.collect.ImmutableMap;
import lombok.SneakyThrows;
import org.apache.lucene.search.FloatVectorSimilarityQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.junit.Before;
Expand All @@ -33,6 +32,7 @@
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.mapper.Mode;
import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.index.util.KNNClusterUtil;
import org.opensearch.knn.index.engine.KNNMethodContext;
Expand Down Expand Up @@ -512,7 +512,7 @@ public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception {

// Then
assertNotNull(query);
assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class));
assertTrue(query.getClass().isAssignableFrom(LuceneEngineKnnVectorQuery.class));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way to check the type of the underlying lucene query, looks like this is what we're asserting in today's version? How about adding a package private getter for the lucene query, or a public method that returns the type of the query class.

}

@SneakyThrows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
import org.apache.lucene.search.join.ToChildBlockJoinQuery;
import org.junit.Before;
import org.mockito.Mock;
Expand All @@ -30,6 +28,7 @@
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.lucenelib.ExpandNestedDocsQuery;
import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;

Expand Down Expand Up @@ -120,7 +119,7 @@ public void testCreateLuceneDefaultQuery() {
.vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD)
.build()
);
assertEquals(KnnFloatVectorQuery.class, query.getClass());
assertEquals(LuceneEngineKnnVectorQuery.class, query.getClass());
}
}

Expand All @@ -138,7 +137,7 @@ public void testLuceneFloatVectorQuery() {
);

// efsearch > k
Query expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, 100, null);
Query expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnFloatVectorQuery(testFieldName, testQueryVector, 100, null));
assertEquals(expectedQuery1, actualQuery1);

// efsearch < k
Expand All @@ -153,7 +152,7 @@ public void testLuceneFloatVectorQuery() {
.vectorDataType(VectorDataType.FLOAT)
.build()
);
expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null);
expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null));
assertEquals(expectedQuery1, actualQuery1);

actualQuery1 = KNNQueryFactory.create(
Expand All @@ -166,7 +165,7 @@ public void testLuceneFloatVectorQuery() {
.vectorDataType(VectorDataType.FLOAT)
.build()
);
expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null);
expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null));
assertEquals(expectedQuery1, actualQuery1);
}

Expand All @@ -184,7 +183,7 @@ public void testLuceneByteVectorQuery() {
);

// efsearch > k
Query expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, 100, null);
Query expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnByteVectorQuery(testFieldName, testByteQueryVector, 100, null));
assertEquals(expectedQuery1, actualQuery1);

// efsearch < k
Expand All @@ -199,7 +198,7 @@ public void testLuceneByteVectorQuery() {
.vectorDataType(VectorDataType.BYTE)
.build()
);
expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null);
expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null));
assertEquals(expectedQuery1, actualQuery1);

actualQuery1 = KNNQueryFactory.create(
Expand All @@ -212,7 +211,7 @@ public void testLuceneByteVectorQuery() {
.vectorDataType(VectorDataType.BYTE)
.build()
);
expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null);
expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null));
assertEquals(expectedQuery1, actualQuery1);
}

Expand All @@ -235,7 +234,7 @@ public void testCreateLuceneQueryWithFilter() {
.filter(FILTER_QUERY_BUILDER)
.build();
Query query = KNNQueryFactory.create(createQueryRequest);
assertEquals(KnnFloatVectorQuery.class, query.getClass());
assertEquals(LuceneEngineKnnVectorQuery.class, query.getClass());
}
}

Expand Down Expand Up @@ -311,8 +310,8 @@ public void testCreateFaissQueryWithFilter_withValidValues_nullEfSearch_thenSucc
}

public void testCreate_whenLuceneWithParentFilter_thenReturnDiversifyingQuery() {
validateDiversifyingQueryWithParentFilter(VectorDataType.BYTE, DiversifyingChildrenByteKnnVectorQuery.class);
validateDiversifyingQueryWithParentFilter(VectorDataType.FLOAT, DiversifyingChildrenFloatKnnVectorQuery.class);
validateDiversifyingQueryWithParentFilter(VectorDataType.BYTE, LuceneEngineKnnVectorQuery.class);
validateDiversifyingQueryWithParentFilter(VectorDataType.FLOAT, LuceneEngineKnnVectorQuery.class);
}

public void testCreate_whenNestedVectorFiledAndNonNestedFilterField_thenReturnToChildBlockJoinQueryForFilters() {
Expand Down Expand Up @@ -515,7 +514,11 @@ private void testExpandNestedDocsQuery(KNNEngine knnEngine, Class klass, VectorD
.build();
Query query = KNNQueryFactory.create(createQueryRequest);

// Then
assertEquals(klass, query.getClass());
if (knnEngine == KNNEngine.LUCENE) {
assertEquals(klass, ((LuceneEngineKnnVectorQuery) query).getLuceneQuery().getClass());
} else {
// Then
assertEquals(klass, query.getClass());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.lucene;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Spy;
import org.opensearch.test.OpenSearchTestCase;

import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.MockitoAnnotations.openMocks;

public class LuceneEngineKnnVectorQueryTests extends OpenSearchTestCase {

@Mock
IndexSearcher indexSearcher;

@Mock
Query luceneQuery;

@Mock
Weight weight;

@Mock
QueryVisitor queryVisitor;

@Spy
@InjectMocks
LuceneEngineKnnVectorQuery objectUnderTest;

@Override
public void setUp() throws Exception {
super.setUp();
openMocks(this);
when(luceneQuery.rewrite(any(IndexSearcher.class))).thenReturn(luceneQuery);
when(luceneQuery.createWeight(any(IndexSearcher.class), any(ScoreMode.class), anyFloat())).thenReturn(weight);
}

public void testRewrite() {
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
verifyNoInteractions(luceneQuery);
verify(objectUnderTest, times(3)).rewrite(indexSearcher);
}

public void testCreateWeight() throws Exception {
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
objectUnderTest.rewrite(indexSearcher);
verifyNoInteractions(luceneQuery);
Weight actualWeight = objectUnderTest.createWeight(indexSearcher, ScoreMode.TOP_DOCS, 1.0f);
verify(luceneQuery, times(1)).rewrite(indexSearcher);
verify(objectUnderTest, times(3)).rewrite(indexSearcher);
assertEquals(weight, actualWeight);
}

public void testVisit() {
objectUnderTest.visit(queryVisitor);
verify(queryVisitor).visitLeaf(objectUnderTest);
}

public void testEquals() {
LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
LuceneEngineKnnVectorQuery otherQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
assertEquals(mainQuery, otherQuery);
assertEquals(mainQuery, mainQuery);
assertNotEquals(mainQuery, null);
assertNotEquals(mainQuery, new Object());
LuceneEngineKnnVectorQuery otherQuery2 = new LuceneEngineKnnVectorQuery(null);
assertNotEquals(mainQuery, otherQuery2);
}

public void testHashCode() {
LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
assertEquals(mainQuery.hashCode(), luceneQuery.hashCode());
}

public void testToString() {
LuceneEngineKnnVectorQuery mainQuery = new LuceneEngineKnnVectorQuery(luceneQuery);
assertEquals(mainQuery.toString(), luceneQuery.toString());
}
}
Loading