Skip to content
This repository has been archived by the owner on Jun 21, 2023. It is now read-only.

Commit

Permalink
Merge pull request #135 from scalecube/copy-msg-hdrs-corresp-client-t…
Browse files Browse the repository at this point in the history
…rnsprt

Passing client headers on connection setup
  • Loading branch information
artem-v authored May 25, 2020
2 parents ba3507e + 8fb5ed2 commit afb0b07
Show file tree
Hide file tree
Showing 13 changed files with 113 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import io.scalecube.services.exceptions.DefaultErrorMapper;
import io.scalecube.services.exceptions.ServiceClientErrorMapper;
import java.time.Duration;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import reactor.netty.tcp.SslProvider;

public class GatewayClientSettings {
Expand All @@ -20,6 +23,7 @@ public class GatewayClientSettings {
private final ServiceClientErrorMapper errorMapper;
private final Duration keepAliveInterval;
private final boolean wiretap;
private final Map<String, String> headers;

private GatewayClientSettings(Builder builder) {
this.host = builder.host;
Expand All @@ -30,6 +34,7 @@ private GatewayClientSettings(Builder builder) {
this.errorMapper = builder.errorMapper;
this.keepAliveInterval = builder.keepAliveInterval;
this.wiretap = builder.wiretap;
this.headers = builder.headers;
}

public String host() {
Expand Down Expand Up @@ -64,6 +69,10 @@ public boolean wiretap() {
return this.wiretap;
}

public Map<String, String> headers() {
return headers;
}

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -96,9 +105,9 @@ public static class Builder {
private ServiceClientErrorMapper errorMapper = DefaultErrorMapper.INSTANCE;
private Duration keepAliveInterval = DEFAULT_KEEPALIVE_INTERVAL;
private boolean wiretap = false;
private Map<String, String> headers = Collections.emptyMap();

private Builder() {
}
private Builder() {}

private Builder(GatewayClientSettings originalSettings) {
this.host = originalSettings.host;
Expand All @@ -109,6 +118,7 @@ private Builder(GatewayClientSettings originalSettings) {
this.errorMapper = originalSettings.errorMapper;
this.keepAliveInterval = originalSettings.keepAliveInterval;
this.wiretap = originalSettings.wiretap;
this.headers = Collections.unmodifiableMap(new HashMap<>(originalSettings.headers));
}

public Builder host(String host) {
Expand Down Expand Up @@ -191,6 +201,11 @@ public Builder errorMapper(ServiceClientErrorMapper errorMapper) {
return this;
}

public Builder headers(Map<String, String> headers) {
this.headers = Collections.unmodifiableMap(new HashMap<>(headers));
return this;
}

public GatewayClientSettings build() {
return new GatewayClientSettings(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public HttpGatewayClient(GatewayClientSettings settings, GatewayClientCodec<Byte

httpClient =
HttpClient.create(ConnectionProvider.elastic("http-gateway-client"))
.headers(headers -> settings.headers().forEach(headers::add))
.followRedirect(settings.followRedirect())
.tcpConfiguration(
tcpClient -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io.rsocket.core.RSocketConnector;
import io.rsocket.frame.decoder.PayloadDecoder;
import io.rsocket.transport.netty.client.WebsocketClientTransport;
import io.rsocket.util.EmptyPayload;
import io.scalecube.services.api.ServiceMessage;
import io.scalecube.services.exceptions.ConnectionClosedException;
import io.scalecube.services.gateway.transport.GatewayClient;
Expand Down Expand Up @@ -135,8 +136,14 @@ private Mono<RSocket> getOrConnect0(Mono prev) {
return prev;
}

Payload setupPayload = EmptyPayload.INSTANCE;
if (!settings.headers().isEmpty()) {
setupPayload = codec.encode(ServiceMessage.builder().headers(settings.headers()).build());
}

return RSocketConnector.create()
.payloadDecoder(PayloadDecoder.DEFAULT)
.setupPayload(setupPayload)
.metadataMimeType(settings.contentType())
.connect(createRSocketTransport(settings))
.doOnSuccess(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public WebsocketGatewayClient(GatewayClientSettings settings, GatewayClientCodec

httpClient =
HttpClient.newConnection()
.headers(headers -> settings.headers().forEach(headers::add))
.followRedirect(settings.followRedirect())
.tcpConfiguration(
tcpClient -> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.scalecube.services.gateway;

import java.util.List;
import java.util.Map;

public interface GatewaySession {
Expand All @@ -15,7 +14,7 @@ public interface GatewaySession {
/**
* Returns headers associated with session.
*
* @return heades map
* @return headers map
*/
Map<String, List<String>> headers();
Map<String, String> headers();
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import io.netty.buffer.ByteBuf;
import io.scalecube.services.api.ServiceMessage;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -84,7 +83,7 @@ default void onSessionError(GatewaySession session, Throwable throwable) {
* @param headers connection/session headers
* @return mono result
*/
default Mono<Void> onConnectionOpen(long sessionId, Map<String, List<String>> headers) {
default Mono<Void> onConnectionOpen(long sessionId, Map<String, String> headers) {
return Mono.fromRunnable(
() ->
LOGGER.debug(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.scalecube.services.gateway.GatewaySessionHandler;
import io.scalecube.services.gateway.ServiceMessageCodec;
import io.scalecube.services.transport.api.HeadersCodec;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
Expand Down Expand Up @@ -41,6 +42,7 @@ public Mono<RSocket> accept(ConnectionSetupPayload setup, RSocket rsocket) {
new RSocketGatewaySession(
serviceCall,
messageCodec,
headers(messageCodec, setup),
(session, req) -> sessionHandler.mapMessage(session, req, Context.empty()));
sessionHandler.onSessionOpen(gatewaySession);
rsocket
Expand All @@ -54,4 +56,11 @@ public Mono<RSocket> accept(ConnectionSetupPayload setup, RSocket rsocket) {

return Mono.just(gatewaySession);
}

private Map<String, String> headers(
ServiceMessageCodec messageCodec, ConnectionSetupPayload setup) {
return messageCodec
.decode(setup.sliceData().retain(), setup.sliceMetadata().retain())
.headers();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import io.scalecube.services.gateway.ReferenceCountUtil;
import io.scalecube.services.gateway.ServiceMessageCodec;
import java.util.Collections;
import java.util.List;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiFunction;
Expand All @@ -30,6 +30,7 @@ public final class RSocketGatewaySession extends AbstractRSocket implements Gate
private final ServiceMessageCodec messageCodec;
private final long sessionId;
private final BiFunction<GatewaySession, ServiceMessage, ServiceMessage> messageMapper;
private final Map<String, String> headers;

/**
* Constructor for gateway rsocket.
Expand All @@ -40,11 +41,13 @@ public final class RSocketGatewaySession extends AbstractRSocket implements Gate
public RSocketGatewaySession(
ServiceCall serviceCall,
ServiceMessageCodec messageCodec,
Map<String, String> headers,
BiFunction<GatewaySession, ServiceMessage, ServiceMessage> messageMapper) {
this.serviceCall = serviceCall;
this.messageCodec = messageCodec;
this.messageMapper = messageMapper;
this.sessionId = SESSION_ID_GENERATOR.incrementAndGet();
this.headers = Collections.unmodifiableMap(new HashMap<>(headers));
}

@Override
Expand All @@ -53,8 +56,8 @@ public long sessionId() {
}

@Override
public Map<String, List<String>> headers() {
return Collections.emptyMap();
public Map<String, String> headers() {
return headers;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
import io.scalecube.services.exceptions.UnauthorizedException;
import io.scalecube.services.gateway.GatewaySessionHandler;
import io.scalecube.services.gateway.ReferenceCountUtil;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import org.reactivestreams.Publisher;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -63,7 +63,7 @@ public WebsocketGatewayAcceptor(ServiceCall serviceCall, GatewaySessionHandler g

@Override
public Publisher<Void> apply(HttpServerRequest httpRequest, HttpServerResponse httpResponse) {
final Map<String, List<String>> headers = computeHeaders(httpRequest.requestHeaders());
final Map<String, String> headers = computeHeaders(httpRequest.requestHeaders());
final long sessionId = SESSION_ID_GENERATOR.incrementAndGet();

return gatewayHandler
Expand All @@ -85,12 +85,9 @@ public Publisher<Void> apply(HttpServerRequest httpRequest, HttpServerResponse h
.onErrorResume(throwable -> Mono.empty());
}

private static Map<String, List<String>> computeHeaders(HttpHeaders httpHeaders) {
Map<String, List<String>> headers = new HashMap<>();
for (String name : httpHeaders.names()) {
headers.put(name, httpHeaders.getAll(name));
}
return headers;
private static Map<String, String> computeHeaders(HttpHeaders httpHeaders) {
// exception will be thrown on duplicate
return httpHeaders.entries().stream().collect(Collectors.toMap(Entry::getKey, Entry::getValue));
}

private static int toStatusCode(Throwable throwable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import io.scalecube.services.gateway.GatewaySessionHandler;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.jctools.maps.NonBlockingHashMapLong;
import org.slf4j.Logger;
Expand All @@ -33,7 +32,7 @@ public final class WebsocketGatewaySession implements GatewaySession {
private final WebsocketServiceMessageCodec codec;

private final long sessionId;
private final Map<String, List<String>> headers;
private final Map<String, String> headers;

/**
* Create a new websocket session with given handshake, inbound and outbound channels.
Expand All @@ -48,7 +47,7 @@ public final class WebsocketGatewaySession implements GatewaySession {
public WebsocketGatewaySession(
long sessionId,
WebsocketServiceMessageCodec codec,
Map<String, List<String>> headers,
Map<String, String> headers,
WebsocketInbound inbound,
WebsocketOutbound outbound,
GatewaySessionHandler gatewayHandler) {
Expand All @@ -68,7 +67,7 @@ public long sessionId() {
}

@Override
public Map<String, List<String>> headers() {
public Map<String, String> headers() {
return headers;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import io.scalecube.services.api.ServiceMessage;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import reactor.util.context.Context;

public class TestGatewaySessionHandler implements GatewaySessionHandler {

public final CountDownLatch msgLatch = new CountDownLatch(1);
public final CountDownLatch connLatch = new CountDownLatch(1);
public final CountDownLatch disconnLatch = new CountDownLatch(1);
private final AtomicReference<GatewaySession> lastSession = new AtomicReference<>();

@Override
public ServiceMessage mapMessage(GatewaySession s, ServiceMessage req, Context context) {
Expand All @@ -19,10 +21,15 @@ public ServiceMessage mapMessage(GatewaySession s, ServiceMessage req, Context c
@Override
public void onSessionOpen(GatewaySession s) {
connLatch.countDown();
lastSession.set(s);
}

@Override
public void onSessionClose(GatewaySession s) {
disconnLatch.countDown();
}

public GatewaySession lastSession() {
return lastSession.get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import io.scalecube.services.transport.rsocket.RSocketServiceTransport;
import java.io.IOException;
import java.time.Duration;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.AfterEach;
Expand Down Expand Up @@ -149,4 +151,29 @@ public void testHandlerEvents() throws InterruptedException {
sessionEventHandler.disconnLatch.await(3, TimeUnit.SECONDS);
Assertions.assertEquals(0, sessionEventHandler.disconnLatch.getCount());
}

@Test
void testClientSettingsHeaders() {
String headerKey = "secret-token";
String headerValue = UUID.randomUUID().toString();
client =
new RSocketGatewayClient(
GatewayClientSettings.builder()
.headers(Map.of(headerKey, headerValue))
.address(gatewayAddress)
.build(),
CLIENT_CODEC);

TestService service =
new ServiceCall()
.transport(new GatewayClientTransport(client))
.router(new StaticAddressRouter(gatewayAddress))
.api(TestService.class);

StepVerifier.create(
service.one("one").then(Mono.fromCallable(() -> sessionEventHandler.lastSession())))
.assertNext(session -> assertEquals(headerValue, session.headers().get(headerKey)))
.expectComplete()
.verify(TIMEOUT);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.time.Duration;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -200,4 +202,28 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception

assertEquals(0, keepaliveLatch.getCount());
}

@Test
void testClientSettingsHeaders() {
String headerKey = "secret-token";
String headerValue = UUID.randomUUID().toString();
client =
new WebsocketGatewayClient(
GatewayClientSettings.builder()
.address(gatewayAddress)
.headers(Map.of(headerKey, headerValue))
.build(),
CLIENT_CODEC);
TestService service =
new ServiceCall()
.transport(new GatewayClientTransport(client))
.router(new StaticAddressRouter(gatewayAddress))
.api(TestService.class);

StepVerifier.create(
service.one("one").then(Mono.fromCallable(() -> sessionEventHandler.lastSession())))
.assertNext(session -> assertEquals(headerValue, session.headers().get(headerKey)))
.expectComplete()
.verify(TIMEOUT);
}
}

0 comments on commit afb0b07

Please sign in to comment.