Skip to content

Commit

Permalink
Fixes #1137 - Adding query_image to Neural query (#1138)
Browse files Browse the repository at this point in the history
* Fixes #1137

Signed-off-by: uri.nudelman <[email protected]>

* Adds missing documentation. Added changelog

Signed-off-by: uri.nudelman <[email protected]>

* Added deserialization test

Signed-off-by: uri.nudelman <[email protected]>

---------

Signed-off-by: uri.nudelman <[email protected]>
Co-authored-by: uri.nudelman <[email protected]>
(cherry picked from commit 7b3719b)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and uriofferup committed Aug 15, 2024
1 parent 5b632ff commit c433a8d
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

## [Unreleased 2.x]
### Added
- Adds `queryImage` (query_image) field to `NeuralQuery`, following definition in ([Neural Query](https://opensearch.org/docs/latest/query-dsl/specialized/neural/)) ([#1137](https://github.com/opensearch-project/opensearch-java/pull/1138))

### Dependencies

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
import org.opensearch.client.json.ObjectBuilderDeserializer;
import org.opensearch.client.json.ObjectDeserializer;
import org.opensearch.client.util.ApiTypeHelper;
import org.opensearch.client.util.MissingRequiredPropertiesException;
import org.opensearch.client.util.ObjectBuilder;

@JsonpDeserializable
public class NeuralQuery extends QueryBase implements QueryVariant {

private final String field;
private final String queryText;
private final String queryImage;
private final int k;
@Nullable
private final String modelId;
Expand All @@ -34,7 +36,11 @@ private NeuralQuery(NeuralQuery.Builder builder) {
super(builder);

this.field = ApiTypeHelper.requireNonNull(builder.field, this, "field");
this.queryText = ApiTypeHelper.requireNonNull(builder.queryText, this, "queryText");
if (builder.queryText == null && builder.queryImage == null && !ApiTypeHelper.requiredPropertiesCheckDisabled()) {
throw new MissingRequiredPropertiesException(this, "queryText", "queryImage");
}
this.queryText = builder.queryText;
this.queryImage = builder.queryImage;
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.modelId = builder.modelId;
this.filter = builder.filter;
Expand Down Expand Up @@ -64,14 +70,25 @@ public final String field() {
}

/**
* Required - Search query text.
* Required - The query_text if query_image is not set.
* Optional - The query_text if query_image is set.
*
* @return Search query text.
*/
public final String queryText() {
return this.queryText;
}

/**
* Required - The query_image if query_text is not set.
* Optional - The query_image if query_text is set.
*
* @return Search query image.
*/
public final String queryImage() {
return this.queryImage;
}

/**
* Required - The number of neighbors to return.
*
Expand Down Expand Up @@ -112,7 +129,13 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {

super.serializeInternal(generator, mapper);

generator.write("query_text", this.queryText);
if (this.queryText != null) {
generator.write("query_text", this.queryText);
}

if (this.queryImage != null) {
generator.write("query_image", this.queryImage);
}

if (this.modelId != null) {
generator.write("model_id", this.modelId);
Expand All @@ -129,7 +152,7 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
}

public Builder toBuilder() {
return new Builder().field(field).queryText(queryText).k(k).modelId(modelId).filter(filter);
return new Builder().field(field).queryText(queryText).queryImage(queryImage).k(k).modelId(modelId).filter(filter);
}

/**
Expand All @@ -138,6 +161,7 @@ public Builder toBuilder() {
public static class Builder extends QueryBase.AbstractBuilder<NeuralQuery.Builder> implements ObjectBuilder<NeuralQuery> {
private String field;
private String queryText;
private String queryImage;
private Integer k;
@Nullable
private String modelId;
Expand All @@ -156,7 +180,8 @@ public NeuralQuery.Builder field(@Nullable String field) {
}

/**
* Required - Search query text.
* Required - The query_text if query_image is not set.
* Optional - The query_text if query_image is set.
*
* @param queryText Search query text.
* @return This builder.
Expand All @@ -166,6 +191,18 @@ public NeuralQuery.Builder queryText(@Nullable String queryText) {
return this;
}

/**
* Required - The query_image if query_text is not set.
* Optional - The query_image if query_text is set.
*
* @param queryImage Search query image.
* @return This builder.
*/
public NeuralQuery.Builder queryImage(@Nullable String queryImage) {
this.queryImage = queryImage;
return this;
}

/**
* Optional - The model_id field if the default model for the index or field is set.
* Required - The model_id field if there is no default model set for the index or field.
Expand Down Expand Up @@ -227,6 +264,7 @@ protected static void setupNeuralQueryDeserializer(ObjectDeserializer<NeuralQuer
setupQueryBaseDeserializer(op);

op.add(NeuralQuery.Builder::queryText, JsonpDeserializer.stringDeserializer(), "query_text");
op.add(NeuralQuery.Builder::queryImage, JsonpDeserializer.stringDeserializer(), "query_image");
op.add(NeuralQuery.Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id");
op.add(NeuralQuery.Builder::k, JsonpDeserializer.integerDeserializer(), "k");
op.add(NeuralQuery.Builder::filter, Query._DESERIALIZER, "filter");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.client.util;

import java.util.StringJoiner;

/**
* Thrown by {@link ObjectBuilder#build()} when one of the required properties is missing.
* <p>
* If you think this is an error and that the reported property is actually optional, a workaround is
* available in {@link ApiTypeHelper} to disable checks. Use with caution.
*/
public class MissingRequiredPropertiesException extends RuntimeException {
private Class<?> clazz;
private String[] properties;

public MissingRequiredPropertiesException(Object obj, String... properties) {
super(
"Missing at least one required property between "
+ buildPropertiesMsg(properties)
+ " in '"
+ obj.getClass().getSimpleName()
+ "'"
);
this.clazz = obj.getClass();
this.properties = properties;
}

/**
* The class where the missing property was found
*/
public Class<?> getObjectClass() {
return clazz;
}

public String[] getPropertiesName() {
return properties;
}

private static String buildPropertiesMsg(String[] properties) {
final StringJoiner sj = new StringJoiner(",", "'", "'");
for (final String property : properties) {
sj.add(property);
}
return sj.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

import org.junit.Test;
import org.opensearch.client.opensearch.model.ModelTestCase;
import org.opensearch.client.util.MissingRequiredPropertiesException;

public class NeuralQueryTest extends ModelTestCase {
@Test
public void toBuilder() {
public void toBuilder_queryText() {
NeuralQuery origin = new NeuralQuery.Builder().field("field")
.queryText("queryText")
.k(1)
Expand All @@ -23,4 +24,37 @@ public void toBuilder() {

assertEquals(toJson(copied), toJson(origin));
}

@Test
public void toBuilder_queryImage() {
NeuralQuery origin = new NeuralQuery.Builder().field("field")
.queryImage("queryImage")
.k(1)
.filter(IdsQuery.of(builder -> builder.values("Some_ID")).toQuery())
.build();
NeuralQuery copied = origin.toBuilder().build();

assertEquals(toJson(copied), toJson(origin));
}

@Test
public void toBuilder_both() {
NeuralQuery origin = new NeuralQuery.Builder().field("field")
.queryText("queryText")
.queryImage("queryImage")
.k(1)
.filter(IdsQuery.of(builder -> builder.values("Some_ID")).toQuery())
.build();
NeuralQuery copied = origin.toBuilder().build();

assertEquals(toJson(copied), toJson(origin));
}

@Test
public void toBuilder_missing_query() {
assertThrows(
MissingRequiredPropertiesException.class,
() -> new NeuralQuery.Builder().field("field").k(1).filter(IdsQuery.of(builder -> builder.values("Some_ID")).toQuery()).build()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ public void testNeuralQueryFromJson() {
+ " \"neural\": {\n"
+ " \"passage_embedding\": {\n"
+ " \"query_text\": \"Hi world!\",\n"
+ " \"query_image\": \"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAAAXNSR0IArs4c6QAAAA1JREFUGFdj+L+U4T8ABu8CpCYJ1DQAAAAASUVORK5CYII=\",\n"
+ " \"model_id\": \"bQ1J8ooBpBj3wT4HVUsb\",\n"
+ " \"k\": 100\n"
+ " }\n"
Expand All @@ -245,6 +246,10 @@ public void testNeuralQueryFromJson() {

assertEquals("passage_embedding", searchRequest.query().neural().field());
assertEquals("Hi world!", searchRequest.query().neural().queryText());
assertEquals(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAAAXNSR0IArs4c6QAAAA1JREFUGFdj+L+U4T8ABu8CpCYJ1DQAAAAASUVORK5CYII=",
searchRequest.query().neural().queryImage()
);
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
assertEquals(100, searchRequest.query().neural().k());
}
Expand Down

0 comments on commit c433a8d

Please sign in to comment.