From ead75ff1e014b20f9147a07bc79c357a89dd4616 Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Thu, 12 Dec 2024 11:36:55 -0800 Subject: [PATCH] Support expand_nested_docs parameter for nmslib engine Signed-off-by: Heemin Kim --- CHANGELOG.md | 1 + .../knn/index/engine/KNNEngine.java | 1 - .../knn/index/query/KNNQueryFactory.java | 12 +------ .../GroupedNestedDocIdSetIterator.java | 31 +++++++++++++------ .../knn/index/query/KNNQueryBuilderTests.java | 11 ++++--- .../knn/index/query/KNNQueryFactoryTests.java | 22 ++++++------- .../GroupedNestedDocIdSetIteratorTests.java | 29 +++++++++++++++++ .../knn/integ/ExpandNestedDocsIT.java | 25 +++++++++++++-- 8 files changed, 91 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cbd0ef2f..4ae86dba1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x) ### Features - Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283] +- Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331] ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java index f75c7f1d9..1e560a11b 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java @@ -34,7 +34,6 @@ public enum KNNEngine implements KNNLibrary { private static final Set CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS); private static final Set ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); public static final Set ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); - public static final Set ENGINES_SUPPORTING_MULTI_VECTORS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); private static Map MAX_DIMENSIONS_BY_ENGINE = Map.of( KNNEngine.NMSLIB, diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index d01a9aff6..0c1efef88 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -26,7 +26,6 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; -import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_MULTI_VECTORS; /** * Creates the Lucene k-NN queries @@ -50,7 +49,6 @@ public static Query create(CreateQueryRequest createQueryRequest) { final Query filterQuery = getFilterQuery(createQueryRequest); final Map methodParameters = createQueryRequest.getMethodParameters(); final RescoreContext rescoreContext = createQueryRequest.getRescoreContext().orElse(null); - final KNNEngine knnEngine = createQueryRequest.getKnnEngine(); final boolean expandNested = createQueryRequest.isExpandNested(); BitSetProducer parentFilter = null; if (createQueryRequest.getContext().isPresent()) { @@ -110,15 +108,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { .build(); } - if (createQueryRequest.getRescoreContext().isPresent()) { - return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested); - } - - if (ENGINES_SUPPORTING_MULTI_VECTORS.contains(knnEngine) && expandNested) { - return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested); - } - - return knnQuery; + return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, expandNested); } Integer requestEfSearch = null; diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java index 19842a67a..727c508fb 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIterator.java @@ -19,9 +19,8 @@ * A `DocIdSetIterator` that iterates over all nested document IDs belongs to the same parent document for a given * set of nested document IDs. * - * The {@link #docIds} should include only a single nested document ID per parent document. Otherwise, the nested documents - * of that parent document will be iterated multiple times. - * + * It is permissible for {@link #docIds} to contain multiple nested document IDs linked to a single parent document. + * In such cases, this iterator will still iterate over each nested document ID only once. */ public class GroupedNestedDocIdSetIterator extends DocIdSetIterator { private final BitSet parentBitSet; @@ -99,9 +98,14 @@ public long cost() { private long calculateCost() { long numDocs = 0; + int lastDocId = -1; for (int docId : docIds) { - for (int i = parentBitSet.prevSetBit(docId) + 1; i < parentBitSet.nextSetBit(docId); i++) { - if (filterBits.get(i)) { + if (docId < lastDocId) { + continue; + } + + for (lastDocId = parentBitSet.prevSetBit(docId) + 1; lastDocId < parentBitSet.nextSetBit(docId); lastDocId++) { + if (filterBits.get(lastDocId)) { numDocs++; } } @@ -111,12 +115,19 @@ private long calculateCost() { private void moveToNextIndex() { currentIndex++; - if (currentIndex >= docIds.size()) { - currentDocId = NO_MORE_DOCS; + while (currentIndex < docIds.size()) { + // Advance currentIndex until the docId at the currentIndex is greater than currentDocId. + // This ensures proper handling when docIds contain multiple entries under the same parent ID + // that have already been iterated. + if (docIds.get(currentIndex) <= currentDocId) { + currentIndex++; + continue; + } + currentDocId = parentBitSet.prevSetBit(docIds.get(currentIndex)) + 1; + currentParentId = parentBitSet.nextSetBit(docIds.get(currentIndex)); + assert currentParentId != NO_MORE_DOCS; return; } - currentDocId = parentBitSet.prevSetBit(docIds.get(currentIndex)) + 1; - currentParentId = parentBitSet.nextSetBit(docIds.get(currentIndex)); - assert currentParentId != NO_MORE_DOCS; + currentDocId = NO_MORE_DOCS; } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index b609bb0df..30c8007e6 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -33,6 +33,7 @@ import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.index.mapper.Mode; import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery; +import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; import org.opensearch.knn.index.query.rescore.RescoreContext; import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.index.engine.KNNMethodContext; @@ -191,7 +192,7 @@ public void testDoToQuery_Normal() throws Exception { when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + KNNQuery query = ((NativeEngineKnnVectorQuery) knnQueryBuilder.doToQuery(mockQueryShardContext)).getKnnQuery(); assertEquals(knnQueryBuilder.getK(), query.getK()); assertEquals(knnQueryBuilder.fieldName(), query.getField()); assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); @@ -599,8 +600,8 @@ public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { // Then assertNotNull(query); - assertTrue(query.getClass().isAssignableFrom(KNNQuery.class)); - assertEquals(HNSW_METHOD_PARAMS, ((KNNQuery) query).getMethodParameters()); + assertTrue(query.getClass().isAssignableFrom(NativeEngineKnnVectorQuery.class)); + assertEquals(HNSW_METHOD_PARAMS, ((NativeEngineKnnVectorQuery) query).getKnnQuery().getMethodParameters()); } public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParameter() { @@ -670,7 +671,7 @@ public void testDoToQuery_FromModel() { KNNQueryBuilder.initialize(modelDao); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + KNNQuery query = ((NativeEngineKnnVectorQuery) knnQueryBuilder.doToQuery(mockQueryShardContext)).getKnnQuery(); assertEquals(knnQueryBuilder.getK(), query.getK()); assertEquals(knnQueryBuilder.fieldName(), query.getField()); assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); @@ -1026,7 +1027,7 @@ public void testDoToQuery_whenBinary_thenValid() throws Exception { when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 32)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + KNNQuery query = ((NativeEngineKnnVectorQuery) knnQueryBuilder.doToQuery(mockQueryShardContext)).getKnnQuery(); assertArrayEquals(expectedQueryVector, query.getByteQueryVector()); assertNull(query.getQueryVector()); } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index eff2ca895..329222636 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -69,7 +69,7 @@ public void setUp() throws Exception { public void testCreateCustomKNNQuery() { for (KNNEngine knnEngine : KNNEngine.getEnginesThatCreateCustomSegmentFiles()) { - Query query = KNNQueryFactory.create( + Query query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create( BaseQueryFactory.CreateQueryRequest.builder() .knnEngine(knnEngine) .indexName(testIndexName) @@ -78,14 +78,14 @@ public void testCreateCustomKNNQuery() { .k(testK) .vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD) .build() - ); + )).getKnnQuery(); assertTrue(query instanceof KNNQuery); assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); assertEquals(testFieldName, ((KNNQuery) query).getField()); assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector()); assertEquals(testK, ((KNNQuery) query).getK()); - query = KNNQueryFactory.create( + query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create( BaseQueryFactory.CreateQueryRequest.builder() .knnEngine(knnEngine) .indexName(testIndexName) @@ -94,7 +94,7 @@ public void testCreateCustomKNNQuery() { .k(testK) .vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD) .build() - ); + )).getKnnQuery(); assertTrue(query instanceof KNNQuery); assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); @@ -269,7 +269,7 @@ public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() { .filter(FILTER_QUERY_BUILDER) .build(); - final Query actual = KNNQueryFactory.create(createQueryRequest); + final Query actual = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery(); // Then assertEquals(expectedQuery, actual); @@ -303,7 +303,7 @@ public void testCreateFaissQueryWithFilter_withValidValues_nullEfSearch_thenSucc .filter(FILTER_QUERY_BUILDER) .build(); - final Query actual = KNNQueryFactory.create(createQueryRequest); + final Query actual = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery(); // Then assertEquals(expectedQuery, actual); @@ -338,7 +338,7 @@ public void testCreate_whenNestedVectorFiledAndNonNestedFilterField_thenReturnTo .context(mockQueryShardContext) .filter(FILTER_QUERY_BUILDER) .build(); - KNNQuery query = (KNNQuery) KNNQueryFactory.create(createQueryRequest); + KNNQuery query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery(); mockedNestedHelper.close(); assertEquals(ToChildBlockJoinQuery.class, query.getFilterQuery().getClass()); } @@ -367,7 +367,7 @@ public void testCreate_whenNestedVectorAndFilterField_thenReturnSameFilterQuery( .context(mockQueryShardContext) .filter(FILTER_QUERY_BUILDER) .build(); - KNNQuery query = (KNNQuery) KNNQueryFactory.create(createQueryRequest); + KNNQuery query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery(); mockedNestedHelper.close(); assertEquals(FILTER_QUERY.getClass(), query.getFilterQuery().getClass()); } @@ -388,7 +388,7 @@ public void testCreate_whenFaissWithParentFilter_thenSuccess() { .vectorDataType(VectorDataType.FLOAT) .context(mockQueryShardContext) .build(); - final Query query = KNNQueryFactory.create(createQueryRequest); + final Query query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery(); assertTrue(query instanceof KNNQuery); assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); assertEquals(testFieldName, ((KNNQuery) query).getField()); @@ -441,7 +441,7 @@ public void testCreate_whenBinary_thenSuccess() { .context(mockQueryShardContext) .filter(FILTER_QUERY_BUILDER) .build(); - Query query = KNNQueryFactory.create(createQueryRequest); + Query query = ((NativeEngineKnnVectorQuery) KNNQueryFactory.create(createQueryRequest)).getKnnQuery(); assertTrue(query instanceof KNNQuery); assertNotNull(((KNNQuery) query).getByteQueryVector()); assertNull(((KNNQuery) query).getQueryVector()); @@ -488,7 +488,7 @@ public void testCreate_whenExpandNestedDocsQueryWithFaiss_thenCreateNativeEngine } public void testCreate_whenExpandNestedDocsQueryWithNmslib_thenCreateKNNQuery() { - testExpandNestedDocsQuery(KNNEngine.NMSLIB, KNNQuery.class, VectorDataType.FLOAT); + testExpandNestedDocsQuery(KNNEngine.NMSLIB, NativeEngineKnnVectorQuery.class, VectorDataType.FLOAT); } public void testCreate_whenExpandNestedDocsQueryWithLucene_thenCreateExpandNestedDocsQuery() { diff --git a/src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java index 55f3d91d9..976b50ea6 100644 --- a/src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/iterators/GroupedNestedDocIdSetIteratorTests.java @@ -70,4 +70,33 @@ public void testGroupedNestedDocIdSetIterator_whenAdvanceIsCalled_thenBehaveAsEx assertEquals(DocIdSetIterator.NO_MORE_DOCS, groupedNestedDocIdSetIterator.docID()); assertEquals(expectedDocIds.size(), groupedNestedDocIdSetIterator.cost()); } + + public void testGroupedNestedDocIdSetIterator_whenGivenMultipleDocsUnderSameParent_thenBehaveAsExpected() throws Exception { + // 0, 1, 2(parent), 3, 4, 5, 6, 7(parent), 8, 9, 10(parent) + BitSet parentBitSet = new FixedBitSet(new long[1], 11); + parentBitSet.set(2); + parentBitSet.set(7); + parentBitSet.set(10); + + BitSet filterBits = new FixedBitSet(new long[1], 11); + filterBits.set(1); + filterBits.set(8); + filterBits.set(9); + + // Run + Set docIds = Set.of(0, 1, 3, 4, 5, 8, 9); + GroupedNestedDocIdSetIterator groupedNestedDocIdSetIterator = new GroupedNestedDocIdSetIterator(parentBitSet, docIds, filterBits); + + // Verify + Set expectedDocIds = Set.of(1, 8, 9); + groupedNestedDocIdSetIterator.advance(1); + assertEquals(1, groupedNestedDocIdSetIterator.docID()); + groupedNestedDocIdSetIterator.nextDoc(); + assertEquals(8, groupedNestedDocIdSetIterator.docID()); + groupedNestedDocIdSetIterator.advance(9); + assertEquals(9, groupedNestedDocIdSetIterator.docID()); + groupedNestedDocIdSetIterator.nextDoc(); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, groupedNestedDocIdSetIterator.docID()); + assertEquals(expectedDocIds.size(), groupedNestedDocIdSetIterator.cost()); + } } diff --git a/src/test/java/org/opensearch/knn/integ/ExpandNestedDocsIT.java b/src/test/java/org/opensearch/knn/integ/ExpandNestedDocsIT.java index 164aa7100..47d201d3d 100644 --- a/src/test/java/org/opensearch/knn/integ/ExpandNestedDocsIT.java +++ b/src/test/java/org/opensearch/knn/integ/ExpandNestedDocsIT.java @@ -14,6 +14,7 @@ import org.junit.After; import org.opensearch.client.Request; import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.rest.RestStatus; @@ -73,7 +74,12 @@ public class ExpandNestedDocsIT extends KNNRestTestCase { @After @SneakyThrows public final void cleanUp() { - deleteKNNIndex(INDEX_NAME); + try { + deleteKNNIndex(INDEX_NAME); + } catch (ResponseException e) { + // Index not found exception is acceptable because some test cases do not create an index + assert e.getMessage().contains("index_not_found_exception"); + } } @ParametersFactory(argumentFormatting = "description:%1$s; engine:%2$s, data_type:%3$s, mode:%4$s, dimension:%5$s") @@ -99,13 +105,19 @@ public static Collection parameters() throws IOException { Mode.ON_DISK, // Currently, on disk mode only supports dimension of multiple of 8 dimension * 8 - ) + ), + $("Nmslib with float format and in memory mode", KNNEngine.NMSLIB, VectorDataType.FLOAT, Mode.NOT_CONFIGURED, dimension) ) ); } @SneakyThrows public void testExpandNestedDocs_whenFilteredOnParentDoc_thenReturnAllNestedDoc() { + if (engine == KNNEngine.NMSLIB) { + // NMSLIB does not support filtering + return; + } + int numberOfNestedFields = 2; createKnnIndex(engine, mode, dimension, dataType); addRandomVectorsWithTopLevelField(1, numberOfNestedFields, FIELD_NAME_PARKING, FIELD_VALUE_TRUE); @@ -131,6 +143,11 @@ public void testExpandNestedDocs_whenFilteredOnParentDoc_thenReturnAllNestedDoc( @SneakyThrows public void testExpandNestedDocs_whenFilteredOnNestedFieldDoc_thenReturnFilteredNestedDoc() { + if (engine == KNNEngine.NMSLIB) { + // NMSLIB does not support filtering + return; + } + int numberOfNestedFields = 2; createKnnIndex(engine, mode, dimension, dataType); addRandomVectorsWithMetadata(1, numberOfNestedFields, FIELD_NAME_STORAGE, Arrays.asList(FIELD_VALUE_FALSE, FIELD_VALUE_FALSE)); @@ -175,7 +192,9 @@ public void testExpandNestedDocs_whenMultiShards_thenReturnCorrectResult() { // Run Float[] queryVector = createVector(); - Response response = queryNestedFieldWithExpandNestedDocs(INDEX_NAME, numberOfDocuments, queryVector); + // NMSLIB does not support dedup per parent documents. Therefore, we need to multiply the k by number of nestedFields. + int k = engine == KNNEngine.NMSLIB ? numberOfDocuments * numberOfNestedFields : numberOfDocuments; + Response response = queryNestedFieldWithExpandNestedDocs(INDEX_NAME, k, queryVector); // Verify String entity = EntityUtils.toString(response.getEntity());