diff --git a/pkg/agent/manager/cache/lru_cache.go b/pkg/agent/manager/cache/lru_cache.go index aa4377c868..90c9415f82 100644 --- a/pkg/agent/manager/cache/lru_cache.go +++ b/pkg/agent/manager/cache/lru_cache.go @@ -24,6 +24,13 @@ const ( DefaultSVIDCacheMaxSize = 1000 // SVIDSyncInterval is the interval at which SVIDs are synced with subscribers SVIDSyncInterval = 500 * time.Millisecond + // Default batch size for processing tainted SVIDs + defaultProcessingBatchSize = 100 +) + +var ( + // Time interval between SVID batch processing + processingTaintedX509SVIDInterval = 5 * time.Second ) // Cache caches each registration entry, bundles, and JWT SVIDs for the agent. @@ -485,84 +492,32 @@ func (c *LRUCache) UpdateSVIDs(update *UpdateSVIDs) { } } -func (c *LRUCache) scheduleRotation(ctx context.Context, entriesToForce []string, taintedX509Authorities []*x509.Certificate) { - // TODO: move to const - batch := 100 - // TODO: search const used for sync interval... - interval := 5 * time.Second - - ticker := c.clk.Ticker(interval) - defer ticker.Stop() - - for len(entriesToForce) > 0 { - if batch > len(entriesToForce) { - batch = len(entriesToForce) - } - processingEntries := entriesToForce[:batch] - start := time.Now() - c.processTaintedSVIDs(processingEntries, taintedX509Authorities) - - c.log.Debugf("******************************************************") - c.log.Debugf("Duration to process %d svids: %v", len(processingEntries), time.Since(start)) - c.log.Debugf("******************************************************") - - entriesToForce = entriesToForce[batch:] - c.log.WithField(telemetry.Count, batch).Debug("entries to process") - - select { - case <-ticker.C: - case <-ctx.Done(): - return - } - } -} - -func (c *LRUCache) processTaintedSVIDs(entries []string, taintedX509Authorities []*x509.Certificate) { - c.mu.Lock() - defer c.mu.Unlock() - - taintedSVIDs := 0 - for _, processingEntry := range entries { - svid, ok := c.svids[processingEntry] - if !ok { - // SVID is not longer there - continue - } - - if svid == nil { - // no SVID stored - continue - } - - if tainted := x509util.IsSignedByRoot(svid.Chain, taintedX509Authorities); tainted { - taintedSVIDs += 1 - delete(c.svids, processingEntry) - } - - } - - agentmetrics.AddCacheManagerTaintedSVIDsSample(c.metrics, "", float32(taintedSVIDs)) - c.log.WithField(telemetry.TaintedSVIDs, taintedSVIDs).Debug("Tainted X.509 SVIDs") -} - +// TaintX509SVIDs initiates the processing of all cached SVIDs, checking if they are tainted by the provided authorities. +// It schedules the processing to run asynchronously in batches. func (c *LRUCache) TaintX509SVIDs(ctx context.Context, taintedX509Authorities []*x509.Certificate) { - // TODO: add elapsed time metrics - c.mu.Lock() - defer c.mu.Unlock() + c.mu.RLock() + defer c.mu.RUnlock() var entriesToProcess []string for key, svid := range c.svids { - // no process already tainted or empty SVIDs - if svid == nil { - continue + if svid != nil { + entriesToProcess = append(entriesToProcess, key) } + } - entriesToProcess = append(entriesToProcess, key) + // Check if there are any entries to process before scheduling + if len(entriesToProcess) == 0 { + c.log.Debug("No SVID entries to process for tainted X.509 authorities") + return } - go c.scheduleRotation(ctx, entriesToProcess, taintedX509Authorities) + // Schedule the rotation process in a separate goroutine + go func() { + c.scheduleRotation(ctx, entriesToProcess, taintedX509Authorities) + }() - c.log.Debug("Scheduling rotation of tainted authorities") + c.log.WithField(telemetry.Count, len(entriesToProcess)). + Debug("Scheduled rotation for SVID entries due to tainted X.509 authorities") } // GetStaleEntries obtains a list of stale entries @@ -603,6 +558,63 @@ func (c *LRUCache) SyncSVIDsWithSubscribers() { c.syncSVIDsWithSubscribers() } +// scheduleRotation processes SVID entries in batches, removing those tainted by X.509 authorities. +// The process continues at regular intervals until all entries have been processed or the context is cancelled. +func (c *LRUCache) scheduleRotation(ctx context.Context, entryIDs []string, taintedX509Authorities []*x509.Certificate) { + ticker := c.clk.Ticker(processingTaintedX509SVIDInterval) + defer ticker.Stop() + + for len(entryIDs) > 0 { + // Processing SVIDs in batches + batchSize := min(defaultProcessingBatchSize, len(entryIDs)) + processingEntries := entryIDs[:batchSize] + + start := time.Now() + c.processTaintedSVIDs(processingEntries, taintedX509Authorities) + + c.log.Debugf("******************************************************") + c.log.Debugf("Processed %d SVIDs in %v", len(processingEntries), time.Since(start)) + c.log.Debugf("******************************************************") + + // Update the list to remove processed entries + entryIDs = entryIDs[batchSize:] + c.log.WithField(telemetry.Count, batchSize).Debug("entries left to process") + + select { + case <-ticker.C: + case <-ctx.Done(): + c.log.Debug("Context cancelled, exiting rotation schedule") + return + } + } +} + +// processTaintedSVIDs identifies and removes tainted SVIDs from the cache that have been signed by the given tainted authorities. +func (c *LRUCache) processTaintedSVIDs(entryIDs []string, taintedX509Authorities []*x509.Certificate) { + taintedSVIDs := 0 + + c.mu.Lock() + defer c.mu.Unlock() + + for _, entryID := range entryIDs { + svid, exists := c.svids[entryID] + if !exists || svid == nil { + // Skip if the SVID is not in cache or is nil + continue + } + + // Check if the SVID is signed by any tainted authority + if tainted := x509util.IsSignedByRoot(svid.Chain, taintedX509Authorities); tainted { + taintedSVIDs++ + delete(c.svids, entryID) + } + + } + + agentmetrics.AddCacheManagerTaintedSVIDsSample(c.metrics, "", float32(taintedSVIDs)) + c.log.WithField(telemetry.TaintedSVIDs, taintedSVIDs).Debug("Tainted X.509 SVIDs") +} + // Notify subscriber of selector set only if all SVIDs for corresponding selector set are cached // It returns whether all SVIDs are cached or not. // This method should be retried with backoff to avoid lock contention.