diff --git a/gradle.properties b/gradle.properties index a8727c71..7d03d1ee 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,4 +1,4 @@ -version=0.11.0-SNAPSHOT -kestraVersion=0.11.+ +version=0.12.0-SNAPSHOT +kestraVersion=[0.12,) micronautVersion=3.9.3 lombokVersion=1.18.28 diff --git a/src/main/java/io/kestra/plugin/gcp/cli/GCloudCLI.java b/src/main/java/io/kestra/plugin/gcp/cli/GCloudCLI.java index 264eb448..1f112c95 100644 --- a/src/main/java/io/kestra/plugin/gcp/cli/GCloudCLI.java +++ b/src/main/java/io/kestra/plugin/gcp/cli/GCloudCLI.java @@ -66,22 +66,24 @@ } ) public class GCloudCLI extends Task implements RunnableTask { + private static final String DEFAULT_IMAGE = "google/cloud-sdk"; + @NotNull @NotEmpty @Schema( - title = "The full service account JSON key to use to authenticate to gcloud" + title = "The full service account JSON key to use to authenticate to gcloud" ) @PluginProperty(dynamic = true) protected String serviceAccount; @Schema( - title = "The project id to scope the commands to" + title = "The project id to scope the commands to" ) @PluginProperty(dynamic = true) protected String projectId; @Schema( - title = "The commands to run" + title = "The commands to run" ) @PluginProperty(dynamic = true) @NotNull @@ -89,7 +91,7 @@ public class GCloudCLI extends Task implements RunnableTask { protected List commands; @Schema( - title = "Additional environment variables for the current process." + title = "Additional environment variables for the current process." ) @PluginProperty( additionalProperties = String.class, @@ -98,13 +100,12 @@ public class GCloudCLI extends Task implements RunnableTask { protected Map env; @Schema( - title = "Docker options when for the `DOCKER` runner" + title = "Docker options when for the `DOCKER` runner", + defaultValue = "{image=" + DEFAULT_IMAGE + ", pullPolicy=ALWAYS}" ) @PluginProperty @Builder.Default - protected DockerOptions docker = DockerOptions.builder() - .image("google/cloud-sdk") - .build(); + protected DockerOptions docker = DockerOptions.builder().build(); @Override public ScriptOutput run(RunContext runContext) throws Exception { @@ -112,7 +113,7 @@ public ScriptOutput run(RunContext runContext) throws Exception { CommandsWrapper commands = new CommandsWrapper(runContext) .withWarningOnStdErr(true) .withRunnerType(RunnerType.DOCKER) - .withDockerOptions(this.docker) + .withDockerOptions(injectDefaults(getDocker())) .withCommands( ScriptService.scriptCommands( List.of("/bin/sh", "-c"), @@ -125,6 +126,15 @@ public ScriptOutput run(RunContext runContext) throws Exception { return commands.run(); } + private DockerOptions injectDefaults(DockerOptions original) { + var builder = original.toBuilder(); + if (original.getImage() == null) { + builder.image(DEFAULT_IMAGE); + } + + return builder.build(); + } + private Map getEnv(RunContext runContext) throws IOException, IllegalVariableEvaluationException { Map envs = new HashMap<>(); if (serviceAccount != null) { diff --git a/src/main/java/io/kestra/plugin/gcp/pubsub/Consume.java b/src/main/java/io/kestra/plugin/gcp/pubsub/Consume.java index 4087a19f..d1c3398d 100644 --- a/src/main/java/io/kestra/plugin/gcp/pubsub/Consume.java +++ b/src/main/java/io/kestra/plugin/gcp/pubsub/Consume.java @@ -1,5 +1,6 @@ package io.kestra.plugin.gcp.pubsub; +import com.google.api.gax.core.FixedCredentialsProvider; import com.google.cloud.pubsub.v1.MessageReceiver; import com.google.cloud.pubsub.v1.Subscriber; import io.kestra.core.models.annotations.Example; @@ -10,6 +11,7 @@ import io.kestra.core.runners.RunContext; import io.kestra.core.serializers.FileSerde; import io.kestra.plugin.gcp.pubsub.model.Message; +import io.kestra.plugin.gcp.pubsub.model.SerdeType; import io.swagger.v3.oas.annotations.media.Schema; import lombok.*; import lombok.experimental.SuperBuilder; @@ -22,6 +24,8 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import javax.validation.constraints.NotNull; + import static io.kestra.core.utils.Rethrow.throwRunnable; @SuperBuilder @@ -31,7 +35,7 @@ @NoArgsConstructor @Schema( title = "Consume messages from a Pub/Sub topic.", - description = "Required a maxDuration or a maxRecords." + description = "Requires a maxDuration or a maxRecords." ) @Plugin( examples = { @@ -46,14 +50,14 @@ public class Consume extends AbstractPubSub implements RunnableTask { @Schema( - title = "The Pub/Sub subscription", + title = "The Pub/Sub subscription.", description = "The Pub/Sub subscription. It will be created automatically if it didn't exist and 'autoCreateSubscription' is enabled." ) @PluginProperty(dynamic = true) private String subscription; @Schema( - title = "Whether the Pub/Sub subscription should be created if not exist" + title = "Whether the Pub/Sub subscription should be created if not exists." ) @PluginProperty @Builder.Default @@ -67,6 +71,11 @@ public class Consume extends AbstractPubSub implements RunnableTask threadException = new AtomicReference<>(); MessageReceiver receiver = (message, consumer) -> { try { - FileSerde.write(outputFile, Message.of(message)); + FileSerde.write(outputFile, Message.of(message, serdeType)); total.getAndIncrement(); consumer.ack(); } @@ -92,7 +101,9 @@ public Output run(RunContext runContext) throws Exception { consumer.nack(); } }; - var subscriber = Subscriber.newBuilder(subscriptionName, receiver).build(); + var subscriber = Subscriber.newBuilder(subscriptionName, receiver) + .setCredentialsProvider(FixedCredentialsProvider.create(this.credentials(runContext))) + .build(); subscriber.startAsync().awaitRunning(); while (!this.ended(total, started)) { diff --git a/src/main/java/io/kestra/plugin/gcp/pubsub/Publish.java b/src/main/java/io/kestra/plugin/gcp/pubsub/Publish.java index 73820b8d..6fd5bf7f 100644 --- a/src/main/java/io/kestra/plugin/gcp/pubsub/Publish.java +++ b/src/main/java/io/kestra/plugin/gcp/pubsub/Publish.java @@ -10,6 +10,7 @@ import io.kestra.core.serializers.FileSerde; import io.kestra.core.serializers.JacksonMapper; import io.kestra.plugin.gcp.pubsub.model.Message; +import io.kestra.plugin.gcp.pubsub.model.SerdeType; import io.reactivex.BackpressureStrategy; import io.reactivex.Flowable; import io.swagger.v3.oas.annotations.media.Schema; @@ -58,6 +59,12 @@ public class Publish extends AbstractPubSub implements RunnableTask buildFlowable(Flowable flowable, Publisher publisher, RunContext runContext) { return flowable .map(message -> { - publisher.publish(message.to(runContext)); + publisher.publish(message.to(runContext, this.serdeType)); return 1; }); } diff --git a/src/main/java/io/kestra/plugin/gcp/pubsub/Trigger.java b/src/main/java/io/kestra/plugin/gcp/pubsub/Trigger.java index 6abe119c..040ec451 100644 --- a/src/main/java/io/kestra/plugin/gcp/pubsub/Trigger.java +++ b/src/main/java/io/kestra/plugin/gcp/pubsub/Trigger.java @@ -13,6 +13,7 @@ import io.kestra.core.models.triggers.TriggerOutput; import io.kestra.core.runners.RunContext; import io.kestra.core.utils.IdUtils; +import io.kestra.plugin.gcp.pubsub.model.SerdeType; import io.swagger.v3.oas.annotations.media.Schema; import lombok.*; import lombok.experimental.SuperBuilder; @@ -22,6 +23,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import javax.validation.constraints.NotNull; @SuperBuilder @@ -78,6 +80,12 @@ public class Trigger extends AbstractTrigger implements PollingTriggerInterface, @Schema(title = "Max duration in the Duration ISO format, after that the task will end.") private Duration maxDuration; + @Builder.Default + @PluginProperty + @NotNull + @Schema(title = "The serializer/deserializer to use.") + private SerdeType serdeType = SerdeType.STRING; + @Override public Optional evaluate(ConditionContext conditionContext, TriggerContext context) throws Exception { RunContext runContext = conditionContext.getRunContext(); @@ -92,6 +100,7 @@ public Optional evaluate(ConditionContext conditionContext, TriggerCo .scopes(this.scopes) .maxRecords(this.maxRecords) .maxDuration(this.maxDuration) + .serdeType(this.serdeType) .build(); Consume.Output run = task.run(runContext); diff --git a/src/main/java/io/kestra/plugin/gcp/pubsub/model/Message.java b/src/main/java/io/kestra/plugin/gcp/pubsub/model/Message.java index f6d9d372..216756c2 100644 --- a/src/main/java/io/kestra/plugin/gcp/pubsub/model/Message.java +++ b/src/main/java/io/kestra/plugin/gcp/pubsub/model/Message.java @@ -10,6 +10,8 @@ import lombok.Getter; import lombok.extern.jackson.Jacksonized; +import java.io.IOException; +import java.util.Base64; import java.util.Map; import static io.kestra.core.utils.Rethrow.throwBiConsumer; @@ -19,11 +21,14 @@ @Jacksonized public class Message { - @Schema(title = "The message data, must be base64 encoded") + @Schema( + title = "The message data, must be a string if serde type is 'STRING', otherwise a JSON object", + description = "If it's a string, it can be a dynamic property otherwise not." + ) @PluginProperty(dynamic = true) - private String data; + private Object data; - @Schema(title = "The message attribute map") + @Schema(title = "The message attributes map") @PluginProperty(dynamic = true) private Map attributes; @@ -35,10 +40,17 @@ public class Message { @PluginProperty(dynamic = true) private String orderingKey; - public PubsubMessage to(RunContext runContext) throws IllegalVariableEvaluationException { + public PubsubMessage to(RunContext runContext, SerdeType serdeType) throws IllegalVariableEvaluationException, IOException { var builder = PubsubMessage.newBuilder(); if(data != null) { - builder.setData(ByteString.copyFrom(runContext.render(data).getBytes())); + byte[] serializedData; + if (data instanceof String dataStr) { + var rendered = runContext.render(dataStr); + serializedData = rendered.getBytes(); + } else { + serializedData = serdeType.serialize(data); + } + builder.setData(ByteString.copyFrom(Base64.getEncoder().encode(serializedData))); } if(attributes != null && !attributes.isEmpty()) { attributes.forEach(throwBiConsumer((key, value) -> builder.putAttributes(runContext.render(key), runContext.render(value)))); @@ -52,12 +64,17 @@ public PubsubMessage to(RunContext runContext) throws IllegalVariableEvaluationE return builder.build(); } - public static Message of(PubsubMessage message) { - return Message.builder() + public static Message of(PubsubMessage message, SerdeType serdeType) throws IOException { + var builder = Message.builder() .messageId(message.getMessageId()) - .data(message.getData().toString()) .attributes(message.getAttributesMap()) - .orderingKey(message.getOrderingKey()) - .build(); + .orderingKey(message.getOrderingKey()); + + if (message.getData() != null) { + var decodedData = Base64.getDecoder().decode(message.getData().toByteArray()); + builder.data(serdeType.deserialize(decodedData)); + } + + return builder.build(); } } diff --git a/src/main/java/io/kestra/plugin/gcp/pubsub/model/SerdeType.java b/src/main/java/io/kestra/plugin/gcp/pubsub/model/SerdeType.java new file mode 100644 index 00000000..56b7e376 --- /dev/null +++ b/src/main/java/io/kestra/plugin/gcp/pubsub/model/SerdeType.java @@ -0,0 +1,29 @@ +package io.kestra.plugin.gcp.pubsub.model; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.kestra.core.serializers.JacksonMapper; + +import java.io.IOException; + +public enum SerdeType { + STRING, + JSON; + + private static final ObjectMapper OBJECT_MAPPER = JacksonMapper.ofJson(false); + + public Object deserialize(byte[] message) throws IOException { + if (this == SerdeType.JSON) { + return OBJECT_MAPPER.readValue(message, Object.class); + } else { + return message; + } + } + + public byte[] serialize(Object message) throws IOException { + if (this == SerdeType.JSON) { + return OBJECT_MAPPER.writeValueAsBytes(message); + } else { + return message.toString().getBytes(); + } + } +} diff --git a/src/main/java/io/kestra/plugin/gcp/vertexai/AbstractGenerativeAi.java b/src/main/java/io/kestra/plugin/gcp/vertexai/AbstractGenerativeAi.java new file mode 100644 index 00000000..e2e93489 --- /dev/null +++ b/src/main/java/io/kestra/plugin/gcp/vertexai/AbstractGenerativeAi.java @@ -0,0 +1,168 @@ +package io.kestra.plugin.gcp.vertexai; + +import io.kestra.core.exceptions.IllegalVariableEvaluationException; +import io.kestra.core.models.annotations.PluginProperty; +import io.kestra.core.models.executions.metrics.Counter; +import io.kestra.core.runners.RunContext; +import io.kestra.plugin.gcp.AbstractTask; +import io.micronaut.http.MediaType; +import io.micronaut.http.MutableHttpRequest; +import io.micronaut.http.client.DefaultHttpClientConfiguration; +import io.micronaut.http.client.HttpClient; +import io.micronaut.http.client.exceptions.HttpClientResponseException; +import io.micronaut.http.client.netty.DefaultHttpClient; +import io.micronaut.http.client.netty.NettyHttpClientFactory; +import io.micronaut.http.codec.MediaTypeCodecRegistry; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.ToString; +import lombok.experimental.SuperBuilder; + +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URI; +import java.util.List; +import javax.validation.constraints.Max; +import javax.validation.constraints.Min; +import javax.validation.constraints.NotNull; +import javax.validation.constraints.Positive; + +@SuperBuilder +@ToString +@EqualsAndHashCode +@Getter +@NoArgsConstructor +abstract class AbstractGenerativeAi extends AbstractTask { + private static final NettyHttpClientFactory FACTORY = new NettyHttpClientFactory(); + private static final String URI_PATTERN = "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict"; + + @Schema( + title = "The region" + ) + @PluginProperty(dynamic = true) + @NotNull + private String region; + + @Builder.Default + @Schema( + title = "The model parameters" + ) + @PluginProperty + private ModelParameter parameters = ModelParameter.builder().build(); + + T call(RunContext runContext, Class responseClass) { + try { + var auth = credentials(runContext); + auth.refreshIfExpired(); + + var request = getPredictionRequest(runContext) + .contentType(MediaType.APPLICATION_JSON) + .bearerAuth(auth.getAccessToken().getTokenValue()); + + try (HttpClient client = this.client(runContext)) { + var response = client.toBlocking().exchange(request, responseClass); + var predictionResponse = response.body(); + if (predictionResponse == null) { + throw new RuntimeException("Received an empty response from the Vertex.ai prediction API"); + } + + return predictionResponse; + } + } catch (HttpClientResponseException e) { + throw new HttpClientResponseException( + "Request failed '" + e.getStatus().getCode() + "' and body '" + e.getResponse().getBody(String.class).orElse("null") + "'", + e, + e.getResponse() + ); + } catch (IllegalVariableEvaluationException | IOException e) { + throw new RuntimeException(e); + } + } + + protected abstract MutableHttpRequest getPredictionRequest(RunContext runContext) throws IllegalVariableEvaluationException; + + protected void sendMetrics(RunContext runContext, Metadata metadata) { + runContext.metric(Counter.of("input.token.total.tokens", metadata.tokenMetadata.inputTokenCount.totalTokens)); + runContext.metric(Counter.of("input.token.total.billable.characters", metadata.tokenMetadata.inputTokenCount.totalBillableCharacters)); + runContext.metric(Counter.of("output.token.total.tokens", metadata.tokenMetadata.outputTokenCount.totalTokens)); + runContext.metric(Counter.of("output.token.total.billable.characters", metadata.tokenMetadata.outputTokenCount.totalBillableCharacters)); + } + + protected URI getPredictionURI(RunContext runContext, String modelId) throws IllegalVariableEvaluationException { + var formatted = URI_PATTERN.formatted(runContext.render(getRegion()), runContext.render(getProjectId()), runContext.render(getRegion()), modelId); + runContext.logger().debug("Calling Vertex.AI prediction API {}", formatted); + return URI.create(formatted); + } + + private HttpClient client(RunContext runContext) throws MalformedURLException { + MediaTypeCodecRegistry mediaTypeCodecRegistry = runContext.getApplicationContext().getBean(MediaTypeCodecRegistry.class); + var httpConfig = new DefaultHttpClientConfiguration(); + httpConfig.setMaxContentLength(Integer.MAX_VALUE); + + DefaultHttpClient client = (DefaultHttpClient) FACTORY.createClient(null, httpConfig); + client.setMediaTypeCodecRegistry(mediaTypeCodecRegistry); + return client; + } + + @Builder + @Getter + public static class ModelParameter { + @Builder.Default + @PluginProperty + @Positive + @Max(1) + @Schema( + title = "Temperature used for sampling during the response generation, which occurs when topP and topK are applied.", + description = "Temperature controls the degree of randomness in token selection. Lower temperatures are good for prompts that require a more deterministic and less open-ended or creative response, while higher temperatures can lead to more diverse or creative results. A temperature of 0 is deterministic: the highest probability response is always selected. For most use cases, try starting with a temperature of 0.2." + ) + private Float temperature = 0.2F; + + @Builder.Default + @PluginProperty + @Min(1) + @Max(1024) + @Schema( + title = "Maximum number of tokens that can be generated in the response", + description = """ + Specify a lower value for shorter responses and a higher value for longer responses. + A token may be smaller than a word. A token is approximately four characters. 100 tokens correspond to roughly 60-80 words.""" + ) + private Integer maxOutputTokens = 128; + + @Builder.Default + @PluginProperty + @Min(1) + @Max(40) + @Schema( + title = "Top-k changes how the model selects tokens for output", + description = """ + A top-k of 1 means the selected token is the most probable among all tokens in the model's vocabulary (also called greedy decoding), while a top-k of 3 means that the next token is selected from among the 3 most probable tokens (using temperature). + For each token selection step, the top K tokens with the highest probabilities are sampled. Then tokens are further filtered based on topP with the final token selected using temperature sampling. + Specify a lower value for less random responses and a higher value for more random responses.""" + ) + private Integer topK = 40; + + @Builder.Default + @PluginProperty + @Positive + @Max(1) + @Schema( + title = "Top-p changes how the model selects tokens for output", + description = """ + Tokens are selected from most K (see topK parameter) probable to least until the sum of their probabilities equals the top-p value. For example, if tokens A, B, and C have a probability of 0.3, 0.2, and 0.1 and the top-p value is 0.5, then the model will select either A or B as the next token (using temperature) and doesn't consider C. The default top-p value is 0.95. + Specify a lower value for less random responses and a higher value for more random responses.""" + ) + private Float topP = 0.95F; + } + + // common response objects + public record CitationMetadata(List citations) {} + public record Citation(List citations) {} + public record SafetyAttributes(List scores, List categories, Boolean blocked) {} + public record Metadata(TokenMetadata tokenMetadata) {} + public record TokenMetadata(TokenCount outputTokenCount, TokenCount inputTokenCount) {} + public record TokenCount(Integer totalTokens, Integer totalBillableCharacters) {} +} diff --git a/src/main/java/io/kestra/plugin/gcp/vertexai/ChatCompletion.java b/src/main/java/io/kestra/plugin/gcp/vertexai/ChatCompletion.java new file mode 100644 index 00000000..504f82b7 --- /dev/null +++ b/src/main/java/io/kestra/plugin/gcp/vertexai/ChatCompletion.java @@ -0,0 +1,138 @@ +package io.kestra.plugin.gcp.vertexai; + +import io.kestra.core.exceptions.IllegalVariableEvaluationException; +import io.kestra.core.models.annotations.Example; +import io.kestra.core.models.annotations.Plugin; +import io.kestra.core.models.annotations.PluginProperty; +import io.kestra.core.models.tasks.RunnableTask; +import io.kestra.core.runners.RunContext; +import io.micronaut.http.HttpRequest; +import io.micronaut.http.MutableHttpRequest; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.ToString; +import lombok.experimental.SuperBuilder; + +import java.util.List; +import javax.validation.constraints.NotEmpty; +import javax.validation.constraints.NotNull; + +import static io.kestra.core.utils.Rethrow.throwFunction; + +@SuperBuilder +@ToString +@EqualsAndHashCode +@Getter +@NoArgsConstructor +@Schema( + title = "Chat completion using the Vertex AI PaLM API for Google's PaLM 2 large language models (LLM)", + description = "See [Generative AI quickstart using the Vertex AI API](https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart) for more information." +) +@Plugin( + examples = { + @Example( + title = "Chat completion using the Vertex AI PaLM API", + code = { + """ + region: us-central1 + projectId: my-project + context: I love jokes that talk about sport + messages: + - author: user + content: Please tell me a joke""" + } + ) + } +) +public class ChatCompletion extends AbstractGenerativeAi implements RunnableTask { + private static final String MODEL_ID = "chat-bison"; + + @PluginProperty(dynamic = true) + @Schema( + title = "Context shapes how the model responds throughout the conversation", + description = "For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style." + ) + private String context; + + @PluginProperty(dynamic = true) + @Schema( + title = "List of structured messages to the model to learn how to respond to the conversation" + ) + private List examples; + + @PluginProperty(dynamic = true) + @Schema( + title = "Conversation history provided to the model in a structured alternate-author form", + description = "Messages appear in chronological order: oldest first, newest last. When the history of messages causes the input to exceed the maximum length, the oldest messages are removed until the entire prompt is within the allowed limit." + ) + @NotEmpty + private List messages; + + @Override + public Output run(RunContext runContext) throws Exception { + var response = call(runContext, PredictionResponse.class); + sendMetrics(runContext, response.metadata); + + return Output.builder() + .predictions(response.predictions) + .build(); + } + + @Override + protected MutableHttpRequest getPredictionRequest(RunContext runContext) throws IllegalVariableEvaluationException { + List chatExamples = examples == null ? null : + examples.stream().map(throwFunction(ex -> new ChatExample(new ChatContent(runContext.render(ex.input)), new ChatContent(runContext.render(ex.output))))).toList(); + List chatMessages = messages.stream().map(throwFunction(msg -> new Message(runContext.render(msg.author), runContext.render(msg.content)))).toList(); + + var request = new ChatPromptRequest(List.of(new ChatPromptInstance(runContext.render(context), chatExamples, chatMessages)), getParameters()); + return HttpRequest.POST(getPredictionURI(runContext, MODEL_ID), request); + } + + // request objects + public record ChatPromptRequest(List instances, ModelParameter parameters) {} + public record ChatPromptInstance(String context, List examples, List messages) {} + public record ChatExample(ChatContent input, ChatContent output) {} + public record ChatContent(String content) {} + + // response objects + public record PredictionResponse(List predictions, Metadata metadata) {} + public record Prediction(List candidates, List citationMetadata, List safetyAttributes) {} + public record Candidate(String content, String author) {} + + + @Builder + @Getter + public static class Example { + @PluginProperty(dynamic = true) + @NotNull + private String input; + + @PluginProperty(dynamic = true) + @NotNull + private String output; + } + + @Builder + @Getter + @AllArgsConstructor + public static class Message { + @PluginProperty(dynamic = true) + @NotNull + private String author; + + @PluginProperty(dynamic = true) + @NotNull + private String content; + } + + @Builder + @Getter + public static class Output implements io.kestra.core.models.tasks.Output { + @Schema(title = "List of text predictions made by the model") + private List predictions; + } +} diff --git a/src/main/java/io/kestra/plugin/gcp/vertexai/TextCompletion.java b/src/main/java/io/kestra/plugin/gcp/vertexai/TextCompletion.java new file mode 100644 index 00000000..6f0bd223 --- /dev/null +++ b/src/main/java/io/kestra/plugin/gcp/vertexai/TextCompletion.java @@ -0,0 +1,83 @@ +package io.kestra.plugin.gcp.vertexai; + +import io.kestra.core.exceptions.IllegalVariableEvaluationException; +import io.kestra.core.models.annotations.Example; +import io.kestra.core.models.annotations.Plugin; +import io.kestra.core.models.annotations.PluginProperty; +import io.kestra.core.models.tasks.RunnableTask; +import io.kestra.core.runners.RunContext; +import io.micronaut.http.HttpRequest; +import io.micronaut.http.MutableHttpRequest; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.ToString; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@SuperBuilder +@ToString +@EqualsAndHashCode +@Getter +@NoArgsConstructor +@Schema( + title = "Text completion using the Vertex AI PaLM API for Google's PaLM 2 large language models (LLM)", + description = "See [Generative AI quickstart using the Vertex AI API](https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart) for more information." +) +@Plugin( + examples = { + @Example( + title = "Text completion using the Vertex AI PaLM API", + code = { + """ + region: us-central1 + projectId: my-project + prompt: Please tell me a joke""" + } + ) + } +) +public class TextCompletion extends AbstractGenerativeAi implements RunnableTask { + private static final String MODEL_ID = "text-bison"; + + @PluginProperty(dynamic = true) + @Schema( + title = "Text input to generate model response", + description = "Prompts can include preamble, questions, suggestions, instructions, or examples." + ) + private String prompt; + + @Override + public Output run(RunContext runContext) throws Exception { + var response = call(runContext, PredictionResponse.class); + sendMetrics(runContext, response.metadata); + + return Output.builder() + .predictions(response.predictions) + .build(); + } + + @Override + protected MutableHttpRequest getPredictionRequest(RunContext runContext) throws IllegalVariableEvaluationException { + var request = new TextPromptRequest(List.of(new TextPromptInstance(runContext.render(prompt))), getParameters()); + return HttpRequest.POST(getPredictionURI(runContext, MODEL_ID), request); + } + + // request objects + public record TextPromptRequest(List instances, ModelParameter parameters) {} + public record TextPromptInstance(String prompt) {} + + // response objects + public record PredictionResponse(List predictions, Metadata metadata) {} + public record Prediction(SafetyAttributes safetyAttributes, CitationMetadata citationMetadata, String content) {} + + @Builder + @Getter + public static class Output implements io.kestra.core.models.tasks.Output { + @Schema(title = "List of text predictions made by the model") + private List predictions; + } +} diff --git a/src/main/resources/icons/io.kestra.plugin.gcp.auth.svg b/src/main/resources/icons/io.kestra.plugin.gcp.auth.svg index f725bf8f..38ed9e18 100644 --- a/src/main/resources/icons/io.kestra.plugin.gcp.auth.svg +++ b/src/main/resources/icons/io.kestra.plugin.gcp.auth.svg @@ -1 +1,23 @@ - \ No newline at end of file + + + + + + + + + + + + + + + + + + diff --git a/src/main/resources/icons/io.kestra.plugin.gcp.dataproc.svg b/src/main/resources/icons/io.kestra.plugin.gcp.dataproc.batches.svg similarity index 100% rename from src/main/resources/icons/io.kestra.plugin.gcp.dataproc.svg rename to src/main/resources/icons/io.kestra.plugin.gcp.dataproc.batches.svg diff --git a/src/main/resources/icons/io.kestra.plugin.gcp.vertexai.svg b/src/main/resources/icons/io.kestra.plugin.gcp.vertexai.svg index 20f9e1b9..76c8ee67 100644 --- a/src/main/resources/icons/io.kestra.plugin.gcp.vertexai.svg +++ b/src/main/resources/icons/io.kestra.plugin.gcp.vertexai.svg @@ -1,305 +1,25 @@ - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/test/java/io/kestra/plugin/gcp/bigquery/TriggerTest.java b/src/test/java/io/kestra/plugin/gcp/bigquery/TriggerTest.java index 602174e7..464dd034 100644 --- a/src/test/java/io/kestra/plugin/gcp/bigquery/TriggerTest.java +++ b/src/test/java/io/kestra/plugin/gcp/bigquery/TriggerTest.java @@ -76,10 +76,10 @@ void flow() throws Exception { // wait for execution executionQueue.receive(TriggerTest.class, execution -> { - last.set(execution); + last.set(execution.getLeft()); queueCount.countDown(); - assertThat(execution.getFlowId(), is("bigquery-listen")); + assertThat(execution.getLeft().getFlowId(), is("bigquery-listen")); }); diff --git a/src/test/java/io/kestra/plugin/gcp/gcs/TriggerTest.java b/src/test/java/io/kestra/plugin/gcp/gcs/TriggerTest.java index 07c21b5e..00a93709 100644 --- a/src/test/java/io/kestra/plugin/gcp/gcs/TriggerTest.java +++ b/src/test/java/io/kestra/plugin/gcp/gcs/TriggerTest.java @@ -82,10 +82,10 @@ void flow() throws Exception { // wait for execution executionQueue.receive(TriggerTest.class, execution -> { - last.set(execution); + last.set(execution.getLeft()); queueCount.countDown(); - assertThat(execution.getFlowId(), is("gcs-listen")); + assertThat(execution.getLeft().getFlowId(), is("gcs-listen")); }); diff --git a/src/test/java/io/kestra/plugin/gcp/pubsub/PublishThenConsumeTest.java b/src/test/java/io/kestra/plugin/gcp/pubsub/PublishThenConsumeTest.java index ea123ae6..de4cd4a7 100644 --- a/src/test/java/io/kestra/plugin/gcp/pubsub/PublishThenConsumeTest.java +++ b/src/test/java/io/kestra/plugin/gcp/pubsub/PublishThenConsumeTest.java @@ -6,6 +6,7 @@ import io.kestra.core.storages.StorageInterface; import io.kestra.core.utils.IdUtils; import io.kestra.plugin.gcp.pubsub.model.Message; +import io.kestra.plugin.gcp.pubsub.model.SerdeType; import io.micronaut.context.annotation.Value; import io.micronaut.test.extensions.junit5.annotation.MicronautTest; import jakarta.inject.Inject; @@ -16,7 +17,6 @@ import java.io.FileOutputStream; import java.io.OutputStream; import java.net.URI; -import java.util.Base64; import java.util.List; import java.util.Map; @@ -43,7 +43,7 @@ void runWithList() throws Exception { .topic("test-topic") .from( List.of( - Message.builder().data(Base64.getEncoder().encodeToString("Hello World".getBytes())).build(), + Message.builder().data("Hello World").build(), Message.builder().attributes(Map.of("key", "value")).build() ) ) @@ -63,6 +63,37 @@ void runWithList() throws Exception { assertThat(consumeOutput.getCount(), is(2)); } + @Test + void runWithJson() throws Exception { + var runContext = runContextFactory.of(); + + var publish = Publish.builder() + .projectId(project) + .topic("test-topic") + .serdeType(SerdeType.JSON) + .from( + List.of( + Message.builder().data(""" + {"hello": "world"}""").build() + ) + ) + .build(); + + var publishOutput = publish.run(runContext); + assertThat(publishOutput.getMessagesCount(), is(1)); + + var consume = Consume.builder() + .projectId(project) + .topic("test-topic") + .serdeType(SerdeType.JSON) + .subscription("test-subscription") + .maxRecords(1) + .build(); + + var consumeOutput = consume.run(runContextFactory.of()); + assertThat(consumeOutput.getCount(), is(1)); + } + @Test void runWithFile() throws Exception { var runContext = runContextFactory.of(); @@ -94,7 +125,7 @@ private URI createTestFile(RunContext runContext) throws Exception { OutputStream output = new FileOutputStream(tempFile); FileSerde.write(output, - Message.builder().data(Base64.getEncoder().encodeToString("Hello World".getBytes())).build()); + Message.builder().data("Hello World".getBytes()).build()); FileSerde.write(output, Message.builder().attributes(Map.of("key", "value")).build()); return storageInterface.put(URI.create("/" + IdUtils.create() + ".ion"), new FileInputStream(tempFile)); diff --git a/src/test/java/io/kestra/plugin/gcp/pubsub/TriggerTest.java b/src/test/java/io/kestra/plugin/gcp/pubsub/TriggerTest.java index 68ba6c10..141bdd91 100644 --- a/src/test/java/io/kestra/plugin/gcp/pubsub/TriggerTest.java +++ b/src/test/java/io/kestra/plugin/gcp/pubsub/TriggerTest.java @@ -20,7 +20,6 @@ import jakarta.inject.Named; import org.junit.jupiter.api.Test; -import java.util.Base64; import java.util.List; import java.util.Map; import java.util.Objects; @@ -75,10 +74,10 @@ void flow() throws Exception { // wait for execution executionQueue.receive(TriggerTest.class, execution -> { - last.set(execution); + last.set(execution.getLeft()); queueCount.countDown(); - assertThat(execution.getFlowId(), is("pubsub-listen")); + assertThat(execution.getLeft().getFlowId(), is("pubsub-listen")); }); @@ -95,7 +94,7 @@ void flow() throws Exception { .projectId(this.project) .from( List.of( - Message.builder().data(Base64.getEncoder().encodeToString("Hello World".getBytes())).build(), + Message.builder().data("Hello World".getBytes()).build(), Message.builder().attributes(Map.of("key", "value")).build() ) ) diff --git a/src/test/resources/application.yml b/src/test/resources/application.yml index cde5d0f2..66e5e009 100644 --- a/src/test/resources/application.yml +++ b/src/test/resources/application.yml @@ -22,6 +22,9 @@ kestra: project: "kestra-unit-test" dataproc: project: "kestra-unit-test" + vertexai: + project: "kestra-unit-test" + region: "us-central1" variables: globals: bucket: kestra-unit-test