From f426804275cdbc1531549c7598ad93582d336a7f Mon Sep 17 00:00:00 2001 From: FZambia Date: Wed, 23 Oct 2024 20:54:50 +0300 Subject: [PATCH] transforms for all proxies, add test --- internal/proxy/connect_handler_test.go | 32 ++++++++++++++++++++++++++ internal/proxy/connect_http.go | 2 +- internal/proxy/http.go | 2 +- internal/proxy/publish_http.go | 7 ++++++ internal/proxy/refresh_http.go | 7 ++++++ internal/proxy/rpc_http.go | 7 ++++++ internal/proxy/sub_refresh_http.go | 7 ++++++ internal/proxy/subscribe_http.go | 7 ++++++ 8 files changed, 69 insertions(+), 2 deletions(-) diff --git a/internal/proxy/connect_handler_test.go b/internal/proxy/connect_handler_test.go index 405d8d8fa0..4a21800e7a 100644 --- a/internal/proxy/connect_handler_test.go +++ b/internal/proxy/connect_handler_test.go @@ -43,6 +43,9 @@ func getTestHttpProxy(commonProxyTestCase *tools.CommonHTTPProxyTestCase, endpoi StaticHttpHeaders: map[string]string{ "X-Test": "test", }, + HttpStatusTransforms: []HttpStatusToCodeTransform{ + {StatusCode: 404, ToDisconnect: TransformDisconnect{Code: 4504, Reason: "not found"}}, + }, } } @@ -334,3 +337,32 @@ func TestHandleConnectWithSubscriptionError(t *testing.T) { require.Equal(t, centrifuge.ConnectReply{}, reply, c.protocol) } } + +func TestHandleConnectWithHTTPCodeTransform(t *testing.T) { + grpcTestCase := newConnHandleGRPCTestCase(context.Background(), newProxyGRPCTestServer("http status code transform", proxyGRPCTestServerOptions{})) + defer grpcTestCase.Teardown() + + httpTestCase := newConnHandleHTTPTestCase(context.Background(), "/proxy") + httpTestCase.Mux.HandleFunc("/proxy", func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{}`)) + }) + defer httpTestCase.Teardown() + + cases := newConnHandleTestCases(httpTestCase, grpcTestCase) + for _, c := range cases { + if c.protocol == "grpc" { + continue // Transforms not supported. + } + + expectedErr := centrifuge.Disconnect{ + Code: 4504, + Reason: "not found", + } + + reply, err := c.invokeHandle(context.Background()) + require.NotNil(t, err, c.protocol) + require.Equal(t, expectedErr.Error(), err.Error(), c.protocol) + require.Equal(t, centrifuge.ConnectReply{}, reply, c.protocol) + } +} diff --git a/internal/proxy/connect_http.go b/internal/proxy/connect_http.go index 973a910e4c..37b612766b 100644 --- a/internal/proxy/connect_http.go +++ b/internal/proxy/connect_http.go @@ -41,7 +41,7 @@ func (p *HTTPConnectProxy) ProxyConnect(ctx context.Context, req *proxyproto.Con } respData, err := p.httpCaller.CallHTTP(ctx, p.config.Endpoint, httpRequestHeaders(ctx, p.config), data) if err != nil { - protocolError, protocolDisconnect := transformHTTPError(err, p.config.HttpStatusTransforms) + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) if protocolError != nil || protocolDisconnect != nil { return &proxyproto.ConnectResponse{ Error: protocolError, diff --git a/internal/proxy/http.go b/internal/proxy/http.go index 24c2f17025..557bcfbfb5 100644 --- a/internal/proxy/http.go +++ b/internal/proxy/http.go @@ -109,7 +109,7 @@ func stringInSlice(a string, list []string) bool { return false } -func transformHTTPError(err error, transforms []HttpStatusToCodeTransform) (*proxyproto.Error, *proxyproto.Disconnect) { +func transformHTTPStatusError(err error, transforms []HttpStatusToCodeTransform) (*proxyproto.Error, *proxyproto.Disconnect) { if len(transforms) == 0 { return nil, nil } diff --git a/internal/proxy/publish_http.go b/internal/proxy/publish_http.go index 0a27e6cedc..4186f09ac2 100644 --- a/internal/proxy/publish_http.go +++ b/internal/proxy/publish_http.go @@ -44,6 +44,13 @@ func (p *HTTPPublishProxy) ProxyPublish(ctx context.Context, req *proxyproto.Pub } respData, err := p.httpCaller.CallHTTP(ctx, p.config.Endpoint, httpRequestHeaders(ctx, p.config), data) if err != nil { + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.PublishResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return nil, err } return httpDecoder.DecodePublishResponse(respData) diff --git a/internal/proxy/refresh_http.go b/internal/proxy/refresh_http.go index b5741e6721..c3621692d4 100644 --- a/internal/proxy/refresh_http.go +++ b/internal/proxy/refresh_http.go @@ -40,6 +40,13 @@ func (p *HTTPRefreshProxy) ProxyRefresh(ctx context.Context, req *proxyproto.Ref if err != nil { return nil, err } + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.RefreshResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return httpDecoder.DecodeRefreshResponse(respData) } diff --git a/internal/proxy/rpc_http.go b/internal/proxy/rpc_http.go index d78a37ed58..4303f78238 100644 --- a/internal/proxy/rpc_http.go +++ b/internal/proxy/rpc_http.go @@ -31,6 +31,13 @@ func (p *HTTPRPCProxy) ProxyRPC(ctx context.Context, req *proxyproto.RPCRequest) } respData, err := p.httpCaller.CallHTTP(ctx, p.config.Endpoint, httpRequestHeaders(ctx, p.config), data) if err != nil { + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.RPCResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return nil, err } return httpDecoder.DecodeRPCResponse(respData) diff --git a/internal/proxy/sub_refresh_http.go b/internal/proxy/sub_refresh_http.go index 4de1640c07..4d7fcf9cf4 100644 --- a/internal/proxy/sub_refresh_http.go +++ b/internal/proxy/sub_refresh_http.go @@ -39,6 +39,13 @@ func (p *HTTPSubRefreshProxy) ProxySubRefresh(ctx context.Context, req *proxypro } respData, err := p.httpCaller.CallHTTP(ctx, p.config.Endpoint, httpRequestHeaders(ctx, p.config), data) if err != nil { + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.SubRefreshResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return nil, err } return httpDecoder.DecodeSubRefreshResponse(respData) diff --git a/internal/proxy/subscribe_http.go b/internal/proxy/subscribe_http.go index 6cf0276f27..3fa09b7711 100644 --- a/internal/proxy/subscribe_http.go +++ b/internal/proxy/subscribe_http.go @@ -31,6 +31,13 @@ func (p *HTTPSubscribeProxy) ProxySubscribe(ctx context.Context, req *proxyproto } respData, err := p.httpCaller.CallHTTP(ctx, p.config.Endpoint, httpRequestHeaders(ctx, p.config), data) if err != nil { + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.SubscribeResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return nil, err } return httpDecoder.DecodeSubscribeResponse(respData)