diff --git a/main.go b/main.go index 5847cbb..4407ffc 100644 --- a/main.go +++ b/main.go @@ -71,6 +71,7 @@ func main() { // Load credentials from configFilePath if it exists, else use regular AWS config var creds *credentials.Value + var region *string var err error if roleArnRe.MatchString(role) { creds, err = assumeRole(role, "", *duration) @@ -88,7 +89,7 @@ func main() { creds, err = assumeRole(roleConfig.Role, roleConfig.MFA, *duration) } else { - creds, err = assumeProfile(role) + creds, region, err = assumeProfile(role) } must(err) @@ -96,11 +97,11 @@ func main() { if len(args) == 0 { switch *format { case "powershell": - printPowerShellCredentials(role, creds) + printPowerShellCredentials(role, creds, region) case "bash": - printCredentials(role, creds) + printCredentials(role, creds, region) case "fish": - printFishCredentials(role, creds) + printFishCredentials(role, creds, region) default: flag.Usage() os.Exit(1) @@ -108,11 +109,11 @@ func main() { return } - err = execWithCredentials(args, creds) + err = execWithCredentials(args, creds, region) must(err) } -func execWithCredentials(argv []string, creds *credentials.Value) error { +func execWithCredentials(argv []string, creds *credentials.Value, region *string) error { argv0, err := exec.LookPath(argv[0]) if err != nil { return err @@ -122,6 +123,9 @@ func execWithCredentials(argv []string, creds *credentials.Value) error { os.Setenv("AWS_SECRET_ACCESS_KEY", creds.SecretAccessKey) os.Setenv("AWS_SESSION_TOKEN", creds.SessionToken) os.Setenv("AWS_SECURITY_TOKEN", creds.SessionToken) + if region != nil { + os.Setenv("AWS_DEFAULT_REGION", *region) + } env := os.Environ() return syscall.Exec(argv0, argv, env) @@ -129,36 +133,45 @@ func execWithCredentials(argv []string, creds *credentials.Value) error { // printCredentials prints the credentials in a way that can easily be sourced // with bash. -func printCredentials(role string, creds *credentials.Value) { +func printCredentials(role string, creds *credentials.Value, region *string) { fmt.Printf("export AWS_ACCESS_KEY_ID=\"%s\"\n", creds.AccessKeyID) fmt.Printf("export AWS_SECRET_ACCESS_KEY=\"%s\"\n", creds.SecretAccessKey) fmt.Printf("export AWS_SESSION_TOKEN=\"%s\"\n", creds.SessionToken) fmt.Printf("export AWS_SECURITY_TOKEN=\"%s\"\n", creds.SessionToken) fmt.Printf("export ASSUMED_ROLE=\"%s\"\n", role) + if region != nil { + fmt.Printf("export AWS_DEFAULT_REGION=\"%s\"\n", *region) + } fmt.Printf("# Run this to configure your shell:\n") fmt.Printf("# eval $(%s)\n", strings.Join(os.Args, " ")) } // printFishCredentials prints the credentials in a way that can easily be sourced // with fish. -func printFishCredentials(role string, creds *credentials.Value) { +func printFishCredentials(role string, creds *credentials.Value, region *string) { fmt.Printf("set -gx AWS_ACCESS_KEY_ID \"%s\";\n", creds.AccessKeyID) fmt.Printf("set -gx AWS_SECRET_ACCESS_KEY \"%s\";\n", creds.SecretAccessKey) fmt.Printf("set -gx AWS_SESSION_TOKEN \"%s\";\n", creds.SessionToken) fmt.Printf("set -gx AWS_SECURITY_TOKEN \"%s\";\n", creds.SessionToken) fmt.Printf("set -gx ASSUMED_ROLE \"%s\";\n", role) + if region != nil { + fmt.Printf("set -gx AWS_DEFAULT_REGION \"%s\";\n", *region) + } fmt.Printf("# Run this to configure your shell:\n") fmt.Printf("# eval (%s)\n", strings.Join(os.Args, " ")) } // printPowerShellCredentials prints the credentials in a way that can easily be sourced // with Windows powershell using Invoke-Expression. -func printPowerShellCredentials(role string, creds *credentials.Value) { +func printPowerShellCredentials(role string, creds *credentials.Value, region *string) { fmt.Printf("$env:AWS_ACCESS_KEY_ID=\"%s\"\n", creds.AccessKeyID) fmt.Printf("$env:AWS_SECRET_ACCESS_KEY=\"%s\"\n", creds.SecretAccessKey) fmt.Printf("$env:AWS_SESSION_TOKEN=\"%s\"\n", creds.SessionToken) fmt.Printf("$env:AWS_SECURITY_TOKEN=\"%s\"\n", creds.SessionToken) fmt.Printf("$env:ASSUMED_ROLE=\"%s\"\n", role) + if region != nil { + fmt.Printf("$env:AWS_DEFAULT_REGION=\"%s\"\n", *region) + } fmt.Printf("# Run this to configure your shell:\n") fmt.Printf("# %s | Invoke-Expression \n", strings.Join(os.Args, " ")) } @@ -166,18 +179,19 @@ func printPowerShellCredentials(role string, creds *credentials.Value) { // assumeProfile assumes the named profile which must exist in ~/.aws/config // (https://docs.aws.amazon.com/cli/latest/userguide/cli-roles.html) and returns the temporary STS // credentials. -func assumeProfile(profile string) (*credentials.Value, error) { +func assumeProfile(profile string) (*credentials.Value, *string, error) { sess := session.Must(session.NewSessionWithOptions(session.Options{ Profile: profile, SharedConfigState: session.SharedConfigEnable, AssumeRoleTokenProvider: readTokenCode, })) + region := sess.Config.Region creds, err := sess.Config.Credentials.Get() if err != nil { - return nil, err + return nil, nil, err } - return &creds, nil + return &creds, region, nil } // assumeRole assumes the given role and returns the temporary STS credentials.