Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple users can share the same gateway server #52

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/main/java/com/ververica/flink/table/gateway/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -72,7 +73,11 @@ public SessionContext getContext() {
return context;
}

public Tuple2<ResultSet, SqlCommandParser.SqlCommand> runStatement(String statement) {
public Tuple2<ResultSet, SqlCommandParser.SqlCommand> runStatement(String statement){
return this.runStatement(statement, Collections.emptyMap());
}

public Tuple2<ResultSet, SqlCommandParser.SqlCommand> runStatement(String statement, Map<String, String> operationConf) {
LOG.info("Session: {}, run statement: {}", sessionId, statement);
boolean isBlinkPlanner = context.getExecutionContext().getEnvironment().getExecution().getPlanner()
.equalsIgnoreCase(ExecutionEntry.EXECUTION_PLANNER_VALUE_BLINK);
Expand All @@ -91,7 +96,7 @@ public Tuple2<ResultSet, SqlCommandParser.SqlCommand> runStatement(String statem
throw new SqlGatewayException(e.getMessage(), e.getCause());
}

Operation operation = OperationFactory.createOperation(call, context);
Operation operation = OperationFactory.createOperation(call, context, operationConf);
ResultSet resultSet = operation.execute();

if (operation instanceof JobOperation) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@
import com.ververica.flink.table.gateway.rest.result.ColumnInfo;
import com.ververica.flink.table.gateway.rest.result.ResultKind;
import com.ververica.flink.table.gateway.rest.result.ResultSet;
import com.ververica.flink.table.gateway.security.HadoopSecurityContext;
import com.ververica.flink.table.gateway.security.NoOpSecurityContext;

import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.JobStatus;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.security.SecurityConfiguration;
import org.apache.flink.runtime.security.SecurityContext;
import org.apache.flink.runtime.security.SecurityUtils;
import org.apache.flink.types.Row;

import org.apache.hadoop.security.UserGroupInformation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -39,7 +45,9 @@
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Callable;

/**
* A default implementation of JobOperation.
Expand All @@ -52,6 +60,7 @@ public abstract class AbstractJobOperation implements JobOperation {
protected ClusterDescriptorAdapter<?> clusterDescriptorAdapter;
protected final String sessionId;
protected volatile JobID jobId;
protected Map<String, String> operationConf;

private long currentToken;
private int previousMaxFetchSize;
Expand Down Expand Up @@ -210,6 +219,28 @@ public synchronized Optional<ResultSet> getJobResult(long token, int maxFetchSiz
);
}

protected <T> T doAsOwner(final Callable<T> callable) {
String user = operationConf.get("proxyUser");

SecurityContext securityContext;
if (UserGroupInformation.isSecurityEnabled()) {
try {
SecurityUtils.install(new SecurityConfiguration(context.getExecutionContext().getFlinkConfig()));
} catch (Exception e) {
throw new RuntimeException("Install security context failed.", e);
}
securityContext = new HadoopSecurityContext(user);
} else {
securityContext = new NoOpSecurityContext(user);
}

try {
return securityContext.runSecured(callable);
} catch (Exception e) {
throw new RuntimeException("Error running function.", e);
}
}

protected abstract Optional<Tuple2<List<Row>, List<Boolean>>> fetchNewJobResults();

protected abstract List<ColumnInfo> getColumnInfos();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Pattern;

Expand All @@ -66,17 +67,18 @@ public class InsertOperation extends AbstractJobOperation {

private boolean fetched = false;

public InsertOperation(SessionContext context, String statement) {
public InsertOperation(SessionContext context, String statement, Map<String, String> operationConf) {
super(context);
this.statement = statement;

this.columnInfos = new ArrayList<>();
this.columnInfos.add(ColumnInfo.create(ConstantNames.AFFECTED_ROW_COUNT, new BigIntType(false)));
this.operationConf = operationConf;
}

@Override
public ResultSet execute() {
jobId = executeUpdateInternal(context.getExecutionContext());
jobId = doAsOwner(() -> executeUpdateInternal(context.getExecutionContext()));
String strJobId = jobId.toString();
return ResultSet.builder()
.resultKind(ResultKind.SUCCESS_WITH_CONTENT)
Expand Down Expand Up @@ -118,7 +120,10 @@ protected List<ColumnInfo> getColumnInfos() {

@Override
protected void cancelJobInternal() {
clusterDescriptorAdapter.cancelJob();
doAsOwner(() -> {
clusterDescriptorAdapter.cancelJob();
return null;
});
}

private <C> JobID executeUpdateInternal(ExecutionContext<C> executionContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,19 @@
import com.ververica.flink.table.gateway.SqlGatewayException;
import com.ververica.flink.table.gateway.context.SessionContext;

import java.util.Map;

/**
* The factory to create {@link Operation} based on {@link SqlCommandCall}.
*/
public class OperationFactory {

public static Operation createOperation(SqlCommandCall call, SessionContext context) {
public static Operation createOperation(SqlCommandCall call, SessionContext context, Map<String, String> operationConf) {

Operation operation;
switch (call.command) {
case SELECT:
operation = new SelectOperation(context, call.operands[0]);
operation = new SelectOperation(context, call.operands[0], operationConf);
break;
case CREATE_VIEW:
operation = new CreateViewOperation(context, call.operands[0], call.operands[1]);
Expand Down Expand Up @@ -71,7 +73,7 @@ public static Operation createOperation(SqlCommandCall call, SessionContext cont
break;
case INSERT_INTO:
case INSERT_OVERWRITE:
operation = new InsertOperation(context, call.operands[0]);
operation = new InsertOperation(context, call.operands[0], operationConf);
break;
case SHOW_MODULES:
operation = new ShowModuleOperation(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;

Expand All @@ -75,16 +76,17 @@ public class SelectOperation extends AbstractJobOperation {
private boolean resultFetched;
private volatile boolean noMoreResult;

public SelectOperation(SessionContext context, String query) {
public SelectOperation(SessionContext context, String query, Map<String, String> operationConf) {
super(context);
this.query = query;
this.resultFetched = false;
this.noMoreResult = false;
this.operationConf = operationConf;
}

@Override
public ResultSet execute() {
resultDescriptor = executeQueryInternal(context.getExecutionContext(), query);
resultDescriptor = doAsOwner(() -> executeQueryInternal(context.getExecutionContext(), query));
jobId = resultDescriptor.getJobClient().getJobID();

List<TableColumn> resultSchemaColumns = resultDescriptor.getResultSchema().getTableColumns();
Expand All @@ -111,7 +113,10 @@ protected void cancelJobInternal() {
return;
}

clusterDescriptorAdapter.cancelJob();
doAsOwner(() -> {
clusterDescriptorAdapter.cancelJob();
return null;
});
}

@Override
Expand All @@ -124,9 +129,9 @@ protected Optional<Tuple2<List<Row>, List<Boolean>>> fetchNewJobResults() {
}

if (resultDescriptor.isChangelogResult()) {
ret = fetchStreamingResult();
ret = doAsOwner(this::fetchStreamingResult);
} else {
ret = fetchBatchResult();
ret = doAsOwner(this::fetchBatchResult);
}
}
resultFetched = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ protected CompletableFuture<StatementExecuteResponseBody> handleRequest(

// TODO supports this
Long executionTimeoutMillis = request.getRequestBody().getExecutionTimeout();
Map<String, String> executionConf = request.getRequestBody().getExecutionConf();

try {
Tuple2<ResultSet, SqlCommand> tuple2 = sessionManager.getSession(sessionId).runStatement(statement);
Tuple2<ResultSet, SqlCommand> tuple2 = sessionManager.getSession(sessionId).runStatement(statement, executionConf);
ResultSet resultSet = tuple2.f0;
String statementType = tuple2.f1.name();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

import javax.annotation.Nullable;

import java.util.Collections;
import java.util.Map;

/**
* {@link RequestBody} for executing a statement.
*/
Expand All @@ -34,6 +37,7 @@ public class StatementExecuteRequestBody implements RequestBody {

private static final String FIELD_STATEMENT = "statement";
private static final String FIELD_EXECUTION_TIMEOUT = "execution_timeout";
private static final String FIELD_EXECUTION_CONF = "execution_conf";

@JsonProperty(FIELD_STATEMENT)
@Nullable
Expand All @@ -43,11 +47,17 @@ public class StatementExecuteRequestBody implements RequestBody {
@Nullable
private Long executionTimeout;

@JsonProperty(FIELD_EXECUTION_CONF)
@Nullable
private Map<String, String> executionConf;

public StatementExecuteRequestBody(
@Nullable @JsonProperty(FIELD_STATEMENT) String statement,
@Nullable @JsonProperty(FIELD_EXECUTION_TIMEOUT) Long executionTimeout) {
@Nullable @JsonProperty(FIELD_EXECUTION_TIMEOUT) Long executionTimeout,
@Nullable @JsonProperty(FIELD_EXECUTION_CONF) Map<String, String> executionConf) {
this.statement = statement;
this.executionTimeout = executionTimeout;
this.executionConf = executionConf;
}

@Nullable
Expand All @@ -61,4 +71,13 @@ public String getStatement() {
public Long getExecutionTimeout() {
return executionTimeout;
}

@Nullable
@JsonIgnore
public Map<String, String> getExecutionConf() {
if (executionConf == null){
return Collections.emptyMap();
}
return executionConf;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.ververica.flink.table.gateway.security;

import org.apache.flink.runtime.security.SecurityContext;

import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.security.UserGroupInformation;

import java.security.PrivilegedExceptionAction;
import java.util.concurrent.Callable;

/**
* An implementation of SecurityContext for secure cluster.
*/
public class HadoopSecurityContext implements SecurityContext {
private final String user;

public HadoopSecurityContext(String user) {
this.user = user;
}

@Override
public <T> T runSecured(Callable<T> securedCallable) throws Exception {
UserGroupInformation ugi;
if (StringUtils.isNotEmpty(user)){
ugi = UserGroupInformation.createProxyUser(user, UserGroupInformation.getLoginUser());
} else {
ugi = UserGroupInformation.getLoginUser();
}
return ugi.doAs((PrivilegedExceptionAction<T>) securedCallable::call);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.ververica.flink.table.gateway.security;

import org.apache.flink.runtime.security.SecurityContext;

import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.security.UserGroupInformation;

import java.security.PrivilegedExceptionAction;
import java.util.concurrent.Callable;

/**
* An implementation of SecurityContext for insecure cluster.
*/
public class NoOpSecurityContext implements SecurityContext {
private final String user;

public NoOpSecurityContext(String user) {
this.user = user;
}

@Override
public <T> T runSecured(Callable<T> securedCallable) throws Exception {
if (StringUtils.isNotEmpty(user)){
UserGroupInformation ugi = UserGroupInformation.createRemoteUser(user);
return ugi.doAs((PrivilegedExceptionAction<T>) securedCallable::call);
}
return securedCallable.call();
}
}