diff --git a/internal/proxy/connect_handler_test.go b/internal/proxy/connect_handler_test.go index 405d8d8fa..4a21800e7 100644 --- a/internal/proxy/connect_handler_test.go +++ b/internal/proxy/connect_handler_test.go @@ -43,6 +43,9 @@ func getTestHttpProxy(commonProxyTestCase *tools.CommonHTTPProxyTestCase, endpoi StaticHttpHeaders: map[string]string{ "X-Test": "test", }, + HttpStatusTransforms: []HttpStatusToCodeTransform{ + {StatusCode: 404, ToDisconnect: TransformDisconnect{Code: 4504, Reason: "not found"}}, + }, } } @@ -334,3 +337,32 @@ func TestHandleConnectWithSubscriptionError(t *testing.T) { require.Equal(t, centrifuge.ConnectReply{}, reply, c.protocol) } } + +func TestHandleConnectWithHTTPCodeTransform(t *testing.T) { + grpcTestCase := newConnHandleGRPCTestCase(context.Background(), newProxyGRPCTestServer("http status code transform", proxyGRPCTestServerOptions{})) + defer grpcTestCase.Teardown() + + httpTestCase := newConnHandleHTTPTestCase(context.Background(), "/proxy") + httpTestCase.Mux.HandleFunc("/proxy", func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{}`)) + }) + defer httpTestCase.Teardown() + + cases := newConnHandleTestCases(httpTestCase, grpcTestCase) + for _, c := range cases { + if c.protocol == "grpc" { + continue // Transforms not supported. + } + + expectedErr := centrifuge.Disconnect{ + Code: 4504, + Reason: "not found", + } + + reply, err := c.invokeHandle(context.Background()) + require.NotNil(t, err, c.protocol) + require.Equal(t, expectedErr.Error(), err.Error(), c.protocol) + require.Equal(t, centrifuge.ConnectReply{}, reply, c.protocol) + } +} diff --git a/internal/proxy/connect_http.go b/internal/proxy/connect_http.go index 973a910e4..37b612766 100644 --- a/internal/proxy/connect_http.go +++ b/internal/proxy/connect_http.go @@ -41,7 +41,7 @@ func (p *HTTPConnectProxy) ProxyConnect(ctx context.Context, req *proxyproto.Con } respData, err := p.httpCaller.CallHTTP(ctx, p.config.Endpoint, httpRequestHeaders(ctx, p.config), data) if err != nil { - protocolError, protocolDisconnect := transformHTTPError(err, p.config.HttpStatusTransforms) + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) if protocolError != nil || protocolDisconnect != nil { return &proxyproto.ConnectResponse{ Error: protocolError, diff --git a/internal/proxy/http.go b/internal/proxy/http.go index 24c2f1702..557bcfbfb 100644 --- a/internal/proxy/http.go +++ b/internal/proxy/http.go @@ -109,7 +109,7 @@ func stringInSlice(a string, list []string) bool { return false } -func transformHTTPError(err error, transforms []HttpStatusToCodeTransform) (*proxyproto.Error, *proxyproto.Disconnect) { +func transformHTTPStatusError(err error, transforms []HttpStatusToCodeTransform) (*proxyproto.Error, *proxyproto.Disconnect) { if len(transforms) == 0 { return nil, nil } diff --git a/internal/proxy/publish_http.go b/internal/proxy/publish_http.go index 0a27e6ced..4186f09ac 100644 --- a/internal/proxy/publish_http.go +++ b/internal/proxy/publish_http.go @@ -44,6 +44,13 @@ func (p *HTTPPublishProxy) ProxyPublish(ctx context.Context, req *proxyproto.Pub } respData, err := p.httpCaller.CallHTTP(ctx, p.config.Endpoint, httpRequestHeaders(ctx, p.config), data) if err != nil { + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.PublishResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return nil, err } return httpDecoder.DecodePublishResponse(respData) diff --git a/internal/proxy/refresh_http.go b/internal/proxy/refresh_http.go index b5741e672..c3621692d 100644 --- a/internal/proxy/refresh_http.go +++ b/internal/proxy/refresh_http.go @@ -40,6 +40,13 @@ func (p *HTTPRefreshProxy) ProxyRefresh(ctx context.Context, req *proxyproto.Ref if err != nil { return nil, err } + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.RefreshResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return httpDecoder.DecodeRefreshResponse(respData) } diff --git a/internal/proxy/rpc_http.go b/internal/proxy/rpc_http.go index d78a37ed5..4303f7823 100644 --- a/internal/proxy/rpc_http.go +++ b/internal/proxy/rpc_http.go @@ -31,6 +31,13 @@ func (p *HTTPRPCProxy) ProxyRPC(ctx context.Context, req *proxyproto.RPCRequest) } respData, err := p.httpCaller.CallHTTP(ctx, p.config.Endpoint, httpRequestHeaders(ctx, p.config), data) if err != nil { + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.RPCResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return nil, err } return httpDecoder.DecodeRPCResponse(respData) diff --git a/internal/proxy/sub_refresh_http.go b/internal/proxy/sub_refresh_http.go index 4de1640c0..4d7fcf9cf 100644 --- a/internal/proxy/sub_refresh_http.go +++ b/internal/proxy/sub_refresh_http.go @@ -39,6 +39,13 @@ func (p *HTTPSubRefreshProxy) ProxySubRefresh(ctx context.Context, req *proxypro } respData, err := p.httpCaller.CallHTTP(ctx, p.config.Endpoint, httpRequestHeaders(ctx, p.config), data) if err != nil { + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.SubRefreshResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return nil, err } return httpDecoder.DecodeSubRefreshResponse(respData) diff --git a/internal/proxy/subscribe_http.go b/internal/proxy/subscribe_http.go index 6cf0276f2..3fa09b771 100644 --- a/internal/proxy/subscribe_http.go +++ b/internal/proxy/subscribe_http.go @@ -31,6 +31,13 @@ func (p *HTTPSubscribeProxy) ProxySubscribe(ctx context.Context, req *proxyproto } respData, err := p.httpCaller.CallHTTP(ctx, p.config.Endpoint, httpRequestHeaders(ctx, p.config), data) if err != nil { + protocolError, protocolDisconnect := transformHTTPStatusError(err, p.config.HttpStatusTransforms) + if protocolError != nil || protocolDisconnect != nil { + return &proxyproto.SubscribeResponse{ + Error: protocolError, + Disconnect: protocolDisconnect, + }, nil + } return nil, err } return httpDecoder.DecodeSubscribeResponse(respData)