Skip to content

Commit

Permalink
Fix bug, using basic instead of basicauth (#2342)
Browse files Browse the repository at this point in the history
* Fix bug, using basic instead of basicauth

Signed-off-by: Peng Huo <[email protected]>

* fix codestyle

Signed-off-by: Peng Huo <[email protected]>

* fix IT failure: datasourceWithBasicAuth

Signed-off-by: Peng Huo <[email protected]>

* fix UT

Signed-off-by: Peng Huo <[email protected]>

---------

Signed-off-by: Peng Huo <[email protected]>
(cherry picked from commit a27e733)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] committed Oct 24, 2023
1 parent f3fdead commit f1c8c53
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
public class SparkSubmitParameters {
public static final String SPACE = " ";
public static final String EQUALS = "=";
public static final String FLINT_BASIC_AUTH = "basic";

private final String className;
private final Map<String, String> config;
Expand Down Expand Up @@ -114,7 +115,7 @@ private void setFlintIndexStoreAuthProperties(
Supplier<String> password,
Supplier<String> region) {
if (AuthenticationType.get(authType).equals(AuthenticationType.BASICAUTH)) {
config.put(FLINT_INDEX_STORE_AUTH_KEY, authType);
config.put(FLINT_INDEX_STORE_AUTH_KEY, FLINT_BASIC_AUTH);
config.put(FLINT_INDEX_STORE_AUTH_USERNAME, userName.get());
config.put(FLINT_INDEX_STORE_AUTH_PASSWORD, password.get());
} else if (AuthenticationType.get(authType).equals(AuthenticationType.AWSSIGV4AUTH)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob
Statement statement = getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId);
StatementState statementState = statement.getStatementState();
result.put(STATUS_FIELD, statementState.getState());
result.put(ERROR_FIELD, "");
result.put(ERROR_FIELD, Optional.of(statement.getStatementModel().getError()).orElse(""));
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,40 @@ public class CreateSessionRequest {
private final String datasourceName;

public StartJobRequest getStartJobRequest() {
return new StartJobRequest(
return new InteractiveSessionStartJobRequest(
"select 1",
jobName,
applicationId,
executionRoleArn,
sparkSubmitParametersBuilder.build().toString(),
tags,
false,
resultIndex);
}

static class InteractiveSessionStartJobRequest extends StartJobRequest {
public InteractiveSessionStartJobRequest(
String query,
String jobName,
String applicationId,
String executionRoleArn,
String sparkSubmitParams,
Map<String, String> tags,
String resultIndex) {
super(
query,
jobName,
applicationId,
executionRoleArn,
sparkSubmitParams,
tags,
false,
resultIndex);
}

/** Interactive query keep running. */
@Override
public Long executionTimeout() {
return 0L;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import static org.opensearch.sql.spark.execution.statement.StatementModel.STATEMENT_DOC_TYPE;
import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX;
import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement;
import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState;

import com.amazonaws.services.emrserverless.model.CancelJobRunResult;
import com.amazonaws.services.emrserverless.model.GetJobRunResult;
Expand All @@ -26,7 +27,9 @@
import com.google.common.collect.ImmutableSet;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import lombok.Getter;
import org.junit.After;
Expand Down Expand Up @@ -109,7 +112,7 @@ public void setup() {
"glue.auth.role_arn",
"arn:aws:iam::924196221507:role/FlintOpensearchServiceRole",
"glue.indexstore.opensearch.uri",
"http://ec2-18-237-133-156.us-west-2.compute.amazonaws" + ".com:9200",
"http://localhost:9200",
"glue.indexstore.opensearch.auth",
"noauth"),
null));
Expand Down Expand Up @@ -269,8 +272,114 @@ public void reuseSessionWhenCreateAsyncQuery() {
assertEquals(second.getQueryId(), secondModel.get().getQueryId());
}

@Test
public void batchQueryHasTimeout() {
LocalEMRSClient emrsClient = new LocalEMRSClient();
AsyncQueryExecutorService asyncQueryExecutorService =
createAsyncQueryExecutorService(emrsClient);

enableSession(false);
CreateAsyncQueryResponse response =
asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null));

assertEquals(120L, (long) emrsClient.getJobRequest().executionTimeout());
}

@Test
public void interactiveQueryNoTimeout() {
LocalEMRSClient emrsClient = new LocalEMRSClient();
AsyncQueryExecutorService asyncQueryExecutorService =
createAsyncQueryExecutorService(emrsClient);

// enable session
enableSession(true);

asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null));
assertEquals(0L, (long) emrsClient.getJobRequest().executionTimeout());
}

@Test
public void datasourceWithBasicAuth() {
Map<String, String> properties = new HashMap<>();
properties.put("glue.auth.type", "iam_role");
properties.put(
"glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole");
properties.put("glue.indexstore.opensearch.uri", "http://localhost:9200");
properties.put("glue.indexstore.opensearch.auth", "basicauth");
properties.put("glue.indexstore.opensearch.auth.username", "username");
properties.put("glue.indexstore.opensearch.auth.password", "password");

dataSourceService.createDataSource(
new DataSourceMetadata(
"mybasicauth", DataSourceType.S3GLUE, ImmutableList.of(), properties, null));
LocalEMRSClient emrsClient = new LocalEMRSClient();
AsyncQueryExecutorService asyncQueryExecutorService =
createAsyncQueryExecutorService(emrsClient);

// enable session
enableSession(true);

asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest("select 1", "mybasicauth", LangType.SQL, null));
String params = emrsClient.getJobRequest().getSparkSubmitParams();
assertTrue(params.contains(String.format("--conf spark.datasource.flint.auth=basic")));
assertTrue(
params.contains(String.format("--conf spark.datasource.flint.auth.username=username")));
assertTrue(
params.contains(String.format("--conf spark.datasource.flint.auth.password=password")));
}

@Test
public void withSessionCreateAsyncQueryFailed() {
LocalEMRSClient emrsClient = new LocalEMRSClient();
AsyncQueryExecutorService asyncQueryExecutorService =
createAsyncQueryExecutorService(emrsClient);

// enable session
enableSession(true);

// 1. create async query.
CreateAsyncQueryResponse response =
asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest("myselect 1", DATASOURCE, LangType.SQL, null));
assertNotNull(response.getSessionId());
Optional<StatementModel> statementModel =
getStatement(stateStore, DATASOURCE).apply(response.getQueryId());
assertTrue(statementModel.isPresent());
assertEquals(StatementState.WAITING, statementModel.get().getStatementState());

// 2. fetch async query result. not result write to SPARK_RESPONSE_BUFFER_INDEX_NAME yet.
// mock failed statement.
StatementModel submitted = statementModel.get();
StatementModel mocked =
StatementModel.builder()
.version("1.0")
.statementState(submitted.getStatementState())
.statementId(submitted.getStatementId())
.sessionId(submitted.getSessionId())
.applicationId(submitted.getApplicationId())
.jobId(submitted.getJobId())
.langType(submitted.getLangType())
.datasourceName(submitted.getDatasourceName())
.query(submitted.getQuery())
.queryId(submitted.getQueryId())
.submitTime(submitted.getSubmitTime())
.error("mock error")
.seqNo(submitted.getSeqNo())
.primaryTerm(submitted.getPrimaryTerm())
.build();
updateStatementState(stateStore, DATASOURCE).apply(mocked, StatementState.FAILED);

AsyncQueryExecutionResponse asyncQueryResults =
asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId());
assertEquals(StatementState.FAILED.getState(), asyncQueryResults.getStatus());
assertEquals("mock error", asyncQueryResults.getError());
}

private DataSourceServiceImpl createDataSourceService() {
String masterKey = "1234567890";
String masterKey = "a57d991d9b573f75b9bba1df";
DataSourceMetadataStorage dataSourceMetadataStorage =
new OpenSearchDataSourceMetadataStorage(
client, clusterService, new EncryptorImpl(masterKey));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ public class SparkQueryDispatcherTest {
@Mock(answer = RETURNS_DEEP_STUBS)
private Session session;

@Mock private Statement statement;
@Mock(answer = RETURNS_DEEP_STUBS)
private Statement statement;

private SparkQueryDispatcher sparkQueryDispatcher;

Expand Down Expand Up @@ -181,7 +182,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() {
String query = "select * from my_glue.default.http_logs";
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(
"basicauth",
"basic",
new HashMap<>() {
{
put(FLINT_INDEX_STORE_AUTH_USERNAME, "username");
Expand Down Expand Up @@ -723,6 +724,7 @@ void testGetQueryResponse() {
void testGetQueryResponseWithSession() {
doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID));
doReturn(Optional.of(statement)).when(session).get(any());
when(statement.getStatementModel().getError()).thenReturn("mock error");
doReturn(StatementState.WAITING).when(statement).getStatementState();

doReturn(new JSONObject())
Expand Down

0 comments on commit f1c8c53

Please sign in to comment.