From 4bfad989cc8c9eb2c7b2a212a6b048dac7b6b3eb Mon Sep 17 00:00:00 2001 From: Marcos Yacob Date: Wed, 28 Aug 2024 19:05:21 -0300 Subject: [PATCH 1/4] * Force rotation of X.509 workload SVIDs in lru cache * Force rotation of X.509 workload SVIDs in store SVID cache * Force rotation of Agent SVID Signed-off-by: Marcos Yacob --- pkg/agent/agent.go | 5 +- pkg/agent/manager/cache/cache.go | 10 +++ pkg/agent/manager/cache/lru_cache.go | 31 +++++++++ pkg/agent/manager/config.go | 3 + pkg/agent/manager/manager.go | 8 +++ pkg/agent/manager/storecache/cache.go | 33 ++++++++++ pkg/agent/manager/sync.go | 94 +++++++++++++++++++++++++-- pkg/agent/svid/rotator.go | 39 ++++++++++- pkg/common/bundleutil/bundle.go | 37 +++++++++-- pkg/common/telemetry/agent/manager.go | 10 +++ pkg/common/telemetry/names.go | 6 ++ pkg/common/x509util/cert.go | 22 +++++++ 12 files changed, 287 insertions(+), 11 deletions(-) diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 433bad4417..01514f307c 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -144,7 +144,7 @@ func (a *Agent) Run(ctx context.Context) error { } } - svidStoreCache := a.newSVIDStoreCache() + svidStoreCache := a.newSVIDStoreCache(metrics) manager, err := a.newManager(ctx, a.sto, cat, metrics, as, svidStoreCache, nodeAttestor) if err != nil { @@ -328,10 +328,11 @@ func (a *Agent) newManager(ctx context.Context, sto storage.Storage, cat catalog } } -func (a *Agent) newSVIDStoreCache() *storecache.Cache { +func (a *Agent) newSVIDStoreCache(metrics telemetry.Metrics) *storecache.Cache { config := &storecache.Config{ Log: a.c.Log.WithField(telemetry.SubsystemName, "svid_store_cache"), TrustDomain: a.c.TrustDomain, + Metrics: metrics, } return storecache.New(config) diff --git a/pkg/agent/manager/cache/cache.go b/pkg/agent/manager/cache/cache.go index 7ad5293090..af8c41b4d1 100644 --- a/pkg/agent/manager/cache/cache.go +++ b/pkg/agent/manager/cache/cache.go @@ -41,6 +41,12 @@ type UpdateEntries struct { // Bundles is a set of ALL trust bundles available to the agent, keyed by trust domain Bundles map[spiffeid.TrustDomain]*spiffebundle.Bundle + // TaintedX509Authorities is a set of all tainted X.509 authorities notified by the server. + TaintedX509Authorities []string + + // TaintedJWTAuthorities is a set of all tainted JWT authorities notified by the server. + TaintedJWTAuthorities []string + // RegistrationEntries is a set of ALL registration entries available to the // agent, keyed by registration entry id. RegistrationEntries map[string]*common.RegistrationEntry @@ -413,6 +419,10 @@ func (c *Cache) UpdateSVIDs(update *UpdateSVIDs) { } } +func (c *Cache) TaintX509SVIDs(taintedX509Authorities []*x509.Certificate) { + // This cache is going to be removed in 1.11... +} + // GetStaleEntries obtains a list of stale entries func (c *Cache) GetStaleEntries() []*StaleEntry { c.mu.Lock() diff --git a/pkg/agent/manager/cache/lru_cache.go b/pkg/agent/manager/cache/lru_cache.go index 8d01fd07bd..b4cb5c775d 100644 --- a/pkg/agent/manager/cache/lru_cache.go +++ b/pkg/agent/manager/cache/lru_cache.go @@ -2,6 +2,7 @@ package cache import ( "context" + "crypto/x509" "fmt" "sort" "sync" @@ -14,6 +15,7 @@ import ( "github.com/spiffe/spire/pkg/agent/common/backoff" "github.com/spiffe/spire/pkg/common/telemetry" agentmetrics "github.com/spiffe/spire/pkg/common/telemetry/agent" + "github.com/spiffe/spire/pkg/common/x509util" "github.com/spiffe/spire/proto/spire/common" ) @@ -483,6 +485,35 @@ func (c *LRUCache) UpdateSVIDs(update *UpdateSVIDs) { } } +func (c *LRUCache) TaintX509SVIDs(taintedX509Authorities []*x509.Certificate) { + // TOOD: add elapsed time metrics + c.mu.Lock() + defer c.mu.Unlock() + + start := time.Now() + + taintedSVIDs := 0 + for key, svid := range c.svids { + // no process already tainted or empty SVIDs + if svid == nil { + continue + } + + if tainted := x509util.IsSignedByRoot(svid.Chain, taintedX509Authorities); tainted { + taintedSVIDs += 1 + delete(c.svids, key) + } + } + + // TODO: remove.... + c.log.Debugf("******************************************************") + c.log.Debugf("Duration to process %d svids: %v", taintedSVIDs, time.Since(start)) + c.log.Debugf("******************************************************") + + agentmetrics.AddCacheManagerExpiredSVIDsSample(c.metrics, "", float32(taintedSVIDs)) + c.log.WithField(telemetry.TaintedSVIDs, taintedSVIDs).Debug("Tainted X.509 SVIDs") +} + // GetStaleEntries obtains a list of stale entries func (c *LRUCache) GetStaleEntries() []*StaleEntry { c.mu.Lock() diff --git a/pkg/agent/manager/config.go b/pkg/agent/manager/config.go index f5d71bbe12..9e3725baee 100644 --- a/pkg/agent/manager/config.go +++ b/pkg/agent/manager/config.go @@ -101,6 +101,9 @@ func newManager(c *Config) *manager { client: client, clk: c.Clk, svidStoreCache: c.SVIDStoreCache, + + processedTaintedX509Authorities: make(map[string]struct{}), + processedTaintedJWTAuthorities: make(map[string]struct{}), } return m diff --git a/pkg/agent/manager/manager.go b/pkg/agent/manager/manager.go index b0dec8290c..a395a1f804 100644 --- a/pkg/agent/manager/manager.go +++ b/pkg/agent/manager/manager.go @@ -171,6 +171,14 @@ type manager struct { // cache. syncedEntries map[string]*common.RegistrationEntry syncedBundles map[string]*common.Bundle + + // processedTaintedX509Authorities holds all the already processed tainted X.509 Authorities + // to prevent processing them again. + processedTaintedX509Authorities map[string]struct{} + + // processedTaintedJWTAuthorities holds all the already processed tainted JWT Authorities + // to prevent processing them again. + processedTaintedJWTAuthorities map[string]struct{} } func (m *manager) Initialize(ctx context.Context) error { diff --git a/pkg/agent/manager/storecache/cache.go b/pkg/agent/manager/storecache/cache.go index f12225bd00..984ab3c340 100644 --- a/pkg/agent/manager/storecache/cache.go +++ b/pkg/agent/manager/storecache/cache.go @@ -1,6 +1,7 @@ package storecache import ( + "crypto/x509" "sort" "sync" "time" @@ -10,6 +11,8 @@ import ( "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/manager/cache" "github.com/spiffe/spire/pkg/common/telemetry" + telemetry_agent "github.com/spiffe/spire/pkg/common/telemetry/agent" + "github.com/spiffe/spire/pkg/common/x509util" "github.com/spiffe/spire/proto/spire/common" ) @@ -46,6 +49,7 @@ type cachedRecord struct { type Config struct { Log logrus.FieldLogger TrustDomain spiffeid.TrustDomain + Metrics telemetry.Metrics } type Cache struct { @@ -219,6 +223,35 @@ func (c *Cache) UpdateSVIDs(update *cache.UpdateSVIDs) { } } +func (c *Cache) TaintX509SVIDs(taintedX509Authorities []*x509.Certificate) { + // TOOD: add elapsed time metrics + c.mtx.Lock() + defer c.mtx.Unlock() + + start := time.Now() + + taintedSVIDs := 0 + for _, record := range c.records { + // no process already tainted or empty SVIDs + if record.svid == nil { + continue + } + + if tainted := x509util.IsSignedByRoot(record.svid.Chain, taintedX509Authorities); tainted { + taintedSVIDs += 1 + record.svid = nil + } + } + + telemetry_agent.AddCacheManagerExpiredSVIDsSample(c.c.Metrics, "svid_store", float32(taintedSVIDs)) + c.c.Log.WithField(telemetry.TaintedSVIDs, taintedSVIDs).Debug("Tainted X.509 SVIDs") + + // TODO: remove.... + c.c.Log.Debugf("******************************************************") + c.c.Log.Debugf("Duration to process %d svids: %v", taintedSVIDs, time.Since(start)) + c.c.Log.Debugf("******************************************************") +} + // GetStaleEntries obtains a list of stale entries, that needs new SVIDs func (c *Cache) GetStaleEntries() []*cache.StaleEntry { c.mtx.Lock() diff --git a/pkg/agent/manager/sync.go b/pkg/agent/manager/sync.go index 141d615fe1..648e3d2c3a 100644 --- a/pkg/agent/manager/sync.go +++ b/pkg/agent/manager/sync.go @@ -4,6 +4,8 @@ import ( "context" "crypto" "crypto/x509" + "fmt" + "strings" "time" "github.com/sirupsen/logrus" @@ -15,6 +17,7 @@ import ( "github.com/spiffe/spire/pkg/common/telemetry" telemetry_agent "github.com/spiffe/spire/pkg/common/telemetry/agent" "github.com/spiffe/spire/pkg/common/util" + "github.com/spiffe/spire/pkg/common/x509util" "github.com/spiffe/spire/proto/spire/common" ) @@ -33,6 +36,10 @@ type SVIDCache interface { // GetStaleEntries gets a list of records that need update SVIDs GetStaleEntries() []*cache.StaleEntry + + // TaintX509SVIDs marks all SVIDs signed by a tainted X.509 authority as tainted + // to force their rotation. + TaintX509SVIDs(taintedX509Authorities []*x509.Certificate) } func (m *manager) syncSVIDs(ctx context.Context) (err error) { @@ -44,6 +51,44 @@ func (m *manager) syncSVIDs(ctx context.Context) (err error) { return nil } +// processTaintedAuthorities verifies if a new authority is tainted and forces rotation in all caches if required. +func (m *manager) processTaintedAuthorities(x509Authorities []string, jwtAuthorities []string) error { + newTaintedX509Authorities := getNewItems(m.processedTaintedX509Authorities, x509Authorities) + if len(newTaintedX509Authorities) > 0 { + m.c.Log.WithField(telemetry.SubjectKeyIDs, strings.Join(newTaintedX509Authorities, ",")). + Debug("New tainted X.509 authorities found") + + taintedX509Authorities, err := bundleutil.FindX509Authorities(m.c.Bundle, newTaintedX509Authorities) + if err != nil { + return fmt.Errorf("failed to search X.509 authorities: %w", err) + } + + // Taint all regular X.509 SVIDs + m.cache.TaintX509SVIDs(taintedX509Authorities) + + // Taint all SVIDStore SVIDs + m.svidStoreCache.TaintX509SVIDs(taintedX509Authorities) + + // Notify rotator about new tainted authorities + if err := m.svid.NotifyTaintedAuthorities(taintedX509Authorities); err != nil { + return err + } + + for _, subjectKeyID := range newTaintedX509Authorities { + m.processedTaintedX509Authorities[subjectKeyID] = struct{}{} + } + } + + newTaintedJWTAuthorities := getNewItems(m.processedTaintedJWTAuthorities, jwtAuthorities) + if len(newTaintedJWTAuthorities) > 0 { + m.c.Log.WithField(telemetry.SubjectKeyIDs, strings.Join(newTaintedJWTAuthorities, ",")). + Debug("New tainted JWT authorities found") + // TODO: IMPLEMENT!!! + } + + return nil +} + // synchronize fetches the authorized entries from the server, updates the // cache, and fetches missing/expiring SVIDs. func (m *manager) synchronize(ctx context.Context) (err error) { @@ -52,6 +97,11 @@ func (m *manager) synchronize(ctx context.Context) (err error) { return err } + // Process all tainted authorities. The bundle is shared between both caches using regular cache data. + if err := m.processTaintedAuthorities(cacheUpdate.TaintedX509Authorities, cacheUpdate.TaintedJWTAuthorities); err != nil { + return err + } + if err := m.updateCache(ctx, cacheUpdate, m.c.Log.WithField(telemetry.CacheType, "workload"), "", m.cache); err != nil { return err } @@ -258,6 +308,27 @@ func (m *manager) fetchEntries(ctx context.Context) (_ *cache.UpdateEntries, _ * return nil, nil, err } + // Get all Subject Key IDs and KeyIDs of tainted authorities + var taintedX509Authorities []string + var taintedJWTAuthorities []string + if b, ok := update.Bundles[m.c.TrustDomain.IDString()]; ok { + for _, rootCA := range b.RootCas { + if rootCA.TaintedKey { + cert, err := x509.ParseCertificate(rootCA.DerBytes) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse tainted x509 authority: %w", err) + } + subjectKeyID := x509util.SubjectKeyIDToString(cert.SubjectKeyId) + taintedX509Authorities = append(taintedX509Authorities, subjectKeyID) + } + } + for _, jwtKey := range b.JwtSigningKeys { + if jwtKey.TaintedKey { + taintedJWTAuthorities = append(taintedJWTAuthorities, jwtKey.Kid) + } + } + } + cacheEntries := make(map[string]*common.RegistrationEntry) storeEntries := make(map[string]*common.RegistrationEntry) @@ -271,11 +342,15 @@ func (m *manager) fetchEntries(ctx context.Context) (_ *cache.UpdateEntries, _ * } return &cache.UpdateEntries{ - Bundles: bundles, - RegistrationEntries: cacheEntries, + Bundles: bundles, + RegistrationEntries: cacheEntries, + TaintedJWTAuthorities: taintedJWTAuthorities, + TaintedX509Authorities: taintedX509Authorities, }, &cache.UpdateEntries{ - Bundles: bundles, - RegistrationEntries: storeEntries, + Bundles: bundles, + RegistrationEntries: storeEntries, + TaintedJWTAuthorities: taintedJWTAuthorities, + TaintedX509Authorities: taintedX509Authorities, }, nil } @@ -307,3 +382,14 @@ func parseBundles(bundles map[string]*common.Bundle) (map[spiffeid.TrustDomain]* } return out, nil } + +func getNewItems(current map[string]struct{}, items []string) []string { + var newItems []string + for _, subjectKeyID := range items { + if _, ok := current[subjectKeyID]; !ok { + newItems = append(newItems, subjectKeyID) + } + } + + return newItems +} diff --git a/pkg/agent/svid/rotator.go b/pkg/agent/svid/rotator.go index fd532c44c8..d22f401adb 100644 --- a/pkg/agent/svid/rotator.go +++ b/pkg/agent/svid/rotator.go @@ -21,12 +21,16 @@ import ( "github.com/spiffe/spire/pkg/common/telemetry" telemetry_agent "github.com/spiffe/spire/pkg/common/telemetry/agent" "github.com/spiffe/spire/pkg/common/util" + "github.com/spiffe/spire/pkg/common/x509util" "google.golang.org/grpc" ) type Rotator interface { Run(ctx context.Context) error Reattest(ctx context.Context) error + // NotifyTaintedAuthorities processes new tainted authorities. If the current SVID is compromised, + // it is marked to force rotation. + NotifyTaintedAuthorities([]*x509.Certificate) error State() State Subscribe() observer.Stream @@ -58,6 +62,8 @@ type rotator struct { // Hook that will be called when the SVID rotation finishes rotationFinishedHook func() + + tainted bool } type State struct { @@ -130,6 +136,24 @@ func (r *rotator) Subscribe() observer.Stream { return r.state.Observe() } +func (r *rotator) NotifyTaintedAuthorities(taintedAuthorities []*x509.Certificate) error { + state, ok := r.state.Value().(State) + if !ok { + return fmt.Errorf("unexpected value type: %T", r.state.Value()) + } + + if r.tainted { + // Already tainted... + return nil + } + + r.tainted = x509util.IsSignedByRoot(state.SVID, taintedAuthorities) + if r.tainted { + r.c.Log.Debug("Agent SVID is tainted, forcing rotation...") + } + return nil +} + func (r *rotator) GetRotationMtx() *sync.RWMutex { return r.rotMtx } @@ -162,7 +186,7 @@ func (r *rotator) rotateSVIDIfNeeded(ctx context.Context) (err error) { return fmt.Errorf("unexpected value type: %T", r.state.Value()) } - if r.c.RotationStrategy.ShouldRotateX509(r.clk.Now(), state.SVID[0]) { + if r.c.RotationStrategy.ShouldRotateX509(r.clk.Now(), state.SVID[0]) || r.tainted { if state.Reattestable { err = r.reattest(ctx) } else { @@ -222,6 +246,7 @@ func (r *rotator) reattest(ctx context.Context) (err error) { } r.state.Update(s) + r.tainted = false // We must release the client because its underlaying connection is tied to an // expired SVID, so next time the client is used, it will get a new connection with @@ -269,6 +294,7 @@ func (r *rotator) rotateSVID(ctx context.Context) (err error) { } r.state.Update(s) + r.tainted = false // We must release the client because its underlaying connection is tied to an // expired SVID, so next time the client is used, it will get a new connection with @@ -323,3 +349,14 @@ func rotationError(state State) string { return "rotate agent SVID" } + +func getNewItems(current map[string]struct{}, items []string) []string { + var newItems []string + for _, subjectKeyID := range items { + if _, ok := current[subjectKeyID]; !ok { + newItems = append(newItems, subjectKeyID) + } + } + + return newItems +} diff --git a/pkg/common/bundleutil/bundle.go b/pkg/common/bundleutil/bundle.go index ac3dd8bff9..592c348648 100644 --- a/pkg/common/bundleutil/bundle.go +++ b/pkg/common/bundleutil/bundle.go @@ -12,6 +12,7 @@ import ( "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/spiffe/spire/pkg/common/telemetry" + "github.com/spiffe/spire/pkg/common/x509util" "github.com/spiffe/spire/proto/spire/common" "google.golang.org/protobuf/proto" ) @@ -29,7 +30,8 @@ func CommonBundleFromProto(b *types.Bundle) (*common.Bundle, error) { var rootCAs []*common.Certificate for _, rootCA := range b.X509Authorities { rootCAs = append(rootCAs, &common.Certificate{ - DerBytes: rootCA.Asn1, + DerBytes: rootCA.Asn1, + TaintedKey: rootCA.Tainted, }) } @@ -40,9 +42,10 @@ func CommonBundleFromProto(b *types.Bundle) (*common.Bundle, error) { } jwtKeys = append(jwtKeys, &common.PublicKey{ - PkixBytes: key.PublicKey, - Kid: key.KeyId, - NotAfter: key.ExpiresAt, + PkixBytes: key.PublicKey, + Kid: key.KeyId, + NotAfter: key.ExpiresAt, + TaintedKey: key.Tainted, }) } @@ -237,6 +240,32 @@ pruneRootCA: return newBundle, changed, nil } +// FindX509Authorities search for all X.509 authorities with provided subjectKeyIDs +func FindX509Authorities(bundle *spiffebundle.Bundle, subjectKeyIDs []string) ([]*x509.Certificate, error) { + var x509Authorities []*x509.Certificate + for _, subjectKeyID := range subjectKeyIDs { + x509Authority, err := getX509Authority(bundle, subjectKeyID) + if err != nil { + return nil, err + } + + x509Authorities = append(x509Authorities, x509Authority) + } + + return x509Authorities, nil +} + +func getX509Authority(bundle *spiffebundle.Bundle, subjectKeyID string) (*x509.Certificate, error) { + for _, x509Authority := range bundle.X509Authorities() { + authoritySKID := x509util.SubjectKeyIDToString(x509Authority.SubjectKeyId) + if authoritySKID == subjectKeyID { + return x509Authority, nil + } + } + + return nil, fmt.Errorf("no X.509 authority found with SubjectKeyID %q", subjectKeyID) +} + func cloneBundle(b *common.Bundle) *common.Bundle { return proto.Clone(b).(*common.Bundle) } diff --git a/pkg/common/telemetry/agent/manager.go b/pkg/common/telemetry/agent/manager.go index bd17d420df..b30105a397 100644 --- a/pkg/common/telemetry/agent/manager.go +++ b/pkg/common/telemetry/agent/manager.go @@ -46,6 +46,16 @@ func AddCacheManagerOutdatedSVIDsSample(m telemetry.Metrics, cacheType string, c m.AddSample(key, count) } +// AddCacheManagerTaintedSVIDsSample count of tainted SVIDs according to +// agent cache manager +func AddCacheManagerTaintedSVIDsSample(m telemetry.Metrics, cacheType string, count float32) { + key := []string{telemetry.CacheManager, cacheType, telemetry.TaintedSVIDs} + if cacheType != "" { + key = append(key, cacheType) + } + m.AddSample(key, count) +} + // End Add Samples func SetSyncStats(m telemetry.Metrics, stats client.SyncStats) { diff --git a/pkg/common/telemetry/names.go b/pkg/common/telemetry/names.go index b24ed815d6..1b9851177b 100644 --- a/pkg/common/telemetry/names.go +++ b/pkg/common/telemetry/names.go @@ -547,6 +547,9 @@ const ( // SubjectKeyID tags a certificate subject key ID SubjectKeyID = "subject_key_id" + // SubjectKeyIDs tags a list of subject key ID + SubjectKeyIDs = "subject_key_ids" + // SVIDMapSize is the gauge key for the size of the LRU cache SVID map SVIDMapSize = "lru_cache_svid_map_size" @@ -777,6 +780,9 @@ const ( // RegistrationManager functionality related to a registration manager RegistrationManager = "registration_manager" + //TaintedSVIDs tags tainted SVID count/list + TaintedSVIDs = "tainted_svids" + // Telemetry tags a telemetry module Telemetry = "telemetry" diff --git a/pkg/common/x509util/cert.go b/pkg/common/x509util/cert.go index 28ce5f960e..1fba0d796b 100644 --- a/pkg/common/x509util/cert.go +++ b/pkg/common/x509util/cert.go @@ -72,3 +72,25 @@ func RawCertsFromCertificates(certs []*x509.Certificate) [][]byte { } return rawCerts } + +// IsSignedByRoot checks if the provided certificate chain is signed by one of the specified root CAs. +func IsSignedByRoot(chain []*x509.Certificate, rootCAs []*x509.Certificate) bool { + rootPool := x509.NewCertPool() + for _, x509Authority := range rootCAs { + rootPool.AddCert(x509Authority) + } + + intermediatePool := x509.NewCertPool() + for _, intermediateCA := range chain[1:] { + intermediatePool.AddCert(intermediateCA) + } + + // Verify certificate chain, using tainted authorities as root + _, err := chain[0].Verify(x509.VerifyOptions{ + Intermediates: intermediatePool, + Roots: rootPool, + }) + + // TODO: may we verify if error is different to Signed by unknown authority? + return err == nil +} From 686b8fc7259c82a223a111cd4796f8255ca10069 Mon Sep 17 00:00:00 2001 From: Marcos Yacob Date: Fri, 6 Sep 2024 13:54:23 -0300 Subject: [PATCH 2/4] Another option Signed-off-by: Marcos Yacob --- pkg/agent/manager/cache/cache.go | 2 +- pkg/agent/manager/cache/lru_cache.go | 69 ++++++++++++++++++++++----- pkg/agent/manager/storecache/cache.go | 3 +- pkg/agent/manager/sync.go | 10 ++-- pkg/common/x509util/cert.go | 3 ++ 5 files changed, 67 insertions(+), 20 deletions(-) diff --git a/pkg/agent/manager/cache/cache.go b/pkg/agent/manager/cache/cache.go index af8c41b4d1..9b6a4dbb63 100644 --- a/pkg/agent/manager/cache/cache.go +++ b/pkg/agent/manager/cache/cache.go @@ -419,7 +419,7 @@ func (c *Cache) UpdateSVIDs(update *UpdateSVIDs) { } } -func (c *Cache) TaintX509SVIDs(taintedX509Authorities []*x509.Certificate) { +func (c *Cache) TaintX509SVIDs(context.Context, []*x509.Certificate) { // This cache is going to be removed in 1.11... } diff --git a/pkg/agent/manager/cache/lru_cache.go b/pkg/agent/manager/cache/lru_cache.go index b4cb5c775d..e256cfc256 100644 --- a/pkg/agent/manager/cache/lru_cache.go +++ b/pkg/agent/manager/cache/lru_cache.go @@ -485,35 +485,78 @@ func (c *LRUCache) UpdateSVIDs(update *UpdateSVIDs) { } } -func (c *LRUCache) TaintX509SVIDs(taintedX509Authorities []*x509.Certificate) { - // TOOD: add elapsed time metrics +func (c *LRUCache) scheduleRotation(ctx context.Context, entriesToForce []string, taintedX509Authorities []*x509.Certificate) { + // TODO: move to const + batch := 100 + // TODO: search const used for sync interval... + interval := 5 * time.Second + + ticker := c.clk.Ticker(interval) + defer ticker.Stop() + + for len(entriesToForce) > 0 { + processingEntries := entriesToForce[:batch] + c.processTaintedSVIDs(processingEntries, taintedX509Authorities) + + processingEntries = processingEntries[batch:] + c.log.WithField(telemetry.Count, len(processingEntries)).Debug("entries to process") + + select { + case <-ticker.C: + case <-ctx.Done(): + return + } + } +} + +func (c *LRUCache) processTaintedSVIDs(entries []string, taintedX509Authorities []*x509.Certificate) { c.mu.Lock() defer c.mu.Unlock() - start := time.Now() - taintedSVIDs := 0 - for key, svid := range c.svids { - // no process already tainted or empty SVIDs + for _, processingEntry := range entries { + svid, ok := c.svids[processingEntry] + if !ok { + // SVID is not longer there + continue + } + if svid == nil { + // no SVID stored continue } if tainted := x509util.IsSignedByRoot(svid.Chain, taintedX509Authorities); tainted { taintedSVIDs += 1 - delete(c.svids, key) + delete(c.svids, processingEntry) } - } - // TODO: remove.... - c.log.Debugf("******************************************************") - c.log.Debugf("Duration to process %d svids: %v", taintedSVIDs, time.Since(start)) - c.log.Debugf("******************************************************") + } - agentmetrics.AddCacheManagerExpiredSVIDsSample(c.metrics, "", float32(taintedSVIDs)) + agentmetrics.AddCacheManagerTaintedSVIDsSample(c.metrics, "", float32(taintedSVIDs)) c.log.WithField(telemetry.TaintedSVIDs, taintedSVIDs).Debug("Tainted X.509 SVIDs") } +func (c *LRUCache) TaintX509SVIDs(ctx context.Context, taintedX509Authorities []*x509.Certificate) { + // TODO: add elapsed time metrics + c.mu.Lock() + defer c.mu.Unlock() + + var entriesToProcess []string + for key, svid := range c.svids { + // no process already tainted or empty SVIDs + if svid == nil { + continue + } + + entriesToProcess = append(entriesToProcess, key) + } + + go c.scheduleRotation(ctx, entriesToProcess, taintedX509Authorities) + + c.log.Debug("Scheduling rotation of tainted authorities") +} + // GetStaleEntries obtains a list of stale entries func (c *LRUCache) GetStaleEntries() []*StaleEntry { c.mu.Lock() diff --git a/pkg/agent/manager/storecache/cache.go b/pkg/agent/manager/storecache/cache.go index 984ab3c340..5dbb9fb1d8 100644 --- a/pkg/agent/manager/storecache/cache.go +++ b/pkg/agent/manager/storecache/cache.go @@ -1,6 +1,7 @@ package storecache import ( + "context" "crypto/x509" "sort" "sync" @@ -223,7 +224,7 @@ func (c *Cache) UpdateSVIDs(update *cache.UpdateSVIDs) { } } -func (c *Cache) TaintX509SVIDs(taintedX509Authorities []*x509.Certificate) { +func (c *Cache) TaintX509SVIDs(ctx context.Context, taintedX509Authorities []*x509.Certificate) { // TOOD: add elapsed time metrics c.mtx.Lock() defer c.mtx.Unlock() diff --git a/pkg/agent/manager/sync.go b/pkg/agent/manager/sync.go index 648e3d2c3a..e71a0387a9 100644 --- a/pkg/agent/manager/sync.go +++ b/pkg/agent/manager/sync.go @@ -39,7 +39,7 @@ type SVIDCache interface { // TaintX509SVIDs marks all SVIDs signed by a tainted X.509 authority as tainted // to force their rotation. - TaintX509SVIDs(taintedX509Authorities []*x509.Certificate) + TaintX509SVIDs(ctx context.Context, taintedX509Authorities []*x509.Certificate) } func (m *manager) syncSVIDs(ctx context.Context) (err error) { @@ -52,7 +52,7 @@ func (m *manager) syncSVIDs(ctx context.Context) (err error) { } // processTaintedAuthorities verifies if a new authority is tainted and forces rotation in all caches if required. -func (m *manager) processTaintedAuthorities(x509Authorities []string, jwtAuthorities []string) error { +func (m *manager) processTaintedAuthorities(ctx context.Context, x509Authorities []string, jwtAuthorities []string) error { newTaintedX509Authorities := getNewItems(m.processedTaintedX509Authorities, x509Authorities) if len(newTaintedX509Authorities) > 0 { m.c.Log.WithField(telemetry.SubjectKeyIDs, strings.Join(newTaintedX509Authorities, ",")). @@ -64,10 +64,10 @@ func (m *manager) processTaintedAuthorities(x509Authorities []string, jwtAuthori } // Taint all regular X.509 SVIDs - m.cache.TaintX509SVIDs(taintedX509Authorities) + m.cache.TaintX509SVIDs(ctx, taintedX509Authorities) // Taint all SVIDStore SVIDs - m.svidStoreCache.TaintX509SVIDs(taintedX509Authorities) + m.svidStoreCache.TaintX509SVIDs(ctx, taintedX509Authorities) // Notify rotator about new tainted authorities if err := m.svid.NotifyTaintedAuthorities(taintedX509Authorities); err != nil { @@ -98,7 +98,7 @@ func (m *manager) synchronize(ctx context.Context) (err error) { } // Process all tainted authorities. The bundle is shared between both caches using regular cache data. - if err := m.processTaintedAuthorities(cacheUpdate.TaintedX509Authorities, cacheUpdate.TaintedJWTAuthorities); err != nil { + if err := m.processTaintedAuthorities(ctx, cacheUpdate.TaintedX509Authorities, cacheUpdate.TaintedJWTAuthorities); err != nil { return err } diff --git a/pkg/common/x509util/cert.go b/pkg/common/x509util/cert.go index 1fba0d796b..84316f7e51 100644 --- a/pkg/common/x509util/cert.go +++ b/pkg/common/x509util/cert.go @@ -75,6 +75,9 @@ func RawCertsFromCertificates(certs []*x509.Certificate) [][]byte { // IsSignedByRoot checks if the provided certificate chain is signed by one of the specified root CAs. func IsSignedByRoot(chain []*x509.Certificate, rootCAs []*x509.Certificate) bool { + if len(chain) == 0 { + return false + } rootPool := x509.NewCertPool() for _, x509Authority := range rootCAs { rootPool.AddCert(x509Authority) From 2dc547b1018a14c0df7cd61164c0d586c0f66369 Mon Sep 17 00:00:00 2001 From: Marcos Yacob Date: Fri, 6 Sep 2024 14:43:38 -0300 Subject: [PATCH 3/4] resolve issues.. Signed-off-by: Marcos Yacob --- pkg/agent/manager/cache/lru_cache.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pkg/agent/manager/cache/lru_cache.go b/pkg/agent/manager/cache/lru_cache.go index e256cfc256..aa4377c868 100644 --- a/pkg/agent/manager/cache/lru_cache.go +++ b/pkg/agent/manager/cache/lru_cache.go @@ -495,11 +495,19 @@ func (c *LRUCache) scheduleRotation(ctx context.Context, entriesToForce []string defer ticker.Stop() for len(entriesToForce) > 0 { + if batch > len(entriesToForce) { + batch = len(entriesToForce) + } processingEntries := entriesToForce[:batch] + start := time.Now() c.processTaintedSVIDs(processingEntries, taintedX509Authorities) - processingEntries = processingEntries[batch:] - c.log.WithField(telemetry.Count, len(processingEntries)).Debug("entries to process") + c.log.Debugf("******************************************************") + c.log.Debugf("Duration to process %d svids: %v", len(processingEntries), time.Since(start)) + c.log.Debugf("******************************************************") + + entriesToForce = entriesToForce[batch:] + c.log.WithField(telemetry.Count, batch).Debug("entries to process") select { case <-ticker.C: From e209b7d1644ecded346a1e78f263962dd0e3dd55 Mon Sep 17 00:00:00 2001 From: Marcos Yacob Date: Fri, 6 Sep 2024 16:09:48 -0300 Subject: [PATCH 4/4] some cleanup Signed-off-by: Marcos Yacob --- pkg/agent/manager/cache/lru_cache.go | 150 +++++++++++++++------------ 1 file changed, 81 insertions(+), 69 deletions(-) diff --git a/pkg/agent/manager/cache/lru_cache.go b/pkg/agent/manager/cache/lru_cache.go index aa4377c868..90c9415f82 100644 --- a/pkg/agent/manager/cache/lru_cache.go +++ b/pkg/agent/manager/cache/lru_cache.go @@ -24,6 +24,13 @@ const ( DefaultSVIDCacheMaxSize = 1000 // SVIDSyncInterval is the interval at which SVIDs are synced with subscribers SVIDSyncInterval = 500 * time.Millisecond + // Default batch size for processing tainted SVIDs + defaultProcessingBatchSize = 100 +) + +var ( + // Time interval between SVID batch processing + processingTaintedX509SVIDInterval = 5 * time.Second ) // Cache caches each registration entry, bundles, and JWT SVIDs for the agent. @@ -485,84 +492,32 @@ func (c *LRUCache) UpdateSVIDs(update *UpdateSVIDs) { } } -func (c *LRUCache) scheduleRotation(ctx context.Context, entriesToForce []string, taintedX509Authorities []*x509.Certificate) { - // TODO: move to const - batch := 100 - // TODO: search const used for sync interval... - interval := 5 * time.Second - - ticker := c.clk.Ticker(interval) - defer ticker.Stop() - - for len(entriesToForce) > 0 { - if batch > len(entriesToForce) { - batch = len(entriesToForce) - } - processingEntries := entriesToForce[:batch] - start := time.Now() - c.processTaintedSVIDs(processingEntries, taintedX509Authorities) - - c.log.Debugf("******************************************************") - c.log.Debugf("Duration to process %d svids: %v", len(processingEntries), time.Since(start)) - c.log.Debugf("******************************************************") - - entriesToForce = entriesToForce[batch:] - c.log.WithField(telemetry.Count, batch).Debug("entries to process") - - select { - case <-ticker.C: - case <-ctx.Done(): - return - } - } -} - -func (c *LRUCache) processTaintedSVIDs(entries []string, taintedX509Authorities []*x509.Certificate) { - c.mu.Lock() - defer c.mu.Unlock() - - taintedSVIDs := 0 - for _, processingEntry := range entries { - svid, ok := c.svids[processingEntry] - if !ok { - // SVID is not longer there - continue - } - - if svid == nil { - // no SVID stored - continue - } - - if tainted := x509util.IsSignedByRoot(svid.Chain, taintedX509Authorities); tainted { - taintedSVIDs += 1 - delete(c.svids, processingEntry) - } - - } - - agentmetrics.AddCacheManagerTaintedSVIDsSample(c.metrics, "", float32(taintedSVIDs)) - c.log.WithField(telemetry.TaintedSVIDs, taintedSVIDs).Debug("Tainted X.509 SVIDs") -} - +// TaintX509SVIDs initiates the processing of all cached SVIDs, checking if they are tainted by the provided authorities. +// It schedules the processing to run asynchronously in batches. func (c *LRUCache) TaintX509SVIDs(ctx context.Context, taintedX509Authorities []*x509.Certificate) { - // TODO: add elapsed time metrics - c.mu.Lock() - defer c.mu.Unlock() + c.mu.RLock() + defer c.mu.RUnlock() var entriesToProcess []string for key, svid := range c.svids { - // no process already tainted or empty SVIDs - if svid == nil { - continue + if svid != nil { + entriesToProcess = append(entriesToProcess, key) } + } - entriesToProcess = append(entriesToProcess, key) + // Check if there are any entries to process before scheduling + if len(entriesToProcess) == 0 { + c.log.Debug("No SVID entries to process for tainted X.509 authorities") + return } - go c.scheduleRotation(ctx, entriesToProcess, taintedX509Authorities) + // Schedule the rotation process in a separate goroutine + go func() { + c.scheduleRotation(ctx, entriesToProcess, taintedX509Authorities) + }() - c.log.Debug("Scheduling rotation of tainted authorities") + c.log.WithField(telemetry.Count, len(entriesToProcess)). + Debug("Scheduled rotation for SVID entries due to tainted X.509 authorities") } // GetStaleEntries obtains a list of stale entries @@ -603,6 +558,63 @@ func (c *LRUCache) SyncSVIDsWithSubscribers() { c.syncSVIDsWithSubscribers() } +// scheduleRotation processes SVID entries in batches, removing those tainted by X.509 authorities. +// The process continues at regular intervals until all entries have been processed or the context is cancelled. +func (c *LRUCache) scheduleRotation(ctx context.Context, entryIDs []string, taintedX509Authorities []*x509.Certificate) { + ticker := c.clk.Ticker(processingTaintedX509SVIDInterval) + defer ticker.Stop() + + for len(entryIDs) > 0 { + // Processing SVIDs in batches + batchSize := min(defaultProcessingBatchSize, len(entryIDs)) + processingEntries := entryIDs[:batchSize] + + start := time.Now() + c.processTaintedSVIDs(processingEntries, taintedX509Authorities) + + c.log.Debugf("******************************************************") + c.log.Debugf("Processed %d SVIDs in %v", len(processingEntries), time.Since(start)) + c.log.Debugf("******************************************************") + + // Update the list to remove processed entries + entryIDs = entryIDs[batchSize:] + c.log.WithField(telemetry.Count, batchSize).Debug("entries left to process") + + select { + case <-ticker.C: + case <-ctx.Done(): + c.log.Debug("Context cancelled, exiting rotation schedule") + return + } + } +} + +// processTaintedSVIDs identifies and removes tainted SVIDs from the cache that have been signed by the given tainted authorities. +func (c *LRUCache) processTaintedSVIDs(entryIDs []string, taintedX509Authorities []*x509.Certificate) { + taintedSVIDs := 0 + + c.mu.Lock() + defer c.mu.Unlock() + + for _, entryID := range entryIDs { + svid, exists := c.svids[entryID] + if !exists || svid == nil { + // Skip if the SVID is not in cache or is nil + continue + } + + // Check if the SVID is signed by any tainted authority + if tainted := x509util.IsSignedByRoot(svid.Chain, taintedX509Authorities); tainted { + taintedSVIDs++ + delete(c.svids, entryID) + } + + } + + agentmetrics.AddCacheManagerTaintedSVIDsSample(c.metrics, "", float32(taintedSVIDs)) + c.log.WithField(telemetry.TaintedSVIDs, taintedSVIDs).Debug("Tainted X.509 SVIDs") +} + // Notify subscriber of selector set only if all SVIDs for corresponding selector set are cached // It returns whether all SVIDs are cached or not. // This method should be retried with backoff to avoid lock contention.