From 3f3b2053c118a651b13c4fce4654eaa2114f8462 Mon Sep 17 00:00:00 2001 From: Marcos Yacob Date: Sat, 14 Sep 2024 14:51:29 -0300 Subject: [PATCH] Force rotation intermediate and Server SVIDs (#5431) * Force rotation of intermediates signed by a compromised authority * Force rotation of Server SVIDs signed by a compromised authority * Force rotation of server SVIDs when not using an upstream authority Signed-off-by: Marcos Yacob --- pkg/agent/agent.go | 2 +- pkg/agent/manager/cache/lru_cache.go | 2 +- pkg/agent/manager/manager.go | 2 +- pkg/agent/svid/rotator.go | 2 +- pkg/agent/svid/rotator_config.go | 2 +- pkg/{agent => }/common/backoff/backoff.go | 0 .../common/backoff/backoff_test.go | 0 .../common/backoff/size_backoff.go | 0 .../common/backoff/size_backoff_test.go | 0 pkg/server/api/localauthority/v1/service.go | 5 + .../api/localauthority/v1/service_test.go | 48 ++++ pkg/server/ca/ca.go | 24 +- pkg/server/ca/ca_test.go | 17 ++ pkg/server/ca/manager/manager.go | 227 ++++++++++++++++-- pkg/server/ca/manager/manager_test.go | 192 ++++++++++++++- pkg/server/ca/rotator/rotator.go | 4 +- pkg/server/ca/rotator/rotator_test.go | 2 +- pkg/server/svid/rotator.go | 51 +++- pkg/server/svid/rotator_test.go | 85 ++++++- test/fakes/fakeserverca/serverca.go | 8 + 20 files changed, 633 insertions(+), 40 deletions(-) rename pkg/{agent => }/common/backoff/backoff.go (100%) rename pkg/{agent => }/common/backoff/backoff_test.go (100%) rename pkg/{agent => }/common/backoff/size_backoff.go (100%) rename pkg/{agent => }/common/backoff/size_backoff_test.go (100%) diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 0d11ca29e0..1c7dc9050c 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -17,13 +17,13 @@ import ( node_attestor "github.com/spiffe/spire/pkg/agent/attestor/node" workload_attestor "github.com/spiffe/spire/pkg/agent/attestor/workload" "github.com/spiffe/spire/pkg/agent/catalog" - "github.com/spiffe/spire/pkg/agent/common/backoff" "github.com/spiffe/spire/pkg/agent/endpoints" "github.com/spiffe/spire/pkg/agent/manager" "github.com/spiffe/spire/pkg/agent/manager/storecache" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" "github.com/spiffe/spire/pkg/agent/storage" "github.com/spiffe/spire/pkg/agent/svid/store" + "github.com/spiffe/spire/pkg/common/backoff" "github.com/spiffe/spire/pkg/common/diskutil" "github.com/spiffe/spire/pkg/common/health" "github.com/spiffe/spire/pkg/common/nodeutil" diff --git a/pkg/agent/manager/cache/lru_cache.go b/pkg/agent/manager/cache/lru_cache.go index 8d01fd07bd..eb5b4e5140 100644 --- a/pkg/agent/manager/cache/lru_cache.go +++ b/pkg/agent/manager/cache/lru_cache.go @@ -11,7 +11,7 @@ import ( "github.com/sirupsen/logrus" "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" "github.com/spiffe/go-spiffe/v2/spiffeid" - "github.com/spiffe/spire/pkg/agent/common/backoff" + "github.com/spiffe/spire/pkg/common/backoff" "github.com/spiffe/spire/pkg/common/telemetry" agentmetrics "github.com/spiffe/spire/pkg/common/telemetry/agent" "github.com/spiffe/spire/proto/spire/common" diff --git a/pkg/agent/manager/manager.go b/pkg/agent/manager/manager.go index b0dec8290c..3294b1c2fc 100644 --- a/pkg/agent/manager/manager.go +++ b/pkg/agent/manager/manager.go @@ -13,11 +13,11 @@ import ( "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/client" - "github.com/spiffe/spire/pkg/agent/common/backoff" "github.com/spiffe/spire/pkg/agent/manager/cache" "github.com/spiffe/spire/pkg/agent/manager/storecache" "github.com/spiffe/spire/pkg/agent/storage" "github.com/spiffe/spire/pkg/agent/svid" + "github.com/spiffe/spire/pkg/common/backoff" "github.com/spiffe/spire/pkg/common/nodeutil" "github.com/spiffe/spire/pkg/common/rotationutil" "github.com/spiffe/spire/pkg/common/telemetry" diff --git a/pkg/agent/svid/rotator.go b/pkg/agent/svid/rotator.go index fd532c44c8..434230d854 100644 --- a/pkg/agent/svid/rotator.go +++ b/pkg/agent/svid/rotator.go @@ -14,8 +14,8 @@ import ( agentv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/agent/v1" node_attestor "github.com/spiffe/spire/pkg/agent/attestor/node" "github.com/spiffe/spire/pkg/agent/client" - "github.com/spiffe/spire/pkg/agent/common/backoff" "github.com/spiffe/spire/pkg/agent/plugin/keymanager" + "github.com/spiffe/spire/pkg/common/backoff" "github.com/spiffe/spire/pkg/common/nodeutil" "github.com/spiffe/spire/pkg/common/rotationutil" "github.com/spiffe/spire/pkg/common/telemetry" diff --git a/pkg/agent/svid/rotator_config.go b/pkg/agent/svid/rotator_config.go index 6eb4b0538d..203c194ec0 100644 --- a/pkg/agent/svid/rotator_config.go +++ b/pkg/agent/svid/rotator_config.go @@ -11,10 +11,10 @@ import ( "github.com/sirupsen/logrus" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/client" - "github.com/spiffe/spire/pkg/agent/common/backoff" "github.com/spiffe/spire/pkg/agent/manager/cache" "github.com/spiffe/spire/pkg/agent/plugin/keymanager" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" + "github.com/spiffe/spire/pkg/common/backoff" "github.com/spiffe/spire/pkg/common/rotationutil" "github.com/spiffe/spire/pkg/common/telemetry" ) diff --git a/pkg/agent/common/backoff/backoff.go b/pkg/common/backoff/backoff.go similarity index 100% rename from pkg/agent/common/backoff/backoff.go rename to pkg/common/backoff/backoff.go diff --git a/pkg/agent/common/backoff/backoff_test.go b/pkg/common/backoff/backoff_test.go similarity index 100% rename from pkg/agent/common/backoff/backoff_test.go rename to pkg/common/backoff/backoff_test.go diff --git a/pkg/agent/common/backoff/size_backoff.go b/pkg/common/backoff/size_backoff.go similarity index 100% rename from pkg/agent/common/backoff/size_backoff.go rename to pkg/common/backoff/size_backoff.go diff --git a/pkg/agent/common/backoff/size_backoff_test.go b/pkg/common/backoff/size_backoff_test.go similarity index 100% rename from pkg/agent/common/backoff/size_backoff_test.go rename to pkg/common/backoff/size_backoff_test.go diff --git a/pkg/server/api/localauthority/v1/service.go b/pkg/server/api/localauthority/v1/service.go index 97f8af987d..e2d66302c1 100644 --- a/pkg/server/api/localauthority/v1/service.go +++ b/pkg/server/api/localauthority/v1/service.go @@ -33,6 +33,7 @@ type CAManager interface { RotateX509CA(ctx context.Context) IsUpstreamAuthority() bool + NotifyTaintedX509Authority(ctx context.Context, authorityID string) error } // RegisterService registers the service on the gRPC server. @@ -367,6 +368,10 @@ func (s *Service) TaintX509Authority(ctx context.Context, req *localauthorityv1. AuthorityId: nextSlot.AuthorityID(), } + if err := s.ca.NotifyTaintedX509Authority(ctx, nextSlot.AuthorityID()); err != nil { + return nil, api.MakeErr(log, codes.Internal, "failed to notify tainted authority", err) + } + rpccontext.AuditRPC(ctx) log.Info("X.509 authority tainted successfully") diff --git a/pkg/server/api/localauthority/v1/service_test.go b/pkg/server/api/localauthority/v1/service_test.go index badce0565f..adf8ec12be 100644 --- a/pkg/server/api/localauthority/v1/service_test.go +++ b/pkg/server/api/localauthority/v1/service_test.go @@ -27,6 +27,7 @@ import ( "github.com/spiffe/spire/test/testca" "github.com/spiffe/spire/test/testkey" testutil "github.com/spiffe/spire/test/util" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -1324,6 +1325,7 @@ func TestTaintX509Authority(t *testing.T) { keyToTaint string customRootCAs []*common.Certificate isUpstreamAuthority bool + notifyTaintedErr error expectLogs []spiretest.LogEntry expectCode codes.Code @@ -1537,6 +1539,36 @@ func TestTaintX509Authority(t *testing.T) { }, }, }, + { + name: "fail to notify tainted authority", + currentSlot: createSlot(journal.Status_ACTIVE, currentAuthorityID, currentKey.Public(), notAfterCurrent), + nextSlot: createSlot(journal.Status_OLD, nextAuthorityID, nextKey.Public(), notAfterNext), + keyToTaint: nextAuthorityID, + notifyTaintedErr: errors.New("oh no"), + expectCode: codes.Internal, + expectMsg: "failed to notify tainted authority: oh no", + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Failed to notify tainted authority", + Data: logrus.Fields{ + telemetry.LocalAuthorityID: nextAuthorityID, + logrus.ErrorKey: "oh no", + }, + }, + { + Level: logrus.InfoLevel, + Message: "API accessed", + Data: logrus.Fields{ + telemetry.Status: "error", + telemetry.StatusCode: "Internal", + telemetry.StatusMessage: "failed to notify tainted authority: oh no", + telemetry.Type: "audit", + telemetry.LocalAuthorityID: nextAuthorityID, + }, + }, + }, + }, } { t.Run(tt.name, func(t *testing.T) { test := setupServiceTest(t) @@ -1545,6 +1577,7 @@ func TestTaintX509Authority(t *testing.T) { test.ca.currentX509CASlot = tt.currentSlot test.ca.nextX509CASlot = tt.nextSlot test.ca.isUpstreamAuthority = tt.isUpstreamAuthority + test.ca.notifyTaintedExpectErr = tt.notifyTaintedErr rootCAs := defaultRootCAs if tt.customRootCAs != nil { @@ -1564,6 +1597,10 @@ func TestTaintX509Authority(t *testing.T) { spiretest.AssertGRPCStatusHasPrefix(t, err, tt.expectCode, tt.expectMsg) spiretest.AssertProtoEqual(t, tt.expectResp, resp) spiretest.AssertLogs(t, test.logHook.AllEntries(), tt.expectLogs) + // Validate notification is received on success test cases + if tt.expectMsg == "" { + assert.Equal(t, tt.keyToTaint, test.ca.notifyTaintedAuthorityID) + } }) } } @@ -2445,6 +2482,17 @@ type fakeCAManager struct { prepareX509CAErr error isUpstreamAuthority bool + + notifyTaintedExpectErr error + notifyTaintedAuthorityID string +} + +func (m *fakeCAManager) NotifyTaintedX509Authority(ctx context.Context, authorityID string) error { + if m.notifyTaintedExpectErr != nil { + return m.notifyTaintedExpectErr + } + m.notifyTaintedAuthorityID = authorityID + return nil } func (m *fakeCAManager) IsUpstreamAuthority() bool { diff --git a/pkg/server/ca/ca.go b/pkg/server/ca/ca.go index 8b52ae1e98..673ca817d1 100644 --- a/pkg/server/ca/ca.go +++ b/pkg/server/ca/ca.go @@ -36,6 +36,7 @@ type ServerCA interface { SignAgentX509SVID(ctx context.Context, params AgentX509SVIDParams) ([]*x509.Certificate, error) SignWorkloadX509SVID(ctx context.Context, params WorkloadX509SVIDParams) ([]*x509.Certificate, error) SignWorkloadJWTSVID(ctx context.Context, params WorkloadJWTSVIDParams) (string, error) + TaintedAuthorities() <-chan []*x509.Certificate } // DownstreamX509CAParams are parameters relevant to downstream X.509 CA creation @@ -133,10 +134,11 @@ type Config struct { type CA struct { c Config - mu sync.RWMutex - x509CA *X509CA - x509CAChain []*x509.Certificate - jwtKey *JWTKey + mu sync.RWMutex + x509CA *X509CA + x509CAChain []*x509.Certificate + jwtKey *JWTKey + taintedAuthoritiesCh chan []*x509.Certificate } func NewCA(config Config) *CA { @@ -146,6 +148,9 @@ func NewCA(config Config) *CA { ca := &CA{ c: config, + + // Notify caller about any tainted authority + taintedAuthoritiesCh: make(chan []*x509.Certificate, 1), } _ = config.HealthChecker.AddCheck("server.ca", &caHealth{ @@ -188,6 +193,17 @@ func (ca *CA) SetJWTKey(jwtKey *JWTKey) { ca.jwtKey = jwtKey } +func (ca *CA) NotifyTaintedX509Authorities(taintedAuthorities []*x509.Certificate) { + select { + case ca.taintedAuthoritiesCh <- taintedAuthorities: + default: + } +} + +func (ca *CA) TaintedAuthorities() <-chan []*x509.Certificate { + return ca.taintedAuthoritiesCh +} + func (ca *CA) SignDownstreamX509CA(ctx context.Context, params DownstreamX509CAParams) ([]*x509.Certificate, error) { x509CA, caChain, err := ca.getX509CA() if err != nil { diff --git a/pkg/server/ca/ca_test.go b/pkg/server/ca/ca_test.go index 6472f22d3a..6dcab1a468 100644 --- a/pkg/server/ca/ca_test.go +++ b/pkg/server/ca/ca_test.go @@ -436,6 +436,23 @@ func (s *CATestSuite) TestNoJWTKeySet() { s.Require().EqualError(err, "JWT key is not available for signing") } +func (s *CATestSuite) TestTaintedAuthoritiesArePropagated() { + authorities := []*x509.Certificate{ + {Raw: []byte("foh")}, + {Raw: []byte("bar")}, + } + s.ca.NotifyTaintedX509Authorities(authorities) + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + select { + case got := <-s.ca.TaintedAuthorities(): + s.Require().Equal(authorities, got) + case <-ctx.Done(): + s.Fail("no notification received") + } +} + func (s *CATestSuite) TestSignWorkloadJWTSVIDUsesDefaultTTLIfTTLUnspecified() { token, err := s.ca.SignWorkloadJWTSVID(ctx, s.createJWTSVIDParams(trustDomainExample, 0)) s.Require().NoError(err) diff --git a/pkg/server/ca/manager/manager.go b/pkg/server/ca/manager/manager.go index 393fd14cee..e67530ef13 100644 --- a/pkg/server/ca/manager/manager.go +++ b/pkg/server/ca/manager/manager.go @@ -6,6 +6,7 @@ import ( "crypto" "crypto/rand" "crypto/x509" + "errors" "fmt" "sync" "time" @@ -13,6 +14,7 @@ import ( "github.com/andres-erbsen/clock" "github.com/sirupsen/logrus" "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/spire/pkg/common/backoff" "github.com/spiffe/spire/pkg/common/coretypes/x509certificate" "github.com/spiffe/spire/pkg/common/telemetry" telemetry_server "github.com/spiffe/spire/pkg/common/telemetry/server" @@ -43,11 +45,15 @@ const ( sevenDays = 7 * 24 * time.Hour activationThresholdCap = sevenDays activationThresholdDivisor = 6 + + taintBackoffInterval = 5 * time.Second + taintBackoffMaxElapsedTime = 1 * time.Minute ) type ManagedCA interface { SetX509CA(*ca.X509CA) SetJWTKey(*ca.JWTKey) + NotifyTaintedX509Authorities([]*x509.Certificate) } type JwtKeyPublisher interface { @@ -65,6 +71,7 @@ type AuthorityManager interface { RotateX509CA(ctx context.Context) IsUpstreamAuthority() bool PublishJWTKey(ctx context.Context, jwtKey *common.PublicKey) ([]*common.PublicKey, error) + NotifyTaintedX509Authority(ctx context.Context, authorityID string) error } type Config struct { @@ -82,11 +89,12 @@ type Config struct { } type Manager struct { - c Config - caTTL time.Duration - bundleUpdatedCh chan struct{} - upstreamClient *ca.UpstreamClient - upstreamPluginName string + c Config + caTTL time.Duration + bundleUpdatedCh chan struct{} + taintedUpstreamAuthoritiesCh chan []*x509.Certificate + upstreamClient *ca.UpstreamClient + upstreamPluginName string currentX509CA *x509CASlot nextX509CA *x509CASlot @@ -100,6 +108,9 @@ type Manager struct { // Used to log a warning only once when the UpstreamAuthority does not support JWT-SVIDs. jwtUnimplementedWarnOnce sync.Once + + // Used for testing backoff, must not be set in regular code + triggerBackOffCh chan error } func NewManager(ctx context.Context, c Config) (*Manager, error) { @@ -108,19 +119,22 @@ func NewManager(ctx context.Context, c Config) (*Manager, error) { } m := &Manager{ - c: c, - caTTL: c.CredBuilder.Config().X509CATTL, - bundleUpdatedCh: make(chan struct{}, 1), + c: c, + caTTL: c.CredBuilder.Config().X509CATTL, + bundleUpdatedCh: make(chan struct{}, 1), + taintedUpstreamAuthoritiesCh: make(chan []*x509.Certificate, 1), } if upstreamAuthority, ok := c.Catalog.GetUpstreamAuthority(); ok { m.upstreamClient = ca.NewUpstreamClient(ca.UpstreamClientConfig{ UpstreamAuthority: upstreamAuthority, BundleUpdater: &bundleUpdater{ - log: c.Log, - trustDomainID: c.TrustDomain.IDString(), - ds: c.Catalog.GetDataStore(), - updated: m.bundleUpdated, + log: c.Log, + trustDomainID: c.TrustDomain.IDString(), + ds: c.Catalog.GetDataStore(), + updated: m.bundleUpdated, + upstreamAuthoritiesTainted: m.notifyUpstreamAuthoritiesTainted, + processedTaintedAuthorities: map[string]struct{}{}, }, }) m.upstreamPluginName = upstreamAuthority.Name() @@ -181,6 +195,16 @@ func (m *Manager) Close() { } } +func (m *Manager) NotifyTaintedX509Authority(ctx context.Context, authoirtyID string) error { + taintedAuthority, err := m.fetchRootCAByAuthorityID(ctx, authoirtyID) + if err != nil { + return err + } + + m.c.CA.NotifyTaintedX509Authorities([]*x509.Certificate{taintedAuthority}) + return nil +} + func (m *Manager) GetCurrentX509CASlot() Slot { m.x509CAMutex.RLock() defer m.x509CAMutex.RUnlock() @@ -461,13 +485,19 @@ func (m *Manager) PruneCAJournals(ctx context.Context) (err error) { return nil } -func (m *Manager) NotifyOnBundleUpdate(ctx context.Context) { +// ProcessBundleUpdates Notify any bundle update, or process tainted authorities +func (m *Manager) ProcessBundleUpdates(ctx context.Context) { for { select { case <-m.bundleUpdatedCh: if err := m.notifyBundleUpdated(ctx); err != nil { m.c.Log.WithError(err).Warn("Failed to notify on bundle update") } + case taintedAuthorities := <-m.taintedUpstreamAuthoritiesCh: + if err := m.notifyTaintedAuthorities(ctx, taintedAuthorities); err != nil { + m.c.Log.WithError(err).Error("Failed to force intermediate bundle rotation") + return + } case <-ctx.Done(): return } @@ -550,6 +580,114 @@ func (m *Manager) dropBundleUpdated() { default: } } + +func (m *Manager) notifyUpstreamAuthoritiesTainted(taintedAuthorities []*x509.Certificate) { + select { + case m.taintedUpstreamAuthoritiesCh <- taintedAuthorities: + default: + } +} + +func (m *Manager) fetchRootCAByAuthorityID(ctx context.Context, authorityID string) (*x509.Certificate, error) { + bundle, err := m.fetchRequiredBundle(ctx) + if err != nil { + return nil, err + } + + for _, rootCA := range bundle.RootCas { + if rootCA.TaintedKey { + cert, err := x509.ParseCertificate(rootCA.DerBytes) + if err != nil { + return nil, fmt.Errorf("failed to parse RootCA: %w", err) + } + + skID := x509util.SubjectKeyIDToString(cert.SubjectKeyId) + if authorityID == skID { + return cert, nil + } + } + } + + return nil, fmt.Errorf("no tainted root CA found with authority ID: %q", authorityID) +} + +func (m *Manager) notifyTaintedAuthorities(ctx context.Context, taintedAuthorities []*x509.Certificate) error { + taintBackoff := backoff.NewBackoff( + m.c.Clock, + taintBackoffInterval, + backoff.WithMaxElapsedTime(taintBackoffMaxElapsedTime), + ) + + for { + err := m.processTaintedUpstreamAuthorities(ctx, taintedAuthorities) + if err == nil { + break + } + + nextDuration := taintBackoff.NextBackOff() + if nextDuration == backoff.Stop { + return err + } + m.c.Log.WithError(err).Warn("Failed to process tainted keys on upstream authority") + if m.triggerBackOffCh != nil { + m.triggerBackOffCh <- err + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.c.Clock.After(nextDuration): + continue + } + } + + return nil +} + +func (m *Manager) processTaintedUpstreamAuthorities(ctx context.Context, taintedAuthorities []*x509.Certificate) error { + // Nothing to rotate if no upstream authority is used + if m.upstreamClient == nil { + return errors.New("processing of tainted upstream authorities must not be reached when not using an upstream authority; please report this bug") + } + + if len(taintedAuthorities) == 0 { + // No tainted keys found + return nil + } + + m.c.Log.Debug("Processing tainted keys on upstream authority") + + currentSlotCA := m.currentX509CA.x509CA + if ok := isX509AuthorityTainted(currentSlotCA, taintedAuthorities); ok { + m.c.Log.Info("Current root CA is signed by a tainted upstream authority, preparing rotation") + if ok := m.shouldPrepareX509CA(taintedAuthorities); ok { + if err := m.PrepareX509CA(ctx); err != nil { + return fmt.Errorf("failed to prepare x509 authority: %w", err) + } + } + + // Activate the prepared X.509 authority + m.RotateX509CA(ctx) + } + + // Now that we have rotated the intermediate, we can notify about the + // tainted authorities, so agents and downstream servers can start forcing + // the rotation of their SVIDs. + ds := m.c.Catalog.GetDataStore() + for _, each := range taintedAuthorities { + skID := x509util.SubjectKeyIDToString(each.SubjectKeyId) + if err := ds.TaintX509CA(ctx, m.c.TrustDomain.IDString(), skID); err != nil { + return fmt.Errorf("could not taint X509 CA in datastore: %w", err) + } + } + + // Intermediate is safe. Notify rotator to force rotation + // of tainted X.509 SVID. + m.c.CA.NotifyTaintedX509Authorities(taintedAuthorities) + + return nil +} + func (m *Manager) notifyBundleUpdated(ctx context.Context) error { var bundle *common.Bundle return m.notify(ctx, "bundle updated", false, @@ -712,6 +850,20 @@ func (m *Manager) appendBundle(ctx context.Context, caChain []*x509.Certificate, return res, nil } +func (m *Manager) shouldPrepareX509CA(taintedAuthorities []*x509.Certificate) bool { + slot := m.nextX509CA + switch { + case slot.IsEmpty(): + return true + case slot.Status() == journal.Status_PREPARED: + isTainted := isX509AuthorityTainted(slot.x509CA, taintedAuthorities) + m.c.Log.Info("Next authority is tainted, prepare new X.509 authority") + return isTainted + default: + return false + } +} + // MaxSVIDTTL returns the maximum SVID lifetime that can be guaranteed to not // be cut artificially short by a scheduled rotation. func MaxSVIDTTL() time.Duration { @@ -740,10 +892,12 @@ func MinCATTLForSVIDTTL(svidTTL time.Duration) time.Duration { } type bundleUpdater struct { - log logrus.FieldLogger - trustDomainID string - ds datastore.DataStore - updated func() + log logrus.FieldLogger + trustDomainID string + ds datastore.DataStore + updated func() + upstreamAuthoritiesTainted func([]*x509.Certificate) + processedTaintedAuthorities map[string]struct{} } func (u *bundleUpdater) SyncX509Roots(ctx context.Context, roots []*x509certificate.X509Authority) error { @@ -758,6 +912,7 @@ func (u *bundleUpdater) SyncX509Roots(ctx context.Context, roots []*x509certific } newAuthorities := make(map[string]struct{}, len(roots)) + var taintedAuthorities []*x509.Certificate for _, root := range roots { skID := x509util.SubjectKeyIDToString(root.Certificate.SubjectKeyId) // Collect all skIDs @@ -767,10 +922,13 @@ func (u *bundleUpdater) SyncX509Roots(ctx context.Context, roots []*x509certific if root.Tainted { // Taint x.509 authority, if required if found, ok := x509Authorities[skID]; ok && !found.Tainted { - if err := u.ds.TaintX509CA(ctx, u.trustDomainID, skID); err != nil { - return fmt.Errorf("failed to taint x.509 authority %q: %w", skID, err) + _, alreadyProcessed := u.processedTaintedAuthorities[skID] + if !alreadyProcessed { + u.processedTaintedAuthorities[skID] = struct{}{} + // Add to the list of new tainted authorities + taintedAuthorities = append(taintedAuthorities, found.Certificate) + u.log.WithField(telemetry.SubjectKeyID, skID).Info("X.509 authority tainted") } - u.log.WithField(telemetry.SubjectKeyID, skID).Info("X.509 authority tainted") // Prevent to add tainted keys, since status is updated before continue } @@ -782,6 +940,14 @@ func (u *bundleUpdater) SyncX509Roots(ctx context.Context, roots []*x509certific }) } + // Notify about tainted authorities to force the rotation of + // intermediates and update the database. This is done in a separate thread + // to prevent agents and downstream servers to start the rotation before the + // current server starts the rotation of the intermediate. + if len(taintedAuthorities) > 0 { + u.upstreamAuthoritiesTainted(taintedAuthorities) + } + for skID, authority := range x509Authorities { // Only tainted keys can ke revoked if authority.Tainted { @@ -895,3 +1061,24 @@ func publicKeyFromJWTKey(jwtKey *ca.JWTKey) (*common.PublicKey, error) { NotAfter: jwtKey.NotAfter.Unix(), }, nil } + +// isX509AuthorityTainted verifies if the provided X.509 authority is tainted +func isX509AuthorityTainted(x509CA *ca.X509CA, taintedAuthorities []*x509.Certificate) bool { + rootPool := x509.NewCertPool() + for _, taintedKey := range taintedAuthorities { + rootPool.AddCert(taintedKey) + } + + intermediatePool := x509.NewCertPool() + for _, intermediateCA := range x509CA.UpstreamChain { + intermediatePool.AddCert(intermediateCA) + } + + // Verify certificate chain, using tainted authority as root + _, err := x509CA.Certificate.Verify(x509.VerifyOptions{ + Intermediates: intermediatePool, + Roots: rootPool, + }) + + return err == nil +} diff --git a/pkg/server/ca/manager/manager_test.go b/pkg/server/ca/manager/manager_test.go index ba70f14817..3050da19fb 100644 --- a/pkg/server/ca/manager/manager_test.go +++ b/pkg/server/ca/manager/manager_test.go @@ -37,6 +37,7 @@ import ( "github.com/spiffe/spire/test/fakes/fakeserverkeymanager" "github.com/spiffe/spire/test/fakes/fakeupstreamauthority" "github.com/spiffe/spire/test/spiretest" + "github.com/spiffe/spire/test/testca" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" @@ -247,6 +248,65 @@ func TestSlotLoadedWhenJournalIsLost(t *testing.T) { require.True(t, test.m.GetCurrentX509CASlot().IsEmpty()) } +func TestNotifyTaintedX509Authority(t *testing.T) { + ctx := context.Background() + test := setupTest(t) + test.initSelfSignedManager() + + // Create a test CA + ca := testca.New(t, testTrustDomain) + cert := ca.X509Authorities()[0] + bundle, err := test.ds.CreateBundle(ctx, &common.Bundle{ + TrustDomainId: testTrustDomain.IDString(), + RootCas: []*common.Certificate{ + { + DerBytes: cert.Raw, + TaintedKey: true, + }, + }, + }) + require.NoError(t, err) + + t.Run("notify tainted authority", func(t *testing.T) { + err = test.m.NotifyTaintedX509Authority(ctx, ca.GetSubjectKeyID()) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + expectedTaintedAuthorities := []*x509.Certificate{cert} + select { + case taintedAuthorities := <-test.ca.taintedAuthoritiesCh: + require.Equal(t, expectedTaintedAuthorities, taintedAuthorities) + case <-ctx.Done(): + assert.Fail(t, "no notification received") + } + }) + + // Untaint authority + bundle.RootCas[0].TaintedKey = false + bundle, err = test.ds.UpdateBundle(ctx, bundle, nil) + require.NoError(t, err) + + t.Run("no tainted authority", func(t *testing.T) { + err := test.m.NotifyTaintedX509Authority(ctx, ca.GetSubjectKeyID()) + + expectedErr := fmt.Sprintf("no tainted root CA found with authority ID: %q", ca.GetSubjectKeyID()) + require.EqualError(t, err, expectedErr) + }) + + bundle.RootCas = append(bundle.RootCas, &common.Certificate{ + DerBytes: []byte("foh"), + TaintedKey: true, + }) + _, err = test.ds.UpdateBundle(ctx, bundle, nil) + require.NoError(t, err) + + t.Run("malformed root CA", func(t *testing.T) { + err := test.m.NotifyTaintedX509Authority(ctx, ca.GetSubjectKeyID()) + require.EqualError(t, err, "failed to parse RootCA: x509: malformed certificate") + }) +} + func TestSelfSigning(t *testing.T) { ctx := context.Background() test := setupTest(t) @@ -307,12 +367,128 @@ func TestUpstreamSigned(t *testing.T) { x509Roots := fakeUA.X509Roots() require.True(t, x509Roots[0].Tainted) + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + select { + case taintedAuthorities := <-test.m.taintedUpstreamAuthoritiesCh: + expectedTaintedAuthorities := []*x509.Certificate{x509Roots[0].Certificate} + require.Equal(t, expectedTaintedAuthorities, taintedAuthorities) + case <-ctx.Done(): + assert.Fail(t, "no notification received") + } +} + +func TestUpstreamProcesssTaintedAuthority(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + test := setupTest(t) + + upstreamAuthority, fakeUA := fakeupstreamauthority.Load(t, fakeupstreamauthority.Config{ + TrustDomain: testTrustDomain, + DisallowPublishJWTKey: true, + }) + + test.initAndActivateUpstreamSignedManager(ctx, upstreamAuthority) + require.True(t, test.m.IsUpstreamAuthority()) + + // Prepared must be tainted too + err := test.m.PrepareX509CA(ctx) + require.NoError(t, err) + + go test.m.ProcessBundleUpdates(ctx) + + // Taint first root + err = fakeUA.TaintAuthority(0) + require.NoError(t, err) + + // Get the roots again and verify that the first X.509 authority is tainted + x509Roots := fakeUA.X509Roots() + require.True(t, x509Roots[0].Tainted) + + commonCertificates := x509certificate.RequireToCommonProtos(x509Roots) + // Retry until the Tainted attribute is propagated to the database + require.Eventually(t, func() bool { + bundle := test.fetchBundle(ctx) + return spiretest.AssertProtoListEqual(t, commonCertificates, bundle.RootCas) + }, time.Minute, 500*time.Millisecond) + + expectedTaintedAuthorities := []*x509.Certificate{x509Roots[0].Certificate} + select { + case received := <-test.ca.taintedAuthoritiesCh: + require.Equal(t, expectedTaintedAuthorities, received) + case <-ctx.Done(): + assert.Fail(t, "deadline reached") + } +} + +func TestUpstreamProcesssTaintedAuthorityBackoff(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + test := setupTest(t) + + upstreamAuthority, fakeUA := fakeupstreamauthority.Load(t, fakeupstreamauthority.Config{ + TrustDomain: testTrustDomain, + DisallowPublishJWTKey: true, + }) + + test.initAndActivateUpstreamSignedManager(ctx, upstreamAuthority) + require.True(t, test.m.IsUpstreamAuthority()) + + test.m.triggerBackOffCh = make(chan error, 1) + + // Prepared must be tainted too + go test.m.ProcessBundleUpdates(ctx) + + // Set an invalid key type to make prepare fails + test.m.c.X509CAKeyType = 123 + err := test.m.PrepareX509CA(ctx) + require.Error(t, err) + + // Taint first root + err = fakeUA.TaintAuthority(0) + require.NoError(t, err) + + // Get the roots again and verify that the first X.509 authority is tainted + x509Roots := fakeUA.X509Roots() + require.True(t, x509Roots[0].Tainted) + + expectBackoffErr := func(t *testing.T) { + select { + case receivedErr := <-test.m.triggerBackOffCh: + require.EqualError(t, receivedErr, "failed to prepare x509 authority: rpc error: code = Internal desc = keymanager(fake): facade does not support key type \"UNKNOWN(123)\"") + case <-ctx.Done(): + assert.Fail(t, "deadline reached") + } + } + + // Must fail due to the invalid key type + expectBackoffErr(t) + + // Try again; expect to fail + test.clock.Add(6 * time.Second) + expectBackoffErr(t) + + // Restore to a valid key type, and advance time again + test.m.c.X509CAKeyType = keymanager.ECP256 + test.clock.Add(10 * time.Second) + commonCertificates := x509certificate.RequireToCommonProtos(x509Roots) // Retry until the Tainted attribute is propagated to the database require.Eventually(t, func() bool { bundle := test.fetchBundle(ctx) return spiretest.AssertProtoListEqual(t, commonCertificates, bundle.RootCas) }, time.Minute, 500*time.Millisecond) + + expectedTaintedAuthorities := []*x509.Certificate{x509Roots[0].Certificate} + select { + case received := <-test.ca.taintedAuthoritiesCh: + require.Equal(t, expectedTaintedAuthorities, received) + case <-ctx.Done(): + assert.Fail(t, "deadline reached") + } } func TestGetCurrentX509CASlotUpstreamSigned(t *testing.T) { @@ -470,7 +646,7 @@ func TestX509CARotation(t *testing.T) { // kick off a goroutine to service bundle update notifications. This is // typically handled by Run() but using it would complicate the test. test.m.dropBundleUpdated() - go test.m.NotifyOnBundleUpdate(ctx) + go test.m.ProcessBundleUpdates(ctx) // after initialization, we should have a current X509CA but no next. first := test.currentX509CA() @@ -560,7 +736,7 @@ func TestJWTKeyRotation(t *testing.T) { // kick off a goroutine to service bundle update notifications. This is // typically handled by Run() but using it would complicate the test. test.m.dropBundleUpdated() // drop bundle update message produce by initialization - go test.m.NotifyOnBundleUpdate(ctx) + go test.m.ProcessBundleUpdates(ctx) // after initialization, we should have a current JWTKey but no next. first := test.currentJWTKey() @@ -646,7 +822,7 @@ func TestPruneBundle(t *testing.T) { // kick off a goroutine to service bundle update notifications. This is // typically handled by Run() but using it would complicate the test. test.m.dropBundleUpdated() // drop bundle update message produce by initialization - go test.m.NotifyOnBundleUpdate(ctx) + go test.m.ProcessBundleUpdates(ctx) // advance just past the expiration time of the first and prune. nothing // should change. @@ -1045,7 +1221,9 @@ type managerTest struct { func setupTest(t *testing.T) *managerTest { clock := clock.NewMock(t) - ca := new(fakeCA) + ca := &fakeCA{ + taintedAuthoritiesCh: make(chan []*x509.Certificate, 1), + } log, logHook := test.NewNullLogger() metrics := fakemetrics.New() @@ -1347,6 +1525,8 @@ type fakeCA struct { mu sync.Mutex x509CA *ca.X509CA jwtKey *ca.JWTKey + + taintedAuthoritiesCh chan []*x509.Certificate } func (s *fakeCA) X509CA() *ca.X509CA { @@ -1372,3 +1552,7 @@ func (s *fakeCA) SetJWTKey(jwtKey *ca.JWTKey) { defer s.mu.Unlock() s.jwtKey = jwtKey } + +func (s *fakeCA) NotifyTaintedX509Authorities(taintedAuthorities []*x509.Certificate) { + s.taintedAuthoritiesCh <- taintedAuthorities +} diff --git a/pkg/server/ca/rotator/rotator.go b/pkg/server/ca/rotator/rotator.go index 17f4ef190d..923a020ca7 100644 --- a/pkg/server/ca/rotator/rotator.go +++ b/pkg/server/ca/rotator/rotator.go @@ -22,7 +22,7 @@ const ( type CAManager interface { NotifyBundleLoaded(ctx context.Context) error - NotifyOnBundleUpdate(ctx context.Context) + ProcessBundleUpdates(ctx context.Context) GetCurrentX509CASlot() manager.Slot GetNextX509CASlot() manager.Slot @@ -91,7 +91,7 @@ func (r *Rotator) Run(ctx context.Context) error { func(ctx context.Context) error { // notifyOnBundleUpdate does not fail but rather logs any errors // encountered while notifying - r.c.Manager.NotifyOnBundleUpdate(ctx) + r.c.Manager.ProcessBundleUpdates(ctx) return nil }, ) diff --git a/pkg/server/ca/rotator/rotator_test.go b/pkg/server/ca/rotator/rotator_test.go index b8fb1bb6d8..6361ee810d 100644 --- a/pkg/server/ca/rotator/rotator_test.go +++ b/pkg/server/ca/rotator/rotator_test.go @@ -413,7 +413,7 @@ func (f *fakeCAManager) NotifyBundleLoaded(context.Context) error { return nil } -func (f *fakeCAManager) NotifyOnBundleUpdate(context.Context) { +func (f *fakeCAManager) ProcessBundleUpdates(context.Context) { } func (f *fakeCAManager) GetCurrentX509CASlot() manager.Slot { diff --git a/pkg/server/svid/rotator.go b/pkg/server/svid/rotator.go index ca23b1ea49..fc62dbaff8 100644 --- a/pkg/server/svid/rotator.go +++ b/pkg/server/svid/rotator.go @@ -13,10 +13,16 @@ import ( "github.com/spiffe/spire/pkg/server/ca" ) +var ( + defaultBundleVerificationTicker = 30 * time.Second +) + type Rotator struct { c *RotatorConfig - state observer.Property + state observer.Property + isSVIDTainted bool + taintedReceived chan bool } // State is the current SVID and key @@ -42,17 +48,31 @@ func (r *Rotator) Interval() time.Duration { return r.c.Interval } +func (r *Rotator) triggerTaintedReceived(tainted bool) { + r.taintedReceived <- tainted +} + // Run starts a ticker which monitors the server SVID // for expiration and rotates the SVID as necessary. func (r *Rotator) Run(ctx context.Context) error { t := r.c.Clock.Ticker(r.c.Interval) defer t.Stop() + bundeVerificationTicker := r.c.Clock.Ticker(defaultBundleVerificationTicker) + defer bundeVerificationTicker.Stop() + for { select { case <-ctx.Done(): r.c.Log.Debug("Stopping SVID rotator") return nil + case taintedAuthorities := <-r.c.ServerCA.TaintedAuthorities(): + isTainted := r.isX509AuthorityTainted(taintedAuthorities) + if isTainted { + r.triggerTaintedReceived(true) + r.c.Log.Info("Server SVID signed using a tainted authority, forcing rotation of the Server SVID") + r.isSVIDTainted = true + } case <-t.C: if r.shouldRotate() { if err := r.rotateSVID(ctx); err != nil { @@ -72,7 +92,31 @@ func (r *Rotator) shouldRotate() bool { return true } - return r.c.Clock.Now().After(certHalfLife(s.SVID[0])) + return r.c.Clock.Now().After(certHalfLife(s.SVID[0])) || + r.isSVIDTainted +} + +func (r *Rotator) isX509AuthorityTainted(taintedAuthorities []*x509.Certificate) bool { + svid := r.State().SVID + + rootPool := x509.NewCertPool() + for _, taintedKey := range taintedAuthorities { + rootPool.AddCert(taintedKey) + } + + intermediatePool := x509.NewCertPool() + for _, intermediateCA := range svid[1:] { + intermediatePool.AddCert(intermediateCA) + } + + // Verify certificate chain, using tainted authority as root + _, err := svid[0].Verify(x509.VerifyOptions{ + Intermediates: intermediatePool, + Roots: rootPool, + CurrentTime: r.c.Clock.Now(), + }) + + return err == nil } // rotateSVID cuts a new server SVID from the CA plugin and installs @@ -103,6 +147,9 @@ func (r *Rotator) rotateSVID(ctx context.Context) (err error) { SVID: svid, Key: signer, }) + // New SVID must not be tainted. Rotator is notified about tainted + // authorities only when the intermediate is already rotated. + r.isSVIDTainted = false return nil } diff --git a/pkg/server/svid/rotator_test.go b/pkg/server/svid/rotator_test.go index 1536362b69..6a321491c1 100644 --- a/pkg/server/svid/rotator_test.go +++ b/pkg/server/svid/rotator_test.go @@ -14,10 +14,16 @@ import ( "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/common/telemetry" + "github.com/spiffe/spire/pkg/common/x509util" + "github.com/spiffe/spire/pkg/server/ca" + "github.com/spiffe/spire/pkg/server/credtemplate" "github.com/spiffe/spire/pkg/server/plugin/keymanager" "github.com/spiffe/spire/test/clock" "github.com/spiffe/spire/test/fakes/fakeserverca" "github.com/spiffe/spire/test/spiretest" + "github.com/spiffe/spire/test/testkey" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -25,6 +31,10 @@ const ( testTTL = time.Minute * 10 ) +var ( + trustDomain = spiffeid.RequireTrustDomainFromString("example.org") +) + func TestRotator(t *testing.T) { suite.Run(t, new(RotatorTestSuite)) } @@ -39,8 +49,6 @@ type RotatorTestSuite struct { } func (s *RotatorTestSuite) SetupTest() { - trustDomain := spiffeid.RequireTrustDomainFromString("example.org") - s.clock = clock.NewMock(s.T()) s.serverCA = fakeserverca.New(s.T(), trustDomain, &fakeserverca.Options{ Clock: s.clock, @@ -105,6 +113,79 @@ func (s *RotatorTestSuite) TestRotationSucceeds() { s.Require().NoError(<-errCh) } +func (s *RotatorTestSuite) TestForceRotation() { + stream := s.r.Subscribe() + t := s.T() + + var wg sync.WaitGroup + defer wg.Wait() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + err := s.r.Initialize(ctx) + s.Require().NoError(err) + + originalCA := s.serverCA.Bundle() + + // New CA + signer := testkey.MustEC256() + template, err := s.serverCA.CredBuilder().BuildSelfSignedX509CATemplate(context.Background(), credtemplate.SelfSignedX509CAParams{ + PublicKey: signer.Public(), + }) + require.NoError(t, err) + + newCA, err := x509util.CreateCertificate(template, template, signer.Public(), signer) + require.NoError(t, err) + + newCASubjectID := newCA.SubjectKeyId + + // The call to initialize should do the first rotation + cert := s.requireNewCert(stream, big.NewInt(-1)) + + // Run should rotate whenever the certificate is within half of its + // remaining lifetime. + wg.Add(1) + errCh := make(chan error, 1) + go func() { + defer wg.Done() + errCh <- s.r.Run(ctx) + }() + + // Change X509CA + s.serverCA.SetX509CA(&ca.X509CA{ + Signer: signer, + Certificate: newCA, + }) + + s.clock.WaitForTicker(time.Minute, "waiting for the Run() ticker") + + s.r.taintedReceived = make(chan bool, 1) + // Notify that old authority is tainted + s.serverCA.NotifyTaintedX509Authorities(originalCA) + + select { + case received := <-s.r.taintedReceived: + assert.True(t, received) + case <-ctx.Done(): + s.Fail("no notification received") + } + + // Advance interval, so new SVID is signed + s.clock.Add(DefaultRotatorInterval) + cert = s.requireNewCert(stream, cert.SerialNumber) + require.Equal(t, newCASubjectID, cert.AuthorityKeyId) + + // Notify again, must not mark as tainted + s.serverCA.NotifyTaintedX509Authorities(originalCA) + s.clock.Add(DefaultRotatorInterval) + s.requireStateChangeTimeout(stream) + require.False(t, s.r.isSVIDTainted) + + cancel() + s.Require().NoError(<-errCh) +} + func (s *RotatorTestSuite) TestRotationFails() { var wg sync.WaitGroup defer wg.Wait() diff --git a/test/fakes/fakeserverca/serverca.go b/test/fakes/fakeserverca/serverca.go index 1a5d09341b..5a85db8df3 100644 --- a/test/fakes/fakeserverca/serverca.go +++ b/test/fakes/fakeserverca/serverca.go @@ -128,6 +128,10 @@ func (c *CA) SetJWTKey(jwtKey *ca.JWTKey) { c.ca.SetJWTKey(jwtKey) } +func (c *CA) NotifyTaintedX509Authorities(taintedAuthorities []*x509.Certificate) { + c.ca.NotifyTaintedX509Authorities(taintedAuthorities) +} + func (c *CA) SignDownstreamX509CA(ctx context.Context, params ca.DownstreamX509CAParams) ([]*x509.Certificate, error) { if c.err != nil { return nil, c.err @@ -163,6 +167,10 @@ func (c *CA) SignWorkloadJWTSVID(ctx context.Context, params ca.WorkloadJWTSVIDP return c.ca.SignWorkloadJWTSVID(ctx, params) } +func (c *CA) TaintedAuthorities() <-chan []*x509.Certificate { + return c.ca.TaintedAuthorities() +} + func (c *CA) SetError(err error) { c.err = err }