diff --git a/internal/config/validate.go b/internal/config/validate.go index df768d879..49623764c 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -8,6 +8,7 @@ import ( "time" "github.com/centrifugal/centrifugo/v5/internal/configtypes" + "github.com/centrifugal/centrifugo/v5/internal/tools" "github.com/centrifugal/centrifuge" ) @@ -36,6 +37,10 @@ func (c Config) Validate() error { if p.Endpoint == "" { return fmt.Errorf("no endpoint set for proxy %s", p.Name) } + if err := validateStatusTransforms(p.ProxyCommon.HTTP.StatusToCodeTransforms); err != nil { + return fmt.Errorf("in proxy %s: %v", p.Name, err) + } + proxyNames = append(proxyNames, p.Name) } if slices.Contains(proxyNames, UnifiedProxyName) { @@ -59,6 +64,9 @@ func (c Config) Validate() error { if err := validateSecondPrecisionDuration(c.Channel.HistoryMetaTTL); err != nil { return fmt.Errorf("in channel.history_meta_ttl: %v", err) } + if err := validateStatusTransforms(c.UnifiedProxy.ProxyCommon.HTTP.StatusToCodeTransforms); err != nil { + return fmt.Errorf("in proxy %s: %v", UnifiedProxyName, err) + } if err := validateChannelOptions(c.Channel.WithoutNamespace, c.Channel.HistoryMetaTTL, proxyNames, c); err != nil { return fmt.Errorf("in channel.without_namespace: %v", err) @@ -124,6 +132,13 @@ func (c Config) Validate() error { consumerNames = append(consumerNames, config.Name) } + if err := validateConnectCodeTransforms(c.UniSSE.ConnectCodeToHTTPStatus.Transforms); err != nil { + return fmt.Errorf("in uni_sse.connect_code_to_http_status.transforms: %v", err) + } + if err := validateConnectCodeTransforms(c.UniHTTPStream.ConnectCodeToHTTPStatus.Transforms); err != nil { + return fmt.Errorf("in uni_http_stream.connect_code_to_http_status.transforms: %v", err) + } + return nil } @@ -256,3 +271,43 @@ func validateTokens(cfg Config) error { } return nil } + +func validateStatusTransforms(transforms []configtypes.HttpStatusToCodeTransform) error { + for i, transform := range transforms { + if transform.StatusCode == 0 { + return fmt.Errorf("status_code should be set in status_to_code_transforms[%d]", i) + } + if transform.ToDisconnect.Code == 0 && transform.ToError.Code == 0 { + return fmt.Errorf("no error or disconnect code set in status_to_code_transforms[%d]", i) + } + if transform.ToDisconnect.Code > 0 && transform.ToError.Code > 0 { + return fmt.Errorf("only error or disconnect code can be set in status_to_code_transforms[%d], but not both", i) + } + if !tools.IsASCII(transform.ToDisconnect.Reason) { + return fmt.Errorf("status_to_code_transforms[%d] disconnect reason must be ASCII", i) + } + if !tools.IsASCII(transform.ToError.Message) { + return fmt.Errorf("status_to_code_transforms[%d] error message must be ASCII", i) + } + const reasonOrMessageMaxLength = 123 // limit comes from WebSocket close reason length limit. See https://datatracker.ietf.org/doc/html/rfc6455. + if len(transform.ToError.Message) > reasonOrMessageMaxLength { + return fmt.Errorf("status_to_code_transforms[%d] item error message can be up to %d characters long", i, reasonOrMessageMaxLength) + } + if len(transform.ToDisconnect.Reason) > reasonOrMessageMaxLength { + return fmt.Errorf("status_to_code_transforms[%d] disconnect reason can be up to %d characters long", i, reasonOrMessageMaxLength) + } + } + return nil +} + +func validateConnectCodeTransforms(transforms []configtypes.ConnectCodeToHTTPStatusTransform) error { + for i, transform := range transforms { + if transform.Code == 0 { + return fmt.Errorf("code should be set in connect_code_to_http_status.transforms[%d]", i) + } + if transform.ToResponse.StatusCode == 0 { + return fmt.Errorf("status_code should be set in connect_code_to_http_status.transforms[%d].to_response", i) + } + } + return nil +} diff --git a/internal/configtypes/types.go b/internal/configtypes/types.go index bd2ecb883..1e8cbf959 100644 --- a/internal/configtypes/types.go +++ b/internal/configtypes/types.go @@ -88,6 +88,20 @@ type UniSSE struct { ConnectCodeToHTTPStatus ConnectCodeToHTTPStatus `mapstructure:"connect_code_to_http_status" json:"connect_code_to_http_status" envconfig:"connect_code_to_http_status" yaml:"connect_code_to_http_status" toml:"connect_code_to_http_status"` } +type ConnectCodeToHTTPStatusTransforms []ConnectCodeToHTTPStatusTransform + +// Decode to implement the envconfig.Decoder interface +func (d *ConnectCodeToHTTPStatusTransforms) Decode(value string) error { + // If the source is a string and the target is a slice, try to parse it as JSON. + var items ConnectCodeToHTTPStatusTransforms + err := json.Unmarshal([]byte(value), &items) + if err != nil { + return fmt.Errorf("error parsing items from JSON: %v", err) + } + *d = items + return nil +} + type ConnectCodeToHTTPStatus struct { Enabled bool `mapstructure:"enabled" json:"enabled" envconfig:"enabled" yaml:"enabled" toml:"enabled"` Transforms []ConnectCodeToHTTPStatusTransform `mapstructure:"transforms" json:"transforms" envconfig:"transforms" yaml:"transforms" toml:"transforms"` @@ -99,8 +113,8 @@ type ConnectCodeToHTTPStatusTransform struct { } type TransformedConnectErrorHttpResponse struct { - Status int `mapstructure:"status_code" json:"status_code" envconfig:"status_code" yaml:"status_code" toml:"status_code"` - Body string `mapstructure:"body" json:"body" envconfig:"body" yaml:"body" toml:"body"` + StatusCode int `mapstructure:"status_code" json:"status_code" envconfig:"status_code" yaml:"status_code" toml:"status_code"` + Body string `mapstructure:"body" json:"body" envconfig:"body" yaml:"body" toml:"body"` } func ConnectErrorToToHTTPResponse(err error, transforms []ConnectCodeToHTTPStatusTransform) (TransformedConnectErrorHttpResponse, bool) { @@ -130,8 +144,8 @@ func ConnectErrorToToHTTPResponse(err error, transforms []ConnectCodeToHTTPStatu } } return TransformedConnectErrorHttpResponse{ - Status: http.StatusInternalServerError, - Body: http.StatusText(http.StatusInternalServerError), + StatusCode: http.StatusInternalServerError, + Body: http.StatusText(http.StatusInternalServerError), }, false } @@ -410,13 +424,27 @@ type HttpStatusToCodeTransform struct { ToDisconnect TransformDisconnect `mapstructure:"to_disconnect" json:"to_disconnect" json:"to_disconnect" yaml:"to_disconnect" toml:"to_disconnect"` } +type HttpStatusToCodeTransforms []HttpStatusToCodeTransform + +// Decode to implement the envconfig.Decoder interface +func (d *HttpStatusToCodeTransforms) Decode(value string) error { + // If the source is a string and the target is a slice, try to parse it as JSON. + var items HttpStatusToCodeTransforms + err := json.Unmarshal([]byte(value), &items) + if err != nil { + return fmt.Errorf("error parsing items from JSON: %v", err) + } + *d = items + return nil +} + type ProxyCommonHTTP struct { // StaticHeaders is a static set of key/value pairs to attach to HTTP proxy request as // headers. Headers received from HTTP client request or metadata from GRPC client request // both have priority over values set in StaticHttpHeaders map. StaticHeaders MapStringString `mapstructure:"static_headers" default:"{}" json:"static_headers" envconfig:"static_headers" yaml:"static_headers" toml:"static_headers"` // Status transforms allow to map HTTP status codes from proxy to Disconnect or Error messages. - StatusToCodeTransforms []HttpStatusToCodeTransform `mapstructure:"status_to_code_transforms" default:"[]" json:"status_to_code_transforms,omitempty" envconfig:"status_to_code_transforms" yaml:"status_to_code_transforms" toml:"status_to_code_transforms"` + StatusToCodeTransforms HttpStatusToCodeTransforms `mapstructure:"status_to_code_transforms" default:"[]" json:"status_to_code_transforms,omitempty" envconfig:"status_to_code_transforms" yaml:"status_to_code_transforms" toml:"status_to_code_transforms"` } type ProxyCommonGRPC struct { diff --git a/internal/runutil/proxy.go b/internal/runutil/proxy.go index 54a05f213..5425e74fd 100644 --- a/internal/runutil/proxy.go +++ b/internal/runutil/proxy.go @@ -3,6 +3,8 @@ package runutil import ( "strings" + "github.com/centrifugal/centrifugo/v5/internal/tools" + "github.com/centrifugal/centrifugo/v5/internal/client" "github.com/centrifugal/centrifugo/v5/internal/config" "github.com/centrifugal/centrifugo/v5/internal/proxy" @@ -53,7 +55,7 @@ func buildProxyMap(cfg config.Config) (*client.ProxyMap, bool) { if err != nil { log.Fatal().Msgf("error creating connect proxy: %v", err) } - log.Info().Str("proxy_name", connectProxyName).Str("endpoint", p.Endpoint).Msg("connect proxy enabled") + log.Info().Str("proxy_name", connectProxyName).Str("endpoint", tools.RedactedLogURLs(p.Endpoint)[0]).Msg("connect proxy enabled") keepHeadersInContext = true } @@ -76,7 +78,7 @@ func buildProxyMap(cfg config.Config) (*client.ProxyMap, bool) { if err != nil { log.Fatal().Msgf("error creating refresh proxy: %v", err) } - log.Info().Str("proxy_name", refreshProxyName).Str("endpoint", p.Endpoint).Msg("refresh proxy enabled") + log.Info().Str("proxy_name", refreshProxyName).Str("endpoint", tools.RedactedLogURLs(p.Endpoint)[0]).Msg("refresh proxy enabled") keepHeadersInContext = true } @@ -99,7 +101,7 @@ func buildProxyMap(cfg config.Config) (*client.ProxyMap, bool) { log.Fatal().Msgf("error creating subscribe proxy: %v", err) } proxyMap.SubscribeProxies[subscribeProxyName] = sp - log.Info().Str("proxy_name", subscribeProxyName).Str("endpoint", p.Endpoint).Msg("subscribe proxy enabled for channels without namespace") + log.Info().Str("proxy_name", subscribeProxyName).Str("endpoint", tools.RedactedLogURLs(p.Endpoint)[0]).Msg("subscribe proxy enabled for channels without namespace") keepHeadersInContext = true } @@ -122,7 +124,7 @@ func buildProxyMap(cfg config.Config) (*client.ProxyMap, bool) { log.Fatal().Msgf("error creating publish proxy: %v", err) } proxyMap.PublishProxies[publishProxyName] = pp - log.Info().Str("proxy_name", publishProxyName).Str("endpoint", p.Endpoint).Msg("publish proxy enabled for channels without namespace") + log.Info().Str("proxy_name", publishProxyName).Str("endpoint", tools.RedactedLogURLs(p.Endpoint)[0]).Msg("publish proxy enabled for channels without namespace") keepHeadersInContext = true } @@ -145,7 +147,7 @@ func buildProxyMap(cfg config.Config) (*client.ProxyMap, bool) { log.Fatal().Msgf("error creating publish proxy: %v", err) } proxyMap.SubRefreshProxies[subRefreshProxyName] = srp - log.Info().Str("proxy_name", subRefreshProxyName).Str("endpoint", p.Endpoint).Msg("sub refresh proxy enabled for channels without namespace") + log.Info().Str("proxy_name", subRefreshProxyName).Str("endpoint", tools.RedactedLogURLs(p.Endpoint)[0]).Msg("sub refresh proxy enabled for channels without namespace") keepHeadersInContext = true } @@ -171,7 +173,7 @@ func buildProxyMap(cfg config.Config) (*client.ProxyMap, bool) { log.Fatal().Msgf("error creating subscribe proxy: %v", err) } proxyMap.SubscribeStreamProxies[subscribeProxyName] = sp - log.Info().Str("proxy_name", subscribeStreamProxyName).Str("endpoint", p.Endpoint).Msg("subscribe stream proxy enabled for channels without namespace") + log.Info().Str("proxy_name", subscribeStreamProxyName).Str("endpoint", tools.RedactedLogURLs(p.Endpoint)[0]).Msg("subscribe stream proxy enabled for channels without namespace") keepHeadersInContext = true } @@ -201,7 +203,7 @@ func buildProxyMap(cfg config.Config) (*client.ProxyMap, bool) { } proxyMap.SubscribeProxies[subscribeProxyName] = sp } - log.Info().Str("proxy_name", subscribeProxyName).Str("endpoint", p.Endpoint).Str("namespace", ns.Name).Msg("subscribe proxy enabled for channels in namespace") + log.Info().Str("proxy_name", subscribeProxyName).Str("endpoint", tools.RedactedLogURLs(p.Endpoint)[0]).Str("namespace", ns.Name).Msg("subscribe proxy enabled for channels in namespace") } if publishProxyName != "" { @@ -224,7 +226,7 @@ func buildProxyMap(cfg config.Config) (*client.ProxyMap, bool) { } proxyMap.PublishProxies[publishProxyName] = pp } - log.Info().Str("proxy_name", publishProxyName).Str("endpoint", p.Endpoint).Str("namespace", ns.Name).Msg("publish proxy enabled for channels in namespace") + log.Info().Str("proxy_name", publishProxyName).Str("endpoint", tools.RedactedLogURLs(p.Endpoint)[0]).Str("namespace", ns.Name).Msg("publish proxy enabled for channels in namespace") keepHeadersInContext = true } @@ -248,7 +250,7 @@ func buildProxyMap(cfg config.Config) (*client.ProxyMap, bool) { } proxyMap.SubRefreshProxies[subRefreshProxyName] = srp } - log.Info().Str("proxy_name", subRefreshProxyName).Str("endpoint", p.Endpoint).Str("namespace", ns.Name).Msg("sub refresh proxy enabled for channels in namespace") + log.Info().Str("proxy_name", subRefreshProxyName).Str("endpoint", tools.RedactedLogURLs(p.Endpoint)[0]).Str("namespace", ns.Name).Msg("sub refresh proxy enabled for channels in namespace") keepHeadersInContext = true } @@ -275,7 +277,7 @@ func buildProxyMap(cfg config.Config) (*client.ProxyMap, bool) { } proxyMap.SubscribeStreamProxies[subscribeStreamProxyName] = sp } - log.Info().Str("proxy_name", subscribeStreamProxyName).Str("endpoint", p.Endpoint).Str("namespace", ns.Name).Msg("subscribe stream proxy enabled for channels in namespace") + log.Info().Str("proxy_name", subscribeStreamProxyName).Str("endpoint", tools.RedactedLogURLs(p.Endpoint)[0]).Str("namespace", ns.Name).Msg("subscribe stream proxy enabled for channels in namespace") keepHeadersInContext = true } } @@ -299,7 +301,7 @@ func buildProxyMap(cfg config.Config) (*client.ProxyMap, bool) { log.Fatal().Msgf("error creating rpc proxy: %v", err) } proxyMap.RpcProxies[rpcProxyName] = rp - log.Info().Str("proxy_name", rpcProxyName).Str("endpoint", p.Endpoint).Msg("RPC proxy enabled for RPC calls without namespace") + log.Info().Str("proxy_name", rpcProxyName).Str("endpoint", tools.RedactedLogURLs(p.Endpoint)[0]).Msg("RPC proxy enabled for RPC calls without namespace") keepHeadersInContext = true } @@ -325,7 +327,7 @@ func buildProxyMap(cfg config.Config) (*client.ProxyMap, bool) { } proxyMap.RpcProxies[rpcProxyName] = rp } - log.Info().Str("proxy_name", rpcProxyName).Str("endpoint", p.Endpoint).Str("namespace", ns.Name).Msg("RPC proxy enabled for RPC calls in namespace") + log.Info().Str("proxy_name", rpcProxyName).Str("endpoint", tools.RedactedLogURLs(p.Endpoint)[0]).Str("namespace", ns.Name).Msg("RPC proxy enabled for RPC calls in namespace") keepHeadersInContext = true } } diff --git a/internal/unihttpstream/handler.go b/internal/unihttpstream/handler.go index 69b382c2a..c90de262d 100644 --- a/internal/unihttpstream/handler.go +++ b/internal/unihttpstream/handler.go @@ -97,7 +97,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err != nil { resp, ok := configtypes.ConnectErrorToToHTTPResponse(err, h.config.ConnectCodeToHTTPStatus.Transforms) if ok { - w.WriteHeader(resp.Status) + w.WriteHeader(resp.StatusCode) _, _ = w.Write([]byte(resp.Body)) return } diff --git a/internal/unisse/handler.go b/internal/unisse/handler.go index 66e89f60f..990427385 100644 --- a/internal/unisse/handler.go +++ b/internal/unisse/handler.go @@ -108,7 +108,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err != nil { resp, ok := configtypes.ConnectErrorToToHTTPResponse(err, h.config.ConnectCodeToHTTPStatus.Transforms) if ok { - w.WriteHeader(resp.Status) + w.WriteHeader(resp.StatusCode) _, _ = w.Write([]byte(resp.Body)) return }