From a680e63dd7491db304496c6c9113d89894e7d5ff Mon Sep 17 00:00:00 2001 From: Nick Baker Date: Mon, 16 Sep 2024 22:28:34 -0700 Subject: [PATCH] nodeadm: block until daemon status changes are reflected (#1965) --- nodeadm/internal/aws/ecr/ecr.go | 3 +- nodeadm/internal/containerd/sandbox.go | 3 +- nodeadm/internal/daemon/interface.go | 15 ++--- nodeadm/internal/daemon/systemd.go | 41 +++++++++---- nodeadm/internal/util/retry.go | 79 ++++++++++++++++++++++---- 5 files changed, 110 insertions(+), 31 deletions(-) diff --git a/nodeadm/internal/aws/ecr/ecr.go b/nodeadm/internal/aws/ecr/ecr.go index c70018980..3e317c314 100644 --- a/nodeadm/internal/aws/ecr/ecr.go +++ b/nodeadm/internal/aws/ecr/ecr.go @@ -6,7 +6,6 @@ import ( "go.uber.org/zap" "net" "strings" - "time" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/ecr" @@ -23,7 +22,7 @@ func GetAuthorizationToken(awsRegion string) (string, error) { } ecrClient := ecr.NewFromConfig(awsConfig) var token *ecr.GetAuthorizationTokenOutput - err = util.RetryExponentialBackoff(3, 2*time.Second, func() error { + err = util.NewRetrier(util.WithBackoffExponential()).Retry(context.TODO(), func() error { token, err = ecrClient.GetAuthorizationToken(context.Background(), &ecr.GetAuthorizationTokenInput{}) return err }) diff --git a/nodeadm/internal/containerd/sandbox.go b/nodeadm/internal/containerd/sandbox.go index 41c97e318..42c58c1e9 100644 --- a/nodeadm/internal/containerd/sandbox.go +++ b/nodeadm/internal/containerd/sandbox.go @@ -1,6 +1,7 @@ package containerd import ( + "context" "fmt" "os/exec" "regexp" @@ -44,7 +45,7 @@ func cacheSandboxImage(cfg *api.NodeConfig) error { imageSpec := &v1.ImageSpec{Image: sandboxImage} authConfig := &v1.AuthConfig{Auth: ecrUserToken} - return util.RetryExponentialBackoff(3, 2*time.Second, func() error { + return util.NewRetrier(util.WithBackoffExponential()).Retry(context.TODO(), func() error { zap.L().Info("Pulling sandbox image..", zap.String("image", sandboxImage)) imageRef, err := client.PullImage(imageSpec, authConfig, nil) if err != nil { diff --git a/nodeadm/internal/daemon/interface.go b/nodeadm/internal/daemon/interface.go index 4ec53db8f..01c75da0d 100644 --- a/nodeadm/internal/daemon/interface.go +++ b/nodeadm/internal/daemon/interface.go @@ -1,20 +1,21 @@ package daemon -import "github.com/awslabs/amazon-eks-ami/nodeadm/internal/api" +import ( + "github.com/awslabs/amazon-eks-ami/nodeadm/internal/api" +) type Daemon interface { // Configure configures the daemon. Configure(*api.NodeConfig) error - - // EnsureRunning ensures that the daemon is running. - // If the daemon is not running, it will be started. - // If the daemon is already running, and has been re-configured, it will be restarted. + // EnsureRunning ensures that the daemon is running by either + // starting/restarting the daemon, then blocking until the status of the + // daemon reflects that it is running. + // * If the daemon is not running, it will be started. + // * If the daemon is already running, and has been re-configured, it will be restarted. EnsureRunning() error - // PostLaunch runs any additional step that needs to occur after the service // daemon as been started PostLaunch(*api.NodeConfig) error - // Name returns the name of the daemon. Name() string } diff --git a/nodeadm/internal/daemon/systemd.go b/nodeadm/internal/daemon/systemd.go index a76db56a2..0f27e1cc1 100644 --- a/nodeadm/internal/daemon/systemd.go +++ b/nodeadm/internal/daemon/systemd.go @@ -5,7 +5,9 @@ package daemon import ( "context" "fmt" + "time" + "github.com/awslabs/amazon-eks-ami/nodeadm/internal/util" "github.com/coreos/go-systemd/v22/dbus" ) @@ -32,21 +34,24 @@ func NewDaemonManager() (DaemonManager, error) { } func (m *systemdDaemonManager) StartDaemon(name string) error { - unitName := getServiceUnitName(name) - _, err := m.conn.StartUnitContext(context.TODO(), unitName, ModeReplace, nil) - return err + if _, err := m.conn.StartUnitContext(context.TODO(), getServiceUnitName(name), ModeReplace, nil); err != nil { + return err + } + return m.waitForStatus(context.TODO(), name, DaemonStatusRunning) } func (m *systemdDaemonManager) StopDaemon(name string) error { - unitName := getServiceUnitName(name) - _, err := m.conn.StopUnitContext(context.TODO(), unitName, ModeReplace, nil) - return err + if _, err := m.conn.StopUnitContext(context.TODO(), getServiceUnitName(name), ModeReplace, nil); err != nil { + return err + } + return m.waitForStatus(context.TODO(), name, DaemonStatusStopped) } func (m *systemdDaemonManager) RestartDaemon(name string) error { - unitName := getServiceUnitName(name) - _, err := m.conn.RestartUnitContext(context.TODO(), unitName, ModeReplace, nil) - return err + if _, err := m.conn.RestartUnitContext(context.TODO(), getServiceUnitName(name), ModeReplace, nil); err != nil { + return err + } + return m.waitForStatus(context.TODO(), name, DaemonStatusRunning) } func (m *systemdDaemonManager) GetDaemonStatus(name string) (DaemonStatus, error) { @@ -55,7 +60,7 @@ func (m *systemdDaemonManager) GetDaemonStatus(name string) (DaemonStatus, error if err != nil { return DaemonStatusUnknown, err } - switch status.Value.String() { + switch status.Value.Value().(string) { case "active": return DaemonStatusRunning, nil case "inactive": @@ -102,3 +107,19 @@ func (m *systemdDaemonManager) Close() { func getServiceUnitName(name string) string { return fmt.Sprintf("%s.service", name) } + +func (m *systemdDaemonManager) waitForStatus(ctx context.Context, name string, targetStatus DaemonStatus) error { + return util.NewRetrier( + util.WithRetryAlways(), + util.WithBackoffFixed(250*time.Millisecond), + ).Retry(ctx, func() error { + status, err := m.GetDaemonStatus(name) + if err != nil { + return err + } + if status != targetStatus { + return fmt.Errorf("%s status is not %q", name, targetStatus) + } + return nil + }) +} diff --git a/nodeadm/internal/util/retry.go b/nodeadm/internal/util/retry.go index 902395f23..ab54e4a1b 100644 --- a/nodeadm/internal/util/retry.go +++ b/nodeadm/internal/util/retry.go @@ -1,16 +1,73 @@ package util -import "time" - -func RetryExponentialBackoff(attempts int, initial time.Duration, f func() error) error { - var err error - wait := initial - for i := 0; i < attempts; i++ { - if err = f(); err == nil { - return nil +import ( + "context" + "time" +) + +type Retrier struct { + ConditionFn func(*Retrier) bool + BackoffFn func(*Retrier) time.Duration + + LastErr error + LastWait time.Duration + LastIter int +} + +func (r *Retrier) Retry(ctx context.Context, fn func() error) error { + for r.LastIter = 0; r.ConditionFn(r); r.LastIter++ { + if r.LastErr = fn(); r.LastErr == nil { + return r.LastErr + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + time.Sleep(r.LastWait) + r.LastWait = r.BackoffFn(r) } - time.Sleep(wait) - wait *= 2 } - return err + return r.LastErr +} + +type fnOpt func(*Retrier) + +func NewRetrier(fnOpts ...fnOpt) *Retrier { + retrier := Retrier{ + LastErr: nil, + LastIter: 0, + LastWait: time.Second, + } + for _, fn := range append([]fnOpt{ + WithRetryCount(5), + WithBackoffExponential(), + }, fnOpts...) { + fn(&retrier) + } + return &retrier +} + +func WithRetryCount(maxAttempts int) fnOpt { + return func(r *Retrier) { + r.ConditionFn = func(r *Retrier) bool { return r.LastIter < maxAttempts } + } +} + +func WithRetryAlways() fnOpt { + return func(r *Retrier) { + r.ConditionFn = func(r *Retrier) bool { return true } + } +} + +func WithBackoffFixed(interval time.Duration) fnOpt { + return func(r *Retrier) { + r.LastWait = interval + r.BackoffFn = func(r *Retrier) time.Duration { return interval } + } +} + +func WithBackoffExponential() fnOpt { + return func(r *Retrier) { + r.BackoffFn = func(r *Retrier) time.Duration { return r.LastWait * 2 } + } }