diff --git a/servicetalk-grpc-netty/src/test/java/io/servicetalk/grpc/netty/KeepAliveTest.java b/servicetalk-grpc-netty/src/test/java/io/servicetalk/grpc/netty/KeepAliveTest.java new file mode 100644 index 0000000000..0911be8172 --- /dev/null +++ b/servicetalk-grpc-netty/src/test/java/io/servicetalk/grpc/netty/KeepAliveTest.java @@ -0,0 +1,184 @@ +/* + * Copyright © 2020 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.grpc.netty; + +import io.servicetalk.concurrent.api.CompositeCloseable; +import io.servicetalk.concurrent.api.Publisher; +import io.servicetalk.concurrent.api.Single; +import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout; +import io.servicetalk.grpc.api.GrpcClientBuilder; +import io.servicetalk.grpc.api.GrpcServerBuilder; +import io.servicetalk.grpc.api.GrpcServiceContext; +import io.servicetalk.grpc.netty.TesterProto.TestRequest; +import io.servicetalk.grpc.netty.TesterProto.TestResponse; +import io.servicetalk.grpc.netty.TesterProto.Tester.ServiceFactory; +import io.servicetalk.grpc.netty.TesterProto.Tester.TesterClient; +import io.servicetalk.grpc.netty.TesterProto.Tester.TesterService; +import io.servicetalk.http.netty.H2KeepAlivePolicies; +import io.servicetalk.http.netty.H2ProtocolConfig; +import io.servicetalk.http.netty.HttpProtocolConfigs; +import io.servicetalk.transport.api.HostAndPort; +import io.servicetalk.transport.api.ServerContext; +import io.servicetalk.transport.api.ServiceTalkSocketOptions; + +import org.junit.After; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.Collection; +import java.util.concurrent.TimeoutException; +import java.util.function.Function; + +import static io.servicetalk.concurrent.api.AsyncCloseables.newCompositeCloseable; +import static io.servicetalk.concurrent.api.Publisher.never; +import static io.servicetalk.concurrent.api.Single.succeeded; +import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress; +import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort; +import static java.time.Duration.ofSeconds; +import static java.util.Arrays.asList; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.MINUTES; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.fail; +import static org.junit.Assume.assumeThat; + +@RunWith(Parameterized.class) +public class KeepAliveTest { + private final TesterClient client; + private final ServerContext ctx; + + @Rule + public final Timeout timeout = new ServiceTalkTestTimeout(1, MINUTES); + private final long idleTimeoutMillis; + + public KeepAliveTest(final boolean keepAlivesFromClient, + final Function protocolConfigSupplier, + final long idleTimeoutMillis) throws Exception { + this.idleTimeoutMillis = idleTimeoutMillis; + GrpcServerBuilder serverBuilder = GrpcServers.forAddress(localAddress(0)); + if (!keepAlivesFromClient) { + serverBuilder.protocols(protocolConfigSupplier.apply("servicetalk-tests-server-wire-logger")); + } else { + serverBuilder.socketOption(ServiceTalkSocketOptions.IDLE_TIMEOUT, idleTimeoutMillis) + .protocols(HttpProtocolConfigs.h2() + .enableFrameLogging("servicetalk-tests-server-wire-logger").build()); + } + ctx = serverBuilder.listenAndAwait(new ServiceFactory(new InfiniteStreamsService())); + GrpcClientBuilder clientBuilder = + GrpcClients.forAddress(serverHostAndPort(ctx)); + if (keepAlivesFromClient) { + clientBuilder.protocols(protocolConfigSupplier.apply("servicetalk-tests-client-wire-logger")); + } else { + clientBuilder.socketOption(ServiceTalkSocketOptions.IDLE_TIMEOUT, idleTimeoutMillis) + .protocols(HttpProtocolConfigs.h2() + .enableFrameLogging("servicetalk-tests-client-wire-logger").build()); + } + client = clientBuilder.build(new TesterProto.Tester.ClientFactory()); + } + + @Parameterized.Parameters(name = "keepAlivesFromClient? {0}, idleTimeout: {2}") + public static Collection data() { + return asList(newParam(true, ofSeconds(10), ofSeconds(12)), + newParam(false, ofSeconds(10), ofSeconds(12))); + } + + private static Object[] newParam(final boolean keepAlivesFromClient, final Duration keepAliveIdleDuration, + final Duration idleTimeoutDuration) { + return new Object[] {keepAlivesFromClient, + (Function) frameLogger -> + HttpProtocolConfigs.h2() + .keepAlivePolicy(H2KeepAlivePolicies.whenIdleFor(keepAliveIdleDuration)) + .enableFrameLogging(frameLogger).build(), + idleTimeoutDuration.toMillis()}; + } + + @After + public void tearDown() throws Exception { + CompositeCloseable closeable = newCompositeCloseable().appendAll(client, ctx); + closeable.close(); + } + + @Test + public void bidiStream() throws Exception { + // Ignore test on CI due to high timeouts + assumeThat(ServiceTalkTestTimeout.CI, is(false)); + + try { + client.testBiDiStream(never()).toFuture().get(idleTimeoutMillis + 100, MILLISECONDS); + fail("Unexpected response available."); + } catch (TimeoutException e) { + // expected + } + } + + @Test + public void requestStream() throws Exception { + // Ignore test on CI due to high timeouts + assumeThat(ServiceTalkTestTimeout.CI, is(false)); + + try { + client.testRequestStream(never()).toFuture().get(idleTimeoutMillis + 100, MILLISECONDS); + fail("Unexpected response available."); + } catch (TimeoutException e) { + // expected + } + } + + @Test + public void responseStream() throws Exception { + // Ignore test on CI due to high timeouts + assumeThat(ServiceTalkTestTimeout.CI, is(false)); + + try { + client.testResponseStream(TestRequest.newBuilder().build()) + .toFuture().get(idleTimeoutMillis + 100, MILLISECONDS); + fail("Unexpected response available."); + } catch (TimeoutException e) { + // expected + } + } + + private static final class InfiniteStreamsService implements TesterService { + + @Override + public Publisher testBiDiStream(final GrpcServiceContext ctx, + final Publisher request) { + return request.map(testRequest -> TestResponse.newBuilder().build()); + } + + @Override + public Single testRequestStream(final GrpcServiceContext ctx, + final Publisher request) { + return request.collect(() -> null, (testResponse, testRequest) -> null) + .map(__ -> TestResponse.newBuilder().build()); + } + + @Override + public Publisher testResponseStream(final GrpcServiceContext ctx, final TestRequest request) { + return never(); + } + + @Override + public Single test(final GrpcServiceContext ctx, final TestRequest request) { + return succeeded(TestResponse.newBuilder().build()); + } + } +} diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/DefaultKeepAlivePolicy.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/DefaultKeepAlivePolicy.java new file mode 100644 index 0000000000..71f3c65c8a --- /dev/null +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/DefaultKeepAlivePolicy.java @@ -0,0 +1,49 @@ +/* + * Copyright © 2020 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.http.netty; + +import io.servicetalk.http.netty.H2ProtocolConfig.KeepAlivePolicy; + +import java.time.Duration; + +import static java.util.Objects.requireNonNull; + +final class DefaultKeepAlivePolicy implements KeepAlivePolicy { + private final Duration idleDuration; + private final Duration ackTimeout; + private final boolean withoutActiveStreams; + + DefaultKeepAlivePolicy(final Duration idleDuration, final Duration ackTimeout, final boolean withoutActiveStreams) { + this.idleDuration = requireNonNull(idleDuration); + this.ackTimeout = requireNonNull(ackTimeout); + this.withoutActiveStreams = withoutActiveStreams; + } + + @Override + public Duration idleDuration() { + return idleDuration; + } + + @Override + public Duration ackTimeout() { + return ackTimeout; + } + + @Override + public boolean withoutActiveStreams() { + return withoutActiveStreams; + } +} diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ClientParentConnectionContext.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ClientParentConnectionContext.java index 6af54c2db1..3d93156ff1 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ClientParentConnectionContext.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ClientParentConnectionContext.java @@ -81,8 +81,9 @@ final class H2ClientParentConnectionContext extends H2ParentConnectionContext { private H2ClientParentConnectionContext(Channel channel, BufferAllocator allocator, Executor executor, FlushStrategy flushStrategy, @Nullable Long idleTimeoutMs, - HttpExecutionStrategy executionStrategy) { - super(channel, allocator, executor, flushStrategy, idleTimeoutMs, executionStrategy); + HttpExecutionStrategy executionStrategy, + final KeepAliveManager keepAliveManager) { + super(channel, allocator, executor, flushStrategy, idleTimeoutMs, executionStrategy, keepAliveManager); } interface H2ClientParentConnection extends FilterableStreamingHttpConnection, NettyConnectionContext { @@ -103,8 +104,10 @@ protected void handleSubscribe(final Subscriber + * ping if the channel is idle for the passed {@code idleDuration}. Default values are used for other parameters + * of the returned {@link KeepAlivePolicy}. + * + * @param idleDuration {@link Duration} of idleness on a connection after which a + * ping is sent. + * @return A {@link KeepAlivePolicy} that sends a + * ping if the channel is idle for the passed {@code idleDuration}. + */ + public static KeepAlivePolicy whenIdleFor(final Duration idleDuration) { + return new KeepAlivePolicyBuilder().idleDuration(idleDuration).build(); + } + + /** + * Returns a {@link KeepAlivePolicy} that sends a + * ping if the channel is idle for the passed {@code idleDuration} and waits for {@code ackTimeout} for an ack + * for that ping. Default values are used for other + * parameters of the returned {@link KeepAlivePolicy}. + * + * @param idleDuration {@link Duration} of idleness on a connection after which a + * ping is sent. + * @param ackTimeout {@link Duration} to wait for an acknowledgment of a previously sent + * ping. + * @return A {@link KeepAlivePolicy} that sends a + * ping if the channel is idle for the passed {@code idleDuration} and waits for {@code ackTimeout} for an ack + * for that ping + */ + public static KeepAlivePolicy whenIdleFor(final Duration idleDuration, final Duration ackTimeout) { + return new KeepAlivePolicyBuilder().idleDuration(idleDuration).ackTimeout(ackTimeout).build(); + } + + /** + * A builder of {@link KeepAlivePolicy}. + */ + public static final class KeepAlivePolicyBuilder { + private Duration idleDuration = DEFAULT_IDLE_DURATION; + private Duration ackTimeout = DEFAULT_ACK_TIMEOUT; + private boolean withoutActiveStreams; + + /** + * Set the {@link Duration} of idleness on a connection after which a + * ping is sent. + *

+ * Too short ping durations may cause high network traffic, so a minimum duration may be + * enforced. + * + * @param idleDuration {@link Duration} of idleness on a connection after which a + * ping is sent. + * @return {@code this}. + * @see KeepAlivePolicy#idleDuration() + */ + public KeepAlivePolicyBuilder idleDuration(final Duration idleDuration) { + if (idleDuration.getSeconds() < 10 || idleDuration.toDays() > 1) { + throw new IllegalArgumentException("idleDuration: " + idleDuration + + " (expected >= 10 seconds and < 1 day)"); + } + this.idleDuration = idleDuration; + return this; + } + + /** + * Set the maximum {@link Duration} to wait for an acknowledgment of a previously sent + * ping. If no acknowledgment is received, the + * connection will be closed. + * + * @param ackTimeout {@link Duration} to wait for an acknowledgment of a previously sent + * ping. + * @return {@code this}. + * @see KeepAlivePolicy#ackTimeout() + */ + public KeepAlivePolicyBuilder ackTimeout(final Duration ackTimeout) { + this.ackTimeout = requireNonNull(ackTimeout); + return this; + } + + /** + * Allow/disallow sending pings even + * when no streams are active. + * + * @param withoutActiveStreams {@code true} if + * pings are expected when no streams are + * active. + * @return {@code this}. + * @see KeepAlivePolicy#withoutActiveStreams() + */ + public KeepAlivePolicyBuilder withoutActiveStreams(final boolean withoutActiveStreams) { + this.withoutActiveStreams = withoutActiveStreams; + return this; + } + + /** + * Build a new {@link KeepAlivePolicy}. + * + * @return new {@link KeepAlivePolicy}. + */ + public KeepAlivePolicy build() { + return new DefaultKeepAlivePolicy(idleDuration, ackTimeout, withoutActiveStreams); + } + } +} diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ParentConnectionContext.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ParentConnectionContext.java index 01f15f716d..371f9a8e52 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ParentConnectionContext.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ParentConnectionContext.java @@ -17,7 +17,7 @@ import io.servicetalk.buffer.api.BufferAllocator; import io.servicetalk.concurrent.Cancellable; -import io.servicetalk.concurrent.CompletableSource; +import io.servicetalk.concurrent.CompletableSource.Processor; import io.servicetalk.concurrent.SingleSource; import io.servicetalk.concurrent.api.Completable; import io.servicetalk.concurrent.api.Executor; @@ -36,70 +36,48 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.ChannelOutboundInvoker; -import io.netty.channel.EventLoop; -import io.netty.handler.codec.http2.DefaultHttp2GoAwayFrame; -import io.netty.handler.codec.http2.DefaultHttp2PingFrame; import io.netty.handler.codec.http2.Http2GoAwayFrame; import io.netty.handler.codec.http2.Http2PingFrame; import io.netty.handler.codec.http2.Http2SettingsAckFrame; import io.netty.handler.codec.http2.Http2SettingsFrame; import io.netty.handler.ssl.SslHandshakeCompletionEvent; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.net.SocketAddress; import java.net.SocketOption; -import java.util.concurrent.Delayed; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.ThreadLocalRandom; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import javax.annotation.Nullable; import javax.net.ssl.SSLSession; -import static io.netty.handler.codec.http2.Http2Error.NO_ERROR; import static io.netty.util.ReferenceCountUtil.release; import static io.servicetalk.concurrent.api.Processors.newCompletableProcessor; import static io.servicetalk.concurrent.api.Processors.newSingleProcessor; import static io.servicetalk.concurrent.api.SourceAdapters.fromSource; import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_2_0; -import static io.servicetalk.http.netty.H2ToStH1Utils.DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_MILLIS; import static io.servicetalk.transport.netty.internal.NettyIoExecutors.fromNettyEventLoop; import static io.servicetalk.transport.netty.internal.NettyPipelineSslUtils.extractSslSession; import static io.servicetalk.transport.netty.internal.SocketOptionUtils.getOption; -import static java.util.concurrent.TimeUnit.MILLISECONDS; class H2ParentConnectionContext extends NettyChannelListenableAsyncCloseable implements NettyConnectionContext, HttpConnectionContext { - - private static final AtomicIntegerFieldUpdater activeChildChannelsUpdater = - AtomicIntegerFieldUpdater.newUpdater(H2ParentConnectionContext.class, "activeChildChannels"); - private static final Logger LOGGER = LoggerFactory.getLogger(H2ParentConnectionContext.class); - private static final ScheduledFuture GRACEFUL_CLOSE_PING_PENDING = new NoopScheduledFuture(); - private static final ScheduledFuture GRACEFUL_CLOSE_PING_ACK_RECV = new NoopScheduledFuture(); - private static final long GRACEFUL_CLOSE_PING_CONTENT = ThreadLocalRandom.current().nextLong(); - private static final long GRACEFUL_CLOSE_PING_ACK_TIMEOUT_MS = 10000; final FlushStrategyHolder flushStrategyHolder; private final HttpExecutionContext executionContext; private final SingleSource.Processor transportError = newSingleProcessor(); - private final CompletableSource.Processor onClosing = newCompletableProcessor(); + private final Processor onClosing = newCompletableProcessor(); + private final KeepAliveManager keepAliveManager; @Nullable final Long idleTimeoutMs; @Nullable private SSLSession sslSession; - @Nullable - private ScheduledFuture gracefulCloseTimeoutFuture; - private volatile int activeChildChannels; - H2ParentConnectionContext(Channel channel, BufferAllocator allocator, Executor executor, - FlushStrategy flushStrategy, @Nullable Long idleTimeoutMs, - HttpExecutionStrategy executionStrategy) { + H2ParentConnectionContext(final Channel channel, final BufferAllocator allocator, final Executor executor, + final FlushStrategy flushStrategy, @Nullable final Long idleTimeoutMs, + final HttpExecutionStrategy executionStrategy, + final KeepAliveManager keepAliveManager) { super(channel, executor); this.executionContext = new DefaultHttpExecutionContext(allocator, fromNettyEventLoop(channel.eventLoop()), executor, executionStrategy); this.flushStrategyHolder = new FlushStrategyHolder(flushStrategy); this.idleTimeoutMs = idleTimeoutMs; + this.keepAliveManager = keepAliveManager; // Just in case the channel abruptly closes, we should complete the onClosing Completable. onClose().subscribe(onClosing::onComplete); } @@ -163,87 +141,11 @@ public final Channel nettyChannel() { @Override protected final void doCloseAsyncGracefully() { - EventLoop eventLoop = channel().eventLoop(); - if (eventLoop.inEventLoop()) { - doCloseAsyncGracefully0(); - } else { - try { - eventLoop.execute(this::doCloseAsyncGracefully0); - } catch (Throwable cause) { - close0(channel()); - LOGGER.warn("channel={} EventLoop rejected a task for graceful shutdown, force closing connection", - channel(), cause); - } - } - } - - final void doCloseAsyncGracefully0() { - if (gracefulCloseTimeoutFuture == null) { - // Set the gracefulCloseTimeoutFuture before doing the write, because we will reference the state - // when we receive the PING(ACK) to determine if action is necessary, and it is conceivable that the - // write future may not be executed which sets the timer. - gracefulCloseTimeoutFuture = GRACEFUL_CLOSE_PING_PENDING; - - onClosing.onComplete(); - - // The graceful close process is described in [1]. In general it involves sending 2 GOAWAY frames. The first - // GOAWAY has last-stream-id= to indicate no new streams can be created, wait for 2 RTT - // time duration for inflight frames to land, and the second GOAWAY includes the maximum known stream ID. - // To account for 2 RTTs we can send a PING and when the PING(ACK) comes back we can send the second GOAWAY. - // https://tools.ietf.org/html/rfc7540#section-6.8 - DefaultHttp2GoAwayFrame goAwayFrame = new DefaultHttp2GoAwayFrame(NO_ERROR); - goAwayFrame.setExtraStreamIds(Integer.MAX_VALUE); - channel().write(goAwayFrame); - channel().writeAndFlush(new DefaultHttp2PingFrame(GRACEFUL_CLOSE_PING_CONTENT)).addListener(future -> { - // If gracefulCloseTimeoutFuture is not GRACEFUL_CLOSE_PING_PENDING that means we have - // already received the PING(ACK) and there is no need to apply the timeout. - if (future.isSuccess() && gracefulCloseTimeoutFuture == GRACEFUL_CLOSE_PING_PENDING) { - gracefulCloseTimeoutFuture = channel().eventLoop().schedule(() -> { - // If the PING(ACK) times out we may have under estimated the 2RTT time so we - // optimistically keep the connection open and rely upon higher level timeouts to tear - // down the connection. - gracefulCloseWriteSecondGoAway(channel()); - LOGGER.debug("channel={} timeout {}ms waiting for PING(ACK) during graceful close", - channel(), GRACEFUL_CLOSE_PING_ACK_TIMEOUT_MS); - }, GRACEFUL_CLOSE_PING_ACK_TIMEOUT_MS, MILLISECONDS); - } - }); - } - } - - final void gracefulCloseWriteSecondGoAway(ChannelOutboundInvoker ctx) { - ctx.writeAndFlush(new DefaultHttp2GoAwayFrame(NO_ERROR)).addListener(future -> { - if (activeChildChannels == 0) { - close0(channel()); - } else if (future.isSuccess()) { - gracefulCloseTimeoutFuture = channel().eventLoop().schedule(() -> { - LOGGER.debug("channel={} timeout {}ms waiting for graceful close with {} active streams", - channel(), DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_MILLIS, activeChildChannels); - close0(channel()); - }, DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_MILLIS, MILLISECONDS); - } - }); + keepAliveManager.initiateGracefulClose(onClosing::onComplete); } final void trackActiveStream(Channel streamChannel) { - activeChildChannelsUpdater.incrementAndGet(this); - streamChannel.closeFuture().addListener(future1 -> { - activeChildChannelsUpdater.decrementAndGet(this); - if (activeChildChannels == 0 && gracefulCloseTimeoutFuture != null && - gracefulCloseTimeoutFuture != GRACEFUL_CLOSE_PING_PENDING) { - // gracefulCloseTimeoutFuture will be cancelled during connection closure elsewhere. - close0(channel()); - } - }); - } - - private static void close0(Channel channel) { - assert channel.eventLoop().inEventLoop(); - // The way netty H2 stream state machine works, we may trigger stream closures during writes with flushes - // pending behind the writes. In such cases, we may close too early ignoring the writes. Hence we flush before - // closure, if there is no write pending then flush is a noop. - channel.flush(); - channel.close(); + keepAliveManager.trackActiveStream(streamChannel); } abstract static class AbstractH2ParentConnection extends ChannelInboundHandlerAdapter { @@ -292,7 +194,7 @@ public final void channelInactive(ChannelHandlerContext ctx) { tryFailSubscriber(StacklessClosedChannelException.newInstance( H2ParentConnectionContext.class, "channelInactive(...)")); } - doConnectionCleanup(); + parentContext.keepAliveManager.channelClosed(); } @Override @@ -301,7 +203,7 @@ public final void handlerRemoved(ChannelHandlerContext ctx) { tryFailSubscriber(StacklessClosedChannelException.newInstance( H2ParentConnectionContext.class, "handlerRemoved(...)")); } - doConnectionCleanup(); + parentContext.keepAliveManager.channelClosed(); } @Override @@ -336,26 +238,14 @@ public final void channelRead(ChannelHandlerContext ctx, Object msg) { // We trigger the graceful close process here (with no timeout) to make sure the socket is closed once // the existing streams are closed. The MultiplexCodec may simulate a GOAWAY when the stream IDs are // exhausted so we shouldn't rely upon our peer to close the transport. - parentContext.doCloseAsyncGracefully0(); + parentContext.keepAliveManager.initiateGracefulClose(parentContext.onClosing::onComplete); } else if (msg instanceof Http2PingFrame) { - Http2PingFrame pingFrame = (Http2PingFrame) msg; - if (pingFrame.ack() && pingFrame.content() == GRACEFUL_CLOSE_PING_CONTENT && - parentContext.gracefulCloseTimeoutFuture != null) { - parentContext.gracefulCloseTimeoutFuture.cancel(true); - parentContext.gracefulCloseTimeoutFuture = GRACEFUL_CLOSE_PING_ACK_RECV; - parentContext.gracefulCloseWriteSecondGoAway(ctx); - } + parentContext.keepAliveManager.pingReceived((Http2PingFrame) msg); } else { ctx.fireChannelRead(msg); } } - private void doConnectionCleanup() { - if (parentContext.gracefulCloseTimeoutFuture != null) { - parentContext.gracefulCloseTimeoutFuture.cancel(true); - } - } - private void doChannelActive(ChannelHandlerContext ctx) { if (waitForSslHandshake) { // Force a read to get the SSL handshake started, any application data that makes it past the SslHandler @@ -366,51 +256,4 @@ private void doChannelActive(ChannelHandlerContext ctx) { } } } - - private static final class NoopScheduledFuture implements ScheduledFuture { - @Override - public long getDelay(final TimeUnit unit) { - return 0; - } - - @Override - public int compareTo(final Delayed o) { - return o == this ? 0 : 1; - } - - @Override - public boolean cancel(final boolean mayInterruptIfRunning) { - return false; - } - - @Override - public boolean isCancelled() { - return false; - } - - @Override - public boolean isDone() { - return false; - } - - @Override - public Object get() { - throw new UnsupportedOperationException(); - } - - @Override - public Object get(final long timeout, final TimeUnit unit) { - throw new UnsupportedOperationException(); - } - - @Override - public int hashCode() { - return 0; - } - - @Override - public boolean equals(Object o) { - return o == this; - } - } } diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ProtocolConfig.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ProtocolConfig.java index b912b158c4..105400fcd4 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ProtocolConfig.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ProtocolConfig.java @@ -19,6 +19,7 @@ import org.slf4j.event.Level; +import java.time.Duration; import java.util.function.BiPredicate; import javax.annotation.Nullable; @@ -53,4 +54,46 @@ default String alpnId() { */ @Nullable String frameLoggerName(); + + /** + * Configured {@link KeepAlivePolicy}. + * + * @return configured {@link KeepAlivePolicy} or {@code null} if none is configured. + */ + @Nullable + KeepAlivePolicy keepAlivePolicy(); + + /** + * A policy for sending PING frames to the peer. + */ + interface KeepAlivePolicy { + /** + * {@link Duration} of time the connection has to be idle before a + * ping is sent. + * + * @return {@link Duration} of time the connection has to be idle before a + * ping is sent. + */ + Duration idleDuration(); + + /** + * {@link Duration} to wait for acknowledgment from the peer after a + * ping is sent. If no acknowledgment is received, + * a closure of the connection will be initiated. + * + * @return {@link Duration} to wait for acknowledgment from the peer after a + * ping is sent. + */ + Duration ackTimeout(); + + /** + * Whether this policy allows to send pings + * even if there are no streams active on the connection. + * + * @return {@code true} if this policy allows to send + * pings even if there are no streams active on + * the connection. + */ + boolean withoutActiveStreams(); + } } diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ProtocolConfigBuilder.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ProtocolConfigBuilder.java index c7b12729fb..f8050c05ce 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ProtocolConfigBuilder.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ProtocolConfigBuilder.java @@ -17,6 +17,7 @@ import io.servicetalk.http.api.HttpHeaders; import io.servicetalk.http.api.HttpHeadersFactory; +import io.servicetalk.http.netty.H2ProtocolConfig.KeepAlivePolicy; import org.slf4j.event.Level; @@ -24,6 +25,7 @@ import javax.annotation.Nullable; import static io.servicetalk.http.netty.H2HeadersFactory.DEFAULT_SENSITIVITY_DETECTOR; +import static io.servicetalk.http.netty.H2KeepAlivePolicies.DISABLE_KEEP_ALIVE; import static java.util.Objects.requireNonNull; /** @@ -37,6 +39,8 @@ public final class H2ProtocolConfigBuilder { private BiPredicate headersSensitivityDetector = DEFAULT_SENSITIVITY_DETECTOR; @Nullable private String frameLoggerName; + @Nullable + private KeepAlivePolicy keepAlivePolicy; H2ProtocolConfigBuilder() { } @@ -81,13 +85,26 @@ public H2ProtocolConfigBuilder enableFrameLogging(final String loggerName) { return this; } + /** + * Sets the {@link KeepAlivePolicy} to use. + * + * @param policy {@link KeepAlivePolicy} to use. + * @return {@code this} + * @see H2KeepAlivePolicies + */ + public H2ProtocolConfigBuilder keepAlivePolicy(final KeepAlivePolicy policy) { + this.keepAlivePolicy = policy == DISABLE_KEEP_ALIVE ? null : requireNonNull(policy); + return this; + } + /** * Builds {@link H2ProtocolConfig}. * * @return {@link H2ProtocolConfig} */ public H2ProtocolConfig build() { - return new DefaultH2ProtocolConfig(headersFactory, headersSensitivityDetector, frameLoggerName); + return new DefaultH2ProtocolConfig(headersFactory, headersSensitivityDetector, frameLoggerName, + keepAlivePolicy); } private static final class DefaultH2ProtocolConfig implements H2ProtocolConfig { @@ -96,13 +113,16 @@ private static final class DefaultH2ProtocolConfig implements H2ProtocolConfig { private final BiPredicate headersSensitivityDetector; @Nullable private final String frameLoggerName; + @Nullable + private final KeepAlivePolicy keepAlivePolicy; DefaultH2ProtocolConfig(final HttpHeadersFactory headersFactory, final BiPredicate headersSensitivityDetector, - @Nullable final String frameLogger) { + @Nullable final String frameLogger, @Nullable final KeepAlivePolicy keepAlivePolicy) { this.headersFactory = headersFactory; this.headersSensitivityDetector = headersSensitivityDetector; this.frameLoggerName = frameLogger; + this.keepAlivePolicy = keepAlivePolicy; } @Override @@ -120,5 +140,11 @@ public BiPredicate headersSensitivityDetector() { public String frameLoggerName() { return frameLoggerName; } - } + + @Nullable + @Override + public KeepAlivePolicy keepAlivePolicy() { + return keepAlivePolicy; + } + } } diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ServerParentConnectionContext.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ServerParentConnectionContext.java index 700ae51271..9b12f97a09 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ServerParentConnectionContext.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ServerParentConnectionContext.java @@ -61,8 +61,9 @@ private H2ServerParentConnectionContext(final Channel channel, final BufferAlloc final Executor executor, final FlushStrategy flushStrategy, @Nullable final Long idleTimeoutMs, final HttpExecutionStrategy executionStrategy, - final SocketAddress listenAddress) { - super(channel, allocator, executor, flushStrategy, idleTimeoutMs, executionStrategy); + final SocketAddress listenAddress, + final KeepAliveManager keepAliveManager) { + super(channel, allocator, executor, flushStrategy, idleTimeoutMs, executionStrategy, keepAliveManager); this.listenAddress = requireNonNull(listenAddress); } @@ -109,13 +110,14 @@ protected void handleSubscribe(final Subscriber activeChildChannelsUpdater = + AtomicIntegerFieldUpdater.newUpdater(KeepAliveManager.class, "activeChildChannels"); + private static final long GRACEFUL_CLOSE_PING_CONTENT = ThreadLocalRandom.current().nextLong(); + private static final long KEEP_ALIVE_PING_CONTENT = ThreadLocalRandom.current().nextLong(); + private static final Object CLOSED = new Object(); + private static final Object GRACEFUL_CLOSE_START = new Object(); + private static final Object GRACEFUL_CLOSE_SECOND_GO_AWAY_SENT = new Object(); + private static final Object KEEP_ALIVE_ACK_PENDING = new Object(); + private static final Object KEEP_ALIVE_ACK_TIMEDOUT = new Object(); + + private volatile int activeChildChannels; + + private final Channel channel; + private final long pingAckTimeoutNanos; + private final boolean disallowKeepAliveWithoutActiveStreams; + private final Scheduler scheduler; + + // below state should only be accessed from eventloop + /** + * This stores the following possible values: + *
    + *
  • {@code null} if graceful close has not started.
  • + *
  • {@link #GRACEFUL_CLOSE_START} if graceful close process has been initiated.
  • + *
  • {@link Future} instance to timeout ack of PING sent to measure RTT.
  • + *
  • {@link #GRACEFUL_CLOSE_SECOND_GO_AWAY_SENT} if we have sent the second go away frame.
  • + *
  • {@link #CLOSED} if the channel is closed.
  • + *
+ */ + @Nullable + private Object gracefulCloseState; + + /** + * This stores the following possible values: + *
    + *
  • {@code null} if keep-alive PING process is not started.
  • + *
  • {@link #KEEP_ALIVE_ACK_PENDING} if a keep-alive PING has been sent but ack is not received.
  • + *
  • {@link Future} instance to timeout ack of PING sent.
  • + *
  • {@link #KEEP_ALIVE_ACK_TIMEDOUT} if we fail to receive a PING ack for the configured timeout.
  • + *
  • {@link #CLOSED} if the channel is closed.
  • + *
+ */ + @Nullable + private Object keepAliveState; + @Nullable + private final GenericFutureListener> pingWriteCompletionListener; + + KeepAliveManager(final Channel channel, @Nullable final KeepAlivePolicy keepAlivePolicy) { + this(channel, keepAlivePolicy, (task, delay, unit) -> + channel.eventLoop().schedule(task, delay, unit), + (ch, idlenessThresholdSeconds, onIdle) -> ch.pipeline().addLast( + new IdleStateHandler(idlenessThresholdSeconds, idlenessThresholdSeconds, 0) { + @Override + protected void channelIdle(final ChannelHandlerContext ctx, final IdleStateEvent evt) { + onIdle.run(); + } + })); + } + + KeepAliveManager(final Channel channel, @Nullable final KeepAlivePolicy keepAlivePolicy, + final Scheduler scheduler, final IdlenessDetector idlenessDetector) { + this.channel = channel; + this.scheduler = scheduler; + if (keepAlivePolicy != null) { + disallowKeepAliveWithoutActiveStreams = !keepAlivePolicy.withoutActiveStreams(); + pingAckTimeoutNanos = keepAlivePolicy.ackTimeout().toNanos(); + pingWriteCompletionListener = future -> { + if (future.isSuccess() && keepAliveState == KEEP_ALIVE_ACK_PENDING) { + // Schedule a task to verify ping ack within the pingAckTimeoutMillis + keepAliveState = scheduler.afterDuration(() -> { + if (keepAliveState != null) { + keepAliveState = KEEP_ALIVE_ACK_TIMEDOUT; + LOGGER.debug( + "channel={}, timeout {}ns waiting for keep-alive PING(ACK), writing go_away.", + this.channel, pingAckTimeoutNanos); + channel.writeAndFlush(new DefaultHttp2GoAwayFrame(NO_ERROR)) + .addListener(f -> { + if (f.isSuccess()) { + LOGGER.debug("Closing channel={}, after keep-alive timeout.", this.channel); + KeepAliveManager.this.close0(); + } + }); + } + }, pingAckTimeoutNanos, NANOSECONDS); + } + }; + int idleInSeconds = (int) min(keepAlivePolicy.idleDuration().getSeconds(), Integer.MAX_VALUE); + idlenessDetector.configure(channel, idleInSeconds, this::channelIdle); + } else { + disallowKeepAliveWithoutActiveStreams = false; + pingAckTimeoutNanos = DEFAULT_ACK_TIMEOUT.toNanos(); + pingWriteCompletionListener = null; + } + } + + void pingReceived(final Http2PingFrame pingFrame) { + assert channel.eventLoop().inEventLoop(); + + if (pingFrame.ack()) { + long pingAckContent = pingFrame.content(); + if (pingAckContent == GRACEFUL_CLOSE_PING_CONTENT) { + LOGGER.debug("channel={}, graceful close ping ack received.", channel); + cancelIfStateIsAFuture(gracefulCloseState); + gracefulCloseWriteSecondGoAway(); + } else if (pingAckContent == KEEP_ALIVE_PING_CONTENT) { + cancelIfStateIsAFuture(keepAliveState); + keepAliveState = null; + } + } else { + // Send an ack for the received ping + channel.writeAndFlush(new DefaultHttp2PingFrame(pingFrame.content(), true)); + } + } + + void trackActiveStream(final Channel streamChannel) { + activeChildChannelsUpdater.incrementAndGet(this); + streamChannel.closeFuture().addListener(f -> { + if (activeChildChannelsUpdater.decrementAndGet(this) == 0 && + gracefulCloseState == GRACEFUL_CLOSE_SECOND_GO_AWAY_SENT) { + close0(); + } + }); + } + + void channelClosed() { + assert channel.eventLoop().inEventLoop(); + + cancelIfStateIsAFuture(gracefulCloseState); + cancelIfStateIsAFuture(keepAliveState); + gracefulCloseState = CLOSED; + keepAliveState = CLOSED; + } + + void initiateGracefulClose(final Runnable whenInitiated) { + EventLoop eventLoop = channel.eventLoop(); + if (eventLoop.inEventLoop()) { + doCloseAsyncGracefully0(whenInitiated); + } else { + eventLoop.execute(() -> doCloseAsyncGracefully0(whenInitiated)); + } + } + + void channelIdle() { + assert channel.eventLoop().inEventLoop(); + assert pingWriteCompletionListener != null; + + if (keepAliveState != null || disallowKeepAliveWithoutActiveStreams && activeChildChannels == 0) { + return; + } + // idleness detected for the first time, send a ping to detect closure, if any. + keepAliveState = KEEP_ALIVE_ACK_PENDING; + channel.writeAndFlush(new DefaultHttp2PingFrame(KEEP_ALIVE_PING_CONTENT, false)) + .addListener(pingWriteCompletionListener); + } + + /** + * Scheduler of {@link Runnable}s. + */ + @FunctionalInterface + interface Scheduler { + + /** + * Run the passed {@link Runnable} after {@code delay} milliseconds. + * + * @param task {@link Runnable} to run. + * @param delay after which the task is to be run. + * @param unit {@link TimeUnit} for the delay. + * @return {@link Future} for the scheduled task. + */ + Future afterDuration(Runnable task, long delay, TimeUnit unit); + } + + /** + * Scheduler of {@link Runnable}s. + */ + @FunctionalInterface + interface IdlenessDetector { + /** + * Configure idleness detection for the passed {@code channel}. + * + * @param channel {@link Channel} for which idleness detection is to be configured. + * @param idlenessThresholdSeconds Seconds of idleness after which {@link Runnable#run()} should be called on + * the passed {@code onIdle}. + * @param onIdle {@link Runnable} to call when the channel is idle more than {@code idlenessThresholdSeconds}. + */ + void configure(Channel channel, int idlenessThresholdSeconds, Runnable onIdle); + } + + private void doCloseAsyncGracefully0(final Runnable whenInitiated) { + assert channel.eventLoop().inEventLoop(); + + if (gracefulCloseState != null) { + // either we are already closed or have already initiated graceful closure. + return; + } + + whenInitiated.run(); + + // Set the pingState before doing the write, because we will reference the state + // when we receive the PING(ACK) to determine if action is necessary, and it is conceivable that the + // write future may not be executed which sets the timer. + gracefulCloseState = GRACEFUL_CLOSE_START; + + // The graceful close process is described in [1]. It involves sending 2 GOAWAY frames. The first + // GOAWAY has last-stream-id= to indicate no new streams can be created, wait for 2 RTT + // time duration for inflight frames to land, and the second GOAWAY includes the maximum known stream ID. + // To account for 2 RTTs we can send a PING and when the PING(ACK) comes back we can send the second GOAWAY. + // [1] https://tools.ietf.org/html/rfc7540#section-6.8 + DefaultHttp2GoAwayFrame goAwayFrame = new DefaultHttp2GoAwayFrame(NO_ERROR); + goAwayFrame.setExtraStreamIds(Integer.MAX_VALUE); + channel.write(goAwayFrame); + channel.writeAndFlush(new DefaultHttp2PingFrame(GRACEFUL_CLOSE_PING_CONTENT)).addListener(future -> { + // If gracefulCloseState is not GRACEFUL_CLOSE_START that means we have already received the PING(ACK) and + // there is no need to apply the timeout. + if (future.isSuccess() && gracefulCloseState == GRACEFUL_CLOSE_START) { + gracefulCloseState = scheduler.afterDuration(() -> { + // If the PING(ACK) times out we may have under estimated the 2RTT time so we + // optimistically keep the connection open and rely upon higher level timeouts to tear + // down the connection. + LOGGER.debug("channel={} timeout {}ns waiting for PING(ACK) during graceful close.", + channel, pingAckTimeoutNanos); + gracefulCloseWriteSecondGoAway(); + }, pingAckTimeoutNanos, NANOSECONDS); + } + }); + } + + private void gracefulCloseWriteSecondGoAway() { + assert channel.eventLoop().inEventLoop(); + + if (gracefulCloseState == GRACEFUL_CLOSE_SECOND_GO_AWAY_SENT) { + return; + } + + gracefulCloseState = GRACEFUL_CLOSE_SECOND_GO_AWAY_SENT; + + channel.writeAndFlush(new DefaultHttp2GoAwayFrame(NO_ERROR)).addListener(future -> { + if (activeChildChannels == 0) { + close0(); + } + }); + } + + private void close0() { + assert channel.eventLoop().inEventLoop(); + + if (gracefulCloseState == CLOSED && keepAliveState == CLOSED) { + return; + } + gracefulCloseState = CLOSED; + keepAliveState = CLOSED; + + // The way netty H2 stream state machine works, we may trigger stream closures during writes with flushes + // pending behind the writes. In such cases, we may close too early ignoring the writes. Hence we flush before + // closure, if there is no write pending then flush is a noop. + channel.flush(); + channel.close(); + } + + private void cancelIfStateIsAFuture(@Nullable final Object state) { + if (state instanceof Future) { + try { + ((Future) state).cancel(true); + } catch (Throwable t) { + LOGGER.debug("Failed to cancel {} scheduled future.", + state == keepAliveState ? "keep-alive" : "graceful close", t); + } + } + } +} diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/KeepAliveManagerTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/KeepAliveManagerTest.java new file mode 100644 index 0000000000..8e6fac4874 --- /dev/null +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/KeepAliveManagerTest.java @@ -0,0 +1,332 @@ +/* + * Copyright © 2020 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.http.netty; + +import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout; +import io.servicetalk.http.netty.H2ProtocolConfig.KeepAlivePolicy; + +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http2.DefaultHttp2PingFrame; +import io.netty.handler.codec.http2.Http2Error; +import io.netty.handler.codec.http2.Http2GoAwayFrame; +import io.netty.handler.codec.http2.Http2PingFrame; +import io.netty.util.concurrent.Promise; +import org.hamcrest.Matcher; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadLocalRandom; + +import static io.servicetalk.http.netty.H2KeepAlivePolicies.DEFAULT_ACK_TIMEOUT; +import static io.servicetalk.http.netty.H2KeepAlivePolicies.DEFAULT_IDLE_DURATION; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KeepAliveManagerTest { + @Rule + public final Timeout timeout = new ServiceTalkTestTimeout(); + + private final BlockingQueue scheduledTasks; + private final EmbeddedChannel channel; + + public KeepAliveManagerTest() { + scheduledTasks = new LinkedBlockingQueue<>(); + channel = new EmbeddedChannel(); + } + + @Test + public void keepAliveDisallowedWithNoActiveStreams() { + KeepAliveManager manager = newManager(false); + manager.channelIdle(); + verifyNoWrite(); + verifyNoScheduledTasks(); + } + + @Test + public void keepAliveAllowedWithNoActiveStreams() { + KeepAliveManager manager = newManager(true); + manager.channelIdle(); + verifyWrite(instanceOf(Http2PingFrame.class)); + } + + @Test + public void keepAliveWithActiveStreams() { + KeepAliveManager manager = newManager(false); + addActiveStream(manager); + manager.channelIdle(); + verifyWrite(instanceOf(Http2PingFrame.class)); + } + + @Test + public void keepAlivePingAckReceived() { + KeepAliveManager manager = newManager(false); + addActiveStream(manager); + manager.channelIdle(); + Http2PingFrame ping = verifyWrite(instanceOf(Http2PingFrame.class)); + ScheduledTask ackTimeoutTask = verifyPingAckTimeoutScheduled(); + + manager.pingReceived(new DefaultHttp2PingFrame(ping.content(), true)); + assertThat("Ping ack timeout task not cancelled.", ackTimeoutTask.promise.isCancelled(), is(true)); + + ackTimeoutTask.task.run(); + verifyNoWrite(); + verifyNoScheduledTasks(); + assertThat("Channel unexpectedly closed.", channel.isOpen(), is(true)); + } + + @Test + public void keepAlivePingAckWithUnknownContent() { + KeepAliveManager manager = newManager(false); + addActiveStream(manager); + manager.channelIdle(); + Http2PingFrame ping = verifyWrite(instanceOf(Http2PingFrame.class)); + ScheduledTask ackTimeoutTask = verifyPingAckTimeoutScheduled(); + + manager.pingReceived(new DefaultHttp2PingFrame(ping.content() + 1, true)); + assertThat("Ping ack timeout task cancelled.", ackTimeoutTask.promise.isCancelled(), is(false)); + + verifyChannelCloseOnMissingPingAck(ackTimeoutTask); + } + + @Test + public void keepAliveMissingPingAck() { + KeepAliveManager manager = newManager(false); + addActiveStream(manager); + manager.channelIdle(); + verifyWrite(instanceOf(Http2PingFrame.class)); + verifyChannelCloseOnMissingPingAck(verifyPingAckTimeoutScheduled()); + } + + @Test + public void gracefulCloseNoActiveStreams() throws Exception { + KeepAliveManager manager = newManager(false); + Http2PingFrame pingFrame = initiateGracefulCloseVerifyGoAwayAndPing(manager); + + sendGracefulClosePingAckAndVerifySecondGoAway(manager, pingFrame); + + channel.closeFuture().sync().await(); + } + + @Test + public void gracefulCloseWithActiveStreams() throws Exception { + KeepAliveManager manager = newManager(false); + EmbeddedChannel activeStream = addActiveStream(manager); + Http2PingFrame pingFrame = initiateGracefulCloseVerifyGoAwayAndPing(manager); + + sendGracefulClosePingAckAndVerifySecondGoAway(manager, pingFrame); + + assertThat("Channel not closed.", channel.isOpen(), is(true)); + activeStream.close().sync().await(); + + channel.closeFuture().sync().await(); + } + + @Test + public void gracefulCloseNoActiveStreamsMissingPingAck() throws Exception { + KeepAliveManager manager = newManager(false); + initiateGracefulCloseVerifyGoAwayAndPing(manager); + + ScheduledTask pingAckTimeoutTask = scheduledTasks.take(); + pingAckTimeoutTask.task.run(); + verifySecondGoAway(); + + channel.closeFuture().sync().await(); + } + + @Test + public void gracefulCloseActiveStreamsMissingPingAck() throws Exception { + KeepAliveManager manager = newManager(false); + EmbeddedChannel activeStream = addActiveStream(manager); + initiateGracefulCloseVerifyGoAwayAndPing(manager); + + ScheduledTask pingAckTimeoutTask = scheduledTasks.take(); + pingAckTimeoutTask.task.run(); + verifySecondGoAway(); + + assertThat("Channel closed.", channel.isOpen(), is(true)); + + activeStream.close().sync().await(); + + channel.closeFuture().sync().await(); + } + + @Test + public void gracefulClosePendingPingsCloseConnection() throws Exception { + KeepAliveManager manager = newManager(false); + addActiveStream(manager); + Http2PingFrame pingFrame = initiateGracefulCloseVerifyGoAwayAndPing(manager); + + sendGracefulClosePingAckAndVerifySecondGoAway(manager, pingFrame); + assertThat("Channel closed.", channel.isOpen(), is(true)); + + manager.channelIdle(); + verifyWrite(instanceOf(Http2PingFrame.class)); + verifyChannelCloseOnMissingPingAck(verifyPingAckTimeoutScheduled()); + } + + @Test + public void pingsAreAcked() { + KeepAliveManager manager = newManager(false); + long pingContent = ThreadLocalRandom.current().nextLong(); + manager.pingReceived(new DefaultHttp2PingFrame(pingContent, false)); + Http2PingFrame pingFrame = verifyWrite(instanceOf(Http2PingFrame.class)); + assertThat("Unexpected ping ack content.", pingFrame.content(), is(pingContent)); + assertThat("Unexpected ping ack content.", pingFrame.ack(), is(true)); + } + + @Test + public void channelClosedDuringGracefulClose() throws Exception { + KeepAliveManager manager = newManager(false); + addActiveStream(manager); + initiateGracefulCloseVerifyGoAwayAndPing(manager); + ScheduledTask pingAckTimeoutTask = scheduledTasks.take(); + assertThat("Ping ack timeout not scheduled.", pingAckTimeoutTask, is(notNullValue())); + + manager.channelClosed(); + assertThat("Graceful close ping ack timeout not cancelled.", pingAckTimeoutTask.promise.isCancelled(), + is(true)); + + verifyNoOtherActionPostClose(manager); + } + + @Test + public void channelClosedDuringPing() { + KeepAliveManager manager = newManager(false); + addActiveStream(manager); + manager.channelIdle(); + verifyWrite(instanceOf(Http2PingFrame.class)); + ScheduledTask ackTimeoutTask = verifyPingAckTimeoutScheduled(); + + manager.channelClosed(); + assertThat("Keep alive ping ack timeout not cancelled.", ackTimeoutTask.promise.isCancelled(), + is(true)); + + verifyNoOtherActionPostClose(manager); + } + + private void verifyNoOtherActionPostClose(final KeepAliveManager manager) { + manager.channelIdle(); + verifyNoWrite(); + verifyNoScheduledTasks(); + + manager.initiateGracefulClose(() -> { }); + verifyNoWrite(); + verifyNoScheduledTasks(); + } + + private ScheduledTask verifyPingAckTimeoutScheduled() { + ScheduledTask ackTimeoutTask = scheduledTasks.poll(); + assertThat("Ping ack timeout not scheduled.", ackTimeoutTask, is(notNullValue())); + assertThat("Unexpected ping ack timeout duration.", ackTimeoutTask.delayMillis, + is(DEFAULT_ACK_TIMEOUT.toMillis())); + return ackTimeoutTask; + } + + private void sendGracefulClosePingAckAndVerifySecondGoAway(final KeepAliveManager manager, + final Http2PingFrame pingFrame) throws Exception { + ScheduledTask pingAckTimeoutTask = scheduledTasks.take(); + manager.pingReceived(new DefaultHttp2PingFrame(pingFrame.content(), true)); + assertThat("Ping ack task not cancelled.", pingAckTimeoutTask.promise.isCancelled(), is(true)); + verifySecondGoAway(); + + pingAckTimeoutTask.task.run(); + + verifyNoWrite(); + verifyNoScheduledTasks(); + } + + private void verifySecondGoAway() { + Http2GoAwayFrame secondGoAway = verifyWrite(instanceOf(Http2GoAwayFrame.class)); + assertThat("Unexpected error in go_away", secondGoAway.errorCode(), is(Http2Error.NO_ERROR.code())); + verifyNoScheduledTasks(); + } + + private Http2PingFrame initiateGracefulCloseVerifyGoAwayAndPing(final KeepAliveManager manager) { + manager.initiateGracefulClose(() -> { }); + Http2GoAwayFrame firstGoAway = verifyWrite(instanceOf(Http2GoAwayFrame.class)); + assertThat("Unexpected error in go_away", firstGoAway.errorCode(), is(Http2Error.NO_ERROR.code())); + Http2PingFrame pingFrame = verifyWrite(instanceOf(Http2PingFrame.class)); + verifyNoWrite(); + return pingFrame; + } + + @SuppressWarnings("unchecked") + private T verifyWrite(Matcher writeMatcher) { + Object written = channel.outboundMessages().poll(); + assertThat("Unexpected frame written.", written, is(notNullValue())); + assertThat("Unexpected frame written.", written, writeMatcher); + return (T) written; + } + + private void verifyNoWrite() { + assertThat("Unexpected frame written.", channel.outboundMessages().poll(), is(nullValue())); + } + + private void verifyNoScheduledTasks() { + assertThat("Unexpected tasks scheduled.", scheduledTasks.poll(), is(nullValue())); + } + + private EmbeddedChannel addActiveStream(final KeepAliveManager manager) { + EmbeddedChannel stream = new EmbeddedChannel(); + manager.trackActiveStream(stream); + return stream; + } + + private void verifyChannelCloseOnMissingPingAck(final ScheduledTask ackTimeoutTask) { + ackTimeoutTask.task.run(); + verifyWrite(instanceOf(Http2GoAwayFrame.class)); + verifyNoScheduledTasks(); + assertThat("Channel not closed.", channel.isOpen(), is(false)); + } + + private KeepAliveManager newManager(final boolean allowPingWithoutActiveStreams) { + KeepAlivePolicy policy = mock(KeepAlivePolicy.class); + when(policy.idleDuration()).thenReturn(DEFAULT_IDLE_DURATION); + when(policy.ackTimeout()).thenReturn(DEFAULT_ACK_TIMEOUT); + when(policy.withoutActiveStreams()).thenReturn(allowPingWithoutActiveStreams); + return new KeepAliveManager(channel, policy, + (task, delay, unit) -> { + ChannelPromise promise = channel.newPromise(); + ScheduledTask scheduledTask = new ScheduledTask(task, promise, + MILLISECONDS.convert(delay, unit)); + scheduledTasks.add(scheduledTask); + return scheduledTask.promise; + }, + (__, ___, ____) -> { }); + } + + private static final class ScheduledTask { + final Runnable task; + final Promise promise; + final long delayMillis; + + ScheduledTask(final Runnable task, final Promise promise, final long delayMillis) { + this.task = task; + this.promise = promise; + this.delayMillis = delayMillis; + } + } +} diff --git a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/NettyChannelListenableAsyncCloseable.java b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/NettyChannelListenableAsyncCloseable.java index f0ccca86be..8b5a7d87e5 100644 --- a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/NettyChannelListenableAsyncCloseable.java +++ b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/NettyChannelListenableAsyncCloseable.java @@ -26,6 +26,7 @@ import static io.servicetalk.concurrent.Cancellable.IGNORE_CANCEL; import static io.servicetalk.concurrent.api.SourceAdapters.toSource; +import static io.servicetalk.concurrent.internal.SubscriberUtils.deliverTerminalFromSource; import static io.servicetalk.transport.netty.internal.CloseStates.CLOSING; import static io.servicetalk.transport.netty.internal.CloseStates.GRACEFULLY_CLOSING; import static io.servicetalk.transport.netty.internal.CloseStates.OPEN; @@ -82,7 +83,12 @@ public final Completable closeAsyncGracefully() { @Override protected void handleSubscribe(final Subscriber subscriber) { if (stateUpdater.compareAndSet(NettyChannelListenableAsyncCloseable.this, OPEN, GRACEFULLY_CLOSING)) { - doCloseAsyncGracefully(); + try { + doCloseAsyncGracefully(); + } catch (Throwable t) { + deliverTerminalFromSource(subscriber, t); + return; + } } toSource(onClose()).subscribe(subscriber); }