Skip to content

Commit

Permalink
Merge pull request #15416 from cdapio/CDAP-20785-cherry-pick
Browse files Browse the repository at this point in the history
[🍒][CDAP-20785] Do not set WI env and use client library to fetch auth token.
  • Loading branch information
itsankit-google authored Nov 9, 2023
2 parents 4b110ca + 1576288 commit a8a443d
Show file tree
Hide file tree
Showing 13 changed files with 97 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ public void run() {
String.format("%s:%s", localhost,
cConf.getInt(Constants.ArtifactLocalizer.PORT))
));
twillPreparer = ((SecureTwillPreparer) twillPreparer)
.withNamespacedWorkloadIdentity(PreviewRunnerTwillRunnable.class.getSimpleName());
}

String priorityClass = cConf.get(Constants.Preview.CONTAINER_PRIORITY_CLASS_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ public void run() {
String.format("%s:%s", localhost,
cConf.getInt(Constants.ArtifactLocalizer.PORT))
));
twillPreparer = ((SecureTwillPreparer) twillPreparer)
.withNamespacedWorkloadIdentity(TaskWorkerTwillRunnable.class.getSimpleName());
}

String priorityClass = cConf.get(Constants.TaskWorker.CONTAINER_PRIORITY_CLASS_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.cdap.cdap.common.conf.Constants;
import io.cdap.cdap.common.http.CommonNettyHttpServiceFactory;
import io.cdap.cdap.common.internal.remote.RemoteClientFactory;
import io.cdap.cdap.security.spi.authenticator.RemoteAuthenticator;
import io.cdap.http.NettyHttpService;
import java.net.InetAddress;
import java.nio.file.Paths;
Expand Down Expand Up @@ -52,7 +53,7 @@ public class ArtifactLocalizerService extends AbstractIdleService {
ArtifactLocalizerService(CConfiguration cConf,
ArtifactLocalizer artifactLocalizer,
CommonNettyHttpServiceFactory commonNettyHttpServiceFactory,
RemoteClientFactory remoteClientFactory) {
RemoteClientFactory remoteClientFactory, RemoteAuthenticator remoteAuthenticator) {
this.cConf = cConf;
this.artifactLocalizer = artifactLocalizer;
this.httpService = commonNettyHttpServiceFactory.builder(Constants.Service.TASK_WORKER)
Expand All @@ -61,7 +62,7 @@ public class ArtifactLocalizerService extends AbstractIdleService {
.setBossThreadPoolSize(cConf.getInt(Constants.ArtifactLocalizer.BOSS_THREADS))
.setWorkerThreadPoolSize(cConf.getInt(Constants.ArtifactLocalizer.WORKER_THREADS))
.setHttpHandlers(new ArtifactLocalizerHttpHandlerInternal(artifactLocalizer),
new GcpMetadataHttpHandlerInternal(cConf, remoteClientFactory))
new GcpMetadataHttpHandlerInternal(cConf, remoteClientFactory, remoteAuthenticator))
.build();

this.cacheCleanupInterval = cConf.getInt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
import com.google.inject.Guice;
import com.google.inject.Injector;
import com.google.inject.Module;
import io.cdap.cdap.api.feature.FeatureFlagsProvider;
import io.cdap.cdap.app.guice.DistributedArtifactManagerModule;
import io.cdap.cdap.common.conf.CConfiguration;
import io.cdap.cdap.common.conf.Constants;
import io.cdap.cdap.common.feature.DefaultFeatureFlagsProvider;
import io.cdap.cdap.common.guice.ConfigModule;
import io.cdap.cdap.common.guice.DFSLocationModule;
import io.cdap.cdap.common.guice.IOModule;
Expand All @@ -40,6 +42,7 @@
import io.cdap.cdap.common.logging.LoggingContext;
import io.cdap.cdap.common.logging.LoggingContextAccessor;
import io.cdap.cdap.common.logging.ServiceLoggingContext;
import io.cdap.cdap.features.Feature;
import io.cdap.cdap.logging.appender.LogAppenderInitializer;
import io.cdap.cdap.logging.guice.KafkaLogAppenderModule;
import io.cdap.cdap.logging.guice.RemoteLogAppenderModule;
Expand Down Expand Up @@ -100,7 +103,13 @@ public static Injector createInjector(CConfiguration cConf, Configuration hConf)

modules.add(new ConfigModule(cConf, hConf));
modules.add(new IOModule());
modules.add(RemoteAuthenticatorModules.getDefaultModule());
FeatureFlagsProvider featureFlagsProvider = new DefaultFeatureFlagsProvider(cConf);
if (Feature.NAMESPACED_SERVICE_ACCOUNTS.isEnabled(featureFlagsProvider)) {
modules.add(RemoteAuthenticatorModules.getDefaultModule(
Constants.ArtifactLocalizer.REMOTE_AUTHENTICATOR_NAME));
} else {
modules.add(RemoteAuthenticatorModules.getDefaultModule());
}
modules.add(new AuthenticationContextModules().getMasterModule());
modules.add(coreSecurityModule);
modules.add(new MessagingServiceModule(cConf));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.cdap.cdap.internal.app.worker.sidecar;

import com.google.common.base.Strings;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
Expand All @@ -37,17 +38,15 @@
import io.cdap.cdap.proto.credential.NamespaceCredentialProvider;
import io.cdap.cdap.proto.credential.NotFoundException;
import io.cdap.cdap.proto.credential.ProvisionedCredential;
import io.cdap.cdap.proto.security.Credential;
import io.cdap.cdap.proto.security.GcpMetadataTaskContext;
import io.cdap.common.http.HttpRequests;
import io.cdap.common.http.HttpResponse;
import io.cdap.cdap.security.spi.authenticator.RemoteAuthenticator;
import io.cdap.http.HttpHandler;
import io.cdap.http.HttpResponder;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import java.io.IOException;
import java.net.URL;
import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.ExecutionException;
Expand All @@ -57,7 +56,6 @@
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.QueryParam;
import joptsimple.internal.Strings;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -73,8 +71,8 @@ public class GcpMetadataHttpHandlerInternal extends AbstractAppFabricHttpHandler
private static final Gson GSON = new GsonBuilder().registerTypeAdapter(BasicThrowable.class,
new BasicThrowableCodec()).create();
private final CConfiguration cConf;
private final String metadataServiceTokenEndpoint;
private final NamespaceCredentialProvider credentialProvider;
private final RemoteAuthenticator remoteAuthenticator;
private final GcpWorkloadIdentityInternalAuthenticator gcpWorkloadIdentityInternalAuthenticator;
private GcpMetadataTaskContext gcpMetadataTaskContext;
private final LoadingCache<ProvisionedCredentialCacheKey,
Expand All @@ -86,10 +84,9 @@ public class GcpMetadataHttpHandlerInternal extends AbstractAppFabricHttpHandler
* @param cConf CConfiguration
*/
public GcpMetadataHttpHandlerInternal(CConfiguration cConf,
RemoteClientFactory remoteClientFactory) {
RemoteClientFactory remoteClientFactory, RemoteAuthenticator remoteAuthenticator) {
this.cConf = cConf;
this.metadataServiceTokenEndpoint = cConf.get(
Constants.TaskWorker.METADATA_SERVICE_END_POINT);
this.remoteAuthenticator = remoteAuthenticator;
this.gcpWorkloadIdentityInternalAuthenticator =
new GcpWorkloadIdentityInternalAuthenticator(gcpMetadataTaskContext);
this.credentialProvider = new RemoteNamespaceCredentialProvider(remoteClientFactory,
Expand Down Expand Up @@ -181,16 +178,17 @@ public void token(HttpRequest request, HttpResponder responder,
// fallback to gcp metadata server for backward compatibility.
}

if (metadataServiceTokenEndpoint == null) {
responder.sendString(HttpResponseStatus.NOT_IMPLEMENTED,
String.format("%s has not been set",
Constants.TaskWorker.METADATA_SERVICE_END_POINT));
return;
}

try {
responder.sendJson(HttpResponseStatus.OK,
fetchTokenFromMetadataServer(scopes).getResponseBodyAsString());
Credential credential = remoteAuthenticator.getCredentials();
if (credential == null || Strings.isNullOrEmpty(credential.getValue())) {
responder.sendJson(HttpResponseStatus.INTERNAL_SERVER_ERROR,
"Failed to fetch token from metadata server");
return;
}
GcpTokenResponse gcpTokenResponse =
new GcpTokenResponse(credential.getType().getQualifiedName(), credential.getValue(),
credential.getExpirationTimeSecs());
responder.sendJson(HttpResponseStatus.OK, GSON.toJson(gcpTokenResponse));
} catch (Exception ex) {
LOG.error("Failed to fetch token from metadata server", ex);
responder.sendJson(HttpResponseStatus.INTERNAL_SERVER_ERROR, exceptionToJson(ex));
Expand All @@ -204,17 +202,6 @@ private ProvisionedCredential fetchTokenFromCredentialProvider(
RetryStrategies.fromConfiguration(cConf, Constants.Service.TASK_WORKER + "."));
}

private HttpResponse fetchTokenFromMetadataServer(String scopes) throws IOException {
URL url = new URL(metadataServiceTokenEndpoint);
if (!Strings.isNullOrEmpty(scopes)) {
url = new URL(String.format("%s?scopes=%s", metadataServiceTokenEndpoint, scopes));
}
io.cdap.common.http.HttpRequest tokenRequest = io.cdap.common.http.HttpRequest.get(url)
.addHeader(METADATA_FLAVOR_HEADER_KEY, METADATA_FLAVOR_HEADER_VALUE)
.build();
return HttpRequests.execute(tokenRequest);
}

/**
* Sets the CDAP Namespace information.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import io.cdap.cdap.common.http.CommonNettyHttpServiceBuilder;
import io.cdap.cdap.common.id.Id;
import io.cdap.cdap.common.internal.remote.DefaultInternalAuthenticator;
import io.cdap.cdap.common.internal.remote.NoOpRemoteAuthenticator;
import io.cdap.cdap.common.internal.remote.RemoteClientFactory;
import io.cdap.cdap.common.internal.remote.TaskWorkerHttpHandlerInternal;
import io.cdap.cdap.common.metrics.NoOpMetricsCollectionService;
Expand Down Expand Up @@ -126,7 +127,8 @@ public static void init() throws Exception {
new ArtifactLocalizer(cConf, remoteClientFactory,
((namespaceId, retryStrategy) -> new NoOpArtifactManager()))
),
new GcpMetadataHttpHandlerInternal(cConf, remoteClientFactory)
new GcpMetadataHttpHandlerInternal(cConf, remoteClientFactory,
new NoOpRemoteAuthenticator())
)
.setChannelPipelineModifier(new ChannelPipelineModifier() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.cdap.cdap.common.id.Id;
import io.cdap.cdap.common.internal.remote.DefaultInternalAuthenticator;
import io.cdap.cdap.common.internal.remote.NoOpInternalAuthenticator;
import io.cdap.cdap.common.internal.remote.NoOpRemoteAuthenticator;
import io.cdap.cdap.common.internal.remote.RemoteClientFactory;
import io.cdap.cdap.common.io.Locations;
import io.cdap.cdap.common.metrics.NoOpMetricsCollectionService;
Expand Down Expand Up @@ -85,7 +86,7 @@ private ArtifactLocalizerService setupArtifactLocalizerService(CConfiguration cC
cConf, new ArtifactLocalizer(cConf, remoteClientFactory, (namespaceId, retryStrategy) -> {
return new NoOpArtifactManager();
}), new CommonNettyHttpServiceFactory(cConf, new NoOpMetricsCollectionService()),
remoteClientFactory);
remoteClientFactory, new NoOpRemoteAuthenticator());
// start the service
artifactLocalizerService.startAndWait();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.cdap.cdap.common.conf.CConfiguration;
import io.cdap.cdap.common.conf.Constants;
import io.cdap.cdap.common.http.CommonNettyHttpServiceBuilder;
import io.cdap.cdap.common.internal.remote.NoOpRemoteAuthenticator;
import io.cdap.cdap.common.internal.remote.RemoteClientFactory;
import io.cdap.cdap.common.metrics.NoOpMetricsCollectionService;
import io.cdap.cdap.common.namespace.InMemoryNamespaceAdmin;
Expand Down Expand Up @@ -69,7 +70,8 @@ public static void init() throws Exception {
httpService = new CommonNettyHttpServiceBuilder(cConf, "test",
new NoOpMetricsCollectionService())
.setHttpHandlers(
new GcpMetadataHttpHandlerInternal(cConf, remoteClientFactory)
new GcpMetadataHttpHandlerInternal(cConf, remoteClientFactory,
new NoOpRemoteAuthenticator())
)
.setChannelPipelineModifier(new ChannelPipelineModifier() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public Credential getCredentials() throws IOException {
if (accessToken == null || accessToken.getExpirationTime().before(Date.from(clock.instant()))) {
accessToken = googleCredentials.refreshAccessToken();
}
return new Credential(accessToken.getTokenValue(), Credential.CredentialType.EXTERNAL_BEARER);
return new Credential(accessToken.getTokenValue(), Credential.CredentialType.EXTERNAL_BEARER,
accessToken.getExpirationTime().getTime() / 1000L);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,8 @@ public static final class ArtifactLocalizer {
public static final String WORKER_THREADS = "artifact.localizer.worker.threads";
public static final String PRELOAD_LIST = "artifact.localizer.preload.list";
public static final String PRELOAD_VERSION_LIMIT = "artifact.localizer.preload.version.limit";
public static final String REMOTE_AUTHENTICATOR_NAME =
"artifact.localizer.remote.authenticator.name";
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class KubeTwillPreparer implements DependentTwillPreparer, StatefulTwillPreparer
private final String resourcePrefix;
private final Map<String, String> extraLabels;
private final Map<String, SecretDiskRunnable> secretDiskRunnables;
private final Set<String> withNamespaceWorkloadIdentityRunnables;
private final Map<String, V1SecurityContext> containerSecurityContexts;
private final Map<String, Set<String>> readonlyDisks;
private final Map<String, Map<String, String>> runnableConfigs;
Expand Down Expand Up @@ -240,6 +241,7 @@ class KubeTwillPreparer implements DependentTwillPreparer, StatefulTwillPreparer
this.dependentRunnableNames = new HashSet<>();
this.serviceAccountName = null;
this.secretDiskRunnables = new HashMap<>();
this.withNamespaceWorkloadIdentityRunnables = new HashSet<>();
this.containerSecurityContexts = new HashMap<>();
this.readonlyDisks = new HashMap<>();
this.runnableConfigs = runnables.stream()
Expand Down Expand Up @@ -368,6 +370,12 @@ public SecureTwillPreparer withSecretDisk(String runnableName, SecretDisk... sec
return this;
}

@Override
public SecureTwillPreparer withNamespacedWorkloadIdentity(String runnableName) {
withNamespaceWorkloadIdentityRunnables.add(runnableName);
return this;
}

@Override
public SecureTwillPreparer withSecurityContext(String runnableName,
SecurityContext securityContext) {
Expand Down Expand Up @@ -1285,9 +1293,8 @@ private List<V1Container> createContainers(Map<String, RuntimeSpecification> run
environs.put(JAVA_OPTS_KEY, jvmOpts);
// Add workload identity environment variable if applicable.
if (workloadIdentityEnabled && WorkloadIdentityUtil.shouldMountWorkloadIdentity(
cdapInstallNamespace,
programRuntimeNamespace,
workloadIdentityServiceAccount)) {
cdapInstallNamespace, programRuntimeNamespace, workloadIdentityServiceAccount)
&& !withNamespaceWorkloadIdentityRunnables.contains(runnableName)) {
V1EnvVar workloadIdentityEnvVar = WorkloadIdentityUtil.generateWorkloadIdentityEnvVar();
environs.put(workloadIdentityEnvVar.getName(), workloadIdentityEnvVar.getValue());
}
Expand All @@ -1314,6 +1321,13 @@ private List<V1Container> createContainers(Map<String, RuntimeSpecification> run
.filter(entry -> !entry.getKey().equals(GCE_METADATA_HOST_ENV_VAR))
.collect(Collectors.toMap(Map.Entry::getKey,
Map.Entry::getValue));
// Add workload identity environment variable in the dependent runnable if applicable.
if (workloadIdentityEnabled && WorkloadIdentityUtil.shouldMountWorkloadIdentity(
cdapInstallNamespace, programRuntimeNamespace, workloadIdentityServiceAccount)
&& !withNamespaceWorkloadIdentityRunnables.contains(name)) {
V1EnvVar workloadIdentityEnvVar = WorkloadIdentityUtil.generateWorkloadIdentityEnvVar();
envs.put(workloadIdentityEnvVar.getName(), workloadIdentityEnvVar.getValue());
}
mounts = addSecreteVolMountIfNeeded(spec, volumeMounts);
containers.add(
createContainer(name, podInfo.getContainerImage(), podInfo.getImagePullPolicy(), workDir,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,13 @@ public interface SecureTwillPreparer extends TwillPreparer {
SecureTwillPreparer withSecurityContext(String runnableName,
SecurityContext securityContext);

/**
* Runs the given runnable with namespace workload identity,
* this feature removes the GOOGLE_APPLICATION_CREDENTIALS environment variable
* to enable namespaced service accounts.
*
* @param runnableName name of the {@link TwillRunnable}
* @return this {@link TwillPreparer}
*/
SecureTwillPreparer withNamespacedWorkloadIdentity(String runnableName);
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,31 @@ public static CredentialType fromQualifiedName(String qualifiedName) {

private final String value;
private final CredentialType type;
private final Long expirationTimeSecs;

/**
* Constructs the Credential.
*
* @param value credential value
* @param type credential type
*/
public Credential(String value, CredentialType type) {
this.value = value;
this.type = type;
this.expirationTimeSecs = null;
}

/**
* Constructs the Credential.
*
* @param value credential value
* @param type credential type
* @param expirationTimeSecs the time in seconds after which credential will expire
*/
public Credential(String value, CredentialType type, Long expirationTimeSecs) {
this.value = value;
this.type = type;
this.expirationTimeSecs = expirationTimeSecs;
}

public String getValue() {
Expand All @@ -98,10 +119,15 @@ public CredentialType getType() {
return type;
}

public Long getExpirationTimeSecs() {
return expirationTimeSecs;
}

@Override
public String toString() {
return "Credential{"
+ "type=" + type
+ ", expires_in=" + expirationTimeSecs
+ ", length=" + value.length()
+ "}";
}
Expand Down

0 comments on commit a8a443d

Please sign in to comment.