Skip to content

Commit

Permalink
Merge branch 'main' into add-spark-submit-parameters-to-cluster-settings
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Oct 5, 2023
2 parents b67f6e2 + 5df6105 commit cab57eb
Show file tree
Hide file tree
Showing 16 changed files with 539 additions and 89 deletions.
8 changes: 5 additions & 3 deletions plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
import org.opensearch.sql.spark.client.EmrServerlessClientImplEMR;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction;
import org.opensearch.sql.spark.storage.SparkStorageFactory;
Expand Down Expand Up @@ -297,14 +298,15 @@ private DataSourceServiceImpl createDataSourceService() {
private AsyncQueryExecutorService createAsyncQueryExecutorService() {
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService =
new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService);
EMRServerlessClient EMRServerlessClient = createEMRServerlessClient();
EMRServerlessClient emrServerlessClient = createEMRServerlessClient();
JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
EMRServerlessClient,
emrServerlessClient,
this.dataSourceService,
new DataSourceUserAuthorizationHelperImpl(client),
jobExecutionResponseReader);
jobExecutionResponseReader,
new FlintIndexMetadataReaderImpl(client));
return new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService, sparkQueryDispatcher, pluginSettings);
}
Expand Down
4 changes: 3 additions & 1 deletion spark/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ jacocoTestCoverageVerification {
'org.opensearch.sql.spark.transport.model.*',
'org.opensearch.sql.spark.asyncquery.model.*',
'org.opensearch.sql.spark.asyncquery.exceptions.*',
'org.opensearch.sql.spark.dispatcher.model.*'
'org.opensearch.sql.spark.dispatcher.model.*',
'org.opensearch.sql.spark.flint.FlintIndexType',
'org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl'
]
limit {
counter = 'LINE'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
import org.opensearch.sql.spark.functions.response.DefaultSparkSqlFunctionResponseHandle;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse;
Expand Down Expand Up @@ -64,7 +65,7 @@ public CreateAsyncQueryResponse createAsyncQuery(
SparkExecutionEngineConfig.toSparkExecutionEngineConfig(
sparkExecutionEngineConfigString));
ClusterName clusterName = settings.getSettingValue(CLUSTER_NAME);
String jobId =
DispatchQueryResponse dispatchQueryResponse =
sparkQueryDispatcher.dispatch(
new DispatchQueryRequest(
sparkExecutionEngineConfig.getApplicationId(),
Expand All @@ -75,8 +76,11 @@ public CreateAsyncQueryResponse createAsyncQuery(
clusterName.value(),
sparkExecutionEngineConfig.getSparkSubmitParameters()));
asyncQueryJobMetadataStorageService.storeJobMetadata(
new AsyncQueryJobMetadata(jobId, sparkExecutionEngineConfig.getApplicationId()));
return new CreateAsyncQueryResponse(jobId);
new AsyncQueryJobMetadata(
sparkExecutionEngineConfig.getApplicationId(),
dispatchQueryResponse.getJobId(),
dispatchQueryResponse.isDropIndexQuery()));
return new CreateAsyncQueryResponse(dispatchQueryResponse.getJobId());
}

@Override
Expand All @@ -85,9 +89,7 @@ public AsyncQueryExecutionResponse getAsyncQueryResults(String queryId) {
Optional<AsyncQueryJobMetadata> jobMetadata =
asyncQueryJobMetadataStorageService.getJobMetadata(queryId);
if (jobMetadata.isPresent()) {
JSONObject jsonObject =
sparkQueryDispatcher.getQueryResponse(
jobMetadata.get().getApplicationId(), jobMetadata.get().getJobId());
JSONObject jsonObject = sparkQueryDispatcher.getQueryResponse(jobMetadata.get());
if (JobRunState.SUCCESS.toString().equals(jsonObject.getString("status"))) {
DefaultSparkSqlFunctionResponseHandle sparkSqlFunctionResponseHandle =
new DefaultSparkSqlFunctionResponseHandle(jsonObject);
Expand All @@ -109,8 +111,7 @@ public String cancelQuery(String queryId) {
Optional<AsyncQueryJobMetadata> asyncQueryJobMetadata =
asyncQueryJobMetadataStorageService.getJobMetadata(queryId);
if (asyncQueryJobMetadata.isPresent()) {
return sparkQueryDispatcher.cancelJob(
asyncQueryJobMetadata.get().getApplicationId(), queryId);
return sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata.get());
}
throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.io.IOException;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.DeprecationHandler;
Expand All @@ -23,9 +24,17 @@
/** This class models all the metadata required for a job. */
@Data
@AllArgsConstructor
@EqualsAndHashCode
public class AsyncQueryJobMetadata {
private String jobId;
private String applicationId;
private String jobId;
private boolean isDropIndexQuery;

public AsyncQueryJobMetadata(String applicationId, String jobId) {
this.applicationId = applicationId;
this.jobId = jobId;
this.isDropIndexQuery = false;
}

@Override
public String toString() {
Expand All @@ -44,6 +53,7 @@ public static XContentBuilder convertToXContent(AsyncQueryJobMetadata metadata)
builder.startObject();
builder.field("jobId", metadata.getJobId());
builder.field("applicationId", metadata.getApplicationId());
builder.field("isDropIndexQuery", metadata.isDropIndexQuery());
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -77,6 +87,7 @@ public static AsyncQueryJobMetadata toJobMetadata(String json) throws IOExceptio
public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws IOException {
String jobId = null;
String applicationId = null;
boolean isDropIndexQuery = false;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
Expand All @@ -88,13 +99,16 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws
case "applicationId":
applicationId = parser.textOrNull();
break;
case "isDropIndexQuery":
isDropIndexQuery = parser.booleanValue();
break;
default:
throw new IllegalArgumentException("Unknown field: " + fieldName);
}
}
if (jobId == null || applicationId == null) {
throw new IllegalArgumentException("jobId and applicationId are required fields.");
}
return new AsyncQueryJobMetadata(jobId, applicationId);
return new AsyncQueryJobMetadata(applicationId, jobId, isDropIndexQuery);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@
import java.util.HashMap;
import java.util.Map;
import lombok.AllArgsConstructor;
import org.apache.commons.lang3.RandomStringUtils;
import org.json.JSONObject;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName;
import org.opensearch.sql.spark.dispatcher.model.IndexDetails;
import org.opensearch.sql.spark.flint.FlintIndexMetadataReader;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.rest.model.LangType;
import org.opensearch.sql.spark.utils.SQLQueryUtils;
Expand All @@ -42,49 +46,64 @@ public class SparkQueryDispatcher {

private JobExecutionResponseReader jobExecutionResponseReader;

public String dispatch(DispatchQueryRequest dispatchQueryRequest) {
return emrServerlessClient.startJobRun(getStartJobRequest(dispatchQueryRequest));
private FlintIndexMetadataReader flintIndexMetadataReader;

public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) {
if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) {
return handleSQLQuery(dispatchQueryRequest);
} else {
// Since we don't need any extra handling for PPL, we are treating it as normal dispatch
// Query.
return handleNonIndexQuery(dispatchQueryRequest);
}
}

// TODO : Fetch from Result Index and then make call to EMR Serverless.
public JSONObject getQueryResponse(String applicationId, String queryId) {
GetJobRunResult getJobRunResult = emrServerlessClient.getJobRunResult(applicationId, queryId);
public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) {
GetJobRunResult getJobRunResult =
emrServerlessClient.getJobRunResult(
asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId());
JSONObject result = new JSONObject();
if (getJobRunResult.getJobRun().getState().equals(JobRunState.SUCCESS.toString())) {
result = jobExecutionResponseReader.getResultFromOpensearchIndex(queryId);
result =
jobExecutionResponseReader.getResultFromOpensearchIndex(asyncQueryJobMetadata.getJobId());
}
result.put("status", getJobRunResult.getJobRun().getState());
return result;
}

public String cancelJob(String applicationId, String jobId) {
CancelJobRunResult cancelJobRunResult = emrServerlessClient.cancelJobRun(applicationId, jobId);
public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
CancelJobRunResult cancelJobRunResult =
emrServerlessClient.cancelJobRun(
asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId());
return cancelJobRunResult.getJobRunId();
}

// we currently don't support index queries in PPL language.
// so we are treating all of them as non-index queries which don't require any kind of query
// parsing.
private StartJobRequest getStartJobRequest(DispatchQueryRequest dispatchQueryRequest) {
if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) {
if (SQLQueryUtils.isIndexQuery(dispatchQueryRequest.getQuery()))
return getStartJobRequestForIndexRequest(dispatchQueryRequest);
else {
return getStartJobRequestForNonIndexQueries(dispatchQueryRequest);
private DispatchQueryResponse handleSQLQuery(DispatchQueryRequest dispatchQueryRequest) {
if (SQLQueryUtils.isIndexQuery(dispatchQueryRequest.getQuery())) {
IndexDetails indexDetails =
SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery());
if (indexDetails.isDropIndex()) {
return handleDropIndexQuery(dispatchQueryRequest, indexDetails);
} else {
return handleIndexQuery(dispatchQueryRequest, indexDetails);
}
} else {
return getStartJobRequestForNonIndexQueries(dispatchQueryRequest);
return handleNonIndexQuery(dispatchQueryRequest);
}
}

private StartJobRequest getStartJobRequestForNonIndexQueries(
DispatchQueryRequest dispatchQueryRequest) {
StartJobRequest startJobRequest;
private DispatchQueryResponse handleIndexQuery(
DispatchQueryRequest dispatchQueryRequest, IndexDetails indexDetails) {
FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName();
dataSourceUserAuthorizationHelper.authorizeDataSource(
this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()));
String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query";
String jobName = dispatchQueryRequest.getClusterName() + ":" + "index-query";
Map<String, String> tags = getDefaultTagsForJobSubmission(dispatchQueryRequest);
startJobRequest =
tags.put(INDEX_TAG_KEY, indexDetails.getIndexName());
tags.put(TABLE_TAG_KEY, fullyQualifiedTableName.getTableName());
tags.put(SCHEMA_TAG_KEY, fullyQualifiedTableName.getSchemaName());
StartJobRequest startJobRequest =
new StartJobRequest(
dispatchQueryRequest.getQuery(),
jobName,
Expand All @@ -94,27 +113,22 @@ private StartJobRequest getStartJobRequestForNonIndexQueries(
.dataSource(
dataSourceService.getRawDataSourceMetadata(
dispatchQueryRequest.getDatasource()))
.structuredStreaming(indexDetails.getAutoRefresh())
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams())
.build()
.toString(),
tags,
false);
return startJobRequest;
indexDetails.getAutoRefresh());
String jobId = emrServerlessClient.startJobRun(startJobRequest);
return new DispatchQueryResponse(jobId, false);
}

private StartJobRequest getStartJobRequestForIndexRequest(
DispatchQueryRequest dispatchQueryRequest) {
StartJobRequest startJobRequest;
IndexDetails indexDetails = SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery());
FullyQualifiedTableName fullyQualifiedTableName = indexDetails.getFullyQualifiedTableName();
private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQueryRequest) {
dataSourceUserAuthorizationHelper.authorizeDataSource(
this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()));
String jobName = dispatchQueryRequest.getClusterName() + ":" + "index-query";
String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query";
Map<String, String> tags = getDefaultTagsForJobSubmission(dispatchQueryRequest);
tags.put(INDEX_TAG_KEY, indexDetails.getIndexName());
tags.put(TABLE_TAG_KEY, fullyQualifiedTableName.getTableName());
tags.put(SCHEMA_TAG_KEY, fullyQualifiedTableName.getSchemaName());
startJobRequest =
StartJobRequest startJobRequest =
new StartJobRequest(
dispatchQueryRequest.getQuery(),
jobName,
Expand All @@ -124,13 +138,23 @@ private StartJobRequest getStartJobRequestForIndexRequest(
.dataSource(
dataSourceService.getRawDataSourceMetadata(
dispatchQueryRequest.getDatasource()))
.structuredStreaming(indexDetails.getAutoRefresh())
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams())
.build()
.toString(),
tags,
indexDetails.getAutoRefresh());
return startJobRequest;
false);
String jobId = emrServerlessClient.startJobRun(startJobRequest);
return new DispatchQueryResponse(jobId, false);
}

private DispatchQueryResponse handleDropIndexQuery(
DispatchQueryRequest dispatchQueryRequest, IndexDetails indexDetails) {
dataSourceUserAuthorizationHelper.authorizeDataSource(
this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()));
String jobId = flintIndexMetadataReader.getJobIdFromFlintIndexMetadata(indexDetails);
emrServerlessClient.cancelJobRun(dispatchQueryRequest.getApplicationId(), jobId);
String dropIndexDummyJobId = RandomStringUtils.randomAlphanumeric(16);
return new DispatchQueryResponse(dropIndexDummyJobId, true);
}

private static Map<String, String> getDefaultTagsForJobSubmission(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.opensearch.sql.spark.dispatcher.model;

import lombok.AllArgsConstructor;
import lombok.Data;

@Data
@AllArgsConstructor
public class DispatchQueryResponse {
private String jobId;
private boolean isDropIndexQuery;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,22 @@

package org.opensearch.sql.spark.dispatcher.model;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.opensearch.sql.spark.flint.FlintIndexType;

/** Index details in an async query. */
@Data
@AllArgsConstructor
@NoArgsConstructor
@EqualsAndHashCode
public class IndexDetails {
private String indexName;
private FullyQualifiedTableName fullyQualifiedTableName;
// by default, auto_refresh = false;
private Boolean autoRefresh = false;
private boolean isDropIndex;
private FlintIndexType indexType;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.opensearch.sql.spark.flint;

import org.opensearch.sql.spark.dispatcher.model.IndexDetails;

/** Interface for FlintIndexMetadataReader */
public interface FlintIndexMetadataReader {

/**
* Given Index details, get the streaming job Id.
*
* @param indexDetails indexDetails.
* @return jobId.
*/
String getJobIdFromFlintIndexMetadata(IndexDetails indexDetails);
}
Loading

0 comments on commit cab57eb

Please sign in to comment.