diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index 08386b797e..74c7f4387d 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -83,6 +83,7 @@ import org.opensearch.sql.spark.cluster.ClusterManagerEventListener; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataServiceImpl; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; import org.opensearch.sql.spark.storage.SparkStorageFactory; import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; @@ -229,7 +230,8 @@ public Collection<Object> createComponents( dataSourceService, injector.getInstance(FlintIndexMetadataServiceImpl.class), injector.getInstance(StateStore.class), - injector.getInstance(EMRServerlessClientFactory.class)); + injector.getInstance(EMRServerlessClientFactory.class), + injector.getInstance(FlintIndexStateModelService.class)); return ImmutableList.of( dataSourceService, injector.getInstance(AsyncQueryExecutorService.class), diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java index cef3b6ede2..439e2a602e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java @@ -7,7 +7,7 @@ package org.opensearch.sql.spark.asyncquery; -import static org.opensearch.sql.spark.execution.statestore.StateStore.createJobMetaData; +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import java.util.Optional; import lombok.RequiredArgsConstructor; @@ -17,6 +17,7 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.execution.xcontent.AsyncQueryJobMetadataXContentSerializer; /** Opensearch implementation of {@link AsyncQueryJobMetadataStorageService} */ @RequiredArgsConstructor @@ -31,15 +32,22 @@ public class OpensearchAsyncQueryJobMetadataStorageService @Override public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { AsyncQueryId queryId = asyncQueryJobMetadata.getQueryId(); - createJobMetaData(stateStore, queryId.getDataSourceName()).apply(asyncQueryJobMetadata); + stateStore.create( + asyncQueryJobMetadata, + AsyncQueryJobMetadata::copy, + DATASOURCE_TO_REQUEST_INDEX.apply(queryId.getDataSourceName())); } @Override public Optional<AsyncQueryJobMetadata> getJobMetadata(String qid) { try { AsyncQueryId queryId = new AsyncQueryId(qid); - return StateStore.getJobMetaData(stateStore, queryId.getDataSourceName()) - .apply(queryId.docId()); + AsyncQueryJobMetadataXContentSerializer asyncQueryJobMetadataXContentSerializer = + new AsyncQueryJobMetadataXContentSerializer(); + return stateStore.get( + queryId.docId(), + asyncQueryJobMetadataXContentSerializer::fromXContent, + DATASOURCE_TO_REQUEST_INDEX.apply(queryId.getDataSourceName())); } catch (Exception e) { LOGGER.error("Error while fetching the job metadata.", e); throw new AsyncQueryNotFoundException(String.format("Invalid QueryId: %s", qid)); diff --git a/spark/src/main/java/org/opensearch/sql/spark/cluster/ClusterManagerEventListener.java b/spark/src/main/java/org/opensearch/sql/spark/cluster/ClusterManagerEventListener.java index f04c6cb830..b7ecc8bbfd 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/cluster/ClusterManagerEventListener.java +++ b/spark/src/main/java/org/opensearch/sql/spark/cluster/ClusterManagerEventListener.java @@ -24,6 +24,7 @@ import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; import org.opensearch.threadpool.Scheduler.Cancellable; import org.opensearch.threadpool.ThreadPool; @@ -39,6 +40,8 @@ public class ClusterManagerEventListener implements LocalNodeClusterManagerListe private FlintIndexMetadataService flintIndexMetadataService; private StateStore stateStore; private EMRServerlessClientFactory emrServerlessClientFactory; + + private FlintIndexStateModelService flintIndexStateModelService; private Duration sessionTtlDuration; private Duration resultTtlDuration; private TimeValue streamingJobHouseKeepingInterval; @@ -57,7 +60,8 @@ public ClusterManagerEventListener( DataSourceService dataSourceService, FlintIndexMetadataService flintIndexMetadataService, StateStore stateStore, - EMRServerlessClientFactory emrServerlessClientFactory) { + EMRServerlessClientFactory emrServerlessClientFactory, + FlintIndexStateModelService flintIndexStateModelService) { this.clusterService = clusterService; this.threadPool = threadPool; this.client = client; @@ -70,7 +74,7 @@ public ClusterManagerEventListener( this.sessionTtlDuration = toDuration(sessionTtl.get(settings)); this.resultTtlDuration = toDuration(resultTtl.get(settings)); this.streamingJobHouseKeepingInterval = streamingJobHouseKeepingInterval.get(settings); - + this.flintIndexStateModelService = flintIndexStateModelService; clusterService .getClusterSettings() .addSettingsUpdateConsumer( @@ -153,7 +157,7 @@ private void initializeStreamingJobHouseKeeperCron() { new FlintStreamingJobHouseKeeperTask( dataSourceService, flintIndexMetadataService, - stateStore, + flintIndexStateModelService, emrServerlessClientFactory), streamingJobHouseKeepingInterval, executorName()); diff --git a/spark/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java b/spark/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java index 27221f1b72..7cd0a869da 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java +++ b/spark/src/main/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTask.java @@ -19,9 +19,9 @@ import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; import org.opensearch.sql.spark.flint.operation.FlintIndexOpAlter; import org.opensearch.sql.spark.flint.operation.FlintIndexOpDrop; @@ -31,7 +31,7 @@ public class FlintStreamingJobHouseKeeperTask implements Runnable { private final DataSourceService dataSourceService; private final FlintIndexMetadataService flintIndexMetadataService; - private final StateStore stateStore; + private final FlintIndexStateModelService flintIndexStateModelService; private final EMRServerlessClientFactory emrServerlessClientFactory; private static final Logger LOGGER = LogManager.getLogger(FlintStreamingJobHouseKeeperTask.class); @@ -96,7 +96,8 @@ private void dropAutoRefreshIndex( // When the datasource is deleted. Possibly Replace with VACUUM Operation. LOGGER.info("Attempting to drop auto refresh index: {}", autoRefreshIndex); FlintIndexOpDrop flintIndexOpDrop = - new FlintIndexOpDrop(stateStore, datasourceName, emrServerlessClientFactory.getClient()); + new FlintIndexOpDrop( + flintIndexStateModelService, datasourceName, emrServerlessClientFactory.getClient()); flintIndexOpDrop.apply(flintIndexMetadata); LOGGER.info("Successfully dropped index: {}", autoRefreshIndex); } @@ -109,7 +110,7 @@ private void alterAutoRefreshIndex( FlintIndexOpAlter flintIndexOpAlter = new FlintIndexOpAlter( flintIndexOptions, - stateStore, + flintIndexStateModelService, datasourceName, emrServerlessClientFactory.getClient(), flintIndexMetadataService); diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java index 412db50e85..233e2d14c6 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java @@ -7,7 +7,6 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; -import static org.opensearch.sql.spark.execution.statestore.StateStore.createIndexDMLResult; import com.amazonaws.services.emrserverless.model.JobRunState; import java.util.Map; @@ -27,9 +26,10 @@ import org.opensearch.sql.spark.dispatcher.model.IndexDMLResult; import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; import org.opensearch.sql.spark.execution.statement.StatementState; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; +import org.opensearch.sql.spark.flint.IndexDMLResultStorageService; import org.opensearch.sql.spark.flint.operation.FlintIndexOp; import org.opensearch.sql.spark.flint.operation.FlintIndexOpAlter; import org.opensearch.sql.spark.flint.operation.FlintIndexOpDrop; @@ -51,7 +51,8 @@ public class IndexDMLHandler extends AsyncQueryHandler { private final FlintIndexMetadataService flintIndexMetadataService; - private final StateStore stateStore; + private final FlintIndexStateModelService flintIndexStateModelService; + private final IndexDMLResultStorageService indexDMLResultStorageService; private final Client client; @@ -106,7 +107,7 @@ private AsyncQueryId storeIndexDMLResult( dispatchQueryRequest.getDatasource(), System.currentTimeMillis() - startTime, System.currentTimeMillis()); - createIndexDMLResult(stateStore, dataSourceMetadata.getResultIndex()).apply(indexDMLResult); + indexDMLResultStorageService.createIndexDMLResult(indexDMLResult, dataSourceMetadata.getName()); return asyncQueryId; } @@ -118,14 +119,16 @@ private void executeIndexOp( case DROP: FlintIndexOp dropOp = new FlintIndexOpDrop( - stateStore, dispatchQueryRequest.getDatasource(), emrServerlessClient); + flintIndexStateModelService, + dispatchQueryRequest.getDatasource(), + emrServerlessClient); dropOp.apply(indexMetadata); break; case ALTER: FlintIndexOpAlter flintIndexOpAlter = new FlintIndexOpAlter( indexQueryDetails.getFlintIndexOptions(), - stateStore, + flintIndexStateModelService, dispatchQueryRequest.getDatasource(), emrServerlessClient, flintIndexMetadataService); @@ -133,7 +136,8 @@ private void executeIndexOp( break; case VACUUM: FlintIndexOp indexVacuumOp = - new FlintIndexOpVacuum(stateStore, dispatchQueryRequest.getDatasource(), client); + new FlintIndexOpVacuum( + flintIndexStateModelService, dispatchQueryRequest.getDatasource(), client); indexVacuumOp.apply(indexMetadata); break; default: diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index d55408f62e..db06460361 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -13,9 +13,9 @@ import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.JobType; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; import org.opensearch.sql.spark.flint.operation.FlintIndexOp; import org.opensearch.sql.spark.flint.operation.FlintIndexOpCancel; import org.opensearch.sql.spark.leasemanager.LeaseManager; @@ -25,18 +25,18 @@ public class RefreshQueryHandler extends BatchQueryHandler { private final FlintIndexMetadataService flintIndexMetadataService; - private final StateStore stateStore; + private final FlintIndexStateModelService flintIndexStateModelService; private final EMRServerlessClient emrServerlessClient; public RefreshQueryHandler( EMRServerlessClient emrServerlessClient, JobExecutionResponseReader jobExecutionResponseReader, FlintIndexMetadataService flintIndexMetadataService, - StateStore stateStore, + FlintIndexStateModelService flintIndexStateModelService, LeaseManager leaseManager) { super(emrServerlessClient, jobExecutionResponseReader, leaseManager); this.flintIndexMetadataService = flintIndexMetadataService; - this.stateStore = stateStore; + this.flintIndexStateModelService = flintIndexStateModelService; this.emrServerlessClient = emrServerlessClient; } @@ -52,7 +52,7 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { } FlintIndexMetadata indexMetadata = indexMetadataMap.get(asyncQueryJobMetadata.getIndexName()); FlintIndexOp jobCancelOp = - new FlintIndexOpCancel(stateStore, datasourceName, emrServerlessClient); + new FlintIndexOpCancel(flintIndexStateModelService, datasourceName, emrServerlessClient); jobCancelOp.apply(indexMetadata); return asyncQueryJobMetadata.getQueryId().getId(); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index c4f4c74868..d3d0e9ec94 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -23,8 +23,9 @@ import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.execution.session.SessionManager; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; +import org.opensearch.sql.spark.flint.IndexDMLResultStorageService; import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; @@ -53,7 +54,9 @@ public class SparkQueryDispatcher { private LeaseManager leaseManager; - private StateStore stateStore; + private FlintIndexStateModelService flintIndexStateModelService; + + private IndexDMLResultStorageService indexDMLResultStorageService; public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient(); @@ -91,7 +94,7 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) emrServerlessClient, jobExecutionResponseReader, flintIndexMetadataService, - stateStore, + flintIndexStateModelService, leaseManager); } } @@ -145,7 +148,7 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { emrServerlessClient, jobExecutionResponseReader, flintIndexMetadataService, - stateStore, + flintIndexStateModelService, leaseManager); } else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) { queryHandler = @@ -162,7 +165,8 @@ private IndexDMLHandler createIndexDMLHandler(EMRServerlessClient emrServerlessC emrServerlessClient, jobExecutionResponseReader, flintIndexMetadataService, - stateStore, + flintIndexStateModelService, + indexDMLResultStorageService, client); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 2363615a7d..82d4b2b3a4 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -10,8 +10,6 @@ import static org.opensearch.sql.spark.execution.session.SessionState.END_STATE; import static org.opensearch.sql.spark.execution.session.SessionState.FAIL; import static org.opensearch.sql.spark.execution.statement.StatementId.newStatementId; -import static org.opensearch.sql.spark.execution.statestore.StateStore.createSession; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import java.util.Optional; import lombok.Builder; @@ -24,6 +22,8 @@ import org.opensearch.sql.spark.execution.statement.QueryRequest; import org.opensearch.sql.spark.execution.statement.Statement; import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statement.StatementStorageService; +import org.opensearch.sql.spark.execution.statestore.SessionStorageService; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.utils.TimeProvider; @@ -42,6 +42,8 @@ public class InteractiveSession implements Session { private final SessionId sessionId; private final StateStore stateStore; + private final StatementStorageService statementStorageService; + private final SessionStorageService sessionStorageService; private final EMRServerlessClient serverlessClient; private SessionModel sessionModel; // the threshold of elapsed time in milliseconds before we say a session is stale @@ -64,7 +66,7 @@ public void open(CreateSessionRequest createSessionRequest) { sessionModel = initInteractiveSession( applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); - createSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel); + sessionStorageService.createSession(sessionModel, sessionModel.getDatasourceName()); } catch (VersionConflictEngineException e) { String errorMsg = "session already exist. " + sessionId; LOG.error(errorMsg); @@ -76,7 +78,7 @@ public void open(CreateSessionRequest createSessionRequest) { @Override public void close() { Optional<SessionModel> model = - getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId()); + sessionStorageService.getSession(sessionModel.getId(), sessionModel.getDatasourceName()); if (model.isEmpty()) { throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { @@ -88,7 +90,7 @@ public void close() { /** Submit statement. If submit successfully, Statement in waiting state. */ public StatementId submit(QueryRequest request) { Optional<SessionModel> model = - getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId()); + sessionStorageService.getSession(sessionModel.getId(), sessionModel.getDatasourceName()); if (model.isEmpty()) { throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { @@ -102,6 +104,7 @@ public StatementId submit(QueryRequest request) { .applicationId(sessionModel.getApplicationId()) .jobId(sessionModel.getJobId()) .stateStore(stateStore) + .statementStorageService(statementStorageService) .statementId(statementId) .langType(LangType.SQL) .datasourceName(sessionModel.getDatasourceName()) @@ -124,8 +127,8 @@ public StatementId submit(QueryRequest request) { @Override public Optional<Statement> get(StatementId stID) { - return StateStore.getStatement(stateStore, sessionModel.getDatasourceName()) - .apply(stID.getId()) + return statementStorageService + .getStatementModel(stID.getId(), sessionModel.getDatasourceName()) .map( model -> Statement.builder() @@ -137,6 +140,7 @@ public Optional<Statement> get(StatementId stID) { .query(model.getQuery()) .queryId(model.getQueryId()) .stateStore(stateStore) + .statementStorageService(statementStorageService) .statementModel(model) .build()); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index e441492c20..1babd8712d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -11,6 +11,8 @@ import java.util.Optional; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; +import org.opensearch.sql.spark.execution.statement.StatementStorageService; +import org.opensearch.sql.spark.execution.statestore.SessionStorageService; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.utils.RealTimeProvider; @@ -21,14 +23,21 @@ */ public class SessionManager { private final StateStore stateStore; + private final StatementStorageService statementStorageService; + + private final SessionStorageService sessionStorageService; private final EMRServerlessClientFactory emrServerlessClientFactory; private Settings settings; public SessionManager( StateStore stateStore, + StatementStorageService statementStorageService, + SessionStorageService sessionStorageService, EMRServerlessClientFactory emrServerlessClientFactory, Settings settings) { this.stateStore = stateStore; + this.statementStorageService = statementStorageService; + this.sessionStorageService = sessionStorageService; this.emrServerlessClientFactory = emrServerlessClientFactory; this.settings = settings; } @@ -38,6 +47,8 @@ public Session createSession(CreateSessionRequest request) { InteractiveSession.builder() .sessionId(newSessionId(request.getDatasourceName())) .stateStore(stateStore) + .statementStorageService(statementStorageService) + .sessionStorageService(sessionStorageService) .serverlessClient(emrServerlessClientFactory.getClient()) .build(); session.open(request); @@ -64,12 +75,14 @@ public Session createSession(CreateSessionRequest request) { */ public Optional<Session> getSession(SessionId sid, String dataSourceName) { Optional<SessionModel> model = - StateStore.getSession(stateStore, dataSourceName).apply(sid.getSessionId()); + sessionStorageService.getSession(sid.getSessionId(), dataSourceName); if (model.isPresent()) { InteractiveSession session = InteractiveSession.builder() .sessionId(sid) .stateStore(stateStore) + .statementStorageService(statementStorageService) + .sessionStorageService(sessionStorageService) .serverlessClient(emrServerlessClientFactory.getClient()) .sessionModel(model.get()) .sessionInactivityTimeoutMilli( diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/OpenSearchStatementStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/OpenSearchStatementStorageService.java new file mode 100644 index 0000000000..95bc6b1c56 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/OpenSearchStatementStorageService.java @@ -0,0 +1,37 @@ +package org.opensearch.sql.spark.execution.statement; + +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; + +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.execution.xcontent.StatementModelXContentSerializer; + +@RequiredArgsConstructor +public class OpenSearchStatementStorageService implements StatementStorageService { + + private final StateStore stateStore; + + @Override + public StatementModel createStatement(StatementModel statementModel, String datasourceName) { + return stateStore.create( + statementModel, StatementModel::copy, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + @Override + public StatementModel updateStatementState( + StatementModel oldStatementModel, StatementState statementState, String datasourceName) { + return stateStore.updateState( + oldStatementModel, + statementState, + StatementModel::copyWithState, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + @Override + public Optional<StatementModel> getStatementModel(String id, String datasourceName) { + StatementModelXContentSerializer serializer = new StatementModelXContentSerializer(); + return stateStore.get( + id, serializer::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java index 94c1f79511..e42af72c7a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -6,9 +6,6 @@ package org.opensearch.sql.spark.execution.statement; import static org.opensearch.sql.spark.execution.statement.StatementModel.submitStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.createStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; import lombok.Builder; import lombok.Getter; @@ -36,6 +33,7 @@ public class Statement { private final String query; private final String queryId; private final StateStore stateStore; + private final StatementStorageService statementStorageService; @Setter private StatementModel statementModel; @@ -52,7 +50,7 @@ public void open() { datasourceName, query, queryId); - statementModel = createStatement(stateStore, datasourceName).apply(statementModel); + statementModel = statementStorageService.createStatement(statementModel, datasourceName); } catch (VersionConflictEngineException e) { String errorMsg = "statement already exist. " + statementId; LOG.error(errorMsg); @@ -76,8 +74,8 @@ public void cancel() { } try { this.statementModel = - updateStatementState(stateStore, statementModel.getDatasourceName()) - .apply(this.statementModel, StatementState.CANCELLED); + statementStorageService.updateStatementState( + this.statementModel, StatementState.CANCELLED, statementModel.getDatasourceName()); } catch (DocumentMissingException e) { String errorMsg = String.format("cancel statement failed. no statement found. statement: %s.", statementId); @@ -85,8 +83,8 @@ public void cancel() { throw new IllegalStateException(errorMsg); } catch (VersionConflictEngineException e) { this.statementModel = - getStatement(stateStore, statementModel.getDatasourceName()) - .apply(statementModel.getId()) + statementStorageService + .getStatementModel(statementId.getId(), statementModel.getDatasourceName()) .orElse(this.statementModel); String errorMsg = String.format( diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementStorageService.java new file mode 100644 index 0000000000..02fe911cc2 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementStorageService.java @@ -0,0 +1,13 @@ +package org.opensearch.sql.spark.execution.statement; + +import java.util.Optional; + +public interface StatementStorageService { + + StatementModel createStatement(StatementModel statementModel, String datasourceName); + + StatementModel updateStatementState( + StatementModel oldStatementModel, StatementState statementState, String datasourceName); + + Optional<StatementModel> getStatementModel(String id, String datasourceName); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpensearchSessionStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpensearchSessionStorageService.java new file mode 100644 index 0000000000..1d67a35582 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/OpensearchSessionStorageService.java @@ -0,0 +1,38 @@ +package org.opensearch.sql.spark.execution.statestore; + +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; + +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.execution.session.SessionModel; +import org.opensearch.sql.spark.execution.session.SessionState; +import org.opensearch.sql.spark.execution.xcontent.SessionModelXContentSerializer; + +@RequiredArgsConstructor +public class OpensearchSessionStorageService implements SessionStorageService { + + private final StateStore stateStore; + + @Override + public SessionModel createSession(SessionModel sessionModel, String datasourceName) { + return stateStore.create( + sessionModel, SessionModel::of, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + @Override + public Optional<SessionModel> getSession(String id, String datasourceName) { + SessionModelXContentSerializer serializer = new SessionModelXContentSerializer(); + return stateStore.get( + id, serializer::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + @Override + public SessionModel updateSessionState( + SessionModel sessionModel, SessionState sessionState, String datasourceName) { + return stateStore.updateState( + sessionModel, + sessionState, + SessionModel::copyWithState, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStorageService.java new file mode 100644 index 0000000000..93dd4148dc --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStorageService.java @@ -0,0 +1,15 @@ +package org.opensearch.sql.spark.execution.statestore; + +import java.util.Optional; +import org.opensearch.sql.spark.execution.session.SessionModel; +import org.opensearch.sql.spark.execution.session.SessionState; + +public interface SessionStorageService { + + SessionModel createSession(SessionModel sessionModel, String datasourceName); + + Optional<SessionModel> getSession(String id, String datasourceName); + + SessionModel updateSessionState( + SessionModel sessionModel, SessionState sessionState, String datasourceName); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index bad44905ae..65b5e7c96e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -14,7 +14,6 @@ import java.nio.charset.StandardCharsets; import java.util.Locale; import java.util.Optional; -import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; import lombok.RequiredArgsConstructor; @@ -256,77 +255,6 @@ private String loadConfigFromResource(String fileName) throws IOException { return IOUtils.toString(fileStream, StandardCharsets.UTF_8); } - /** Helper Functions */ - public static Function<StatementModel, StatementModel> createStatement( - StateStore stateStore, String datasourceName) { - return (st) -> - stateStore.create( - st, StatementModel::copy, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Function<String, Optional<StatementModel>> getStatement( - StateStore stateStore, String datasourceName) { - StatementModelXContentSerializer serializer = new StatementModelXContentSerializer(); - return (docId) -> - stateStore.get( - docId, serializer::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static BiFunction<StatementModel, StatementState, StatementModel> updateStatementState( - StateStore stateStore, String datasourceName) { - return (old, state) -> - stateStore.updateState( - old, - state, - StatementModel::copyWithState, - DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Function<SessionModel, SessionModel> createSession( - StateStore stateStore, String datasourceName) { - return (session) -> - stateStore.create( - session, SessionModel::of, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Function<String, Optional<SessionModel>> getSession( - StateStore stateStore, String datasourceName) { - SessionModelXContentSerializer serializer = new SessionModelXContentSerializer(); - return (docId) -> - stateStore.get( - docId, serializer::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static BiFunction<SessionModel, SessionState, SessionModel> updateSessionState( - StateStore stateStore, String datasourceName) { - return (old, state) -> - stateStore.updateState( - old, - state, - SessionModel::copyWithState, - DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Function<AsyncQueryJobMetadata, AsyncQueryJobMetadata> createJobMetaData( - StateStore stateStore, String datasourceName) { - return (jobMetadata) -> - stateStore.create( - jobMetadata, - AsyncQueryJobMetadata::copy, - DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Function<String, Optional<AsyncQueryJobMetadata>> getJobMetaData( - StateStore stateStore, String datasourceName) { - AsyncQueryJobMetadataXContentSerializer asyncQueryJobMetadataXContentSerializer = - new AsyncQueryJobMetadataXContentSerializer(); - return (docId) -> - stateStore.get( - docId, - asyncQueryJobMetadataXContentSerializer::fromXContent, - DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - public static Supplier<Long> activeSessionsCount(StateStore stateStore, String datasourceName) { return () -> stateStore.count( @@ -341,47 +269,6 @@ public static Supplier<Long> activeSessionsCount(StateStore stateStore, String d SessionModel.SESSION_STATE, SessionState.RUNNING.getSessionState()))); } - public static BiFunction<FlintIndexStateModel, FlintIndexState, FlintIndexStateModel> - updateFlintIndexState(StateStore stateStore, String datasourceName) { - return (old, state) -> - stateStore.updateState( - old, - state, - FlintIndexStateModel::copyWithState, - DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Function<String, Optional<FlintIndexStateModel>> getFlintIndexState( - StateStore stateStore, String datasourceName) { - return (docId) -> - stateStore.get( - docId, - FlintIndexStateModel::fromXContent, - DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Function<FlintIndexStateModel, FlintIndexStateModel> createFlintIndexState( - StateStore stateStore, String datasourceName) { - return (st) -> - stateStore.create( - st, FlintIndexStateModel::copy, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - /** - * @param stateStore index state store - * @param datasourceName data source name - * @return function that accepts index state doc ID and perform the deletion - */ - public static Function<String, Boolean> deleteFlintIndexState( - StateStore stateStore, String datasourceName) { - return (docId) -> stateStore.delete(docId, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); - } - - public static Function<IndexDMLResult, IndexDMLResult> createIndexDMLResult( - StateStore stateStore, String indexName) { - return (result) -> stateStore.create(result, IndexDMLResult::copy, indexName); - } - public static Supplier<Long> activeRefreshJobCount(StateStore stateStore, String datasourceName) { return () -> stateStore.count( diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java new file mode 100644 index 0000000000..9c3dc8220f --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModelService.java @@ -0,0 +1,17 @@ +package org.opensearch.sql.spark.flint; + +import java.util.Optional; + +public interface FlintIndexStateModelService { + FlintIndexStateModel updateFlintIndexState( + FlintIndexStateModel flintIndexStateModel, + FlintIndexState flintIndexState, + String datasourceName); + + Optional<FlintIndexStateModel> getFlintIndexStateModel(String id, String datasourceName); + + FlintIndexStateModel createFlintIndexStateModel( + FlintIndexStateModel flintIndexStateModel, String datasourceName); + + boolean deleteFlintIndexStateModel(String id, String datasourceName); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/IndexDMLResultStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/flint/IndexDMLResultStorageService.java new file mode 100644 index 0000000000..37b53dbdbb --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/IndexDMLResultStorageService.java @@ -0,0 +1,8 @@ +package org.opensearch.sql.spark.flint; + +import org.opensearch.sql.spark.dispatcher.model.IndexDMLResult; + +public interface IndexDMLResultStorageService { + + IndexDMLResult createIndexDMLResult(IndexDMLResult result, String datasourceName); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java b/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java new file mode 100644 index 0000000000..31436e5512 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchFlintIndexStateModelService.java @@ -0,0 +1,45 @@ +package org.opensearch.sql.spark.flint; + +import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; + +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.execution.statestore.StateStore; + +@RequiredArgsConstructor +public class OpenSearchFlintIndexStateModelService implements FlintIndexStateModelService { + + private final StateStore stateStore; + + @Override + public FlintIndexStateModel updateFlintIndexState( + FlintIndexStateModel flintIndexStateModel, + FlintIndexState flintIndexState, + String datasourceName) { + return stateStore.updateState( + flintIndexStateModel, + flintIndexState, + FlintIndexStateModel::copyWithState, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + @Override + public Optional<FlintIndexStateModel> getFlintIndexStateModel(String id, String datasourceName) { + return stateStore.get( + id, FlintIndexStateModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + @Override + public FlintIndexStateModel createFlintIndexStateModel( + FlintIndexStateModel flintIndexStateModel, String datasourceName) { + return stateStore.create( + flintIndexStateModel, + FlintIndexStateModel::copy, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + @Override + public boolean deleteFlintIndexStateModel(String id, String datasourceName) { + return stateStore.delete(id, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchIndexDMLResultStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchIndexDMLResultStorageService.java new file mode 100644 index 0000000000..34437d8609 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/OpenSearchIndexDMLResultStorageService.java @@ -0,0 +1,20 @@ +package org.opensearch.sql.spark.flint; + +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.spark.dispatcher.model.IndexDMLResult; +import org.opensearch.sql.spark.execution.statestore.StateStore; + +@RequiredArgsConstructor +public class OpenSearchIndexDMLResultStorageService implements IndexDMLResultStorageService { + + private final DataSourceService dataSourceService; + private final StateStore stateStore; + + @Override + public IndexDMLResult createIndexDMLResult(IndexDMLResult result, String datasourceName) { + DataSourceMetadata dataSourceMetadata = dataSourceService.getDataSourceMetadata(datasourceName); + return stateStore.create(result, IndexDMLResult::copy, dataSourceMetadata.getResultIndex()); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java index 8d5e301631..c8cfe2ac23 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java @@ -6,9 +6,6 @@ package org.opensearch.sql.spark.flint.operation; import static org.opensearch.sql.spark.client.EmrServerlessClientImpl.GENERIC_INTERNAL_SERVER_ERROR_MESSAGE; -import static org.opensearch.sql.spark.execution.statestore.StateStore.deleteFlintIndexState; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getFlintIndexState; -import static org.opensearch.sql.spark.execution.statestore.StateStore.updateFlintIndexState; import com.amazonaws.services.emrserverless.model.ValidationException; import java.util.Locale; @@ -21,17 +18,17 @@ import org.jetbrains.annotations.NotNull; import org.opensearch.index.seqno.SequenceNumbers; import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; /** Flint Index Operation. */ @RequiredArgsConstructor public abstract class FlintIndexOp { private static final Logger LOG = LogManager.getLogger(); - private final StateStore stateStore; + private final FlintIndexStateModelService flintIndexStateModelService; private final String datasourceName; /** Apply operation on {@link FlintIndexMetadata} */ @@ -55,8 +52,10 @@ public void apply(FlintIndexMetadata metadata) { } catch (Throwable e) { LOG.error("Rolling back transient log due to transaction operation failure", e); try { - updateFlintIndexState(stateStore, datasourceName) - .apply(transitionedFlintIndexStateModel, initialFlintIndexStateModel.getIndexState()); + flintIndexStateModelService.updateFlintIndexState( + transitionedFlintIndexStateModel, + initialFlintIndexStateModel.getIndexState(), + datasourceName); } catch (Exception ex) { LOG.error("Failed to rollback transient log", ex); } @@ -68,7 +67,7 @@ public void apply(FlintIndexMetadata metadata) { @NotNull private FlintIndexStateModel getFlintIndexStateModel(String latestId) { Optional<FlintIndexStateModel> flintIndexOptional = - getFlintIndexState(stateStore, datasourceName).apply(latestId); + flintIndexStateModelService.getFlintIndexStateModel(latestId, datasourceName); if (flintIndexOptional.isEmpty()) { String errorMsg = String.format(Locale.ROOT, "no state found. docId: %s", latestId); LOG.error(errorMsg); @@ -109,7 +108,8 @@ private FlintIndexStateModel moveToTransitioningState(FlintIndexStateModel flint FlintIndexState transitioningState = transitioningState(); try { flintIndex = - updateFlintIndexState(stateStore, datasourceName).apply(flintIndex, transitioningState()); + flintIndexStateModelService.updateFlintIndexState( + flintIndex, transitioningState(), datasourceName); } catch (Exception e) { String errorMsg = String.format(Locale.ROOT, "Moving to transition state:%s failed.", transitioningState); @@ -125,9 +125,10 @@ private void commit(FlintIndexStateModel flintIndex) { try { if (stableState == FlintIndexState.NONE) { LOG.info("Deleting index state with docId: " + flintIndex.getLatestId()); - deleteFlintIndexState(stateStore, datasourceName).apply(flintIndex.getLatestId()); + flintIndexStateModelService.deleteFlintIndexStateModel( + flintIndex.getLatestId(), datasourceName); } else { - updateFlintIndexState(stateStore, datasourceName).apply(flintIndex, stableState); + flintIndexStateModelService.updateFlintIndexState(flintIndex, stableState, datasourceName); } } catch (Exception e) { String errorMsg = diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java index 7db4f6a4c6..04a5ca1bf7 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpAlter.java @@ -10,11 +10,11 @@ import org.apache.logging.log4j.Logger; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; /** * Index Operation for Altering the flint index. Only handles alter operation when @@ -28,11 +28,11 @@ public class FlintIndexOpAlter extends FlintIndexOp { public FlintIndexOpAlter( FlintIndexOptions flintIndexOptions, - StateStore stateStore, + FlintIndexStateModelService flintIndexStateModelService, String datasourceName, EMRServerlessClient emrServerlessClient, FlintIndexMetadataService flintIndexMetadataService) { - super(stateStore, datasourceName); + super(flintIndexStateModelService, datasourceName); this.emrServerlessClient = emrServerlessClient; this.flintIndexMetadataService = flintIndexMetadataService; this.flintIndexOptions = flintIndexOptions; diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java index 2317c5b6dc..f99ab26fbf 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpCancel.java @@ -9,10 +9,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; /** Cancel refreshing job for refresh query when user clicks cancel button on UI. */ public class FlintIndexOpCancel extends FlintIndexOp { @@ -21,8 +21,10 @@ public class FlintIndexOpCancel extends FlintIndexOp { private final EMRServerlessClient emrServerlessClient; public FlintIndexOpCancel( - StateStore stateStore, String datasourceName, EMRServerlessClient emrServerlessClient) { - super(stateStore, datasourceName); + FlintIndexStateModelService flintIndexStateModelService, + String datasourceName, + EMRServerlessClient emrServerlessClient) { + super(flintIndexStateModelService, datasourceName); this.emrServerlessClient = emrServerlessClient; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java index 586c346863..0aa2a3a9ac 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpDrop.java @@ -9,10 +9,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; public class FlintIndexOpDrop extends FlintIndexOp { private static final Logger LOG = LogManager.getLogger(); @@ -20,8 +20,10 @@ public class FlintIndexOpDrop extends FlintIndexOp { private final EMRServerlessClient emrServerlessClient; public FlintIndexOpDrop( - StateStore stateStore, String datasourceName, EMRServerlessClient emrServerlessClient) { - super(stateStore, datasourceName); + FlintIndexStateModelService flintIndexStateModelService, + String datasourceName, + EMRServerlessClient emrServerlessClient) { + super(flintIndexStateModelService, datasourceName); this.emrServerlessClient = emrServerlessClient; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java index cf204450e7..6e0d386664 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpVacuum.java @@ -10,10 +10,10 @@ import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.client.Client; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; /** Flint index vacuum operation. */ public class FlintIndexOpVacuum extends FlintIndexOp { @@ -23,8 +23,11 @@ public class FlintIndexOpVacuum extends FlintIndexOp { /** OpenSearch client. */ private final Client client; - public FlintIndexOpVacuum(StateStore stateStore, String datasourceName, Client client) { - super(stateStore, datasourceName); + public FlintIndexOpVacuum( + FlintIndexStateModelService flintIndexStateModelService, + String datasourceName, + Client client) { + super(flintIndexStateModelService, datasourceName); this.client = client; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index 9038870c63..e0e9283d4c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -27,8 +27,16 @@ import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl; import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher; import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.OpenSearchStatementStorageService; +import org.opensearch.sql.spark.execution.statement.StatementStorageService; +import org.opensearch.sql.spark.execution.statestore.OpensearchSessionStorageService; +import org.opensearch.sql.spark.execution.statestore.SessionStorageService; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataServiceImpl; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; +import org.opensearch.sql.spark.flint.IndexDMLResultStorageService; +import org.opensearch.sql.spark.flint.OpenSearchFlintIndexStateModelService; +import org.opensearch.sql.spark.flint.OpenSearchIndexDMLResultStorageService; import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @@ -72,7 +80,8 @@ public SparkQueryDispatcher sparkQueryDispatcher( NodeClient client, SessionManager sessionManager, DefaultLeaseManager defaultLeaseManager, - StateStore stateStore) { + FlintIndexStateModelService flintIndexStateModelService, + IndexDMLResultStorageService indexDMLResultStorageService) { return new SparkQueryDispatcher( emrServerlessClientFactory, dataSourceService, @@ -81,15 +90,44 @@ public SparkQueryDispatcher sparkQueryDispatcher( client, sessionManager, defaultLeaseManager, - stateStore); + flintIndexStateModelService, + indexDMLResultStorageService); } @Provides public SessionManager sessionManager( StateStore stateStore, + StatementStorageService statementStorageService, + SessionStorageService sessionStorageService, EMRServerlessClientFactory emrServerlessClientFactory, Settings settings) { - return new SessionManager(stateStore, emrServerlessClientFactory, settings); + return new SessionManager( + stateStore, + statementStorageService, + sessionStorageService, + emrServerlessClientFactory, + settings); + } + + @Provides + public StatementStorageService statementStorageService(StateStore stateStore) { + return new OpenSearchStatementStorageService(stateStore); + } + + @Provides + public SessionStorageService sessionStorageService(StateStore stateStore) { + return new OpensearchSessionStorageService(stateStore); + } + + @Provides + public FlintIndexStateModelService flintIndexStateModelService(StateStore stateStore) { + return new OpenSearchFlintIndexStateModelService(stateStore); + } + + @Provides + public IndexDMLResultStorageService indexDMLResultStorageService( + StateStore stateStore, DataSourceService dataSourceService) { + return new OpenSearchIndexDMLResultStorageService(dataSourceService, stateStore); } @Provides diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index f2d3bb1aa8..5ac981fcf2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -13,8 +13,6 @@ import static org.opensearch.sql.spark.execution.session.SessionModel.SESSION_DOC_TYPE; import static org.opensearch.sql.spark.execution.statement.StatementModel.SESSION_ID; import static org.opensearch.sql.spark.execution.statement.StatementModel.STATEMENT_DOC_TYPE; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; import com.google.common.collect.ImmutableMap; import java.util.HashMap; @@ -144,7 +142,7 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null)); assertNotNull(response.getSessionId()); Optional<StatementModel> statementModel = - getStatement(stateStore, MYS3_DATASOURCE).apply(response.getQueryId()); + statementStorageService.getStatementModel(response.getQueryId(), MYS3_DATASOURCE); assertTrue(statementModel.isPresent()); assertEquals(StatementState.WAITING, statementModel.get().getStatementState()); @@ -199,13 +197,13 @@ public void reuseSessionWhenCreateAsyncQuery() { .must(QueryBuilders.termQuery(SESSION_ID, first.getSessionId())))); Optional<StatementModel> firstModel = - getStatement(stateStore, MYS3_DATASOURCE).apply(first.getQueryId()); + statementStorageService.getStatementModel(first.getQueryId(), MYS3_DATASOURCE); assertTrue(firstModel.isPresent()); assertEquals(StatementState.WAITING, firstModel.get().getStatementState()); assertEquals(first.getQueryId(), firstModel.get().getStatementId().getId()); assertEquals(first.getQueryId(), firstModel.get().getQueryId()); Optional<StatementModel> secondModel = - getStatement(stateStore, MYS3_DATASOURCE).apply(second.getQueryId()); + statementStorageService.getStatementModel(second.getQueryId(), MYS3_DATASOURCE); assertEquals(StatementState.WAITING, secondModel.get().getStatementState()); assertEquals(second.getQueryId(), secondModel.get().getStatementId().getId()); assertEquals(second.getQueryId(), secondModel.get().getQueryId()); @@ -295,7 +293,7 @@ public void withSessionCreateAsyncQueryFailed() { new CreateAsyncQueryRequest("myselect 1", MYS3_DATASOURCE, LangType.SQL, null)); assertNotNull(response.getSessionId()); Optional<StatementModel> statementModel = - getStatement(stateStore, MYS3_DATASOURCE).apply(response.getQueryId()); + statementStorageService.getStatementModel(response.getQueryId(), MYS3_DATASOURCE); assertTrue(statementModel.isPresent()); assertEquals(StatementState.WAITING, statementModel.get().getStatementState()); @@ -319,7 +317,7 @@ public void withSessionCreateAsyncQueryFailed() { .seqNo(submitted.getSeqNo()) .primaryTerm(submitted.getPrimaryTerm()) .build(); - updateStatementState(stateStore, MYS3_DATASOURCE).apply(mocked, StatementState.FAILED); + statementStorageService.updateStatementState(mocked, StatementState.FAILED, MYS3_DATASOURCE); AsyncQueryExecutionResponse asyncQueryResults = asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index c4cb96391b..af608b423a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -9,8 +9,6 @@ import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SPARK_EXECUTION_REFRESH_JOB_LIMIT_SETTING; import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.SPARK_EXECUTION_SESSION_LIMIT_SETTING; import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; -import static org.opensearch.sql.spark.execution.statestore.StateStore.updateSessionState; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; @@ -62,9 +60,17 @@ import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.session.SessionModel; import org.opensearch.sql.spark.execution.session.SessionState; +import org.opensearch.sql.spark.execution.statement.OpenSearchStatementStorageService; +import org.opensearch.sql.spark.execution.statement.StatementStorageService; +import org.opensearch.sql.spark.execution.statestore.OpensearchSessionStorageService; +import org.opensearch.sql.spark.execution.statestore.SessionStorageService; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataServiceImpl; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; import org.opensearch.sql.spark.flint.FlintIndexType; +import org.opensearch.sql.spark.flint.IndexDMLResultStorageService; +import org.opensearch.sql.spark.flint.OpenSearchFlintIndexStateModelService; +import org.opensearch.sql.spark.flint.OpenSearchIndexDMLResultStorageService; import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.storage.DataSourceFactory; @@ -79,6 +85,11 @@ public class AsyncQueryExecutorServiceSpec extends OpenSearchIntegTestCase { protected NodeClient client; protected DataSourceServiceImpl dataSourceService; protected StateStore stateStore; + protected StatementStorageService statementStorageService; + protected SessionStorageService sessionStorageService; + protected FlintIndexStateModelService flintIndexStateModelService; + + protected IndexDMLResultStorageService indexDMLResultStorageService; protected ClusterSettings clusterSettings; @Override @@ -145,6 +156,11 @@ public void setup() { .build(); dataSourceService.createDataSource(otherDm); stateStore = new StateStore(client, clusterService); + statementStorageService = new OpenSearchStatementStorageService(stateStore); + sessionStorageService = new OpensearchSessionStorageService(stateStore); + flintIndexStateModelService = new OpenSearchFlintIndexStateModelService(stateStore); + indexDMLResultStorageService = + new OpenSearchIndexDMLResultStorageService(dataSourceService, stateStore); createIndexWithMappings(dm.getResultIndex(), loadResultIndexMappings()); createIndexWithMappings(otherDm.getResultIndex(), loadResultIndexMappings()); } @@ -207,9 +223,15 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( jobExecutionResponseReader, new FlintIndexMetadataServiceImpl(client), client, - new SessionManager(stateStore, emrServerlessClientFactory, pluginSettings), + new SessionManager( + stateStore, + statementStorageService, + sessionStorageService, + emrServerlessClientFactory, + pluginSettings), new DefaultLeaseManager(pluginSettings, stateStore), - stateStore); + flintIndexStateModelService, + indexDMLResultStorageService); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, @@ -314,9 +336,9 @@ int search(QueryBuilder query) { } void setSessionState(String sessionId, SessionState sessionState) { - Optional<SessionModel> model = getSession(stateStore, MYS3_DATASOURCE).apply(sessionId); + Optional<SessionModel> model = sessionStorageService.getSession(sessionId, MYS3_DATASOURCE); SessionModel updated = - updateSessionState(stateStore, MYS3_DATASOURCE).apply(model.get(), sessionState); + sessionStorageService.updateSessionState(model.get(), sessionState, MYS3_DATASOURCE); assertEquals(sessionState, updated.getSessionState()); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java index 10598d110c..ac2d6b3690 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java @@ -8,7 +8,6 @@ import static org.opensearch.action.support.WriteRequest.RefreshPolicy.WAIT_UNTIL; import static org.opensearch.sql.data.model.ExprValueUtils.tupleValue; import static org.opensearch.sql.datasource.model.DataSourceMetadata.DEFAULT_RESULT_INDEX; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; import com.amazonaws.services.emrserverless.model.JobRunState; import com.google.common.collect.ImmutableList; @@ -30,7 +29,6 @@ import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.execution.statement.StatementModel; import org.opensearch.sql.spark.execution.statement.StatementState; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexType; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; @@ -510,8 +508,9 @@ void emrJobWriteResultDoc(Map<String, Object> resultDoc) { /** Simulate EMR-S updates query_execution_request with state */ void emrJobUpdateStatementState(StatementState newState) { - StatementModel stmt = getStatement(stateStore, MYS3_DATASOURCE).apply(queryId).get(); - StateStore.updateStatementState(stateStore, MYS3_DATASOURCE).apply(stmt, newState); + StatementModel stmt = + statementStorageService.getStatementModel(queryId, MYS3_DATASOURCE).get(); + statementStorageService.updateStatementState(stmt, newState, MYS3_DATASOURCE); } void emrJobUpdateJobState(JobRunState jobState) { diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java index 4cfdb6a9a9..412b371ea2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java @@ -14,15 +14,17 @@ import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; +import org.opensearch.sql.spark.flint.OpenSearchFlintIndexStateModelService; public class MockFlintSparkJob { private FlintIndexStateModel stateModel; - private StateStore stateStore; + private FlintIndexStateModelService flintIndexStateModelService; private String datasource; public MockFlintSparkJob(StateStore stateStore, String latestId, String datasource) { assertNotNull(latestId); - this.stateStore = stateStore; + this.flintIndexStateModelService = new OpenSearchFlintIndexStateModelService(stateStore); this.datasource = datasource; stateModel = new FlintIndexStateModel( @@ -35,53 +37,54 @@ public MockFlintSparkJob(StateStore stateStore, String latestId, String datasour "", SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); - stateModel = StateStore.createFlintIndexState(stateStore, datasource).apply(stateModel); + stateModel = + this.flintIndexStateModelService.createFlintIndexStateModel(stateModel, datasource); } public void transition(FlintIndexState newState) { stateModel = - StateStore.updateFlintIndexState(stateStore, datasource).apply(stateModel, newState); + this.flintIndexStateModelService.updateFlintIndexState(stateModel, newState, datasource); } public void refreshing() { stateModel = - StateStore.updateFlintIndexState(stateStore, datasource) - .apply(stateModel, FlintIndexState.REFRESHING); + this.flintIndexStateModelService.updateFlintIndexState( + stateModel, FlintIndexState.REFRESHING, datasource); } public void active() { stateModel = - StateStore.updateFlintIndexState(stateStore, datasource) - .apply(stateModel, FlintIndexState.ACTIVE); + this.flintIndexStateModelService.updateFlintIndexState( + stateModel, FlintIndexState.ACTIVE, datasource); } public void creating() { stateModel = - StateStore.updateFlintIndexState(stateStore, datasource) - .apply(stateModel, FlintIndexState.CREATING); + this.flintIndexStateModelService.updateFlintIndexState( + stateModel, FlintIndexState.CREATING, datasource); } public void updating() { stateModel = - StateStore.updateFlintIndexState(stateStore, datasource) - .apply(stateModel, FlintIndexState.UPDATING); + this.flintIndexStateModelService.updateFlintIndexState( + stateModel, FlintIndexState.UPDATING, datasource); } public void deleting() { stateModel = - StateStore.updateFlintIndexState(stateStore, datasource) - .apply(stateModel, FlintIndexState.DELETING); + this.flintIndexStateModelService.updateFlintIndexState( + stateModel, FlintIndexState.DELETING, datasource); } public void deleted() { stateModel = - StateStore.updateFlintIndexState(stateStore, datasource) - .apply(stateModel, FlintIndexState.DELETED); + this.flintIndexStateModelService.updateFlintIndexState( + stateModel, FlintIndexState.DELETED, datasource); } public void assertState(FlintIndexState expected) { Optional<FlintIndexStateModel> stateModelOpt = - StateStore.getFlintIndexState(stateStore, datasource).apply(stateModel.getId()); + this.flintIndexStateModelService.getFlintIndexStateModel(stateModel.getId(), datasource); assertTrue((stateModelOpt.isPresent())); assertEquals(expected, stateModelOpt.get().getIndexState()); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java b/spark/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java index 80542ba2e0..008f976542 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/cluster/FlintStreamingJobHouseKeeperTaskTest.java @@ -84,7 +84,10 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, + flintIndexMetadataService, + flintIndexStateModelService, + emrServerlessClientFactory); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); @@ -152,7 +155,10 @@ public void testStreamingJobHouseKeeperWhenCancelJobGivesTimeout() { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, + flintIndexMetadataService, + flintIndexStateModelService, + emrServerlessClientFactory); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); @@ -220,7 +226,10 @@ public void testSimulateConcurrentJobHouseKeeperExecution() { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, + flintIndexMetadataService, + flintIndexStateModelService, + emrServerlessClientFactory); FlintStreamingJobHouseKeeperTask.isRunning.compareAndSet(false, true); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); @@ -299,7 +308,10 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, + flintIndexMetadataService, + flintIndexStateModelService, + emrServerlessClientFactory); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); @@ -375,7 +387,10 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, + flintIndexMetadataService, + flintIndexStateModelService, + emrServerlessClientFactory); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); @@ -417,7 +432,10 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, + flintIndexMetadataService, + flintIndexStateModelService, + emrServerlessClientFactory); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); @@ -452,7 +470,10 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, + flintIndexMetadataService, + flintIndexStateModelService, + emrServerlessClientFactory); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); @@ -493,7 +514,10 @@ public void updateIndexToManualRefresh( }; FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, + flintIndexMetadataService, + flintIndexStateModelService, + emrServerlessClientFactory); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); @@ -561,7 +585,10 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, + flintIndexMetadataService, + flintIndexStateModelService, + emrServerlessClientFactory); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); @@ -663,7 +690,10 @@ public GetJobRunResult getJobRunResult(String applicationId, String jobId) { FlintIndexMetadataService flintIndexMetadataService = new FlintIndexMetadataServiceImpl(client); FlintStreamingJobHouseKeeperTask flintStreamingJobHouseKeeperTask = new FlintStreamingJobHouseKeeperTask( - dataSourceService, flintIndexMetadataService, stateStore, emrServerlessClientFactory); + dataSourceService, + flintIndexMetadataService, + flintIndexStateModelService, + emrServerlessClientFactory); Thread thread = new Thread(flintStreamingJobHouseKeeperTask); thread.start(); thread.join(); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java index 045de66d0a..df4ce67c82 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandlerTest.java @@ -33,10 +33,11 @@ import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; import org.opensearch.sql.spark.flint.FlintIndexType; +import org.opensearch.sql.spark.flint.IndexDMLResultStorageService; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; @@ -46,13 +47,14 @@ class IndexDMLHandlerTest { @Mock private EMRServerlessClient emrServerlessClient; @Mock private JobExecutionResponseReader jobExecutionResponseReader; @Mock private FlintIndexMetadataService flintIndexMetadataService; - @Mock private StateStore stateStore; + @Mock private FlintIndexStateModelService flintIndexStateModelService; + @Mock private IndexDMLResultStorageService indexDMLResultStorageService; @Mock private Client client; @Test public void getResponseFromExecutor() { JSONObject result = - new IndexDMLHandler(null, null, null, null, null).getResponseFromExecutor(null); + new IndexDMLHandler(null, null, null, null, null, null).getResponseFromExecutor(null); assertEquals("running", result.getString(STATUS_FIELD)); assertEquals("", result.getString(ERROR_FIELD)); @@ -65,7 +67,8 @@ public void testWhenIndexDetailsAreNotFound() { emrServerlessClient, jobExecutionResponseReader, flintIndexMetadataService, - stateStore, + flintIndexStateModelService, + indexDMLResultStorageService, client); DispatchQueryRequest dispatchQueryRequest = new DispatchQueryRequest( @@ -107,7 +110,8 @@ public void testWhenIndexDetailsWithInvalidQueryActionType() { emrServerlessClient, jobExecutionResponseReader, flintIndexMetadataService, - stateStore, + flintIndexStateModelService, + indexDMLResultStorageService, client); DispatchQueryRequest dispatchQueryRequest = new DispatchQueryRequest( diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 1f250a0aea..3df5b23dc4 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -72,8 +72,9 @@ import org.opensearch.sql.spark.execution.statement.Statement; import org.opensearch.sql.spark.execution.statement.StatementId; import org.opensearch.sql.spark.execution.statement.StatementState; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataService; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; +import org.opensearch.sql.spark.flint.IndexDMLResultStorageService; import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; @@ -100,7 +101,9 @@ public class SparkQueryDispatcherTest { @Mock(answer = RETURNS_DEEP_STUBS) private Statement statement; - @Mock private StateStore stateStore; + @Mock private FlintIndexStateModelService flintIndexStateModelService; + + @Mock private IndexDMLResultStorageService indexDMLResultStorageService; private SparkQueryDispatcher sparkQueryDispatcher; @@ -119,7 +122,8 @@ void setUp() { openSearchClient, sessionManager, leaseManager, - stateStore); + flintIndexStateModelService, + indexDMLResultStorageService); when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 8fca190cd6..9a646e0cda 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -9,7 +9,6 @@ import static org.opensearch.sql.spark.execution.session.SessionManagerTest.sessionSetting; import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; @@ -26,6 +25,10 @@ import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.client.StartJobRequest; import org.opensearch.sql.spark.dispatcher.model.JobType; +import org.opensearch.sql.spark.execution.statement.OpenSearchStatementStorageService; +import org.opensearch.sql.spark.execution.statement.StatementStorageService; +import org.opensearch.sql.spark.execution.statestore.OpensearchSessionStorageService; +import org.opensearch.sql.spark.execution.statestore.SessionStorageService; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.test.OpenSearchIntegTestCase; @@ -39,12 +42,16 @@ public class InteractiveSessionTest extends OpenSearchIntegTestCase { private TestEMRServerlessClient emrsClient; private StartJobRequest startJobRequest; private StateStore stateStore; + private StatementStorageService statementStorageService; + private SessionStorageService sessionStorageService; @Before public void setup() { emrsClient = new TestEMRServerlessClient(); startJobRequest = new StartJobRequest("", "appId", "", "", new HashMap<>(), false, ""); stateStore = new StateStore(client(), clusterService()); + statementStorageService = new OpenSearchStatementStorageService(stateStore); + sessionStorageService = new OpensearchSessionStorageService(stateStore); } @After @@ -61,11 +68,13 @@ public void openCloseSession() { InteractiveSession.builder() .sessionId(sessionId) .stateStore(stateStore) + .statementStorageService(statementStorageService) + .sessionStorageService(sessionStorageService) .serverlessClient(emrsClient) .build(); // open session - TestSession testSession = testSession(session, stateStore); + TestSession testSession = testSession(session, sessionStorageService); testSession .open(createSessionRequest()) .assertSessionState(NOT_STARTED) @@ -87,6 +96,8 @@ public void openSessionFailedConflict() { InteractiveSession.builder() .sessionId(sessionId) .stateStore(stateStore) + .sessionStorageService(sessionStorageService) + .statementStorageService(statementStorageService) .serverlessClient(emrsClient) .build(); session.open(createSessionRequest()); @@ -95,6 +106,8 @@ public void openSessionFailedConflict() { InteractiveSession.builder() .sessionId(sessionId) .stateStore(stateStore) + .statementStorageService(statementStorageService) + .sessionStorageService(sessionStorageService) .serverlessClient(emrsClient) .build(); IllegalStateException exception = @@ -110,6 +123,8 @@ public void closeNotExistSession() { InteractiveSession.builder() .sessionId(sessionId) .stateStore(stateStore) + .statementStorageService(statementStorageService) + .sessionStorageService(sessionStorageService) .serverlessClient(emrsClient) .build(); session.open(createSessionRequest()); @@ -125,10 +140,15 @@ public void closeNotExistSession() { public void sessionManagerCreateSession() { EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) + new SessionManager( + stateStore, + statementStorageService, + sessionStorageService, + emrServerlessClientFactory, + sessionSetting()) .createSession(createSessionRequest()); - TestSession testSession = testSession(session, stateStore); + TestSession testSession = testSession(session, sessionStorageService); testSession.assertSessionState(NOT_STARTED).assertAppId("appId").assertJobId("jobId"); } @@ -136,7 +156,12 @@ public void sessionManagerCreateSession() { public void sessionManagerGetSession() { EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; SessionManager sessionManager = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()); + new SessionManager( + stateStore, + statementStorageService, + sessionStorageService, + emrServerlessClientFactory, + sessionSetting()); Session session = sessionManager.createSession(createSessionRequest()); Optional<Session> managerSession = sessionManager.getSession(session.getSessionId()); @@ -148,7 +173,12 @@ public void sessionManagerGetSession() { public void sessionManagerGetSessionNotExist() { EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; SessionManager sessionManager = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()); + new SessionManager( + stateStore, + statementStorageService, + sessionStorageService, + emrServerlessClientFactory, + sessionSetting()); Optional<Session> managerSession = sessionManager.getSession(SessionId.newSessionId("no-exist")); @@ -158,17 +188,18 @@ public void sessionManagerGetSessionNotExist() { @RequiredArgsConstructor static class TestSession { private final Session session; - private final StateStore stateStore; + private final SessionStorageService sessionStorageService; - public static TestSession testSession(Session session, StateStore stateStore) { - return new TestSession(session, stateStore); + public static TestSession testSession( + Session session, SessionStorageService sessionStorageService) { + return new TestSession(session, sessionStorageService); } public TestSession assertSessionState(SessionState expected) { assertEquals(expected, session.getSessionModel().getSessionState()); Optional<SessionModel> sessionStoreState = - getSession(stateStore, DS_NAME).apply(session.getSessionModel().getId()); + sessionStorageService.getSession(session.getSessionModel().getId(), DS_NAME); assertTrue(sessionStoreState.isPresent()); assertEquals(expected, sessionStoreState.get().getSessionState()); diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java index d021bc7248..f4e9fa4e8b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -15,6 +15,8 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; +import org.opensearch.sql.spark.execution.statement.StatementStorageService; +import org.opensearch.sql.spark.execution.statestore.SessionStorageService; import org.opensearch.sql.spark.execution.statestore.StateStore; @ExtendWith(MockitoExtension.class) @@ -23,10 +25,20 @@ public class SessionManagerTest { @Mock private EMRServerlessClientFactory emrServerlessClientFactory; + @Mock private StatementStorageService statementStorageService; + + @Mock private SessionStorageService sessionStorageService; + @Test public void sessionEnable() { Assertions.assertTrue( - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()).isEnabled()); + new SessionManager( + stateStore, + statementStorageService, + sessionStorageService, + emrServerlessClientFactory, + sessionSetting()) + .isEnabled()); } public static org.opensearch.sql.common.setting.Settings sessionSetting() { diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index 3a69fa01d7..29a2c9f87c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -12,9 +12,6 @@ import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.updateSessionState; -import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; import java.util.Optional; import lombok.RequiredArgsConstructor; @@ -30,6 +27,8 @@ import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.session.SessionState; +import org.opensearch.sql.spark.execution.statestore.OpensearchSessionStorageService; +import org.opensearch.sql.spark.execution.statestore.SessionStorageService; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.test.OpenSearchIntegTestCase; @@ -40,12 +39,16 @@ public class StatementTest extends OpenSearchIntegTestCase { private static final String indexName = DATASOURCE_TO_REQUEST_INDEX.apply(DS_NAME); private StateStore stateStore; + private StatementStorageService statementStorageService; + private SessionStorageService sessionStorageService; private InteractiveSessionTest.TestEMRServerlessClient emrsClient = new InteractiveSessionTest.TestEMRServerlessClient(); @Before public void setup() { stateStore = new StateStore(client(), clusterService()); + statementStorageService = new OpenSearchStatementStorageService(stateStore); + sessionStorageService = new OpensearchSessionStorageService(stateStore); } @After @@ -68,10 +71,11 @@ public void openThenCancelStatement() { .query("query") .queryId("statementId") .stateStore(stateStore) + .statementStorageService(statementStorageService) .build(); // submit statement - TestStatement testStatement = testStatement(st, stateStore); + TestStatement testStatement = testStatement(st, statementStorageService); testStatement .open() .assertSessionState(WAITING) @@ -94,6 +98,7 @@ public void openFailedBecauseConflict() { .query("query") .queryId("statementId") .stateStore(stateStore) + .statementStorageService(statementStorageService) .build(); st.open(); @@ -109,6 +114,7 @@ public void openFailedBecauseConflict() { .query("query") .queryId("statementId") .stateStore(stateStore) + .statementStorageService(statementStorageService) .build(); IllegalStateException exception = assertThrows(IllegalStateException.class, dupSt::open); assertEquals("statement already exist. statementId=statementId", exception.getMessage()); @@ -128,6 +134,7 @@ public void cancelNotExistStatement() { .query("query") .queryId("statementId") .stateStore(stateStore) + .statementStorageService(statementStorageService) .build(); st.open(); @@ -153,12 +160,12 @@ public void cancelFailedBecauseOfConflict() { .query("query") .queryId("statementId") .stateStore(stateStore) + .statementStorageService(statementStorageService) .build(); st.open(); StatementModel running = - updateStatementState(stateStore, DS_NAME).apply(st.getStatementModel(), CANCELLED); - + statementStorageService.updateStatementState(st.getStatementModel(), CANCELLED, DS_NAME); assertEquals(StatementState.CANCELLED, running.getStatementState()); // cancel conflict @@ -242,10 +249,11 @@ public void cancelRunningStatementSuccess() { .query("query") .queryId("statementId") .stateStore(stateStore) + .statementStorageService(statementStorageService) .build(); // submit statement - TestStatement testStatement = testStatement(st, stateStore); + TestStatement testStatement = testStatement(st, statementStorageService); testStatement .open() .assertSessionState(WAITING) @@ -261,11 +269,17 @@ public void cancelRunningStatementSuccess() { public void submitStatementInRunningSession() { EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) + new SessionManager( + stateStore, + statementStorageService, + sessionStorageService, + emrServerlessClientFactory, + sessionSetting()) .createSession(createSessionRequest()); // App change state to running - updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); + sessionStorageService.updateSessionState( + session.getSessionModel(), SessionState.RUNNING, DS_NAME); StatementId statementId = session.submit(queryRequest()); assertFalse(statementId.getId().isEmpty()); @@ -275,7 +289,12 @@ public void submitStatementInRunningSession() { public void submitStatementInNotStartedState() { EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) + new SessionManager( + stateStore, + statementStorageService, + sessionStorageService, + emrServerlessClientFactory, + sessionSetting()) .createSession(createSessionRequest()); StatementId statementId = session.submit(queryRequest()); @@ -286,10 +305,15 @@ public void submitStatementInNotStartedState() { public void failToSubmitStatementInDeadState() { EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) + new SessionManager( + stateStore, + statementStorageService, + sessionStorageService, + emrServerlessClientFactory, + sessionSetting()) .createSession(createSessionRequest()); - updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.DEAD); + sessionStorageService.updateSessionState(session.getSessionModel(), SessionState.DEAD, DS_NAME); IllegalStateException exception = assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); @@ -303,10 +327,15 @@ public void failToSubmitStatementInDeadState() { public void failToSubmitStatementInFailState() { EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) + new SessionManager( + stateStore, + statementStorageService, + sessionStorageService, + emrServerlessClientFactory, + sessionSetting()) .createSession(createSessionRequest()); - updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.FAIL); + sessionStorageService.updateSessionState(session.getSessionModel(), SessionState.FAIL, DS_NAME); IllegalStateException exception = assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); @@ -320,7 +349,12 @@ public void failToSubmitStatementInFailState() { public void newStatementFieldAssert() { EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) + new SessionManager( + stateStore, + statementStorageService, + sessionStorageService, + emrServerlessClientFactory, + sessionSetting()) .createSession(createSessionRequest()); StatementId statementId = session.submit(queryRequest()); Optional<Statement> statement = session.get(statementId); @@ -339,7 +373,12 @@ public void newStatementFieldAssert() { public void failToSubmitStatementInDeletedSession() { EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) + new SessionManager( + stateStore, + statementStorageService, + sessionStorageService, + emrServerlessClientFactory, + sessionSetting()) .createSession(createSessionRequest()); // other's delete session @@ -356,10 +395,16 @@ public void failToSubmitStatementInDeletedSession() { public void getStatementSuccess() { EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) + new SessionManager( + stateStore, + statementStorageService, + sessionStorageService, + emrServerlessClientFactory, + sessionSetting()) .createSession(createSessionRequest()); // App change state to running - updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); + sessionStorageService.updateSessionState( + session.getSessionModel(), SessionState.RUNNING, DS_NAME); StatementId statementId = session.submit(queryRequest()); Optional<Statement> statement = session.get(statementId); @@ -372,10 +417,16 @@ public void getStatementSuccess() { public void getStatementNotExist() { EMRServerlessClientFactory emrServerlessClientFactory = () -> emrsClient; Session session = - new SessionManager(stateStore, emrServerlessClientFactory, sessionSetting()) + new SessionManager( + stateStore, + statementStorageService, + sessionStorageService, + emrServerlessClientFactory, + sessionSetting()) .createSession(createSessionRequest()); // App change state to running - updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); + sessionStorageService.updateSessionState( + session.getSessionModel(), SessionState.RUNNING, DS_NAME); Optional<Statement> statement = session.get(StatementId.newStatementId("not-exist-id")); assertFalse(statement.isPresent()); @@ -384,17 +435,18 @@ public void getStatementNotExist() { @RequiredArgsConstructor static class TestStatement { private final Statement st; - private final StateStore stateStore; + private final StatementStorageService statementStorageService; - public static TestStatement testStatement(Statement st, StateStore stateStore) { - return new TestStatement(st, stateStore); + public static TestStatement testStatement( + Statement st, StatementStorageService statementStorageService) { + return new TestStatement(st, statementStorageService); } public TestStatement assertSessionState(StatementState expected) { assertEquals(expected, st.getStatementModel().getStatementState()); Optional<StatementModel> model = - getStatement(stateStore, DS_NAME).apply(st.getStatementId().getId()); + statementStorageService.getStatementModel(st.getStatementId().getId(), DS_NAME); assertTrue(model.isPresent()); assertEquals(expected, model.get().getStatementState()); @@ -405,7 +457,7 @@ public TestStatement assertStatementId(StatementId expected) { assertEquals(expected, st.getStatementModel().getStatementId()); Optional<StatementModel> model = - getStatement(stateStore, DS_NAME).apply(st.getStatementId().getId()); + statementStorageService.getStatementModel(st.getStatementId().getId(), DS_NAME); assertTrue(model.isPresent()); assertEquals(expected, model.get().getStatementId()); return this; @@ -423,7 +475,7 @@ public TestStatement cancel() { public TestStatement run() { StatementModel model = - updateStatementState(stateStore, DS_NAME).apply(st.getStatementModel(), RUNNING); + statementStorageService.updateStatementState(st.getStatementModel(), RUNNING, DS_NAME); st.setStatementModel(model); return this; } @@ -445,6 +497,7 @@ private Statement createStatement(StatementId stId) { .query("query") .queryId("statementId") .stateStore(stateStore) + .statementStorageService(statementStorageService) .build(); st.open(); return st; diff --git a/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java b/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java index 5755d03baa..523f4dc84b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java @@ -4,7 +4,6 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import java.util.Optional; import org.junit.jupiter.api.Assertions; @@ -13,15 +12,15 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.index.seqno.SequenceNumbers; -import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; +import org.opensearch.sql.spark.flint.FlintIndexStateModelService; @ExtendWith(MockitoExtension.class) public class FlintIndexOpTest { - @Mock private StateStore mockStateStore; + @Mock private FlintIndexStateModelService flintIndexStateModelService; @Test public void testApplyWithTransitioningStateFailure() { @@ -38,11 +37,11 @@ public void testApplyWithTransitioningStateFailure() { "", SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); - when(mockStateStore.get(eq("latestId"), any(), eq(DATASOURCE_TO_REQUEST_INDEX.apply("myS3")))) + when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), eq("myS3"))) .thenReturn(Optional.of(fakeModel)); - when(mockStateStore.updateState(any(), any(), any(), any())) + when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) .thenThrow(new RuntimeException("Transitioning state failed")); - FlintIndexOp flintIndexOp = new TestFlintIndexOp(mockStateStore, "myS3"); + FlintIndexOp flintIndexOp = new TestFlintIndexOp(flintIndexStateModelService, "myS3"); IllegalStateException illegalStateException = Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); Assertions.assertEquals( @@ -64,13 +63,13 @@ public void testApplyWithCommitFailure() { "", SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); - when(mockStateStore.get(eq("latestId"), any(), eq(DATASOURCE_TO_REQUEST_INDEX.apply("myS3")))) + when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), eq("myS3"))) .thenReturn(Optional.of(fakeModel)); - when(mockStateStore.updateState(any(), any(), any(), any())) + when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) .thenReturn(FlintIndexStateModel.copy(fakeModel, 1, 2)) .thenThrow(new RuntimeException("Commit state failed")) .thenReturn(FlintIndexStateModel.copy(fakeModel, 1, 3)); - FlintIndexOp flintIndexOp = new TestFlintIndexOp(mockStateStore, "myS3"); + FlintIndexOp flintIndexOp = new TestFlintIndexOp(flintIndexStateModelService, "myS3"); IllegalStateException illegalStateException = Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); Assertions.assertEquals( @@ -92,13 +91,13 @@ public void testApplyWithRollBackFailure() { "", SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); - when(mockStateStore.get(eq("latestId"), any(), eq(DATASOURCE_TO_REQUEST_INDEX.apply("myS3")))) + when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), eq("myS3"))) .thenReturn(Optional.of(fakeModel)); - when(mockStateStore.updateState(any(), any(), any(), any())) + when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) .thenReturn(FlintIndexStateModel.copy(fakeModel, 1, 2)) .thenThrow(new RuntimeException("Commit state failed")) .thenThrow(new RuntimeException("Rollback failure")); - FlintIndexOp flintIndexOp = new TestFlintIndexOp(mockStateStore, "myS3"); + FlintIndexOp flintIndexOp = new TestFlintIndexOp(flintIndexStateModelService, "myS3"); IllegalStateException illegalStateException = Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); Assertions.assertEquals( @@ -107,8 +106,9 @@ public void testApplyWithRollBackFailure() { static class TestFlintIndexOp extends FlintIndexOp { - public TestFlintIndexOp(StateStore stateStore, String datasourceName) { - super(stateStore, datasourceName); + public TestFlintIndexOp( + FlintIndexStateModelService flintIndexStateModelService, String datasourceName) { + super(flintIndexStateModelService, datasourceName); } @Override