Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add expand_nested_docs Parameter support to NMSLIB engine #2331

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283]
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
- Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ public enum KNNEngine implements KNNLibrary {
private static final Set<KNNEngine> CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS);
private static final Set<KNNEngine> ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_MULTI_VECTORS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
heemin32 marked this conversation as resolved.
Show resolved Hide resolved

private static Map<KNNEngine, Integer> MAX_DIMENSIONS_BY_ENGINE = Map.of(
KNNEngine.NMSLIB,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;
import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_MULTI_VECTORS;

/**
* Creates the Lucene k-NN queries
Expand Down Expand Up @@ -110,15 +109,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.build();
}

if (createQueryRequest.getRescoreContext().isPresent()) {
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested);
}

if (ENGINES_SUPPORTING_MULTI_VECTORS.contains(knnEngine) && expandNested) {
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested);
}

return knnQuery;
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested);
}

Integer requestEfSearch = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
* A `DocIdSetIterator` that iterates over all nested document IDs belongs to the same parent document for a given
* set of nested document IDs.
*
* The {@link #docIds} should include only a single nested document ID per parent document. Otherwise, the nested documents
* of that parent document will be iterated multiple times.
*
* It is permissible for {@link #docIds} to contain multiple nested document IDs linked to a single parent document.
* In such cases, this iterator will still iterate over each nested document ID only once.
*/
public class GroupedNestedDocIdSetIterator extends DocIdSetIterator {
private final BitSet parentBitSet;
Expand Down Expand Up @@ -99,9 +98,14 @@ public long cost() {

private long calculateCost() {
long numDocs = 0;
int lastDocId = -1;
for (int docId : docIds) {
for (int i = parentBitSet.prevSetBit(docId) + 1; i < parentBitSet.nextSetBit(docId); i++) {
if (filterBits.get(i)) {
if (docId < lastDocId) {
continue;
}

for (lastDocId = parentBitSet.prevSetBit(docId) + 1; lastDocId < parentBitSet.nextSetBit(docId); lastDocId++) {
if (filterBits.get(lastDocId)) {
numDocs++;
}
}
Expand All @@ -111,12 +115,19 @@ private long calculateCost() {

private void moveToNextIndex() {
currentIndex++;
if (currentIndex >= docIds.size()) {
currentDocId = NO_MORE_DOCS;
while (currentIndex < docIds.size()) {
// Advance currentIndex until the docId at the currentIndex is greater than currentDocId.
// This ensures proper handling when docIds contain multiple entries under the same parent ID
// that have already been iterated.
if (docIds.get(currentIndex) <= currentDocId) {
currentIndex++;
continue;
}
currentDocId = parentBitSet.prevSetBit(docIds.get(currentIndex)) + 1;
currentParentId = parentBitSet.nextSetBit(docIds.get(currentIndex));
assert currentParentId != NO_MORE_DOCS;
return;
}
currentDocId = parentBitSet.prevSetBit(docIds.get(currentIndex)) + 1;
currentParentId = parentBitSet.nextSetBit(docIds.get(currentIndex));
assert currentParentId != NO_MORE_DOCS;
currentDocId = NO_MORE_DOCS;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.mapper.Mode;
import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.index.util.KNNClusterUtil;
import org.opensearch.knn.index.engine.KNNMethodContext;
Expand Down Expand Up @@ -191,7 +192,7 @@ public void testDoToQuery_Normal() throws Exception {
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 4));
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext);
KNNQuery query = ((NativeEngineKnnVectorQuery) knnQueryBuilder.doToQuery(mockQueryShardContext)).getKnnQuery();
assertEquals(knnQueryBuilder.getK(), query.getK());
assertEquals(knnQueryBuilder.fieldName(), query.getField());
assertEquals(knnQueryBuilder.vector(), query.getQueryVector());
Expand Down Expand Up @@ -599,8 +600,8 @@ public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() {

// Then
assertNotNull(query);
assertTrue(query.getClass().isAssignableFrom(KNNQuery.class));
assertEquals(HNSW_METHOD_PARAMS, ((KNNQuery) query).getMethodParameters());
assertTrue(query.getClass().isAssignableFrom(NativeEngineKnnVectorQuery.class));
assertEquals(HNSW_METHOD_PARAMS, ((NativeEngineKnnVectorQuery) query).getKnnQuery().getMethodParameters());
}

public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParameter() {
Expand Down Expand Up @@ -670,7 +671,7 @@ public void testDoToQuery_FromModel() {
KNNQueryBuilder.initialize(modelDao);

when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext);
KNNQuery query = ((NativeEngineKnnVectorQuery) knnQueryBuilder.doToQuery(mockQueryShardContext)).getKnnQuery();
assertEquals(knnQueryBuilder.getK(), query.getK());
assertEquals(knnQueryBuilder.fieldName(), query.getField());
assertEquals(knnQueryBuilder.vector(), query.getQueryVector());
Expand Down Expand Up @@ -1026,7 +1027,7 @@ public void testDoToQuery_whenBinary_thenValid() throws Exception {
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY);
when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 32));
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext);
KNNQuery query = ((NativeEngineKnnVectorQuery) knnQueryBuilder.doToQuery(mockQueryShardContext)).getKnnQuery();
assertArrayEquals(expectedQueryVector, query.getByteQueryVector());
assertNull(query.getQueryVector());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public void setUp() throws Exception {

public void testCreateCustomKNNQuery() {
for (KNNEngine knnEngine : KNNEngine.getEnginesThatCreateCustomSegmentFiles()) {
Query query = KNNQueryFactory.create(
Query query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(
BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(testIndexName)
Expand All @@ -78,14 +78,14 @@ public void testCreateCustomKNNQuery() {
.k(testK)
.vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD)
.build()
);
)).getKnnQuery();
assertTrue(query instanceof KNNQuery);
assertEquals(testIndexName, ((KNNQuery) query).getIndexName());
assertEquals(testFieldName, ((KNNQuery) query).getField());
assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector());
assertEquals(testK, ((KNNQuery) query).getK());

query = KNNQueryFactory.create(
query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(
BaseQueryFactory.CreateQueryRequest.builder()
.knnEngine(knnEngine)
.indexName(testIndexName)
Expand All @@ -94,7 +94,7 @@ public void testCreateCustomKNNQuery() {
.k(testK)
.vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD)
.build()
);
)).getKnnQuery();

assertTrue(query instanceof KNNQuery);
assertEquals(testIndexName, ((KNNQuery) query).getIndexName());
Expand Down Expand Up @@ -269,7 +269,7 @@ public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() {
.filter(FILTER_QUERY_BUILDER)
.build();

final Query actual = KNNQueryFactory.create(createQueryRequest);
final Query actual = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery();

// Then
assertEquals(expectedQuery, actual);
Expand Down Expand Up @@ -303,7 +303,7 @@ public void testCreateFaissQueryWithFilter_withValidValues_nullEfSearch_thenSucc
.filter(FILTER_QUERY_BUILDER)
.build();

final Query actual = KNNQueryFactory.create(createQueryRequest);
final Query actual = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery();

// Then
assertEquals(expectedQuery, actual);
Expand Down Expand Up @@ -338,7 +338,7 @@ public void testCreate_whenNestedVectorFiledAndNonNestedFilterField_thenReturnTo
.context(mockQueryShardContext)
.filter(FILTER_QUERY_BUILDER)
.build();
KNNQuery query = (KNNQuery) KNNQueryFactory.create(createQueryRequest);
KNNQuery query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery();
mockedNestedHelper.close();
assertEquals(ToChildBlockJoinQuery.class, query.getFilterQuery().getClass());
}
Expand Down Expand Up @@ -367,7 +367,7 @@ public void testCreate_whenNestedVectorAndFilterField_thenReturnSameFilterQuery(
.context(mockQueryShardContext)
.filter(FILTER_QUERY_BUILDER)
.build();
KNNQuery query = (KNNQuery) KNNQueryFactory.create(createQueryRequest);
KNNQuery query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery();
mockedNestedHelper.close();
assertEquals(FILTER_QUERY.getClass(), query.getFilterQuery().getClass());
}
Expand All @@ -388,7 +388,7 @@ public void testCreate_whenFaissWithParentFilter_thenSuccess() {
.vectorDataType(VectorDataType.FLOAT)
.context(mockQueryShardContext)
.build();
final Query query = KNNQueryFactory.create(createQueryRequest);
final Query query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery();
assertTrue(query instanceof KNNQuery);
assertEquals(testIndexName, ((KNNQuery) query).getIndexName());
assertEquals(testFieldName, ((KNNQuery) query).getField());
Expand Down Expand Up @@ -441,7 +441,7 @@ public void testCreate_whenBinary_thenSuccess() {
.context(mockQueryShardContext)
.filter(FILTER_QUERY_BUILDER)
.build();
Query query = KNNQueryFactory.create(createQueryRequest);
Query query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery();
assertTrue(query instanceof KNNQuery);
assertNotNull(((KNNQuery) query).getByteQueryVector());
assertNull(((KNNQuery) query).getQueryVector());
Expand Down Expand Up @@ -488,7 +488,7 @@ public void testCreate_whenExpandNestedDocsQueryWithFaiss_thenCreateNativeEngine
}

public void testCreate_whenExpandNestedDocsQueryWithNmslib_thenCreateKNNQuery() {
testExpandNestedDocsQuery(KNNEngine.NMSLIB, KNNQuery.class, VectorDataType.FLOAT);
testExpandNestedDocsQuery(KNNEngine.NMSLIB, NativeEngineKnnVectorQuery.class, VectorDataType.FLOAT);
}

public void testCreate_whenExpandNestedDocsQueryWithLucene_thenCreateExpandNestedDocsQuery() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,33 @@ public void testGroupedNestedDocIdSetIterator_whenAdvanceIsCalled_thenBehaveAsEx
assertEquals(DocIdSetIterator.NO_MORE_DOCS, groupedNestedDocIdSetIterator.docID());
assertEquals(expectedDocIds.size(), groupedNestedDocIdSetIterator.cost());
}

public void testGroupedNestedDocIdSetIterator_whenGivenMultipleDocsUnderSameParent_thenBehaveAsExpected() throws Exception {
// 0, 1, 2(parent), 3, 4, 5, 6, 7(parent), 8, 9, 10(parent)
BitSet parentBitSet = new FixedBitSet(new long[1], 11);
parentBitSet.set(2);
parentBitSet.set(7);
parentBitSet.set(10);

BitSet filterBits = new FixedBitSet(new long[1], 11);
filterBits.set(1);
filterBits.set(8);
filterBits.set(9);

// Run
Set<Integer> docIds = Set.of(0, 1, 3, 4, 5, 8, 9);
GroupedNestedDocIdSetIterator groupedNestedDocIdSetIterator = new GroupedNestedDocIdSetIterator(parentBitSet, docIds, filterBits);

// Verify
Set<Integer> expectedDocIds = Set.of(1, 8, 9);
groupedNestedDocIdSetIterator.advance(1);
assertEquals(1, groupedNestedDocIdSetIterator.docID());
groupedNestedDocIdSetIterator.nextDoc();
assertEquals(8, groupedNestedDocIdSetIterator.docID());
groupedNestedDocIdSetIterator.advance(9);
assertEquals(9, groupedNestedDocIdSetIterator.docID());
groupedNestedDocIdSetIterator.nextDoc();
assertEquals(DocIdSetIterator.NO_MORE_DOCS, groupedNestedDocIdSetIterator.docID());
assertEquals(expectedDocIds.size(), groupedNestedDocIdSetIterator.cost());
}
}
10 changes: 0 additions & 10 deletions src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.ArrayUtils;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.junit.After;
import org.junit.BeforeClass;
import org.opensearch.client.Response;
import org.opensearch.common.settings.Settings;
Expand Down Expand Up @@ -68,15 +67,6 @@ public static void setUpClass() throws IOException {
testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath(), groundTruthValues.getPath());
}

@After
public void cleanUp() {
try {
deleteKNNIndex(INDEX_NAME);
} catch (Exception e) {
log.error(e);
}
}

@SneakyThrows
public void testHnswBinary_whenSmallDataSet_thenCreateIngestQueryWorks() {
// Create Index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.junit.After;
import org.opensearch.knn.KNNJsonIndexMappingsBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.index.VectorDataType;
Expand All @@ -29,15 +28,6 @@
@Log4j2
@AllArgsConstructor
public class BinaryIndexInvalidMappingIT extends KNNRestTestCase {
@After
public void cleanUp() {
try {
deleteKNNIndex(INDEX_NAME);
} catch (Exception e) {
log.error(e);
}
}

private String description;
private String indexMapping;
private String expectedExceptionMessage;
Expand Down
24 changes: 15 additions & 9 deletions src/test/java/org/opensearch/knn/integ/ExpandNestedDocsIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.junit.After;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.common.settings.Settings;
Expand Down Expand Up @@ -70,12 +69,6 @@ public class ExpandNestedDocsIT extends KNNRestTestCase {
private Mode mode;
private Integer dimension;

@After
@SneakyThrows
public final void cleanUp() {
deleteKNNIndex(INDEX_NAME);
}

@ParametersFactory(argumentFormatting = "description:%1$s; engine:%2$s, data_type:%3$s, mode:%4$s, dimension:%5$s")
public static Collection<Object[]> parameters() throws IOException {
int dimension = 1;
Expand All @@ -99,13 +92,19 @@ public static Collection<Object[]> parameters() throws IOException {
Mode.ON_DISK,
// Currently, on disk mode only supports dimension of multiple of 8
dimension * 8
)
),
$("Nmslib with float format and in memory mode", KNNEngine.NMSLIB, VectorDataType.FLOAT, Mode.NOT_CONFIGURED, dimension)
)
);
}

@SneakyThrows
public void testExpandNestedDocs_whenFilteredOnParentDoc_thenReturnAllNestedDoc() {
if (engine == KNNEngine.NMSLIB) {
// NMSLIB does not support filtering
return;
}

int numberOfNestedFields = 2;
createKnnIndex(engine, mode, dimension, dataType);
addRandomVectorsWithTopLevelField(1, numberOfNestedFields, FIELD_NAME_PARKING, FIELD_VALUE_TRUE);
Expand All @@ -131,6 +130,11 @@ public void testExpandNestedDocs_whenFilteredOnParentDoc_thenReturnAllNestedDoc(

@SneakyThrows
public void testExpandNestedDocs_whenFilteredOnNestedFieldDoc_thenReturnFilteredNestedDoc() {
if (engine == KNNEngine.NMSLIB) {
// NMSLIB does not support filtering
return;
}

int numberOfNestedFields = 2;
createKnnIndex(engine, mode, dimension, dataType);
addRandomVectorsWithMetadata(1, numberOfNestedFields, FIELD_NAME_STORAGE, Arrays.asList(FIELD_VALUE_FALSE, FIELD_VALUE_FALSE));
Expand Down Expand Up @@ -175,7 +179,9 @@ public void testExpandNestedDocs_whenMultiShards_thenReturnCorrectResult() {

// Run
Float[] queryVector = createVector();
Response response = queryNestedFieldWithExpandNestedDocs(INDEX_NAME, numberOfDocuments, queryVector);
// NMSLIB does not support dedup per parent documents. Therefore, we need to multiply the k by number of nestedFields.
int k = engine == KNNEngine.NMSLIB ? numberOfDocuments * numberOfNestedFields : numberOfDocuments;
Response response = queryNestedFieldWithExpandNestedDocs(INDEX_NAME, k, queryVector);

// Verify
String entity = EntityUtils.toString(response.getEntity());
Expand Down
Loading
Loading