From 627cb97ef139583ea1246e76ad0c76396be3c7fc Mon Sep 17 00:00:00 2001 From: Petu Eusebiu Date: Thu, 2 Dec 2021 19:45:26 +0200 Subject: [PATCH] Add wait group for graceful shutdown, closes #302 Signed-off-by: Petu Eusebiu --- pkg/api/controller.go | 14 ++- pkg/api/routes.go | 4 +- pkg/extensions/extensions.go | 12 ++- pkg/extensions/minimal.go | 8 +- pkg/extensions/sync/on_demand.go | 4 +- pkg/extensions/sync/sync.go | 8 +- pkg/extensions/sync/sync_internal_test.go | 3 +- pkg/extensions/sync/sync_test.go | 120 ++++++---------------- 8 files changed, 68 insertions(+), 105 deletions(-) diff --git a/pkg/api/controller.go b/pkg/api/controller.go index fee3a5c50..99c042e68 100644 --- a/pkg/api/controller.go +++ b/pkg/api/controller.go @@ -1,12 +1,14 @@ package api import ( + "context" "crypto/tls" "crypto/x509" "fmt" "io/ioutil" "net" "net/http" + goSync "sync" "time" "github.com/gorilla/handlers" @@ -34,6 +36,7 @@ type Controller struct { Audit *log.Logger Server *http.Server Metrics monitoring.MetricServer + wgShutDown *goSync.WaitGroup // use it to gracefully shutdown goroutines } func NewController(config *config.Config) *Controller { @@ -43,6 +46,7 @@ func NewController(config *config.Config) *Controller { controller.Config = config controller.Log = logger + controller.wgShutDown = new(goSync.WaitGroup) if config.Log.Audit != "" { audit := log.NewAuditLogger(config.Log.Level, config.Log.Audit) @@ -195,7 +199,7 @@ func (c *Controller) Run() error { // Enable extensions if extension config is provided if c.Config.Extensions != nil && c.Config.Extensions.Sync != nil { - ext.EnableSyncExtension(c.Config, c.Log, c.StoreController) + ext.EnableSyncExtension(c.Config, c.wgShutDown, c.StoreController, c.Log) } monitoring.SetServerInfo(c.Metrics, c.Config.Commit, c.Config.BinaryType, c.Config.GoVersion, c.Config.Version) @@ -247,3 +251,11 @@ func (c *Controller) Run() error { return server.Serve(l) } + +func (c *Controller) Shutdown() { + // wait gracefully + c.wgShutDown.Wait() + + ctx := context.Background() + _ = c.Server.Shutdown(ctx) +} diff --git a/pkg/api/routes.go b/pkg/api/routes.go index e32d34ce6..9747dbc6a 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -1271,7 +1271,7 @@ func getImageManifest(rh *RouteHandler, is storage.ImageStore, name, if rh.c.Config.Extensions != nil && rh.c.Config.Extensions.Sync != nil { rh.c.Log.Info().Msgf("image not found, trying to get image %s:%s by syncing on demand", name, reference) - errSync := ext.SyncOneImage(rh.c.Config, rh.c.Log, rh.c.StoreController, name, reference) + errSync := ext.SyncOneImage(rh.c.Config, rh.c.StoreController, name, reference, rh.c.Log) if errSync != nil { rh.c.Log.Err(errSync).Msgf("error encounter while syncing image %s:%s", name, reference) } else { @@ -1283,7 +1283,7 @@ func getImageManifest(rh *RouteHandler, is storage.ImageStore, name, if rh.c.Config.Extensions != nil && rh.c.Config.Extensions.Sync != nil { rh.c.Log.Info().Msgf("manifest not found, trying to get image %s:%s by syncing on demand", name, reference) - errSync := ext.SyncOneImage(rh.c.Config, rh.c.Log, rh.c.StoreController, name, reference) + errSync := ext.SyncOneImage(rh.c.Config, rh.c.StoreController, name, reference, rh.c.Log) if errSync != nil { rh.c.Log.Err(errSync).Msgf("error encounter while syncing image %s:%s", name, reference) } else { diff --git a/pkg/extensions/extensions.go b/pkg/extensions/extensions.go index f42509dd0..5a426ca9b 100644 --- a/pkg/extensions/extensions.go +++ b/pkg/extensions/extensions.go @@ -4,6 +4,7 @@ package extensions import ( + goSync "sync" "time" gqlHandler "github.com/99designs/gqlgen/graphql/handler" @@ -68,7 +69,8 @@ func EnableExtensions(config *config.Config, log log.Logger, rootDir string) { } // EnableSyncExtension enables sync extension. -func EnableSyncExtension(config *config.Config, log log.Logger, storeController storage.StoreController) { +func EnableSyncExtension(config *config.Config, wg *goSync.WaitGroup, + storeController storage.StoreController, log log.Logger) { if config.Extensions.Sync != nil { defaultPollInterval, _ := time.ParseDuration("1h") for id, registryCfg := range config.Extensions.Sync.Registries { @@ -83,7 +85,7 @@ func EnableSyncExtension(config *config.Config, log log.Logger, storeController } } - if err := sync.Run(*config.Extensions.Sync, storeController, log); err != nil { + if err := sync.Run(*config.Extensions.Sync, storeController, wg, log); err != nil { log.Error().Err(err).Msg("Error encountered while setting up syncing") } } else { @@ -128,11 +130,11 @@ func SetupRoutes(config *config.Config, router *mux.Router, storeController stor } // SyncOneImage syncs one image. -func SyncOneImage(config *config.Config, log log.Logger, - storeController storage.StoreController, repoName, reference string) error { +func SyncOneImage(config *config.Config, storeController storage.StoreController, + repoName, reference string, log log.Logger) error { log.Info().Msgf("syncing image %s:%s", repoName, reference) - err := sync.OneImage(*config.Extensions.Sync, log, storeController, repoName, reference) + err := sync.OneImage(*config.Extensions.Sync, storeController, repoName, reference, log) return err } diff --git a/pkg/extensions/minimal.go b/pkg/extensions/minimal.go index 1d3238fb4..3076d1d17 100644 --- a/pkg/extensions/minimal.go +++ b/pkg/extensions/minimal.go @@ -4,6 +4,7 @@ package extensions import ( + goSync "sync" "time" "github.com/gorilla/mux" @@ -24,7 +25,8 @@ func EnableExtensions(config *config.Config, log log.Logger, rootDir string) { } // EnableSyncExtension ... -func EnableSyncExtension(config *config.Config, log log.Logger, storeController storage.StoreController) { +func EnableSyncExtension(config *config.Config, wg *goSync.WaitGroup, + storeController storage.StoreController, log log.Logger) { log.Warn().Msg("skipping enabling sync extension because given zot binary doesn't support any extensions," + "please build zot full binary for this feature") } @@ -36,8 +38,8 @@ func SetupRoutes(conf *config.Config, router *mux.Router, storeController storag } // SyncOneImage ... -func SyncOneImage(config *config.Config, log log.Logger, storeController storage.StoreController, - repoName, reference string) error { +func SyncOneImage(config *config.Config, storeController storage.StoreController, + repoName, reference string, log log.Logger) error { log.Warn().Msg("skipping syncing on demand because given zot binary doesn't support any extensions," + "please build zot full binary for this feature") return nil diff --git a/pkg/extensions/sync/on_demand.go b/pkg/extensions/sync/on_demand.go index 22932a4ce..b4ae1b5a7 100644 --- a/pkg/extensions/sync/on_demand.go +++ b/pkg/extensions/sync/on_demand.go @@ -17,8 +17,8 @@ import ( "zotregistry.io/zot/pkg/storage" ) -func OneImage(cfg Config, log log.Logger, - storeController storage.StoreController, repo, tag string) error { +func OneImage(cfg Config, storeController storage.StoreController, + repo, tag string, log log.Logger) error { var credentialsFile CredentialsFile if cfg.CredentialsFile != "" { diff --git a/pkg/extensions/sync/sync.go b/pkg/extensions/sync/sync.go index d2e311cba..b96c5d165 100644 --- a/pkg/extensions/sync/sync.go +++ b/pkg/extensions/sync/sync.go @@ -12,6 +12,7 @@ import ( "path" "regexp" "strings" + goSync "sync" "time" "github.com/Masterminds/semver" @@ -438,7 +439,7 @@ func getLocalContexts(log log.Logger) (*types.SystemContext, *signature.PolicyCo return localCtx, policyContext, nil } -func Run(cfg Config, storeController storage.StoreController, logger log.Logger) error { +func Run(cfg Config, storeController storage.StoreController, wg *goSync.WaitGroup, logger log.Logger) error { var credentialsFile CredentialsFile var err error @@ -468,6 +469,9 @@ func Run(cfg Config, storeController storage.StoreController, logger log.Logger) continue } + // increment reference since will be busy, so shutdown has to wait + wg.Add(1) + // schedule each registry sync ticker := time.NewTicker(regCfg.PollInterval) @@ -484,6 +488,8 @@ func Run(cfg Config, storeController storage.StoreController, logger log.Logger) l.Error().Err(err).Msg("sync exited with error, stopping it...") ticker.Stop() } + // mark as done after a single sync run + wg.Done() } }(regCfg, l) } diff --git a/pkg/extensions/sync/sync_internal_test.go b/pkg/extensions/sync/sync_internal_test.go index a8ac13b94..cbf6beeb4 100644 --- a/pkg/extensions/sync/sync_internal_test.go +++ b/pkg/extensions/sync/sync_internal_test.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "os" "path" + goSync "sync" "testing" "time" @@ -101,7 +102,7 @@ func TestSyncInternal(t *testing.T) { cfg := Config{Registries: []RegistryConfig{syncRegistryConfig}, CredentialsFile: "/invalid/path/to/file"} - So(Run(cfg, storage.StoreController{}, log.NewLogger("debug", "")), ShouldNotBeNil) + So(Run(cfg, storage.StoreController{}, new(goSync.WaitGroup), log.NewLogger("debug", "")), ShouldNotBeNil) _, err = getFileCredentials("/invalid/path/to/file") So(err, ShouldNotBeNil) diff --git a/pkg/extensions/sync/sync_test.go b/pkg/extensions/sync/sync_test.go index d0bedccc3..f292973b0 100644 --- a/pkg/extensions/sync/sync_test.go +++ b/pkg/extensions/sync/sync_test.go @@ -4,7 +4,6 @@ package sync_test import ( - "context" "crypto/tls" "crypto/x509" "encoding/json" @@ -235,9 +234,7 @@ func TestSyncOnDemand(t *testing.T) { defer os.RemoveAll(srcDir) defer func() { - ctx := context.Background() - _ = sc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + sc.Shutdown() }() regex := ".*" @@ -269,9 +266,7 @@ func TestSyncOnDemand(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() var srcTagsList TagsList @@ -369,9 +364,7 @@ func TestSync(t *testing.T) { defer os.RemoveAll(srcDir) defer func() { - ctx := context.Background() - _ = sc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + sc.Shutdown() }() regex := ".*" @@ -400,9 +393,7 @@ func TestSync(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() var srcTagsList TagsList @@ -480,9 +471,7 @@ func TestSync(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() var srcTagsList TagsList @@ -551,9 +540,7 @@ func TestSyncPermsDenied(t *testing.T) { defer os.RemoveAll(srcDir) defer func() { - ctx := context.Background() - _ = sc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + sc.Shutdown() }() regex := ".*" @@ -582,9 +569,7 @@ func TestSyncPermsDenied(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() err := os.Chmod(path.Join(destDir, testImage, sync.SyncBlobUploadDir), 0000) @@ -608,9 +593,7 @@ func TestSyncBadTLS(t *testing.T) { defer os.RemoveAll(srcDir) defer func() { - ctx := context.Background() - _ = sc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + sc.Shutdown() }() regex := ".*" @@ -639,9 +622,7 @@ func TestSyncBadTLS(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() // give it time to set up sync @@ -669,9 +650,7 @@ func TestSyncTLS(t *testing.T) { defer os.RemoveAll(srcDir) defer func() { - ctx := context.Background() - _ = sc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + sc.Shutdown() }() var srcIndex ispec.Index @@ -738,9 +717,7 @@ func TestSyncTLS(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() // wait till ready @@ -781,9 +758,7 @@ func TestSyncBasicAuth(t *testing.T) { defer os.RemoveAll(srcDir) defer func() { - ctx := context.Background() - _ = sc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + sc.Shutdown() }() Convey("Verify sync basic auth with file credentials", func() { @@ -811,9 +786,7 @@ func TestSyncBasicAuth(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() var srcTagsList TagsList @@ -915,9 +888,7 @@ func TestSyncBasicAuth(t *testing.T) { }() defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() // wait till ready @@ -982,9 +953,7 @@ func TestSyncBasicAuth(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() resp, err := destClient.R().Get(destBaseURL + "/v2/" + testImage + "/manifests/" + testImageTag) @@ -1028,9 +997,7 @@ func TestSyncBasicAuth(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() var srcTagsList TagsList @@ -1119,9 +1086,7 @@ func TestSyncBadURL(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() Convey("Test sync on POST request on /sync", func() { @@ -1141,9 +1106,7 @@ func TestSyncNoImagesByRegex(t *testing.T) { defer os.RemoveAll(srcDir) defer func() { - ctx := context.Background() - _ = sc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + sc.Shutdown() }() regex := "9.9.9" @@ -1170,9 +1133,7 @@ func TestSyncNoImagesByRegex(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() Convey("Test sync on POST request on /sync", func() { @@ -1205,9 +1166,7 @@ func TestSyncInvalidRegex(t *testing.T) { defer os.RemoveAll(srcDir) defer func() { - ctx := context.Background() - _ = sc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + sc.Shutdown() }() regex := "[" @@ -1234,9 +1193,7 @@ func TestSyncInvalidRegex(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() Convey("Test sync on POST request on /sync", func() { @@ -1256,9 +1213,7 @@ func TestSyncNotSemver(t *testing.T) { defer os.RemoveAll(srcDir) defer func() { - ctx := context.Background() - _ = sc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + sc.Shutdown() }() // get manifest so we can update it with a semver non compliant tag @@ -1300,9 +1255,7 @@ func TestSyncNotSemver(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() Convey("Test sync on POST request on /sync", func() { @@ -1335,9 +1288,7 @@ func TestSyncInvalidCerts(t *testing.T) { defer os.RemoveAll(srcDir) defer func() { - ctx := context.Background() - _ = sc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + sc.Shutdown() }() // copy client certs, use them in sync config @@ -1397,9 +1348,7 @@ func TestSyncInvalidCerts(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() Convey("Test sync on POST request on /sync", func() { @@ -1456,9 +1405,7 @@ func TestSyncInvalidUrl(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() resp, err := destClient.R().Get(destBaseURL + "/v2/" + testImage + "/manifests/" + testImageTag) @@ -1475,9 +1422,7 @@ func TestSyncInvalidTags(t *testing.T) { defer os.RemoveAll(srcDir) defer func() { - ctx := context.Background() - _ = sc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + sc.Shutdown() }() regex := ".*" @@ -1509,9 +1454,7 @@ func TestSyncInvalidTags(t *testing.T) { defer os.RemoveAll(destDir) defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + dc.Shutdown() }() resp, err := destClient.R().Get(destBaseURL + "/v2/" + testImage + "/manifests/" + "invalid:tag") @@ -1565,9 +1508,7 @@ func TestSyncSubPaths(t *testing.T) { } defer func() { - ctx := context.Background() - _ = sc.Server.Shutdown(ctx) - time.Sleep(500 * time.Millisecond) + sc.Shutdown() }() regex := ".*" @@ -1646,8 +1587,7 @@ func TestSyncSubPaths(t *testing.T) { } defer func() { - ctx := context.Background() - _ = dc.Server.Shutdown(ctx) + dc.Shutdown() }() var destTagsList TagsList