Skip to content

Commit

Permalink
[CDAP-20872] Refactor user credential encryption to use new SPI
Browse files Browse the repository at this point in the history
  • Loading branch information
dli357 committed Nov 13, 2023
1 parent 838fb75 commit 4ad445e
Show file tree
Hide file tree
Showing 13 changed files with 435 additions and 309 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
import io.cdap.cdap.gateway.router.handlers.HttpStatusRequestHandler;
import io.cdap.cdap.security.auth.TokenValidator;
import io.cdap.cdap.security.auth.UserIdentityExtractor;
import io.cdap.cdap.security.encryption.AeadCipherService;
import io.cdap.cdap.security.encryption.guice.AeadEncryptionModule;
import io.cdap.cdap.security.impersonation.SecurityUtil;
import io.cdap.http.SSLConfig;
import io.cdap.http.SSLHandlerFactory;
Expand Down Expand Up @@ -94,6 +96,7 @@ public class NettyRouter extends AbstractIdleService {
private final UserIdentityExtractor userIdentityExtractor;
private final boolean sslEnabled;
private final DiscoveryServiceClient discoveryServiceClient;
private final AeadCipherService userCredentialAeadCipherService;

private InetSocketAddress boundAddress;
private Cancellable serverCancellable;
Expand All @@ -105,7 +108,9 @@ public NettyRouter(CConfiguration cConf, SConfiguration sConf,
@Named(Constants.Router.ADDRESS) InetAddress hostname,
RouterServiceLookup serviceLookup, TokenValidator tokenValidator,
UserIdentityExtractor userIdentityExtractor,
DiscoveryServiceClient discoveryServiceClient) {
DiscoveryServiceClient discoveryServiceClient,
@Named(AeadEncryptionModule.USER_CREDENTIAL_ENCRYPTION)
AeadCipherService userCredentialAeadCipherService) {
this.cConf = cConf;
this.sConf = sConf;
this.serverBossThreadPoolSize = cConf.getInt(Constants.Router.SERVER_BOSS_THREADS);
Expand All @@ -121,6 +126,7 @@ public NettyRouter(CConfiguration cConf, SConfiguration sConf,
this.port = sslEnabled
? cConf.getInt(Constants.Router.ROUTER_SSL_PORT)
: cConf.getInt(Constants.Router.ROUTER_PORT);
this.userCredentialAeadCipherService = userCredentialAeadCipherService;
}

/**
Expand All @@ -137,6 +143,7 @@ protected void startUp() throws Exception {
if (SecurityUtil.isManagedSecurity(cConf) && !SecurityUtil.isInternalAuthEnabled(cConf)) {
tokenValidator.startAndWait();
}
userCredentialAeadCipherService.startAndWait();
ChannelGroup channelGroup = new DefaultChannelGroup(ImmediateEventExecutor.INSTANCE);
serverCancellable = startServer(createServerBootstrap(channelGroup), channelGroup);
scheduleConfigReloadThread();
Expand All @@ -153,6 +160,7 @@ protected void shutDown() {
if (SecurityUtil.isManagedSecurity(cConf) && !SecurityUtil.isInternalAuthEnabled(cConf)) {
tokenValidator.stopAndWait();
}
userCredentialAeadCipherService.stopAndWait();

LOG.info("Stopped Netty Router.");
}
Expand Down Expand Up @@ -225,7 +233,7 @@ protected void initChannel(SocketChannel ch) {
if (securityEnabled) {
pipeline.addLast("access-token-authenticator",
new AuthenticationHandler(cConf, sConf, discoveryServiceClient,
userIdentityExtractor));
userIdentityExtractor, userCredentialAeadCipherService));
}
if (cConf.getBoolean(Constants.Router.ROUTER_AUDIT_LOG_ENABLED)) {
pipeline.addLast("audit-log", new AuditLogHandler());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,18 @@
import com.google.gson.JsonPrimitive;
import io.cdap.cdap.common.conf.CConfiguration;
import io.cdap.cdap.common.conf.Constants;
import io.cdap.cdap.common.conf.Constants.Security.Encryption;
import io.cdap.cdap.common.conf.SConfiguration;
import io.cdap.cdap.common.logging.AuditLogEntry;
import io.cdap.cdap.common.utils.Networks;
import io.cdap.cdap.proto.security.Credential;
import io.cdap.cdap.security.auth.CipherException;
import io.cdap.cdap.security.auth.TinkCipher;
import io.cdap.cdap.security.auth.UserIdentity;
import io.cdap.cdap.security.auth.UserIdentityExtractionResponse;
import io.cdap.cdap.security.auth.UserIdentityExtractionState;
import io.cdap.cdap.security.auth.UserIdentityExtractor;
import io.cdap.cdap.security.auth.UserIdentityPair;
import io.cdap.cdap.security.encryption.AeadCipherService;
import io.cdap.cdap.security.spi.encryption.CipherOperationException;
import io.cdap.cdap.security.server.GrantAccessToken;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
Expand Down Expand Up @@ -86,10 +87,12 @@ public class AuthenticationHandler extends ChannelInboundHandlerAdapter {
private final List<String> authServerURLs;
private final DiscoveryServiceClient discoveryServiceClient;
private final UserIdentityExtractor userIdentityExtractor;
private final AeadCipherService userCredentialAeadCipherService;

public AuthenticationHandler(CConfiguration cConf, SConfiguration sConf,
DiscoveryServiceClient discoveryServiceClient,
UserIdentityExtractor userIdentityExtractor) {
UserIdentityExtractor userIdentityExtractor,
AeadCipherService userCredentialAeadCipherService) {
this.cConf = cConf;
this.sConf = sConf;
this.realm = cConf.get(Constants.Security.CFG_REALM);
Expand All @@ -98,6 +101,7 @@ public AuthenticationHandler(CConfiguration cConf, SConfiguration sConf,
this.authServerURLs = getConfiguredAuthServerURLs(cConf);
this.discoveryServiceClient = discoveryServiceClient;
this.userIdentityExtractor = userIdentityExtractor;
this.userCredentialAeadCipherService = userCredentialAeadCipherService;
}

@Override
Expand Down Expand Up @@ -239,7 +243,7 @@ public void onChange(ServiceDiscovered serviceDiscovered) {
* Get user credential from {@link UserIdentityPair} and return it in encrypted form if enabled.
*/
@Nullable
private Credential getUserCredential(UserIdentityPair userIdentityPair) throws CipherException {
private Credential getUserCredential(UserIdentityPair userIdentityPair) throws CipherOperationException {
String userCredential = userIdentityPair.getUserCredential();
UserIdentity userIdentity = userIdentityPair.getUserIdentity();
if (userIdentity.getIdentifierType() == UserIdentity.IdentifierType.INTERNAL) {
Expand All @@ -250,7 +254,9 @@ private Credential getUserCredential(UserIdentityPair userIdentityPair) throws C
false)) {
return new Credential(userCredential, Credential.CredentialType.EXTERNAL);
}
String encryptedCredential = new TinkCipher(sConf).encryptStringToBase64(userCredential, null);
String encryptedCredential = userCredentialAeadCipherService
.encryptStringToBase64(userCredential,
Encryption.USER_CREDENTIAL_ENCRYPTION_ASSOCIATED_DATA.getBytes());
return new Credential(encryptedCredential, Credential.CredentialType.EXTERNAL_ENCRYPTED);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.cdap.cdap.common.security.AuditDetail;
import io.cdap.cdap.common.security.AuditPolicy;
import io.cdap.cdap.security.auth.TokenValidator;
import io.cdap.cdap.security.encryption.NoOpAeadCipherService;
import io.cdap.http.AbstractHttpHandler;
import io.cdap.http.HttpResponder;
import io.cdap.http.NettyHttpService;
Expand Down Expand Up @@ -82,7 +83,8 @@ public class AuditLogTest {
public static void init() throws Exception {
// Configure a log appender programmatically for the audit log
TestLogAppender.addAppender(Constants.Router.AUDIT_LOGGER_NAME);
((ch.qos.logback.classic.Logger) LoggerFactory.getLogger(Constants.Router.AUDIT_LOGGER_NAME)).setLevel(Level.TRACE);
((ch.qos.logback.classic.Logger) LoggerFactory.getLogger(Constants.Router.AUDIT_LOGGER_NAME))
.setLevel(Level.TRACE);

CConfiguration cConf = CConfiguration.create();
SConfiguration sConf = SConfiguration.create();
Expand All @@ -93,18 +95,21 @@ public static void init() throws Exception {

InMemoryDiscoveryService discoveryService = new InMemoryDiscoveryService();

RouterServiceLookup serviceLookup = new RouterServiceLookup(cConf, discoveryService, new RouterPathLookup());
RouterServiceLookup serviceLookup = new RouterServiceLookup(cConf, discoveryService,
new RouterPathLookup());

TokenValidator successValidator = new SuccessTokenValidator();
router = new NettyRouter(cConf, sConf, InetAddress.getLoopbackAddress(), serviceLookup, successValidator,
new MockAccessTokenIdentityExtractor(successValidator), discoveryService);
router = new NettyRouter(cConf, sConf, InetAddress.getLoopbackAddress(), serviceLookup,
successValidator,
new MockAccessTokenIdentityExtractor(successValidator), discoveryService,
new NoOpAeadCipherService());
router.startAndWait();

httpService = NettyHttpService.builder("test").setHttpHandlers(new TestHandler()).build();
httpService.start();

cancelDiscovery = discoveryService.register(new Discoverable(Constants.Service.APP_FABRIC_HTTP,
httpService.getBindAddress()));
httpService.getBindAddress()));

int port = router.getBoundAddress().orElseThrow(IllegalStateException::new).getPort();
baseURI = URI.create(String.format("http://%s:%d", cConf.get(Constants.Router.ADDRESS), port));
Expand All @@ -126,30 +131,35 @@ public void testAuditLog() throws IOException {
urlConn = createURLConnection("/put", HttpMethod.PUT);
urlConn.getOutputStream().write("Test Put".getBytes(StandardCharsets.UTF_8));
Assert.assertEquals(200, urlConn.getResponseCode());
Assert.assertEquals("Test Put", new String(ByteStreams.toByteArray(urlConn.getInputStream()), "UTF-8"));
Assert.assertEquals("Test Put",
new String(ByteStreams.toByteArray(urlConn.getInputStream()), "UTF-8"));
urlConn.getInputStream().close();

urlConn = createURLConnection("/post", HttpMethod.POST);
urlConn.getOutputStream().write("Test Post".getBytes(StandardCharsets.UTF_8));
Assert.assertEquals(200, urlConn.getResponseCode());
Assert.assertEquals("Test Post", new String(ByteStreams.toByteArray(urlConn.getInputStream()), "UTF-8"));
Assert.assertEquals("Test Post",
new String(ByteStreams.toByteArray(urlConn.getInputStream()), "UTF-8"));
urlConn.getInputStream().close();

urlConn = createURLConnection("/postHeaders", HttpMethod.POST);
urlConn.setRequestProperty("user-id", "cdap");
urlConn.getOutputStream().write("Post Headers".getBytes(StandardCharsets.UTF_8));
Assert.assertEquals(200, urlConn.getResponseCode());
Assert.assertEquals("Post Headers", new String(ByteStreams.toByteArray(urlConn.getInputStream()), "UTF-8"));
Assert.assertEquals("Post Headers",
new String(ByteStreams.toByteArray(urlConn.getInputStream()), "UTF-8"));
urlConn.getInputStream().close();

List<String> loggedMessages = TestLogAppender.INSTANCE.getLoggedMessages();
Assert.assertEquals(4, loggedMessages.size());

Assert.assertTrue(loggedMessages.get(0).endsWith("\"GET /get HTTP/1.1\" - - 200 0 -"));
Assert.assertTrue(loggedMessages.get(1).endsWith("\"PUT /put HTTP/1.1\" - Test Put 200 8 -"));
Assert.assertTrue(loggedMessages.get(2).endsWith("\"POST /post HTTP/1.1\" - Test Post 200 9 Test Post"));
Assert.assertTrue(
loggedMessages.get(3).endsWith("\"POST /postHeaders HTTP/1.1\" {user-id=cdap} Post Headers 200 12 Post Headers"));
loggedMessages.get(2).endsWith("\"POST /post HTTP/1.1\" - Test Post 200 9 Test Post"));
Assert.assertTrue(
loggedMessages.get(3).endsWith(
"\"POST /postHeaders HTTP/1.1\" {user-id=cdap} Post Headers 200 12 Post Headers"));
}

private HttpURLConnection createURLConnection(String path, HttpMethod method) throws IOException {
Expand Down Expand Up @@ -177,21 +187,25 @@ public void get(HttpRequest request, HttpResponder responder, @QueryParam("q") S
@PUT
@AuditPolicy(AuditDetail.REQUEST_BODY)
public void put(FullHttpRequest request, HttpResponder responder) {
responder.sendContent(HttpResponseStatus.OK, request.content().retainedDuplicate(), EmptyHttpHeaders.INSTANCE);
responder.sendContent(HttpResponseStatus.OK, request.content().retainedDuplicate(),
EmptyHttpHeaders.INSTANCE);
}

@Path("/post")
@POST
@AuditPolicy({AuditDetail.REQUEST_BODY, AuditDetail.RESPONSE_BODY})
public void post(FullHttpRequest request, HttpResponder responder) {
responder.sendContent(HttpResponseStatus.OK, request.content().retainedDuplicate(), EmptyHttpHeaders.INSTANCE);
responder.sendContent(HttpResponseStatus.OK, request.content().retainedDuplicate(),
EmptyHttpHeaders.INSTANCE);
}

@Path("/postHeaders")
@POST
@AuditPolicy({AuditDetail.REQUEST_BODY, AuditDetail.RESPONSE_BODY, AuditDetail.HEADERS})
public void postHeaders(FullHttpRequest request, HttpResponder responder, @HeaderParam("user-id") String userId) {
responder.sendContent(HttpResponseStatus.OK, request.content().retainedDuplicate(), EmptyHttpHeaders.INSTANCE);
public void postHeaders(FullHttpRequest request, HttpResponder responder,
@HeaderParam("user-id") String userId) {
responder.sendContent(HttpResponseStatus.OK, request.content().retainedDuplicate(),
EmptyHttpHeaders.INSTANCE);
}
}

Expand All @@ -208,7 +222,6 @@ static void addAppender(String loggerName) {
LoggerContext loggerContext = (LoggerContext) LoggerFactory.getILoggerFactory();
Logger logger = loggerContext.getLogger(loggerName);


// Check if the logger already contains the logAppender
if (Iterators.contains(logger.iteratorForAppenders(), INSTANCE)) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.cdap.cdap.security.auth.AuthenticationMode;
import io.cdap.cdap.security.auth.TokenValidator;
import io.cdap.cdap.security.auth.UserIdentityExtractor;
import io.cdap.cdap.security.encryption.NoOpAeadCipherService;
import io.cdap.cdap.security.guice.CoreSecurityRuntimeModule;
import io.cdap.cdap.security.guice.ExternalAuthenticationModule;
import io.cdap.cdap.security.server.GrantAccessToken;
Expand All @@ -55,16 +56,19 @@
import org.junit.Test;

public class AuthServerAnnounceTest {

private static final String HOSTNAME = "127.0.0.1";
private static final int CONNECTION_IDLE_TIMEOUT_SECS = 2;
private static final DiscoveryService DISCOVERY_SERVICE = new InMemoryDiscoveryService();
private static final String ANNOUNCE_URLS = "https://vip.cask.co:80,http://vip.cask.co:1000";
private static final Gson GSON = new Gson();
private static final Type TYPE = new TypeToken<Map<String, List<String>>>() { }.getType();
private static final Type TYPE = new TypeToken<Map<String, List<String>>>() {
}.getType();

@Test
public void testEmptyAnnounceAddressURLsConfig() throws Exception {
HttpRouterService routerService = new AuthServerAnnounceTest.HttpRouterService(HOSTNAME, DISCOVERY_SERVICE);
HttpRouterService routerService = new AuthServerAnnounceTest.HttpRouterService(HOSTNAME,
DISCOVERY_SERVICE);
routerService.startUp();
try {
Assert.assertEquals(Collections.EMPTY_LIST, getAuthURI(routerService));
Expand All @@ -75,36 +79,40 @@ public void testEmptyAnnounceAddressURLsConfig() throws Exception {

@Test
public void testAnnounceURLsConfig() throws Exception {
HttpRouterService routerService = new AuthServerAnnounceTest.HttpRouterService(HOSTNAME, DISCOVERY_SERVICE);
HttpRouterService routerService = new AuthServerAnnounceTest.HttpRouterService(HOSTNAME,
DISCOVERY_SERVICE);
routerService.cConf.set(Constants.Security.AUTH_SERVER_ANNOUNCE_URLS, ANNOUNCE_URLS);
routerService.startUp();
try {
List<String> expected = Stream.of(ANNOUNCE_URLS.split(","))
.map(url -> String.format("%s/%s", url, GrantAccessToken.Paths.GET_TOKEN))
.collect(Collectors.toList());
.map(url -> String.format("%s/%s", url, GrantAccessToken.Paths.GET_TOKEN))
.collect(Collectors.toList());
Assert.assertEquals(expected, getAuthURI(routerService));
} finally {
routerService.shutDown();
}
}

private List<String> getAuthURI(HttpRouterService routerService) throws IOException, URISyntaxException {
private List<String> getAuthURI(HttpRouterService routerService)
throws IOException, URISyntaxException {
DefaultHttpClient client = new DefaultHttpClient();
String url = resolveURI("/v3/apps", routerService);
HttpGet get = new HttpGet(url);
HttpResponse response = client.execute(get);
Map<String, List<String>> responseMap =
GSON.fromJson(new InputStreamReader(response.getEntity().getContent()), TYPE);
GSON.fromJson(new InputStreamReader(response.getEntity().getContent()), TYPE);
return responseMap.get("auth_uri");
}

private String resolveURI(String path, HttpRouterService routerService) throws URISyntaxException {
private String resolveURI(String path, HttpRouterService routerService)
throws URISyntaxException {
InetSocketAddress address = routerService.getRouterAddress();
return new URI(String.format("%s://%s:%d", "http", address.getHostName(),
address.getPort())).resolve(path).toASCIIString();
address.getPort())).resolve(path).toASCIIString();
}

private static class HttpRouterService extends AbstractIdleService {

private final String hostname;
private final DiscoveryService discoveryService;
private CConfiguration cConf = CConfiguration.create();
Expand All @@ -119,10 +127,11 @@ private HttpRouterService(String hostname, DiscoveryService discoveryService) {
protected void startUp() {
SConfiguration sConfiguration = SConfiguration.create();
Injector injector = Guice.createInjector(new CoreSecurityRuntimeModule().getInMemoryModules(),
new ExternalAuthenticationModule(),
new InMemoryDiscoveryModule(),
new AppFabricTestModule(cConf));
DiscoveryServiceClient discoveryServiceClient = injector.getInstance(DiscoveryServiceClient.class);
new ExternalAuthenticationModule(),
new InMemoryDiscoveryModule(),
new AppFabricTestModule(cConf));
DiscoveryServiceClient discoveryServiceClient = injector
.getInstance(DiscoveryServiceClient.class);
TokenValidator validator = new MissingTokenValidator();
UserIdentityExtractor userIdentityExtractor = new MockAccessTokenIdentityExtractor(validator);
cConf.set(Constants.Router.ADDRESS, hostname);
Expand All @@ -132,10 +141,11 @@ protected void startUp() {
cConf.setEnum(Constants.Security.Authentication.MODE, AuthenticationMode.MANAGED);

router =
new NettyRouter(cConf, sConfiguration, InetAddresses.forString(hostname),
new RouterServiceLookup(cConf, (DiscoveryServiceClient) discoveryService,
new RouterPathLookup()),
validator, userIdentityExtractor, discoveryServiceClient);
new NettyRouter(cConf, sConfiguration, InetAddresses.forString(hostname),
new RouterServiceLookup(cConf, (DiscoveryServiceClient) discoveryService,
new RouterPathLookup()),
validator, userIdentityExtractor, discoveryServiceClient,
new NoOpAeadCipherService());
router.startAndWait();
}

Expand Down
Loading

0 comments on commit 4ad445e

Please sign in to comment.