Skip to content

Commit

Permalink
fix: properly parse peer address
Browse files Browse the repository at this point in the history
After switch to Go's http/server, the peer address comes wrapped, so use
a different method to unwrap it.

The tests haven't caught that, as they were using gRPC's server, so
switch tests to use same approach as production, ans enable HTTP/2 over
TLS, as otherwise h2c is a mess, and it doesn't abort connections
properly for test purposes.

Signed-off-by: Andrey Smirnov <[email protected]>
  • Loading branch information
smira committed Sep 26, 2024
1 parent cf39974 commit efbb10b
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 70 deletions.
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ require (
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10
github.com/prometheus/client_golang v1.20.4
github.com/siderolabs/discovery-api v0.1.4
github.com/siderolabs/discovery-client v0.1.9
github.com/siderolabs/discovery-client v0.1.10
github.com/siderolabs/gen v0.5.0
github.com/siderolabs/go-debug v0.4.0
github.com/stretchr/testify v1.9.0
go.uber.org/zap v1.27.0
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/net v0.28.0
golang.org/x/sync v0.8.0
golang.org/x/time v0.6.0
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/siderolabs/discovery-api v0.1.4 h1:2fMEFSMiWaD1zDiBDY5md8VxItvL1rDQRSOfeXNjYKc=
github.com/siderolabs/discovery-api v0.1.4/go.mod h1:kaBy+G42v2xd/uAF/NIe383sjNTBE2AhxPTyi9SZI0s=
github.com/siderolabs/discovery-client v0.1.9 h1:yDzvts++Nf/2qczdDUfU5GAibkEIgz/eo9RPG/k/rOc=
github.com/siderolabs/discovery-client v0.1.9/go.mod h1:Ew1z07eyJwqNwum84IKYH4S649KEKK5WUmRW49HlXS8=
github.com/siderolabs/discovery-client v0.1.10 h1:bTAvFLiISSzVXyYL1cIgAz8cPYd9ZfvhxwdebgtxARA=
github.com/siderolabs/discovery-client v0.1.10/go.mod h1:Ew1z07eyJwqNwum84IKYH4S649KEKK5WUmRW49HlXS8=
github.com/siderolabs/gen v0.5.0 h1:Afdjx+zuZDf53eH5DB+E+T2JeCwBXGinV66A6osLgQI=
github.com/siderolabs/gen v0.5.0/go.mod h1:1GUMBNliW98Xeq8GPQeVMYqQE09LFItE8enR3wgMh3Q=
github.com/siderolabs/go-debug v0.4.0 h1:pbFt6Rzumm90s3GvbRer7yIxFNc0gQ94I53omkqswHA=
Expand All @@ -56,8 +56,6 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
Expand Down
53 changes: 53 additions & 0 deletions internal/grpclog/grpclog.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) 2024 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.

// Package grpclog provides a logger that logs to a zap logger.
package grpclog

import (
"context"
"fmt"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
"go.uber.org/zap"
)

// Adapter returns a logging.Logger that logs to the provided zap logger.
func Adapter(l *zap.Logger) logging.Logger {
return logging.LoggerFunc(func(_ context.Context, lvl logging.Level, msg string, fields ...any) {
f := make([]zap.Field, 0, len(fields)/2)

for i := 0; i < len(fields); i += 2 {
key := fields[i].(string) //nolint:forcetypeassert,errcheck
value := fields[i+1]

switch v := value.(type) {
case string:
f = append(f, zap.String(key, v))
case int:
f = append(f, zap.Int(key, v))
case bool:
f = append(f, zap.Bool(key, v))
default:
f = append(f, zap.Any(key, v))
}
}

logger := l.WithOptions(zap.AddCallerSkip(1)).With(f...)

switch lvl {
case logging.LevelDebug:
logger.Debug(msg)
case logging.LevelInfo:
logger.Info(msg)
case logging.LevelWarn:
logger.Warn(msg)
case logging.LevelError:
logger.Error(msg)
default:
panic(fmt.Sprintf("unknown level %v", lvl))
}
})
}
8 changes: 2 additions & 6 deletions pkg/server/addr.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ package server

import (
"context"
"net"
"net/netip"

"go4.org/netipx"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
)
Expand Down Expand Up @@ -38,10 +36,8 @@ func PeerAddress(ctx context.Context) netip.Addr {
}

if peer, ok := peer.FromContext(ctx); ok {
if addr, ok := peer.Addr.(*net.TCPAddr); ok {
if ip, ok := netipx.FromStdIP(addr.IP); ok {
return ip
}
if addrPort, err := netip.ParseAddrPort(peer.Addr.String()); err == nil {
return addrPort.Addr()
}
}

Expand Down
97 changes: 97 additions & 0 deletions pkg/server/cert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) 2024 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.

package server_test

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"net"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
)

var onceCert = sync.OnceValues(func() (*tls.Certificate, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}

keyUsage := x509.KeyUsageDigitalSignature
notBefore := time.Now()
notAfter := notBefore.Add(time.Hour)

serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, err
}

template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"SideroLabs Testing"},
},
NotBefore: notBefore,
NotAfter: notAfter,

KeyUsage: keyUsage,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
}

derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return nil, err
}

cert := tls.Certificate{
Certificate: [][]byte{derBytes},
PrivateKey: priv,
}

cert.Leaf, err = x509.ParseCertificate(derBytes)
if err != nil {
return nil, err
}

return &cert, nil
})

func GetServerTLSConfig(t *testing.T) *tls.Config {
t.Helper()

cert, err := onceCert()
require.NoError(t, err)

return &tls.Config{
Certificates: []tls.Certificate{*cert},
MinVersion: tls.VersionTLS12,
}
}

func GetClientTLSConfig(t *testing.T) *tls.Config {
t.Helper()

cert, err := onceCert()
require.NoError(t, err)

certPool := x509.NewCertPool()
certPool.AddCert(cert.Leaf)

return &tls.Config{
RootCAs: certPool,
MinVersion: tls.VersionTLS12,
}
}
14 changes: 7 additions & 7 deletions pkg/server/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestClient(t *testing.T) {
ClusterID: clusterID,
AffiliateID: affiliate1,
TTL: time.Minute,
Insecure: true,
TLSConfig: GetClientTLSConfig(t),
})
require.NoError(t, err)

Expand All @@ -70,7 +70,7 @@ func TestClient(t *testing.T) {
ClusterID: clusterID,
AffiliateID: affiliate2,
TTL: time.Minute,
Insecure: true,
TLSConfig: GetClientTLSConfig(t),
})
require.NoError(t, err)

Expand Down Expand Up @@ -239,7 +239,7 @@ func TestClient(t *testing.T) {
ClusterID: clusterID,
AffiliateID: affiliate1,
TTL: time.Second,
Insecure: true,
TLSConfig: GetClientTLSConfig(t),
})
require.NoError(t, err)

Expand All @@ -249,7 +249,7 @@ func TestClient(t *testing.T) {
ClusterID: clusterID,
AffiliateID: affiliate2,
TTL: time.Minute,
Insecure: true,
TLSConfig: GetClientTLSConfig(t),
})
require.NoError(t, err)

Expand Down Expand Up @@ -392,7 +392,7 @@ func clusterSimulator(t *testing.T, endpoint string, logger *zap.Logger, numAffi
AffiliateID: fmt.Sprintf("affiliate-%d", i),
ClientVersion: "v0.0.1",
TTL: 10 * time.Second,
Insecure: true,
TLSConfig: GetClientTLSConfig(t),
})
require.NoError(t, err)
}
Expand Down Expand Up @@ -557,7 +557,7 @@ func TestClientRedirect(t *testing.T) {
ClusterID: clusterID,
AffiliateID: affiliate1,
TTL: time.Minute,
Insecure: true,
TLSConfig: GetClientTLSConfig(t),
})
require.NoError(t, err)

Expand All @@ -567,7 +567,7 @@ func TestClientRedirect(t *testing.T) {
ClusterID: clusterID,
AffiliateID: affiliate2,
TTL: time.Minute,
Insecure: true,
TLSConfig: GetClientTLSConfig(t),
})
require.NoError(t, err)

Expand Down
Loading

0 comments on commit efbb10b

Please sign in to comment.