From 6c3e68f8eb85fba3ca5c5127ff9ebcc7b4be8343 Mon Sep 17 00:00:00 2001 From: YeonghyeonKo <46114393+YeonghyeonKO@users.noreply.github.com> Date: Tue, 19 Nov 2024 05:54:08 +0900 Subject: [PATCH] Add `batch_size` param for `text_embedding` processor (#1298) * Add batchSize parameter for text_embedding processor Signed-off-by: YeonghyeonKO * throw IllegalArgumentException when batchSize is not a positive integer Signed-off-by: YeonghyeonKO * test: add test cases for BatchSize param Signed-off-by: YeonghyeonKO * test: exception when batchSize is zero or negative integer Signed-off-by: YeonghyeonKO * refactor: use assertNotNull for readability & convention Signed-off-by: YeonghyeonKO * update CHANGELOG about #1298 PR Signed-off-by: YeonghyeonKO * apply code convention to keep the codes spotless Signed-off-by: YeonghyeonKO --------- Signed-off-by: YeonghyeonKO Signed-off-by: Thomas Farr Co-authored-by: Thomas Farr --- CHANGELOG.md | 1 + .../ingest/TextEmbeddingProcessor.java | 33 +++++++++++++- .../ingest/TextEmbeddingProcessorTest.java | 43 ++++++++++++++++--- 3 files changed, 70 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 03137df4c2..f62376127d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,7 @@ This section is for maintaining a changelog for all breaking changes for the cli ### Added - Added support for disabling typed keys serialization ([#1296](https://github.com/opensearch-project/opensearch-java/pull/1296)) +- Added support for the `batch_size` param on the `text_embedding` processor ([#1298](https://github.com/opensearch-project/opensearch-java/pull/1298)) ### Dependencies diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessor.java b/java-client/src/main/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessor.java index 0f610a8cd4..58d623a22f 100644 --- a/java-client/src/main/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessor.java +++ b/java-client/src/main/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessor.java @@ -31,6 +31,9 @@ public class TextEmbeddingProcessor extends ProcessorBase implements ProcessorVa @Nullable private final String description; + @Nullable + private final Integer batchSize; + // --------------------------------------------------------------------------------------------- private TextEmbeddingProcessor(Builder builder) { @@ -39,7 +42,7 @@ private TextEmbeddingProcessor(Builder builder) { this.modelId = ApiTypeHelper.requireNonNull(builder.modelId, this, "modelId"); this.fieldMap = ApiTypeHelper.unmodifiableRequired(builder.fieldMap, this, "fieldMap"); this.description = builder.description; - + this.batchSize = builder.batchSize; } public static TextEmbeddingProcessor of(Function> fn) { @@ -76,6 +79,14 @@ public final String description() { return this.description; } + /** + * API name: {@code batch_size} + */ + @Nullable + public final Integer batchSize() { + return this.batchSize; + } + protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { super.serializeInternal(generator, mapper); @@ -96,7 +107,10 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { generator.writeKey("description"); generator.write(this.description); } - + if (this.batchSize != null) { + generator.writeKey("batch_size"); + generator.write(this.batchSize); + } } // --------------------------------------------------------------------------------------------- @@ -114,6 +128,9 @@ public static class Builder extends ProcessorBase.AbstractBuilder imple @Nullable private String description; + @Nullable + private Integer batchSize; + /** * Required - API name: {@code model_id} */ @@ -150,6 +167,17 @@ public final Builder description(@Nullable String value) { return this; } + /** + * API name: {@code batch_size} + */ + public final Builder batchSize(@Nullable Integer value) { + if (value != null && value <= 0) { + throw new IllegalArgumentException("batchSize must be a positive integer"); + } + this.batchSize = value; + return this; + } + @Override protected Builder self() { return this; @@ -183,6 +211,7 @@ protected static void setupTextEmbeddingProcessorDeserializer(ObjectDeserializer op.add(Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id"); op.add(Builder::fieldMap, JsonpDeserializer.stringMapDeserializer(JsonpDeserializer.stringDeserializer()), "field_map"); op.add(Builder::description, JsonpDeserializer.stringDeserializer(), "description"); + op.add(Builder::batchSize, JsonpDeserializer.integerDeserializer(), "batch_size"); } } diff --git a/java-client/src/test/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessorTest.java b/java-client/src/test/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessorTest.java index 46a10827ea..fb6bd73e0e 100644 --- a/java-client/src/test/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessorTest.java +++ b/java-client/src/test/java/org/opensearch/client/opensearch/ingest/TextEmbeddingProcessorTest.java @@ -25,30 +25,63 @@ private static TextEmbeddingProcessor.Builder baseTextEmbeddingProcessor() { } @Test - public void testJsonRoundtripWithDescription() { + public void testJsonRoundtripWithDescriptionAndBatchSize() { Processor processor = new Processor.Builder().textEmbedding( - baseTextEmbeddingProcessor().description("processor-description").build() + baseTextEmbeddingProcessor().description("processor-description").batchSize(1).build() ).build(); String json = - "{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"},\"description\":\"processor-description\"}}"; + "{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"},\"description\":\"processor-description\",\"batch_size\":1}}"; TextEmbeddingProcessor deserialized = checkJsonRoundtrip(processor, json).textEmbedding(); assertEquals("modelId", deserialized.modelId()); assertEquals(baseFieldMap, deserialized.fieldMap()); assertEquals("processor-description", deserialized.description()); assertEquals("some-tag", deserialized.tag()); + assertNotNull(deserialized.batchSize()); + assertEquals(1, deserialized.batchSize().intValue()); } @Test public void testJsonRoundtripWithoutDescription() { - Processor processor = new Processor.Builder().textEmbedding(baseTextEmbeddingProcessor().build()).build(); + Processor processor = new Processor.Builder().textEmbedding(baseTextEmbeddingProcessor().batchSize(1).build()).build(); String json = - "{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"}}}"; + "{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"},\"batch_size\":1}}"; TextEmbeddingProcessor deserialized = checkJsonRoundtrip(processor, json).textEmbedding(); assertEquals("modelId", deserialized.modelId()); assertEquals(baseFieldMap, deserialized.fieldMap()); assertNull(deserialized.description()); assertEquals("some-tag", deserialized.tag()); + assertNotNull(deserialized.batchSize()); + assertEquals(1, deserialized.batchSize().intValue()); + } + + @Test + public void testJsonRoundtripWithoutBatchSize() { + Processor processor = new Processor.Builder().textEmbedding( + baseTextEmbeddingProcessor().description("processor-description").build() + ).build(); + String json = + "{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"},\"description\":\"processor-description\"}}"; + TextEmbeddingProcessor deserialized = checkJsonRoundtrip(processor, json).textEmbedding(); + + assertEquals("modelId", deserialized.modelId()); + assertEquals(baseFieldMap, deserialized.fieldMap()); + assertEquals("processor-description", deserialized.description()); + assertEquals("some-tag", deserialized.tag()); + assertNull(deserialized.batchSize()); + } + + @Test + public void testInvalidBatchSizeThrowsException() { + IllegalArgumentException exceptionWhenBatchSizeIsZero = assertThrows(IllegalArgumentException.class, () -> { + new Processor.Builder().textEmbedding(baseTextEmbeddingProcessor().batchSize(0).build()).build(); + }); + assertEquals("batchSize must be a positive integer", exceptionWhenBatchSizeIsZero.getMessage()); + + IllegalArgumentException exceptionWhenBatchSizeIsNegative = assertThrows(IllegalArgumentException.class, () -> { + new Processor.Builder().textEmbedding(baseTextEmbeddingProcessor().batchSize(-1).build()).build(); + }); + assertEquals("batchSize must be a positive integer", exceptionWhenBatchSizeIsNegative.getMessage()); } }