diff --git a/cli/exec.go b/cli/exec.go index fa4a16ef0..60a936e7b 100644 --- a/cli/exec.go +++ b/cli/exec.go @@ -1,6 +1,8 @@ package cli import ( + "encoding/json" + "fmt" "log" "os" "os/exec" @@ -17,17 +19,26 @@ import ( ) type ExecCommandInput struct { - Profile string - Command string - Args []string - Keyring keyring.Keyring - Duration time.Duration - RoleDuration time.Duration - MfaToken string - MfaPrompt prompt.PromptFunc - StartServer bool - Signals chan os.Signal - NoSession bool + Profile string + Command string + Args []string + Keyring keyring.Keyring + Duration time.Duration + RoleDuration time.Duration + MfaToken string + MfaPrompt prompt.PromptFunc + StartServer bool + CredentialHelper bool + Signals chan os.Signal + NoSession bool +} + +type AwsCredentialHelperData struct { + Version string `json:"Version"` + AccessKeyID string `json:"AccessKeyId"` + SecretAccessKey string `json:"SecretAccessKey"` + SessionToken string `json:"SessionToken"` + Expiration string `json:"Expiration,omitempty"` } func ConfigureExecCommand(app *kingpin.Application) { @@ -53,6 +64,10 @@ func ConfigureExecCommand(app *kingpin.Application) { Short('m'). StringVar(&input.MfaToken) + cmd.Flag("json", "AWS credential helper"). + Short('j'). + BoolVar(&input.CredentialHelper) + cmd.Flag("server", "Run the server in the background for credentials"). Short('s'). BoolVar(&input.StartServer) @@ -116,67 +131,85 @@ func ExecCommand(app *kingpin.Application, input ExecCommandInput) { } } - env := environ(os.Environ()) - env.Set("AWS_VAULT", input.Profile) - - env.Unset("AWS_ACCESS_KEY_ID") - env.Unset("AWS_SECRET_ACCESS_KEY") - env.Unset("AWS_CREDENTIAL_FILE") - env.Unset("AWS_DEFAULT_PROFILE") - env.Unset("AWS_PROFILE") - - if profile, _ := awsConfig.Profile(input.Profile); profile.Region != "" { - log.Printf("Setting subprocess env: AWS_DEFAULT_REGION=%s, AWS_REGION=%s", profile.Region, profile.Region) - env.Set("AWS_DEFAULT_REGION", profile.Region) - env.Set("AWS_REGION", profile.Region) - } + if input.CredentialHelper { + credentialData := AwsCredentialHelperData{ + Version: "1", + AccessKeyID: val.AccessKeyID, + SecretAccessKey: val.SecretAccessKey, + SessionToken: val.SessionToken, + } + if !input.NoSession { + credentialData.Expiration = time.Now().Add(input.Duration).Format("2006-01-02T15:04:05") + } + json, err := json.Marshal(&credentialData) + if err != nil { + app.Fatalf("Error creating credential json") + } + fmt.Printf(string(json)) + } else { + + env := environ(os.Environ()) + env.Set("AWS_VAULT", input.Profile) + + env.Unset("AWS_ACCESS_KEY_ID") + env.Unset("AWS_SECRET_ACCESS_KEY") + env.Unset("AWS_CREDENTIAL_FILE") + env.Unset("AWS_DEFAULT_PROFILE") + env.Unset("AWS_PROFILE") + + if profile, _ := awsConfig.Profile(input.Profile); profile.Region != "" { + log.Printf("Setting subprocess env: AWS_DEFAULT_REGION=%s, AWS_REGION=%s", profile.Region, profile.Region) + env.Set("AWS_DEFAULT_REGION", profile.Region) + env.Set("AWS_REGION", profile.Region) + } - if setEnv { - log.Println("Setting subprocess env: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY") - env.Set("AWS_ACCESS_KEY_ID", val.AccessKeyID) - env.Set("AWS_SECRET_ACCESS_KEY", val.SecretAccessKey) + if setEnv { + log.Println("Setting subprocess env: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY") + env.Set("AWS_ACCESS_KEY_ID", val.AccessKeyID) + env.Set("AWS_SECRET_ACCESS_KEY", val.SecretAccessKey) - if val.SessionToken != "" { - log.Println("Setting subprocess env: AWS_SESSION_TOKEN, AWS_SECURITY_TOKEN") - env.Set("AWS_SESSION_TOKEN", val.SessionToken) - env.Set("AWS_SECURITY_TOKEN", val.SessionToken) + if val.SessionToken != "" { + log.Println("Setting subprocess env: AWS_SESSION_TOKEN, AWS_SECURITY_TOKEN") + env.Set("AWS_SESSION_TOKEN", val.SessionToken) + env.Set("AWS_SECURITY_TOKEN", val.SessionToken) + } } - } - cmd := exec.Command(input.Command, input.Args...) - cmd.Env = env - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - signal.Notify(input.Signals, os.Interrupt, os.Kill) + cmd := exec.Command(input.Command, input.Args...) + cmd.Env = env + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + signal.Notify(input.Signals, os.Interrupt, os.Kill) - if err := cmd.Start(); err != nil { - app.Fatalf("%v", err) - } - // wait for the command to finish - waitCh := make(chan error, 1) - go func() { - waitCh <- cmd.Wait() - close(waitCh) - }() - - for { - select { - case sig := <-input.Signals: - if err = cmd.Process.Signal(sig); err != nil { - app.Errorf("%v", err) - break - } - case err := <-waitCh: - var waitStatus syscall.WaitStatus - if exitError, ok := err.(*exec.ExitError); ok { - waitStatus = exitError.Sys().(syscall.WaitStatus) - os.Exit(waitStatus.ExitStatus()) - } - if err != nil { - app.Fatalf("%v", err) + if err := cmd.Start(); err != nil { + app.Fatalf("%v", err) + } + // wait for the command to finish + waitCh := make(chan error, 1) + go func() { + waitCh <- cmd.Wait() + close(waitCh) + }() + + for { + select { + case sig := <-input.Signals: + if err = cmd.Process.Signal(sig); err != nil { + app.Errorf("%v", err) + break + } + case err := <-waitCh: + var waitStatus syscall.WaitStatus + if exitError, ok := err.(*exec.ExitError); ok { + waitStatus = exitError.Sys().(syscall.WaitStatus) + os.Exit(waitStatus.ExitStatus()) + } + if err != nil { + app.Fatalf("%v", err) + } + return } - return } } }