diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 3a99ad3ac..29fefbf9b 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,2 +1,2 @@ # This should match the owning team set up in https://github.com/orgs/opensearch-project/teams -* @heemin32 @navneet1v @VijayanB @vamshin @jmazanec15 @naveentatikonda @junqiu-lei @martin-gaievski @ryanbogan @luyuncheng @shatejas +* @heemin32 @navneet1v @VijayanB @vamshin @jmazanec15 @naveentatikonda @junqiu-lei @martin-gaievski @ryanbogan @luyuncheng @shatejas @0ctopus13prime diff --git a/CHANGELOG.md b/CHANGELOG.md index b02742d3b..8cfc29dce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,10 +23,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] - Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305] +- Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320] +- Use one formula to calculate cosine similarity (#2357)[https://github.com/opensearch-project/k-NN/pull/2357] ### Bug Fixes * Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282] * Fixing the bug where search fails with "fields" parameter for an index with a knn_vector field (#2314)[https://github.com/opensearch-project/k-NN/pull/2314] +* Fix for NPE while merging segments after all the vector fields docs are deleted (#2365)[https://github.com/opensearch-project/k-NN/pull/2365] * Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315] +* Release query vector memory after execution (#2346)[https://github.com/opensearch-project/k-NN/pull/2346] +* Fix shard level rescoring disabled setting flag (#2352)[https://github.com/opensearch-project/k-NN/pull/2352] +* Fix filter rewrite logic which was resulting in getting inconsistent / incorrect results for cases where filter was getting rewritten for shards (#2359)[https://github.com/opensearch-project/k-NN/pull/2359] ### Infrastructure * Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259) * Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279) diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 36c99b96a..fd18eee5c 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -6,6 +6,7 @@ This document contains a list of maintainers in this repo. See [opensearch-proje | Maintainer | GitHub ID | Affiliation | |-------------------------|-------------------------------------------------------|-------------| +| Doo Yong Kim | [0ctopus13prime](https://github.com/0ctopus13prime) | Amazon | | Heemin Kim | [heemin32](https://github.com/heemin32) | Amazon | | Jack Mazanec | [jmazanec15](https://github.com/jmazanec15) | Amazon | | Junqiu Lei | [junqiu-lei](https://github.com/junqiu-lei) | Amazon | diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 98c40cf6b..c02c410c1 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -1180,6 +1180,7 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); throw; } + jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); } else { faiss::SearchParameters *searchParameters = nullptr; faiss::SearchParametersHNSW hnswParams; @@ -1202,6 +1203,7 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter throw; } } + jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryVector, JNI_ABORT); // lims is structured to support batched queries, it has a length of nq + 1 (where nq is the number of queries), // lims[i] - lims[i-1] gives the number of results for the i-th query. With a single query we used in k-NN, diff --git a/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java b/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java index 778cc164d..ea826c1ff 100644 --- a/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java +++ b/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java @@ -7,9 +7,7 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; -import org.opensearch.knn.index.vectorvalues.KNNVectorValues; -import java.io.IOException; import java.util.List; import java.util.Objects; @@ -62,17 +60,4 @@ public static int[] intListToArray(final List integerList) { } return intArray; } - - /** - * Iterates vector values once if it is not at start of the location, - * Intended to be done to make sure dimension and bytesPerVector are available - * @param vectorValues - * @throws IOException - */ - public static void iterateVectorValuesOnce(final KNNVectorValues vectorValues) throws IOException { - if (vectorValues.docId() == -1) { - vectorValues.nextDoc(); - vectorValues.getVector(); - } - } } diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index b81a54124..6dc72a22b 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -577,7 +577,7 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) { .getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE); } - public static boolean isShardLevelRescoringEnabledForDiskBasedVector(String indexName) { + public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) { return KNNSettings.state().clusterService.state() .getMetadata() .index(indexName) diff --git a/src/main/java/org/opensearch/knn/index/SpaceType.java b/src/main/java/org/opensearch/knn/index/SpaceType.java index abe265a01..5d90071e8 100644 --- a/src/main/java/org/opensearch/knn/index/SpaceType.java +++ b/src/main/java/org/opensearch/knn/index/SpaceType.java @@ -60,9 +60,21 @@ public float scoreToDistanceTranslation(float score) { } }, COSINESIMIL("cosinesimil") { + /** + * Cosine similarity has range of [-1, 1] where -1 represents vectors are at diametrically opposite, and 1 is where + * they are identical in direction and perfectly similar. In Lucene, scores have to be in the range of [0, Float.MAX_VALUE]. + * Hence, to move the range from [-1, 1] to [ 0, Float.MAX_VALUE], we convert using following formula which is adopted + * by Lucene as mentioned here + * https://github.com/apache/lucene/blob/0494c824e0ac8049b757582f60d085932a890800/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java#L73 + * We expect raw score = 1 - cosine(x,y), if underlying library returns different range or other than expected raw score, + * they should override this method to either provide valid range or convert raw score to the format as 1 - cosine and call this method + * + * @param rawScore score returned from underlying library + * @return Lucene scaled score + */ @Override public float scoreTranslation(float rawScore) { - return 1 / (1 + rawScore); + return Math.max((2.0F - rawScore) / 2.0F, 0.0F); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java index 23c3ba116..15d38a079 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java @@ -23,7 +23,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNVectorUtil.intListToArray; -import static org.opensearch.knn.common.KNNVectorUtil.iterateVectorValuesOnce; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; import static org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory.getVectorTransfer; /** @@ -52,7 +52,7 @@ public static DefaultIndexBuildStrategy getInstance() { public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOException { final KNNVectorValues knnVectorValues = indexInfo.getVectorValues(); // Needed to make sure we don't get 0 dimensions while initializing index - iterateVectorValuesOnce(knnVectorValues); + initializeVectorValues(knnVectorValues); IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo); try ( diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java index 81f5915a7..2864be6d2 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java @@ -22,7 +22,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.opensearch.knn.common.KNNVectorUtil.intListToArray; -import static org.opensearch.knn.common.KNNVectorUtil.iterateVectorValuesOnce; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; import static org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory.getVectorTransfer; /** @@ -53,7 +53,7 @@ public static MemOptimizedNativeIndexBuildStrategy getInstance() { public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOException { final KNNVectorValues knnVectorValues = indexInfo.getVectorValues(); // Needed to make sure we don't get 0 dimensions while initializing index - iterateVectorValuesOnce(knnVectorValues); + initializeVectorValues(knnVectorValues); KNNEngine engine = indexInfo.getKnnEngine(); Map indexParameters = indexInfo.getParameters(); IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo); diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index 27a1ecfb6..7078645e5 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -43,7 +43,7 @@ import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; -import static org.opensearch.knn.common.KNNVectorUtil.iterateVectorValuesOnce; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; @@ -100,7 +100,7 @@ public static NativeIndexWriter getWriter( * @throws IOException */ public void flushIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { - iterateVectorValuesOnce(knnVectorValues); + initializeVectorValues(knnVectorValues); buildAndWriteIndex(knnVectorValues, totalLiveDocs); recordRefreshStats(); } @@ -111,7 +111,7 @@ public void flushIndex(final KNNVectorValues knnVectorValues, int totalLiveDo * @throws IOException */ public void mergeIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { - iterateVectorValuesOnce(knnVectorValues); + initializeVectorValues(knnVectorValues); if (knnVectorValues.docId() == NO_MORE_DOCS) { // This is in place so we do not add metrics log.debug("Skipping mergeIndex, vector values are already iterated for {}", fieldInfo.name); diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index 3ccfc3c2b..f2db1f25d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -9,12 +9,15 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import java.io.IOException; import java.util.Comparator; import java.util.List; import java.util.stream.Collectors; @@ -116,6 +119,31 @@ public static String getNativeEngineFileFromFieldInfo(FieldInfo field, SegmentIn } } + /** + * Positions the vectorValuesIterator to the first vector document ID if not already positioned there. + * This initialization is crucial for setting up vector dimensions and other properties in VectorValues. + *

+ * If the VectorValues contains no vector documents, the iterator will be positioned at + * {@link DocIdSetIterator#NO_MORE_DOCS} + * + * @param vectorValues {@link KNNVectorValues} + * @throws IOException if there is an error while accessing the vector values + */ + public static void initializeVectorValues(final KNNVectorValues vectorValues) throws IOException { + // The docId will be set to -1 if next doc has never been called yet. If it has already been called, + // no need to advance the vector values + if (vectorValues.docId() != -1) { + return; + } + // Ensure that we are not getting the next vector if there are no more docs + vectorValues.nextDoc(); + if (vectorValues.docId() == DocIdSetIterator.NO_MORE_DOCS) { + // Ensure that we are not getting the vector if there are no more docs + return; + } + vectorValues.getVector(); + } + /** * Get KNNEngine From FieldInfo * diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNMappingConfig.java b/src/main/java/org/opensearch/knn/index/mapper/KNNMappingConfig.java index cd77ebd9a..eeaef9847 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNMappingConfig.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNMappingConfig.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.mapper; +import org.opensearch.Version; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; @@ -62,4 +63,12 @@ default QuantizationConfig getQuantizationConfig() { * @return the dimension of the index; for model based indices, it will be null */ int getDimension(); + + /** + * Returns index created Version + * @return Version + */ + default Version getIndexCreatedVersion() { + return Version.CURRENT; + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 7990fdcab..4ceb9b4b2 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -17,6 +17,7 @@ import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; +import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.knn.index.KNNVectorSimilarityFunction; import org.opensearch.knn.index.VectorDataType; @@ -73,6 +74,11 @@ public Mode getMode() { public CompressionLevel getCompressionLevel() { return knnMethodConfigContext.getCompressionLevel(); } + + @Override + public Version getIndexCreatedVersion() { + return knnMethodConfigContext.getVersionCreated(); + } } ); diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index bf5bc2b51..755439ce6 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -8,6 +8,7 @@ import org.apache.lucene.document.FieldType; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.VectorEncoding; +import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.index.SpaceType; @@ -86,6 +87,11 @@ public CompressionLevel getCompressionLevel() { public QuantizationConfig getQuantizationConfig() { return quantizationConfig; } + + @Override + public Version getIndexCreatedVersion() { + return knnMethodConfigContext.getVersionCreated(); + } } ); return new MethodFieldMapper( diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index 013cb0c53..cbc7520cf 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -107,6 +107,11 @@ public QuantizationConfig getQuantizationConfig() { return quantizationConfig; } + @Override + public Version getIndexCreatedVersion() { + return indexCreatedVersion; + } + // ModelMetadata relies on cluster state which may not be available during field mapper creation. Thus, // we lazily initialize it. private void initFromModelMetadata() { diff --git a/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java b/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java index bf06e8c5e..12711911a 100644 --- a/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java +++ b/src/main/java/org/opensearch/knn/index/query/FilterIdsSelector.java @@ -78,7 +78,10 @@ public enum FilterIdsSelectorType { public static FilterIdsSelector getFilterIdSelector(final BitSet filterIdsBitSet, final int cardinality) throws IOException { long[] filterIds; FilterIdsSelector.FilterIdsSelectorType filterType; - if (filterIdsBitSet instanceof FixedBitSet) { + if (filterIdsBitSet == null) { + filterIds = null; + filterType = FilterIdsSelector.FilterIdsSelectorType.BITMAP; + } else if (filterIdsBitSet instanceof FixedBitSet) { /** * When filterIds is dense filter, using fixed bitset */ diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index d2b169b2a..ee18394f6 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -662,9 +662,24 @@ public String getWriteableName() { @Override protected QueryBuilder doRewrite(QueryRewriteContext queryShardContext) throws IOException { - // rewrite filter query if it exists to avoid runtime errors in next steps of query phase + QueryBuilder rewrittenFilter; if (Objects.nonNull(filter)) { - filter = filter.rewrite(queryShardContext); + rewrittenFilter = filter.rewrite(queryShardContext); + if (rewrittenFilter != filter) { + KNNQueryBuilder rewrittenQueryBuilder = KNNQueryBuilder.builder() + .fieldName(this.fieldName) + .vector(this.vector) + .k(this.k) + .maxDistance(this.maxDistance) + .minScore(this.minScore) + .methodParameters(this.methodParameters) + .filter(rewrittenFilter) + .ignoreUnmapped(this.ignoreUnmapped) + .rescoreContext(this.rescoreContext) + .expandNested(this.expandNested) + .build(); + return rewrittenQueryBuilder; + } } return super.doRewrite(queryShardContext); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 891f9325c..37b5cc9ad 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -129,6 +129,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { */ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOException { final BitSet filterBitSet = getFilteredDocsBitSet(context); + final int maxDoc = context.reader().maxDoc(); int cardinality = filterBitSet.cardinality(); // We don't need to go to JNI layer if no documents are found which satisfy the filters // We should give this condition a deeper look that where it should be placed. For now I feel this is a good @@ -145,7 +146,14 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep Map result = doExactSearch(context, new BitSetIterator(filterBitSet, cardinality), cardinality, k); return new PerLeafResult(filterWeight == null ? null : filterBitSet, result); } - Map docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k); + + /* + * If filters match all docs in this segment, then null should be passed as filterBitSet + * so that it will not do a bitset look up in bottom search layer. + */ + final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet; + final Map docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k); + // See whether we have to perform exact search based on approx search results // This is required if there are no native engine files or if approximate search returned // results less than K, though we have more than k filtered docs diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index 47ea215f3..5b4d6e7a1 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -63,11 +63,11 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo if (rescoreContext == null) { perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK); } else { - boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName()); + boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()); int dimension = knnQuery.getQueryVector().length; - int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension); + int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension); perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK); - if (isShardLevelRescoringEnabled == true) { + if (isShardLevelRescoringDisabled == false) { ResultUtil.reduceToTopK(perLeafResults, firstPassK); } diff --git a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java index 09aeb7591..03c03a87f 100644 --- a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java +++ b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java @@ -61,17 +61,17 @@ public static RescoreContext getDefault() { * based on the vector dimension if shard-level rescoring is disabled. * * @param finalK The final number of results to return for the entire shard. - * @param isShardLevelRescoringEnabled A boolean flag indicating whether shard-level rescoring is enabled. - * If true, the dimension-based oversampling logic is bypassed. + * @param isShardLevelRescoringDisabled A boolean flag indicating whether shard-level rescoring is disabled. + * If false, the dimension-based oversampling logic is bypassed. * @param dimension The dimension of the vector. This is used to determine the oversampling factor when * shard-level rescoring is disabled. * @return The number of results to return for the first pass of rescoring, adjusted by the oversample factor. */ - public int getFirstPassK(int finalK, boolean isShardLevelRescoringEnabled, int dimension) { + public int getFirstPassK(int finalK, boolean isShardLevelRescoringDisabled, int dimension) { // Only apply default dimension-based oversampling logic when: // 1. Shard-level rescoring is disabled // 2. The oversample factor was not provided by the user - if (!isShardLevelRescoringEnabled && !userProvided) { + if (isShardLevelRescoringDisabled && !userProvided) { // Apply new dimension-based oversampling logic when shard-level rescoring is disabled if (dimension >= DIMENSION_THRESHOLD_1000) { oversampleFactor = OVERSAMPLE_FACTOR_1000; // No oversampling for dimensions >= 1000 diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index 71616c9fd..b613efab2 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -8,6 +8,7 @@ import lombok.Getter; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.IndexSearcher; +import org.opensearch.Version; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; @@ -69,7 +70,7 @@ public KNNFieldSpace( ) { KNNVectorFieldType knnVectorFieldType = toKNNVectorFieldType(fieldType, spaceName, supportingVectorDataTypes); this.processedQuery = getProcessedQuery(query, knnVectorFieldType); - this.scoringMethod = getScoringMethod(this.processedQuery); + this.scoringMethod = getScoringMethod(this.processedQuery, knnVectorFieldType.getKnnMappingConfig().getIndexCreatedVersion()); } public ScoreScript getScoreScript( @@ -122,6 +123,10 @@ protected float[] getProcessedQuery(final Object query, final KNNVectorFieldType protected abstract BiFunction getScoringMethod(final float[] processedQuery); + protected BiFunction getScoringMethod(final float[] processedQuery, Version indexCreatedVersion) { + return getScoringMethod(processedQuery); + } + } class L2 extends KNNFieldSpace { @@ -141,9 +146,29 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) { } @Override - protected BiFunction getScoringMethod(final float[] processedQuery) { + protected BiFunction getScoringMethod(float[] processedQuery) { + return getScoringMethod(processedQuery, Version.CURRENT); + } + + @Override + protected BiFunction getScoringMethod(final float[] processedQuery, Version indexCreatedVersion) { SpaceType.COSINESIMIL.validateVector(processedQuery); float qVectorSquaredMagnitude = getVectorMagnitudeSquared(processedQuery); + if (indexCreatedVersion.onOrAfter(Version.V_2_19_0)) { + // To be consistent, we will be using same formula used by lucene as mentioned below + // https://github.com/apache/lucene/blob/0494c824e0ac8049b757582f60d085932a890800/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java#L73 + // for indices that are created on or after 2.19.0 + // + // OS Score = ( 2 − cosineSimil) / 2 + // However cosineSimil = 1 - cos θ, after applying this to above formula, + // OS Score = ( 2 − ( 1 − cos θ ) ) / 2 + // which simplifies to + // OS Score = ( 1 + cos θ ) / 2 + return (float[] q, float[] v) -> Math.max( + ((1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude)) / 2.0F), + 0 + ); + } return (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude); } } diff --git a/src/test/java/org/opensearch/knn/common/KNNVectorUtilTests.java b/src/test/java/org/opensearch/knn/common/KNNVectorUtilTests.java index d64b73c9a..5712a7f3a 100644 --- a/src/test/java/org/opensearch/knn/common/KNNVectorUtilTests.java +++ b/src/test/java/org/opensearch/knn/common/KNNVectorUtilTests.java @@ -11,17 +11,10 @@ package org.opensearch.knn.common; -import lombok.SneakyThrows; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.vectorvalues.KNNVectorValues; -import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; -import org.opensearch.knn.index.vectorvalues.TestVectorValues; import java.util.List; -import static org.opensearch.knn.common.KNNVectorUtil.iterateVectorValuesOnce; - public class KNNVectorUtilTests extends KNNTestCase { public void testByteZeroVector() { assertTrue(KNNVectorUtil.isZeroVector(new byte[] { 0, 0, 0 })); @@ -38,23 +31,4 @@ public void testIntListToArray() { assertNull(KNNVectorUtil.intListToArray(List.of())); assertNull(KNNVectorUtil.intListToArray(null)); } - - @SneakyThrows - public void testInit() { - // Give - final List floatArray = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }); - final int dimension = floatArray.get(0).length; - final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( - floatArray - ); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); - - // When - iterateVectorValuesOnce(knnVectorValues); - - // Then - assertNotEquals(-1, knnVectorValues.docId()); - assertArrayEquals(floatArray.get(0), knnVectorValues.getVector(), 0.001f); - assertEquals(dimension, knnVectorValues.dimension()); - } } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index f6aef8cb1..c2e75ecb2 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -412,6 +412,51 @@ public void testRadialQuery_withFilter_thenSuccess() { deleteKNNIndex(INDEX_NAME); } + @SneakyThrows + public void testQueryWithFilterMultipleShards() { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD_NAME) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, "3") + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, METHOD_HNSW) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .startObject(INTEGER_FIELD_NAME) + .field(TYPE_FIELD_NAME, FILED_TYPE_INTEGER) + .endObject() + .endObject() + .endObject(); + String mapping = builder.toString(); + + createIndex(INDEX_NAME, Settings.builder().put("number_of_shards", 10).put("number_of_replicas", 0).put("index.knn", true).build()); + putMappingRequest(INDEX_NAME, mapping); + + addKnnDocWithAttributes("doc1", new float[] { 7.0f, 7.0f, 3.0f }, ImmutableMap.of("dateReceived", "2024-10-01")); + + refreshIndex(INDEX_NAME); + + final float[] searchVector = { 6.0f, 7.0f, 3.0f }; + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder( + FIELD_NAME, + searchVector, + 1, + QueryBuilders.boolQuery().must(QueryBuilders.rangeQuery("dateReceived").gte("2023-11-01")) + ), + 10 + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(1, knnResults.size()); + } + @SneakyThrows public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { String indexName = "test-index"; diff --git a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java index c7a8e7ed8..24990dd36 100644 --- a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java @@ -159,7 +159,7 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() { } @SneakyThrows - public void testShardLevelRescoringEnabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() { + public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() { Node mockNode = createMockNode(Collections.emptyMap()); mockNode.start(); ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class); @@ -167,7 +167,7 @@ public void testShardLevelRescoringEnabled_whenNoValuesProvidedByUser_thenDefaul mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet(); KNNSettings.state().setClusterService(clusterService); - boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME); + boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME); mockNode.close(); assertFalse(shardLevelRescoringDisabled); } @@ -188,7 +188,7 @@ public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingA mockNode.client().admin().indices().updateSettings(new UpdateSettingsRequest(rescoringDisabledSetting, INDEX_NAME)).actionGet(); - boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME); + boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME); mockNode.close(); assertEquals(userDefinedRescoringDisabled, shardLevelRescoringDisabled); } diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index 688d22e74..b478415a0 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -15,6 +15,7 @@ import org.opensearch.client.Response; import org.opensearch.client.ResponseException; import org.opensearch.common.Nullable; +import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.index.query.QueryBuilder; @@ -282,6 +283,51 @@ public void testQueryWithFilterUsingByteVectorDataType() { validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult); } + @SneakyThrows + public void testQueryWithFilterMultipleShards() { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD_NAME) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, DIMENSION) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, METHOD_HNSW) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.LUCENE.getName()) + .endObject() + .endObject() + .startObject(INTEGER_FIELD_NAME) + .field(TYPE_FIELD_NAME, FILED_TYPE_INTEGER) + .endObject() + .endObject() + .endObject(); + String mapping = builder.toString(); + + createIndex(INDEX_NAME, Settings.builder().put("number_of_shards", 10).put("number_of_replicas", 0).put("index.knn", true).build()); + putMappingRequest(INDEX_NAME, mapping); + + addKnnDocWithAttributes("doc1", new float[] { 7.0f, 7.0f, 3.0f }, ImmutableMap.of("dateReceived", "2024-10-01")); + + refreshIndex(INDEX_NAME); + + final float[] searchVector = { 6.0f, 7.0f, 3.0f }; + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder( + FIELD_NAME, + searchVector, + 1, + QueryBuilders.boolQuery().must(QueryBuilders.rangeQuery("dateReceived").gte("2023-11-01")) + ), + 10 + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(1, knnResults.size()); + } + @SneakyThrows public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful() { XContentBuilder builder = XContentFactory.jsonBuilder() diff --git a/src/test/java/org/opensearch/knn/index/NmslibIT.java b/src/test/java/org/opensearch/knn/index/NmslibIT.java index 8ca436bf4..e0ba58eb1 100644 --- a/src/test/java/org/opensearch/knn/index/NmslibIT.java +++ b/src/test/java/org/opensearch/knn/index/NmslibIT.java @@ -195,6 +195,61 @@ public void testEndToEnd() throws Exception { fail("Graphs are not getting evicted"); } + public void testEndToEnd_withApproxAndExactSearch_inSameIndex_ForCosineSpaceType() throws Exception { + String indexName = "test-index-1"; + String fieldName = "test-field-1"; + SpaceType spaceType = SpaceType.COSINESIMIL; + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(indexName, buildKNNIndexSettings(0), mapping); + + // Index one document + addKnnDoc(indexName, randomAlphaOfLength(5), fieldName, Floats.asList(testData.indexData.vectors[0]).toArray()); + + // Assert we have the right number of documents in the index + refreshAllIndices(); + assertEquals(1, getDocCount(indexName)); + // update threshold setting to skip building graph + updateIndexSettings(indexName, Settings.builder().put(KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, -1)); + // add duplicate document with different id + addKnnDoc(indexName, randomAlphaOfLength(5), fieldName, Floats.asList(testData.indexData.vectors[0]).toArray()); + assertEquals(2, getDocCount(indexName)); + final int k = 2; + // search index + Response response = searchKNNIndex( + indexName, + KNNQueryBuilder.builder().fieldName(fieldName).vector(testData.queries[0]).k(k).build(), + k + ); + String responseBody = EntityUtils.toString(response.getEntity()); + List knnResults = parseSearchResponse(responseBody, fieldName); + assertEquals(k, knnResults.size()); + + List actualScores = parseSearchResponseScore(responseBody, fieldName); + + // both document should have identical score + assertEquals(actualScores.get(0), actualScores.get(1), 0.001); + } + @SneakyThrows private void validateSearch( final String indexName, diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java index 86e22cd88..b6680925e 100644 --- a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java @@ -6,16 +6,24 @@ package org.opensearch.knn.index.codec.util; import junit.framework.TestCase; +import lombok.SneakyThrows; import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.search.DocIdSetIterator; +import org.junit.Assert; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import java.util.Collections; import java.util.List; import java.util.Set; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; public class KNNCodecUtilTests extends TestCase { @@ -46,4 +54,38 @@ public void testGetKNNEngines() { assertEquals(engineFiles.size(), 2); assertTrue(engineFiles.get(0).equals("_0_2011_target_field.faissc")); } + + @SneakyThrows + public void testInitializeVectorValues_whenValidVectorValues_thenSuccess() { + // Give + final List floatArray = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }); + final int dimension = floatArray.get(0).length; + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + floatArray + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + // When + initializeVectorValues(knnVectorValues); + + // Then + Assert.assertNotEquals(-1, knnVectorValues.docId()); + Assert.assertArrayEquals(floatArray.get(0), knnVectorValues.getVector(), 0.001f); + assertEquals(dimension, knnVectorValues.dimension()); + } + + @SneakyThrows + public void testInitializeVectorValues_whenNoDocs_thenSuccess() { + // Give + final List floatArray = Collections.emptyList(); + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + floatArray + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + // When + initializeVectorValues(knnVectorValues); + // Then + Assert.assertEquals(DocIdSetIterator.NO_MORE_DOCS, knnVectorValues.docId()); + } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 30c8007e6..1333d616e 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -1066,11 +1066,14 @@ public void testDoRewrite_whenFilterSet_thenSuccessful() { .filter(rewrittenFilter) .k(K) .build(); + // When KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).filter(filter).k(K).build(); QueryBuilder actual = knnQueryBuilder.rewrite(context); + assertEquals(knnQueryBuilder, KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).filter(filter).k(K).build()); + // Then assertEquals(expected, actual); } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 511895026..8011cc08c 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -671,7 +671,7 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is when(liveDocsBits.length()).thenReturn(1000); final SegmentReader reader = mockSegmentReader(); - when(reader.maxDoc()).thenReturn(filterDocIds.length); + when(reader.maxDoc()).thenReturn(filterDocIds.length + 1); when(reader.getLiveDocs()).thenReturn(liveDocsBits); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); @@ -758,6 +758,88 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); } + @SneakyThrows + public void testANNWithFilterQuery_whenFiltersMatchAllDocs_thenSuccess() { + // Given + int k = 3; + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length); + for (int docId : filterDocIds) { + filterBitSet.set(docId); + } + + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), eq(null), anyInt(), any()) + ).thenReturn(getFilteredKNNQueryResults()); + + final Bits liveDocsBits = mock(Bits.class); + for (int filterDocId : filterDocIds) { + when(liveDocsBits.get(filterDocId)).thenReturn(true); + } + when(liveDocsBits.length()).thenReturn(1000); + + final SegmentReader reader = mockSegmentReader(); + when(reader.maxDoc()).thenReturn(filterDocIds.length); + when(reader.getLiveDocs()).thenReturn(liveDocsBits); + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); + + final Weight filterQueryWeight = mock(Weight.class); + final Scorer filterScorer = mock(Scorer.class); + when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); + // Just to make sure that we are not hitting the exact search condition + when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1)); + + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.getValue() + ); + + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(attributesMap); + + // When + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + + // Then + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()), + times(1) + ); + + final List actualDocIds = new ArrayList<>(); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + private SegmentReader mockSegmentReader() { Path path = mock(Path.class); @@ -815,7 +897,7 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); // scorer will return 2 documents when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1)); - when(reader.maxDoc()).thenReturn(1); + when(reader.maxDoc()).thenReturn(2); final Bits liveDocsBits = mock(Bits.class); when(reader.getLiveDocs()).thenReturn(liveDocsBits); when(liveDocsBits.get(filterDocId)).thenReturn(true); @@ -891,6 +973,7 @@ public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() { final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); when(leafReaderContext.reader()).thenReturn(reader); + when(reader.maxDoc()).thenReturn(1); final FSDirectory directory = mock(FSDirectory.class); when(reader.directory()).thenReturn(directory); @@ -968,7 +1051,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); // scorer will return 2 documents when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1)); - when(reader.maxDoc()).thenReturn(1); + when(reader.maxDoc()).thenReturn(2); final Bits liveDocsBits = mock(Bits.class); when(reader.getLiveDocs()).thenReturn(liveDocsBits); when(liveDocsBits.get(filterDocId)).thenReturn(true); @@ -1168,6 +1251,7 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() { final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); when(leafReaderContext.reader()).thenReturn(reader); + when(reader.maxDoc()).thenReturn(1); final Weight filterQueryWeight = mock(Weight.class); final Scorer filterScorer = mock(Scorer.class); @@ -1202,7 +1286,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() { // We will have 0, 1 for filteredIds and 2 will be the parent id for both of them final Scorer filterScorer = mock(Scorer.class); when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(2)); - when(reader.maxDoc()).thenReturn(2); + when(reader.maxDoc()).thenReturn(3); // Query vector is {1.8f, 2.4f}, therefore, second vector {1.9f, 2.5f} should be returned in a result final List vectors = Arrays.asList(new float[] { 0.1f, 0.3f }, new float[] { 1.9f, 2.5f }); diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 789bd1054..87c4a5014 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -183,7 +183,7 @@ public void testRescoreWhenShardLevelRescoringEnabled() { ) { // When shard-level re-scoring is enabled - mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true); + mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false); // Mock ResultUtil to return valid TopDocs mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(any(), anyInt())) @@ -265,7 +265,7 @@ public void testRescore() { ) { // When shard-level re-scoring is enabled - mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true); + mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false); mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod); mockedResultUtil.when(() -> ResultUtil.resultMapToDocIds(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod); diff --git a/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java b/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java index 2b309e4ab..a0a5cc546 100644 --- a/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java +++ b/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java @@ -16,23 +16,23 @@ public void testGetFirstPassK() { float oversample = 2.6f; RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); int finalK = 100; - boolean isShardLevelRescoringEnabled = true; + boolean isShardLevelRescoringDisabled = false; int dimension = 500; // Case 1: Test with standard oversample factor when shard-level rescoring is enabled - assertEquals(260, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); + assertEquals(260, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension)); // Case 2: Test with a very small finalK that should result in a value less than MIN_FIRST_PASS_RESULTS finalK = 1; - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension)); // Case 3: Test with finalK = 0, should return MIN_FIRST_PASS_RESULTS finalK = 0; - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension)); // Case 4: Test with finalK = MAX_FIRST_PASS_RESULTS, should cap at MAX_FIRST_PASS_RESULTS finalK = MAX_FIRST_PASS_RESULTS; - assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); + assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension)); } public void testGetFirstPassKWithDimensionBasedOversampling() { @@ -42,44 +42,44 @@ public void testGetFirstPassKWithDimensionBasedOversampling() { // Case 1: Test no oversampling for dimensions >= 1000 when shard-level rescoring is disabled dimension = 1000; RescoreContext rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensuring dimension-based logic applies - assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // No oversampling + assertEquals(100, rescoreContext.getFirstPassK(finalK, true, dimension)); // No oversampling // Case 2: Test 2x oversampling for dimensions >= 768 but < 1000 when shard-level rescoring is disabled dimension = 800; rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over - assertEquals(200, rescoreContext.getFirstPassK(finalK, false, dimension)); // 2x oversampling + assertEquals(200, rescoreContext.getFirstPassK(finalK, true, dimension)); // 2x oversampling // Case 3: Test 3x oversampling for dimensions < 768 when shard-level rescoring is disabled dimension = 700; rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over - assertEquals(300, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling + assertEquals(300, rescoreContext.getFirstPassK(finalK, true, dimension)); // 3x oversampling // Case 4: Shard-level rescoring enabled, oversample factor should be used as provided by the user (ignore dimension) rescoreContext = RescoreContext.builder().oversampleFactor(5.0f).userProvided(true).build(); // Provided by user dimension = 500; - assertEquals(500, rescoreContext.getFirstPassK(finalK, true, dimension)); // User-defined oversample factor should be used + assertEquals(500, rescoreContext.getFirstPassK(finalK, false, dimension)); // User-defined oversample factor should be used // Case 5: Test finalK where oversampling factor results in a value less than MIN_FIRST_PASS_RESULTS finalK = 10; dimension = 700; rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure dimension-based logic applies - assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling results in 30 + assertEquals(100, rescoreContext.getFirstPassK(finalK, true, dimension)); // 3x oversampling results in 30 } public void testGetFirstPassKWithMinPassK() { float oversample = 0.5f; RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided - boolean isShardLevelRescoringEnabled = false; + boolean isShardLevelRescoringDisabled = true; // Case 1: Test where finalK * oversample is smaller than MIN_FIRST_PASS_RESULTS int finalK = 10; int dimension = 700; - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension)); // Case 2: Test where finalK * oversample results in exactly MIN_FIRST_PASS_RESULTS finalK = 100; oversample = 1.0f; // This will result in exactly 100 (MIN_FIRST_PASS_RESULTS) rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension)); } } diff --git a/src/test/java/org/opensearch/knn/integ/FilteredSearchANNSearchIT.java b/src/test/java/org/opensearch/knn/integ/FilteredSearchANNSearchIT.java new file mode 100644 index 000000000..191ab944c --- /dev/null +++ b/src/test/java/org/opensearch/knn/integ/FilteredSearchANNSearchIT.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.integ; + +import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.KNNJsonQueryBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.index.KNNSettings; +import java.util.List; + +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; + +@Log4j2 +public class FilteredSearchANNSearchIT extends KNNRestTestCase { + @SneakyThrows + public void testFilteredSearchWithFaissHnsw_whenFiltersMatchAllDocs_thenReturnCorrectResults() { + String filterFieldName = "color"; + final int expectResultSize = randomIntBetween(1, 3); + final String filterValue = "red"; + createKnnIndex(INDEX_NAME, getKNNDefaultIndexSettings(), createKnnIndexMapping(FIELD_NAME, 3, METHOD_HNSW, FAISS_NAME)); + + // ingest 4 vector docs into the index with the same field {"color": "red"} + for (int i = 0; i < 4; i++) { + addKnnDocWithAttributes(String.valueOf(i), new float[] { i, i, i }, ImmutableMap.of(filterFieldName, filterValue)); + } + + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 0)); + + Float[] queryVector = { 3f, 3f, 3f }; + // All docs in one segment will match the filters value + String query = KNNJsonQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(expectResultSize) + .filterFieldName(filterFieldName) + .filterValue(filterValue) + .build() + .getQueryString(); + Response response = searchKNNIndex(INDEX_NAME, query, expectResultSize); + String entity = EntityUtils.toString(response.getEntity()); + List docIds = parseIds(entity); + assertEquals(expectResultSize, docIds.size()); + assertEquals(expectResultSize, parseTotalSearchHits(entity)); + } +} diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java index 4fc549d6b..99e847eea 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java @@ -10,6 +10,7 @@ import java.util.Locale; import lombok.SneakyThrows; +import org.apache.lucene.index.VectorSimilarityFunction; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.engine.KNNMethodContext; @@ -86,7 +87,11 @@ public void testCosineSimilarity_whenValid_thenSucceed() { getMappingConfigForMethodMapping(knnMethodContext, 3) ); KNNScoringSpace.CosineSimilarity cosineSimilarity = new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, fieldType); - assertEquals(2F, cosineSimilarity.getScoringMethod().apply(arrayFloat2, arrayFloat), 0.1F); + assertEquals( + VectorSimilarityFunction.COSINE.compare(arrayFloat2, arrayFloat), + cosineSimilarity.getScoringMethod().apply(arrayFloat2, arrayFloat), + 0.1F + ); // invalid zero vector final List queryZeroVector = List.of(0.0f, 0.0f, 0.0f);