Skip to content

Commit

Permalink
Add MDC context in SecurityContdxtTenantAware (#1818)
Browse files Browse the repository at this point in the history
Signed-off-by: Marinov Avgustin <[email protected]>
  • Loading branch information
avgustinmm authored Aug 13, 2024
1 parent 96d8831 commit 9bb61fd
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ public static MDCHandler getInstance() {
return SINGLETON;
}

/**
* Executes callable and returns the result. If MDC is enabled, it sets the tenant and / or user in the MDC context.
*
* @param <T> the return type
* @param callable the callable to execute
* @return the result
* @throws Exception if thrown by the callable
*/
public <T> T withLogging(final Callable<T> callable) throws Exception {
if (!mdcEnabled) {
return callable.call();
Expand Down Expand Up @@ -81,29 +89,41 @@ public <T> T withLogging(final Callable<T> callable) throws Exception {
}
}

private <T> T putUserAndCall(final Callable<T> callable) throws WrappedException {
/**
* With logging throwing Runtime Exception (wihtLoggingRE). Calls the {@link #withLogging(Callable)} method and
* wraps any catchable exception into a {@link RuntimeException}.
*
* @param <T> the return type
* @param callable the callable to execute
* @return the result
*/
public <T> T withLoggingRE(final Callable<T> callable) {
try {
return withLogging(callable);
} catch (final RuntimeException re) {
throw re;
} catch (final Exception e) {
throw new RuntimeException(e);
}
}

private <T> T putUserAndCall(final Callable<T> callable) throws Exception {
final String user = springSecurityAuditorAware
.getCurrentAuditor()
.filter(username -> !username.equals("system")) // null and system are the same - system user
.map(username -> (securityContext != null && securityContext.isCurrentThreadSystemCode() ? "as " : "") + username)
.orElse(null);

final String currentUser = MDC.get(MDC_KEY_USER);
try {
if (Objects.equals(currentUser, user)) {
if (Objects.equals(currentUser, user)) {
return callable.call();
} else {
put(MDC_KEY_USER, user);
try {
return callable.call();
} else {
put(MDC_KEY_USER, user);
try {
return callable.call();
} finally {
put(MDC_KEY_USER, currentUser);
}
} finally {
put(MDC_KEY_USER, currentUser);
}
} catch (final RuntimeException e) {
throw e;
} catch (final Exception e) {
throw new WrappedException(e);
}
}

Expand All @@ -115,18 +135,6 @@ private static void put(final String key, final String value) {
}
}

// Wraps catchable exceptions to rethrow
public static class WrappedException extends Exception {

public WrappedException(final Throwable cause) {
super(cause);
}

public RuntimeException toRuntimeException() {
return new RuntimeException(getCause() == null ? this : getCause());
}
}

@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class Filter {

Expand All @@ -144,19 +152,9 @@ protected void doFilterInternal(
filterChain.doFilter(request, response);
return null;
});
} catch (final RuntimeException re) {
throw re;
} catch (final WrappedException we) {
final Throwable cause = we.getCause();
if (cause instanceof ServletException se) {
throw se;
} else if (cause instanceof IOException ioe) {
throw ioe;
} else {
throw we.toRuntimeException();
}
} catch (final ServletException | IOException | RuntimeException e) {
throw e;
} catch (final Exception e) {
// should never be here - if mdc is handler is enabled non-runtime exceptions are always wrapped
throw new RuntimeException(e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import org.eclipse.hawkbit.ContextAware;
Expand Down Expand Up @@ -43,8 +44,8 @@
public class SecurityContextTenantAware implements ContextAware {

public static final String SYSTEM_USER = "system";
private static final Collection<? extends GrantedAuthority> SYSTEM_AUTHORITIES = Collections
.singletonList(new SimpleGrantedAuthority(SpringEvalExpressions.SYSTEM_ROLE));
private static final Collection<? extends GrantedAuthority> SYSTEM_AUTHORITIES =
Collections.singletonList(new SimpleGrantedAuthority(SpringEvalExpressions.SYSTEM_ROLE));

private final UserAuthoritiesResolver authoritiesResolver;
private final SecurityContextSerializer securityContextSerializer;
Expand All @@ -66,11 +67,8 @@ public SecurityContextTenantAware(final UserAuthoritiesResolver authoritiesResol
* Creates the {@link SecurityContextTenantAware} based on the given
* {@link UserAuthoritiesResolver}.
*
* @param authoritiesResolver
* Resolver to retrieve the authorities for a given user. Must
* not be <code>null</code>.
* @param securityContextSerializer
* Serializer that is used to serialize / deserialize {@link SecurityContext}s.
* @param authoritiesResolver Resolver to retrieve the authorities for a given user. Must not be <code>null</code>.
* @param securityContextSerializer Serializer that is used to serialize / deserialize {@link SecurityContext}s.
*/
public SecurityContextTenantAware(final UserAuthoritiesResolver authoritiesResolver, @Nullable final SecurityContextSerializer securityContextSerializer) {
this.authoritiesResolver = authoritiesResolver;
Expand Down Expand Up @@ -106,24 +104,25 @@ public String getCurrentUsername() {
return null;
}

@Override
public Optional<String> getCurrentContext() {
return Optional.ofNullable(SecurityContextHolder.getContext()).map(securityContextSerializer::serialize);
}

@Override
public <T> T runAsTenant(final String tenant, final TenantRunner<T> tenantRunner) {
return runInContext(buildSystemSecurityContext(tenant), tenantRunner);
return runInContext(buildUserSecurityContext(tenant, SYSTEM_USER, SYSTEM_AUTHORITIES), tenantRunner::run);
}

@Override
public <T> T runAsTenantAsUser(final String tenant, final String username, final TenantRunner<T> tenantRunner) {
Objects.requireNonNull(tenant);
Objects.requireNonNull(username);

final List<SimpleGrantedAuthority> authorities = runAsSystem(
() -> authoritiesResolver.getUserAuthorities(tenant, username).stream().map(SimpleGrantedAuthority::new)
.collect(Collectors.toList()));
return runInContext(buildUserSecurityContext(tenant, username, authorities), tenantRunner);
}

@Override
public Optional<String> getCurrentContext() {
return Optional.ofNullable(SecurityContextHolder.getContext()).map(securityContextSerializer::serialize);
return runInContext(buildUserSecurityContext(tenant, username, authorities), tenantRunner::run);
}

@Override
Expand All @@ -133,45 +132,35 @@ public <T, R> R runInContext(final String serializedContext, final Function<T, R
final SecurityContext securityContext = securityContextSerializer.deserialize(serializedContext);
Objects.requireNonNull(securityContext);

return runInContext(securityContext, () -> function.apply(t));
}

private static <T> T runInContext(final SecurityContext securityContext, final Supplier<T> supplier) {
final SecurityContext originalContext = SecurityContextHolder.getContext();
if (Objects.equals(securityContext, originalContext)) {
return function.apply(t);
return supplier.get();
} else {
SecurityContextHolder.setContext(securityContext);
try {
return function.apply(t);
return MDCHandler.getInstance().withLoggingRE(supplier::get);
} finally {
SecurityContextHolder.setContext(originalContext);
}
}
}

private static <T> T runInContext(final SecurityContext context, final TenantRunner<T> tenantRunner) {
final SecurityContext originalContext = SecurityContextHolder.getContext();
try {
SecurityContextHolder.setContext(context);
return tenantRunner.run();
} finally {
SecurityContextHolder.setContext(originalContext);
}
}

private static SecurityContext buildSystemSecurityContext(final String tenant) {
return buildUserSecurityContext(tenant, SYSTEM_USER, SYSTEM_AUTHORITIES);
}

private static <T> T runAsSystem(final TenantRunner<T> tenantRunner) {
final SecurityContext currentContext = SecurityContextHolder.getContext();
SystemSecurityContext.setSystemContext(currentContext);
try {
SystemSecurityContext.setSystemContext(currentContext);
return tenantRunner.run();
return MDCHandler.getInstance().withLoggingRE(tenantRunner::run);
} finally {
SecurityContextHolder.setContext(currentContext);
}
}

private static SecurityContext buildUserSecurityContext(final String tenant, final String username,
final Collection<? extends GrantedAuthority> authorities) {
private static SecurityContext buildUserSecurityContext(
final String tenant, final String username, final Collection<? extends GrantedAuthority> authorities) {
final SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(new AuthenticationDelegate(
SecurityContextHolder.getContext().getAuthentication(), tenant, username, authorities));
Expand All @@ -189,21 +178,25 @@ private static final class AuthenticationDelegate implements Authentication {
private static final long serialVersionUID = 1L;

private final Authentication delegate;

private final TenantAwareUser principal;

private final TenantAwareAuthenticationDetails tenantAwareAuthenticationDetails;

private AuthenticationDelegate(final Authentication delegate, final String tenant, final String username,
final Collection<? extends GrantedAuthority> authorities) {
this.delegate = delegate;
this.principal = new TenantAwareUser(username, username, authorities, tenant);
principal = new TenantAwareUser(username, username, authorities, tenant);
tenantAwareAuthenticationDetails = new TenantAwareAuthenticationDetails(tenant, false);
}

@Override
public boolean equals(final Object another) {
return Objects.equals(delegate, another);
if (another instanceof Authentication anotherAuthentication) {
return Objects.equals(delegate, anotherAuthentication) &&
Objects.equals(principal, anotherAuthentication.getPrincipal()) &&
Objects.equals(tenantAwareAuthenticationDetails, anotherAuthentication.getDetails());
} else {
return false;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,14 @@ public <T> T runAsSystem(final Callable<T> callable) {
public <T> T runAsSystemAsTenant(final Callable<T> callable, final String tenant) {
final SecurityContext oldContext = SecurityContextHolder.getContext();
try {
log.debug("entering system code execution");
log.debug("Entering system code execution");
return tenantAware.runAsTenant(tenant, () -> {
try {
setSystemContext(SecurityContextHolder.getContext());
return MDCHandler.getInstance().withLogging(callable);
} catch (final RuntimeException e) {
throw e;
} catch (final Exception e) {
throw new RuntimeException(e);
}
setSystemContext(SecurityContextHolder.getContext());
return MDCHandler.getInstance().withLoggingRE(callable);
});
} finally {
SecurityContextHolder.setContext(oldContext);
log.debug("leaving system code execution");
log.debug("Leaving system code execution");
}
}

Expand All @@ -144,12 +138,8 @@ public <T> T runAsControllerAsTenant(@NotEmpty final String tenant, @NotNull fin
.singletonList(new SimpleGrantedAuthority(SpringEvalExpressions.CONTROLLER_ROLE_ANONYMOUS));
try {
return tenantAware.runAsTenant(tenant, () -> {
try {
setCustomSecurityContext(tenant, oldContext.getAuthentication().getPrincipal(), authorities);
return MDCHandler.getInstance().withLogging(callable);
} catch (final Exception e) {
throw new RuntimeException(e);
}
setCustomSecurityContext(tenant, oldContext.getAuthentication().getPrincipal(), authorities);
return MDCHandler.getInstance().withLoggingRE(callable);
});
} finally {
SecurityContextHolder.setContext(oldContext);
Expand Down

0 comments on commit 9bb61fd

Please sign in to comment.