diff --git a/pkg/api/controller.go b/pkg/api/controller.go index 09fc6988d1..1c04345237 100644 --- a/pkg/api/controller.go +++ b/pkg/api/controller.go @@ -49,6 +49,7 @@ type Controller struct { SyncOnDemand SyncOnDemand RelyingParties map[string]rp.RelyingParty CookieStore *CookieStore + taskScheduler *scheduler.Scheduler // runtime params chosenPort int // kernel-chosen port } @@ -366,13 +367,14 @@ func (c *Controller) LoadNewConfig(reloadCtx context.Context, newConfig *config. } func (c *Controller) Shutdown() { + c.taskScheduler.Shutdown() ctx := context.Background() _ = c.Server.Shutdown(ctx) } func (c *Controller) StartBackgroundTasks(reloadCtx context.Context) { - taskScheduler := scheduler.NewScheduler(c.Config, c.Log) - taskScheduler.RunScheduler(reloadCtx) + c.taskScheduler = scheduler.NewScheduler(c.Config, c.Log) + c.taskScheduler.RunScheduler(reloadCtx) // Enable running garbage-collect periodically for DefaultStore if c.Config.Storage.GC { @@ -381,20 +383,20 @@ func (c *Controller) StartBackgroundTasks(reloadCtx context.Context) { ImageRetention: c.Config.Storage.Retention, }, c.Audit, c.Log) - gc.CleanImageStorePeriodically(c.Config.Storage.GCInterval, taskScheduler) + gc.CleanImageStorePeriodically(c.Config.Storage.GCInterval, c.taskScheduler) } // Enable running dedupe blobs both ways (dedupe or restore deduped blobs) - c.StoreController.DefaultStore.RunDedupeBlobs(time.Duration(0), taskScheduler) + c.StoreController.DefaultStore.RunDedupeBlobs(time.Duration(0), c.taskScheduler) // Enable extensions if extension config is provided for DefaultStore if c.Config != nil && c.Config.Extensions != nil { ext.EnableMetricsExtension(c.Config, c.Log, c.Config.Storage.RootDirectory) - ext.EnableSearchExtension(c.Config, c.StoreController, c.MetaDB, taskScheduler, c.CveScanner, c.Log) + ext.EnableSearchExtension(c.Config, c.StoreController, c.MetaDB, c.taskScheduler, c.CveScanner, c.Log) } // runs once if metrics are enabled & imagestore is local if c.Config.IsMetricsEnabled() && c.Config.Storage.StorageDriver == nil { - c.StoreController.DefaultStore.PopulateStorageMetrics(time.Duration(0), taskScheduler) + c.StoreController.DefaultStore.PopulateStorageMetrics(time.Duration(0), c.taskScheduler) } if c.Config.Storage.SubPaths != nil { @@ -407,7 +409,7 @@ func (c *Controller) StartBackgroundTasks(reloadCtx context.Context) { ImageRetention: storageConfig.Retention, }, c.Audit, c.Log) - gc.CleanImageStorePeriodically(storageConfig.GCInterval, taskScheduler) + gc.CleanImageStorePeriodically(storageConfig.GCInterval, c.taskScheduler) } // Enable extensions if extension config is provided for subImageStore @@ -418,19 +420,19 @@ func (c *Controller) StartBackgroundTasks(reloadCtx context.Context) { // Enable running dedupe blobs both ways (dedupe or restore deduped blobs) for subpaths substore := c.StoreController.SubStore[route] if substore != nil { - substore.RunDedupeBlobs(time.Duration(0), taskScheduler) + substore.RunDedupeBlobs(time.Duration(0), c.taskScheduler) if c.Config.IsMetricsEnabled() && c.Config.Storage.StorageDriver == nil { - substore.PopulateStorageMetrics(time.Duration(0), taskScheduler) + substore.PopulateStorageMetrics(time.Duration(0), c.taskScheduler) } } } } if c.Config.Extensions != nil { - ext.EnableScrubExtension(c.Config, c.Log, c.StoreController, taskScheduler) + ext.EnableScrubExtension(c.Config, c.Log, c.StoreController, c.taskScheduler) //nolint: contextcheck - syncOnDemand, err := ext.EnableSyncExtension(c.Config, c.MetaDB, c.StoreController, taskScheduler, c.Log) + syncOnDemand, err := ext.EnableSyncExtension(c.Config, c.MetaDB, c.StoreController, c.taskScheduler, c.Log) if err != nil { c.Log.Error().Err(err).Msg("unable to start sync extension") } @@ -439,11 +441,11 @@ func (c *Controller) StartBackgroundTasks(reloadCtx context.Context) { } if c.CookieStore != nil { - c.CookieStore.RunSessionCleaner(taskScheduler) + c.CookieStore.RunSessionCleaner(c.taskScheduler) } // we can later move enabling the other scheduled tasks inside the call below - ext.EnableScheduledTasks(c.Config, taskScheduler, c.MetaDB, c.Log) //nolint: contextcheck + ext.EnableScheduledTasks(c.Config, c.taskScheduler, c.MetaDB, c.Log) //nolint: contextcheck } type SyncOnDemand interface { diff --git a/pkg/api/controller_test.go b/pkg/api/controller_test.go index 9e245168d4..ee8ea175d8 100644 --- a/pkg/api/controller_test.go +++ b/pkg/api/controller_test.go @@ -8378,7 +8378,7 @@ func TestGCSignaturesAndUntaggedManifestsWithMetaDB(t *testing.T) { So(len(index.Manifests), ShouldEqual, 1) // shouldn't do anything - err = gc.CleanRepo(repoName) //nolint: contextcheck + err = gc.CleanRepo(ctx, repoName) //nolint: contextcheck So(err, ShouldBeNil) // make sure both signatures are stored in repodb @@ -8404,7 +8404,7 @@ func TestGCSignaturesAndUntaggedManifestsWithMetaDB(t *testing.T) { err = UploadImage(img, baseURL, repoName, img.DigestStr()) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldNotBeNil) err = os.Chmod(path.Join(dir, repoName, "blobs", "sha256", refs.Manifests[0].Digest.Encoded()), 0o755) @@ -8418,7 +8418,7 @@ func TestGCSignaturesAndUntaggedManifestsWithMetaDB(t *testing.T) { err = UploadImage(img, baseURL, repoName, tag) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldNotBeNil) err = os.WriteFile(path.Join(dir, repoName, "blobs", "sha256", refs.Manifests[0].Digest.Encoded()), content, 0o600) @@ -8469,7 +8469,7 @@ func TestGCSignaturesAndUntaggedManifestsWithMetaDB(t *testing.T) { So(err, ShouldBeNil) newManifestDigest := godigest.FromBytes(manifestBuf) - err = gc.CleanRepo(repoName) //nolint: contextcheck + err = gc.CleanRepo(ctx, repoName) //nolint: contextcheck So(err, ShouldBeNil) // make sure both signatures are removed from metaDB and repo reference for untagged is removed @@ -8548,7 +8548,7 @@ func TestGCSignaturesAndUntaggedManifestsWithMetaDB(t *testing.T) { So(err, ShouldBeNil) cm := test.NewControllerManager(ctlr) - cm.StartAndWait(port) + cm.StartAndWait(port) //nolint: contextcheck defer cm.StopServer() gc := gc.NewGarbageCollect(ctlr.StoreController.DefaultStore, ctlr.MetaDB, @@ -8606,7 +8606,7 @@ func TestGCSignaturesAndUntaggedManifestsWithMetaDB(t *testing.T) { So(err, ShouldBeNil) So(resp.StatusCode(), ShouldEqual, http.StatusCreated) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) resp, err = resty.R().SetHeader("Content-Type", ispec.MediaTypeImageIndex). diff --git a/pkg/cli/client/client.go b/pkg/cli/client/client.go index 9388ea42ef..0c33bfc4a0 100644 --- a/pkg/cli/client/client.go +++ b/pkg/cli/client/client.go @@ -212,7 +212,7 @@ func (p *requestsPool) doJob(ctx context.Context, job *httpJob) { header, err := makeHEADRequest(ctx, job.url, job.username, job.password, job.config.VerifyTLS, job.config.Debug) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } p.outputCh <- stringResult{"", err} @@ -224,7 +224,7 @@ func (p *requestsPool) doJob(ctx context.Context, job *httpJob) { case ispec.MediaTypeImageManifest: image, err := fetchImageManifestStruct(ctx, job) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } p.outputCh <- stringResult{"", err} @@ -235,7 +235,7 @@ func (p *requestsPool) doJob(ctx context.Context, job *httpJob) { str, err := image.string(job.config.OutputFormat, len(job.imageName), len(job.tagName), len(platformStr), verbose) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } p.outputCh <- stringResult{"", err} @@ -243,7 +243,7 @@ func (p *requestsPool) doJob(ctx context.Context, job *httpJob) { return } - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } @@ -251,7 +251,7 @@ func (p *requestsPool) doJob(ctx context.Context, job *httpJob) { case ispec.MediaTypeImageIndex: image, err := fetchImageIndexStruct(ctx, job) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } p.outputCh <- stringResult{"", err} @@ -263,7 +263,7 @@ func (p *requestsPool) doJob(ctx context.Context, job *httpJob) { str, err := image.string(job.config.OutputFormat, len(job.imageName), len(job.tagName), len(platformStr), verbose) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } p.outputCh <- stringResult{"", err} @@ -271,7 +271,7 @@ func (p *requestsPool) doJob(ctx context.Context, job *httpJob) { return } - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } @@ -287,7 +287,7 @@ func fetchImageIndexStruct(ctx context.Context, job *httpJob) (*imageStruct, err header, err := makeGETRequest(ctx, job.url, job.username, job.password, job.config.VerifyTLS, job.config.Debug, &indexContent, job.config.ResultWriter) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return nil, context.Canceled } @@ -378,7 +378,7 @@ func fetchManifestStruct(ctx context.Context, repo, manifestReference string, se header, err := makeGETRequest(ctx, URL, username, password, searchConf.VerifyTLS, searchConf.Debug, &manifestResp, searchConf.ResultWriter) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return common.ManifestSummary{}, context.Canceled } @@ -390,7 +390,7 @@ func fetchManifestStruct(ctx context.Context, repo, manifestReference string, se configContent, err := fetchConfig(ctx, repo, configDigest, searchConf, username, password) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return common.ManifestSummary{}, context.Canceled } @@ -467,7 +467,7 @@ func fetchConfig(ctx context.Context, repo, configDigest string, searchConf Sear _, err := makeGETRequest(ctx, URL, username, password, searchConf.VerifyTLS, searchConf.Debug, &configContent, searchConf.ResultWriter) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return ispec.Image{}, context.Canceled } diff --git a/pkg/cli/client/cve_cmd_test.go b/pkg/cli/client/cve_cmd_test.go index 6c46968d9a..43b08c3e51 100644 --- a/pkg/cli/client/cve_cmd_test.go +++ b/pkg/cli/client/cve_cmd_test.go @@ -583,7 +583,7 @@ func TestCVESort(t *testing.T) { } ctlr.CveScanner = mocks.CveScannerMock{ - ScanImageFn: func(image string) (map[string]cvemodel.CVE, error) { + ScanImageFn: func(ctx context.Context, image string) (map[string]cvemodel.CVE, error) { return map[string]cvemodel.CVE{ "CVE-2023-1255": { ID: "CVE-2023-1255", @@ -687,7 +687,7 @@ func getMockCveScanner(metaDB mTypes.MetaDB) cveinfo.Scanner { // MetaDB loaded with initial data now mock the scanner // Setup test CVE data in mock scanner scanner := mocks.CveScannerMock{ - ScanImageFn: func(image string) (map[string]cvemodel.CVE, error) { + ScanImageFn: func(ctx context.Context, image string) (map[string]cvemodel.CVE, error) { if strings.Contains(image, "zot-cve-test@sha256:db573b01") || image == "zot-cve-test:0.0.1" { return map[string]cvemodel.CVE{ diff --git a/pkg/cli/client/service.go b/pkg/cli/client/service.go index 82089d7fdd..a71ed759d6 100644 --- a/pkg/cli/client/service.go +++ b/pkg/cli/client/service.go @@ -415,7 +415,7 @@ func (service searchService) getReferrers(ctx context.Context, config SearchConf referrersEndpoint, err := combineServerAndEndpointURL(config.ServURL, fmt.Sprintf("/v2/%s/referrers/%s", repo, digest)) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return referrersResult{}, nil } @@ -427,7 +427,7 @@ func (service searchService) getReferrers(ctx context.Context, config SearchConf config.Debug, &referrerResp, config.ResultWriter) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return referrersResult{}, nil } @@ -477,7 +477,7 @@ func (service searchService) getAllImages(ctx context.Context, config SearchConf catalogEndPoint, err := combineServerAndEndpointURL(config.ServURL, fmt.Sprintf("%s%s", constants.RoutePrefix, constants.ExtCatalogPrefix)) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } rch <- stringResult{"", err} @@ -488,7 +488,7 @@ func (service searchService) getAllImages(ctx context.Context, config SearchConf _, err = makeGETRequest(ctx, catalogEndPoint, username, password, config.VerifyTLS, config.Debug, catalog, config.ResultWriter) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } rch <- stringResult{"", err} @@ -522,7 +522,7 @@ func getImage(ctx context.Context, config SearchConfig, username, password, imag tagListEndpoint, err := combineServerAndEndpointURL(config.ServURL, fmt.Sprintf("/v2/%s/tags/list", repo)) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } rch <- stringResult{"", err} @@ -535,7 +535,7 @@ func getImage(ctx context.Context, config SearchConfig, username, password, imag config.Debug, &tagList, config.ResultWriter) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } rch <- stringResult{"", err} @@ -601,7 +601,7 @@ func (service searchService) getImagesByDigest(ctx context.Context, config Searc err := service.makeGraphQLQuery(ctx, config, username, password, query, result) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } rch <- stringResult{"", err} @@ -616,7 +616,7 @@ func (service searchService) getImagesByDigest(ctx context.Context, config Searc fmt.Fprintln(&errBuilder, err.Message) } - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } rch <- stringResult{"", errors.New(errBuilder.String())} //nolint: goerr113 @@ -640,15 +640,6 @@ func (service searchService) getImagesByDigest(ctx context.Context, config Searc localWg.Wait() } -func isContextDone(ctx context.Context) bool { - select { - case <-ctx.Done(): - return true - default: - return false - } -} - // Query using GQL, the query string is passed as a parameter // errors are returned in the stringResult channel, the unmarshalled payload is in resultPtr. func (service searchService) makeGraphQLQuery(ctx context.Context, @@ -672,7 +663,7 @@ func (service searchService) makeGraphQLQuery(ctx context.Context, func checkResultGraphQLQuery(ctx context.Context, err error, resultErrors []common.ErrorGQL, ) error { if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return nil //nolint:nilnil } @@ -686,7 +677,7 @@ func checkResultGraphQLQuery(ctx context.Context, err error, resultErrors []comm fmt.Fprintln(&errBuilder, error.Message) } - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return nil } @@ -705,7 +696,7 @@ func addManifestCallToPool(ctx context.Context, config SearchConfig, pool *reque manifestEndpoint, err := combineServerAndEndpointURL(config.ServURL, fmt.Sprintf("/v2/%s/manifests/%s", imageName, tagName)) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } rch <- stringResult{"", err} @@ -1315,7 +1306,7 @@ func (service searchService) getRepos(ctx context.Context, config SearchConfig, catalogEndPoint, err := combineServerAndEndpointURL(config.ServURL, fmt.Sprintf("%s%s", constants.RoutePrefix, constants.ExtCatalogPrefix)) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } rch <- stringResult{"", err} @@ -1326,7 +1317,7 @@ func (service searchService) getRepos(ctx context.Context, config SearchConfig, _, err = makeGETRequest(ctx, catalogEndPoint, username, password, config.VerifyTLS, config.Debug, catalog, config.ResultWriter) if err != nil { - if isContextDone(ctx) { + if common.IsContextDone(ctx) { return } rch <- stringResult{"", err} diff --git a/pkg/common/common.go b/pkg/common/common.go index 846b7e0c2e..d034e46848 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -1,6 +1,7 @@ package common import ( + "context" "encoding/json" "errors" "fmt" @@ -135,3 +136,12 @@ func IsReferrersTag(tag string) bool { return referrersTagRule.MatchString(tag) } + +func IsContextDone(ctx context.Context) bool { + select { + case <-ctx.Done(): + return true + default: + return false + } +} diff --git a/pkg/extensions/imagetrust/image_trust.go b/pkg/extensions/imagetrust/image_trust.go index 3fd06fa141..b56bba3d2b 100644 --- a/pkg/extensions/imagetrust/image_trust.go +++ b/pkg/extensions/imagetrust/image_trust.go @@ -256,8 +256,12 @@ func (validityT *validityTask) DoWork(ctx context.Context) error { validityT.log.Info().Msg("update signatures validity") for signedManifest, sigs := range validityT.repo.Signatures { + if zcommon.IsContextDone(ctx) { + return ctx.Err() + } + if len(sigs[zcommon.CosignSignature]) != 0 || len(sigs[zcommon.NotationSignature]) != 0 { - err := validityT.metaDB.UpdateSignaturesValidity(validityT.repo.Name, godigest.Digest(signedManifest)) + err := validityT.metaDB.UpdateSignaturesValidity(ctx, validityT.repo.Name, godigest.Digest(signedManifest)) if err != nil { validityT.log.Info().Msg("error while verifying signatures") diff --git a/pkg/extensions/search/convert/convert_internal_test.go b/pkg/extensions/search/convert/convert_internal_test.go index bd2321537f..d23053729b 100644 --- a/pkg/extensions/search/convert/convert_internal_test.go +++ b/pkg/extensions/search/convert/convert_internal_test.go @@ -61,7 +61,7 @@ func TestCVEConvert(t *testing.T) { Vulnerabilities: false, }, mocks.CveInfoMock{ - GetCVESummaryForImageMediaFn: func(repo string, digest, mediaType string, + GetCVESummaryForImageMediaFn: func(ctx context.Context, repo string, digest, mediaType string, ) (cvemodel.ImageCVESummary, error) { return cvemodel.ImageCVESummary{}, ErrTestError }, @@ -99,7 +99,7 @@ func TestCVEConvert(t *testing.T) { Vulnerabilities: false, }, mocks.CveInfoMock{ - GetCVESummaryForImageMediaFn: func(repo string, digest, mediaType string, + GetCVESummaryForImageMediaFn: func(ctx context.Context, repo string, digest, mediaType string, ) (cvemodel.ImageCVESummary, error) { return cvemodel.ImageCVESummary{ Count: 1, @@ -126,7 +126,7 @@ func TestCVEConvert(t *testing.T) { Vulnerabilities: false, }, mocks.CveInfoMock{ - GetCVESummaryForImageMediaFn: func(repo string, digest, mediaType string, + GetCVESummaryForImageMediaFn: func(ctx context.Context, repo string, digest, mediaType string, ) (cvemodel.ImageCVESummary, error) { return cvemodel.ImageCVESummary{}, ErrTestError }, @@ -149,7 +149,7 @@ func TestCVEConvert(t *testing.T) { Vulnerabilities: false, }, mocks.CveInfoMock{ - GetCVESummaryForImageMediaFn: func(repo string, digest, mediaType string, + GetCVESummaryForImageMediaFn: func(ctx context.Context, repo string, digest, mediaType string, ) (cvemodel.ImageCVESummary, error) { return cvemodel.ImageCVESummary{ Count: 1, @@ -179,7 +179,7 @@ func TestCVEConvert(t *testing.T) { Vulnerabilities: false, }, mocks.CveInfoMock{ - GetCVESummaryForImageMediaFn: func(repo string, digest, mediaType string, + GetCVESummaryForImageMediaFn: func(ctx context.Context, repo string, digest, mediaType string, ) (cvemodel.ImageCVESummary, error) { return cvemodel.ImageCVESummary{ Count: 1, @@ -207,7 +207,7 @@ func TestCVEConvert(t *testing.T) { Vulnerabilities: false, }, mocks.CveInfoMock{ - GetCVESummaryForImageMediaFn: func(repo string, digest, mediaType string, + GetCVESummaryForImageMediaFn: func(ctx context.Context, repo string, digest, mediaType string, ) (cvemodel.ImageCVESummary, error) { return cvemodel.ImageCVESummary{ Count: 1, @@ -248,7 +248,7 @@ func TestCVEConvert(t *testing.T) { Vulnerabilities: false, }, mocks.CveInfoMock{ - GetCVESummaryForImageMediaFn: func(repo string, digest, mediaType string, + GetCVESummaryForImageMediaFn: func(ctx context.Context, repo string, digest, mediaType string, ) (cvemodel.ImageCVESummary, error) { return cvemodel.ImageCVESummary{ Count: 1, @@ -271,7 +271,7 @@ func TestCVEConvert(t *testing.T) { Vulnerabilities: false, }, mocks.CveInfoMock{ - GetCVESummaryForImageMediaFn: func(repo string, digest, mediaType string, + GetCVESummaryForImageMediaFn: func(ctx context.Context, repo string, digest, mediaType string, ) (cvemodel.ImageCVESummary, error) { return cvemodel.ImageCVESummary{}, ErrTestError }, diff --git a/pkg/extensions/search/convert/cve.go b/pkg/extensions/search/convert/cve.go index b266f3f1eb..ffda720885 100644 --- a/pkg/extensions/search/convert/cve.go +++ b/pkg/extensions/search/convert/cve.go @@ -47,7 +47,7 @@ func updateImageSummaryVulnerabilities( return } - imageCveSummary, err := cveInfo.GetCVESummaryForImageMedia(*imageSummary.RepoName, *imageSummary.Digest, + imageCveSummary, err := cveInfo.GetCVESummaryForImageMedia(ctx, *imageSummary.RepoName, *imageSummary.Digest, *imageSummary.MediaType) if err != nil { // Log the error, but we should still include the image in results @@ -91,7 +91,7 @@ func updateManifestSummaryVulnerabilities( return } - imageCveSummary, err := cveInfo.GetCVESummaryForImageMedia(repoName, *manifestSummary.Digest, + imageCveSummary, err := cveInfo.GetCVESummaryForImageMedia(ctx, repoName, *manifestSummary.Digest, ispec.MediaTypeImageManifest) if err != nil { // Log the error, but we should still include the manifest in results diff --git a/pkg/extensions/search/cve/cve.go b/pkg/extensions/search/cve/cve.go index a4acb8b459..cf032b276a 100644 --- a/pkg/extensions/search/cve/cve.go +++ b/pkg/extensions/search/cve/cve.go @@ -19,20 +19,20 @@ import ( ) type CveInfo interface { - GetImageListForCVE(repo, cveID string) ([]cvemodel.TagInfo, error) - GetImageListWithCVEFixed(repo, cveID string) ([]cvemodel.TagInfo, error) - GetCVEListForImage(repo, tag string, searchedCVE string, pageinput cvemodel.PageInput, + GetImageListForCVE(ctx context.Context, repo, cveID string) ([]cvemodel.TagInfo, error) + GetImageListWithCVEFixed(ctx context.Context, repo, cveID string) ([]cvemodel.TagInfo, error) + GetCVEListForImage(ctx context.Context, repo, tag string, searchedCVE string, pageinput cvemodel.PageInput, ) ([]cvemodel.CVE, zcommon.PageInfo, error) - GetCVESummaryForImageMedia(repo, digest, mediaType string) (cvemodel.ImageCVESummary, error) + GetCVESummaryForImageMedia(ctx context.Context, repo, digest, mediaType string) (cvemodel.ImageCVESummary, error) } type Scanner interface { - ScanImage(image string) (map[string]cvemodel.CVE, error) + ScanImage(ctx context.Context, image string) (map[string]cvemodel.CVE, error) IsImageFormatScannable(repo, ref string) (bool, error) IsImageMediaScannable(repo, digestStr, mediaType string) (bool, error) IsResultCached(digestStr string) bool GetCachedResult(digestStr string) map[string]cvemodel.CVE - UpdateDB() error + UpdateDB(ctx context.Context) error } type BaseCveInfo struct { @@ -55,10 +55,10 @@ func NewCVEInfo(scanner Scanner, metaDB mTypes.MetaDB, log log.Logger) *BaseCveI } } -func (cveinfo BaseCveInfo) GetImageListForCVE(repo, cveID string) ([]cvemodel.TagInfo, error) { +func (cveinfo BaseCveInfo) GetImageListForCVE(ctx context.Context, repo, cveID string) ([]cvemodel.TagInfo, error) { imgList := make([]cvemodel.TagInfo, 0) - repoMeta, err := cveinfo.MetaDB.GetRepoMeta(context.Background(), repo) + repoMeta, err := cveinfo.MetaDB.GetRepoMeta(ctx, repo) if err != nil { cveinfo.Log.Error().Err(err).Str("repository", repo).Str("cve-id", cveID). Msg("unable to get list of tags from repo") @@ -80,8 +80,12 @@ func (cveinfo BaseCveInfo) GetImageListForCVE(repo, cveID string) ([]cvemodel.Ta continue } - cveMap, err := cveinfo.Scanner.ScanImage(zcommon.GetFullImageName(repo, tag)) + cveMap, err := cveinfo.Scanner.ScanImage(ctx, zcommon.GetFullImageName(repo, tag)) if err != nil { + if zcommon.IsContextDone(ctx) { + return imgList, err + } + cveinfo.Log.Info().Str("image", repo+":"+tag).Err(err).Msg("image scan failed") continue @@ -105,8 +109,9 @@ func (cveinfo BaseCveInfo) GetImageListForCVE(repo, cveID string) ([]cvemodel.Ta return imgList, nil } -func (cveinfo BaseCveInfo) GetImageListWithCVEFixed(repo, cveID string) ([]cvemodel.TagInfo, error) { - repoMeta, err := cveinfo.MetaDB.GetRepoMeta(context.Background(), repo) +func (cveinfo BaseCveInfo) GetImageListWithCVEFixed(ctx context.Context, repo, cveID string, +) ([]cvemodel.TagInfo, error) { + repoMeta, err := cveinfo.MetaDB.GetRepoMeta(ctx, repo) if err != nil { cveinfo.Log.Error().Err(err).Str("repository", repo).Str("cve-id", cveID). Msg("unable to get list of tags from repo") @@ -132,7 +137,12 @@ func (cveinfo BaseCveInfo) GetImageListWithCVEFixed(repo, cveID string) ([]cvemo allTags = append(allTags, tagInfo) - if cveinfo.isManifestVulnerable(repo, tag, manifestDigestStr, cveID) { + ok, err := cveinfo.isManifestVulnerable(ctx, repo, tag, manifestDigestStr, cveID) + if err != nil { + return []cvemodel.TagInfo{}, err + } + + if ok { vulnerableTags = append(vulnerableTags, tagInfo) } case ispec.MediaTypeImageIndex: @@ -162,7 +172,12 @@ func (cveinfo BaseCveInfo) GetImageListWithCVEFixed(repo, cveID string) ([]cvemo allManifests = append(allManifests, manifestDescriptorInfo) - if cveinfo.isManifestVulnerable(repo, tag, manifest.Digest.String(), cveID) { + ok, err := cveinfo.isManifestVulnerable(ctx, repo, tag, manifest.Digest.String(), cveID) + if err != nil { + return []cvemodel.TagInfo{}, err + } + + if ok { vulnerableManifests = append(vulnerableManifests, manifestDescriptorInfo) } } @@ -250,7 +265,8 @@ func getTagInfoForManifest(tag, manifestDigestStr string, metaDB mTypes.MetaDB) }, nil } -func (cveinfo *BaseCveInfo) isManifestVulnerable(repo, tag, manifestDigestStr, cveID string) bool { +func (cveinfo *BaseCveInfo) isManifestVulnerable(ctx context.Context, repo, tag, manifestDigestStr, cveID string, +) (bool, error) { image := zcommon.GetFullImageName(repo, tag) isValidImage, err := cveinfo.Scanner.IsImageMediaScannable(repo, manifestDigestStr, ispec.MediaTypeImageManifest) @@ -258,15 +274,19 @@ func (cveinfo *BaseCveInfo) isManifestVulnerable(repo, tag, manifestDigestStr, c cveinfo.Log.Debug().Str("image", image).Str("cve-id", cveID).Err(err). Msg("image media type not supported for scanning, adding as a vulnerable image") - return true + return true, nil } - cveMap, err := cveinfo.Scanner.ScanImage(zcommon.GetFullImageName(repo, manifestDigestStr)) + cveMap, err := cveinfo.Scanner.ScanImage(ctx, zcommon.GetFullImageName(repo, manifestDigestStr)) if err != nil { + if zcommon.IsContextDone(ctx) { + return false, ctx.Err() + } + cveinfo.Log.Debug().Str("image", image).Str("cve-id", cveID). Msg("scanning failed, adding as a vulnerable image") - return true + return true, nil } hasCVE := false @@ -279,7 +299,7 @@ func (cveinfo *BaseCveInfo) isManifestVulnerable(repo, tag, manifestDigestStr, c } } - return hasCVE + return hasCVE, nil } func getIndexContent(metaDB mTypes.MetaDB, indexDigestStr string) (ispec.Index, error) { @@ -330,10 +350,10 @@ func filterCVEList(cveMap map[string]cvemodel.CVE, searchedCVE string, pageFinde } } -func (cveinfo BaseCveInfo) GetCVEListForImage(repo, ref string, searchedCVE string, pageInput cvemodel.PageInput) ( - []cvemodel.CVE, - zcommon.PageInfo, - error, +func (cveinfo BaseCveInfo) GetCVEListForImage(ctx context.Context, repo, ref string, searchedCVE string, + pageInput cvemodel.PageInput, +) ( + []cvemodel.CVE, zcommon.PageInfo, error, ) { isValidImage, err := cveinfo.Scanner.IsImageFormatScannable(repo, ref) if !isValidImage { @@ -344,7 +364,7 @@ func (cveinfo BaseCveInfo) GetCVEListForImage(repo, ref string, searchedCVE stri image := zcommon.GetFullImageName(repo, ref) - cveMap, err := cveinfo.Scanner.ScanImage(image) + cveMap, err := cveinfo.Scanner.ScanImage(ctx, image) if err != nil { return []cvemodel.CVE{}, zcommon.PageInfo{}, err } @@ -361,7 +381,7 @@ func (cveinfo BaseCveInfo) GetCVEListForImage(repo, ref string, searchedCVE stri return cveList, pageInfo, nil } -func (cveinfo BaseCveInfo) GetCVESummaryForImageMedia(repo, digest, mediaType string, +func (cveinfo BaseCveInfo) GetCVESummaryForImageMedia(ctx context.Context, repo, digest, mediaType string, ) (cvemodel.ImageCVESummary, error) { // There are several cases, expected returned values below: // not scanned yet - max severity "" - cve count 0 - no Errors diff --git a/pkg/extensions/search/cve/cve_test.go b/pkg/extensions/search/cve/cve_test.go index 46fbba0dc3..1f5f341940 100644 --- a/pkg/extensions/search/cve/cve_test.go +++ b/pkg/extensions/search/cve/cve_test.go @@ -913,7 +913,7 @@ func TestCVEStruct(t *testing.T) { //nolint:gocyclo // MetaDB loaded with initial data, now mock the scanner // Setup test CVE data in mock scanner scanner := mocks.CveScannerMock{ - ScanImageFn: func(image string) (map[string]cvemodel.CVE, error) { + ScanImageFn: func(ctx context.Context, image string) (map[string]cvemodel.CVE, error) { result := cache.Get(image) // Will not match sending the repo:tag as a parameter, but we don't care if result != nil { @@ -1127,15 +1127,17 @@ func TestCVEStruct(t *testing.T) { //nolint:gocyclo SortBy: cveinfo.SeverityDsc, } + ctx := context.Background() + // Image is found - cveList, pageInfo, err := cveInfo.GetCVEListForImage(repo1, "0.1.0", "", pageInput) + cveList, pageInfo, err := cveInfo.GetCVEListForImage(ctx, repo1, "0.1.0", "", pageInput) So(err, ShouldBeNil) So(len(cveList), ShouldEqual, 1) So(cveList[0].ID, ShouldEqual, "CVE1") So(pageInfo.ItemCount, ShouldEqual, 1) So(pageInfo.TotalCount, ShouldEqual, 1) - cveList, pageInfo, err = cveInfo.GetCVEListForImage(repo1, "1.0.0", "", pageInput) + cveList, pageInfo, err = cveInfo.GetCVEListForImage(ctx, repo1, "1.0.0", "", pageInput) So(err, ShouldBeNil) So(len(cveList), ShouldEqual, 3) So(cveList[0].ID, ShouldEqual, "CVE2") @@ -1144,7 +1146,7 @@ func TestCVEStruct(t *testing.T) { //nolint:gocyclo So(pageInfo.ItemCount, ShouldEqual, 3) So(pageInfo.TotalCount, ShouldEqual, 3) - cveList, pageInfo, err = cveInfo.GetCVEListForImage(repo1, "1.0.1", "", pageInput) + cveList, pageInfo, err = cveInfo.GetCVEListForImage(ctx, repo1, "1.0.1", "", pageInput) So(err, ShouldBeNil) So(len(cveList), ShouldEqual, 2) So(cveList[0].ID, ShouldEqual, "CVE1") @@ -1152,21 +1154,21 @@ func TestCVEStruct(t *testing.T) { //nolint:gocyclo So(pageInfo.ItemCount, ShouldEqual, 2) So(pageInfo.TotalCount, ShouldEqual, 2) - cveList, pageInfo, err = cveInfo.GetCVEListForImage(repo1, "1.1.0", "", pageInput) + cveList, pageInfo, err = cveInfo.GetCVEListForImage(ctx, repo1, "1.1.0", "", pageInput) So(err, ShouldBeNil) So(len(cveList), ShouldEqual, 1) So(cveList[0].ID, ShouldEqual, "CVE3") So(pageInfo.ItemCount, ShouldEqual, 1) So(pageInfo.TotalCount, ShouldEqual, 1) - cveList, pageInfo, err = cveInfo.GetCVEListForImage(repo6, "1.0.0", "", pageInput) + cveList, pageInfo, err = cveInfo.GetCVEListForImage(ctx, repo6, "1.0.0", "", pageInput) So(err, ShouldBeNil) So(len(cveList), ShouldEqual, 0) So(pageInfo.ItemCount, ShouldEqual, 0) So(pageInfo.TotalCount, ShouldEqual, 0) // Image is multiarch - cveList, pageInfo, err = cveInfo.GetCVEListForImage(repoMultiarch, "tagIndex", "", pageInput) + cveList, pageInfo, err = cveInfo.GetCVEListForImage(ctx, repoMultiarch, "tagIndex", "", pageInput) So(err, ShouldBeNil) So(len(cveList), ShouldEqual, 1) So(cveList[0].ID, ShouldEqual, "CVE1") @@ -1174,35 +1176,35 @@ func TestCVEStruct(t *testing.T) { //nolint:gocyclo So(pageInfo.TotalCount, ShouldEqual, 1) // Image is not scannable - cveList, pageInfo, err = cveInfo.GetCVEListForImage(repo2, "1.0.0", "", pageInput) + cveList, pageInfo, err = cveInfo.GetCVEListForImage(ctx, repo2, "1.0.0", "", pageInput) So(err, ShouldEqual, zerr.ErrScanNotSupported) So(len(cveList), ShouldEqual, 0) So(pageInfo.ItemCount, ShouldEqual, 0) So(pageInfo.TotalCount, ShouldEqual, 0) // Tag is not found - cveList, pageInfo, err = cveInfo.GetCVEListForImage(repo3, "1.0.0", "", pageInput) + cveList, pageInfo, err = cveInfo.GetCVEListForImage(ctx, repo3, "1.0.0", "", pageInput) So(err, ShouldEqual, zerr.ErrTagMetaNotFound) So(len(cveList), ShouldEqual, 0) So(pageInfo.ItemCount, ShouldEqual, 0) So(pageInfo.TotalCount, ShouldEqual, 0) // Scan failed - cveList, pageInfo, err = cveInfo.GetCVEListForImage(repo7, "1.0.0", "", pageInput) + cveList, pageInfo, err = cveInfo.GetCVEListForImage(ctx, repo7, "1.0.0", "", pageInput) So(err, ShouldEqual, ErrFailedScan) So(len(cveList), ShouldEqual, 0) So(pageInfo.ItemCount, ShouldEqual, 0) So(pageInfo.TotalCount, ShouldEqual, 0) // Tag is not found - cveList, pageInfo, err = cveInfo.GetCVEListForImage("repo-with-bad-tag-digest", "tag", "", pageInput) + cveList, pageInfo, err = cveInfo.GetCVEListForImage(ctx, "repo-with-bad-tag-digest", "tag", "", pageInput) So(err, ShouldEqual, zerr.ErrImageMetaNotFound) So(len(cveList), ShouldEqual, 0) So(pageInfo.ItemCount, ShouldEqual, 0) So(pageInfo.TotalCount, ShouldEqual, 0) // Repo is not found - cveList, pageInfo, err = cveInfo.GetCVEListForImage(repo100, "1.0.0", "", pageInput) + cveList, pageInfo, err = cveInfo.GetCVEListForImage(ctx, repo100, "1.0.0", "", pageInput) So(err, ShouldEqual, zerr.ErrRepoMetaNotFound) So(len(cveList), ShouldEqual, 0) So(pageInfo.ItemCount, ShouldEqual, 0) @@ -1212,51 +1214,51 @@ func TestCVEStruct(t *testing.T) { //nolint:gocyclo t.Log("\nTest GetCVESummaryForImage\n") // Image is found - cveSummary, err := cveInfo.GetCVESummaryForImageMedia(repo1, image11Digest, image11Media) + cveSummary, err := cveInfo.GetCVESummaryForImageMedia(ctx, repo1, image11Digest, image11Media) So(err, ShouldBeNil) So(cveSummary.Count, ShouldEqual, 1) So(cveSummary.MaxSeverity, ShouldEqual, "MEDIUM") - cveSummary, err = cveInfo.GetCVESummaryForImageMedia(repo1, image12Digest, image12Media) + cveSummary, err = cveInfo.GetCVESummaryForImageMedia(ctx, repo1, image12Digest, image12Media) So(err, ShouldBeNil) So(cveSummary.Count, ShouldEqual, 3) So(cveSummary.MaxSeverity, ShouldEqual, "HIGH") - cveSummary, err = cveInfo.GetCVESummaryForImageMedia(repo1, image14Digest, image14Media) + cveSummary, err = cveInfo.GetCVESummaryForImageMedia(ctx, repo1, image14Digest, image14Media) So(err, ShouldBeNil) So(cveSummary.Count, ShouldEqual, 2) So(cveSummary.MaxSeverity, ShouldEqual, "MEDIUM") - cveSummary, err = cveInfo.GetCVESummaryForImageMedia(repo1, image13Digest, image13Media) + cveSummary, err = cveInfo.GetCVESummaryForImageMedia(ctx, repo1, image13Digest, image13Media) So(err, ShouldBeNil) So(cveSummary.Count, ShouldEqual, 1) So(cveSummary.MaxSeverity, ShouldEqual, "LOW") - cveSummary, err = cveInfo.GetCVESummaryForImageMedia(repo6, image61Digest, image61Media) + cveSummary, err = cveInfo.GetCVESummaryForImageMedia(ctx, repo6, image61Digest, image61Media) So(err, ShouldBeNil) So(cveSummary.Count, ShouldEqual, 0) So(cveSummary.MaxSeverity, ShouldEqual, "NONE") // Image is multiarch - cveSummary, err = cveInfo.GetCVESummaryForImageMedia(repoMultiarch, indexDigest, indexMedia) + cveSummary, err = cveInfo.GetCVESummaryForImageMedia(ctx, repoMultiarch, indexDigest, indexMedia) So(err, ShouldBeNil) So(cveSummary.Count, ShouldEqual, 1) So(cveSummary.MaxSeverity, ShouldEqual, "MEDIUM") // Image is not scannable - cveSummary, err = cveInfo.GetCVESummaryForImageMedia(repo2, image21Digest, image21Media) + cveSummary, err = cveInfo.GetCVESummaryForImageMedia(ctx, repo2, image21Digest, image21Media) So(err, ShouldEqual, zerr.ErrScanNotSupported) So(cveSummary.Count, ShouldEqual, 0) So(cveSummary.MaxSeverity, ShouldEqual, "") // Scan failed - cveSummary, err = cveInfo.GetCVESummaryForImageMedia(repo5, image71Digest, image71Media) + cveSummary, err = cveInfo.GetCVESummaryForImageMedia(ctx, repo5, image71Digest, image71Media) So(err, ShouldBeNil) So(cveSummary.Count, ShouldEqual, 0) So(cveSummary.MaxSeverity, ShouldEqual, "") // Repo is not found - cveSummary, err = cveInfo.GetCVESummaryForImageMedia(repo100, + cveSummary, err = cveInfo.GetCVESummaryForImageMedia(ctx, repo100, godigest.FromString("missing_digest").String(), ispec.MediaTypeImageManifest) So(err, ShouldEqual, zerr.ErrRepoMetaNotFound) So(cveSummary.Count, ShouldEqual, 0) @@ -1265,19 +1267,19 @@ func TestCVEStruct(t *testing.T) { //nolint:gocyclo t.Log("\nTest GetImageListWithCVEFixed\n") // Image is found - tagList, err := cveInfo.GetImageListWithCVEFixed(repo1, "CVE1") + tagList, err := cveInfo.GetImageListWithCVEFixed(ctx, repo1, "CVE1") So(err, ShouldBeNil) So(len(tagList), ShouldEqual, 1) So(tagList[0].Tag, ShouldEqual, "1.1.0") - tagList, err = cveInfo.GetImageListWithCVEFixed(repo1, "CVE2") + tagList, err = cveInfo.GetImageListWithCVEFixed(ctx, repo1, "CVE2") So(err, ShouldBeNil) So(len(tagList), ShouldEqual, 2) expectedTags := []string{"1.0.1", "1.1.0"} So(expectedTags, ShouldContain, tagList[0].Tag) So(expectedTags, ShouldContain, tagList[1].Tag) - tagList, err = cveInfo.GetImageListWithCVEFixed(repo1, "CVE3") + tagList, err = cveInfo.GetImageListWithCVEFixed(ctx, repo1, "CVE3") So(err, ShouldBeNil) // CVE3 is not present in 0.1.0, but that is older than all other // images where it is present. The rest of the images explicitly have it. @@ -1285,13 +1287,13 @@ func TestCVEStruct(t *testing.T) { //nolint:gocyclo So(len(tagList), ShouldEqual, 0) // Image doesn't have any CVEs in the first place - tagList, err = cveInfo.GetImageListWithCVEFixed(repo6, "CVE1") + tagList, err = cveInfo.GetImageListWithCVEFixed(ctx, repo6, "CVE1") So(err, ShouldBeNil) So(len(tagList), ShouldEqual, 1) So(tagList[0].Tag, ShouldEqual, "1.0.0") // Image is not scannable - tagList, err = cveInfo.GetImageListWithCVEFixed(repo2, "CVE100") + tagList, err = cveInfo.GetImageListWithCVEFixed(ctx, repo2, "CVE100") // CVE is not considered fixed as scan is not possible // but do not return an error So(err, ShouldBeNil) @@ -1299,14 +1301,14 @@ func TestCVEStruct(t *testing.T) { //nolint:gocyclo // Repo is not found, there could potentially be unaffected tags in the repo // but we can't access their data - tagList, err = cveInfo.GetImageListWithCVEFixed(repo100, "CVE100") + tagList, err = cveInfo.GetImageListWithCVEFixed(ctx, repo100, "CVE100") So(err, ShouldEqual, zerr.ErrRepoMetaNotFound) So(len(tagList), ShouldEqual, 0) t.Log("\nTest GetImageListForCVE\n") // Image is found - tagList, err = cveInfo.GetImageListForCVE(repo1, "CVE1") + tagList, err = cveInfo.GetImageListForCVE(ctx, repo1, "CVE1") So(err, ShouldBeNil) So(len(tagList), ShouldEqual, 3) expectedTags = []string{"0.1.0", "1.0.0", "1.0.1"} @@ -1314,12 +1316,12 @@ func TestCVEStruct(t *testing.T) { //nolint:gocyclo So(expectedTags, ShouldContain, tagList[1].Tag) So(expectedTags, ShouldContain, tagList[2].Tag) - tagList, err = cveInfo.GetImageListForCVE(repo1, "CVE2") + tagList, err = cveInfo.GetImageListForCVE(ctx, repo1, "CVE2") So(err, ShouldBeNil) So(len(tagList), ShouldEqual, 1) So(tagList[0].Tag, ShouldEqual, "1.0.0") - tagList, err = cveInfo.GetImageListForCVE(repo1, "CVE3") + tagList, err = cveInfo.GetImageListForCVE(ctx, repo1, "CVE3") So(err, ShouldBeNil) So(len(tagList), ShouldEqual, 3) expectedTags = []string{"1.0.0", "1.0.1", "1.1.0"} @@ -1328,32 +1330,32 @@ func TestCVEStruct(t *testing.T) { //nolint:gocyclo So(expectedTags, ShouldContain, tagList[2].Tag) // Image/repo doesn't have the CVE at all - tagList, err = cveInfo.GetImageListForCVE(repo6, "CVE1") + tagList, err = cveInfo.GetImageListForCVE(ctx, repo6, "CVE1") So(err, ShouldBeNil) So(len(tagList), ShouldEqual, 0) // Image is not scannable - tagList, err = cveInfo.GetImageListForCVE(repo2, "CVE100") + tagList, err = cveInfo.GetImageListForCVE(ctx, repo2, "CVE100") // Image is not considered affected with CVE as scan is not possible // but do not return an error So(err, ShouldBeNil) So(len(tagList), ShouldEqual, 0) // Tag is not found, but we should not error - tagList, err = cveInfo.GetImageListForCVE(repo3, "CVE101") + tagList, err = cveInfo.GetImageListForCVE(ctx, repo3, "CVE101") So(err, ShouldBeNil) So(len(tagList), ShouldEqual, 0) // Repo is not found, assume it is affected by the CVE // But we don't have enough of its data to actually return it - tagList, err = cveInfo.GetImageListForCVE(repo100, "CVE100") + tagList, err = cveInfo.GetImageListForCVE(ctx, repo100, "CVE100") So(err, ShouldEqual, zerr.ErrRepoMetaNotFound) So(len(tagList), ShouldEqual, 0) t.Log("\nTest errors while scanning\n") faultyScanner := mocks.CveScannerMock{ - ScanImageFn: func(image string) (map[string]cvemodel.CVE, error) { + ScanImageFn: func(ctx context.Context, image string) (map[string]cvemodel.CVE, error) { // Could be any type of error, let's reuse this one return nil, zerr.ErrScanNotSupported }, @@ -1361,24 +1363,24 @@ func TestCVEStruct(t *testing.T) { //nolint:gocyclo cveInfo = cveinfo.BaseCveInfo{Log: log, Scanner: faultyScanner, MetaDB: metaDB} - cveSummary, err = cveInfo.GetCVESummaryForImageMedia(repo1, image11Digest, image11Media) + cveSummary, err = cveInfo.GetCVESummaryForImageMedia(ctx, repo1, image11Digest, image11Media) So(err, ShouldBeNil) So(cveSummary.Count, ShouldEqual, 0) So(cveSummary.MaxSeverity, ShouldEqual, "") - cveList, pageInfo, err = cveInfo.GetCVEListForImage(repo1, "0.1.0", "", pageInput) + cveList, pageInfo, err = cveInfo.GetCVEListForImage(ctx, repo1, "0.1.0", "", pageInput) So(err, ShouldNotBeNil) So(cveList, ShouldBeEmpty) So(pageInfo.ItemCount, ShouldEqual, 0) So(pageInfo.TotalCount, ShouldEqual, 0) - tagList, err = cveInfo.GetImageListWithCVEFixed(repo1, "CVE1") + tagList, err = cveInfo.GetImageListWithCVEFixed(ctx, repo1, "CVE1") // CVE is not considered fixed as scan is not possible // but do not return an error So(err, ShouldBeNil) So(len(tagList), ShouldEqual, 0) - tagList, err = cveInfo.GetImageListForCVE(repo1, "CVE1") + tagList, err = cveInfo.GetImageListForCVE(ctx, repo1, "CVE1") // Image is not considered affected with CVE as scan is not possible // but do not return an error So(err, ShouldBeNil) @@ -1390,19 +1392,19 @@ func TestCVEStruct(t *testing.T) { //nolint:gocyclo }, }, MetaDB: metaDB} - _, err = cveInfo.GetImageListForCVE(repoMultiarch, "CVE1") + _, err = cveInfo.GetImageListForCVE(ctx, repoMultiarch, "CVE1") So(err, ShouldBeNil) cveInfo = cveinfo.BaseCveInfo{Log: log, Scanner: mocks.CveScannerMock{ IsImageFormatScannableFn: func(repo, reference string) (bool, error) { return true, nil }, - ScanImageFn: func(image string) (map[string]cvemodel.CVE, error) { + ScanImageFn: func(ctx context.Context, image string) (map[string]cvemodel.CVE, error) { return nil, zerr.ErrTypeAssertionFailed }, }, MetaDB: metaDB} - _, err = cveInfo.GetImageListForCVE(repoMultiarch, "CVE1") + _, err = cveInfo.GetImageListForCVE(ctx, repoMultiarch, "CVE1") So(err, ShouldBeNil) }) } @@ -1545,7 +1547,7 @@ func TestFixedTagsWithIndex(t *testing.T) { cveInfo := cveinfo.NewCVEInfo(ctlr.CveScanner, ctlr.MetaDB, ctlr.Log) - tagsInfo, err := cveInfo.GetImageListWithCVEFixed("repo", Vulnerability1ID) + tagsInfo, err := cveInfo.GetImageListWithCVEFixed(context.Background(), "repo", Vulnerability1ID) So(err, ShouldBeNil) So(len(tagsInfo), ShouldEqual, 1) So(len(tagsInfo[0].Manifests), ShouldEqual, 1) @@ -1593,7 +1595,7 @@ func TestGetCVESummaryForImageMediaErrors(t *testing.T) { cveInfo := cveinfo.NewCVEInfo(scanner, metaDB, log) - _, err := cveInfo.GetCVESummaryForImageMedia("repo", "digest", ispec.MediaTypeImageManifest) + _, err := cveInfo.GetCVESummaryForImageMedia(context.Background(), "repo", "digest", ispec.MediaTypeImageManifest) So(err, ShouldNotBeNil) }) }) diff --git a/pkg/extensions/search/cve/pagination_test.go b/pkg/extensions/search/cve/pagination_test.go index e374e3b394..a154ddd4db 100644 --- a/pkg/extensions/search/cve/pagination_test.go +++ b/pkg/extensions/search/cve/pagination_test.go @@ -73,7 +73,7 @@ func TestCVEPagination(t *testing.T) { // Setup test CVE data in mock scanner scanner := mocks.CveScannerMock{ - ScanImageFn: func(image string) (map[string]cvemodel.CVE, error) { + ScanImageFn: func(ctx context.Context, image string) (map[string]cvemodel.CVE, error) { cveMap := map[string]cvemodel.CVE{} if image == "repo1:0.1.0" { @@ -106,6 +106,8 @@ func TestCVEPagination(t *testing.T) { log := log.NewLogger("debug", "") cveInfo := cveinfo.BaseCveInfo{Log: log, Scanner: scanner, MetaDB: metaDB} + ctx := context.Background() + Convey("create new paginator errors", func() { paginator, err := cveinfo.NewCvePageFinder(-1, 10, cveinfo.AlphabeticAsc) So(paginator, ShouldBeNil) @@ -138,7 +140,7 @@ func TestCVEPagination(t *testing.T) { Convey("Page", func() { Convey("defaults", func() { // By default expect unlimitted results sorted by severity - cves, pageInfo, err := cveInfo.GetCVEListForImage("repo1", "0.1.0", "", cvemodel.PageInput{}) + cves, pageInfo, err := cveInfo.GetCVEListForImage(ctx, "repo1", "0.1.0", "", cvemodel.PageInput{}) So(err, ShouldBeNil) So(len(cves), ShouldEqual, 5) So(pageInfo.ItemCount, ShouldEqual, 5) @@ -149,7 +151,7 @@ func TestCVEPagination(t *testing.T) { previousSeverity = severityToInt[cve.Severity] } - cves, pageInfo, err = cveInfo.GetCVEListForImage("repo1", "1.0.0", "", cvemodel.PageInput{}) + cves, pageInfo, err = cveInfo.GetCVEListForImage(ctx, "repo1", "1.0.0", "", cvemodel.PageInput{}) So(err, ShouldBeNil) So(len(cves), ShouldEqual, 30) So(pageInfo.ItemCount, ShouldEqual, 30) @@ -167,7 +169,7 @@ func TestCVEPagination(t *testing.T) { cveIds = append(cveIds, fmt.Sprintf("CVE%d", i)) } - cves, pageInfo, err := cveInfo.GetCVEListForImage("repo1", "0.1.0", "", + cves, pageInfo, err := cveInfo.GetCVEListForImage(ctx, "repo1", "0.1.0", "", cvemodel.PageInput{SortBy: cveinfo.AlphabeticAsc}) So(err, ShouldBeNil) So(len(cves), ShouldEqual, 5) @@ -178,7 +180,7 @@ func TestCVEPagination(t *testing.T) { } sort.Strings(cveIds) - cves, pageInfo, err = cveInfo.GetCVEListForImage("repo1", "1.0.0", "", + cves, pageInfo, err = cveInfo.GetCVEListForImage(ctx, "repo1", "1.0.0", "", cvemodel.PageInput{SortBy: cveinfo.AlphabeticAsc}) So(err, ShouldBeNil) So(len(cves), ShouldEqual, 30) @@ -189,7 +191,7 @@ func TestCVEPagination(t *testing.T) { } sort.Sort(sort.Reverse(sort.StringSlice(cveIds))) - cves, pageInfo, err = cveInfo.GetCVEListForImage("repo1", "1.0.0", "", + cves, pageInfo, err = cveInfo.GetCVEListForImage(ctx, "repo1", "1.0.0", "", cvemodel.PageInput{SortBy: cveinfo.AlphabeticDsc}) So(err, ShouldBeNil) So(len(cves), ShouldEqual, 30) @@ -199,7 +201,7 @@ func TestCVEPagination(t *testing.T) { So(cve.ID, ShouldEqual, cveIds[i]) } - cves, pageInfo, err = cveInfo.GetCVEListForImage("repo1", "1.0.0", "", + cves, pageInfo, err = cveInfo.GetCVEListForImage(ctx, "repo1", "1.0.0", "", cvemodel.PageInput{SortBy: cveinfo.SeverityDsc}) So(err, ShouldBeNil) So(len(cves), ShouldEqual, 30) @@ -218,7 +220,7 @@ func TestCVEPagination(t *testing.T) { cveIds = append(cveIds, fmt.Sprintf("CVE%d", i)) } - cves, pageInfo, err := cveInfo.GetCVEListForImage("repo1", "0.1.0", "", cvemodel.PageInput{ + cves, pageInfo, err := cveInfo.GetCVEListForImage(ctx, "repo1", "0.1.0", "", cvemodel.PageInput{ Limit: 3, Offset: 1, SortBy: cveinfo.AlphabeticAsc, @@ -232,7 +234,7 @@ func TestCVEPagination(t *testing.T) { So(cves[1].ID, ShouldEqual, "CVE2") So(cves[2].ID, ShouldEqual, "CVE3") - cves, pageInfo, err = cveInfo.GetCVEListForImage("repo1", "0.1.0", "", cvemodel.PageInput{ + cves, pageInfo, err = cveInfo.GetCVEListForImage(ctx, "repo1", "0.1.0", "", cvemodel.PageInput{ Limit: 2, Offset: 1, SortBy: cveinfo.AlphabeticDsc, @@ -245,7 +247,7 @@ func TestCVEPagination(t *testing.T) { So(cves[0].ID, ShouldEqual, "CVE3") So(cves[1].ID, ShouldEqual, "CVE2") - cves, pageInfo, err = cveInfo.GetCVEListForImage("repo1", "0.1.0", "", cvemodel.PageInput{ + cves, pageInfo, err = cveInfo.GetCVEListForImage(ctx, "repo1", "0.1.0", "", cvemodel.PageInput{ Limit: 3, Offset: 1, SortBy: cveinfo.SeverityDsc, @@ -262,7 +264,7 @@ func TestCVEPagination(t *testing.T) { } sort.Strings(cveIds) - cves, pageInfo, err = cveInfo.GetCVEListForImage("repo1", "1.0.0", "", cvemodel.PageInput{ + cves, pageInfo, err = cveInfo.GetCVEListForImage(ctx, "repo1", "1.0.0", "", cvemodel.PageInput{ Limit: 5, Offset: 20, SortBy: cveinfo.AlphabeticAsc, @@ -278,7 +280,7 @@ func TestCVEPagination(t *testing.T) { }) Convey("limit > len(cves)", func() { - cves, pageInfo, err := cveInfo.GetCVEListForImage("repo1", "0.1.0", "", cvemodel.PageInput{ + cves, pageInfo, err := cveInfo.GetCVEListForImage(ctx, "repo1", "0.1.0", "", cvemodel.PageInput{ Limit: 6, Offset: 3, SortBy: cveinfo.AlphabeticAsc, @@ -291,7 +293,7 @@ func TestCVEPagination(t *testing.T) { So(cves[0].ID, ShouldEqual, "CVE3") So(cves[1].ID, ShouldEqual, "CVE4") - cves, pageInfo, err = cveInfo.GetCVEListForImage("repo1", "0.1.0", "", cvemodel.PageInput{ + cves, pageInfo, err = cveInfo.GetCVEListForImage(ctx, "repo1", "0.1.0", "", cvemodel.PageInput{ Limit: 6, Offset: 3, SortBy: cveinfo.AlphabeticDsc, @@ -304,7 +306,7 @@ func TestCVEPagination(t *testing.T) { So(cves[0].ID, ShouldEqual, "CVE1") So(cves[1].ID, ShouldEqual, "CVE0") - cves, pageInfo, err = cveInfo.GetCVEListForImage("repo1", "0.1.0", "", cvemodel.PageInput{ + cves, pageInfo, err = cveInfo.GetCVEListForImage(ctx, "repo1", "0.1.0", "", cvemodel.PageInput{ Limit: 6, Offset: 3, SortBy: cveinfo.SeverityDsc, diff --git a/pkg/extensions/search/cve/scan.go b/pkg/extensions/search/cve/scan.go index 3f3e1980d2..422807e97d 100644 --- a/pkg/extensions/search/cve/scan.go +++ b/pkg/extensions/search/cve/scan.go @@ -183,7 +183,7 @@ func (st *scanTask) DoWork(ctx context.Context) error { // We cache the results internally in the scanner // so we can discard the actual results for now - if _, err := st.generator.scanner.ScanImage(image); err != nil { + if _, err := st.generator.scanner.ScanImage(ctx, image); err != nil { st.generator.log.Error().Err(err).Str("image", image).Msg("Scheduled CVE scan errored for image") st.generator.addError(st.digest, err) diff --git a/pkg/extensions/search/cve/scan_test.go b/pkg/extensions/search/cve/scan_test.go index e6078a5206..41873a7aee 100644 --- a/pkg/extensions/search/cve/scan_test.go +++ b/pkg/extensions/search/cve/scan_test.go @@ -216,7 +216,7 @@ func TestScanGeneratorWithMockedData(t *testing.T) { //nolint: gocyclo // MetaDB loaded with initial data, now mock the scanner // Setup test CVE data in mock scanner scanner := mocks.CveScannerMock{ - ScanImageFn: func(image string) (map[string]cvemodel.CVE, error) { + ScanImageFn: func(ctx context.Context, image string) (map[string]cvemodel.CVE, error) { result := cache.Get(image) // Will not match sending the repo:tag as a parameter, but we don't care if result != nil { @@ -408,7 +408,7 @@ func TestScanGeneratorWithMockedData(t *testing.T) { //nolint: gocyclo IsResultCachedFn: func(digest string) bool { return cache.Contains(digest) }, - UpdateDBFn: func() error { + UpdateDBFn: func(ctx context.Context) error { cache.Purge() return nil @@ -416,7 +416,7 @@ func TestScanGeneratorWithMockedData(t *testing.T) { //nolint: gocyclo } // Purge scan, it should not be needed - So(scanner.UpdateDB(), ShouldBeNil) + So(scanner.UpdateDB(context.Background()), ShouldBeNil) // Verify none of the entries are cached to begin with t.Log("verify cache is initially empty") @@ -515,7 +515,7 @@ func TestScanGeneratorWithRealData(t *testing.T) { So(err, ShouldBeNil) scanner := cveinfo.NewScanner(storeController, metaDB, "ghcr.io/project-zot/trivy-db", "", logger) - err = scanner.UpdateDB() + err = scanner.UpdateDB(context.Background()) So(err, ShouldBeNil) So(scanner.IsResultCached(image.DigestStr()), ShouldBeFalse) @@ -551,7 +551,7 @@ func TestScanGeneratorWithRealData(t *testing.T) { So(scanner.IsResultCached(image.DigestStr()), ShouldBeTrue) - cveMap, err := scanner.ScanImage("zot-test:0.0.1") + cveMap, err := scanner.ScanImage(context.Background(), "zot-test:0.0.1") So(err, ShouldBeNil) t.Logf("cveMap: %v", cveMap) // As of September 22 2023 there are 5 CVEs: @@ -567,7 +567,7 @@ func TestScanGeneratorWithRealData(t *testing.T) { cveInfo := cveinfo.NewCVEInfo(scanner, metaDB, logger) // Based on cache population only, no extra scanning - cveSummary, err := cveInfo.GetCVESummaryForImageMedia("zot-test", image.DigestStr(), + cveSummary, err := cveInfo.GetCVESummaryForImageMedia(context.Background(), "zot-test", image.DigestStr(), image.ManifestDescriptor.MediaType) So(err, ShouldBeNil) So(cveSummary.Count, ShouldBeGreaterThanOrEqualTo, 5) diff --git a/pkg/extensions/search/cve/trivy/scanner.go b/pkg/extensions/search/cve/trivy/scanner.go index 38c6f1e9b9..a337c03b40 100644 --- a/pkg/extensions/search/cve/trivy/scanner.go +++ b/pkg/extensions/search/cve/trivy/scanner.go @@ -157,9 +157,7 @@ func (scanner Scanner) getTrivyOptions(image string) flag.Options { return opts } -func (scanner Scanner) runTrivy(opts flag.Options) (types.Report, error) { - ctx := context.Background() - +func (scanner Scanner) runTrivy(ctx context.Context, opts flag.Options) (types.Report, error) { err := scanner.checkDBPresence() if err != nil { return types.Report{}, err @@ -191,7 +189,7 @@ func (scanner Scanner) IsImageFormatScannable(repo, ref string) (bool, error) { ) if zcommon.IsTag(ref) { - imgDescriptor, err := getImageDescriptor(scanner.metaDB, repo, ref) + imgDescriptor, err := getImageDescriptor(context.Background(), scanner.metaDB, repo, ref) if err != nil { return false, err } @@ -316,7 +314,7 @@ func (scanner Scanner) GetCachedResult(digest string) map[string]cvemodel.CVE { return scanner.cache.Get(digest) } -func (scanner Scanner) ScanImage(image string) (map[string]cvemodel.CVE, error) { +func (scanner Scanner) ScanImage(ctx context.Context, image string) (map[string]cvemodel.CVE, error) { var ( originalImageInput = image digest string @@ -328,7 +326,7 @@ func (scanner Scanner) ScanImage(image string) (map[string]cvemodel.CVE, error) digest = ref if isTag { - imgDescriptor, err := getImageDescriptor(scanner.metaDB, repo, ref) + imgDescriptor, err := getImageDescriptor(ctx, scanner.metaDB, repo, ref) if err != nil { return map[string]cvemodel.CVE{}, err } @@ -351,9 +349,9 @@ func (scanner Scanner) ScanImage(image string) (map[string]cvemodel.CVE, error) switch mediaType { case ispec.MediaTypeImageIndex: - cveIDMap, err = scanner.scanIndex(repo, digest) + cveIDMap, err = scanner.scanIndex(ctx, repo, digest) default: - cveIDMap, err = scanner.scanManifest(repo, digest) + cveIDMap, err = scanner.scanManifest(ctx, repo, digest) } if err != nil { @@ -365,7 +363,7 @@ func (scanner Scanner) ScanImage(image string) (map[string]cvemodel.CVE, error) return cveIDMap, nil } -func (scanner Scanner) scanManifest(repo, digest string) (map[string]cvemodel.CVE, error) { +func (scanner Scanner) scanManifest(ctx context.Context, repo, digest string) (map[string]cvemodel.CVE, error) { if cachedMap := scanner.cache.Get(digest); cachedMap != nil { return cachedMap, nil } @@ -375,7 +373,7 @@ func (scanner Scanner) scanManifest(repo, digest string) (map[string]cvemodel.CV scanner.dbLock.Lock() opts := scanner.getTrivyOptions(image) - report, err := scanner.runTrivy(opts) + report, err := scanner.runTrivy(ctx, opts) scanner.dbLock.Unlock() if err != nil { //nolint: wsl @@ -441,7 +439,7 @@ func (scanner Scanner) scanManifest(repo, digest string) (map[string]cvemodel.CV return cveidMap, nil } -func (scanner Scanner) scanIndex(repo, digest string) (map[string]cvemodel.CVE, error) { +func (scanner Scanner) scanIndex(ctx context.Context, repo, digest string) (map[string]cvemodel.CVE, error) { if cachedMap := scanner.cache.Get(digest); cachedMap != nil { return cachedMap, nil } @@ -459,7 +457,7 @@ func (scanner Scanner) scanIndex(repo, digest string) (map[string]cvemodel.CVE, for _, manifest := range indexData.Index.Manifests { if isScannable, err := scanner.isManifestScanable(manifest.Digest.String()); isScannable && err == nil { - manifestCveIDMap, err := scanner.scanManifest(repo, manifest.Digest.String()) + manifestCveIDMap, err := scanner.scanManifest(ctx, repo, manifest.Digest.String()) if err != nil { return nil, err } @@ -476,7 +474,7 @@ func (scanner Scanner) scanIndex(repo, digest string) (map[string]cvemodel.CVE, } // UpdateDB downloads the Trivy DB / Cache under the store root directory. -func (scanner Scanner) UpdateDB() error { +func (scanner Scanner) UpdateDB(ctx context.Context) error { // We need a lock as using multiple substores each with its own DB // can result in a DATARACE because some varibles in trivy-db are global // https://github.com/project-zot/trivy-db/blob/main/pkg/db/db.go#L23 @@ -486,7 +484,7 @@ func (scanner Scanner) UpdateDB() error { if scanner.storeController.DefaultStore != nil { dbDir := path.Join(scanner.storeController.DefaultStore.RootDir(), "_trivy") - err := scanner.updateDB(dbDir) + err := scanner.updateDB(ctx, dbDir) if err != nil { return err } @@ -496,7 +494,7 @@ func (scanner Scanner) UpdateDB() error { for _, storage := range scanner.storeController.SubStore { dbDir := path.Join(storage.RootDir(), "_trivy") - err := scanner.updateDB(dbDir) + err := scanner.updateDB(ctx, dbDir) if err != nil { return err } @@ -508,11 +506,9 @@ func (scanner Scanner) UpdateDB() error { return nil } -func (scanner Scanner) updateDB(dbDir string) error { +func (scanner Scanner) updateDB(ctx context.Context, dbDir string) error { scanner.log.Debug().Str("dbDir", dbDir).Msg("Download Trivy DB to destination dir") - ctx := context.Background() - registryOpts := fanalTypes.RegistryOptions{Insecure: false} scanner.log.Debug().Str("dbDir", dbDir).Msg("Started downloading Trivy DB to destination dir") @@ -569,8 +565,8 @@ func (scanner Scanner) checkDBPresence() error { return nil } -func getImageDescriptor(metaDB mTypes.MetaDB, repo, tag string) (mTypes.Descriptor, error) { - repoMeta, err := metaDB.GetRepoMeta(context.Background(), repo) +func getImageDescriptor(ctx context.Context, metaDB mTypes.MetaDB, repo, tag string) (mTypes.Descriptor, error) { + repoMeta, err := metaDB.GetRepoMeta(ctx, repo) if err != nil { return mTypes.Descriptor{}, err } diff --git a/pkg/extensions/search/cve/trivy/scanner_internal_test.go b/pkg/extensions/search/cve/trivy/scanner_internal_test.go index ecfd3c8e84..ce0be37d75 100644 --- a/pkg/extensions/search/cve/trivy/scanner_internal_test.go +++ b/pkg/extensions/search/cve/trivy/scanner_internal_test.go @@ -128,46 +128,56 @@ func TestMultipleStoragePath(t *testing.T) { So(err, ShouldBeNil) // Try to scan without the DB being downloaded - _, err = scanner.ScanImage(img0) + _, err = scanner.ScanImage(context.Background(), img0) So(err, ShouldNotBeNil) So(err, ShouldWrap, zerr.ErrCVEDBNotFound) + // Try to scan with a context done + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err = scanner.ScanImage(ctx, img0) + So(err, ShouldNotBeNil) + + ctx = context.Background() + // Download DB since DB download on scan is disabled - err = scanner.UpdateDB() + err = scanner.UpdateDB(ctx) So(err, ShouldBeNil) // Scanning image in default store - cveMap, err := scanner.ScanImage(img0) + cveMap, err := scanner.ScanImage(ctx, img0) So(err, ShouldBeNil) So(len(cveMap), ShouldEqual, 0) // Scanning image in substore - cveMap, err = scanner.ScanImage(img1) + cveMap, err = scanner.ScanImage(ctx, img1) So(err, ShouldBeNil) So(len(cveMap), ShouldEqual, 0) // Scanning image which does not exist - cveMap, err = scanner.ScanImage("a/test/image2:tag100") + cveMap, err = scanner.ScanImage(ctx, "a/test/image2:tag100") So(err, ShouldNotBeNil) So(len(cveMap), ShouldEqual, 0) // Download the DB to a default store location without permissions err = os.Chmod(firstRootDir, 0o000) So(err, ShouldBeNil) - err = scanner.UpdateDB() + err = scanner.UpdateDB(ctx) So(err, ShouldNotBeNil) // Check the download works correctly when permissions allow err = os.Chmod(firstRootDir, 0o777) So(err, ShouldBeNil) - err = scanner.UpdateDB() + err = scanner.UpdateDB(ctx) So(err, ShouldBeNil) // Download the DB to a substore location without permissions err = os.Chmod(secondRootDir, 0o000) So(err, ShouldBeNil) - err = scanner.UpdateDB() + err = scanner.UpdateDB(ctx) So(err, ShouldNotBeNil) err = os.Chmod(secondRootDir, 0o777) @@ -210,12 +220,22 @@ func TestTrivyLibraryErrors(t *testing.T) { // Download DB fails for missing DB url scanner := NewScanner(storeController, metaDB, "", "", log) - err = scanner.UpdateDB() + // Try to scan with a context done + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + + opts := scanner.getTrivyOptions(img) + _, err = scanner.runTrivy(cancelCtx, opts) + So(err, ShouldNotBeNil) + + ctx := context.Background() + + err = scanner.UpdateDB(ctx) So(err, ShouldNotBeNil) // Try to scan without the DB being downloaded - opts := scanner.getTrivyOptions(img) - _, err = scanner.runTrivy(opts) + opts = scanner.getTrivyOptions(img) + _, err = scanner.runTrivy(ctx, opts) So(err, ShouldNotBeNil) So(err, ShouldWrap, zerr.ErrCVEDBNotFound) @@ -223,37 +243,45 @@ func TestTrivyLibraryErrors(t *testing.T) { scanner = NewScanner(storeController, metaDB, "ghcr.io/project-zot/trivy-db", "ghcr.io/project-zot/trivy-not-db", log) - err = scanner.UpdateDB() + err = scanner.UpdateDB(ctx) So(err, ShouldNotBeNil) // Download DB passes for valid Trivy DB url, and missing Trivy Java DB url // Download DB is necessary since DB download on scan is disabled scanner = NewScanner(storeController, metaDB, "ghcr.io/project-zot/trivy-db", "", log) - err = scanner.UpdateDB() + err = scanner.UpdateDB(ctx) So(err, ShouldBeNil) + // UpdateDB with context done + err = scanner.UpdateDB(cancelCtx) + So(err, ShouldNotBeNil) + // Scanning image with correct options opts = scanner.getTrivyOptions(img) - _, err = scanner.runTrivy(opts) + _, err = scanner.runTrivy(ctx, opts) So(err, ShouldBeNil) + // Scanning image with context done + _, err = scanner.runTrivy(cancelCtx, opts) + So(err, ShouldNotBeNil) + // Scanning image with incorrect cache options // to trigger runner initialization errors opts.CacheOptions.CacheBackend = "redis://asdf!$%&!*)(" - _, err = scanner.runTrivy(opts) + _, err = scanner.runTrivy(ctx, opts) So(err, ShouldNotBeNil) // Scanning image with invalid input to trigger a scanner error opts = scanner.getTrivyOptions("nilnonexisting_image:0.0.1") - _, err = scanner.runTrivy(opts) + _, err = scanner.runTrivy(ctx, opts) So(err, ShouldNotBeNil) // Scanning image with incorrect report options // to trigger report filtering errors opts = scanner.getTrivyOptions(img) opts.ReportOptions.IgnorePolicy = "invalid file path" - _, err = scanner.runTrivy(opts) + _, err = scanner.runTrivy(ctx, opts) So(err, ShouldNotBeNil) }) } @@ -397,22 +425,23 @@ func TestDefaultTrivyDBUrl(t *testing.T) { scanner := NewScanner(storeController, metaDB, "ghcr.io/aquasecurity/trivy-db", "ghcr.io/aquasecurity/trivy-java-db", log) + ctx := context.Background() // Download DB since DB download on scan is disabled - err = scanner.UpdateDB() + err = scanner.UpdateDB(ctx) So(err, ShouldBeNil) // Scanning image img := "zot-test:0.0.1" //nolint:goconst opts := scanner.getTrivyOptions(img) - _, err = scanner.runTrivy(opts) + _, err = scanner.runTrivy(ctx, opts) So(err, ShouldBeNil) // Scanning image containing a jar file img = "zot-cve-java-test:0.0.1" opts = scanner.getTrivyOptions(img) - _, err = scanner.runTrivy(opts) + _, err = scanner.runTrivy(ctx, opts) So(err, ShouldBeNil) }) } diff --git a/pkg/extensions/search/cve/trivy/scanner_test.go b/pkg/extensions/search/cve/trivy/scanner_test.go index a4aa7470a7..ef69fac6c8 100644 --- a/pkg/extensions/search/cve/trivy/scanner_test.go +++ b/pkg/extensions/search/cve/trivy/scanner_test.go @@ -3,6 +3,7 @@ package trivy_test import ( + "context" "errors" "path/filepath" "testing" @@ -61,10 +62,10 @@ func TestScanBigTestFile(t *testing.T) { // scan scanner := trivy.NewScanner(ctlr.StoreController, ctlr.MetaDB, "ghcr.io/project-zot/trivy-db", "", ctlr.Log) - err = scanner.UpdateDB() + err = scanner.UpdateDB(context.Background()) So(err, ShouldBeNil) - cveMap, err := scanner.ScanImage("zot-test:0.0.1") + cveMap, err := scanner.ScanImage(context.Background(), "zot-test:0.0.1") So(err, ShouldBeNil) So(cveMap, ShouldNotBeNil) }) @@ -105,26 +106,28 @@ func TestScanningByDigest(t *testing.T) { // scan scanner := trivy.NewScanner(ctlr.StoreController, ctlr.MetaDB, "ghcr.io/project-zot/trivy-db", "", ctlr.Log) - err = scanner.UpdateDB() + ctx := context.Background() + + err = scanner.UpdateDB(ctx) So(err, ShouldBeNil) - cveMap, err := scanner.ScanImage("multi-arch@" + vulnImage.DigestStr()) + cveMap, err := scanner.ScanImage(ctx, "multi-arch@"+vulnImage.DigestStr()) So(err, ShouldBeNil) So(cveMap, ShouldContainKey, Vulnerability1ID) So(cveMap, ShouldContainKey, Vulnerability2ID) So(cveMap, ShouldContainKey, Vulnerability3ID) - cveMap, err = scanner.ScanImage("multi-arch@" + simpleImage.DigestStr()) + cveMap, err = scanner.ScanImage(ctx, "multi-arch@"+simpleImage.DigestStr()) So(err, ShouldBeNil) So(cveMap, ShouldBeEmpty) - cveMap, err = scanner.ScanImage("multi-arch@" + multiArch.DigestStr()) + cveMap, err = scanner.ScanImage(ctx, "multi-arch@"+multiArch.DigestStr()) So(err, ShouldBeNil) So(cveMap, ShouldContainKey, Vulnerability1ID) So(cveMap, ShouldContainKey, Vulnerability2ID) So(cveMap, ShouldContainKey, Vulnerability3ID) - cveMap, err = scanner.ScanImage("multi-arch:multi-arch-tag") + cveMap, err = scanner.ScanImage(ctx, "multi-arch:multi-arch-tag") So(err, ShouldBeNil) So(cveMap, ShouldContainKey, Vulnerability1ID) So(cveMap, ShouldContainKey, Vulnerability2ID) @@ -188,10 +191,10 @@ func TestVulnerableLayer(t *testing.T) { scanner := trivy.NewScanner(storeController, metaDB, "ghcr.io/project-zot/trivy-db", "", log) - err = scanner.UpdateDB() + err = scanner.UpdateDB(context.Background()) So(err, ShouldBeNil) - cveMap, err := scanner.ScanImage("repo@" + img.DigestStr()) + cveMap, err := scanner.ScanImage(context.Background(), "repo@"+img.DigestStr()) So(err, ShouldBeNil) t.Logf("cveMap: %v", cveMap) // As of September 17 2023 there are 5 CVEs: @@ -271,7 +274,7 @@ func TestScannerErrors(t *testing.T) { scanner := trivy.NewScanner(storeController, metaDB, "ghcr.io/project-zot/trivy-db", "", log) - _, err := scanner.ScanImage("image@" + godigest.FromString("digest").String()) + _, err := scanner.ScanImage(context.Background(), "image@"+godigest.FromString("digest").String()) So(err, ShouldNotBeNil) }) }) diff --git a/pkg/extensions/search/cve/update.go b/pkg/extensions/search/cve/update.go index 871bb97436..4befebf6c5 100644 --- a/pkg/extensions/search/cve/update.go +++ b/pkg/extensions/search/cve/update.go @@ -94,7 +94,7 @@ func newDBUpdadeTask(interval time.Duration, scanner Scanner, func (dbt *dbUpdateTask) DoWork(ctx context.Context) error { dbt.log.Info().Msg("updating the CVE database") - err := dbt.scanner.UpdateDB() + err := dbt.scanner.UpdateDB(ctx) if err != nil { dbt.generator.lock.Lock() dbt.generator.status = pending diff --git a/pkg/extensions/search/resolver.go b/pkg/extensions/search/resolver.go index a6949ac814..adfe0c0069 100644 --- a/pkg/extensions/search/resolver.go +++ b/pkg/extensions/search/resolver.go @@ -211,7 +211,7 @@ func getCVEListForImage( return &gql_generated.CVEResultForImage{}, gqlerror.Errorf("no reference provided") } - cveList, pageInfo, err := cveInfo.GetCVEListForImage(repo, ref, searchedCVE, pageInput) + cveList, pageInfo, err := cveInfo.GetCVEListForImage(ctx, repo, ref, searchedCVE, pageInput) if err != nil { return &gql_generated.CVEResultForImage{}, err } @@ -334,7 +334,7 @@ func getImageListForCVE( log.Info().Str("repository", repo).Str("CVE", cveID).Msg("extracting list of tags affected by CVE") - tagsInfo, err := cveInfo.GetImageListForCVE(repo, cveID) + tagsInfo, err := cveInfo.GetImageListForCVE(ctx, repo, cveID) if err != nil { log.Error().Str("repository", repo).Str("CVE", cveID).Err(err). Msg("error getting image list for CVE from repo") @@ -407,7 +407,7 @@ func getImageListWithCVEFixed( log.Info().Str("repository", repo).Str("CVE", cveID).Msg("extracting list of tags where CVE is fixed") - tagsInfo, err := cveInfo.GetImageListWithCVEFixed(repo, cveID) + tagsInfo, err := cveInfo.GetImageListWithCVEFixed(ctx, repo, cveID) if err != nil { log.Error().Str("repository", repo).Str("CVE", cveID).Err(err). Msg("error getting image list with CVE fixed from repo") diff --git a/pkg/extensions/search/resolver_test.go b/pkg/extensions/search/resolver_test.go index 0a7e7fd2d5..9629afbd64 100644 --- a/pkg/extensions/search/resolver_test.go +++ b/pkg/extensions/search/resolver_test.go @@ -1046,14 +1046,14 @@ func TestCVEResolvers(t *testing.T) { //nolint:gocyclo // MetaDB loaded with initial data, now mock the scanner // Setup test CVE data in mock scanner scanner := mocks.CveScannerMock{ - ScanImageFn: func(image string) (map[string]cvemodel.CVE, error) { + ScanImageFn: func(ctx context.Context, image string) (map[string]cvemodel.CVE, error) { repo, ref, _, _ := common.GetRepoReference(image) if common.IsDigest(ref) { return getCveResults(ref), nil } - repoMeta, _ := metaDB.GetRepoMeta(context.Background(), repo) + repoMeta, _ := metaDB.GetRepoMeta(ctx, repo) if _, ok := repoMeta.Tags[ref]; !ok { panic("unexpected tag '" + ref + "', test might be wrong") @@ -1685,7 +1685,7 @@ func TestCVEResolvers(t *testing.T) { //nolint:gocyclo ctx, "id", mocks.CveInfoMock{ - GetImageListForCVEFn: func(repo, cveID string) ([]cvemodel.TagInfo, error) { + GetImageListForCVEFn: func(ctx context.Context, repo, cveID string) ([]cvemodel.TagInfo, error) { return []cvemodel.TagInfo{}, ErrTestError }, }, diff --git a/pkg/extensions/search/search_test.go b/pkg/extensions/search/search_test.go index 9d3d92401c..e02ede3d61 100644 --- a/pkg/extensions/search/search_test.go +++ b/pkg/extensions/search/search_test.go @@ -292,7 +292,7 @@ func getMockCveScanner(metaDB mTypes.MetaDB) cveinfo.Scanner { } scanner := mocks.CveScannerMock{ - ScanImageFn: func(image string) (map[string]cvemodel.CVE, error) { + ScanImageFn: func(ctx context.Context, image string) (map[string]cvemodel.CVE, error) { return getCveResults(image), nil }, GetCachedResultFn: func(digestStr string) map[string]cvemodel.CVE { diff --git a/pkg/extensions/sync/service.go b/pkg/extensions/sync/service.go index a121c7cef1..725af08704 100644 --- a/pkg/extensions/sync/service.go +++ b/pkg/extensions/sync/service.go @@ -292,10 +292,8 @@ func (service *BaseService) SyncRepo(ctx context.Context, repo string) error { localRepo := service.contentManager.GetRepoDestination(repo) for _, tag := range tags { - select { - case <-ctx.Done(): + if common.IsContextDone(ctx) { return ctx.Err() - default: } if references.IsCosignTag(tag) || common.IsReferrersTag(tag) { diff --git a/pkg/extensions/sync/sync_test.go b/pkg/extensions/sync/sync_test.go index 35901ae285..97b78871d5 100644 --- a/pkg/extensions/sync/sync_test.go +++ b/pkg/extensions/sync/sync_test.go @@ -1883,9 +1883,6 @@ func TestConfigReloader(t *testing.T) { destConfig.Log.Output = logFile.Name() dctlr := api.NewController(destConfig) - dcm := test.NewControllerManager(dctlr) - - defer dcm.StopServer() //nolint: dupl Convey("Reload config without sync", func() { @@ -1927,6 +1924,8 @@ func TestConfigReloader(t *testing.T) { time.Sleep(100 * time.Millisecond) } + defer dctlr.Shutdown() + // let it sync time.Sleep(3 * time.Second) @@ -2075,6 +2074,8 @@ func TestConfigReloader(t *testing.T) { time.Sleep(100 * time.Millisecond) } + defer dctlr.Shutdown() + // let it sync time.Sleep(3 * time.Second) diff --git a/pkg/meta/boltdb/boltdb.go b/pkg/meta/boltdb/boltdb.go index a7a24298cf..795be2ea2f 100644 --- a/pkg/meta/boltdb/boltdb.go +++ b/pkg/meta/boltdb/boltdb.go @@ -1130,7 +1130,7 @@ func (bdw *BoltDB) UpdateStatsOnDownload(repo string, reference string) error { return err } -func (bdw *BoltDB) UpdateSignaturesValidity(repo string, manifestDigest godigest.Digest) error { +func (bdw *BoltDB) UpdateSignaturesValidity(ctx context.Context, repo string, manifestDigest godigest.Digest) error { err := bdw.DB.Update(func(transaction *bbolt.Tx) error { imgTrustStore := bdw.ImageTrustStore() @@ -1169,6 +1169,10 @@ func (bdw *BoltDB) UpdateSignaturesValidity(repo string, manifestDigest godigest manifestSignatures := proto_go.ManifestSignatures{Map: map[string]*proto_go.SignaturesInfo{"": {}}} for sigType, sigs := range protoRepoMeta.Signatures[manifestDigest.String()].Map { + if zcommon.IsContextDone(ctx) { + return ctx.Err() + } + signaturesInfo := []*proto_go.SignatureInfo{} for _, sigInfo := range sigs.List { diff --git a/pkg/meta/boltdb/boltdb_test.go b/pkg/meta/boltdb/boltdb_test.go index bd99373554..e93c9f31ea 100644 --- a/pkg/meta/boltdb/boltdb_test.go +++ b/pkg/meta/boltdb/boltdb_test.go @@ -159,16 +159,26 @@ func TestWrapperErrors(t *testing.T) { boltdbWrapper.SetImageTrustStore(imgTrustStore{}) digest := image.Digest() + ctx := context.Background() + Convey("image meta blob not found", func() { - err := boltdbWrapper.UpdateSignaturesValidity("repo", digest) + err := boltdbWrapper.UpdateSignaturesValidity(ctx, "repo", digest) So(err, ShouldBeNil) }) + Convey("UpdateSignaturesValidity with context done", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := boltdbWrapper.UpdateSignaturesValidity(ctx, "repo", digest) + So(err, ShouldNotBeNil) + }) + Convey("image meta unmarshal fail", func() { err := setImageMeta(digest, badProtoBlob, boltdbWrapper.DB) So(err, ShouldBeNil) - err = boltdbWrapper.UpdateSignaturesValidity("repo", digest) + err = boltdbWrapper.UpdateSignaturesValidity(ctx, "repo", digest) So(err, ShouldNotBeNil) }) @@ -176,7 +186,7 @@ func TestWrapperErrors(t *testing.T) { err := boltdbWrapper.SetImageMeta(digest, imageMeta) So(err, ShouldBeNil) - err = boltdbWrapper.UpdateSignaturesValidity("repo", digest) + err = boltdbWrapper.UpdateSignaturesValidity(ctx, "repo", digest) So(err, ShouldNotBeNil) }) @@ -187,7 +197,7 @@ func TestWrapperErrors(t *testing.T) { err = setRepoMeta("repo", badProtoBlob, boltdbWrapper.DB) So(err, ShouldBeNil) - err = boltdbWrapper.UpdateSignaturesValidity("repo", digest) + err = boltdbWrapper.UpdateSignaturesValidity(ctx, "repo", digest) So(err, ShouldNotBeNil) }) }) diff --git a/pkg/meta/dynamodb/dynamodb.go b/pkg/meta/dynamodb/dynamodb.go index cfb84e0ec7..26da330a97 100644 --- a/pkg/meta/dynamodb/dynamodb.go +++ b/pkg/meta/dynamodb/dynamodb.go @@ -988,20 +988,20 @@ func (dwr *DynamoDB) UpdateStatsOnDownload(repo string, reference string) error return dwr.setProtoRepoMeta(repo, repoMeta) } -func (dwr *DynamoDB) UpdateSignaturesValidity(repo string, manifestDigest godigest.Digest) error { +func (dwr *DynamoDB) UpdateSignaturesValidity(ctx context.Context, repo string, manifestDigest godigest.Digest) error { imgTrustStore := dwr.ImageTrustStore() if imgTrustStore == nil { return nil } - protoImageMeta, err := dwr.GetProtoImageMeta(context.Background(), manifestDigest) + protoImageMeta, err := dwr.GetProtoImageMeta(ctx, manifestDigest) if err != nil { return err } // update signatures with details about validity and author - protoRepoMeta, err := dwr.getProtoRepoMeta(context.Background(), repo) + protoRepoMeta, err := dwr.getProtoRepoMeta(ctx, repo) if err != nil { return err } @@ -1009,6 +1009,10 @@ func (dwr *DynamoDB) UpdateSignaturesValidity(repo string, manifestDigest godige manifestSignatures := proto_go.ManifestSignatures{Map: map[string]*proto_go.SignaturesInfo{"": {}}} for sigType, sigs := range protoRepoMeta.Signatures[manifestDigest.String()].Map { + if zcommon.IsContextDone(ctx) { + return ctx.Err() + } + signaturesInfo := []*proto_go.SignatureInfo{} for _, sigInfo := range sigs.List { @@ -1041,7 +1045,7 @@ func (dwr *DynamoDB) UpdateSignaturesValidity(repo string, manifestDigest godige protoRepoMeta.Signatures[manifestDigest.String()] = &manifestSignatures - return dwr.setProtoRepoMeta(protoRepoMeta.Name, protoRepoMeta) + return dwr.setProtoRepoMeta(protoRepoMeta.Name, protoRepoMeta) //nolint: contextcheck } func (dwr *DynamoDB) AddManifestSignature(repo string, signedManifestDigest godigest.Digest, diff --git a/pkg/meta/dynamodb/dynamodb_test.go b/pkg/meta/dynamodb/dynamodb_test.go index f1e716c940..91d3160020 100644 --- a/pkg/meta/dynamodb/dynamodb_test.go +++ b/pkg/meta/dynamodb/dynamodb_test.go @@ -163,6 +163,7 @@ func TestWrapperErrors(t *testing.T) { // t.FailNow() // } + //nolint: contextcheck Convey("Errors", t, func() { params := mdynamodb.DBDriverParameters{ //nolint:contextcheck Endpoint: endpoint, @@ -257,7 +258,15 @@ func TestWrapperErrors(t *testing.T) { digest := image.Digest() Convey("image meta blob not found", func() { - err := dynamoWrapper.UpdateSignaturesValidity("repo", digest) + err := dynamoWrapper.UpdateSignaturesValidity(ctx, "repo", digest) + So(err, ShouldNotBeNil) + }) + + Convey("UpdateSignaturesValidity with context done", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := dynamoWrapper.UpdateSignaturesValidity(ctx, "repo", digest) So(err, ShouldNotBeNil) }) @@ -265,7 +274,7 @@ func TestWrapperErrors(t *testing.T) { err := setImageMeta(digest, badProtoBlob, dynamoWrapper) So(err, ShouldBeNil) - err = dynamoWrapper.UpdateSignaturesValidity("repo", digest) + err = dynamoWrapper.UpdateSignaturesValidity(ctx, "repo", digest) So(err, ShouldNotBeNil) }) @@ -273,7 +282,7 @@ func TestWrapperErrors(t *testing.T) { err := dynamoWrapper.SetImageMeta(digest, imageMeta) So(err, ShouldBeNil) - err = dynamoWrapper.UpdateSignaturesValidity("repo", digest) + err = dynamoWrapper.UpdateSignaturesValidity(ctx, "repo", digest) So(err, ShouldNotBeNil) }) @@ -284,7 +293,7 @@ func TestWrapperErrors(t *testing.T) { err = setRepoMeta("repo", badProtoBlob, dynamoWrapper) So(err, ShouldBeNil) - err = dynamoWrapper.UpdateSignaturesValidity("repo", digest) + err = dynamoWrapper.UpdateSignaturesValidity(ctx, "repo", digest) So(err, ShouldNotBeNil) }) }) diff --git a/pkg/meta/meta_test.go b/pkg/meta/meta_test.go index 4c34327431..8108c44864 100644 --- a/pkg/meta/meta_test.go +++ b/pkg/meta/meta_test.go @@ -1366,7 +1366,7 @@ func RunMetaDBTests(t *testing.T, metaDB mTypes.MetaDB, preparationFuncs ...func }) So(err, ShouldBeNil) - err = metaDB.UpdateSignaturesValidity(repo1, image1.Digest()) + err = metaDB.UpdateSignaturesValidity(ctx, repo1, image1.Digest()) So(err, ShouldBeNil) repoData, err := metaDB.GetRepoMeta(ctx, repo1) @@ -1462,7 +1462,7 @@ func RunMetaDBTests(t *testing.T, metaDB mTypes.MetaDB, preparationFuncs ...func err = imagetrust.UploadCertificate(imgTrustStore.NotationStorage, certificateContent, "ca") So(err, ShouldBeNil) - err = metaDB.UpdateSignaturesValidity(repo, image1.Digest()) //nolint:contextcheck + err = metaDB.UpdateSignaturesValidity(ctx, repo, image1.Digest()) //nolint:contextcheck So(err, ShouldBeNil) repoData, err := metaDB.GetRepoMeta(ctx, repo) diff --git a/pkg/meta/parse.go b/pkg/meta/parse.go index f6979fbdd2..9e575338f4 100644 --- a/pkg/meta/parse.go +++ b/pkg/meta/parse.go @@ -300,7 +300,7 @@ func SetImageMetaFromInput(ctx context.Context, repo, reference, mediaType strin return err } - err = metaDB.UpdateSignaturesValidity(repo, signedManifestDigest) + err = metaDB.UpdateSignaturesValidity(ctx, repo, signedManifestDigest) if err != nil { log.Error().Err(err).Str("repository", repo).Str("reference", reference).Str("digest", signedManifestDigest.String()).Msg("load-repo: failed verify signatures validity for signed image") diff --git a/pkg/meta/parse_test.go b/pkg/meta/parse_test.go index 286d838e2b..ffe0953720 100644 --- a/pkg/meta/parse_test.go +++ b/pkg/meta/parse_test.go @@ -217,7 +217,9 @@ func TestParseStorageErrors(t *testing.T) { So(err, ShouldNotBeNil) }) Convey("UpdateSignaturesValidity errors", func() { - mockedMetaDB.UpdateSignaturesValidityFn = func(repo string, manifestDigest godigest.Digest) error { + mockedMetaDB.UpdateSignaturesValidityFn = func(ctx context.Context, repo string, + manifestDigest godigest.Digest, + ) error { return ErrTestError } err := meta.SetImageMetaFromInput(ctx, "repo", "tag", mediaType, goodNotationSignature.Digest(), diff --git a/pkg/meta/types/types.go b/pkg/meta/types/types.go index 0216b494c6..ecc3997a13 100644 --- a/pkg/meta/types/types.go +++ b/pkg/meta/types/types.go @@ -101,7 +101,7 @@ type MetaDB interface { //nolint:interfacebloat DeleteSignature(repo string, signedManifestDigest godigest.Digest, sigMeta SignatureMetadata) error // UpdateSignaturesValidity checks and updates signatures validity of a given manifest - UpdateSignaturesValidity(repo string, manifestDigest godigest.Digest) error + UpdateSignaturesValidity(ctx context.Context, repo string, manifestDigest godigest.Digest) error // IncrementRepoStars adds 1 to the star count of an image IncrementRepoStars(repo string) error diff --git a/pkg/retention/retention.go b/pkg/retention/retention.go index 35ed4e271f..6525159e16 100644 --- a/pkg/retention/retention.go +++ b/pkg/retention/retention.go @@ -1,6 +1,7 @@ package retention import ( + "context" "fmt" glob "github.com/bmatcuk/doublestar/v4" @@ -97,7 +98,7 @@ func (p policyManager) getRules(tagPolicy config.KeepTagsPolicy) []types.Rule { return rules } -func (p policyManager) GetRetainedTags(repoMeta mTypes.RepoMeta, index ispec.Index) []string { +func (p policyManager) GetRetainedTags(ctx context.Context, repoMeta mTypes.RepoMeta, index ispec.Index) []string { repo := repoMeta.Name matchedByName := make([]string, 0) @@ -134,6 +135,10 @@ func (p policyManager) GetRetainedTags(repoMeta mTypes.RepoMeta, index ispec.Ind grouped := p.groupCandidatesByTagPolicy(repo, candidates) for _, candidates := range grouped { + if zcommon.IsContextDone(ctx) { + return nil + } + retainCandidates := candidates.candidates // copy // tag rules rules := candidates.rules diff --git a/pkg/retention/types/types.go b/pkg/retention/types/types.go index 3e35dead9e..d960ba4c58 100644 --- a/pkg/retention/types/types.go +++ b/pkg/retention/types/types.go @@ -1,6 +1,7 @@ package types import ( + "context" "time" ispec "github.com/opencontainers/image-spec/specs-go/v1" @@ -21,7 +22,7 @@ type PolicyManager interface { HasDeleteReferrer(repo string) bool HasDeleteUntagged(repo string) bool HasTagRetention(repo string) bool - GetRetainedTags(repoMeta mTypes.RepoMeta, index ispec.Index) []string + GetRetainedTags(ctx context.Context, repoMeta mTypes.RepoMeta, index ispec.Index) []string } type Rule interface { diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index cc771d065c..53d500843c 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -5,6 +5,7 @@ import ( "context" "runtime" "sync" + "sync/atomic" "time" "zotregistry.io/zot/pkg/api/config" @@ -68,9 +69,11 @@ type Scheduler struct { waitingGenerators []*generator generatorsLock *sync.Mutex log log.Logger - stopCh chan struct{} RateLimit time.Duration NumWorkers int + workerChan chan Task + workerWg *sync.WaitGroup + isShuttingDown atomic.Bool } func NewScheduler(cfg *config.Config, logC log.Logger) *Scheduler { @@ -90,17 +93,20 @@ func NewScheduler(cfg *config.Config, logC log.Logger) *Scheduler { generators: generatorPQ, generatorsLock: new(sync.Mutex), log: log.Logger{Logger: sublogger}, - stopCh: make(chan struct{}), // default value RateLimit: rateLimit, NumWorkers: numWorkers, + workerChan: make(chan Task, numWorkers), + workerWg: new(sync.WaitGroup), } } -func (scheduler *Scheduler) poolWorker(ctx context.Context, numWorkers int, tasks chan Task) { - for i := 0; i < numWorkers; i++ { +func (scheduler *Scheduler) poolWorker(ctx context.Context) { + for i := 0; i < scheduler.NumWorkers; i++ { go func(workerID int) { - for task := range tasks { + defer scheduler.workerWg.Done() + + for task := range scheduler.workerChan { scheduler.log.Debug().Int("worker", workerID).Msg("scheduler: starting task") if err := task.DoWork(ctx); err != nil { @@ -113,33 +119,56 @@ func (scheduler *Scheduler) poolWorker(ctx context.Context, numWorkers int, task } } +func (scheduler *Scheduler) Shutdown() { + if !scheduler.inShutdown() { + scheduler.shutdown() + } + + scheduler.workerWg.Wait() +} + +func (scheduler *Scheduler) inShutdown() bool { + return scheduler.isShuttingDown.Load() +} + +func (scheduler *Scheduler) shutdown() { + close(scheduler.workerChan) + scheduler.isShuttingDown.Store(true) +} + func (scheduler *Scheduler) RunScheduler(ctx context.Context) { throttle := time.NewTicker(rateLimit).C numWorkers := scheduler.NumWorkers - tasksWorker := make(chan Task, numWorkers) + + // wait all workers to finish their work before exiting from Shutdown() + scheduler.workerWg.Add(numWorkers) // start worker pool - go scheduler.poolWorker(ctx, numWorkers, tasksWorker) + go scheduler.poolWorker(ctx) go func() { for { select { case <-ctx.Done(): - close(tasksWorker) - close(scheduler.stopCh) + if !scheduler.inShutdown() { + scheduler.shutdown() + } - scheduler.log.Debug().Msg("scheduler: received stop signal, exiting...") + scheduler.log.Debug().Msg("scheduler: received stop signal, gracefully shutting down...") return default: i := 0 for i < numWorkers { task := scheduler.getTask() + if task != nil { // push tasks into worker pool - scheduler.log.Debug().Msg("scheduler: pushing task into worker pool") - tasksWorker <- task + if !scheduler.inShutdown() { + scheduler.log.Debug().Msg("scheduler: pushing task into worker pool") + scheduler.workerChan <- task + } } i++ } @@ -251,17 +280,17 @@ func (scheduler *Scheduler) SubmitTask(task Task, priority Priority) { } // check if the scheduler it's still running in order to add the task to the channel - select { - case <-scheduler.stopCh: + if scheduler.inShutdown() { return - default: } select { - case <-scheduler.stopCh: - return case tasksQ <- task: scheduler.log.Info().Msg("scheduler: adding a new task") + default: + if scheduler.inShutdown() { + return + } } } diff --git a/pkg/storage/common/common.go b/pkg/storage/common/common.go index b058e81f35..0795b649a5 100644 --- a/pkg/storage/common/common.go +++ b/pkg/storage/common/common.go @@ -1052,7 +1052,7 @@ func newDedupeTask(imgStore storageTypes.ImageStore, digest godigest.Digest, ded func (dt *dedupeTask) DoWork(ctx context.Context) error { // run task - err := dt.imgStore.RunDedupeForDigest(dt.digest, dt.dedupe, dt.duplicateBlobs) //nolint: contextcheck + err := dt.imgStore.RunDedupeForDigest(ctx, dt.digest, dt.dedupe, dt.duplicateBlobs) //nolint: contextcheck if err != nil { // log it dt.log.Error().Err(err).Str("digest", dt.digest.String()).Msg("rebuild dedupe: failed to rebuild digest") diff --git a/pkg/storage/gc/gc.go b/pkg/storage/gc/gc.go index 0dace3d7b9..c826124c85 100644 --- a/pkg/storage/gc/gc.go +++ b/pkg/storage/gc/gc.go @@ -80,11 +80,11 @@ in any manifests referenced in repo's index.json It also gc referrers with missing subject if the Referrer Option is enabled It also gc untagged manifests. */ -func (gc GarbageCollect) CleanRepo(repo string) error { +func (gc GarbageCollect) CleanRepo(ctx context.Context, repo string) error { gc.log.Info().Str("module", "gc"). Msg(fmt.Sprintf("executing GC of orphaned blobs for %s", path.Join(gc.imgStore.RootDir(), repo))) - if err := gc.cleanRepo(repo); err != nil { + if err := gc.cleanRepo(ctx, repo); err != nil { errMessage := fmt.Sprintf("error while running GC for %s", path.Join(gc.imgStore.RootDir(), repo)) gc.log.Error().Err(err).Str("module", "gc").Msg(errMessage) gc.log.Info().Str("module", "gc"). @@ -99,7 +99,7 @@ func (gc GarbageCollect) CleanRepo(repo string) error { return nil } -func (gc GarbageCollect) cleanRepo(repo string) error { +func (gc GarbageCollect) cleanRepo(ctx context.Context, repo string) error { var lockLatency time.Time dir := path.Join(gc.imgStore.RootDir(), repo) @@ -127,12 +127,12 @@ func (gc GarbageCollect) cleanRepo(repo string) error { } // apply tags retention - if err := gc.removeTagsPerRetentionPolicy(repo, &index); err != nil { + if err := gc.removeTagsPerRetentionPolicy(ctx, repo, &index); err != nil { return err } // gc referrers manifests with missing subject and untagged manifests - if err := gc.removeManifestsPerRepoPolicy(repo, &index); err != nil { + if err := gc.removeManifestsPerRepoPolicy(ctx, repo, &index); err != nil { return err } @@ -146,20 +146,24 @@ func (gc GarbageCollect) cleanRepo(repo string) error { } // gc unreferenced blobs - if err := gc.removeUnreferencedBlobs(repo, gc.opts.Delay, gc.log); err != nil { + if err := gc.removeUnreferencedBlobs(ctx, repo, gc.opts.Delay, gc.log); err != nil { return err } return nil } -func (gc GarbageCollect) removeManifestsPerRepoPolicy(repo string, index *ispec.Index) error { +func (gc GarbageCollect) removeManifestsPerRepoPolicy(ctx context.Context, repo string, index *ispec.Index) error { var err error /* gc all manifests that have a missing subject, stop when neither gc(referrer and untagged) happened in a full loop over index.json. */ var stop bool for !stop { + if zcommon.IsContextDone(ctx) { + return ctx.Err() + } + var gcedReferrer bool var gcedUntagged bool @@ -349,22 +353,26 @@ func (gc GarbageCollect) removeReferrer(repo string, index *ispec.Index, manifes return gced, nil } -func (gc GarbageCollect) removeTagsPerRetentionPolicy(repo string, index *ispec.Index) error { +func (gc GarbageCollect) removeTagsPerRetentionPolicy(ctx context.Context, repo string, index *ispec.Index) error { if !gc.policyMgr.HasTagRetention(repo) { return nil } - repoMeta, err := gc.metaDB.GetRepoMeta(context.Background(), repo) + repoMeta, err := gc.metaDB.GetRepoMeta(ctx, repo) if err != nil { gc.log.Error().Err(err).Str("module", "gc").Str("repository", repo).Msg("can't retrieve repoMeta for repo") return err } - retainTags := gc.policyMgr.GetRetainedTags(repoMeta, *index) + retainTags := gc.policyMgr.GetRetainedTags(ctx, repoMeta, *index) // remove for _, desc := range index.Manifests { + if zcommon.IsContextDone(ctx) { + return ctx.Err() + } + // check tag tag, ok := getDescriptorTag(desc) if ok && !zcommon.Contains(retainTags, tag) { @@ -537,7 +545,7 @@ func (gc GarbageCollect) identifyManifestsReferencedInIndex(index ispec.Index, r } // removeUnreferencedBlobs gc all blobs which are not referenced by any manifest found in repo's index.json. -func (gc GarbageCollect) removeUnreferencedBlobs(repo string, delay time.Duration, log zlog.Logger, +func (gc GarbageCollect) removeUnreferencedBlobs(ctx context.Context, repo string, delay time.Duration, log zlog.Logger, ) error { gc.log.Debug().Str("module", "gc").Str("repository", repo).Msg("cleaning orphan blobs") @@ -572,6 +580,10 @@ func (gc GarbageCollect) removeUnreferencedBlobs(repo string, delay time.Duratio gcBlobs := make([]godigest.Digest, 0) for _, blob := range allBlobs { + if zcommon.IsContextDone(ctx) { + return ctx.Err() + } + digest := godigest.NewDigestFromEncoded(godigest.SHA256, blob) if err = digest.Validate(); err != nil { log.Error().Err(err).Str("module", "gc").Str("repository", repo).Str("digest", blob). @@ -841,5 +853,5 @@ func NewGCTask(imgStore types.ImageStore, gc GarbageCollect, repo string, func (gct *gcTask) DoWork(ctx context.Context) error { // run task - return gct.gc.CleanRepo(gct.repo) //nolint: contextcheck + return gct.gc.CleanRepo(ctx, gct.repo) //nolint: contextcheck } diff --git a/pkg/storage/gc/gc_internal_test.go b/pkg/storage/gc/gc_internal_test.go index 9b869df6bd..5fd7f941bd 100644 --- a/pkg/storage/gc/gc_internal_test.go +++ b/pkg/storage/gc/gc_internal_test.go @@ -292,6 +292,8 @@ func TestGarbageCollectIndexErrors(t *testing.T) { func TestGarbageCollectWithMockedImageStore(t *testing.T) { trueVal := true + ctx := context.Background() + Convey("Cover gc error paths", t, func(c C) { log := zlog.NewLogger("debug", "") audit := zlog.NewAuditLogger("debug", "") @@ -316,7 +318,7 @@ func TestGarbageCollectWithMockedImageStore(t *testing.T) { }, }, gcOptions, audit, log) - err := gc.cleanRepo(repoName) + err := gc.cleanRepo(ctx, repoName) So(err, ShouldNotBeNil) }) @@ -327,7 +329,7 @@ func TestGarbageCollectWithMockedImageStore(t *testing.T) { }, }, gcOptions, audit, log) - err := gc.removeUnreferencedBlobs("repo", time.Hour, log) + err := gc.removeUnreferencedBlobs(ctx, "repo", time.Hour, log) So(err, ShouldNotBeNil) }) @@ -366,7 +368,7 @@ func TestGarbageCollectWithMockedImageStore(t *testing.T) { }, }, gcOptions, audit, log) - err := gc.removeTagsPerRetentionPolicy("name", &ispec.Index{}) + err := gc.removeTagsPerRetentionPolicy(ctx, "name", &ispec.Index{}) So(err, ShouldNotBeNil) }) @@ -387,7 +389,7 @@ func TestGarbageCollectWithMockedImageStore(t *testing.T) { gc := NewGarbageCollect(imgStore, mocks.MetaDBMock{}, gcOptions, audit, log) - err = gc.cleanRepo(repoName) + err = gc.cleanRepo(ctx, repoName) So(err, ShouldNotBeNil) }) @@ -411,7 +413,7 @@ func TestGarbageCollectWithMockedImageStore(t *testing.T) { gc := NewGarbageCollect(imgStore, mocks.MetaDBMock{}, gcOptions, audit, log) - err = gc.cleanRepo(repoName) + err = gc.cleanRepo(ctx, repoName) So(err, ShouldNotBeNil) }) @@ -424,7 +426,7 @@ func TestGarbageCollectWithMockedImageStore(t *testing.T) { gc := NewGarbageCollect(imgStore, mocks.MetaDBMock{}, gcOptions, audit, log) - err := gc.cleanRepo(repoName) + err := gc.cleanRepo(ctx, repoName) So(err, ShouldNotBeNil) }) @@ -464,7 +466,7 @@ func TestGarbageCollectWithMockedImageStore(t *testing.T) { } gc := NewGarbageCollect(imgStore, mocks.MetaDBMock{}, gcOptions, audit, log) - err = gc.removeManifestsPerRepoPolicy(repoName, &ispec.Index{ + err = gc.removeManifestsPerRepoPolicy(ctx, repoName, &ispec.Index{ Manifests: []ispec.Descriptor{ { MediaType: ispec.MediaTypeImageIndex, @@ -493,7 +495,7 @@ func TestGarbageCollectWithMockedImageStore(t *testing.T) { gc := NewGarbageCollect(imgStore, mocks.MetaDBMock{}, gcOptions, audit, log) - err := gc.removeManifestsPerRepoPolicy(repoName, &ispec.Index{ + err := gc.removeManifestsPerRepoPolicy(ctx, repoName, &ispec.Index{ Manifests: []ispec.Descriptor{ { MediaType: ispec.MediaTypeImageManifest, @@ -534,7 +536,7 @@ func TestGarbageCollectWithMockedImageStore(t *testing.T) { } gc := NewGarbageCollect(imgStore, metaDB, gcOptions, audit, log) - err = gc.removeManifestsPerRepoPolicy(repoName, &ispec.Index{ + err = gc.removeManifestsPerRepoPolicy(ctx, repoName, &ispec.Index{ Manifests: []ispec.Descriptor{ { MediaType: ispec.MediaTypeImageManifest, @@ -612,7 +614,7 @@ func TestGarbageCollectWithMockedImageStore(t *testing.T) { gc := NewGarbageCollect(imgStore, mocks.MetaDBMock{}, gcOptions, audit, log) - err = gc.removeManifestsPerRepoPolicy(repoName, &returnedIndexImage) + err = gc.removeManifestsPerRepoPolicy(ctx, repoName, &returnedIndexImage) So(err, ShouldNotBeNil) }) @@ -643,7 +645,7 @@ func TestGarbageCollectWithMockedImageStore(t *testing.T) { gc := NewGarbageCollect(imgStore, mocks.MetaDBMock{}, gcOptions, audit, log) - err = gc.removeManifestsPerRepoPolicy(repoName, &ispec.Index{ + err = gc.removeManifestsPerRepoPolicy(ctx, repoName, &ispec.Index{ Manifests: []ispec.Descriptor{ manifestDesc, }, diff --git a/pkg/storage/gc/gc_test.go b/pkg/storage/gc/gc_test.go index dfa7aec02d..32be51f91e 100644 --- a/pkg/storage/gc/gc_test.go +++ b/pkg/storage/gc/gc_test.go @@ -159,6 +159,8 @@ func TestGarbageCollectAndRetention(t *testing.T) { storeController := storage.StoreController{} storeController.DefaultStore = imgStore + ctx := context.Background() + Convey("setup gc images", t, func() { // for gc testing // basic images @@ -244,10 +246,10 @@ func TestGarbageCollectAndRetention(t *testing.T) { err = WriteImageToFileSystem(gcNew3, "retention", "0.0.6", storeController) So(err, ShouldBeNil) - err = meta.ParseStorage(metaDB, storeController, log) + err = meta.ParseStorage(metaDB, storeController, log) //nolint: contextcheck So(err, ShouldBeNil) - retentionMeta, err := metaDB.GetRepoMeta(context.Background(), "retention") + retentionMeta, err := metaDB.GetRepoMeta(ctx, "retention") So(err, ShouldBeNil) // update timestamps for image retention @@ -305,16 +307,16 @@ func TestGarbageCollectAndRetention(t *testing.T) { }, }, audit, log) - err := gc.CleanRepo("gc-test1") + err := gc.CleanRepo(ctx, "gc-test1") So(err, ShouldBeNil) - err = gc.CleanRepo("gc-test2") + err = gc.CleanRepo(ctx, "gc-test2") So(err, ShouldBeNil) - err = gc.CleanRepo("gc-test3") + err = gc.CleanRepo(ctx, "gc-test3") So(err, ShouldBeNil) - err = gc.CleanRepo("retention") + err = gc.CleanRepo(ctx, "retention") So(err, ShouldBeNil) _, _, _, err = imgStore.GetImageManifest("gc-test1", gcTest1.DigestStr()) @@ -388,16 +390,16 @@ func TestGarbageCollectAndRetention(t *testing.T) { }, }, audit, log) - err := gc.CleanRepo("gc-test1") + err := gc.CleanRepo(ctx, "gc-test1") So(err, ShouldBeNil) - err = gc.CleanRepo("gc-test2") + err = gc.CleanRepo(ctx, "gc-test2") So(err, ShouldBeNil) - err = gc.CleanRepo("gc-test3") + err = gc.CleanRepo(ctx, "gc-test3") So(err, ShouldBeNil) - err = gc.CleanRepo("retention") + err = gc.CleanRepo(ctx, "retention") So(err, ShouldBeNil) _, _, _, err = imgStore.GetImageManifest("gc-test1", gcTest1.DigestStr()) @@ -457,7 +459,7 @@ func TestGarbageCollectAndRetention(t *testing.T) { }, }, audit, log) - err := gc.CleanRepo("gc-test1") + err := gc.CleanRepo(ctx, "gc-test1") So(err, ShouldBeNil) _, _, _, err = imgStore.GetImageManifest("gc-test1", gcUntagged1.DigestStr()) @@ -503,7 +505,7 @@ func TestGarbageCollectAndRetention(t *testing.T) { }, }, audit, log) - err := gc.CleanRepo("gc-test1") + err := gc.CleanRepo(ctx, "gc-test1") So(err, ShouldBeNil) _, _, _, err = imgStore.GetImageManifest("gc-test1", gcUntagged1.DigestStr()) @@ -549,7 +551,7 @@ func TestGarbageCollectAndRetention(t *testing.T) { }, }, audit, log) - err = gc.CleanRepo("retention") + err = gc.CleanRepo(ctx, "retention") So(err, ShouldBeNil) _, _, _, err = imgStore.GetImageManifest("gc-test1", "0.0.1") @@ -607,7 +609,7 @@ func TestGarbageCollectAndRetention(t *testing.T) { }, }, audit, log) - err = gc.CleanRepo("retention") + err = gc.CleanRepo(ctx, "retention") So(err, ShouldBeNil) tags, err := imgStore.GetImageTags("retention") @@ -643,7 +645,7 @@ func TestGarbageCollectAndRetention(t *testing.T) { }, }, audit, log) - err = gc.CleanRepo("retention") + err = gc.CleanRepo(ctx, "retention") So(err, ShouldBeNil) tags, err := imgStore.GetImageTags("retention") @@ -679,7 +681,7 @@ func TestGarbageCollectAndRetention(t *testing.T) { }, }, audit, log) - err = gc.CleanRepo("retention") + err = gc.CleanRepo(ctx, "retention") So(err, ShouldBeNil) tags, err := imgStore.GetImageTags("retention") @@ -716,7 +718,7 @@ func TestGarbageCollectAndRetention(t *testing.T) { }, }, audit, log) - err = gc.CleanRepo("retention") + err = gc.CleanRepo(ctx, "retention") So(err, ShouldBeNil) tags, err := imgStore.GetImageTags("retention") @@ -759,7 +761,7 @@ func TestGarbageCollectAndRetention(t *testing.T) { }, }, audit, log) - err = gc.CleanRepo("retention") + err = gc.CleanRepo(ctx, "retention") So(err, ShouldBeNil) tags, err := imgStore.GetImageTags("retention") @@ -789,7 +791,7 @@ func TestGarbageCollectAndRetention(t *testing.T) { }, }, audit, log) - err := gc.CleanRepo("gc-test1") + err := gc.CleanRepo(ctx, "gc-test1") So(err, ShouldBeNil) _, _, _, err = imgStore.GetImageManifest("gc-test1", gcUntagged1.DigestStr()) @@ -833,7 +835,7 @@ func TestGarbageCollectAndRetention(t *testing.T) { }, }, audit, log) - err = gc.CleanRepo("retention") + err = gc.CleanRepo(ctx, "retention") So(err, ShouldBeNil) _, _, _, err = imgStore.GetImageManifest("retention", "0.0.1") diff --git a/pkg/storage/imagestore/imagestore.go b/pkg/storage/imagestore/imagestore.go index 57a845f73c..00f663d773 100644 --- a/pkg/storage/imagestore/imagestore.go +++ b/pkg/storage/imagestore/imagestore.go @@ -2,6 +2,7 @@ package imagestore import ( "bytes" + "context" "crypto/sha256" "encoding/json" "errors" @@ -1809,7 +1810,7 @@ func (is *ImageStore) getOriginalBlob(digest godigest.Digest, duplicateBlobs []s return originalBlob, nil } -func (is *ImageStore) dedupeBlobs(digest godigest.Digest, duplicateBlobs []string) error { +func (is *ImageStore) dedupeBlobs(ctx context.Context, digest godigest.Digest, duplicateBlobs []string) error { if fmt.Sprintf("%v", is.cache) == fmt.Sprintf("%v", nil) { is.log.Error().Err(zerr.ErrDedupeRebuild).Msg("no cache driver found, can not dedupe blobs") @@ -1822,6 +1823,10 @@ func (is *ImageStore) dedupeBlobs(digest godigest.Digest, duplicateBlobs []strin // rebuild from dedupe false to true for _, blobPath := range duplicateBlobs { + if zcommon.IsContextDone(ctx) { + return ctx.Err() + } + binfo, err := is.storeDriver.Stat(blobPath) if err != nil { is.log.Error().Err(err).Str("path", blobPath).Msg("rebuild dedupe: failed to stat blob") @@ -1882,7 +1887,7 @@ func (is *ImageStore) dedupeBlobs(digest godigest.Digest, duplicateBlobs []strin return nil } -func (is *ImageStore) restoreDedupedBlobs(digest godigest.Digest, duplicateBlobs []string) error { +func (is *ImageStore) restoreDedupedBlobs(ctx context.Context, digest godigest.Digest, duplicateBlobs []string) error { is.log.Info().Str("digest", digest.String()).Msg("rebuild dedupe: restoring deduped blobs for digest") // first we need to find the original blob, either in cache or by checking each blob size @@ -1894,6 +1899,10 @@ func (is *ImageStore) restoreDedupedBlobs(digest godigest.Digest, duplicateBlobs } for _, blobPath := range duplicateBlobs { + if zcommon.IsContextDone(ctx) { + return ctx.Err() + } + binfo, err := is.storeDriver.Stat(blobPath) if err != nil { is.log.Error().Err(err).Str("path", blobPath).Msg("rebuild dedupe: failed to stat blob") @@ -1924,17 +1933,19 @@ func (is *ImageStore) restoreDedupedBlobs(digest godigest.Digest, duplicateBlobs return nil } -func (is *ImageStore) RunDedupeForDigest(digest godigest.Digest, dedupe bool, duplicateBlobs []string) error { +func (is *ImageStore) RunDedupeForDigest(ctx context.Context, digest godigest.Digest, dedupe bool, + duplicateBlobs []string, +) error { var lockLatency time.Time is.Lock(&lockLatency) defer is.Unlock(&lockLatency) if dedupe { - return is.dedupeBlobs(digest, duplicateBlobs) + return is.dedupeBlobs(ctx, digest, duplicateBlobs) } - return is.restoreDedupedBlobs(digest, duplicateBlobs) + return is.restoreDedupedBlobs(ctx, digest, duplicateBlobs) } func (is *ImageStore) RunDedupeBlobs(interval time.Duration, sch *scheduler.Scheduler) { diff --git a/pkg/storage/local/local_test.go b/pkg/storage/local/local_test.go index 60452b453b..ec41ae7bfb 100644 --- a/pkg/storage/local/local_test.go +++ b/pkg/storage/local/local_test.go @@ -1149,7 +1149,7 @@ func FuzzRunGCRepo(f *testing.F) { ImageRetention: DeleteReferrers, }, audit, log) - if err := gc.CleanRepo(data); err != nil { + if err := gc.CleanRepo(context.Background(), data); err != nil { t.Error(err) } }) @@ -1359,7 +1359,7 @@ func TestDedupeLinks(t *testing.T) { err := os.Remove(path.Join(dir, "dedupe1", "blobs", "sha256", blobDigest1)) So(err, ShouldBeNil) - err = imgStore.RunDedupeForDigest(godigest.Digest(blobDigest1), true, duplicateBlobs) + err = imgStore.RunDedupeForDigest(context.TODO(), godigest.Digest(blobDigest1), true, duplicateBlobs) So(err, ShouldNotBeNil) }) @@ -2001,9 +2001,12 @@ func TestInjectWriteFile(t *testing.T) { } func TestGarbageCollectForImageStore(t *testing.T) { + //nolint: contextcheck Convey("Garbage collect for a specific repo from an ImageStore", t, func(c C) { dir := t.TempDir() + ctx := context.Background() + Convey("Garbage collect error for repo with config removed", func() { logFile, _ := os.CreateTemp("", "zot-log*.txt") @@ -2039,7 +2042,7 @@ func TestGarbageCollectForImageStore(t *testing.T) { panic(err) } - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldNotBeNil) time.Sleep(500 * time.Millisecond) @@ -2081,7 +2084,7 @@ func TestGarbageCollectForImageStore(t *testing.T) { So(os.Chmod(path.Join(dir, repoName, "index.json"), 0o000), ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldNotBeNil) time.Sleep(500 * time.Millisecond) @@ -2163,7 +2166,7 @@ func TestGarbageCollectForImageStore(t *testing.T) { err = WriteImageToFileSystem(cosignWithReferrersSig, repoName, "cosign", storeController) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) }) }) @@ -2171,6 +2174,8 @@ func TestGarbageCollectForImageStore(t *testing.T) { func TestGarbageCollectImageUnknownManifest(t *testing.T) { Convey("Garbage collect with short delay", t, func() { + ctx := context.Background() + dir := t.TempDir() log := zlog.NewLogger("debug", "") @@ -2276,7 +2281,7 @@ func TestGarbageCollectImageUnknownManifest(t *testing.T) { time.Sleep(1 * time.Second) Convey("Garbage collect blobs referenced by manifest with unsupported media type", func() { - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) _, _, _, err = imgStore.GetImageManifest(repoName, img.DigestStr()) @@ -2313,7 +2318,7 @@ func TestGarbageCollectImageUnknownManifest(t *testing.T) { err = imgStore.DeleteImageManifest(repoName, img.DigestStr(), true) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) _, _, _, err = imgStore.GetImageManifest(repoName, img.DigestStr()) @@ -2350,6 +2355,8 @@ func TestGarbageCollectImageUnknownManifest(t *testing.T) { func TestGarbageCollectErrors(t *testing.T) { Convey("Make image store", t, func(c C) { + ctx := context.Background() + dir := t.TempDir() log := zlog.NewLogger("debug", "") @@ -2456,7 +2463,7 @@ func TestGarbageCollectErrors(t *testing.T) { time.Sleep(500 * time.Millisecond) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldNotBeNil) }) @@ -2507,14 +2514,14 @@ func TestGarbageCollectErrors(t *testing.T) { time.Sleep(500 * time.Millisecond) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldNotBeNil) // trigger Unmarshal error _, err = os.Create(imgStore.BlobPath(repoName, digest)) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldNotBeNil) }) @@ -2564,7 +2571,7 @@ func TestGarbageCollectErrors(t *testing.T) { time.Sleep(500 * time.Millisecond) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) // blob shouldn't be gc'ed //TODO check this one diff --git a/pkg/storage/s3/s3_test.go b/pkg/storage/s3/s3_test.go index 6ba070b262..8bd7e1d5e4 100644 --- a/pkg/storage/s3/s3_test.go +++ b/pkg/storage/s3/s3_test.go @@ -2283,7 +2283,7 @@ func TestRebuildDedupeMockStoreDriver(t *testing.T) { digest, duplicateBlobs, err := imgStore.GetNextDigestWithBlobPaths([]string{"path/to"}, []godigest.Digest{}) So(err, ShouldBeNil) - err = imgStore.RunDedupeForDigest(digest, false, duplicateBlobs) + err = imgStore.RunDedupeForDigest(context.TODO(), digest, false, duplicateBlobs) So(err, ShouldNotBeNil) }) @@ -2332,7 +2332,7 @@ func TestRebuildDedupeMockStoreDriver(t *testing.T) { digest, duplicateBlobs, err := imgStore.GetNextDigestWithBlobPaths([]string{"path/to"}, []godigest.Digest{}) So(err, ShouldBeNil) - err = imgStore.RunDedupeForDigest(digest, false, duplicateBlobs) + err = imgStore.RunDedupeForDigest(context.TODO(), digest, false, duplicateBlobs) So(err, ShouldNotBeNil) }) @@ -2381,7 +2381,7 @@ func TestRebuildDedupeMockStoreDriver(t *testing.T) { digest, duplicateBlobs, err := imgStore.GetNextDigestWithBlobPaths([]string{"path/to"}, []godigest.Digest{}) So(err, ShouldBeNil) - err = imgStore.RunDedupeForDigest(digest, false, duplicateBlobs) + err = imgStore.RunDedupeForDigest(context.TODO(), digest, false, duplicateBlobs) So(err, ShouldNotBeNil) }) @@ -2427,7 +2427,7 @@ func TestRebuildDedupeMockStoreDriver(t *testing.T) { digest, duplicateBlobs, err := imgStore.GetNextDigestWithBlobPaths([]string{"path/to"}, []godigest.Digest{}) So(err, ShouldBeNil) - err = imgStore.RunDedupeForDigest(digest, false, duplicateBlobs) + err = imgStore.RunDedupeForDigest(context.TODO(), digest, false, duplicateBlobs) So(err, ShouldNotBeNil) Convey("Trigger Stat() error in dedupeBlobs()", func() { @@ -2472,7 +2472,7 @@ func TestRebuildDedupeMockStoreDriver(t *testing.T) { digest, duplicateBlobs, err := imgStore.GetNextDigestWithBlobPaths([]string{"path/to"}, []godigest.Digest{}) So(err, ShouldBeNil) - err = imgStore.RunDedupeForDigest(digest, false, duplicateBlobs) + err = imgStore.RunDedupeForDigest(context.TODO(), digest, false, duplicateBlobs) So(err, ShouldNotBeNil) }) }) @@ -2523,7 +2523,7 @@ func TestRebuildDedupeMockStoreDriver(t *testing.T) { digest, duplicateBlobs, err := imgStore.GetNextDigestWithBlobPaths([]string{"path/to"}, []godigest.Digest{}) So(err, ShouldBeNil) - err = imgStore.RunDedupeForDigest(digest, true, duplicateBlobs) + err = imgStore.RunDedupeForDigest(context.TODO(), digest, true, duplicateBlobs) So(err, ShouldNotBeNil) }) @@ -2571,7 +2571,7 @@ func TestRebuildDedupeMockStoreDriver(t *testing.T) { digest, duplicateBlobs, err := imgStore.GetNextDigestWithBlobPaths([]string{"path/to"}, []godigest.Digest{}) So(err, ShouldBeNil) - err = imgStore.RunDedupeForDigest(digest, true, duplicateBlobs) + err = imgStore.RunDedupeForDigest(context.TODO(), digest, true, duplicateBlobs) So(err, ShouldNotBeNil) }) @@ -2619,7 +2619,7 @@ func TestRebuildDedupeMockStoreDriver(t *testing.T) { digest, duplicateBlobs, err := imgStore.GetNextDigestWithBlobPaths([]string{"path/to"}, []godigest.Digest{}) So(err, ShouldBeNil) - err = imgStore.RunDedupeForDigest(digest, true, duplicateBlobs) + err = imgStore.RunDedupeForDigest(context.TODO(), digest, true, duplicateBlobs) So(err, ShouldNotBeNil) }) @@ -2726,7 +2726,7 @@ func TestRebuildDedupeMockStoreDriver(t *testing.T) { digest, duplicateBlobs, err := imgStore.GetNextDigestWithBlobPaths([]string{"path/to"}, []godigest.Digest{}) So(err, ShouldBeNil) - err = imgStore.RunDedupeForDigest(digest, true, duplicateBlobs) + err = imgStore.RunDedupeForDigest(context.TODO(), digest, true, duplicateBlobs) So(err, ShouldNotBeNil) }) @@ -2748,7 +2748,7 @@ func TestRebuildDedupeMockStoreDriver(t *testing.T) { digest, duplicateBlobs, err := imgStore.GetNextDigestWithBlobPaths([]string{"path/to"}, []godigest.Digest{}) So(err, ShouldBeNil) - err = imgStore.RunDedupeForDigest(digest, true, duplicateBlobs) + err = imgStore.RunDedupeForDigest(context.TODO(), digest, true, duplicateBlobs) So(err, ShouldNotBeNil) }) @@ -2766,7 +2766,7 @@ func TestRebuildDedupeMockStoreDriver(t *testing.T) { digest, duplicateBlobs, err := imgStore.GetNextDigestWithBlobPaths([]string{"path/to"}, []godigest.Digest{}) So(err, ShouldBeNil) - err = imgStore.RunDedupeForDigest(digest, true, duplicateBlobs) + err = imgStore.RunDedupeForDigest(context.TODO(), digest, true, duplicateBlobs) So(err, ShouldNotBeNil) }) }) diff --git a/pkg/storage/scrub.go b/pkg/storage/scrub.go index 3c4f31c62d..45bfa78cf6 100644 --- a/pkg/storage/scrub.go +++ b/pkg/storage/scrub.go @@ -13,6 +13,7 @@ import ( ispec "github.com/opencontainers/image-spec/specs-go/v1" "zotregistry.io/zot/errors" + "zotregistry.io/zot/pkg/common" storageTypes "zotregistry.io/zot/pkg/storage/types" ) @@ -87,10 +88,6 @@ func CheckImageStoreBlobsIntegrity(ctx context.Context, imgStore storageTypes.Im func CheckRepo(ctx context.Context, imageName string, imgStore storageTypes.ImageStore) ([]ScrubImageResult, error) { results := []ScrubImageResult{} - if ctx.Err() != nil { - return results, ctx.Err() - } - var lockLatency time.Time imgStore.RLock(&lockLatency) @@ -120,6 +117,10 @@ func CheckRepo(ctx context.Context, imageName string, imgStore storageTypes.Imag scrubbedManifests := make(map[godigest.Digest]ScrubImageResult) for _, manifest := range index.Manifests { + if common.IsContextDone(ctx) { + return results, ctx.Err() + } + tag := manifest.Annotations[ispec.AnnotationRefName] scrubManifest(ctx, manifest, imgStore, imageName, tag, scrubbedManifests) results = append(results, scrubbedManifests[manifest.Digest]) @@ -159,12 +160,16 @@ func scrubManifest( } // check all manifests - for _, m := range idx.Manifests { - scrubManifest(ctx, m, imgStore, imageName, tag, scrubbedManifests) + for _, man := range idx.Manifests { + if common.IsContextDone(ctx) { + return + } + + scrubManifest(ctx, man, imgStore, imageName, tag, scrubbedManifests) // if the manifest is affected then this index is also affected - if scrubbedManifests[m.Digest].Error != "" { - mRes := scrubbedManifests[m.Digest] + if scrubbedManifests[man.Digest].Error != "" { + mRes := scrubbedManifests[man.Digest] scrubbedManifests[manifest.Digest] = newScrubImageResult(imageName, tag, mRes.Status, mRes.AffectedBlob, mRes.Error) @@ -226,6 +231,7 @@ func CheckIntegrity( func CheckManifestAndConfig( imageName string, manifestDesc ispec.Descriptor, imgStore storageTypes.ImageStore, ) (godigest.Digest, error) { + // Q oras artifacts? if manifestDesc.MediaType != ispec.MediaTypeImageManifest { return manifestDesc.Digest, errors.ErrBadManifest } diff --git a/pkg/storage/scrub_test.go b/pkg/storage/scrub_test.go index 80d403af00..fab04087b5 100644 --- a/pkg/storage/scrub_test.go +++ b/pkg/storage/scrub_test.go @@ -10,6 +10,7 @@ import ( "regexp" "strings" "testing" + "time" "github.com/docker/distribution/registry/storage/driver" guuid "github.com/gofrs/uuid" @@ -126,6 +127,22 @@ func RunCheckAllBlobsIntegrityTests( //nolint: thelper So(actual, ShouldContainSubstring, "test 1.0 ok") }) + Convey("Blobs integrity with context done", func() { + buff := bytes.NewBufferString("") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + res, err := storeCtlr.CheckAllBlobsIntegrity(ctx) + res.PrintScrubResults(buff) + So(err, ShouldBeNil) + + space := regexp.MustCompile(`\s+`) + str := space.ReplaceAllString(buff.String(), " ") + actual := strings.TrimSpace(str) + So(actual, ShouldContainSubstring, "REPOSITORY TAG STATUS AFFECTED BLOB ERROR") + So(actual, ShouldNotContainSubstring, "test 1.0 ok") + }) + Convey("Manifest integrity affected", func() { // get content of manifest file content, _, _, err := imgStore.GetImageManifest(repoName, manifestDigest.String()) @@ -364,6 +381,22 @@ func RunCheckAllBlobsIntegrityTests( //nolint: thelper So(actual, ShouldContainSubstring, "test 1.0 ok") So(actual, ShouldContainSubstring, "test ok") + // test scrub context done + buff = bytes.NewBufferString("") + + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Millisecond)) + defer cancel() + + res, err = storeCtlr.CheckAllBlobsIntegrity(ctx) + res.PrintScrubResults(buff) + So(err, ShouldBeNil) + + str = space.ReplaceAllString(buff.String(), " ") + actual = strings.TrimSpace(str) + So(actual, ShouldContainSubstring, "REPOSITORY TAG STATUS AFFECTED BLOB ERROR") + So(actual, ShouldNotContainSubstring, "test 1.0 ok") + So(actual, ShouldNotContainSubstring, "test ok") + // test scrub index - errors // delete content of manifest file manifestFile := path.Join(imgStore.RootDir(), repoName, "/blobs/sha256", newManifestDigest.Encoded()) diff --git a/pkg/storage/storage_test.go b/pkg/storage/storage_test.go index f9dde13d1b..51b3c7187f 100644 --- a/pkg/storage/storage_test.go +++ b/pkg/storage/storage_test.go @@ -1481,6 +1481,9 @@ func TestGarbageCollectImageManifest(t *testing.T) { metrics := monitoring.NewMetricsServer(false, log) + ctx := context.Background() + + //nolint: contextcheck Convey("Repo layout", t, func(c C) { Convey("Garbage collect with default/long delay", func() { var imgStore storageTypes.ImageStore @@ -1578,7 +1581,7 @@ func TestGarbageCollectImageManifest(t *testing.T) { _, _, err = imgStore.PutImageManifest(repoName, tag, ispec.MediaTypeImageManifest, manifestBuf) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) // put artifact referencing above image @@ -1622,7 +1625,7 @@ func TestGarbageCollectImageManifest(t *testing.T) { ispec.MediaTypeImageManifest, artifactManifestBuf) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) hasBlob, _, err = imgStore.CheckBlob(repoName, bdigest) @@ -1636,7 +1639,7 @@ func TestGarbageCollectImageManifest(t *testing.T) { err = imgStore.DeleteImageManifest(repoName, digest.String(), false) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) hasBlob, _, err = imgStore.CheckBlob(repoName, bdigest) @@ -1872,7 +1875,7 @@ func TestGarbageCollectImageManifest(t *testing.T) { _, _, _, err = imgStore.GetImageManifest(repoName, orasDigest.String()) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) hasBlob, _, err = imgStore.CheckBlob(repoName, odigest) @@ -1898,7 +1901,7 @@ func TestGarbageCollectImageManifest(t *testing.T) { err = imgStore.DeleteImageManifest(repoName, digest.String(), false) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) hasBlob, _, err = imgStore.CheckBlob(repoName, bdigest) @@ -1935,7 +1938,7 @@ func TestGarbageCollectImageManifest(t *testing.T) { err = imgStore.DeleteImageManifest(repoName, tag, false) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) hasBlob, _, err = imgStore.CheckBlob(repoName, bdigest) @@ -2181,7 +2184,7 @@ func TestGarbageCollectImageManifest(t *testing.T) { _, _, err = imgStore.PutImageManifest(repo2Name, tag, ispec.MediaTypeImageManifest, manifestBuf) So(err, ShouldBeNil) - err = gc.CleanRepo(repo2Name) + err = gc.CleanRepo(ctx, repo2Name) So(err, ShouldBeNil) // original blob should exist @@ -2206,6 +2209,9 @@ func TestGarbageCollectImageIndex(t *testing.T) { metrics := monitoring.NewMetricsServer(false, log) + ctx := context.Background() + + //nolint: contextcheck Convey("Repo layout", t, func(c C) { Convey("Garbage collect with default/long delay", func() { var imgStore storageTypes.ImageStore @@ -2303,7 +2309,7 @@ func TestGarbageCollectImageIndex(t *testing.T) { ispec.MediaTypeImageManifest, artifactManifestBuf) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) hasBlob, _, err := imgStore.CheckBlob(repoName, bdgst) @@ -2314,7 +2320,7 @@ func TestGarbageCollectImageIndex(t *testing.T) { err = imgStore.DeleteImageManifest(repoName, indexDigest.String(), false) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) hasBlob, _, err = imgStore.CheckBlob(repoName, bdgst) @@ -2516,7 +2522,7 @@ func TestGarbageCollectImageIndex(t *testing.T) { _, _, _, err = imgStore.GetImageManifest(repoName, orasDigest.String()) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) _, _, _, err = imgStore.GetImageManifest(repoName, orasDigest.String()) @@ -2537,7 +2543,7 @@ func TestGarbageCollectImageIndex(t *testing.T) { time.Sleep(2 * time.Second) Convey("delete inner referenced manifest", func() { - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) // check orphan artifact is gc'ed @@ -2556,7 +2562,7 @@ func TestGarbageCollectImageIndex(t *testing.T) { err = imgStore.DeleteImageManifest(repoName, artifactDigest.String(), false) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) _, _, _, err = imgStore.GetImageManifest(repoName, artifactOfArtifactManifestDigest.String()) @@ -2570,7 +2576,7 @@ func TestGarbageCollectImageIndex(t *testing.T) { }) Convey("delete index manifest, references should not be persisted", func() { - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) // check orphan artifact is gc'ed @@ -2589,7 +2595,7 @@ func TestGarbageCollectImageIndex(t *testing.T) { err = imgStore.DeleteImageManifest(repoName, indexDigest.String(), false) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) _, _, _, err = imgStore.GetImageManifest(repoName, artifactDigest.String()) @@ -2652,6 +2658,9 @@ func TestGarbageCollectChainedImageIndexes(t *testing.T) { metrics := monitoring.NewMetricsServer(false, log) + ctx := context.Background() + + //nolint: contextcheck Convey("Garbage collect with short delay", t, func() { var imgStore storageTypes.ImageStore @@ -3019,7 +3028,7 @@ func TestGarbageCollectChainedImageIndexes(t *testing.T) { _, _, _, err = imgStore.GetImageManifest(repoName, orasDigest.String()) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) _, _, _, err = imgStore.GetImageManifest(repoName, orasDigest.String()) @@ -3040,7 +3049,7 @@ func TestGarbageCollectChainedImageIndexes(t *testing.T) { time.Sleep(5 * time.Second) Convey("delete inner referenced manifest", func() { - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) // check orphan artifact is gc'ed @@ -3059,7 +3068,7 @@ func TestGarbageCollectChainedImageIndexes(t *testing.T) { err = imgStore.DeleteImageManifest(repoName, artifactDigest.String(), false) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) _, _, _, err = imgStore.GetImageManifest(repoName, artifactOfArtifactManifestDigest.String()) @@ -3073,7 +3082,7 @@ func TestGarbageCollectChainedImageIndexes(t *testing.T) { }) Convey("delete index manifest, references should not be persisted", func() { - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) // check orphan artifact is gc'ed @@ -3092,7 +3101,7 @@ func TestGarbageCollectChainedImageIndexes(t *testing.T) { err = imgStore.DeleteImageManifest(repoName, indexDigest.String(), false) So(err, ShouldBeNil) - err = gc.CleanRepo(repoName) + err = gc.CleanRepo(ctx, repoName) So(err, ShouldBeNil) _, _, _, err = imgStore.GetImageManifest(repoName, artifactDigest.String()) diff --git a/pkg/storage/types/types.go b/pkg/storage/types/types.go index 03f7648918..a7b2b23e81 100644 --- a/pkg/storage/types/types.go +++ b/pkg/storage/types/types.go @@ -1,6 +1,7 @@ package types import ( + "context" "io" "time" @@ -58,7 +59,7 @@ type ImageStore interface { //nolint:interfacebloat GetReferrers(repo string, digest godigest.Digest, artifactTypes []string) (ispec.Index, error) GetOrasReferrers(repo string, digest godigest.Digest, artifactType string) ([]artifactspec.Descriptor, error) RunDedupeBlobs(interval time.Duration, sch *scheduler.Scheduler) - RunDedupeForDigest(digest godigest.Digest, dedupe bool, duplicateBlobs []string) error + RunDedupeForDigest(ctx context.Context, digest godigest.Digest, dedupe bool, duplicateBlobs []string) error GetNextDigestWithBlobPaths(repos []string, lastDigests []godigest.Digest) (godigest.Digest, []string, error) GetAllBlobs(repo string) ([]string, error) PopulateStorageMetrics(interval time.Duration, sch *scheduler.Scheduler) diff --git a/pkg/test/mocks/cve_mock.go b/pkg/test/mocks/cve_mock.go index 9c1e950f50..03d17a0b9d 100644 --- a/pkg/test/mocks/cve_mock.go +++ b/pkg/test/mocks/cve_mock.go @@ -1,36 +1,39 @@ package mocks import ( + "context" + "zotregistry.io/zot/pkg/common" cvemodel "zotregistry.io/zot/pkg/extensions/search/cve/model" ) type CveInfoMock struct { - GetImageListForCVEFn func(repo, cveID string) ([]cvemodel.TagInfo, error) - GetImageListWithCVEFixedFn func(repo, cveID string) ([]cvemodel.TagInfo, error) - GetCVEListForImageFn func(repo string, reference string, searchedCVE string, pageInput cvemodel.PageInput, - ) ([]cvemodel.CVE, common.PageInfo, error) - GetCVESummaryForImageMediaFn func(repo string, digest, mediaType string, + GetImageListForCVEFn func(ctx context.Context, repo, cveID string) ([]cvemodel.TagInfo, error) + GetImageListWithCVEFixedFn func(ctx context.Context, repo, cveID string) ([]cvemodel.TagInfo, error) + GetCVEListForImageFn func(ctx context.Context, repo string, reference string, searchedCVE string, + pageInput cvemodel.PageInput) ([]cvemodel.CVE, common.PageInfo, error) + GetCVESummaryForImageMediaFn func(ctx context.Context, repo string, digest, mediaType string, ) (cvemodel.ImageCVESummary, error) } -func (cveInfo CveInfoMock) GetImageListForCVE(repo, cveID string) ([]cvemodel.TagInfo, error) { +func (cveInfo CveInfoMock) GetImageListForCVE(ctx context.Context, repo, cveID string) ([]cvemodel.TagInfo, error) { if cveInfo.GetImageListForCVEFn != nil { - return cveInfo.GetImageListForCVEFn(repo, cveID) + return cveInfo.GetImageListForCVEFn(ctx, repo, cveID) } return []cvemodel.TagInfo{}, nil } -func (cveInfo CveInfoMock) GetImageListWithCVEFixed(repo, cveID string) ([]cvemodel.TagInfo, error) { +func (cveInfo CveInfoMock) GetImageListWithCVEFixed(ctx context.Context, repo, cveID string, +) ([]cvemodel.TagInfo, error) { if cveInfo.GetImageListWithCVEFixedFn != nil { - return cveInfo.GetImageListWithCVEFixedFn(repo, cveID) + return cveInfo.GetImageListWithCVEFixedFn(ctx, repo, cveID) } return []cvemodel.TagInfo{}, nil } -func (cveInfo CveInfoMock) GetCVEListForImage(repo string, reference string, +func (cveInfo CveInfoMock) GetCVEListForImage(ctx context.Context, repo string, reference string, searchedCVE string, pageInput cvemodel.PageInput, ) ( []cvemodel.CVE, @@ -38,16 +41,16 @@ func (cveInfo CveInfoMock) GetCVEListForImage(repo string, reference string, error, ) { if cveInfo.GetCVEListForImageFn != nil { - return cveInfo.GetCVEListForImageFn(repo, reference, searchedCVE, pageInput) + return cveInfo.GetCVEListForImageFn(ctx, repo, reference, searchedCVE, pageInput) } return []cvemodel.CVE{}, common.PageInfo{}, nil } -func (cveInfo CveInfoMock) GetCVESummaryForImageMedia(repo, digest, mediaType string, +func (cveInfo CveInfoMock) GetCVESummaryForImageMedia(ctx context.Context, repo, digest, mediaType string, ) (cvemodel.ImageCVESummary, error) { if cveInfo.GetCVESummaryForImageMediaFn != nil { - return cveInfo.GetCVESummaryForImageMediaFn(repo, digest, mediaType) + return cveInfo.GetCVESummaryForImageMediaFn(ctx, repo, digest, mediaType) } return cvemodel.ImageCVESummary{}, nil @@ -58,8 +61,8 @@ type CveScannerMock struct { IsImageMediaScannableFn func(repo string, digest, mediaType string) (bool, error) IsResultCachedFn func(digest string) bool GetCachedResultFn func(digest string) map[string]cvemodel.CVE - ScanImageFn func(image string) (map[string]cvemodel.CVE, error) - UpdateDBFn func() error + ScanImageFn func(ctx context.Context, image string) (map[string]cvemodel.CVE, error) + UpdateDBFn func(ctx context.Context) error } func (scanner CveScannerMock) IsImageFormatScannable(repo string, reference string) (bool, error) { @@ -94,17 +97,17 @@ func (scanner CveScannerMock) GetCachedResult(digest string) map[string]cvemodel return map[string]cvemodel.CVE{} } -func (scanner CveScannerMock) ScanImage(image string) (map[string]cvemodel.CVE, error) { +func (scanner CveScannerMock) ScanImage(ctx context.Context, image string) (map[string]cvemodel.CVE, error) { if scanner.ScanImageFn != nil { - return scanner.ScanImageFn(image) + return scanner.ScanImageFn(ctx, image) } return map[string]cvemodel.CVE{}, nil } -func (scanner CveScannerMock) UpdateDB() error { +func (scanner CveScannerMock) UpdateDB(ctx context.Context) error { if scanner.UpdateDBFn != nil { - return scanner.UpdateDBFn() + return scanner.UpdateDBFn(ctx) } return nil diff --git a/pkg/test/mocks/image_store_mock.go b/pkg/test/mocks/image_store_mock.go index 5835932f3b..2217161bb0 100644 --- a/pkg/test/mocks/image_store_mock.go +++ b/pkg/test/mocks/image_store_mock.go @@ -1,6 +1,7 @@ package mocks import ( + "context" "io" "time" @@ -46,11 +47,12 @@ type MockedImageStore struct { GetReferrersFn func(repo string, digest godigest.Digest, artifactTypes []string) (ispec.Index, error) GetOrasReferrersFn func(repo string, digest godigest.Digest, artifactType string, ) ([]artifactspec.Descriptor, error) - URLForPathFn func(path string) (string, error) - RunGCRepoFn func(repo string) error - RunGCPeriodicallyFn func(interval time.Duration, sch *scheduler.Scheduler) - RunDedupeBlobsFn func(interval time.Duration, sch *scheduler.Scheduler) - RunDedupeForDigestFn func(digest godigest.Digest, dedupe bool, duplicateBlobs []string) error + URLForPathFn func(path string) (string, error) + RunGCRepoFn func(repo string) error + RunGCPeriodicallyFn func(interval time.Duration, sch *scheduler.Scheduler) + RunDedupeBlobsFn func(interval time.Duration, sch *scheduler.Scheduler) + RunDedupeForDigestFn func(ctx context.Context, digest godigest.Digest, dedupe bool, + duplicateBlobs []string) error GetNextDigestWithBlobPathsFn func(repos []string, lastDigests []godigest.Digest) (godigest.Digest, []string, error) GetAllBlobsFn func(repo string) ([]string, error) CleanupRepoFn func(repo string, blobs []godigest.Digest, removeRepo bool) (int, error) @@ -374,9 +376,11 @@ func (is MockedImageStore) RunDedupeBlobs(interval time.Duration, sch *scheduler } } -func (is MockedImageStore) RunDedupeForDigest(digest godigest.Digest, dedupe bool, duplicateBlobs []string) error { +func (is MockedImageStore) RunDedupeForDigest(ctx context.Context, digest godigest.Digest, dedupe bool, + duplicateBlobs []string, +) error { if is.RunDedupeForDigestFn != nil { - return is.RunDedupeForDigestFn(digest, dedupe, duplicateBlobs) + return is.RunDedupeForDigestFn(ctx, digest, dedupe, duplicateBlobs) } return nil diff --git a/pkg/test/mocks/repo_db_mock.go b/pkg/test/mocks/repo_db_mock.go index b9d43ff46a..283ec3d351 100644 --- a/pkg/test/mocks/repo_db_mock.go +++ b/pkg/test/mocks/repo_db_mock.go @@ -76,7 +76,7 @@ type MetaDBMock struct { UpdateStatsOnDownloadFn func(repo string, reference string) error - UpdateSignaturesValidityFn func(repo string, manifestDigest godigest.Digest) error + UpdateSignaturesValidityFn func(ctx context.Context, crepo string, manifestDigest godigest.Digest) error AddManifestSignatureFn func(repo string, signedManifestDigest godigest.Digest, sygMeta mTypes.SignatureMetadata, ) error @@ -372,9 +372,9 @@ func (sdm MetaDBMock) UpdateStatsOnDownload(repo string, reference string) error return nil } -func (sdm MetaDBMock) UpdateSignaturesValidity(repo string, manifestDigest godigest.Digest) error { +func (sdm MetaDBMock) UpdateSignaturesValidity(ctx context.Context, repo string, manifestDigest godigest.Digest) error { if sdm.UpdateSignaturesValidityFn != nil { - return sdm.UpdateSignaturesValidityFn(repo, manifestDigest) + return sdm.UpdateSignaturesValidityFn(ctx, repo, manifestDigest) } return nil