From d4043d4699d90416a1750b5667805a9653e951f6 Mon Sep 17 00:00:00 2001 From: Krystian Panek Date: Mon, 29 Jan 2024 11:10:07 +0100 Subject: [PATCH] SSM works like a charm ;) --- internal/client/client_manager.go | 12 +++--- internal/client/connection_aws_ssm.go | 59 ++++++++++++++------------- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/internal/client/client_manager.go b/internal/client/client_manager.go index 2327db8..8f37b5d 100644 --- a/internal/client/client_manager.go +++ b/internal/client/client_manager.go @@ -41,12 +41,12 @@ func (c ClientManager) connection(typeName string, settings map[string]string) ( }, nil case "aws-ssm": return &AWSSSMConnection{ - instanceID: settings["instance_id"], - region: settings["region"], - outputTimeout: cast.ToDuration(settings["output_timeout"]), - minWaitDelay: cast.ToDuration(settings["min_wait_delay"]), - maxWaitDelay: cast.ToDuration(settings["max_wait_delay"]), - context: context.Background(), + instanceID: settings["instance_id"], + region: settings["region"], + context: context.Background(), + commandOutputTimeout: cast.ToDuration(settings["command_output_timeout"]), + commandWaitMin: cast.ToDuration(settings["command_wait_min"]), + commandWaitMax: cast.ToDuration(settings["command_wait_max"]), }, nil } return nil, fmt.Errorf("unknown AEM client type: %s", typeName) diff --git a/internal/client/connection_aws_ssm.go b/internal/client/connection_aws_ssm.go index 5fc5824..e2a5608 100644 --- a/internal/client/connection_aws_ssm.go +++ b/internal/client/connection_aws_ssm.go @@ -13,14 +13,14 @@ import ( ) type AWSSSMConnection struct { - instanceID string - region string - outputTimeout time.Duration - minWaitDelay time.Duration - maxWaitDelay time.Duration - client *ssm.Client - sessionId *string - context context.Context + instanceID string + region string + client *ssm.Client + sessionId *string + context context.Context + commandOutputTimeout time.Duration + commandWaitMax time.Duration + commandWaitMin time.Duration } func (a *AWSSSMConnection) Info() string { @@ -40,14 +40,14 @@ func (a *AWSSSMConnection) User() string { } func (a *AWSSSMConnection) Connect() error { - if a.outputTimeout == 0 { - a.outputTimeout = 5 * time.Hour + if a.commandOutputTimeout == 0 { + a.commandOutputTimeout = 5 * time.Hour } - if a.minWaitDelay == 0 { - a.minWaitDelay = 5 * time.Millisecond + if a.commandWaitMin == 0 { + a.commandWaitMin = 5 * time.Millisecond } - if a.maxWaitDelay == 0 { - a.maxWaitDelay = 5 * time.Second + if a.commandWaitMax == 0 { + a.commandWaitMax = 5 * time.Second } var optFns []func(*config.LoadOptions) error @@ -61,24 +61,21 @@ func (a *AWSSSMConnection) Connect() error { } client := ssm.NewFromConfig(cfg) - startSessionInput := &ssm.StartSessionInput{Target: aws.String(a.instanceID)} - - startSessionOutput, err := client.StartSession(a.context, startSessionInput) + sessionIn := &ssm.StartSessionInput{Target: aws.String(a.instanceID)} + sessionOut, err := client.StartSession(a.context, sessionIn) if err != nil { return fmt.Errorf("ssm: error starting session: %v", err) } a.client = client - a.sessionId = startSessionOutput.SessionId + a.sessionId = sessionOut.SessionId return nil } func (a *AWSSSMConnection) Disconnect() error { - // Disconnect from the session - terminateSessionInput := &ssm.TerminateSessionInput{SessionId: a.sessionId} - - _, err := a.client.TerminateSession(a.context, terminateSessionInput) + sessionIn := &ssm.TerminateSessionInput{SessionId: a.sessionId} + _, err := a.client.TerminateSession(a.context, sessionIn) if err != nil { return fmt.Errorf("ssm: error terminating session: %v", err) } @@ -88,14 +85,14 @@ func (a *AWSSSMConnection) Disconnect() error { func (a *AWSSSMConnection) Command(cmdLine []string) ([]byte, error) { command := strings.Join(cmdLine, " ") - runCommandInput := &ssm.SendCommandInput{ + commandIn := &ssm.SendCommandInput{ DocumentName: aws.String("AWS-RunShellScript"), InstanceIds: []string{a.instanceID}, Parameters: map[string][]string{ "commands": {command}, }, } - runOut, err := a.client.SendCommand(a.context, runCommandInput) + runOut, err := a.client.SendCommand(a.context, commandIn) if err != nil { return nil, fmt.Errorf("ssm: error executing command: %v", err) } @@ -105,12 +102,16 @@ func (a *AWSSSMConnection) Command(cmdLine []string) ([]byte, error) { CommandId: commandId, InstanceId: aws.String(a.instanceID), } - optFns := []func(opt *ssm.CommandExecutedWaiterOptions){func(opt *ssm.CommandExecutedWaiterOptions) { - opt.MinDelay = a.minWaitDelay - opt.MaxDelay = a.maxWaitDelay - }} + var optFns []func(opt *ssm.CommandExecutedWaiterOptions) + if a.commandWaitMax > 0 && a.commandWaitMin > 0 { + optFns = []func(opt *ssm.CommandExecutedWaiterOptions){func(opt *ssm.CommandExecutedWaiterOptions) { + opt.MinDelay = a.commandWaitMin + opt.MaxDelay = a.commandWaitMax + }} + } + waiter := ssm.NewCommandExecutedWaiter(a.client, optFns...) - invocationOut, err := waiter.WaitForOutput(a.context, invocationIn, a.outputTimeout) + invocationOut, err := waiter.WaitForOutput(a.context, invocationIn, a.commandOutputTimeout) if err != nil { invocationOut, err = a.client.GetCommandInvocation(a.context, invocationIn) if invocationOut != nil {