Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Add batch_size param for text_embedding processor #1312

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
## [Unreleased 2.x]
### Added
- Added support for disabling typed keys serialization ([#1296](https://github.com/opensearch-project/opensearch-java/pull/1296))
- Added support for the `batch_size` param on the `text_embedding` processor ([#1298](https://github.com/opensearch-project/opensearch-java/pull/1298))

### Dependencies

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ public class TextEmbeddingProcessor extends ProcessorBase implements ProcessorVa
@Nullable
private final String description;

@Nullable
private final Integer batchSize;

// ---------------------------------------------------------------------------------------------

private TextEmbeddingProcessor(Builder builder) {
Expand All @@ -39,7 +42,7 @@ private TextEmbeddingProcessor(Builder builder) {
this.modelId = ApiTypeHelper.requireNonNull(builder.modelId, this, "modelId");
this.fieldMap = ApiTypeHelper.unmodifiableRequired(builder.fieldMap, this, "fieldMap");
this.description = builder.description;

this.batchSize = builder.batchSize;
}

public static TextEmbeddingProcessor of(Function<Builder, ObjectBuilder<TextEmbeddingProcessor>> fn) {
Expand Down Expand Up @@ -76,6 +79,14 @@ public final String description() {
return this.description;
}

/**
* API name: {@code batch_size}
*/
@Nullable
public final Integer batchSize() {
return this.batchSize;
}

protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {

super.serializeInternal(generator, mapper);
Expand All @@ -96,7 +107,10 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
generator.writeKey("description");
generator.write(this.description);
}

if (this.batchSize != null) {
generator.writeKey("batch_size");
generator.write(this.batchSize);
}
}

// ---------------------------------------------------------------------------------------------
Expand All @@ -114,6 +128,9 @@ public static class Builder extends ProcessorBase.AbstractBuilder<Builder> imple
@Nullable
private String description;

@Nullable
private Integer batchSize;

/**
* Required - API name: {@code model_id}
*/
Expand Down Expand Up @@ -150,6 +167,17 @@ public final Builder description(@Nullable String value) {
return this;
}

/**
* API name: {@code batch_size}
*/
public final Builder batchSize(@Nullable Integer value) {
if (value != null && value <= 0) {
throw new IllegalArgumentException("batchSize must be a positive integer");
}
this.batchSize = value;
return this;
}

@Override
protected Builder self() {
return this;
Expand Down Expand Up @@ -183,6 +211,7 @@ protected static void setupTextEmbeddingProcessorDeserializer(ObjectDeserializer
op.add(Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id");
op.add(Builder::fieldMap, JsonpDeserializer.stringMapDeserializer(JsonpDeserializer.stringDeserializer()), "field_map");
op.add(Builder::description, JsonpDeserializer.stringDeserializer(), "description");
op.add(Builder::batchSize, JsonpDeserializer.integerDeserializer(), "batch_size");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,63 @@ private static TextEmbeddingProcessor.Builder baseTextEmbeddingProcessor() {
}

@Test
public void testJsonRoundtripWithDescription() {
public void testJsonRoundtripWithDescriptionAndBatchSize() {
Processor processor = new Processor.Builder().textEmbedding(
baseTextEmbeddingProcessor().description("processor-description").build()
baseTextEmbeddingProcessor().description("processor-description").batchSize(1).build()
).build();
String json =
"{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"},\"description\":\"processor-description\"}}";
"{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"},\"description\":\"processor-description\",\"batch_size\":1}}";
TextEmbeddingProcessor deserialized = checkJsonRoundtrip(processor, json).textEmbedding();

assertEquals("modelId", deserialized.modelId());
assertEquals(baseFieldMap, deserialized.fieldMap());
assertEquals("processor-description", deserialized.description());
assertEquals("some-tag", deserialized.tag());
assertNotNull(deserialized.batchSize());
assertEquals(1, deserialized.batchSize().intValue());
}

@Test
public void testJsonRoundtripWithoutDescription() {
Processor processor = new Processor.Builder().textEmbedding(baseTextEmbeddingProcessor().build()).build();
Processor processor = new Processor.Builder().textEmbedding(baseTextEmbeddingProcessor().batchSize(1).build()).build();
String json =
"{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"}}}";
"{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"},\"batch_size\":1}}";
TextEmbeddingProcessor deserialized = checkJsonRoundtrip(processor, json).textEmbedding();

assertEquals("modelId", deserialized.modelId());
assertEquals(baseFieldMap, deserialized.fieldMap());
assertNull(deserialized.description());
assertEquals("some-tag", deserialized.tag());
assertNotNull(deserialized.batchSize());
assertEquals(1, deserialized.batchSize().intValue());
}

@Test
public void testJsonRoundtripWithoutBatchSize() {
Processor processor = new Processor.Builder().textEmbedding(
baseTextEmbeddingProcessor().description("processor-description").build()
).build();
String json =
"{\"text_embedding\":{\"tag\":\"some-tag\",\"model_id\":\"modelId\",\"field_map\":{\"input_field\":\"vector_field\"},\"description\":\"processor-description\"}}";
TextEmbeddingProcessor deserialized = checkJsonRoundtrip(processor, json).textEmbedding();

assertEquals("modelId", deserialized.modelId());
assertEquals(baseFieldMap, deserialized.fieldMap());
assertEquals("processor-description", deserialized.description());
assertEquals("some-tag", deserialized.tag());
assertNull(deserialized.batchSize());
}

@Test
public void testInvalidBatchSizeThrowsException() {
IllegalArgumentException exceptionWhenBatchSizeIsZero = assertThrows(IllegalArgumentException.class, () -> {
new Processor.Builder().textEmbedding(baseTextEmbeddingProcessor().batchSize(0).build()).build();
});
assertEquals("batchSize must be a positive integer", exceptionWhenBatchSizeIsZero.getMessage());

IllegalArgumentException exceptionWhenBatchSizeIsNegative = assertThrows(IllegalArgumentException.class, () -> {
new Processor.Builder().textEmbedding(baseTextEmbeddingProcessor().batchSize(-1).build()).build();
});
assertEquals("batchSize must be a positive integer", exceptionWhenBatchSizeIsNegative.getMessage());
}
}
Loading