Skip to content

Commit

Permalink
Merge pull request 99designs#264 from AlexRudd/support-external-id
Browse files Browse the repository at this point in the history
Read external_id into Profile and pass to assume role requests
  • Loading branch information
lox authored Aug 1, 2018
2 parents b496dcc + 60636d8 commit ef7dc21
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
1 change: 1 addition & 0 deletions vault/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ type Profile struct {
Name string `ini:"-"`
MFASerial string `ini:"mfa_serial,omitempty"`
RoleARN string `ini:"role_arn,omitempty"`
ExternalID string `ini:"external_id,omitempty"`
Region string `ini:"region,omitempty"`
SourceProfile string `ini:"source_profile,omitempty"`
RoleSessionName string `ini:"role_session_name,omitempty"`
Expand Down
26 changes: 17 additions & 9 deletions vault/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (p *VaultProvider) Retrieve() (credentials.Value, error) {
session.Expiration.Sub(time.Now()).String())

if profile, exists := p.config.Profile(p.profile); exists && profile.RoleARN != "" {
session, err = p.assumeRoleFromSession(session, profile.RoleARN)
session, err = p.assumeRoleFromSession(session, profile)
if err != nil {
return credentials.Value{}, err
}
Expand Down Expand Up @@ -165,7 +165,7 @@ func (p *VaultProvider) RetrieveWithoutSessionToken() (credentials.Value, error)
}

if profile, exists := p.config.Profile(p.profile); exists && profile.RoleARN != "" {
session, err := p.assumeRole(creds, profile.RoleARN)
session, err := p.assumeRole(creds, profile)
if err != nil {
return credentials.Value{}, err
}
Expand Down Expand Up @@ -283,7 +283,7 @@ func (p *VaultProvider) roleSessionName() string {
}

// assumeRoleFromSession takes a session created with GetSessionToken and uses that to assume a role
func (p *VaultProvider) assumeRoleFromSession(creds sts.Credentials, roleArn string) (sts.Credentials, error) {
func (p *VaultProvider) assumeRoleFromSession(creds sts.Credentials, profile Profile) (sts.Credentials, error) {
client := sts.New(session.New(p.awsConfig().
WithCredentials(credentials.NewStaticCredentials(
*creds.AccessKeyId,
Expand All @@ -292,12 +292,16 @@ func (p *VaultProvider) assumeRoleFromSession(creds sts.Credentials, roleArn str
))))

input := &sts.AssumeRoleInput{
RoleArn: aws.String(roleArn),
RoleArn: aws.String(profile.RoleARN),
RoleSessionName: aws.String(p.roleSessionName()),
DurationSeconds: aws.Int64(int64(p.AssumeRoleDuration.Seconds())),
}

log.Printf("Assuming role %s from session token", roleArn)
if profile.ExternalID != "" {
input.ExternalId = aws.String(profile.ExternalID)
}

log.Printf("Assuming role %s from session token", profile.RoleARN)
resp, err := client.AssumeRole(input)
if err != nil {
return sts.Credentials{}, err
Expand All @@ -307,19 +311,23 @@ func (p *VaultProvider) assumeRoleFromSession(creds sts.Credentials, roleArn str
}

// assumeRole uses IAM credentials to assume a role
func (p *VaultProvider) assumeRole(creds credentials.Value, roleArn string) (sts.Credentials, error) {
func (p *VaultProvider) assumeRole(creds credentials.Value, profile Profile) (sts.Credentials, error) {
client := sts.New(session.New(p.awsConfig().
WithCredentials(credentials.NewCredentials(&credentials.StaticProvider{Value: creds})),
))

input := &sts.AssumeRoleInput{
RoleArn: aws.String(roleArn),
RoleArn: aws.String(profile.RoleARN),
RoleSessionName: aws.String(p.roleSessionName()),
DurationSeconds: aws.Int64(int64(p.AssumeRoleDuration.Seconds())),
}

if profile.ExternalID != "" {
input.ExternalId = aws.String(profile.ExternalID)
}

// if we don't have a session, we need to include MFA token in the AssumeRole call
if profile, _ := p.Config.Profile(p.profile); profile.MFASerial != "" {
if profile.MFASerial != "" {
input.SerialNumber = aws.String(profile.MFASerial)
if p.MfaToken == "" {
token, err := p.MfaPrompt(fmt.Sprintf("Enter token for %s: ", profile.MFASerial))
Expand All @@ -332,7 +340,7 @@ func (p *VaultProvider) assumeRole(creds credentials.Value, roleArn string) (sts
}
}

log.Printf("Assuming role %s with iam credentials", roleArn)
log.Printf("Assuming role %s with iam credentials", profile.RoleARN)
resp, err := client.AssumeRole(input)
if err != nil {
return sts.Credentials{}, err
Expand Down

0 comments on commit ef7dc21

Please sign in to comment.