diff --git a/internal/app/engine.go b/internal/app/engine.go index 01ac102d7b..7b0cfd95ed 100644 --- a/internal/app/engine.go +++ b/internal/app/engine.go @@ -7,6 +7,7 @@ import ( "github.com/centrifugal/centrifugo/v5/internal/confighelpers" "github.com/centrifugal/centrifugo/v5/internal/natsbroker" "github.com/centrifugal/centrifugo/v5/internal/redisnatsbroker" + "github.com/centrifugal/centrifugo/v5/internal/service" "github.com/centrifugal/centrifuge" "github.com/rs/zerolog/log" @@ -18,7 +19,7 @@ type engineModes struct { presenceManagerMode string } -func configureEngines(node *centrifuge.Node, cfgContainer *config.Container) (engineModes, error) { +func configureEngines(node *centrifuge.Node, cfgContainer *config.Container, serviceManager *service.Manager) (engineModes, error) { cfg := cfgContainer.Config() var modes engineModes @@ -32,7 +33,7 @@ func configureEngines(node *centrifuge.Node, cfgContainer *config.Container) (en case "memory": broker, presenceManager, err = createMemoryEngine(node) case "redis": - broker, presenceManager, modes.engineMode, err = createRedisEngine(node, cfgContainer) + broker, presenceManager, modes.engineMode, err = createRedisEngine(node, cfgContainer, serviceManager) default: return modes, fmt.Errorf("unknown engine type: %s", cfg.Engine.Type) } @@ -97,7 +98,7 @@ func configureEngines(node *centrifuge.Node, cfgContainer *config.Container) (en case "memory": presenceManager, err = createMemoryPresenceManager(node) case "redis": - presenceManager, modes.presenceManagerMode, err = createRedisPresenceManager(node, cfgContainer) + presenceManager, modes.presenceManagerMode, err = createRedisPresenceManager(node, cfgContainer, serviceManager) default: return modes, fmt.Errorf("unknown presence manager type: %s", cfg.PresenceManager.Type) } @@ -162,7 +163,7 @@ func NatsBroker(node *centrifuge.Node, cfg config.Config) (*natsbroker.NatsBroke return natsbroker.New(node, cfg.Broker.Nats) } -func createRedisEngine(n *centrifuge.Node, cfgContainer *config.Container) (*centrifuge.RedisBroker, centrifuge.PresenceManager, string, error) { +func createRedisEngine(n *centrifuge.Node, cfgContainer *config.Container, _ *service.Manager) (*centrifuge.RedisBroker, centrifuge.PresenceManager, string, error) { cfg := cfgContainer.Config() redisShards, mode, err := confighelpers.CentrifugeRedisShards(n, cfg.Broker.Redis.Redis) if err != nil { @@ -202,7 +203,7 @@ func createRedisBroker(n *centrifuge.Node, cfgContainer *config.Container) (*cen return broker, mode, nil } -func createRedisPresenceManager(n *centrifuge.Node, cfgContainer *config.Container) (centrifuge.PresenceManager, string, error) { +func createRedisPresenceManager(n *centrifuge.Node, cfgContainer *config.Container, _ *service.Manager) (centrifuge.PresenceManager, string, error) { cfg := cfgContainer.Config() redisShards, mode, err := confighelpers.CentrifugeRedisShards(n, cfg.PresenceManager.Redis.Redis) if err != nil { diff --git a/internal/app/run.go b/internal/app/run.go index f337737fa0..5aedef96dc 100644 --- a/internal/app/run.go +++ b/internal/app/run.go @@ -31,7 +31,6 @@ import ( "github.com/rs/zerolog/log" "github.com/spf13/cobra" "go.uber.org/automaxprocs/maxprocs" - "golang.org/x/sync/errgroup" "google.golang.org/grpc" _ "google.golang.org/grpc/encoding/gzip" @@ -62,6 +61,10 @@ func Run(cmd *cobra.Command, configFile string) { log.Info().Msgf(strings.ToLower(s), i...) })) + // Registered services will be run after node.Run() but before HTTP/GRPC servers start. + // Registered services will be stopped after node's shutdown and HTTP/GRPC servers shutdown. + serviceManager := service.NewManager() + entry := log.Info(). Str("version", build.Version). Str("runtime", runtime.Version()). @@ -108,7 +111,7 @@ func Run(cmd *cobra.Command, configFile string) { } } - modes, err := configureEngines(node, cfgContainer) + modes, err := configureEngines(node, cfgContainer, serviceManager) if err != nil { log.Fatal().Msgf("%v", err) } @@ -166,8 +169,6 @@ func Run(cmd *cobra.Command, configFile string) { UseOpenTelemetry: useConsumingOpentelemetry, }) - var services []service.Service - consumingHandler := api.NewConsumingHandler(node, consumingAPIExecutor, api.ConsumingHandlerConfig{ UseOpenTelemetry: useConsumingOpentelemetry, }) @@ -177,37 +178,10 @@ func Run(cmd *cobra.Command, configFile string) { log.Fatal().Msgf("error initializing consumers: %v", err) } - services = append(services, consumingServices...) - - if err = node.Run(); err != nil { - log.Fatal().Msgf("error running node: %v", err) - } - - var grpcAPIServer *grpc.Server - if cfg.GrpcAPI.Enabled { - var err error - grpcAPIServer, err = runGRPCAPIServer(cfg, node, useAPIOpentelemetry, grpcAPIExecutor) - if err != nil { - log.Fatal().Msgf("error creating GRPC API server: %v", err) - } - } - - var grpcUniServer *grpc.Server - if cfg.UniGRPC.Enabled { - var err error - grpcAPIServer, err = runGRPCUniServer(cfg, node) - if err != nil { - log.Fatal().Msgf("error creating GRPC API server: %v", err) - } - } - - httpServers, err := runHTTPServers(node, cfgContainer, httpAPIExecutor, keepHeadersInContext) - if err != nil { - log.Fatal().Msgf("error running HTTP server: %v", err) - } + serviceManager.Register(consumingServices...) if cfg.Graphite.Enabled { - services = append(services, graphiteExporter(cfg, nodeCfg)) + serviceManager.Register(graphiteExporter(cfg, nodeCfg)) } var statsSender *usage.Sender @@ -252,37 +226,55 @@ func Run(cmd *cobra.Command, configFile string) { Throttling: false, Singleflight: false, }) - services = append(services, statsSender) + serviceManager.Register(statsSender) } notify.RegisterHandlers(node, statsSender) - var serviceGroup *errgroup.Group - serviceCancel := func() {} - if len(services) > 0 { - var serviceCtx context.Context - serviceCtx, serviceCancel = context.WithCancel(context.Background()) - serviceGroup, serviceCtx = errgroup.WithContext(serviceCtx) - for _, s := range services { - s := s - serviceGroup.Go(func() error { - return s.Run(serviceCtx) - }) + if err = node.Run(); err != nil { + log.Fatal().Msgf("error running node: %v", err) + } + + ctx, serviceCancel := context.WithCancel(context.Background()) + defer serviceCancel() + serviceManager.Run(ctx) + + var grpcAPIServer *grpc.Server + if cfg.GrpcAPI.Enabled { + var err error + grpcAPIServer, err = runGRPCAPIServer(cfg, node, useAPIOpentelemetry, grpcAPIExecutor) + if err != nil { + log.Fatal().Msgf("error creating GRPC API server: %v", err) + } + } + + var grpcUniServer *grpc.Server + if cfg.UniGRPC.Enabled { + var err error + grpcAPIServer, err = runGRPCUniServer(cfg, node) + if err != nil { + log.Fatal().Msgf("error creating GRPC API server: %v", err) } } + httpServers, err := runHTTPServers(node, cfgContainer, httpAPIExecutor, keepHeadersInContext) + if err != nil { + log.Fatal().Msgf("error running HTTP server: %v", err) + } + logStartWarnings(cfg, cfgMeta) + handleSignals( cmd, configFile, node, cfgContainer, tokenVerifier, subTokenVerifier, httpServers, grpcAPIServer, grpcUniServer, - serviceGroup, serviceCancel, + serviceManager, serviceCancel, ) } func handleSignals( cmd *cobra.Command, configFile string, n *centrifuge.Node, cfgContainer *config.Container, tokenVerifier *jwtverify.VerifierJWT, subTokenVerifier *jwtverify.VerifierJWT, httpServers []*http.Server, - grpcAPIServer *grpc.Server, grpcUniServer *grpc.Server, serviceGroup *errgroup.Group, + grpcAPIServer *grpc.Server, grpcUniServer *grpc.Server, serviceManager *service.Manager, serviceCancel context.CancelFunc, ) { cfg := cfgContainer.Config() @@ -339,20 +331,11 @@ func handleSignals( if pidFile != "" { _ = os.Remove(pidFile) } - os.Exit(1) + log.Fatal().Msg("shutdown timeout reached") }) var wg sync.WaitGroup - if serviceGroup != nil { - serviceCancel() - wg.Add(1) - go func() { - defer wg.Done() - _ = serviceGroup.Wait() - }() - } - if grpcAPIServer != nil { wg.Add(1) go func() { @@ -369,20 +352,19 @@ func handleSignals( }() } - ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout.ToDuration()) - for _, srv := range httpServers { wg.Add(1) go func(srv *http.Server) { defer wg.Done() - _ = srv.Shutdown(ctx) + _ = srv.Shutdown(context.Background()) // We have a separate timeout goroutine. }(srv) } - _ = n.Shutdown(ctx) - + _ = n.Shutdown(context.Background()) // We have a separate timeout goroutine. wg.Wait() - cancel() + + serviceCancel() + _ = serviceManager.Wait() if pidFile != "" { _ = os.Remove(pidFile) diff --git a/internal/service/service.go b/internal/service/service.go index 0bec2c2c11..ab1e19a061 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -2,8 +2,57 @@ package service import ( "context" + "sync" + + "golang.org/x/sync/errgroup" ) type Service interface { Run(ctx context.Context) error } + +// Manager manages a collection of services. +type Manager struct { + services []Service + mu sync.Mutex // Protects access to the services slice + wg *errgroup.Group +} + +func NewManager() *Manager { + return &Manager{services: make([]Service, 0)} +} + +// Register adds a new service to the ServiceManager. +func (sm *Manager) Register(s ...Service) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.services = append(sm.services, s...) +} + +// Run runs all registered services concurrently using an errgroup. +func (sm *Manager) Run(ctx context.Context) { + sm.mu.Lock() + defer sm.mu.Unlock() + + if len(sm.services) == 0 { + return + } + + group, ctx := errgroup.WithContext(ctx) + for _, s := range sm.services { + s := s // Capture the service in the loop. + group.Go(func() error { + return s.Run(ctx) + }) + } + sm.wg = group +} + +func (sm *Manager) Wait() error { + sm.mu.Lock() + defer sm.mu.Unlock() + if len(sm.services) == 0 { + return nil + } + return sm.wg.Wait() +}