diff --git a/cmd/app/grpc.go b/cmd/app/grpc.go index 3aaf283a5..05dccbfb4 100644 --- a/cmd/app/grpc.go +++ b/cmd/app/grpc.go @@ -21,8 +21,10 @@ import ( "fmt" "net" "os" + "os/signal" "runtime" "sync" + "syscall" "github.com/fsnotify/fsnotify" "github.com/goadesign/goa/grpc/middleware" @@ -197,7 +199,7 @@ func (g *grpcServer) setupPrometheus(reg *prometheus.Registry) { grpc_prometheus.Register(g.Server) } -func (g *grpcServer) startTCPListener() { +func (g *grpcServer) startTCPListener(wg *sync.WaitGroup) { // lis is closed by g.Server.Serve() upon exit lis, err := net.Listen("tcp", g.grpcServerEndpoint) if err != nil { @@ -206,13 +208,30 @@ func (g *grpcServer) startTCPListener() { g.grpcServerEndpoint = lis.Addr().String() log.Logger.Infof("listening on grpc at %s", g.grpcServerEndpoint) + + idleConnsClosed := make(chan struct{}) + go func() { + sigint := make(chan os.Signal, 1) + signal.Notify(sigint, syscall.SIGINT, syscall.SIGTERM) + <-sigint + + // received an interrupt signal, shut down + g.Server.GracefulStop() + close(idleConnsClosed) + log.Logger.Info("stopped grpc server") + }() + + wg.Add(1) go func() { if g.tlsCertWatcher != nil { defer g.tlsCertWatcher.Close() } if err := g.Server.Serve(lis); err != nil { - log.Logger.Errorf("error shutting down grpcServer: %w", err) + log.Logger.Fatalf("error shutting down grpcServer: %w", err) } + <-idleConnsClosed + wg.Done() + log.Logger.Info("grpc server shutdown") }() } @@ -221,12 +240,12 @@ func (g *grpcServer) startUnixListener() { if runtime.GOOS != "linux" { // As MacOS doesn't have abstract unix domain sockets the file // created by a previous run needs to be explicitly removed - if err := os.RemoveAll(LegacyUnixDomainSocket); err != nil { + if err := os.RemoveAll(g.grpcServerEndpoint); err != nil { log.Logger.Fatal(err) } } - unixAddr, err := net.ResolveUnixAddr("unix", LegacyUnixDomainSocket) + unixAddr, err := net.ResolveUnixAddr("unix", g.grpcServerEndpoint) if err != nil { log.Logger.Fatal(err) } @@ -246,7 +265,7 @@ func (g *grpcServer) ExposesGRPCTLS() bool { return viper.IsSet("grpc-tls-certificate") && viper.IsSet("grpc-tls-key") } -func createLegacyGRPCServer(cfg *config.FulcioConfig, v2Server gw.CAServer) (*grpcServer, error) { +func createLegacyGRPCServer(cfg *config.FulcioConfig, unixDomainSocket string, v2Server gw.CAServer) (*grpcServer, error) { logger, opts := log.SetupGRPCLogging() myServer := grpc.NewServer(grpc.UnaryInterceptor( @@ -264,7 +283,7 @@ func createLegacyGRPCServer(cfg *config.FulcioConfig, v2Server gw.CAServer) (*gr // Register your gRPC service implementations. gw_legacy.RegisterCAServer(myServer, legacyGRPCCAServer) - return &grpcServer{myServer, LegacyUnixDomainSocket, v2Server, nil}, nil + return &grpcServer{myServer, unixDomainSocket, v2Server, nil}, nil } func panicRecoveryHandler(ctx context.Context, p interface{}) error { diff --git a/cmd/app/http.go b/cmd/app/http.go index ad096d22d..19a1564cc 100644 --- a/cmd/app/http.go +++ b/cmd/app/http.go @@ -21,8 +21,12 @@ import ( "errors" "fmt" "net/http" + "os" + "os/signal" "strconv" "strings" + "sync" + "syscall" "time" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" @@ -102,12 +106,32 @@ func createHTTPServer(ctx context.Context, serverEndpoint string, grpcServer, le return httpServer{&api, serverEndpoint} } -func (h httpServer) startListener() { +func (h httpServer) startListener(wg *sync.WaitGroup) { log.Logger.Infof("listening on http at %s", h.httpServerEndpoint) + + idleConnsClosed := make(chan struct{}) + go func() { + sigint := make(chan os.Signal, 1) + signal.Notify(sigint, syscall.SIGINT, syscall.SIGTERM) + <-sigint + + // received an interrupt signal, shut down + if err := h.Shutdown(context.Background()); err != nil { + // error from closing listeners, or context timeout + log.Logger.Errorf("HTTP server Shutdown: %v", err) + } + close(idleConnsClosed) + log.Logger.Info("stopped http server") + }() + + wg.Add(1) go func() { if err := h.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Logger.Fatal(err) } + <-idleConnsClosed + wg.Done() + log.Logger.Info("http server shutdown") }() } diff --git a/cmd/app/http_test.go b/cmd/app/http_test.go index 7495ee483..debb70036 100644 --- a/cmd/app/http_test.go +++ b/cmd/app/http_test.go @@ -27,6 +27,7 @@ import ( "os" "path/filepath" "strings" + "sync" "testing" "github.com/sigstore/fulcio/pkg/ca" @@ -50,7 +51,8 @@ func setupHTTPServer(t *testing.T) (httpServer, string) { if err != nil { t.Error(err) } - grpcServer.startTCPListener() + var wg sync.WaitGroup + grpcServer.startTCPListener(&wg) conn, err := grpc.Dial(grpcServer.grpcServerEndpoint, grpc.WithTransportCredentials(insecure.NewCredentials())) defer func() { if conn != nil { @@ -95,7 +97,9 @@ func setupHTTPServerWithGRPCTLS(t *testing.T) (httpServer, string) { if err != nil { t.Error(err) } - grpcServer.startTCPListener() + + var wg sync.WaitGroup + grpcServer.startTCPListener(&wg) conn, err := grpc.Dial(grpcServer.grpcServerEndpoint, grpc.WithTransportCredentials(insecure.NewCredentials())) defer func() { if conn != nil { @@ -105,7 +109,7 @@ func setupHTTPServerWithGRPCTLS(t *testing.T) (httpServer, string) { if err != nil { t.Error(err) } - legacyGRPCServer, err := createLegacyGRPCServer(nil, grpcServer.caService) + legacyGRPCServer, err := createLegacyGRPCServer(nil, LegacyUnixDomainSocket, grpcServer.caService) if err != nil { t.Fatal(err) } diff --git a/cmd/app/serve.go b/cmd/app/serve.go index a7cc950fe..2196bfbc3 100644 --- a/cmd/app/serve.go +++ b/cmd/app/serve.go @@ -19,13 +19,17 @@ import ( "bytes" "context" "crypto/x509" + "errors" "flag" "fmt" "net" "net/http" "os" + "os/signal" "path/filepath" "strings" + "sync" + "syscall" "time" "chainguard.dev/go-grpc-kit/pkg/duplex" @@ -99,6 +103,7 @@ func newServeCmd() *cobra.Command { cmd.Flags().String("grpc-host", "0.0.0.0", "The host on which to serve requests for GRPC") cmd.Flags().String("grpc-port", "8081", "The port on which to serve requests for GRPC") cmd.Flags().String("metrics-port", "2112", "The port on which to serve prometheus metrics endpoint") + cmd.Flags().String("legacy-unix-domain-socket", LegacyUnixDomainSocket, "The Unix domain socket used for the legacy gRPC server") cmd.Flags().Duration("read-header-timeout", 10*time.Second, "The time allowed to read the headers of the requests in seconds") cmd.Flags().String("grpc-tls-certificate", "", "the certificate file to use for secure connections - only applies to grpc-port") cmd.Flags().String("grpc-tls-key", "", "the private key file to use for secure connections (without passphrase) - only applies to grpc-port") @@ -286,6 +291,9 @@ func runServeCmd(cmd *cobra.Command, args []string) { //nolint: revive return } + // waiting for http and grpc servers to shutdown gracefully + var wg sync.WaitGroup + httpServerEndpoint := fmt.Sprintf("%v:%v", viper.GetString("http-host"), viper.GetString("http-port")) reg := prometheus.NewRegistry() @@ -295,16 +303,16 @@ func runServeCmd(cmd *cobra.Command, args []string) { //nolint: revive log.Logger.Fatal(err) } grpcServer.setupPrometheus(reg) - grpcServer.startTCPListener() + grpcServer.startTCPListener(&wg) - legacyGRPCServer, err := createLegacyGRPCServer(cfg, grpcServer.caService) + legacyGRPCServer, err := createLegacyGRPCServer(cfg, viper.GetString("legacy-unix-domain-socket"), grpcServer.caService) if err != nil { log.Logger.Fatal(err) } legacyGRPCServer.startUnixListener() httpServer := createHTTPServer(ctx, httpServerEndpoint, grpcServer, legacyGRPCServer) - httpServer.startListener() + httpServer.startListener(&wg) readHeaderTimeout := viper.GetDuration("read-header-timeout") prom := http.Server{ @@ -312,7 +320,29 @@ func runServeCmd(cmd *cobra.Command, args []string) { //nolint: revive Handler: promhttp.Handler(), ReadHeaderTimeout: readHeaderTimeout, } - log.Logger.Error(prom.ListenAndServe()) + + idleConnsClosed := make(chan struct{}) + go func() { + sigint := make(chan os.Signal, 1) + signal.Notify(sigint, syscall.SIGINT, syscall.SIGTERM) + <-sigint + + // received an interrupt signal, shut down + if err := prom.Shutdown(context.Background()); err != nil { + // error from closing listeners, or context timeout + log.Logger.Errorf("HTTP server Shutdown: %v", err) + } + close(idleConnsClosed) + log.Logger.Info("stopped prom server") + }() + if err := prom.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Logger.Fatal(err) + } + <-idleConnsClosed + log.Logger.Info("prom server shutdown") + + // wait for http and grpc servers to shutdown + wg.Wait() } func checkServeCmdConfigFile() error {