Skip to content

Commit

Permalink
Merge pull request #4 from signalsciences/sigsci_sync_1.4.2
Browse files Browse the repository at this point in the history
Sync fork with upstream repo v1.4.2
  • Loading branch information
amacnair authored Jan 24, 2023
2 parents 9260f86 + ca8f6ec commit c7c7d08
Showing 1 changed file with 57 additions and 14 deletions.
71 changes: 57 additions & 14 deletions forward/fwd.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package forward

import (
"bytes"
"context"
"crypto/tls"
"errors"
"io"
Expand Down Expand Up @@ -42,6 +43,9 @@ type ReqRewriter interface {
Rewrite(r *http.Request)
}

// WsHook websocket message hook called when message is received or sent
type WsHook func(req *http.Request, messageType int, reader io.Reader) (io.Reader, error)

type optSetter func(f *Forwarder) error

// PassHostHeader specifies if a client's Host header field should be delegated.
Expand Down Expand Up @@ -69,10 +73,21 @@ func Rewriter(r ReqRewriter) optSetter {
}
}

// WebsocketTLSClientConfig define the websocker client TLS configuration.
// WebsocketTLSClientConfig define the websocket client TLS configuration.
func WebsocketTLSClientConfig(tcc *tls.Config) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.tlsClientConfig = tcc
f.websocketDialer.TLSClientConfig = tcc.Clone()
// WebSocket is only in http/1.1
f.websocketDialer.TLSClientConfig.NextProtos = []string{"http/1.1"}

return nil
}
}

// WebsocketNetDialContext define the websocket client DialContext function
func WebsocketNetDialContext(dialContext func(ctx context.Context, network string, addr string) (net.Conn, error)) optSetter {
return func(f *Forwarder) error {
f.websocketDialer.NetDialContext = dialContext
return nil
}
}
Expand Down Expand Up @@ -136,7 +151,23 @@ func WebsocketConnectionClosedHook(hook func(req *http.Request, conn net.Conn))
}
}

// ResponseModifier defines a response modifier for the HTTP forwarder.
// WebsocketMessageReceivedHook defines a hook called when websocket message is received.
func WebsocketMessageReceivedHook(hook WsHook) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.websocketMessageReceivedHook = hook
return nil
}
}

// WebsocketMessageSentHook defines a hook called when websocket message is sent.
func WebsocketMessageSentHook(hook WsHook) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.websocketMessageSentHook = hook
return nil
}
}

// ResponseModifier defines a response modifier for the HTTP forwarder
func ResponseModifier(responseModifier func(*http.Response) error) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.modifyResponse = responseModifier
Expand Down Expand Up @@ -180,6 +211,9 @@ type httpForwarder struct {

bufferPool httputil.BufferPool
websocketConnectionClosedHook func(req *http.Request, conn net.Conn)
websocketMessageReceivedHook WsHook
websocketMessageSentHook WsHook
websocketDialer *websocket.Dialer
}

const defaultFlushInterval = 100 * clock.Millisecond
Expand All @@ -203,6 +237,12 @@ func New(setters ...optSetter) (*Forwarder, error) {
httpForwarder: &httpForwarder{log: &internalLogger{Logger: log.StandardLogger()}},
handlerContext: &handlerContext{},
}

f.websocketDialer = &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 45 * time.Second,
}

for _, s := range setters {
if err := s(f); err != nil {
return nil, err
Expand Down Expand Up @@ -234,6 +274,9 @@ func New(setters ...optSetter) (*Forwarder, error) {
if f.tlsClientConfig == nil {
if ht, ok := f.httpForwarder.roundTripper.(*http.Transport); ok {
f.tlsClientConfig = ht.TLSClientConfig
if f.websocketDialer.TLSClientConfig == nil && ht.TLSClientConfig != nil {
_ = WebsocketTLSClientConfig(ht.TLSClientConfig)(f)
}
}
}

Expand Down Expand Up @@ -315,14 +358,7 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request,

outReq := f.copyWebSocketRequest(req)

dialer := websocket.DefaultDialer
if outReq.URL.Scheme == "wss" && f.tlsClientConfig != nil {
dialer.TLSClientConfig = f.tlsClientConfig.Clone()
// WebSocket is only in http/1.1
dialer.TLSClientConfig.NextProtos = []string{"http/1.1"}
}

targetConn, resp, err := dialer.DialContext(outReq.Context(), outReq.URL.String(), outReq.Header)
targetConn, resp, err := f.websocketDialer.DialContext(outReq.Context(), outReq.URL.String(), outReq.Header)
if err != nil {
if resp == nil {
ctx.errHandler.ServeHTTP(w, req, err)
Expand Down Expand Up @@ -383,7 +419,8 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request,

errClient := make(chan error, 1)
errBackend := make(chan error, 1)
replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) {
replicateWebsocketConn := func(dst, src *websocket.Conn, websocketMessageHook WsHook, errc chan error) {

forward := func(messageType int, reader io.Reader) error {
writer, err := dst.NextWriter(messageType)
if err != nil {
Expand Down Expand Up @@ -424,6 +461,12 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request,
}
break
}
if websocketMessageHook != nil {
if reader, err = websocketMessageHook(req, msgType, reader); err != nil {
errc <- err
break
}
}
err = forward(msgType, reader)
if err != nil {
errc <- err
Expand All @@ -432,8 +475,8 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request,
}
}

go replicateWebsocketConn(underlyingConn, targetConn, errClient)
go replicateWebsocketConn(targetConn, underlyingConn, errBackend)
go replicateWebsocketConn(underlyingConn, targetConn, f.websocketMessageSentHook, errClient)
go replicateWebsocketConn(targetConn, underlyingConn, f.websocketMessageReceivedHook, errBackend)

var message string
select {
Expand Down

0 comments on commit c7c7d08

Please sign in to comment.