From 95085ffc9f6bedb9a58b385b8b227ac11717096e Mon Sep 17 00:00:00 2001 From: Junwei Dai Date: Mon, 6 Jan 2025 13:08:22 -0800 Subject: [PATCH] Add synchronous execution option to workflow provisioning Signed-off-by: Junwei Dai --- .../flowframework/common/CommonValue.java | 6 + .../rest/RestCreateWorkflowAction.java | 6 +- .../rest/RestProvisionWorkflowAction.java | 5 +- .../CreateWorkflowTransportAction.java | 16 +- .../ProvisionWorkflowTransportAction.java | 127 +++++++++++++- .../transport/WorkflowRequest.java | 48 +++++- .../transport/WorkflowResponse.java | 39 ++++- .../rest/RestCreateWorkflowActionTests.java | 35 ++++ .../RestProvisionWorkflowActionTests.java | 22 +++ .../CreateWorkflowTransportActionTests.java | 156 ++++++++++++++++-- .../WorkflowRequestResponseTests.java | 40 ++++- 11 files changed, 476 insertions(+), 24 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 9c88788b3..2fe46996b 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -8,6 +8,8 @@ */ package org.opensearch.flowframework.common; +import org.opensearch.common.unit.TimeValue; + /** * Representation of common values that are used across project */ @@ -55,6 +57,8 @@ private CommonValue() {} /** The last provisioned time field */ public static final String LAST_PROVISIONED_TIME_FIELD = "last_provisioned_time"; + public static final TimeValue DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT = TimeValue.timeValueSeconds(1); + /* * Constants associated with Rest or Transport actions */ @@ -74,6 +78,8 @@ private CommonValue() {} public static final String PROVISION_WORKFLOW = "provision"; /** The param name for update workflow field in create API */ public static final String UPDATE_WORKFLOW_FIELDS = "update_fields"; + /** The param name for specifying the timeout duration in seconds to wait for workflow completion */ + public static final String WAIT_FOR_COMPLETION_TIMEOUT = "wait_for_completion_timeout"; /** The field name for workflow steps. This field represents the name of the workflow steps to be fetched. */ public static final String WORKFLOW_STEP = "workflow_step"; /** The param name for default use case, used by the create workflow API */ diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index 4abedc365..1e0ad5088 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.client.node.NodeClient; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; @@ -43,6 +44,7 @@ import static org.opensearch.flowframework.common.CommonValue.UPDATE_WORKFLOW_FIELDS; import static org.opensearch.flowframework.common.CommonValue.USE_CASE; import static org.opensearch.flowframework.common.CommonValue.VALIDATION; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -87,6 +89,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false); boolean reprovision = request.paramAsBoolean(REPROVISION_WORKFLOW, false); boolean updateFields = request.paramAsBoolean(UPDATE_WORKFLOW_FIELDS, false); + TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, null); String useCase = request.param(USE_CASE); // If provisioning, consume all other params and pass to provision transport action @@ -226,7 +229,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli validation, provision || updateFields, params, - reprovision + reprovision, + waitForCompletionTimeout ); return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index 6ae56905c..502bf9423 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.client.node.NodeClient; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; @@ -33,6 +34,7 @@ import java.util.stream.Collectors; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION_TIMEOUT; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; @@ -73,6 +75,7 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String workflowId = request.param(WORKFLOW_ID); + TimeValue waitForCompletionTimeout = request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT, null); try { Map params = parseParamsAndContent(request); if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { @@ -86,7 +89,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); } // Create request and provision - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, params); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, params, waitForCompletionTimeout); return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 813613a32..eda7e42f3 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -251,7 +251,8 @@ private void createExecute(WorkflowRequest request, User user, ActionListener { - listener.onResponse(new WorkflowResponse(provisionResponse.getWorkflowId())); + if (request.getWaitForCompletionTimeout() != null) { + listener.onResponse( + new WorkflowResponse( + provisionResponse.getWorkflowId(), + provisionResponse.getWorkflowState() + ) + ); + } else { + listener.onResponse( + new WorkflowResponse(provisionResponse.getWorkflowId()) + ); + } }, exception -> { String errorMessage = "Provisioning failed."; logger.error(errorMessage, exception); diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 45f374161..841c76cf5 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -45,6 +45,8 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.ERROR_FIELD; @@ -210,14 +212,27 @@ private void executeProvisionRequest( ), ActionListener.wrap(updateResponse -> { logger.info("updated workflow {} state to {}", request.getWorkflowId(), State.PROVISIONING); - executeWorkflowAsync(workflowId, provisionProcessSequence, listener); + if (request.getWaitForCompletionTimeout() != null) { + executeWorkflowSync( + workflowId, + provisionProcessSequence, + listener, + request.getWaitForCompletionTimeout().getMillis() + ); + } else { + executeWorkflowAsync(workflowId, provisionProcessSequence, listener); + } // update last provisioned field in template Template newTemplate = Template.builder(template).lastProvisionedTime(Instant.now()).build(); flowFrameworkIndicesHandler.updateTemplateInGlobalContext( request.getWorkflowId(), newTemplate, ActionListener.wrap(templateResponse -> { - listener.onResponse(new WorkflowResponse(request.getWorkflowId())); + if (request.getWaitForCompletionTimeout() != null) { + logger.info("Waiting for workflow completion"); + } else { + listener.onResponse(new WorkflowResponse(request.getWorkflowId())); + } }, exception -> { String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( "Failed to update use case template {}", @@ -275,18 +290,105 @@ private void executeProvisionRequest( */ private void executeWorkflowAsync(String workflowId, List workflowSequence, ActionListener listener) { try { - threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> { executeWorkflow(workflowSequence, workflowId); }); + threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL) + .execute(() -> { executeWorkflow(workflowSequence, workflowId, listener, false); }); } catch (Exception exception) { listener.onFailure(new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(exception))); } } + /** + * Retrieves a thread from the provision thread pool to execute a workflow with a timeout mechanism. + * If the execution exceeds the specified timeout, it will return the current status of the workflow. + * + * @param workflowId The id of the workflow + * @param workflowSequence The sorted workflow to execute + * @param listener ActionListener for any failures or responses + * @param timeout The timeout duration in milliseconds + */ + private void executeWorkflowSync( + String workflowId, + List workflowSequence, + ActionListener listener, + long timeout + ) { + PlainActionFuture workflowFuture = new PlainActionFuture<>(); + AtomicBoolean isResponseSent = new AtomicBoolean(false); + CompletableFuture.runAsync(() -> { + try { + executeWorkflow(workflowSequence, workflowId, new ActionListener<>() { + @Override + public void onResponse(WorkflowResponse workflowResponse) { + if (isResponseSent.get()) { + logger.info("Ignoring onResponse for workflowId: {} as timeout already occurred", workflowId); + return; + } + isResponseSent.set(true); + workflowFuture.onResponse(null); + listener.onResponse(new WorkflowResponse(workflowResponse.getWorkflowId(), workflowResponse.getWorkflowState())); + } + + @Override + public void onFailure(Exception e) { + if (isResponseSent.get()) { + logger.info("Ignoring onFailure for workflowId: {} as timeout already occurred", workflowId); + return; + } + isResponseSent.set(true); + workflowFuture.onFailure( + new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(e)) + ); + listener.onFailure( + new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(e)) + ); + } + }, true); + } catch (Exception ex) { + if (!isResponseSent.get()) { + isResponseSent.set(true); + workflowFuture.onFailure( + new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(ex)) + ); + listener.onFailure(new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(ex))); + } + } + }, threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL)); + + threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> { + try { + Thread.sleep(timeout); + if (isResponseSent.compareAndSet(false, true)) { + logger.warn("Workflow execution timed out for workflowId: {}", workflowId); + client.execute( + GetWorkflowStateAction.INSTANCE, + new GetWorkflowStateRequest(workflowId, false), + ActionListener.wrap( + response -> listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())), + exception -> listener.onFailure( + new FlowFrameworkException("Failed to get workflow state after timeout", ExceptionsHelper.status(exception)) + ) + ) + ); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + } + /** * Executes the given workflow sequence * @param workflowSequence The topologically sorted workflow to execute * @param workflowId The workflowId associated with the workflow that is executing + * @param listener The ActionListener to handle the workflow response or failure + * @param isSyncExecution Flag indicating whether the workflow should be executed synchronously (true) or asynchronously (false) */ - private void executeWorkflow(List workflowSequence, String workflowId) { + private void executeWorkflow( + List workflowSequence, + String workflowId, + ActionListener listener, + boolean isSyncExecution + ) { String currentStepId = ""; try { Map> workflowFutureMap = new LinkedHashMap<>(); @@ -324,6 +426,23 @@ private void executeWorkflow(List workflowSequence, String workflow ), ActionListener.wrap(updateResponse -> { logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); + if (isSyncExecution) { + client.execute( + GetWorkflowStateAction.INSTANCE, + new GetWorkflowStateRequest(workflowId, false), + ActionListener.wrap(response -> { + listener.onResponse(new WorkflowResponse(workflowId, response.getWorkflowState())); + }, exception -> { + String errorMessage = "Failed to get workflow state."; + logger.error(errorMessage, exception); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }) + ); + } }, exception -> { logger.error("Failed to update workflow state for workflow {}", workflowId, exception); }) ); } catch (Exception ex) { diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 97f032e31..9c480be65 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -11,6 +11,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.Nullable; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.flowframework.model.Template; @@ -62,13 +63,20 @@ public class WorkflowRequest extends ActionRequest { */ private Map params; + /** + * The timeout duration to wait for workflow completion. + * If null, the request will respond immediately with the workflowId. + */ + @Nullable + private TimeValue waitForCompletionTimeout; + /** * Instantiates a new WorkflowRequest, set validation to all, no provisioning * @param workflowId the documentId of the workflow * @param template the use case template which describes the workflow */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { - this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), false); + this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), false, null); } /** @@ -78,7 +86,27 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) * @param params The parameters from the REST path */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, Map params) { - this(workflowId, template, new String[] { "all" }, true, params, false); + this(workflowId, template, new String[] { "all" }, true, params, false, null); + } + + /** + * Instantiates a new WorkflowRequest with a specified wait-for-completion timeout. + * This constructor allows the caller to specify a custom timeout for the workflow execution, + * which determines how long the system will wait for the workflow to complete before returning a response. + * By default, the validation is set to "all", and provisioning is set to true. + * @param workflowId The unique document ID of the workflow. Can be null for new workflows. + * @param template The use case template that defines the structure and logic of the workflow. Can be null if not provided. + * @param params A map of parameters extracted from the REST request path, used to customize the workflow execution. + * @param waitForCompletionTimeout The maximum duration to wait for the workflow execution to complete. + * If the workflow does not complete within this timeout, the request will return a timeout response. + */ + public WorkflowRequest( + @Nullable String workflowId, + @Nullable Template template, + Map params, + TimeValue waitForCompletionTimeout + ) { + this(workflowId, template, new String[] { "all" }, true, params, false, waitForCompletionTimeout); } /** @@ -89,6 +117,7 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, * @param provisionOrUpdate provision or updateFields flag. Only one may be true, the presence of update_fields key in map indicates if updating fields, otherwise true means it's provisioning. * @param params map of REST path params. If provisionOrUpdate is false, must be an empty map. If update_fields key is present, must be only key. * @param reprovision flag to indicate if request is to reprovision + * @param waitForCompletionTimeout the timeout duration (in milliseconds) to wait for workflow completion */ public WorkflowRequest( @Nullable String workflowId, @@ -96,7 +125,8 @@ public WorkflowRequest( String[] validation, boolean provisionOrUpdate, Map params, - boolean reprovision + boolean reprovision, + TimeValue waitForCompletionTimeout ) { this.workflowId = workflowId; this.template = template; @@ -108,6 +138,7 @@ public WorkflowRequest( } this.params = this.updateFields ? Collections.emptyMap() : params; this.reprovision = reprovision; + this.waitForCompletionTimeout = waitForCompletionTimeout; } /** @@ -133,6 +164,7 @@ public WorkflowRequest(StreamInput in) throws IOException { this.params = Collections.emptyMap(); } this.reprovision = !provision && Boolean.parseBoolean(params.get(REPROVISION_WORKFLOW)); + this.waitForCompletionTimeout = in.readOptionalTimeValue(); } /** @@ -193,6 +225,15 @@ public boolean isReprovision() { return this.reprovision; } + /** + * Gets the timeout duration (in milliseconds) to wait for workflow completion. + * @return the timeout duration, or null if the request should return immediately + */ + @Nullable + public TimeValue getWaitForCompletionTimeout() { + return this.waitForCompletionTimeout; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -207,6 +248,7 @@ public void writeTo(StreamOutput out) throws IOException { } else if (reprovision) { out.writeMap(Map.of(REPROVISION_WORKFLOW, "true"), StreamOutput::writeString, StreamOutput::writeString); } + out.writeOptionalTimeValue(waitForCompletionTimeout); } @Override diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java index 20a7700a3..8a9f21d93 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowResponse.java @@ -13,6 +13,7 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.model.WorkflowState; import java.io.IOException; @@ -27,6 +28,8 @@ public class WorkflowResponse extends ActionResponse implements ToXContentObject * The documentId of the workflow entry within the Global Context index */ private String workflowId; + /** The workflow state */ + private WorkflowState workflowState; /** * Instantiates a new WorkflowResponse from params @@ -44,6 +47,8 @@ public WorkflowResponse(String workflowId) { public WorkflowResponse(StreamInput in) throws IOException { super(in); this.workflowId = in.readString(); + this.workflowState = in.readOptionalWriteable(WorkflowState::new); + } /** @@ -54,14 +59,46 @@ public String getWorkflowId() { return this.workflowId; } + /** + * Gets the workflowState of this repsonse + * @return the workflowState + */ + public WorkflowState getWorkflowState() { + return this.workflowState; + } + + /** + * Constructs a new WorkflowResponse object with the specified workflowId and workflowState. + * The WorkflowResponse is typically returned as part of a `wait_for_completion` request, + * indicating the final state of a workflow after execution. + * @param workflowId The unique identifier for the workflow. + * @param workflowState The current state of the workflow, including status, errors (if any), + * and resources created as part of the workflow execution. + */ + public WorkflowResponse(String workflowId, WorkflowState workflowState) { + this.workflowId = workflowId; + this.workflowState = WorkflowState.builder() + .workflowId(workflowId) + .error(workflowState.getError()) + .state(workflowState.getState()) + .resourcesCreated(workflowState.resourcesCreated()) + .build(); + + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(workflowId); + out.writeOptionalWriteable(workflowState); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.startObject().field(WORKFLOW_ID, this.workflowId).endObject(); + if (workflowState != null) { + return workflowState.toXContent(builder, params); + } else { + return builder.startObject().field(WORKFLOW_ID, this.workflowId).endObject(); + } } } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index f6b1a5fc7..fd31cb823 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -128,6 +128,41 @@ public void testCreateWorkflowRequestWithParamsAndProvision() throws Exception { assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123")); } + public void testRestCreateWorkflow_withWaitForCompletionTimeout() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withParams(Map.of("wait_for_completion_timeout", "5s")) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(new WorkflowResponse("workflow_1")); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + + assertEquals(RestStatus.CREATED, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_1")); + } + + public void testInvalidValueForCompletionTimeout() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withParams(Map.of("wait_for_completion_timeout", "invalid_value")) + .withContent(new BytesArray(validTemplate), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> { + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + }); + + assertTrue(exception.getMessage().contains("failed to parse setting [wait_for_completion_timeout] with value [invalid_value]")); + } + public void testCreateWorkflowRequestWithParamsButNoProvision() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath) diff --git a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java index fd5cd478d..625e48e34 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java @@ -144,4 +144,26 @@ public void testFeatureFlagNotEnabled() throws Exception { assertEquals(RestStatus.FORBIDDEN, channel.capturedResponse().status()); assertTrue(channel.capturedResponse().content().utf8ToString().contains("This API is disabled.")); } + + public void testProvisionWorkflowWithValidWaitForCompletionTimeout() throws Exception { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.provisionWorkflowPath) + .withParams(Map.of("workflow_id", "abc", "wait_for_completion_timeout", "5s")) + .withContent(new BytesArray("{\"foo\": \"bar\"}"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(new WorkflowResponse("workflow_1")); + return null; + }).when(nodeClient).execute(any(), any(WorkflowRequest.class), any()); + + provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + + assertEquals(RestStatus.OK, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_1")); + } + } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index ba76bc833..86a950175 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -38,6 +38,7 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.plugins.PluginsService; import org.opensearch.search.SearchHit; @@ -48,6 +49,7 @@ import org.opensearch.transport.TransportService; import java.io.IOException; +import java.time.Instant; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -252,7 +254,7 @@ public void testMaxWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false,null); doAnswer(invocation -> { ActionListener searchListener = invocation.getArgument(1); @@ -289,7 +291,15 @@ public void onFailure(Exception e) { public void testFailedToCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + false, + null + ); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -320,7 +330,15 @@ public void testFailedToCreateNewWorkflow() { public void testCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + false, + null + ); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -384,7 +402,15 @@ public void testCreateWithUserAndFilterOn() { ); ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + false, + null + ); // Bypass checkMaxWorkflows and force onResponse doAnswer(invocation -> { @@ -448,7 +474,15 @@ public void testFailedToCreateNewWorkflowWithNullUser() { ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + false, + null + ); createWorkflowTransportAction1.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -483,7 +517,15 @@ public void testFailedToCreateNewWorkflowWithNoBackendRoleUser() { ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), false); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + false, + null + ); createWorkflowTransportAction1.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -497,7 +539,15 @@ public void testFailedToCreateNewWorkflowWithNoBackendRoleUser() { public void testUpdateWorkflowWithReprovision() throws IOException { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest("1", template, new String[] { "off" }, false, Collections.emptyMap(), true); + WorkflowRequest workflowRequest = new WorkflowRequest( + "1", + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + true, + null + ); doAnswer(invocation -> { ActionListener getListener = invocation.getArgument(1); @@ -541,7 +591,15 @@ public void testUpdateWorkflowWithReprovision() throws IOException { public void testFailedToUpdateWorkflowWithReprovision() throws IOException { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest workflowRequest = new WorkflowRequest("1", template, new String[] { "off" }, false, Collections.emptyMap(), true); + WorkflowRequest workflowRequest = new WorkflowRequest( + "1", + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + true, + null + ); doAnswer(invocation -> { ActionListener getListener = invocation.getArgument(1); @@ -841,7 +899,8 @@ public void testCreateWorkflow_withValidation_withProvision_Success() throws Exc new String[] { "all" }, true, Collections.emptyMap(), - false + false, + null ); // Bypass checkMaxWorkflows and force onResponse @@ -888,6 +947,82 @@ public void testCreateWorkflow_withValidation_withProvision_Success() throws Exc assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId()); } + public void testCreateWorkflow_withValidation_withWaitForCompletion_withProvision_Success() throws Exception { + + Template validTemplate = generateValidTemplate(); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + doNothing().when(workflowProcessSorter).validate(any(), any()); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + validTemplate, + new String[] { "all" }, + true, + Collections.emptyMap(), + false, + TimeValue.timeValueSeconds(5) + ); + + // Bypass checkMaxWorkflows and force onResponse + doAnswer(invocation -> { + ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + checkMaxWorkflowListener.onResponse(true); + return null; + }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + + // Bypass putTemplateToGlobalContext and force onResponse + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(), any()); + + // Bypass putInitialStateToWorkflowState and force on response + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(new IndexResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(any(), any(), any()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + WorkflowResponse response = mock(WorkflowResponse.class); + when(response.getWorkflowId()).thenReturn("1"); + when(response.getWorkflowState()).thenReturn( + new WorkflowState( + "1", + "test", + "PROVISIONING", + "IN_PROGRESS", + Instant.now(), + Instant.now(), + TestHelpers.randomUser(), + Collections.emptyMap(), + Collections.emptyList() + ) + ); + responseListener.onResponse(response); + return null; + }).when(client).execute(eq(ProvisionWorkflowAction.INSTANCE), any(WorkflowRequest.class), any(ActionListener.class)); + + ArgumentCaptor workflowResponseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + + createWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + + verify(listener, times(1)).onResponse(workflowResponseCaptor.capture()); + assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId()); + assertEquals("PROVISIONING", workflowResponseCaptor.getValue().getWorkflowState().getState()); + } + public void testCreateWorkflow_withValidation_withProvision_FailedProvisioning() throws Exception { Template validTemplate = generateValidTemplate(); @@ -901,7 +1036,8 @@ public void testCreateWorkflow_withValidation_withProvision_FailedProvisioning() new String[] { "all" }, true, Collections.emptyMap(), - false + false, + null ); // Bypass checkMaxWorkflows and force onResponse diff --git a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java index e92255e0f..50c60a19e 100644 --- a/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java @@ -21,9 +21,11 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; +import java.time.Instant; import java.util.Collections; import java.util.List; import java.util.Map; @@ -156,7 +158,7 @@ public void testWorkflowRequestWithParams() throws IOException { public void testWorkflowRequestWithParamsNoProvision() throws IOException { IllegalArgumentException ex = assertThrows( IllegalArgumentException.class, - () -> new WorkflowRequest("123", template, new String[] { "all" }, false, Map.of("foo", "bar"), false) + () -> new WorkflowRequest("123", template, new String[] { "all" }, false, Map.of("foo", "bar"), false, null) ); assertEquals("Params may only be included when provisioning.", ex.getMessage()); } @@ -168,7 +170,8 @@ public void testWorkflowRequestWithOnlyUpdateParamNoProvision() throws IOExcepti new String[] { "all" }, true, Map.of(UPDATE_WORKFLOW_FIELDS, "true"), - false + false, + null ); assertNotNull(workflowRequest.getWorkflowId()); assertEquals(template, workflowRequest.getTemplate()); @@ -208,4 +211,37 @@ public void testWorkflowResponse() throws IOException { assertEquals("{\"workflow_id\":\"123\"}", builder.toString()); } + public void testWorkflowResponseWithWaitForCompletionTimeOut() throws IOException { + WorkflowState workFlowState = new WorkflowState( + "123", + "test", + "PROVISIONING", + "IN_PROGRESS", + Instant.now(), + Instant.now(), + TestHelpers.randomUser(), + Collections.emptyMap(), + Collections.emptyList() + ); + + WorkflowResponse response = new WorkflowResponse("123", workFlowState); + assertEquals("123", response.getWorkflowId()); + assertEquals("PROVISIONING", response.getWorkflowState().getState()); + + BytesStreamOutput out = new BytesStreamOutput(); + response.writeTo(out); + BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes())); + WorkflowResponse streamInputResponse = new WorkflowResponse(in); + + assertEquals(response.getWorkflowId(), streamInputResponse.getWorkflowId()); + assertEquals(response.getWorkflowState().getState(), streamInputResponse.getWorkflowState().getState()); + + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + + assertNotNull(builder); + assertTrue(builder.toString().contains("\"workflow_id\":\"123\"")); + assertTrue(builder.toString().contains("\"state\":\"PROVISIONING\"")); + } + }