From cf39974104bbfc291289736847cf05e3a205301e Mon Sep 17 00:00:00 2001 From: Andrey Smirnov Date: Thu, 26 Sep 2024 18:05:36 +0400 Subject: [PATCH] feat: support direct TLS serving Support certificate reload on the fly. Slice version to just `vX.Y` in the metrics. Bump IP-based limits. Signed-off-by: Andrey Smirnov --- cmd/discovery-service/main.go | 21 +++-- go.mod | 9 +- go.sum | 16 ++-- pkg/limits/limits.go | 4 +- pkg/server/addr.go | 17 +++- pkg/server/server_test.go | 4 + pkg/server/version.go | 8 +- pkg/server/version_test.go | 11 ++- pkg/service/certificate.go | 137 ++++++++++++++++++++++++++++ pkg/service/service.go | 165 ++++++++++++++++++++++++---------- 10 files changed, 312 insertions(+), 80 deletions(-) create mode 100644 pkg/service/certificate.go diff --git a/cmd/discovery-service/main.go b/cmd/discovery-service/main.go index cb93e89..70d71cf 100644 --- a/cmd/discovery-service/main.go +++ b/cmd/discovery-service/main.go @@ -34,18 +34,24 @@ var ( snapshotsEnabled = true snapshotPath = "/var/discovery-service/state.binpb" snapshotInterval = 10 * time.Minute + certificatePath = "" + keyPath = "" + trustXRealIP = true ) func init() { flag.StringVar(&listenAddr, "addr", listenAddr, "addr on which to listen") - flag.StringVar(&landingAddr, "landing-addr", landingAddr, "addr on which to listen for landing page") - flag.StringVar(&metricsAddr, "metrics-addr", metricsAddr, "prometheus metrics listen addr") + flag.StringVar(&certificatePath, "certificate-path", certificatePath, "path to the certificate file") + flag.StringVar(&keyPath, "key-path", keyPath, "path to the key file") + flag.StringVar(&landingAddr, "landing-addr", landingAddr, "addr on which to listen for landing page (set to empty to disable)") + flag.StringVar(&metricsAddr, "metrics-addr", metricsAddr, "prometheus metrics listen addr (set to empty to disable)") flag.BoolVar(&devMode, "debug", devMode, "enable debug mode") flag.DurationVar(&gcInterval, "gc-interval", gcInterval, "garbage collection interval") flag.StringVar(&redirectEndpoint, "redirect-endpoint", redirectEndpoint, "redirect all clients to a new endpoint (gRPC endpoint, e.g. 'example.com:443'") flag.BoolVar(&snapshotsEnabled, "snapshots-enabled", snapshotsEnabled, "enable snapshots") flag.StringVar(&snapshotPath, "snapshot-path", snapshotPath, "path to the snapshot file") flag.DurationVar(&snapshotInterval, "snapshot-interval", snapshotInterval, "interval to save the snapshot") + flag.BoolVar(&trustXRealIP, "trust-x-real-ip", trustXRealIP, "trust X-Real-IP header") if debug.Enabled { flag.StringVar(&debugAddr, "debug-addr", debugAddr, "debug (pprof, trace, expvar) listen addr") @@ -84,16 +90,21 @@ func main() { ListenAddr: listenAddr, GCInterval: gcInterval, - LandingServerEnabled: true, + CertificatePath: certificatePath, + KeyPath: keyPath, + + LandingServerEnabled: landingAddr != "", LandingAddr: landingAddr, - DebugServerEnabled: true, + DebugServerEnabled: debugAddr != "", DebugAddr: debugAddr, - MetricsServerEnabled: true, + MetricsServerEnabled: metricsAddr != "", MetricsAddr: metricsAddr, MetricsRegisterer: prometheus.DefaultRegisterer, + + TrustXRealIP: trustXRealIP, }, logger) }); err != nil { logger.Error("service failed", zap.Error(err)) diff --git a/go.mod b/go.mod index ee010b6..9b78bf7 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,14 @@ module github.com/siderolabs/discovery-service -go 1.22.3 +go 1.23.1 require ( + github.com/fsnotify/fsnotify v1.7.0 github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0 github.com/jonboulle/clockwork v0.4.1-0.20231224152657-fc59783b0293 github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 - github.com/prometheus/client_golang v1.20.2 + github.com/prometheus/client_golang v1.20.4 github.com/siderolabs/discovery-api v0.1.4 github.com/siderolabs/discovery-client v0.1.9 github.com/siderolabs/gen v0.5.0 @@ -15,9 +16,10 @@ require ( github.com/stretchr/testify v1.9.0 go.uber.org/zap v1.27.0 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba + golang.org/x/net v0.28.0 golang.org/x/sync v0.8.0 golang.org/x/time v0.6.0 - google.golang.org/grpc v1.66.0 + google.golang.org/grpc v1.67.0 google.golang.org/protobuf v1.34.2 ) @@ -34,7 +36,6 @@ require ( github.com/prometheus/common v0.57.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/net v0.28.0 // indirect golang.org/x/sys v0.24.0 // indirect golang.org/x/text v0.17.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240827150818-7e3bb234dfed // indirect diff --git a/go.sum b/go.sum index a693a52..af769f6 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.1 h1:qnpSQwGEnkcRpTqNOIR6bJbR0gAorgP9CSALpRcKoAA= @@ -28,12 +30,10 @@ github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgm github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.20.2 h1:5ctymQzZlyOON1666svgwn3s6IKWgfbjsejTMiXIyjg= -github.com/prometheus/client_golang v1.20.2/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= +github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI= +github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= -github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= -github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= github.com/prometheus/common v0.57.0 h1:Ro/rKjwdq9mZn1K5QPctzh+MA4Lp0BuYk5ZZEVhoNcY= github.com/prometheus/common v0.57.0/go.mod h1:7uRPFSUTbfZWsJ7MHY56sqt7hLQu3bxXHDnNhl8E9qI= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= @@ -62,20 +62,16 @@ golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= -golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 h1:1GBuWVLM/KMVUv1t1En5Gs+gFZCNd360GGb4sSxtrhU= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0= google.golang.org/genproto/googleapis/rpc v0.0.0-20240827150818-7e3bb234dfed h1:J6izYgfBXAI3xTKLgxzTmUltdYaLsuBxFCgDHWJ/eXg= google.golang.org/genproto/googleapis/rpc v0.0.0-20240827150818-7e3bb234dfed/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= -google.golang.org/grpc v1.66.0 h1:DibZuoBznOxbDQxRINckZcUvnCEvrW9pcWIE2yF9r1c= -google.golang.org/grpc v1.66.0/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y= +google.golang.org/grpc v1.67.0 h1:IdH9y6PF5MPSdAntIcpjQ+tXO41pcQsfZV2RxtQgVcw= +google.golang.org/grpc v1.67.0/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/limits/limits.go b/pkg/limits/limits.go index bfa07ef..3e098b6 100644 --- a/pkg/limits/limits.go +++ b/pkg/limits/limits.go @@ -23,7 +23,7 @@ const ( // IP Rate Limiter. const ( - IPRateRequestsPerSecondMax = 5 - IPRateBurstSizeMax = 30 + IPRateRequestsPerSecondMax = 15 + IPRateBurstSizeMax = 60 IPRateGarbageCollectionPeriod = time.Minute ) diff --git a/pkg/server/addr.go b/pkg/server/addr.go index b003f5a..2b17df1 100644 --- a/pkg/server/addr.go +++ b/pkg/server/addr.go @@ -15,15 +15,24 @@ import ( "google.golang.org/grpc/peer" ) +var trustXRealIP bool + +// TrustXRealIP enables X-Real-IP header support. +func TrustXRealIP(enabled bool) { + trustXRealIP = enabled +} + // PeerAddress is used to extract peer address from the client. // it will try to extract the actual client's IP when called via // Nginx ingress first if not it will get the nginx or the machine // which calls the server, if everything fails returns an empty address. func PeerAddress(ctx context.Context) netip.Addr { - if md, ok := metadata.FromIncomingContext(ctx); ok { - if vals := md.Get("X-Real-IP"); vals != nil { - if ip, err := netip.ParseAddr(vals[0]); err == nil { - return ip + if trustXRealIP { + if md, ok := metadata.FromIncomingContext(ctx); ok { + if vals := md.Get("X-Real-IP"); vals != nil { + if ip, err := netip.ParseAddr(vals[0]); err == nil { + return ip + } } } } diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index dbc6f92..ee7f334 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -569,3 +569,7 @@ func TestServerRedirect(t *testing.T) { assert.Equal(t, "new.example.com:443", resp.GetRedirect().GetEndpoint()) } + +func init() { + server.TrustXRealIP(true) +} diff --git a/pkg/server/version.go b/pkg/server/version.go index 3ceaa1c..a03f774 100644 --- a/pkg/server/version.go +++ b/pkg/server/version.go @@ -9,15 +9,15 @@ import ( "regexp" ) -var vRE = regexp.MustCompile(`^(v\d+\.\d+\.\d+\-?[^-]*)(.*)$`) +var vRE = regexp.MustCompile(`^(v\d+\.\d+)(\.\d+)(\-?[^-]*)(.*)$`) func parseVersion(v string) string { m := vRE.FindAllStringSubmatch(v, -1) - if len(m) == 1 && len(m[0]) == 3 { + if len(m) == 1 && len(m[0]) >= 2 { res := m[0][1] - if m[0][2] != "" { - res += "-dev" + if len(m[0]) >= 3 && m[0][3] != "" { + res += "-pre" } return res diff --git a/pkg/server/version_test.go b/pkg/server/version_test.go index 3f1554c..82adf5b 100644 --- a/pkg/server/version_test.go +++ b/pkg/server/version_test.go @@ -17,10 +17,13 @@ func TestParseVersion(t *testing.T) { for v, expected := range map[string]string{ "": "unknown", "unknown": "unknown", - "v0.13.0": "v0.13.0", - "v0.13.0-beta.0": "v0.13.0-beta.0", - "v0.14.0-alpha.0-7-gf7d9f211": "v0.14.0-alpha.0-dev", - "v0.14.0-alpha.0-7-gf7d9f211-dirty": "v0.14.0-alpha.0-dev", + "v0.13.0": "v0.13", + "v0.13.0-beta.0": "v0.13-pre", + "v0.14.0-alpha.0-7-gf7d9f211": "v0.14-pre", + "v0.14.0-alpha.0-7-gf7d9f211-dirty": "v0.14-pre", + "v1.8.3": "v1.8", + "v1.8.3-7-gf7d9f211": "v1.8-pre", + "v1.8.0-beta.1": "v1.8-pre", } { t.Run(v, func(t *testing.T) { t.Parallel() diff --git a/pkg/service/certificate.go b/pkg/service/certificate.go new file mode 100644 index 0000000..0eb6126 --- /dev/null +++ b/pkg/service/certificate.go @@ -0,0 +1,137 @@ +// Copyright (c) 2024 Sidero Labs, Inc. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. + +package service + +import ( + "context" + "crypto/tls" + "fmt" + "sync" + "time" + + "github.com/fsnotify/fsnotify" + "go.uber.org/zap" +) + +// DynamicCertificate is a certificate that can be reloaded from disk. +type DynamicCertificate struct { + cert tls.Certificate + certFile string + keyFile string + mu sync.Mutex + loaded bool +} + +// NewDynamicCertificate creates a new DynamicCertificate. +func NewDynamicCertificate(certFile, keyFile string) *DynamicCertificate { + return &DynamicCertificate{ + certFile: certFile, + keyFile: keyFile, + } +} + +// Load the initial certificate. +func (c *DynamicCertificate) Load() error { + cert, err := tls.LoadX509KeyPair(c.certFile, c.keyFile) + if err != nil { + return err + } + + c.mu.Lock() + defer c.mu.Unlock() + + c.loaded = true + c.cert = cert + + return nil +} + +// GetCertificate returns the current certificate. +// +// It is suitable for use with tls.Config.GetCertificate. +func (c *DynamicCertificate) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.loaded { + return nil, fmt.Errorf("the cert wasn't loaded yet") + } + + return &c.cert, nil +} + +// Watch the certificate files for changes and reload them. +func (c *DynamicCertificate) Watch(ctx context.Context, logger *zap.Logger) error { + w, err := fsnotify.NewWatcher() + if err != nil { + return fmt.Errorf("error creating fsnotify watcher: %w", err) + } + defer w.Close() //nolint:errcheck + + if err = w.Add(c.certFile); err != nil { + return fmt.Errorf("error adding watch for file %s: %w", c.certFile, err) + } + + if err = w.Add(c.keyFile); err != nil { + return fmt.Errorf("error adding watch for file %s: %w", c.keyFile, err) + } + + handleEvent := func(e fsnotify.Event) error { + defer func() { + if err = c.Load(); err != nil { + logger.Error("failed to load certs", zap.Error(err)) + + return + } + + logger.Info("reloaded certs") + }() + + if !e.Has(fsnotify.Remove) && !e.Has(fsnotify.Rename) { + return nil + } + + if err = w.Remove(e.Name); err != nil { + logger.Error("failed to remove file watch, it may have been deleted", zap.String("file", e.Name), zap.Error(err)) + } + + if err = w.Add(e.Name); err != nil { + return fmt.Errorf("error adding watch for file %s: %w", e.Name, err) + } + + return nil + } + + for { + select { + case e := <-w.Events: + if err = handleEvent(e); err != nil { + return err + } + case err = <-w.Errors: + return fmt.Errorf("received fsnotify error: %w", err) + case <-ctx.Done(): + return nil + } + } +} + +// WatchWithRestarts restarts the Watch on error. +func (c *DynamicCertificate) WatchWithRestarts(ctx context.Context, logger *zap.Logger) error { + for { + if err := c.Watch(ctx, logger); err != nil { + logger.Error("watch error", zap.Error(err)) + } else { + return nil + } + + select { + case <-ctx.Done(): + return nil + case <-time.After(5 * time.Second): // retry + } + } +} diff --git a/pkg/service/service.go b/pkg/service/service.go index 47773e0..f7488b6 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -8,10 +8,12 @@ package service import ( "context" + "crypto/tls" "errors" "fmt" "net" "net/http" + "strings" "time" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" @@ -23,6 +25,8 @@ import ( "github.com/siderolabs/discovery-api/api/v1alpha1/server/pb" "github.com/siderolabs/go-debug" "go.uber.org/zap" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -41,12 +45,15 @@ import ( type Options struct { MetricsRegisterer prom.Registerer - LandingAddr string - MetricsAddr string - SnapshotPath string - DebugAddr string + ListenAddr string + LandingAddr string + MetricsAddr string + SnapshotPath string + DebugAddr string + + CertificatePath, KeyPath string + RedirectEndpoint string - ListenAddr string GCInterval time.Duration SnapshotInterval time.Duration @@ -55,14 +62,10 @@ type Options struct { DebugServerEnabled bool MetricsServerEnabled bool SnapshotsEnabled bool + TrustXRealIP bool } -// Run starts the service with the given options. -func Run(ctx context.Context, options Options, logger *zap.Logger) error { - logger.Info("service starting") - - defer logger.Info("service shut down") - +func newGRPCServer(ctx context.Context, state *state.State, options Options, logger *zap.Logger) (*grpc.Server, *server.ClusterServer, *limiter.IPRateLimiter, *grpc_prometheus.ServerMetrics) { recoveryOpt := grpc_recovery.WithRecoveryHandler(recoveryHandler(logger)) limiter := limiter.NewIPRateLimiter(limits.IPRateRequestsPerSecondMax, limits.IPRateBurstSizeMax) @@ -100,6 +103,25 @@ func Run(ctx context.Context, options Options, logger *zap.Logger) error { grpc.WriteBufferSize(16 * 1024), } + srv := server.NewClusterServer(state, ctx.Done(), options.RedirectEndpoint) + + s := grpc.NewServer(serverOptions...) + pb.RegisterClusterServer(s, srv) + + metrics.InitializeMetrics(s) + + return s, srv, limiter, metrics +} + +// Run starts the service with the given options. +// +//nolint:gocognit,gocyclo,cyclop +func Run(ctx context.Context, options Options, logger *zap.Logger) error { + logger.Info("service starting") + defer logger.Info("service shut down") + + server.TrustXRealIP(options.TrustXRealIP) + state := state.NewState(logger) var stateStorage *storage.Storage @@ -113,46 +135,59 @@ func Run(ctx context.Context, options Options, logger *zap.Logger) error { logger.Info("snapshots are disabled") } - srv := server.NewClusterServer(state, ctx.Done(), options.RedirectEndpoint) + s, srv, limiter, metrics := newGRPCServer(ctx, state, options, logger) lis, err := net.Listen("tcp", options.ListenAddr) if err != nil { return fmt.Errorf("failed to listen: %w", err) } - landingLis, err := net.Listen("tcp", options.LandingAddr) - if err != nil { - return fmt.Errorf("failed to listen: %w", err) - } + landingHandler := landing.Handler(state, logger) - s := grpc.NewServer(serverOptions...) - pb.RegisterClusterServer(s, srv) + eg, ctx := errgroup.WithContext(ctx) - metrics.InitializeMetrics(s) + var rootHandler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { + s.ServeHTTP(w, r) + } else { + landingHandler.ServeHTTP(w, r) + } + }) - var ( - metricsServer http.Server - landingServer http.Server - ) + insecure := options.CertificatePath == "" && options.KeyPath == "" - if options.MetricsServerEnabled { - var metricsMux http.ServeMux + if insecure { + rootHandler = h2c.NewHandler(rootHandler, &http2.Server{}) + } - metricsMux.Handle("/metrics", promhttp.Handler()) + var tlsConfig *tls.Config - metricsServer = http.Server{ - Addr: options.MetricsAddr, - Handler: &metricsMux, + if !insecure { + certLoader := NewDynamicCertificate(options.CertificatePath, options.KeyPath) + if err = certLoader.Load(); err != nil { + return fmt.Errorf("failed to load certificate: %w", err) } - } - if options.LandingServerEnabled { - landingServer = http.Server{ - Handler: landing.Handler(state, logger), + eg.Go(func() error { + return certLoader.WatchWithRestarts(ctx, logger) + }) + + tlsConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + GetCertificate: certLoader.GetCertificate, } } - eg, ctx := errgroup.WithContext(ctx) + mainServer := &http.Server{ + ReadHeaderTimeout: 10 * time.Second, + Handler: rootHandler, + TLSConfig: tlsConfig, + ErrorLog: zap.NewStdLog(logger.With(zap.String("server", "http"))), + } + + if err = http2.ConfigureServer(mainServer, nil); err != nil { + return fmt.Errorf("failed to configure server: %w", err) + } if options.SnapshotsEnabled { eg.Go(func() error { @@ -161,9 +196,17 @@ func Run(ctx context.Context, options Options, logger *zap.Logger) error { } eg.Go(func() error { - logger.Info("gRPC server starting", zap.Stringer("address", lis.Addr())) + logger.Info("API server starting", zap.Stringer("address", lis.Addr())) - if serveErr := s.Serve(lis); serveErr != nil { + var serveErr error + + if insecure { + serveErr = mainServer.Serve(lis) + } else { + serveErr = mainServer.ServeTLS(lis, "", "") + } + + if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) { return fmt.Errorf("failed to serve: %w", serveErr) } @@ -171,6 +214,17 @@ func Run(ctx context.Context, options Options, logger *zap.Logger) error { }) if options.LandingServerEnabled { + var landingLis net.Listener + + landingLis, err = net.Listen("tcp", options.LandingAddr) + if err != nil { + return fmt.Errorf("failed to listen: %w", err) + } + + landingServer := http.Server{ + Handler: landingHandler, + } + eg.Go(func() error { logger.Info("landing server starting", zap.Stringer("address", landingLis.Addr())) @@ -180,9 +234,27 @@ func Run(ctx context.Context, options Options, logger *zap.Logger) error { return nil }) + + eg.Go(func() error { + <-ctx.Done() + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + return landingServer.Shutdown(shutdownCtx) //nolint:contextcheck + }) } if options.MetricsServerEnabled { + var metricsMux http.ServeMux + + metricsMux.Handle("/metrics", promhttp.Handler()) + + metricsServer := http.Server{ + Addr: options.MetricsAddr, + Handler: &metricsMux, + } + eg.Go(func() error { logger.Info("metrics starting", zap.String("address", metricsServer.Addr)) @@ -192,6 +264,15 @@ func Run(ctx context.Context, options Options, logger *zap.Logger) error { return nil }) + + eg.Go(func() error { + <-ctx.Done() + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + return metricsServer.Shutdown(shutdownCtx) //nolint:contextcheck + }) } eg.Go(func() error { @@ -200,17 +281,7 @@ func Run(ctx context.Context, options Options, logger *zap.Logger) error { shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) defer shutdownCancel() - s.GracefulStop() - - if options.LandingServerEnabled { - landingServer.Shutdown(ctx) //nolint:errcheck - } - - if options.MetricsServerEnabled { - metricsServer.Shutdown(shutdownCtx) //nolint:errcheck,contextcheck - } - - return nil + return mainServer.Shutdown(shutdownCtx) //nolint:contextcheck }) eg.Go(func() error {