Skip to content

Commit

Permalink
filter support
Browse files Browse the repository at this point in the history
  • Loading branch information
yawkat committed Dec 7, 2023
1 parent 1df33d6 commit 6fd5290
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 59 deletions.
3 changes: 2 additions & 1 deletion core/build.gradle
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import me.champeau.gradle.japicmp.JapicmpTask
import io.micronaut.build.internal.japicmp.RemovedPackages
import me.champeau.gradle.japicmp.JapicmpTask

plugins {
id "io.micronaut.build.internal.convention-core-library"
Expand All @@ -15,6 +15,7 @@ dependencies {
compileOnly libs.managed.jakarta.annotation.api
compileOnly libs.graal
compileOnly libs.managed.kotlin.stdlib
compileOnly libs.managed.netty.common
}

spotless {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,8 @@ final class PropagatedContextImpl implements PropagatedContext {

static final PropagatedContextImpl EMPTY = new PropagatedContextImpl(new PropagatedContextElement[0], false);

private static final ThreadLocal<PropagatedContextImpl> THREAD_CONTEXT = new ThreadLocal<>() {
@Override
public String toString() {
return "Micronaut Propagation Context";
}
};

private static final Scope CLEANUP = THREAD_CONTEXT::remove;
private static final Scope CLEANUP = ThreadContext::remove;
private static final Scope NOOP = () -> {};

private final PropagatedContextElement[] elements;
private final boolean containsThreadElements;
Expand Down Expand Up @@ -78,28 +72,28 @@ private static boolean isThreadElement(PropagatedContextElement element) {
}

public static boolean exists() {
PropagatedContextImpl propagatedContext = PropagatedContextImpl.THREAD_CONTEXT.get();
PropagatedContextImpl propagatedContext = ThreadContext.get();
if (propagatedContext == null) {
return false;
}
return propagatedContext.elements.length != 0;
}

public static PropagatedContextImpl get() {
PropagatedContextImpl propagatedContext = THREAD_CONTEXT.get();
PropagatedContextImpl propagatedContext = ThreadContext.get();
if (propagatedContext == null) {
throw new IllegalStateException("No active propagation context!");
}
return propagatedContext;
}

public static Optional<PropagatedContext> find() {
return Optional.ofNullable(THREAD_CONTEXT.get());
return Optional.ofNullable(ThreadContext.get());
}

@NonNull
public static PropagatedContextImpl getOrEmpty() {
PropagatedContextImpl propagatedContext = THREAD_CONTEXT.get();
PropagatedContextImpl propagatedContext = ThreadContext.get();
if (propagatedContext == null) {
return EMPTY;
}
Expand Down Expand Up @@ -185,25 +179,30 @@ public List<PropagatedContextElement> getAllElements() {

@Override
public Scope propagate() {
PropagatedContextImpl prevCtx = THREAD_CONTEXT.get();
Scope restore = prevCtx == null ? CLEANUP : () -> THREAD_CONTEXT.set(prevCtx);
if (prevCtx == this) {
return restore;
}
if (elements.length == 0) {
THREAD_CONTEXT.remove();
return restore;
PropagatedContextImpl prevCtx = ThreadContext.get();
Scope restore;
if (prevCtx == null && elements.length == 0) {
return NOOP;
} else if (prevCtx == null) {
restore = CLEANUP;
} else { // elements.length == 0
restore = () -> ThreadContext.set(prevCtx);
if (elements.length == 0) {
ThreadContext.remove();
return restore;
}
}

PropagatedContextImpl ctx = this;
THREAD_CONTEXT.set(ctx);
ThreadContext.set(ctx);
if (containsThreadElements) {
List<Map.Entry<ThreadPropagatedContextElement<Object>, Object>> threadState = ctx.updateThreadState();
return () -> {
ctx.restoreState(threadState);
if (prevCtx == null) {
THREAD_CONTEXT.remove();
ThreadContext.remove();
} else {
THREAD_CONTEXT.set(prevCtx);
ThreadContext.set(prevCtx);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright 2017-2023 original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.micronaut.core.propagation;

import io.netty.util.concurrent.FastThreadLocal;

@SuppressWarnings("unchecked")
final class ThreadContext {
private static final Object FAST;
private static final ThreadLocal<PropagatedContextImpl> SLOW;

static {
Object fast;
ThreadLocal<PropagatedContextImpl> slow;
try {
fast = new FastThreadLocal<PropagatedContextImpl>();
slow = null;
} catch (NoClassDefFoundError e) {
fast = null;
slow = new ThreadLocal<>() {
@Override
public String toString() {
return "Micronaut Propagation Context";
}
};
}
FAST = fast;
SLOW = slow;
}

static void remove() {
if (FAST == null) {
SLOW.remove();
} else {
((FastThreadLocal<PropagatedContextImpl>) FAST).remove();
}
}

static PropagatedContextImpl get() {
if (FAST == null) {
return SLOW.get();
} else {
return ((FastThreadLocal<PropagatedContextImpl>) FAST).get();
}
}

static void set(PropagatedContextImpl value) {
if (FAST == null) {
SLOW.set(value);
} else {
((FastThreadLocal<PropagatedContextImpl>) FAST).set(value);
}
}
}
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ managed-methvin-directoryWatcher = { module = "io.methvin:directory-watcher", ve
managed-netty-buffer = { module = "io.netty:netty-buffer", version.ref = "managed-netty" }
managed-netty-codec-http = { module = "io.netty:netty-codec-http", version.ref = "managed-netty" }
managed-netty-codec-http2 = { module = "io.netty:netty-codec-http2", version.ref = "managed-netty" }
managed-netty-common = { module = "io.netty:netty-common", version.ref = "managed-netty" }
managed-netty-incubator-codec-http3 = { module = "io.netty.incubator:netty-incubator-codec-http3", version.ref = "managed-netty-http3" }
managed-netty-handler = { module = "io.netty:netty-handler", version.ref = "managed-netty" }
managed-netty-handler-proxy = { module = "io.netty:netty-handler-proxy", version.ref = "managed-netty" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.micronaut.core.annotation.Nullable;
import io.micronaut.core.async.publisher.Publishers;
import io.micronaut.core.convert.ConversionService;
import io.micronaut.core.execution.ExecutionFlow;
import io.micronaut.core.io.buffer.ByteBuffer;
import io.micronaut.core.io.buffer.ByteBufferFactory;
import io.micronaut.core.propagation.PropagatedContext;
Expand Down Expand Up @@ -56,7 +57,6 @@
import io.micronaut.http.netty.body.ShortCircuitNettyBodyWriter;
import io.micronaut.http.netty.channel.ChannelPipelineCustomizer;
import io.micronaut.http.netty.stream.JsonSubscriber;
import io.micronaut.http.netty.stream.StreamedHttpRequest;
import io.micronaut.http.netty.stream.StreamedHttpResponse;
import io.micronaut.http.server.RouteExecutor;
import io.micronaut.http.server.binding.RequestArgumentSatisfier;
Expand Down Expand Up @@ -110,6 +110,7 @@
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -290,10 +291,11 @@ private ExecutionLeaf<RequestHandler> shortCircuitHandler(MatchRule rule, UriRou
return ExecutionLeaf.indeterminate();
}
List<GenericHttpFilter> fixedFilters = routeExecutor.getRouter().getFixedFilters().orElse(null);
// CorsFilter is handled specially here. It's always present, so we can't bail, but it only does anything when the Origin header is set, which is checked in accept().
if (fixedFilters == null || !fixedFilters.stream().allMatch(ghf -> FilterRunner.isCorsFilter(ghf, CorsFilter.class))) {
if (fixedFilters == null) {
return ExecutionLeaf.indeterminate();
}
// CorsFilter is handled specially here. It's always present, so we can't bail, but it only does anything when the Origin header is set, which is checked in accept().
fixedFilters = fixedFilters.stream().filter(f -> !FilterRunner.isCorsFilter(f, CorsFilter.class)).toList();
MethodExecutionHandle<?, ?> executionHandle = routeInfo.getTargetMethod();
if (executionHandle.getReturnType().isOptional() ||
executionHandle.getReturnType().getType() == HttpStatus.class) {
Expand Down Expand Up @@ -346,37 +348,45 @@ private ExecutionLeaf<RequestHandler> shortCircuitHandler(MatchRule rule, UriRou
} else {
return ExecutionLeaf.indeterminate();
}
List<GenericHttpFilter> finalFixedFilters = fixedFilters;
BiFunction<HttpRequest<?>, PropagatedContext, ExecutionFlow<HttpResponse<?>>> exec = (httpRequest, propagatedContext) -> {
Object[] arguments = shortCircuitBinders.length == 0 ? ArrayUtils.EMPTY_OBJECT_ARRAY : new Object[shortCircuitBinders.length];
ImmediateByteBody body = (ImmediateByteBody) ((NettyHttpRequest<?>) httpRequest).byteBody();
for (int i = 0; i < arguments.length; i++) {
arguments[i] = shortCircuitBinders[i].bind(httpRequest.getHeaders(), body);
}
Object result = unsafeExecutionHandle.invokeUnsafe(arguments);
if (unwrapResponse) {
return ExecutionFlow.just((HttpResponse<?>) result);
} else {
return ExecutionFlow.just(HttpResponse.ok(result));
}
};
return new ExecutionLeaf.Route<>(new RequestHandler() {
@Override
public void accept(ChannelHandlerContext ctx, io.netty.handler.codec.http.HttpRequest request, ByteBody body, PipeliningServerHandler.OutboundAccess outboundAccess) {
try {
NettyHttpHeaders requestHeaders = new NettyHttpHeaders(request.headers(), conversionService);

Object[] arguments = shortCircuitBinders.length == 0 ? ArrayUtils.EMPTY_OBJECT_ARRAY : new Object[shortCircuitBinders.length];
for (int i = 0; i < arguments.length; i++) {
arguments[i] = shortCircuitBinders[i].bind(request, requestHeaders, (ImmediateByteBody) body);
}
Object result = unsafeExecutionHandle.invokeUnsafe(arguments);
HttpResponseStatus status = HttpResponseStatus.OK;
HttpHeaders responseHeaders;
if (unwrapResponse) {
HttpResponse<?> resp = (HttpResponse<?>) result;
responseHeaders = ((NettyHttpHeaders) resp.getHeaders()).getNettyHeaders();
if (!responseHeaders.contains(HttpHeaderNames.CONTENT_TYPE)) {
responseHeaders.set(HttpHeaderNames.CONTENT_TYPE, responseMediaType.toString());
NettyHttpRequest<Object> nhr = new NettyHttpRequest<>(request, body, ctx, conversionService, serverConfiguration);
outboundAccess.attachment(nhr);

new FilterRunner(finalFixedFilters, exec).run(nhr, PropagatedContext.empty()).onComplete((response, err) -> {
if (err != null) {
RoutingInBoundHandler.this.handleUnboundError(err);
} else {
HttpHeaders responseHeaders = ((NettyHttpHeaders) response.getHeaders()).getNettyHeaders();
if (!responseHeaders.contains(HttpHeaderNames.CONTENT_TYPE)) {
responseHeaders.set(HttpHeaderNames.CONTENT_TYPE, responseMediaType.toString());
}
HttpResponseStatus status = HttpResponseStatus.valueOf(response.code(), response.reason());
if (scWriter != null) {
scWriter.writeTo(nhr.getHeaders(), status, responseHeaders, response.body(), outboundAccess);
} else {
ByteBuf buf = (ByteBuf) rawWriter.writeTo(responseMediaType, response.body(), NettyByteBufferFactory.DEFAULT).asNativeBuffer();
outboundAccess.writeFull(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buf, responseHeaders, EmptyHttpHeaders.INSTANCE));
}
}
status = HttpResponseStatus.valueOf(resp.code(), resp.reason());
result = resp.body();
} else {
responseHeaders = new DefaultHttpHeaders();
responseHeaders.set(HttpHeaderNames.CONTENT_TYPE, responseMediaType.toString());
}
if (scWriter != null) {
scWriter.writeTo(requestHeaders, status, responseHeaders, result, outboundAccess);
} else {
ByteBuf buf = (ByteBuf) rawWriter.writeTo(responseMediaType, result, NettyByteBufferFactory.DEFAULT).asNativeBuffer();
outboundAccess.writeFull(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buf, responseHeaders, EmptyHttpHeaders.INSTANCE));
}
});

} catch (Exception e) {
RoutingInBoundHandler.this.handleUnboundError(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import io.micronaut.core.convert.ConversionService;
import io.micronaut.core.convert.value.ConvertibleValues;
import io.micronaut.core.execution.ExecutionFlow;
import io.micronaut.http.HttpAttributes;
import io.micronaut.core.type.Argument;
import io.micronaut.http.HttpAttributes;
import io.micronaut.http.HttpRequest;
import io.micronaut.http.MediaType;
import io.micronaut.http.annotation.Body;
Expand All @@ -37,8 +37,8 @@
import io.micronaut.http.server.netty.NettyHttpRequest;
import io.micronaut.http.server.netty.body.ImmediateByteBody;
import io.micronaut.http.server.netty.shortcircuit.ShortCircuitArgumentBinder;
import io.micronaut.web.router.shortcircuit.MatchRule;
import io.micronaut.web.router.RouteInfo;
import io.micronaut.web.router.shortcircuit.MatchRule;

import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -194,7 +194,7 @@ public Optional<Prepared> prepare(Argument<T> argument, MatchRule.ContentType fi
}
reader = opt.get();
}
return Optional.of((nettyRequest, mnHeaders, body) -> {
return Optional.of((mnHeaders, body) -> {
if (body.empty()) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import io.micronaut.http.bind.binders.RequestArgumentBinder;
import io.micronaut.http.server.netty.body.ImmediateByteBody;
import io.micronaut.web.router.shortcircuit.MatchRule;
import io.netty.handler.codec.http.HttpRequest;

import java.util.Optional;

Expand Down Expand Up @@ -52,11 +51,10 @@ interface Prepared {
/**
* Bind the parameter.
*
* @param nettyRequest The netty request
* @param mnHeaders The request headers (micronaut-http class)
* @param body The request body
* @return The bound argument
*/
Object bind(@NonNull HttpRequest nettyRequest, HttpHeaders mnHeaders, @NonNull ImmediateByteBody body);
Object bind(HttpHeaders mnHeaders, @NonNull ImmediateByteBody body);
}
}

0 comments on commit 6fd5290

Please sign in to comment.