Skip to content

Commit

Permalink
SSM works like a charm ;)
Browse files Browse the repository at this point in the history
  • Loading branch information
krystian-panek-vmltech committed Jan 29, 2024
1 parent 2d1fbf5 commit d4043d4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 35 deletions.
12 changes: 6 additions & 6 deletions internal/client/client_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ func (c ClientManager) connection(typeName string, settings map[string]string) (
}, nil
case "aws-ssm":
return &AWSSSMConnection{
instanceID: settings["instance_id"],
region: settings["region"],
outputTimeout: cast.ToDuration(settings["output_timeout"]),
minWaitDelay: cast.ToDuration(settings["min_wait_delay"]),
maxWaitDelay: cast.ToDuration(settings["max_wait_delay"]),
context: context.Background(),
instanceID: settings["instance_id"],
region: settings["region"],
context: context.Background(),
commandOutputTimeout: cast.ToDuration(settings["command_output_timeout"]),
commandWaitMin: cast.ToDuration(settings["command_wait_min"]),
commandWaitMax: cast.ToDuration(settings["command_wait_max"]),
}, nil
}
return nil, fmt.Errorf("unknown AEM client type: %s", typeName)
Expand Down
59 changes: 30 additions & 29 deletions internal/client/connection_aws_ssm.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ import (
)

type AWSSSMConnection struct {
instanceID string
region string
outputTimeout time.Duration
minWaitDelay time.Duration
maxWaitDelay time.Duration
client *ssm.Client
sessionId *string
context context.Context
instanceID string
region string
client *ssm.Client
sessionId *string
context context.Context
commandOutputTimeout time.Duration
commandWaitMax time.Duration
commandWaitMin time.Duration
}

func (a *AWSSSMConnection) Info() string {
Expand All @@ -40,14 +40,14 @@ func (a *AWSSSMConnection) User() string {
}

func (a *AWSSSMConnection) Connect() error {
if a.outputTimeout == 0 {
a.outputTimeout = 5 * time.Hour
if a.commandOutputTimeout == 0 {
a.commandOutputTimeout = 5 * time.Hour
}
if a.minWaitDelay == 0 {
a.minWaitDelay = 5 * time.Millisecond
if a.commandWaitMin == 0 {
a.commandWaitMin = 5 * time.Millisecond
}
if a.maxWaitDelay == 0 {
a.maxWaitDelay = 5 * time.Second
if a.commandWaitMax == 0 {
a.commandWaitMax = 5 * time.Second
}

var optFns []func(*config.LoadOptions) error
Expand All @@ -61,24 +61,21 @@ func (a *AWSSSMConnection) Connect() error {
}

client := ssm.NewFromConfig(cfg)
startSessionInput := &ssm.StartSessionInput{Target: aws.String(a.instanceID)}

startSessionOutput, err := client.StartSession(a.context, startSessionInput)
sessionIn := &ssm.StartSessionInput{Target: aws.String(a.instanceID)}
sessionOut, err := client.StartSession(a.context, sessionIn)
if err != nil {
return fmt.Errorf("ssm: error starting session: %v", err)
}

a.client = client
a.sessionId = startSessionOutput.SessionId
a.sessionId = sessionOut.SessionId

return nil
}

func (a *AWSSSMConnection) Disconnect() error {
// Disconnect from the session
terminateSessionInput := &ssm.TerminateSessionInput{SessionId: a.sessionId}

_, err := a.client.TerminateSession(a.context, terminateSessionInput)
sessionIn := &ssm.TerminateSessionInput{SessionId: a.sessionId}
_, err := a.client.TerminateSession(a.context, sessionIn)
if err != nil {
return fmt.Errorf("ssm: error terminating session: %v", err)
}
Expand All @@ -88,14 +85,14 @@ func (a *AWSSSMConnection) Disconnect() error {

func (a *AWSSSMConnection) Command(cmdLine []string) ([]byte, error) {
command := strings.Join(cmdLine, " ")
runCommandInput := &ssm.SendCommandInput{
commandIn := &ssm.SendCommandInput{
DocumentName: aws.String("AWS-RunShellScript"),
InstanceIds: []string{a.instanceID},
Parameters: map[string][]string{
"commands": {command},
},
}
runOut, err := a.client.SendCommand(a.context, runCommandInput)
runOut, err := a.client.SendCommand(a.context, commandIn)
if err != nil {
return nil, fmt.Errorf("ssm: error executing command: %v", err)
}
Expand All @@ -105,12 +102,16 @@ func (a *AWSSSMConnection) Command(cmdLine []string) ([]byte, error) {
CommandId: commandId,
InstanceId: aws.String(a.instanceID),
}
optFns := []func(opt *ssm.CommandExecutedWaiterOptions){func(opt *ssm.CommandExecutedWaiterOptions) {
opt.MinDelay = a.minWaitDelay
opt.MaxDelay = a.maxWaitDelay
}}
var optFns []func(opt *ssm.CommandExecutedWaiterOptions)
if a.commandWaitMax > 0 && a.commandWaitMin > 0 {
optFns = []func(opt *ssm.CommandExecutedWaiterOptions){func(opt *ssm.CommandExecutedWaiterOptions) {
opt.MinDelay = a.commandWaitMin
opt.MaxDelay = a.commandWaitMax
}}
}

waiter := ssm.NewCommandExecutedWaiter(a.client, optFns...)
invocationOut, err := waiter.WaitForOutput(a.context, invocationIn, a.outputTimeout)
invocationOut, err := waiter.WaitForOutput(a.context, invocationIn, a.commandOutputTimeout)
if err != nil {
invocationOut, err = a.client.GetCommandInvocation(a.context, invocationIn)
if invocationOut != nil {
Expand Down

0 comments on commit d4043d4

Please sign in to comment.