diff --git a/CHANGELOG.md b/CHANGELOG.md index 50d625a61..3bb27884c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Fix shard level rescoring disabled setting flag (#2352)[https://github.com/opensearch-project/k-NN/pull/2352] * Fix filter rewrite logic which was resulting in getting inconsistent / incorrect results for cases where filter was getting rewritten for shards (#2359)[https://github.com/opensearch-project/k-NN/pull/2359] * Fixing it to retrieve space_type from index setting when both method and top level don't have the value. [#2374](https://github.com/opensearch-project/k-NN/pull/2374) +* Fixing the bug where setting rescore as false for on_disk knn_vector query is a no-op (#2399)[https://github.com/opensearch-project/k-NN/pull/2399] ### Infrastructure * Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259) * Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279) diff --git a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java index 99f74c246..c14bb9d82 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java +++ b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java @@ -25,10 +25,10 @@ public enum CompressionLevel { x1(1, "1x", null, Collections.emptySet()), x2(2, "2x", null, Collections.emptySet()), x4(4, "4x", null, Collections.emptySet()), - x8(8, "8x", new RescoreContext(2.0f, false), Set.of(Mode.ON_DISK)), - x16(16, "16x", new RescoreContext(3.0f, false), Set.of(Mode.ON_DISK)), - x32(32, "32x", new RescoreContext(3.0f, false), Set.of(Mode.ON_DISK)), - x64(64, "64x", new RescoreContext(5.0f, false), Set.of(Mode.ON_DISK)); + x8(8, "8x", new RescoreContext(2.0f, false, false), Set.of(Mode.ON_DISK)), + x16(16, "16x", new RescoreContext(3.0f, false, false), Set.of(Mode.ON_DISK)), + x32(32, "32x", new RescoreContext(3.0f, false, false), Set.of(Mode.ON_DISK)), + x64(64, "64x", new RescoreContext(5.0f, false, false), Set.of(Mode.ON_DISK)); public static final CompressionLevel MAX_COMPRESSION_LEVEL = CompressionLevel.x64; diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index 5b4d6e7a1..1ffaa804d 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -60,7 +60,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo List perLeafResults; RescoreContext rescoreContext = knnQuery.getRescoreContext(); final int finalK = knnQuery.getK(); - if (rescoreContext == null) { + if (rescoreContext == null || !rescoreContext.isRescoreEnabled()) { perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK); } else { boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()); diff --git a/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java b/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java index 4d085da53..2d74a2052 100644 --- a/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java +++ b/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java @@ -24,7 +24,10 @@ import java.util.List; import java.util.Locale; import java.util.Objects; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.Supplier; import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; @@ -34,6 +37,7 @@ import static org.opensearch.knn.index.query.KNNQueryBuilder.EXPAND_NESTED_FIELD; import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_FIELD; import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER; +import static org.opensearch.knn.index.query.rescore.RescoreContext.EXPLICITLY_DISABLED_RESCORE_CONTEXT; import static org.opensearch.knn.index.util.IndexUtil.isClusterOnOrAfterMinRequiredVersion; import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.knn.index.query.KNNQueryBuilder.IGNORE_UNMAPPED_FIELD; @@ -84,12 +88,22 @@ private static ObjectParser createInternalObjectP ); internalParser.declareObject(KNNQueryBuilder.Builder::filter, (p, v) -> parseInnerQueryBuilder(p), FILTER_FIELD); - internalParser.declareObjectOrDefault( - KNNQueryBuilder.Builder::rescoreContext, - (p, v) -> RescoreParser.fromXContent(p), - RescoreContext::getDefault, - RESCORE_FIELD - ); + internalParser.declareField((p, v, c) -> { + BiConsumer consumer = KNNQueryBuilder.Builder::rescoreContext; + BiFunction objectParser = (_p, _v) -> RescoreParser.fromXContent(_p); + Supplier defaultValue = RescoreContext::getDefault; + if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) { + if (p.booleanValue()) { + consumer.accept(v, defaultValue.get()); + } else { + // If the user specifies false, we explicitly set to null so we don't + // accidentally resolve. + consumer.accept(v, EXPLICITLY_DISABLED_RESCORE_CONTEXT); + } + } else { + consumer.accept(v, objectParser.apply(p, c)); + } + }, RESCORE_FIELD, ObjectParser.ValueType.OBJECT_OR_BOOLEAN); internalParser.declareBoolean(KNNQueryBuilder.Builder::expandNested, EXPAND_NESTED_FIELD); diff --git a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java index 03c03a87f..4e89b1b04 100644 --- a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java +++ b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java @@ -47,6 +47,17 @@ public final class RescoreContext { @Builder.Default private boolean userProvided = true; + /** + * Flag to track whether rescoring has been disabled by the query parameters. + */ + @Builder.Default + private boolean rescoreEnabled = true; + + public static final RescoreContext EXPLICITLY_DISABLED_RESCORE_CONTEXT = RescoreContext.builder() + .oversampleFactor(DEFAULT_OVERSAMPLE_FACTOR) + .rescoreEnabled(false) + .build(); + /** * * @return default RescoreContext diff --git a/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java b/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java index 6cac5580b..4bd9acc3f 100644 --- a/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java +++ b/src/test/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParserTests.java @@ -334,6 +334,78 @@ public void testFromXContent_missingQueryVector() throws Exception { assertTrue(exception.getMessage(), exception.getMessage().contains("[knn] failed to parse field [vector]")); } + public void testFromXContent_rescoreEnabled() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + RescoreContext explicitRescoreContext = RescoreContext.builder().oversampleFactor(1.5f).build(); + // Test with default rescore + KNNQueryBuilder knnQueryBuilderDefaultRescore = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .rescoreContext(RescoreContext.getDefault()) + .build(); + XContentBuilder builderDefaultRescore = XContentFactory.jsonBuilder(); + builderDefaultRescore.startObject(); + builderDefaultRescore.startObject(knnQueryBuilderDefaultRescore.fieldName()); + builderDefaultRescore.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilderDefaultRescore.vector()); + builderDefaultRescore.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilderDefaultRescore.getK()); + builderDefaultRescore.field(KNNQueryBuilder.RESCORE_FIELD.getPreferredName(), true); + builderDefaultRescore.endObject(); + builderDefaultRescore.endObject(); + XContentParser contentParserDefaultRescore = createParser(builderDefaultRescore); + contentParserDefaultRescore.nextToken(); + KNNQueryBuilder actualBuilderDefaultRescore = KNNQueryBuilderParser.fromXContent(contentParserDefaultRescore); + assertEquals(knnQueryBuilderDefaultRescore, actualBuilderDefaultRescore); + + // Test with explicit rescore + KNNQueryBuilder knnQueryBuilderExplicitRescore = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .rescoreContext(explicitRescoreContext) + .build(); + XContentBuilder builderExplicitRescore = XContentFactory.jsonBuilder(); + builderExplicitRescore.startObject(); + builderExplicitRescore.startObject(knnQueryBuilderExplicitRescore.fieldName()); + builderExplicitRescore.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilderExplicitRescore.vector()); + builderExplicitRescore.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilderExplicitRescore.getK()); + builderExplicitRescore.startObject(KNNQueryBuilder.RESCORE_FIELD.getPreferredName()); + builderExplicitRescore.field( + KNNQueryBuilder.RESCORE_OVERSAMPLE_FIELD.getPreferredName(), + explicitRescoreContext.getOversampleFactor() + ); + builderExplicitRescore.endObject(); + builderExplicitRescore.endObject(); + builderExplicitRescore.endObject(); + XContentParser contentParserExplicitRescore = createParser(builderExplicitRescore); + contentParserExplicitRescore.nextToken(); + KNNQueryBuilder actualBuilderExplicitRescore = KNNQueryBuilderParser.fromXContent(contentParserExplicitRescore); + assertEquals(knnQueryBuilderExplicitRescore, actualBuilderExplicitRescore); + } + + public void testFromXContent_rescoreDisabled() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + // Test with rescore disabled + KNNQueryBuilder knnQueryBuilderRescoreDisabled = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .rescoreContext(RescoreContext.EXPLICITLY_DISABLED_RESCORE_CONTEXT) + .build(); + XContentBuilder builderRescoreDisabled = XContentFactory.jsonBuilder(); + builderRescoreDisabled.startObject(); + builderRescoreDisabled.startObject(knnQueryBuilderRescoreDisabled.fieldName()); + builderRescoreDisabled.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilderRescoreDisabled.vector()); + builderRescoreDisabled.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilderRescoreDisabled.getK()); + builderRescoreDisabled.field(KNNQueryBuilder.RESCORE_FIELD.getPreferredName(), false); + builderRescoreDisabled.endObject(); + builderRescoreDisabled.endObject(); + XContentParser contentParserRescoreDisabled = createParser(builderRescoreDisabled); + contentParserRescoreDisabled.nextToken(); + KNNQueryBuilder actualBuilderRescoreDisabled = KNNQueryBuilderParser.fromXContent(contentParserRescoreDisabled); + assertEquals(knnQueryBuilderRescoreDisabled, actualBuilderRescoreDisabled); + } + public void testFromXContent_whenFlat_thenException() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; XContentBuilder builder = XContentFactory.jsonBuilder(); diff --git a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java index 262221a4c..8b2cf5d2b 100644 --- a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java +++ b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java @@ -206,6 +206,125 @@ public void testIndexCreation_whenValid_ThenSucceed() { } } + @SneakyThrows + public void testQueryRescoreEnabledAndDisabled() { + XContentBuilder builder; + String mode = Mode.ON_DISK.getName(); + String compressionLevel = CompressionLevel.x32.getName(); + String indexName = INDEX_NAME + compressionLevel; + builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", DIMENSION) + .field(MODE_PARAMETER, mode) + .field(COMPRESSION_LEVEL_PARAMETER, compressionLevel) + .endObject() + .endObject() + .endObject(); + String mapping = builder.toString(); + validateIndex(indexName, mapping); + logger.info("Compression level {}", compressionLevel); + // Do exact search and gather right scores for the documents + Response exactSearchResponse = searchKNNIndex( + indexName, + XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("script_score") + .startObject("query") + .field("match_all") + .startObject() + .endObject() + .endObject() + .startObject("script") + .field("source", "knn_score") + .field("lang", "knn") + .startObject("params") + .field("field", FIELD_NAME) + .field("query_value", TEST_VECTOR) + .field("space_type", SpaceType.L2.getValue()) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(), + K + ); + assertOK(exactSearchResponse); + String exactSearchResponseBody = EntityUtils.toString(exactSearchResponse.getEntity()); + List exactSearchKnnResults = parseSearchResponseScore(exactSearchResponseBody, FIELD_NAME); + assertEquals(NUM_DOCS, exactSearchKnnResults.size()); + // Search without rescore + Response response = searchKNNIndex( + indexName, + XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(FIELD_NAME) + .field("vector", TEST_VECTOR) + .field("k", K) + .field(RescoreParser.RESCORE_PARAMETER, false) + .endObject() + .endObject() + .endObject() + .endObject(), + K + ); + assertOK(response); + String responseBody = EntityUtils.toString(response.getEntity()); + List knnResults = parseSearchResponseScore(responseBody, FIELD_NAME); + assertEquals(K, knnResults.size()); + Assert.assertNotEquals(exactSearchKnnResults, knnResults); + // Search with explicit rescore + response = searchKNNIndex( + indexName, + XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(FIELD_NAME) + .field("vector", TEST_VECTOR) + .field("k", K) + .startObject(RescoreParser.RESCORE_PARAMETER) + .field(RescoreParser.RESCORE_OVERSAMPLE_PARAMETER, 2.0f) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(), + K + ); + assertOK(response); + responseBody = EntityUtils.toString(response.getEntity()); + knnResults = parseSearchResponseScore(responseBody, FIELD_NAME); + assertEquals(K, knnResults.size()); + Assert.assertEquals(exactSearchKnnResults, knnResults); + // Search with default rescore + response = searchKNNIndex( + indexName, + XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(FIELD_NAME) + .field("vector", TEST_VECTOR) + .field("k", K) + .endObject() + .endObject() + .endObject() + .endObject(), + K + ); + assertOK(response); + responseBody = EntityUtils.toString(response.getEntity()); + knnResults = parseSearchResponseScore(responseBody, FIELD_NAME); + assertEquals(K, knnResults.size()); + Assert.assertEquals(exactSearchKnnResults, knnResults); + } + @SneakyThrows public void testDeletedDocsWithSegmentMerge_whenValid_ThenSucceed() { XContentBuilder builder;