diff --git a/go.mod b/go.mod index 9b78bf7..8c74b14 100644 --- a/go.mod +++ b/go.mod @@ -10,12 +10,11 @@ require ( github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 github.com/prometheus/client_golang v1.20.4 github.com/siderolabs/discovery-api v0.1.4 - github.com/siderolabs/discovery-client v0.1.9 + github.com/siderolabs/discovery-client v0.1.10 github.com/siderolabs/gen v0.5.0 github.com/siderolabs/go-debug v0.4.0 github.com/stretchr/testify v1.9.0 go.uber.org/zap v1.27.0 - go4.org/netipx v0.0.0-20231129151722-fdeea329fbba golang.org/x/net v0.28.0 golang.org/x/sync v0.8.0 golang.org/x/time v0.6.0 diff --git a/go.sum b/go.sum index af769f6..40ad215 100644 --- a/go.sum +++ b/go.sum @@ -42,8 +42,8 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/siderolabs/discovery-api v0.1.4 h1:2fMEFSMiWaD1zDiBDY5md8VxItvL1rDQRSOfeXNjYKc= github.com/siderolabs/discovery-api v0.1.4/go.mod h1:kaBy+G42v2xd/uAF/NIe383sjNTBE2AhxPTyi9SZI0s= -github.com/siderolabs/discovery-client v0.1.9 h1:yDzvts++Nf/2qczdDUfU5GAibkEIgz/eo9RPG/k/rOc= -github.com/siderolabs/discovery-client v0.1.9/go.mod h1:Ew1z07eyJwqNwum84IKYH4S649KEKK5WUmRW49HlXS8= +github.com/siderolabs/discovery-client v0.1.10 h1:bTAvFLiISSzVXyYL1cIgAz8cPYd9ZfvhxwdebgtxARA= +github.com/siderolabs/discovery-client v0.1.10/go.mod h1:Ew1z07eyJwqNwum84IKYH4S649KEKK5WUmRW49HlXS8= github.com/siderolabs/gen v0.5.0 h1:Afdjx+zuZDf53eH5DB+E+T2JeCwBXGinV66A6osLgQI= github.com/siderolabs/gen v0.5.0/go.mod h1:1GUMBNliW98Xeq8GPQeVMYqQE09LFItE8enR3wgMh3Q= github.com/siderolabs/go-debug v0.4.0 h1:pbFt6Rzumm90s3GvbRer7yIxFNc0gQ94I53omkqswHA= @@ -56,8 +56,6 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= -go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= diff --git a/internal/grpclog/grpclog.go b/internal/grpclog/grpclog.go new file mode 100644 index 0000000..6971f21 --- /dev/null +++ b/internal/grpclog/grpclog.go @@ -0,0 +1,53 @@ +// Copyright (c) 2024 Sidero Labs, Inc. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. + +// Package grpclog provides a logger that logs to a zap logger. +package grpclog + +import ( + "context" + "fmt" + + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + "go.uber.org/zap" +) + +// Adapter returns a logging.Logger that logs to the provided zap logger. +func Adapter(l *zap.Logger) logging.Logger { + return logging.LoggerFunc(func(_ context.Context, lvl logging.Level, msg string, fields ...any) { + f := make([]zap.Field, 0, len(fields)/2) + + for i := 0; i < len(fields); i += 2 { + key := fields[i].(string) //nolint:forcetypeassert,errcheck + value := fields[i+1] + + switch v := value.(type) { + case string: + f = append(f, zap.String(key, v)) + case int: + f = append(f, zap.Int(key, v)) + case bool: + f = append(f, zap.Bool(key, v)) + default: + f = append(f, zap.Any(key, v)) + } + } + + logger := l.WithOptions(zap.AddCallerSkip(1)).With(f...) + + switch lvl { + case logging.LevelDebug: + logger.Debug(msg) + case logging.LevelInfo: + logger.Info(msg) + case logging.LevelWarn: + logger.Warn(msg) + case logging.LevelError: + logger.Error(msg) + default: + panic(fmt.Sprintf("unknown level %v", lvl)) + } + }) +} diff --git a/pkg/server/addr.go b/pkg/server/addr.go index 2b17df1..7101482 100644 --- a/pkg/server/addr.go +++ b/pkg/server/addr.go @@ -7,10 +7,8 @@ package server import ( "context" - "net" "net/netip" - "go4.org/netipx" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" ) @@ -38,10 +36,8 @@ func PeerAddress(ctx context.Context) netip.Addr { } if peer, ok := peer.FromContext(ctx); ok { - if addr, ok := peer.Addr.(*net.TCPAddr); ok { - if ip, ok := netipx.FromStdIP(addr.IP); ok { - return ip - } + if addrPort, err := netip.ParseAddrPort(peer.Addr.String()); err == nil { + return addrPort.Addr() } } diff --git a/pkg/server/cert_test.go b/pkg/server/cert_test.go new file mode 100644 index 0000000..425927e --- /dev/null +++ b/pkg/server/cert_test.go @@ -0,0 +1,97 @@ +// Copyright (c) 2024 Sidero Labs, Inc. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. + +package server_test + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +var onceCert = sync.OnceValues(func() (*tls.Certificate, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, err + } + + keyUsage := x509.KeyUsageDigitalSignature + notBefore := time.Now() + notAfter := notBefore.Add(time.Hour) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"SideroLabs Testing"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: keyUsage, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return nil, err + } + + cert := tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: priv, + } + + cert.Leaf, err = x509.ParseCertificate(derBytes) + if err != nil { + return nil, err + } + + return &cert, nil +}) + +func GetServerTLSConfig(t *testing.T) *tls.Config { + t.Helper() + + cert, err := onceCert() + require.NoError(t, err) + + return &tls.Config{ + Certificates: []tls.Certificate{*cert}, + MinVersion: tls.VersionTLS12, + } +} + +func GetClientTLSConfig(t *testing.T) *tls.Config { + t.Helper() + + cert, err := onceCert() + require.NoError(t, err) + + certPool := x509.NewCertPool() + certPool.AddCert(cert.Leaf) + + return &tls.Config{ + RootCAs: certPool, + MinVersion: tls.VersionTLS12, + } +} diff --git a/pkg/server/client_test.go b/pkg/server/client_test.go index 697efd0..3641c25 100644 --- a/pkg/server/client_test.go +++ b/pkg/server/client_test.go @@ -60,7 +60,7 @@ func TestClient(t *testing.T) { ClusterID: clusterID, AffiliateID: affiliate1, TTL: time.Minute, - Insecure: true, + TLSConfig: GetClientTLSConfig(t), }) require.NoError(t, err) @@ -70,7 +70,7 @@ func TestClient(t *testing.T) { ClusterID: clusterID, AffiliateID: affiliate2, TTL: time.Minute, - Insecure: true, + TLSConfig: GetClientTLSConfig(t), }) require.NoError(t, err) @@ -239,7 +239,7 @@ func TestClient(t *testing.T) { ClusterID: clusterID, AffiliateID: affiliate1, TTL: time.Second, - Insecure: true, + TLSConfig: GetClientTLSConfig(t), }) require.NoError(t, err) @@ -249,7 +249,7 @@ func TestClient(t *testing.T) { ClusterID: clusterID, AffiliateID: affiliate2, TTL: time.Minute, - Insecure: true, + TLSConfig: GetClientTLSConfig(t), }) require.NoError(t, err) @@ -392,7 +392,7 @@ func clusterSimulator(t *testing.T, endpoint string, logger *zap.Logger, numAffi AffiliateID: fmt.Sprintf("affiliate-%d", i), ClientVersion: "v0.0.1", TTL: 10 * time.Second, - Insecure: true, + TLSConfig: GetClientTLSConfig(t), }) require.NoError(t, err) } @@ -557,7 +557,7 @@ func TestClientRedirect(t *testing.T) { ClusterID: clusterID, AffiliateID: affiliate1, TTL: time.Minute, - Insecure: true, + TLSConfig: GetClientTLSConfig(t), }) require.NoError(t, err) @@ -567,7 +567,7 @@ func TestClientRedirect(t *testing.T) { ClusterID: clusterID, AffiliateID: affiliate2, TTL: time.Minute, - Insecure: true, + TLSConfig: GetClientTLSConfig(t), }) require.NoError(t, err) diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index ee7f334..36c544e 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -11,25 +11,29 @@ import ( "errors" "fmt" "net" + "net/http" "strings" "testing" "time" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" prom "github.com/prometheus/client_golang/prometheus" promtestutil "github.com/prometheus/client_golang/prometheus/testutil" "github.com/siderolabs/discovery-api/api/v1alpha1/server/pb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" + "golang.org/x/net/http2" "golang.org/x/time/rate" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/durationpb" + "github.com/siderolabs/discovery-service/internal/grpclog" "github.com/siderolabs/discovery-service/internal/limiter" _ "github.com/siderolabs/discovery-service/internal/proto" "github.com/siderolabs/discovery-service/internal/state" @@ -48,6 +52,7 @@ func checkMetrics(t *testing.T, c prom.Collector) { type testServer struct { //nolint:govet lis net.Listener s *grpc.Server + httpServer *http.Server state *state.State stopCh <-chan struct{} serverOptions []grpc.ServerOption @@ -89,13 +94,20 @@ func setupServer(t *testing.T, rateLimit rate.Limit, redirectEndpoint string) *t limiter := limiter.NewIPRateLimiter(rateLimit, limits.IPRateBurstSizeMax) + loggingOpts := []logging.Option{ + logging.WithLogOnEvents(logging.StartCall, logging.FinishCall), + logging.WithFieldsFromContext(logging.ExtractFields), + } + testServer.serverOptions = []grpc.ServerOption{ grpc.ChainUnaryInterceptor( server.AddLoggingFieldsUnaryServerInterceptor(), + logging.UnaryServerInterceptor(grpclog.Adapter(logger), loggingOpts...), server.RateLimitUnaryServerInterceptor(limiter), ), grpc.ChainStreamInterceptor( server.AddLoggingFieldsStreamServerInterceptor(), + logging.StreamServerInterceptor(grpclog.Adapter(logger), loggingOpts...), server.RateLimitStreamServerInterceptor(limiter), ), grpc.SharedWriteBuffer(true), @@ -106,19 +118,35 @@ func setupServer(t *testing.T, rateLimit rate.Limit, redirectEndpoint string) *t testServer.s = grpc.NewServer(testServer.serverOptions...) pb.RegisterClusterServer(testServer.s, srv) + testServer.httpServer = &http.Server{ + Handler: testServer.s, + TLSConfig: GetServerTLSConfig(t), + } + + require.NoError(t, http2.ConfigureServer(testServer.httpServer, nil)) + go func() { - if stopErr := testServer.s.Serve(testServer.lis); stopErr != nil && !errors.Is(stopErr, grpc.ErrServerStopped) { - require.NoError(t, err) + if stopErr := testServer.httpServer.ServeTLS(testServer.lis, "", ""); stopErr != nil && !errors.Is(stopErr, http.ErrServerClosed) { + assert.NoError(t, stopErr) } + + t.Logf("server stopped") }() - t.Cleanup(testServer.s.Stop) + t.Cleanup(func() { + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + assert.NoError(t, testServer.httpServer.Shutdown(shutdownCtx)) + }) return testServer } func (testServer *testServer) restartWithRedirect(t *testing.T, redirectEndpoint string) { - testServer.s.Stop() + t.Logf("restarting server with redirect to %s", redirectEndpoint) + + assert.NoError(t, testServer.httpServer.Close()) srv := server.NewClusterServer(testServer.state, testServer.stopCh, redirectEndpoint) @@ -130,13 +158,27 @@ func (testServer *testServer) restartWithRedirect(t *testing.T, redirectEndpoint testServer.lis, err = net.Listen("tcp", testServer.address) require.NoError(t, err) + testServer.httpServer = &http.Server{ + Handler: testServer.s, + TLSConfig: GetServerTLSConfig(t), + } + + require.NoError(t, http2.ConfigureServer(testServer.httpServer, nil)) + go func() { - if stopErr := testServer.s.Serve(testServer.lis); stopErr != nil && !errors.Is(stopErr, grpc.ErrServerStopped) { - require.NoError(t, err) + if stopErr := testServer.httpServer.ServeTLS(testServer.lis, "", ""); stopErr != nil && !errors.Is(stopErr, http.ErrServerClosed) { + assert.NoError(t, stopErr) } + + t.Logf("restarted server stopped") }() - t.Cleanup(testServer.s.Stop) + t.Cleanup(func() { + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + assert.NoError(t, testServer.httpServer.Shutdown(shutdownCtx)) + }) } func TestServerAPI(t *testing.T) { @@ -144,7 +186,7 @@ func TestServerAPI(t *testing.T) { addr := setupServer(t, 5000, "").address - conn, e := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + conn, e := grpc.NewClient(addr, grpc.WithTransportCredentials(credentials.NewTLS(GetClientTLSConfig(t)))) require.NoError(t, e) client := pb.NewClusterClient(conn) @@ -344,7 +386,7 @@ func TestValidation(t *testing.T) { addr := setupServer(t, 5000, "").address - conn, e := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + conn, e := grpc.NewClient(addr, grpc.WithTransportCredentials(credentials.NewTLS(GetClientTLSConfig(t)))) require.NoError(t, e) client := pb.NewClusterClient(conn) @@ -539,7 +581,7 @@ func TestServerRateLimit(t *testing.T) { addr := setupServer(t, 1, "").address - conn, e := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + conn, e := grpc.NewClient(addr, grpc.WithTransportCredentials(credentials.NewTLS(GetClientTLSConfig(t)))) require.NoError(t, e) client := pb.NewClusterClient(conn) @@ -553,7 +595,7 @@ func TestServerRedirect(t *testing.T) { addr := setupServer(t, 1, "new.example.com:443").address - conn, e := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + conn, e := grpc.NewClient(addr, grpc.WithTransportCredentials(credentials.NewTLS(GetClientTLSConfig(t)))) require.NoError(t, e) client := pb.NewClusterClient(conn) diff --git a/pkg/service/service.go b/pkg/service/service.go index f7488b6..d6a3d27 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -33,6 +33,7 @@ import ( "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" + "github.com/siderolabs/discovery-service/internal/grpclog" "github.com/siderolabs/discovery-service/internal/landing" "github.com/siderolabs/discovery-service/internal/limiter" "github.com/siderolabs/discovery-service/internal/state" @@ -83,7 +84,7 @@ func newGRPCServer(ctx context.Context, state *state.State, options Options, log serverOptions := []grpc.ServerOption{ grpc.ChainUnaryInterceptor( server.AddLoggingFieldsUnaryServerInterceptor(), - logging.UnaryServerInterceptor(interceptorLogger(logger), loggingOpts...), + logging.UnaryServerInterceptor(grpclog.Adapter(logger), loggingOpts...), server.RateLimitUnaryServerInterceptor(limiter), metrics.UnaryServerInterceptor(), grpc_recovery.UnaryServerInterceptor(recoveryOpt), @@ -91,7 +92,7 @@ func newGRPCServer(ctx context.Context, state *state.State, options Options, log grpc.ChainStreamInterceptor( server.AddLoggingFieldsStreamServerInterceptor(), server.RateLimitStreamServerInterceptor(limiter), - logging.StreamServerInterceptor(interceptorLogger(logger), loggingOpts...), + logging.StreamServerInterceptor(grpclog.Adapter(logger), loggingOpts...), metrics.StreamServerInterceptor(), grpc_recovery.StreamServerInterceptor(recoveryOpt), ), @@ -325,43 +326,6 @@ func recoveryHandler(logger *zap.Logger) grpc_recovery.RecoveryHandlerFunc { } } -func interceptorLogger(l *zap.Logger) logging.Logger { - return logging.LoggerFunc(func(_ context.Context, lvl logging.Level, msg string, fields ...any) { - f := make([]zap.Field, 0, len(fields)/2) - - for i := 0; i < len(fields); i += 2 { - key := fields[i].(string) //nolint:forcetypeassert,errcheck - value := fields[i+1] - - switch v := value.(type) { - case string: - f = append(f, zap.String(key, v)) - case int: - f = append(f, zap.Int(key, v)) - case bool: - f = append(f, zap.Bool(key, v)) - default: - f = append(f, zap.Any(key, v)) - } - } - - logger := l.WithOptions(zap.AddCallerSkip(1)).With(f...) - - switch lvl { - case logging.LevelDebug: - logger.Debug(msg) - case logging.LevelInfo: - logger.Info(msg) - case logging.LevelWarn: - logger.Warn(msg) - case logging.LevelError: - logger.Error(msg) - default: - panic(fmt.Sprintf("unknown level %v", lvl)) - } - }) -} - func unregisterCollectors(registerer prom.Registerer, collectors ...prom.Collector) { for _, collector := range collectors { if collector == nil {