diff --git a/pkg/describer/worker.go b/pkg/describer/worker.go index fa1f7056..884a5383 100755 --- a/pkg/describer/worker.go +++ b/pkg/describer/worker.go @@ -96,7 +96,9 @@ func doDescribe( if err != nil { return nil, fmt.Errorf(" account credentials: %w", err) } - creds.CrossAccountRoleName = job.IntegrationLabels["CrossAccountRoleARN"] + crossAccountName := strings.Split(job.IntegrationLabels["CrossAccountRoleARN"], "/") + creds.CrossAccountRoleName = crossAccountName[1] + creds.AccountID = job.ProviderID logger.Info("Creds", zap.Any("creds", creds)) diff --git a/provider/configs/credentials.go b/provider/configs/credentials.go index 7ff03dd9..c7fc7bd7 100644 --- a/provider/configs/credentials.go +++ b/provider/configs/credentials.go @@ -4,4 +4,5 @@ import "github.com/opengovern/opengovernance/services/integration/integration-ty type IntegrationCredentials struct { configs.IntegrationCredentials + AccountID string `json:"account_id"` } diff --git a/provider/credentials.go b/provider/credentials.go index 4fd3fee5..fb0df7aa 100755 --- a/provider/credentials.go +++ b/provider/credentials.go @@ -4,8 +4,8 @@ import ( "encoding/json" "fmt" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts" awsmodel "github.com/opengovern/og-aws-describer/aws/model" model "github.com/opengovern/og-describer-aws/pkg/sdk/models" "github.com/opengovern/og-describer-aws/provider/configs" @@ -14,25 +14,60 @@ import ( "strings" ) -// GenerateAWSConfig creates an AWS configuration using the provided credentials provider. -func GenerateAWSConfig(credsProvider aws.CredentialsProvider) (aws.Config, error) { - cfg, err := config.LoadDefaultConfig(context.TODO(), - config.WithCredentialsProvider(credsProvider), - ) - if err != nil { - return aws.Config{}, fmt.Errorf("failed to load configuration: %v", err) +// GenerateAWSConfig creates an AWS configuration using the provided credentials. +// It can assume a role if roleNameToAssume is provided. +func GenerateAWSConfig(awsAccessKeyID string, awsSecretAccessKey string, roleNameToAssume string, externalID string, accountID string) (aws.Config, error) { + // Step 1: Create base credentials provider + credsProvider := aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider( + awsAccessKeyID, + awsSecretAccessKey, + "", + )) + + // Step 2: Manually create the AWS Config struct with explicit credentials and region + cfg := aws.Config{ + Region: "us-east-2", + Credentials: credsProvider, + } + + // Step 3: If a role is specified to assume, perform the AssumeRole operation + if roleNameToAssume != "" { + // Construct Role ARN + roleArn := fmt.Sprintf("arn:aws:iam::%s:role/%s", accountID, roleNameToAssume) + + // Use STS client with the current config + stsClient := sts.NewFromConfig(cfg) + + // Prepare AssumeRole input + input := &sts.AssumeRoleInput{ + RoleArn: aws.String(roleArn), + RoleSessionName: aws.String("GenerateAWSConfigSession"), + } + if externalID != "" { + input.ExternalId = aws.String(externalID) + } + + // Perform AssumeRole + assumeRoleOutput, err := stsClient.AssumeRole(context.TODO(), input) + if err != nil { + return aws.Config{}, fmt.Errorf("failed to assume role: %v", err) + } + + // Update credentials provider with assumed role credentials + credsProvider = aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider( + *assumeRoleOutput.Credentials.AccessKeyId, + *assumeRoleOutput.Credentials.SecretAccessKey, + *assumeRoleOutput.Credentials.SessionToken, + )) + + // Update the AWS Config with the new credentials + cfg.Credentials = credsProvider } return cfg, nil } func GetConfig(config configs.IntegrationCredentials) (*aws.Config, error) { - credsProvider := aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider( - config.AwsAccessKeyID, - config.AwsSecretAccessKey, - "", - )) - - cfg, err := GenerateAWSConfig(credsProvider) + cfg, err := GenerateAWSConfig(config.AwsAccessKeyID, config.AwsSecretAccessKey, config.CrossAccountRoleName, config.ExternalID, config.AccountID) if err != nil { return nil, err }