From 5b90848472bb7511b9c6c9afec12c00bc05da3bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Vil=C3=A1?= Date: Mon, 3 Jun 2024 13:00:43 -0500 Subject: [PATCH] Add text embedding processor (#1007) * Add text embedding processor Signed-off-by: miguel-vila * add changelog entry Signed-off-by: miguel-vila * add (de)serialization test Signed-off-by: miguel-vila * fix files headers Signed-off-by: miguel-vila * fix for java 8 Signed-off-by: miguel-vila --------- Signed-off-by: miguel-vila --- CHANGELOG.md | 2 + .../client/opensearch/ingest/Processor.java | 30 +++ .../opensearch/ingest/ProcessorBuilders.java | 8 + .../ingest/TextEmbeddingProcessor.java | 193 ++++++++++++++++++ .../ingest/TextEmbeddingProcessorTest.java | 59 ++++++ 5 files changed, 292 insertions(+) create mode 100644 java-client/src/main/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessor.java create mode 100644 java-client/src/test/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessorTest.java diff --git a/CHANGELOG.md b/CHANGELOG.md index b487551405..731794785f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ This section is for maintaining a changelog for all breaking changes for the cli ### Added +- Added support for [text embedding processor](https://opensearch.org/docs/latest/ingest-pipelines/processors/text-embedding/) ([#1007](https://github.com/opensearch-project/opensearch-java/pull/1007)) + ### Dependencies ### Changed diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/ingest/Processor.java b/java-client/src/main/java/org/opensearch/client/opensearch/ingest/Processor.java index c315527ee8..3bc8c0afea 100644 --- a/java-client/src/main/java/org/opensearch/client/opensearch/ingest/Processor.java +++ b/java-client/src/main/java/org/opensearch/client/opensearch/ingest/Processor.java @@ -127,6 +127,8 @@ public enum Kind implements JsonEnum { Inference("inference"), + TextEmbedding("text_embedding"), + ; private final String jsonValue; @@ -735,6 +737,23 @@ public InferenceProcessor inference() { return TaggedUnionUtils.get(this, Kind.Inference); } + /** + * Is this variant instance of kind {@code text_embedding}? + */ + public boolean isTextEmbedding() { + return _kind == Kind.TextEmbedding; + } + + /** + * Get the {@code text_embedding} variant value. + * + * @throws IllegalStateException + * if the current variant is not of the {@code text_embedding} kind. + */ + public TextEmbeddingProcessor textEmbedding() { + return TaggedUnionUtils.get(this, Kind.TextEmbedding); + } + @Override @SuppressWarnings("unchecked") public void serialize(JsonGenerator generator, JsonpMapper mapper) { @@ -1086,6 +1105,16 @@ public ObjectBuilder inference(Function textEmbedding(TextEmbeddingProcessor v) { + this._kind = Kind.TextEmbedding; + this._value = v; + return this; + } + + public ObjectBuilder textEmbedding(Function> fn) { + return this.textEmbedding(fn.apply(new TextEmbeddingProcessor.Builder()).build()); + } + public Processor build() { _checkSingleUse(); return new Processor(this); @@ -1128,6 +1157,7 @@ protected static void setupProcessorDeserializer(ObjectDeserializer op) op.add(Builder::drop, DropProcessor._DESERIALIZER, "drop"); op.add(Builder::circle, CircleProcessor._DESERIALIZER, "circle"); op.add(Builder::inference, InferenceProcessor._DESERIALIZER, "inference"); + op.add(Builder::textEmbedding, TextEmbeddingProcessor._DESERIALIZER, "text_embedding"); } diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/ingest/ProcessorBuilders.java b/java-client/src/main/java/org/opensearch/client/opensearch/ingest/ProcessorBuilders.java index 85f95c4772..921ce75f6a 100644 --- a/java-client/src/main/java/org/opensearch/client/opensearch/ingest/ProcessorBuilders.java +++ b/java-client/src/main/java/org/opensearch/client/opensearch/ingest/ProcessorBuilders.java @@ -301,4 +301,12 @@ public static InferenceProcessor.Builder inference() { return new InferenceProcessor.Builder(); } + /** + * Creates a builder for the {@link TextEmbeddingProcessor text_embedding} + * {@code Processor} variant. + */ + public static TextEmbeddingProcessor.Builder textEmbedding() { + return new TextEmbeddingProcessor.Builder(); + } + } diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessor.java b/java-client/src/main/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessor.java new file mode 100644 index 0000000000..214500dbb5 --- /dev/null +++ b/java-client/src/main/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessor.java @@ -0,0 +1,193 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.client.opensearch.ingest; + +import jakarta.json.stream.JsonGenerator; +import java.util.Map; +import java.util.function.Function; +import javax.annotation.Nullable; +import org.opensearch.client.json.JsonpDeserializable; +import org.opensearch.client.json.JsonpDeserializer; +import org.opensearch.client.json.JsonpMapper; +import org.opensearch.client.json.ObjectBuilderDeserializer; +import org.opensearch.client.json.ObjectDeserializer; +import org.opensearch.client.util.ApiTypeHelper; +import org.opensearch.client.util.ObjectBuilder; + +// typedef: ingest._types.TextEmbeddingProcessor + +@JsonpDeserializable +public class TextEmbeddingProcessor extends ProcessorBase implements ProcessorVariant { + private final String modelId; + + private final Map fieldMap; + + @Nullable + private final String description; + + // --------------------------------------------------------------------------------------------- + + private TextEmbeddingProcessor(Builder builder) { + super(builder); + + this.modelId = ApiTypeHelper.requireNonNull(builder.modelId, this, "modelId"); + this.fieldMap = ApiTypeHelper.unmodifiableRequired(builder.fieldMap, this, "fieldMap"); + this.description = builder.description; + + } + + public static TextEmbeddingProcessor of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + /** + * Processor variant kind. + */ + @Override + public Processor.Kind _processorKind() { + return Processor.Kind.Inference; + } + + /** + * Required - API name: {@code model_id} + */ + public final String modelId() { + return this.modelId; + } + + /** + * API name: {@code field_map} + */ + public final Map fieldMap() { + return this.fieldMap; + } + + /** + * API name: {@code description} + */ + @Nullable + public final String description() { + return this.description; + } + + protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { + + super.serializeInternal(generator, mapper); + generator.writeKey("model_id"); + generator.write(this.modelId); + + if (ApiTypeHelper.isDefined(this.fieldMap)) { + generator.writeKey("field_map"); + generator.writeStartObject(); + for (Map.Entry item0 : this.fieldMap.entrySet()) { + generator.writeKey(item0.getKey()); + generator.write(item0.getValue()); + } + generator.writeEnd(); + + } + if (this.description != null) { + generator.writeKey("description"); + generator.write(this.description); + } + + } + + // --------------------------------------------------------------------------------------------- + + /** + * Builder for {@link TextEmbeddingProcessor}. + */ + + public static class Builder extends ProcessorBase.AbstractBuilder implements ObjectBuilder { + private String modelId; + + @Nullable + private Map fieldMap; + + @Nullable + private String description; + + /** + * Required - API name: {@code model_id} + */ + public final Builder modelId(String value) { + this.modelId = value; + return this; + } + + /** + * API name: {@code field_map} + *

+ * Adds all entries of map to fieldMap. + */ + public final Builder fieldMap(Map map) { + this.fieldMap = _mapPutAll(this.fieldMap, map); + return this; + } + + /** + * API name: {@code field_map} + *

+ * Adds an entry to fieldMap. + */ + public final Builder fieldMap(String key, String value) { + this.fieldMap = _mapPut(this.fieldMap, key, value); + return this; + } + + /** + * API name: {@code description} + */ + public final Builder description(@Nullable String value) { + this.description = value; + return this; + } + + @Override + protected Builder self() { + return this; + } + + /** + * Builds a {@link TextEmbeddingProcessor}. + * + * @throws NullPointerException + * if some of the required fields are null. + */ + public TextEmbeddingProcessor build() { + _checkSingleUse(); + + return new TextEmbeddingProcessor(this); + } + } + + // --------------------------------------------------------------------------------------------- + + /** + * Json deserializer for {@link TextEmbeddingProcessor} + */ + public static final JsonpDeserializer _DESERIALIZER = ObjectBuilderDeserializer.lazy( + Builder::new, + TextEmbeddingProcessor::setupTextEmbeddingProcessorDeserializer + ); + + protected static void setupTextEmbeddingProcessorDeserializer(ObjectDeserializer op) { + ProcessorBase.setupProcessorBaseDeserializer(op); + op.add(Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id"); + op.add(Builder::fieldMap, JsonpDeserializer.stringMapDeserializer(JsonpDeserializer.stringDeserializer()), "field_map"); + op.add(Builder::description, JsonpDeserializer.stringDeserializer(), "description"); + } + +} diff --git a/java-client/src/test/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessorTest.java b/java-client/src/test/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessorTest.java new file mode 100644 index 0000000000..088fc59f82 --- /dev/null +++ b/java-client/src/test/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessorTest.java @@ -0,0 +1,59 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.client.opensearch.ingest; + +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; +import org.opensearch.client.opensearch.model.ModelTestCase; + +public class TextEmbeddingProcessorTest extends ModelTestCase { + + private static final Map baseFieldMap = new HashMap<>(); + static { + baseFieldMap.put("input_field", "vector_field"); + } + + private static TextEmbeddingProcessor.Builder baseTextEmbeddingProcessor() { + return new TextEmbeddingProcessor.Builder().modelId("modelId").fieldMap(baseFieldMap).tag("some-tag"); + } + + @Test + public void testJsonRoundtripWithDescription() { + Processor processor = new Processor.Builder().textEmbedding( + baseTextEmbeddingProcessor().description("processor-description").build() + ).build(); + String json = + "{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"},\"description\":\"processor-description\"}}"; + TextEmbeddingProcessor deserialized = checkJsonRoundtrip(processor, json).textEmbedding(); + + assertEquals("modelId", deserialized.modelId()); + assertEquals(baseFieldMap, deserialized.fieldMap()); + assertEquals("processor-description", deserialized.description()); + assertEquals("some-tag", deserialized.tag()); + } + + @Test + public void testJsonRoundtripWithoutDescription() { + Processor processor = new Processor.Builder().textEmbedding(baseTextEmbeddingProcessor().build()).build(); + String json = + "{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"}}}"; + TextEmbeddingProcessor deserialized = checkJsonRoundtrip(processor, json).textEmbedding(); + + assertEquals("modelId", deserialized.modelId()); + assertEquals(baseFieldMap, deserialized.fieldMap()); + assertNull(deserialized.description()); + assertEquals("some-tag", deserialized.tag()); + } +}