Skip to content

Commit

Permalink
Adding credential helper
Browse files Browse the repository at this point in the history
Signed-off-by: Noel Georgi <[email protected]>
  • Loading branch information
frezbo committed Oct 10, 2018
1 parent 0cd0697 commit d15847e
Showing 1 changed file with 99 additions and 66 deletions.
165 changes: 99 additions & 66 deletions cli/exec.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package cli

import (
"encoding/json"
"fmt"
"log"
"os"
"os/exec"
Expand All @@ -17,17 +19,26 @@ import (
)

type ExecCommandInput struct {
Profile string
Command string
Args []string
Keyring keyring.Keyring
Duration time.Duration
RoleDuration time.Duration
MfaToken string
MfaPrompt prompt.PromptFunc
StartServer bool
Signals chan os.Signal
NoSession bool
Profile string
Command string
Args []string
Keyring keyring.Keyring
Duration time.Duration
RoleDuration time.Duration
MfaToken string
MfaPrompt prompt.PromptFunc
StartServer bool
CredentialHelper bool
Signals chan os.Signal
NoSession bool
}

type AwsCredentialHelperData struct {
Version string `json:"Version"`
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string `json:"SecretAccessKey"`
SessionToken string `json:"SessionToken"`
Expiration string `json:"Expiration,omitempty"`
}

func ConfigureExecCommand(app *kingpin.Application) {
Expand All @@ -53,6 +64,10 @@ func ConfigureExecCommand(app *kingpin.Application) {
Short('m').
StringVar(&input.MfaToken)

cmd.Flag("json", "AWS credential helper").
Short('j').
BoolVar(&input.CredentialHelper)

cmd.Flag("server", "Run the server in the background for credentials").
Short('s').
BoolVar(&input.StartServer)
Expand Down Expand Up @@ -116,67 +131,85 @@ func ExecCommand(app *kingpin.Application, input ExecCommandInput) {
}
}

env := environ(os.Environ())
env.Set("AWS_VAULT", input.Profile)

env.Unset("AWS_ACCESS_KEY_ID")
env.Unset("AWS_SECRET_ACCESS_KEY")
env.Unset("AWS_CREDENTIAL_FILE")
env.Unset("AWS_DEFAULT_PROFILE")
env.Unset("AWS_PROFILE")

if profile, _ := awsConfig.Profile(input.Profile); profile.Region != "" {
log.Printf("Setting subprocess env: AWS_DEFAULT_REGION=%s, AWS_REGION=%s", profile.Region, profile.Region)
env.Set("AWS_DEFAULT_REGION", profile.Region)
env.Set("AWS_REGION", profile.Region)
}
if input.CredentialHelper {
credentialData := AwsCredentialHelperData{
Version: "1",
AccessKeyID: val.AccessKeyID,
SecretAccessKey: val.SecretAccessKey,
SessionToken: val.SessionToken,
}
if !input.NoSession {
credentialData.Expiration = time.Now().Add(input.Duration).Format("2006-01-02T15:04:05")
}
json, err := json.Marshal(&credentialData)
if err != nil {
app.Fatalf("Error creating credential json")
}
fmt.Printf(string(json))
} else {

env := environ(os.Environ())
env.Set("AWS_VAULT", input.Profile)

env.Unset("AWS_ACCESS_KEY_ID")
env.Unset("AWS_SECRET_ACCESS_KEY")
env.Unset("AWS_CREDENTIAL_FILE")
env.Unset("AWS_DEFAULT_PROFILE")
env.Unset("AWS_PROFILE")

if profile, _ := awsConfig.Profile(input.Profile); profile.Region != "" {
log.Printf("Setting subprocess env: AWS_DEFAULT_REGION=%s, AWS_REGION=%s", profile.Region, profile.Region)
env.Set("AWS_DEFAULT_REGION", profile.Region)
env.Set("AWS_REGION", profile.Region)
}

if setEnv {
log.Println("Setting subprocess env: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY")
env.Set("AWS_ACCESS_KEY_ID", val.AccessKeyID)
env.Set("AWS_SECRET_ACCESS_KEY", val.SecretAccessKey)
if setEnv {
log.Println("Setting subprocess env: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY")
env.Set("AWS_ACCESS_KEY_ID", val.AccessKeyID)
env.Set("AWS_SECRET_ACCESS_KEY", val.SecretAccessKey)

if val.SessionToken != "" {
log.Println("Setting subprocess env: AWS_SESSION_TOKEN, AWS_SECURITY_TOKEN")
env.Set("AWS_SESSION_TOKEN", val.SessionToken)
env.Set("AWS_SECURITY_TOKEN", val.SessionToken)
if val.SessionToken != "" {
log.Println("Setting subprocess env: AWS_SESSION_TOKEN, AWS_SECURITY_TOKEN")
env.Set("AWS_SESSION_TOKEN", val.SessionToken)
env.Set("AWS_SECURITY_TOKEN", val.SessionToken)
}
}
}

cmd := exec.Command(input.Command, input.Args...)
cmd.Env = env
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
signal.Notify(input.Signals, os.Interrupt, os.Kill)
cmd := exec.Command(input.Command, input.Args...)
cmd.Env = env
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
signal.Notify(input.Signals, os.Interrupt, os.Kill)

if err := cmd.Start(); err != nil {
app.Fatalf("%v", err)
}
// wait for the command to finish
waitCh := make(chan error, 1)
go func() {
waitCh <- cmd.Wait()
close(waitCh)
}()

for {
select {
case sig := <-input.Signals:
if err = cmd.Process.Signal(sig); err != nil {
app.Errorf("%v", err)
break
}
case err := <-waitCh:
var waitStatus syscall.WaitStatus
if exitError, ok := err.(*exec.ExitError); ok {
waitStatus = exitError.Sys().(syscall.WaitStatus)
os.Exit(waitStatus.ExitStatus())
}
if err != nil {
app.Fatalf("%v", err)
if err := cmd.Start(); err != nil {
app.Fatalf("%v", err)
}
// wait for the command to finish
waitCh := make(chan error, 1)
go func() {
waitCh <- cmd.Wait()
close(waitCh)
}()

for {
select {
case sig := <-input.Signals:
if err = cmd.Process.Signal(sig); err != nil {
app.Errorf("%v", err)
break
}
case err := <-waitCh:
var waitStatus syscall.WaitStatus
if exitError, ok := err.(*exec.ExitError); ok {
waitStatus = exitError.Sys().(syscall.WaitStatus)
os.Exit(waitStatus.ExitStatus())
}
if err != nil {
app.Fatalf("%v", err)
}
return
}
return
}
}
}
Expand Down

0 comments on commit d15847e

Please sign in to comment.