Skip to content

Commit

Permalink
Enhanced WebsocketGatewayClientTransport
Browse files Browse the repository at this point in the history
  • Loading branch information
artem-v committed Oct 1, 2024
1 parent b3477ea commit c26977a
Showing 1 changed file with 30 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelOption;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.scalecube.services.Address;
import io.scalecube.services.ServiceReference;
Expand All @@ -12,11 +13,11 @@
import io.scalecube.services.transport.api.ClientTransport;
import java.lang.reflect.Type;
import java.time.Duration;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.UnaryOperator;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -26,28 +27,21 @@
import reactor.netty.http.client.HttpClient;
import reactor.netty.resources.ConnectionProvider;
import reactor.netty.resources.LoopResources;
import reactor.netty.tcp.SslProvider;

public final class WebsocketGatewayClientTransport implements ClientChannel, ClientTransport {

private static final Logger LOGGER =
LoggerFactory.getLogger(WebsocketGatewayClientTransport.class);

private static final String CONTENT_TYPE = "application/json";
private static final String STREAM_ID = "sid";

private static final String CONTENT_TYPE = "application/json";
private static final WebsocketGatewayClientCodec CLIENT_CODEC = new WebsocketGatewayClientCodec();
private static final int CONNECT_TIMEOUT_MILLIS = (int) Duration.ofSeconds(5).toMillis();

private final GatewayClientCodec clientCodec;
private final LoopResources loopResources;
private final Address address;
private final Duration connectTimeout;
private final String contentType;
private final boolean followRedirect;
private final SslProvider sslProvider;
private final boolean shouldWiretap;
private final Duration keepAliveInterval;
private final Map<String, String> headers;
private final Function<HttpClient, HttpClient> operator;
private final boolean ownsLoopResources;

private final AtomicLong sidCounter = new AtomicLong();
Expand All @@ -56,14 +50,8 @@ public final class WebsocketGatewayClientTransport implements ClientChannel, Cli

private WebsocketGatewayClientTransport(Builder builder) {
this.clientCodec = builder.clientCodec;
this.address = builder.address;
this.connectTimeout = builder.connectTimeout;
this.contentType = builder.contentType;
this.followRedirect = builder.followRedirect;
this.sslProvider = builder.sslProvider;
this.shouldWiretap = builder.shouldWiretap;
this.keepAliveInterval = builder.keepAliveInterval;
this.headers = builder.headers;
this.operator = builder.operator;
this.loopResources =
builder.loopResources == null
? LoopResources.create("websocket-gateway-client")
Expand All @@ -79,28 +67,20 @@ public ClientChannel create(ServiceReference serviceReference) {
return oldValue;
}

HttpClient httpClient =
HttpClient.create(ConnectionProvider.newConnection())
.headers(entries -> headers.forEach(entries::add))
.headers(entries -> entries.set("Content-Type", contentType))
.followRedirect(followRedirect)
.wiretap(shouldWiretap)
.runOn(loopResources)
.host(address.host())
.port(address.port())
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) connectTimeout.toMillis())
.option(ChannelOption.TCP_NODELAY, true);

if (sslProvider != null) {
httpClient = httpClient.secure(sslProvider);
}
final HttpClient httpClient =
operator.apply(
HttpClient.create(ConnectionProvider.newConnection())
.runOn(loopResources)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, CONNECT_TIMEOUT_MILLIS)
.option(ChannelOption.TCP_NODELAY, true)
.headers(headers -> headers.set(HttpHeaderNames.CONTENT_TYPE, CONTENT_TYPE)));

return createClientSession(httpClient);
return clientSession(httpClient);
});
return this;
}

private WebsocketGatewayClientSession createClientSession(HttpClient httpClient) {
private WebsocketGatewayClientSession clientSession(HttpClient httpClient) {
try {
return httpClient
.websocket()
Expand All @@ -115,7 +95,7 @@ private WebsocketGatewayClientSession createClientSession(HttpClient httpClient)
: connection)
.map(
connection -> {
WebsocketGatewayClientSession session =
final WebsocketGatewayClientSession session =
new WebsocketGatewayClientSession(clientCodec, connection);
LOGGER.info("Created session: {}", session);
// setup shutdown hook
Expand All @@ -131,8 +111,7 @@ private WebsocketGatewayClientSession createClientSession(HttpClient httpClient)
th.toString()));
return session;
})
.doOnError(
ex -> LOGGER.warn("Failed to connect on {}, cause: {}", address, ex.toString()))
.doOnError(ex -> LOGGER.warn("Failed to connect, cause: {}", ex.toString()))
.toFuture()
.get();
} catch (Exception e) {
Expand Down Expand Up @@ -222,105 +201,49 @@ public static class Builder {

private GatewayClientCodec clientCodec = CLIENT_CODEC;
private LoopResources loopResources;
private Address address;
private Duration connectTimeout = Duration.ofSeconds(5);
private String contentType = CONTENT_TYPE;
private boolean followRedirect;
private SslProvider sslProvider;
private boolean shouldWiretap;
private Duration keepAliveInterval = Duration.ZERO;
private Map<String, String> headers = new HashMap<>();
private Function<HttpClient, HttpClient> operator = client -> client;

public Builder() {}

public GatewayClientCodec clientCodec() {
return clientCodec;
}

public Builder clientCodec(GatewayClientCodec clientCodec) {
this.clientCodec = clientCodec;
return this;
}

public LoopResources loopResources() {
return loopResources;
}

public Builder loopResources(LoopResources loopResources) {
this.loopResources = loopResources;
return this;
}

public Address address() {
return address;
}

public Builder address(Address address) {
this.address = address;
public Builder httpClient(UnaryOperator<HttpClient> operator) {
this.operator = this.operator.andThen(operator);
return this;
}

public Duration connectTimeout() {
return connectTimeout;
public Builder address(Address address) {
return httpClient(client -> client.host(address.host()).port(address.port()));
}

public Builder connectTimeout(Duration connectTimeout) {
this.connectTimeout = connectTimeout;
return this;
}

public String contentType() {
return contentType;
return httpClient(
client ->
client.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) connectTimeout.toMillis()));
}

public Builder contentType(String contentType) {
this.contentType = contentType;
return this;
}

public boolean isFollowRedirect() {
return followRedirect;
}

public Builder followRedirect(boolean followRedirect) {
this.followRedirect = followRedirect;
return this;
}

public SslProvider sslProvider() {
return sslProvider;
}

public Builder sslProvider(SslProvider sslProvider) {
this.sslProvider = sslProvider;
return this;
}

public boolean isShouldWiretap() {
return shouldWiretap;
}

public Builder shouldWiretap(boolean shouldWiretap) {
this.shouldWiretap = shouldWiretap;
return this;
}

public Duration keepAliveInterval() {
return keepAliveInterval;
return httpClient(
client ->
client.headers(headers -> headers.set(HttpHeaderNames.CONTENT_TYPE, contentType)));
}

public Builder keepAliveInterval(Duration keepAliveInterval) {
this.keepAliveInterval = keepAliveInterval;
return this;
}

public Map<String, String> headers() {
return headers;
}

public Builder headers(Map<String, String> headers) {
this.headers = Collections.unmodifiableMap(new HashMap<>(headers));
return this;
return httpClient(client -> client.headers(entries -> headers.forEach(entries::set)));
}

public WebsocketGatewayClientTransport build() {
Expand Down

0 comments on commit c26977a

Please sign in to comment.