Skip to content

Commit

Permalink
Fix setting rescore as false in on_disk knn_vector query (#2420)
Browse files Browse the repository at this point in the history
* Fix setting rescore as false in on_disk knn_vector query




* Update CHANGELOG.md




* Reapply Spotless Java




* Add IT for testing rescore enabled and disabled




---------



(cherry picked from commit 387344e)

Signed-off-by: Ethan Emoto <[email protected]>
  • Loading branch information
e-emoto authored Jan 24, 2025
1 parent 8ec1c7e commit 780bf49
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Fix shard level rescoring disabled setting flag (#2352)[https://github.com/opensearch-project/k-NN/pull/2352]
* Fix filter rewrite logic which was resulting in getting inconsistent / incorrect results for cases where filter was getting rewritten for shards (#2359)[https://github.com/opensearch-project/k-NN/pull/2359]
* Fixing it to retrieve space_type from index setting when both method and top level don't have the value. [#2374](https://github.com/opensearch-project/k-NN/pull/2374)
* Fixing the bug where setting rescore as false for on_disk knn_vector query is a no-op (#2399)[https://github.com/opensearch-project/k-NN/pull/2399]
### Infrastructure
* Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259)
* Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ public enum CompressionLevel {
x1(1, "1x", null, Collections.emptySet()),
x2(2, "2x", null, Collections.emptySet()),
x4(4, "4x", null, Collections.emptySet()),
x8(8, "8x", new RescoreContext(2.0f, false), Set.of(Mode.ON_DISK)),
x16(16, "16x", new RescoreContext(3.0f, false), Set.of(Mode.ON_DISK)),
x32(32, "32x", new RescoreContext(3.0f, false), Set.of(Mode.ON_DISK)),
x64(64, "64x", new RescoreContext(5.0f, false), Set.of(Mode.ON_DISK));
x8(8, "8x", new RescoreContext(2.0f, false, false), Set.of(Mode.ON_DISK)),
x16(16, "16x", new RescoreContext(3.0f, false, false), Set.of(Mode.ON_DISK)),
x32(32, "32x", new RescoreContext(3.0f, false, false), Set.of(Mode.ON_DISK)),
x64(64, "64x", new RescoreContext(5.0f, false, false), Set.of(Mode.ON_DISK));

public static final CompressionLevel MAX_COMPRESSION_LEVEL = CompressionLevel.x64;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
List<PerLeafResult> perLeafResults;
RescoreContext rescoreContext = knnQuery.getRescoreContext();
final int finalK = knnQuery.getK();
if (rescoreContext == null) {
if (rescoreContext == null || !rescoreContext.isRescoreEnabled()) {
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK);
} else {
boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;

import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD;
Expand All @@ -34,6 +37,7 @@
import static org.opensearch.knn.index.query.KNNQueryBuilder.EXPAND_NESTED_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.RESCORE_FIELD;
import static org.opensearch.knn.index.query.parser.RescoreParser.RESCORE_PARAMETER;
import static org.opensearch.knn.index.query.rescore.RescoreContext.EXPLICITLY_DISABLED_RESCORE_CONTEXT;
import static org.opensearch.knn.index.util.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.IGNORE_UNMAPPED_FIELD;
Expand Down Expand Up @@ -84,12 +88,22 @@ private static ObjectParser<KNNQueryBuilder.Builder, Void> createInternalObjectP
);
internalParser.declareObject(KNNQueryBuilder.Builder::filter, (p, v) -> parseInnerQueryBuilder(p), FILTER_FIELD);

internalParser.declareObjectOrDefault(
KNNQueryBuilder.Builder::rescoreContext,
(p, v) -> RescoreParser.fromXContent(p),
RescoreContext::getDefault,
RESCORE_FIELD
);
internalParser.declareField((p, v, c) -> {
BiConsumer<KNNQueryBuilder.Builder, RescoreContext> consumer = KNNQueryBuilder.Builder::rescoreContext;
BiFunction<XContentParser, Void, RescoreContext> objectParser = (_p, _v) -> RescoreParser.fromXContent(_p);
Supplier<RescoreContext> defaultValue = RescoreContext::getDefault;
if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
if (p.booleanValue()) {
consumer.accept(v, defaultValue.get());
} else {
// If the user specifies false, we explicitly set to null so we don't
// accidentally resolve.
consumer.accept(v, EXPLICITLY_DISABLED_RESCORE_CONTEXT);
}
} else {
consumer.accept(v, objectParser.apply(p, c));
}
}, RESCORE_FIELD, ObjectParser.ValueType.OBJECT_OR_BOOLEAN);

internalParser.declareBoolean(KNNQueryBuilder.Builder::expandNested, EXPAND_NESTED_FIELD);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ public final class RescoreContext {
@Builder.Default
private boolean userProvided = true;

/**
* Flag to track whether rescoring has been disabled by the query parameters.
*/
@Builder.Default
private boolean rescoreEnabled = true;

public static final RescoreContext EXPLICITLY_DISABLED_RESCORE_CONTEXT = RescoreContext.builder()
.oversampleFactor(DEFAULT_OVERSAMPLE_FACTOR)
.rescoreEnabled(false)
.build();

/**
*
* @return default RescoreContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,78 @@ public void testFromXContent_missingQueryVector() throws Exception {
assertTrue(exception.getMessage(), exception.getMessage().contains("[knn] failed to parse field [vector]"));
}

public void testFromXContent_rescoreEnabled() throws Exception {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
RescoreContext explicitRescoreContext = RescoreContext.builder().oversampleFactor(1.5f).build();
// Test with default rescore
KNNQueryBuilder knnQueryBuilderDefaultRescore = KNNQueryBuilder.builder()
.fieldName(FIELD_NAME)
.vector(queryVector)
.k(K)
.rescoreContext(RescoreContext.getDefault())
.build();
XContentBuilder builderDefaultRescore = XContentFactory.jsonBuilder();
builderDefaultRescore.startObject();
builderDefaultRescore.startObject(knnQueryBuilderDefaultRescore.fieldName());
builderDefaultRescore.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilderDefaultRescore.vector());
builderDefaultRescore.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilderDefaultRescore.getK());
builderDefaultRescore.field(KNNQueryBuilder.RESCORE_FIELD.getPreferredName(), true);
builderDefaultRescore.endObject();
builderDefaultRescore.endObject();
XContentParser contentParserDefaultRescore = createParser(builderDefaultRescore);
contentParserDefaultRescore.nextToken();
KNNQueryBuilder actualBuilderDefaultRescore = KNNQueryBuilderParser.fromXContent(contentParserDefaultRescore);
assertEquals(knnQueryBuilderDefaultRescore, actualBuilderDefaultRescore);

// Test with explicit rescore
KNNQueryBuilder knnQueryBuilderExplicitRescore = KNNQueryBuilder.builder()
.fieldName(FIELD_NAME)
.vector(queryVector)
.k(K)
.rescoreContext(explicitRescoreContext)
.build();
XContentBuilder builderExplicitRescore = XContentFactory.jsonBuilder();
builderExplicitRescore.startObject();
builderExplicitRescore.startObject(knnQueryBuilderExplicitRescore.fieldName());
builderExplicitRescore.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilderExplicitRescore.vector());
builderExplicitRescore.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilderExplicitRescore.getK());
builderExplicitRescore.startObject(KNNQueryBuilder.RESCORE_FIELD.getPreferredName());
builderExplicitRescore.field(
KNNQueryBuilder.RESCORE_OVERSAMPLE_FIELD.getPreferredName(),
explicitRescoreContext.getOversampleFactor()
);
builderExplicitRescore.endObject();
builderExplicitRescore.endObject();
builderExplicitRescore.endObject();
XContentParser contentParserExplicitRescore = createParser(builderExplicitRescore);
contentParserExplicitRescore.nextToken();
KNNQueryBuilder actualBuilderExplicitRescore = KNNQueryBuilderParser.fromXContent(contentParserExplicitRescore);
assertEquals(knnQueryBuilderExplicitRescore, actualBuilderExplicitRescore);
}

public void testFromXContent_rescoreDisabled() throws Exception {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
// Test with rescore disabled
KNNQueryBuilder knnQueryBuilderRescoreDisabled = KNNQueryBuilder.builder()
.fieldName(FIELD_NAME)
.vector(queryVector)
.k(K)
.rescoreContext(RescoreContext.EXPLICITLY_DISABLED_RESCORE_CONTEXT)
.build();
XContentBuilder builderRescoreDisabled = XContentFactory.jsonBuilder();
builderRescoreDisabled.startObject();
builderRescoreDisabled.startObject(knnQueryBuilderRescoreDisabled.fieldName());
builderRescoreDisabled.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilderRescoreDisabled.vector());
builderRescoreDisabled.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilderRescoreDisabled.getK());
builderRescoreDisabled.field(KNNQueryBuilder.RESCORE_FIELD.getPreferredName(), false);
builderRescoreDisabled.endObject();
builderRescoreDisabled.endObject();
XContentParser contentParserRescoreDisabled = createParser(builderRescoreDisabled);
contentParserRescoreDisabled.nextToken();
KNNQueryBuilder actualBuilderRescoreDisabled = KNNQueryBuilderParser.fromXContent(contentParserRescoreDisabled);
assertEquals(knnQueryBuilderRescoreDisabled, actualBuilderRescoreDisabled);
}

public void testFromXContent_whenFlat_thenException() throws Exception {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
XContentBuilder builder = XContentFactory.jsonBuilder();
Expand Down
119 changes: 119 additions & 0 deletions src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,125 @@ public void testIndexCreation_whenValid_ThenSucceed() {
}
}

@SneakyThrows
public void testQueryRescoreEnabledAndDisabled() {
XContentBuilder builder;
String mode = Mode.ON_DISK.getName();
String compressionLevel = CompressionLevel.x32.getName();
String indexName = INDEX_NAME + compressionLevel;
builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(FIELD_NAME)
.field("type", "knn_vector")
.field("dimension", DIMENSION)
.field(MODE_PARAMETER, mode)
.field(COMPRESSION_LEVEL_PARAMETER, compressionLevel)
.endObject()
.endObject()
.endObject();
String mapping = builder.toString();
validateIndex(indexName, mapping);
logger.info("Compression level {}", compressionLevel);
// Do exact search and gather right scores for the documents
Response exactSearchResponse = searchKNNIndex(
indexName,
XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("script_score")
.startObject("query")
.field("match_all")
.startObject()
.endObject()
.endObject()
.startObject("script")
.field("source", "knn_score")
.field("lang", "knn")
.startObject("params")
.field("field", FIELD_NAME)
.field("query_value", TEST_VECTOR)
.field("space_type", SpaceType.L2.getValue())
.endObject()
.endObject()
.endObject()
.endObject()
.endObject(),
K
);
assertOK(exactSearchResponse);
String exactSearchResponseBody = EntityUtils.toString(exactSearchResponse.getEntity());
List<Float> exactSearchKnnResults = parseSearchResponseScore(exactSearchResponseBody, FIELD_NAME);
assertEquals(NUM_DOCS, exactSearchKnnResults.size());
// Search without rescore
Response response = searchKNNIndex(
indexName,
XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("knn")
.startObject(FIELD_NAME)
.field("vector", TEST_VECTOR)
.field("k", K)
.field(RescoreParser.RESCORE_PARAMETER, false)
.endObject()
.endObject()
.endObject()
.endObject(),
K
);
assertOK(response);
String responseBody = EntityUtils.toString(response.getEntity());
List<Float> knnResults = parseSearchResponseScore(responseBody, FIELD_NAME);
assertEquals(K, knnResults.size());
Assert.assertNotEquals(exactSearchKnnResults, knnResults);
// Search with explicit rescore
response = searchKNNIndex(
indexName,
XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("knn")
.startObject(FIELD_NAME)
.field("vector", TEST_VECTOR)
.field("k", K)
.startObject(RescoreParser.RESCORE_PARAMETER)
.field(RescoreParser.RESCORE_OVERSAMPLE_PARAMETER, 2.0f)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject(),
K
);
assertOK(response);
responseBody = EntityUtils.toString(response.getEntity());
knnResults = parseSearchResponseScore(responseBody, FIELD_NAME);
assertEquals(K, knnResults.size());
Assert.assertEquals(exactSearchKnnResults, knnResults);
// Search with default rescore
response = searchKNNIndex(
indexName,
XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("knn")
.startObject(FIELD_NAME)
.field("vector", TEST_VECTOR)
.field("k", K)
.endObject()
.endObject()
.endObject()
.endObject(),
K
);
assertOK(response);
responseBody = EntityUtils.toString(response.getEntity());
knnResults = parseSearchResponseScore(responseBody, FIELD_NAME);
assertEquals(K, knnResults.size());
Assert.assertEquals(exactSearchKnnResults, knnResults);
}

@SneakyThrows
public void testDeletedDocsWithSegmentMerge_whenValid_ThenSucceed() {
XContentBuilder builder;
Expand Down

0 comments on commit 780bf49

Please sign in to comment.