diff --git a/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancer.java b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancer.java index d463ea8f7c..2dd1cb332a 100644 --- a/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancer.java +++ b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancer.java @@ -378,7 +378,7 @@ void markInactive() { @SuppressWarnings("unchecked") List toRemove = connectionsUpdater.getAndSet(this, INACTIVE); for (C conn : toRemove) { - conn.closeAsync().subscribe(); + conn.closeAsyncGracefully().subscribe(); } } diff --git a/servicetalk-loadbalancer/src/test/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancerTest.java b/servicetalk-loadbalancer/src/test/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancerTest.java index 569c668575..588aaa33f9 100644 --- a/servicetalk-loadbalancer/src/test/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancerTest.java +++ b/servicetalk-loadbalancer/src/test/java/io/servicetalk/loadbalancer/RoundRobinLoadBalancerTest.java @@ -22,7 +22,6 @@ import io.servicetalk.client.api.LoadBalancerReadyEvent; import io.servicetalk.client.api.NoAvailableHostException; import io.servicetalk.client.api.ServiceDiscovererEvent; -import io.servicetalk.concurrent.CompletableSource.Processor; import io.servicetalk.concurrent.PublisherSource.Subscriber; import io.servicetalk.concurrent.PublisherSource.Subscription; import io.servicetalk.concurrent.api.Completable; @@ -62,11 +61,10 @@ import java.util.function.Function; import java.util.function.Predicate; +import static io.servicetalk.concurrent.api.AsyncCloseables.emptyAsyncCloseable; import static io.servicetalk.concurrent.api.BlockingTestUtils.awaitIndefinitely; -import static io.servicetalk.concurrent.api.Processors.newCompletableProcessor; import static io.servicetalk.concurrent.api.Single.failed; import static io.servicetalk.concurrent.api.Single.succeeded; -import static io.servicetalk.concurrent.api.SourceAdapters.fromSource; import static io.servicetalk.concurrent.api.SourceAdapters.toSource; import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION; import static io.servicetalk.concurrent.internal.ServiceTalkTestTimeout.DEFAULT_TIMEOUT_SECONDS; @@ -89,6 +87,8 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class RoundRobinLoadBalancerTest { @@ -105,7 +105,7 @@ public class RoundRobinLoadBalancerTest { private final List connectionsCreated = new CopyOnWriteArrayList<>(); private final Queue connectionRealizers = new ConcurrentLinkedQueue<>(); - private TestPublisher> serviceDiscoveryPublisher = new TestPublisher<>(); + private final TestPublisher> serviceDiscoveryPublisher = new TestPublisher<>(); private RoundRobinLoadBalancer lb; private DelegatingConnectionFactory connectionFactory; @@ -406,6 +406,16 @@ public void newConnectionIsClosedWhenSelectorRejects() throws Exception { awaitIndefinitely(connection.onClose()); } + @Test + public void hostDownGracefulCloseConnection() throws Exception { + sendServiceDiscoveryEvents(upEvent("address-1")); + TestLoadBalancedConnection conn = lb.selectConnection(any()).toFuture().get(); + sendServiceDiscoveryEvents(downEvent("address-1")); + conn.onClose().toFuture().get(); + verify(conn).closeAsyncGracefully(); + verify(conn, times(0)).closeAsync(); + } + @SuppressWarnings("unchecked") private void sendServiceDiscoveryEvents(final ServiceDiscovererEvent... events) { serviceDiscoveryPublisher.onNext((ServiceDiscovererEvent[]) events); @@ -437,12 +447,10 @@ private Single newRealizedConnectionSingle(final Str @SuppressWarnings("unchecked") private TestLoadBalancedConnection newConnection(final String address) { final TestLoadBalancedConnection cnx = mock(TestLoadBalancedConnection.class); - final Processor closeCompletable = newCompletableProcessor(); - when(cnx.closeAsync()).thenAnswer(__ -> { - closeCompletable.onComplete(); - return closeCompletable; - }); - when(cnx.onClose()).thenReturn(fromSource(closeCompletable)); + final ListenableAsyncCloseable closeable = emptyAsyncCloseable(); + when(cnx.closeAsync()).thenReturn(closeable.closeAsync()); + when(cnx.closeAsyncGracefully()).thenReturn(closeable.closeAsyncGracefully()); + when(cnx.onClose()).thenReturn(closeable.onClose()); when(cnx.address()).thenReturn(address); when(cnx.toString()).thenReturn(address + '@' + cnx.hashCode());