From 80c3e22ef995644dcce8480b82c1c54a5841c843 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Tue, 22 Oct 2024 14:43:20 +1300 Subject: [PATCH] Support Radial Search (#1166) (#1240) * Generate the ml namespace (#1158) * Generate ml.register_model_group * Start neural search sample * Re-generate ShardStatistics * Re-generate ShardFailure * Re-generate Result * Re-generate WriteResponseBase * Generate ml.delete_model_group * Generate ml.register_model * Exclude legacy license from ml namespace * Generate ml.get_task * Generate ml.delete_task * Generate ml.delete_model * Generate ml.deploy_model * Generate ml.undeploy_model * Complete neural search sample * Generate ml.get_model * Add changelog entry * note * Fix tests --------- * Fix copy-paste mistake in NeuralSearch sample (#1161) * Support Radial Search Add minScore, maxDistance parameters to KnnQuery in order to support Radial Search, which was introduced in OpenSearch 2.14 https://opensearch.org/docs/latest/search-plugins/knn/radial-search-knn/ * Update CHANGELOG.md * Update changelog post releasing 2.14.0 (#1162) (#1167) (cherry picked from commit 2a362a62455115ad6f47fe1790ddbddc0fe32eb5) * Reduce required release approvals (#1168) --------- (cherry picked from commit 3902aef3c6e7d02a0e46d2329a9f76f18a8f57cd) Signed-off-by: Thomas Farr Signed-off-by: Alex Keeler Signed-off-by: alex-keeler <59743435+alex-keeler@users.noreply.github.com> Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] Co-authored-by: Thomas Farr --- CHANGELOG.md | 1 + .../opensearch/_types/query_dsl/KnnQuery.java | 73 +++++++++++++++++-- .../_types/query_dsl/KnnQueryTest.java | 2 +- .../client/opensearch/model/VariantsTest.java | 4 +- 4 files changed, 71 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 332a809294..78aea6681c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Unreleased 2.x] ### Added +- Added `minScore` and `maxDistance` to `KnnQuery` ([#1166](https://github.com/opensearch-project/opensearch-java/pull/1166)) ### Dependencies diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java index 596752f47c..59f03cc1ab 100644 --- a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java +++ b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java @@ -23,7 +23,12 @@ public class KnnQuery extends QueryBase implements QueryVariant { private final String field; private final float[] vector; - private final int k; + @Nullable + private final Integer k; + @Nullable + private final Float minScore; + @Nullable + private final Float maxDistance; @Nullable private final Query filter; @@ -32,7 +37,9 @@ private KnnQuery(Builder builder) { this.field = ApiTypeHelper.requireNonNull(builder.field, this, "field"); this.vector = ApiTypeHelper.requireNonNull(builder.vector, this, "vector"); - this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k"); + this.k = builder.k; + this.minScore = builder.minScore; + this.maxDistance = builder.maxDistance; this.filter = builder.filter; } @@ -66,13 +73,29 @@ public final float[] vector() { } /** - * Required - The number of neighbors the search of each graph will return. + * Optional - The number of neighbors the search of each graph will return. * @return The number of neighbors to return. */ - public final int k() { + public final Integer k() { return this.k; } + /** + * Optional - The minimum score allowed for the returned search results. + * @return The minimum score allowed for the returned search results. + */ + private final Float minScore() { + return this.minScore; + } + + /** + * Optional - The maximum distance allowed between the vector and each of the returned search results. + * @return The maximum distance allowed between the vector and each ofthe returned search results. + */ + private final Float maxDistance() { + return this.maxDistance; + } + /** * Optional - A query to filter the results of the query. * @return The filter query. @@ -97,7 +120,17 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { } generator.writeEnd(); - generator.write("k", this.k); + if (this.k != null) { + generator.write("k", this.k); + } + + if (this.minScore != null) { + generator.write("min_score", this.minScore); + } + + if (this.maxDistance != null) { + generator.write("max_distance", this.maxDistance); + } if (this.filter != null) { generator.writeKey("filter"); @@ -108,7 +141,7 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { } public Builder toBuilder() { - return toBuilder(new Builder()).field(field).vector(vector).k(k).filter(filter); + return toBuilder(new Builder()).field(field).vector(vector).k(k).minScore(minScore).maxDistance(maxDistance).filter(filter); } /** @@ -122,6 +155,10 @@ public static class Builder extends QueryBase.AbstractBuilder implement @Nullable private Integer k; @Nullable + private Float minScore; + @Nullable + private Float maxDistance; + @Nullable private Query filter; /** @@ -156,6 +193,28 @@ public Builder k(@Nullable Integer k) { return this; } + /** + * Optional - The minimum score allowed for the returned search results. + * + * @param minScore The minimum score allowed for the returned search results. + * @return This builder. + */ + public Builder minScore(@Nullable Float minScore) { + this.minScore = minScore; + return this; + } + + /** + * Optional - The maximum distance allowed between the vector and each of the returned search results. + * + * @param maxDistance The maximum distance allowed between the vector and each ofthe returned search results. + * @return This builder. + */ + public Builder maxDistance(@Nullable Float maxDistance) { + this.maxDistance = maxDistance; + return this; + } + /** * Optional - A query to filter the results of the knn query. * @@ -201,6 +260,8 @@ protected static void setupKnnQueryDeserializer(ObjectDeserializer op) b.vector(vector); }, JsonpDeserializer.arrayDeserializer(JsonpDeserializer.floatDeserializer()), "vector"); op.add(Builder::k, JsonpDeserializer.integerDeserializer(), "k"); + op.add(Builder::minScore, JsonpDeserializer.floatDeserializer(), "min_score"); + op.add(Builder::maxDistance, JsonpDeserializer.floatDeserializer(), "max_distance"); op.add(Builder::filter, Query._DESERIALIZER, "filter"); op.setKey(Builder::field, JsonpDeserializer.stringDeserializer()); diff --git a/java-client/src/test/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQueryTest.java b/java-client/src/test/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQueryTest.java index a8a3fd779b..941f5224d7 100644 --- a/java-client/src/test/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQueryTest.java +++ b/java-client/src/test/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQueryTest.java @@ -14,7 +14,7 @@ public class KnnQueryTest extends ModelTestCase { @Test public void toBuilder() { - KnnQuery origin = new KnnQuery.Builder().field("field").vector(new float[] { 1.0f }).k(1).build(); + KnnQuery origin = new KnnQuery.Builder().field("field").vector(new float[] { 1.0f }).k(1).minScore(0.0f).maxDistance(1.0f).build(); KnnQuery copied = origin.toBuilder().build(); assertEquals(toJson(copied), toJson(origin)); diff --git a/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java b/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java index f647ae56f0..2251ff4c4a 100644 --- a/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java +++ b/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java @@ -282,7 +282,7 @@ public void testHybridQuery() { assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k()); assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field()); assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length); - assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k()); + assertEquals(Integer.valueOf(2), searchRequest.query().hybrid().queries().get(2).knn().k()); } @Test @@ -304,6 +304,6 @@ public void testHybridQueryFromJson() { assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k()); assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field()); assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length); - assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k()); + assertEquals(Integer.valueOf(2), searchRequest.query().hybrid().queries().get(2).knn().k()); } }