Skip to content

Commit

Permalink
Allow users with STATE permissions to read and write the state APIs f…
Browse files Browse the repository at this point in the history
…or querying with deep storage (apache#14944)

Currently, only the user who has submitted the async query has permission to interact with the status APIs for that async query. However, often we want an administrator to interact with these resources as well.
Druid handles these with the STATE resource traditionally, and if the requesting user has necessary permissions on it as well, alternatively, they should be allowed to interact with the status APIs, irrespective of whether they are the submitter of the query.
  • Loading branch information
LakshSingla authored Sep 21, 2023
1 parent 883c269 commit ebb7946
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,14 @@
import org.apache.druid.rpc.HttpResponseException;
import org.apache.druid.rpc.indexing.OverlordClient;
import org.apache.druid.server.QueryResponse;
import org.apache.druid.server.security.Access;
import org.apache.druid.server.security.Action;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.AuthorizationUtils;
import org.apache.druid.server.security.AuthorizerMapper;
import org.apache.druid.server.security.ForbiddenException;
import org.apache.druid.server.security.Resource;
import org.apache.druid.server.security.ResourceAction;
import org.apache.druid.sql.DirectStatement;
import org.apache.druid.sql.HttpStatement;
import org.apache.druid.sql.SqlRowTransformer;
Expand All @@ -103,6 +108,7 @@
import javax.ws.rs.core.StreamingOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand All @@ -120,20 +126,23 @@ public class SqlStatementResource
private final ObjectMapper jsonMapper;
private final OverlordClient overlordClient;
private final StorageConnector storageConnector;
private final AuthorizerMapper authorizerMapper;


@Inject
public SqlStatementResource(
final @MSQ SqlStatementFactory msqSqlStatementFactory,
final ObjectMapper jsonMapper,
final OverlordClient overlordClient,
final @MultiStageQuery StorageConnector storageConnector
final @MultiStageQuery StorageConnector storageConnector,
final AuthorizerMapper authorizerMapper
)
{
this.msqSqlStatementFactory = msqSqlStatementFactory;
this.jsonMapper = jsonMapper;
this.overlordClient = overlordClient;
this.storageConnector = storageConnector;
this.authorizerMapper = authorizerMapper;
}

/**
Expand Down Expand Up @@ -178,7 +187,7 @@ public Response doPost(final SqlQuery sqlQuery, @Context final HttpServletReques
final boolean isTaskStruct = MSQTaskSqlEngine.TASK_STRUCT_FIELD_NAMES.equals(rowTransformer.getFieldList());

if (isTaskStruct) {
return buildTaskResponse(sequence, stmt.query().authResult().getIdentity());
return buildTaskResponse(sequence, stmt.query().authResult());
} else {
// Used for EXPLAIN
return buildStandardResponse(sequence, modifiedQuery, sqlQueryId, rowTransformer);
Expand Down Expand Up @@ -231,8 +240,9 @@ public Response doGetStatus(

Optional<SqlStatementResult> sqlStatementResult = getStatementStatus(
queryId,
authenticationResult.getIdentity(),
true
authenticationResult,
true,
Action.READ
);

if (sqlStatementResult.isPresent()) {
Expand Down Expand Up @@ -288,7 +298,11 @@ public Response doGetResults(
throw queryNotFoundException(queryId);
}

MSQControllerTask msqControllerTask = getMSQControllerTaskOrThrow(queryId, authenticationResult.getIdentity());
MSQControllerTask msqControllerTask = getMSQControllerTaskAndCheckPermission(
queryId,
authenticationResult,
Action.READ
);
throwIfQueryIsNotSuccessful(queryId, statusPlus);

Optional<List<ColumnNameAndTypes>> signature = SqlStatementResourceHelper.getSignature(msqControllerTask);
Expand Down Expand Up @@ -353,8 +367,9 @@ public Response deleteQuery(@PathParam("id") final String queryId, @Context fina

Optional<SqlStatementResult> sqlStatementResult = getStatementStatus(
queryId,
authenticationResult.getIdentity(),
false
authenticationResult,
false,
Action.WRITE
);
if (sqlStatementResult.isPresent()) {
switch (sqlStatementResult.get().getState()) {
Expand Down Expand Up @@ -448,7 +463,7 @@ private Response buildStandardResponse(
}
}

private Response buildTaskResponse(Sequence<Object[]> sequence, String user)
private Response buildTaskResponse(Sequence<Object[]> sequence, AuthenticationResult authenticationResult)
{
List<Object[]> rows = sequence.toList();
int numRows = rows.size();
Expand All @@ -464,7 +479,7 @@ private Response buildTaskResponse(Sequence<Object[]> sequence, String user)
}
String taskId = String.valueOf(firstRow[0]);

Optional<SqlStatementResult> statementResult = getStatementStatus(taskId, user, true);
Optional<SqlStatementResult> statementResult = getStatementStatus(taskId, authenticationResult, true, Action.READ);

if (statementResult.isPresent()) {
return Response.status(Response.Status.OK).entity(statementResult.get()).build();
Expand Down Expand Up @@ -565,8 +580,12 @@ private Optional<ResultSetInformation> getSampleResults(
}


private Optional<SqlStatementResult> getStatementStatus(String queryId, String currentUser, boolean withResults)
throws DruidException
private Optional<SqlStatementResult> getStatementStatus(
String queryId,
AuthenticationResult authenticationResult,
boolean withResults,
Action forAction
) throws DruidException
{
TaskStatusResponse taskResponse = contactOverlord(overlordClient.taskStatus(queryId), queryId);
if (taskResponse == null) {
Expand All @@ -579,7 +598,7 @@ private Optional<SqlStatementResult> getStatementStatus(String queryId, String c
}

// since we need the controller payload for auth checks.
MSQControllerTask msqControllerTask = getMSQControllerTaskOrThrow(queryId, currentUser);
MSQControllerTask msqControllerTask = getMSQControllerTaskAndCheckPermission(queryId, authenticationResult, forAction);
SqlStatementState sqlStatementState = SqlStatementResourceHelper.getSqlStatementState(statusPlus);

if (SqlStatementState.FAILED == sqlStatementState) {
Expand Down Expand Up @@ -610,7 +629,20 @@ private Optional<SqlStatementResult> getStatementStatus(String queryId, String c
}


private MSQControllerTask getMSQControllerTaskOrThrow(String queryId, String currentUser)
/**
* This method contacts the overlord for the controller task and checks if the requested user has the
* necessary permissions. A user has the necessary permissions if one of the following criteria is satisfied:
* 1. The user is the one who submitted the query
* 2. The user belongs to a role containing the READ or WRITE permissions over the STATE resource. For endpoints like GET,
* the user should have READ permission for the STATE resource, while for endpoints like DELETE, the user should
* have WRITE permission for the STATE resource. (Note: POST API does not need to check the state permissions since
* the currentUser always equal to the queryUser)
*/
private MSQControllerTask getMSQControllerTaskAndCheckPermission(
String queryId,
AuthenticationResult authenticationResult,
Action forAction
) throws ForbiddenException
{
TaskPayloadResponse taskPayloadResponse = contactOverlord(overlordClient.taskPayload(queryId), queryId);
SqlStatementResourceHelper.isMSQPayload(taskPayloadResponse, queryId);
Expand All @@ -620,15 +652,28 @@ private MSQControllerTask getMSQControllerTaskOrThrow(String queryId, String cur
.getQuery()
.getContext()
.get(MSQTaskQueryMaker.USER_KEY));
if (currentUser == null || !currentUser.equals(queryUser)) {
throw new ForbiddenException(StringUtils.format(
"The current user[%s] cannot view query id[%s] since the query is owned by user[%s]",
currentUser,
queryId,
queryUser
));

String currentUser = authenticationResult.getIdentity();

if (currentUser != null && currentUser.equals(queryUser)) {
return msqControllerTask;
}

Access access = AuthorizationUtils.authorizeAllResourceActions(
authenticationResult,
Collections.singletonList(new ResourceAction(Resource.STATE_RESOURCE, forAction)),
authorizerMapper
);

if (access.isAllowed()) {
return msqControllerTask;
}
return msqControllerTask;

throw new ForbiddenException(StringUtils.format(
"The current user[%s] cannot view query id[%s] since the query is owned by another user",
currentUser,
queryId
));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ public void init()
sqlStatementFactory,
objectMapper,
indexingServiceClient,
localFileStorageConnector
localFileStorageConnector,
authorizerMapper
);
}

Expand Down Expand Up @@ -274,7 +275,8 @@ public void durableStorageDisabledTest()
sqlStatementFactory,
objectMapper,
indexingServiceClient,
NilStorageConnector.getInstance()
NilStorageConnector.getInstance(),
authorizerMapper
);

String errorMessage = "The sql statement api cannot read from the select destination [durableStorage] provided in "
Expand Down
Loading

0 comments on commit ebb7946

Please sign in to comment.