Skip to content

Commit

Permalink
create-aws-client-with-region
Browse files Browse the repository at this point in the history
  • Loading branch information
calvix committed Mar 13, 2024
1 parent e7c9629 commit a617ca1
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
6 changes: 4 additions & 2 deletions pkg/cloud/identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ func GetAssumeRoleCredentials(roleIdentityProvider *AWSRolePrincipalTypeProvider
}

// NewAWSRolePrincipalTypeProvider will create a new AWSRolePrincipalTypeProvider from an AWSClusterRoleIdentity.
func NewAWSRolePrincipalTypeProvider(identity *infrav1.AWSClusterRoleIdentity, sourceProvider AWSPrincipalTypeProvider, log logger.Wrapper) *AWSRolePrincipalTypeProvider {
func NewAWSRolePrincipalTypeProvider(identity *infrav1.AWSClusterRoleIdentity, sourceProvider AWSPrincipalTypeProvider, region string, log logger.Wrapper) *AWSRolePrincipalTypeProvider {
return &AWSRolePrincipalTypeProvider{
credentials: nil,
stsClient: nil,
region: region,
Principal: identity,
sourceProvider: sourceProvider,
log: log.WithName("AWSRolePrincipalTypeProvider"),
Expand Down Expand Up @@ -130,6 +131,7 @@ func (p *AWSStaticPrincipalTypeProvider) IsExpired() bool {
type AWSRolePrincipalTypeProvider struct {
Principal *infrav1.AWSClusterRoleIdentity
credentials *credentials.Credentials
region string
sourceProvider AWSPrincipalTypeProvider
log logger.Wrapper
stsClient stsiface.STSAPI
Expand All @@ -154,7 +156,7 @@ func (p *AWSRolePrincipalTypeProvider) Name() string {
// Retrieve returns the credential values for the AWSRolePrincipalTypeProvider.
func (p *AWSRolePrincipalTypeProvider) Retrieve() (credentials.Value, error) {
if p.credentials == nil || p.IsExpired() {
awsConfig := aws.NewConfig()
awsConfig := aws.NewConfig().WithRegion(p.region)
if p.sourceProvider != nil {
sourceCreds, err := p.sourceProvider.Retrieve()
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions pkg/cloud/identity/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func TestAWSStaticPrincipalTypeProvider(t *testing.T) {
roleProvider := &AWSRolePrincipalTypeProvider{
credentials: nil,
Principal: roleIdentity,
region: "us-west-2",
sourceProvider: staticProvider,
stsClient: stsMock,
}
Expand All @@ -78,6 +79,7 @@ func TestAWSStaticPrincipalTypeProvider(t *testing.T) {
roleProvider2 := &AWSRolePrincipalTypeProvider{
credentials: nil,
Principal: roleIdentity2,
region: "us-west-2",
sourceProvider: roleProvider,
stsClient: stsMock,
}
Expand Down
11 changes: 6 additions & 5 deletions pkg/cloud/scope/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func sessionForClusterWithRegion(k8sClient client.Client, clusterScoper cloud.Se
return endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
}

providers, err := getProvidersForCluster(context.Background(), k8sClient, clusterScoper, log)
providers, err := getProvidersForCluster(context.Background(), k8sClient, clusterScoper, region, log)
if err != nil {
// could not get providers and retrieve the credentials
conditions.MarkFalse(clusterScoper.InfraCluster(), infrav1.PrincipalCredentialRetrievedCondition, infrav1.PrincipalCredentialRetrievalFailedReason, clusterv1.ConditionSeverityError, err.Error())
Expand Down Expand Up @@ -256,6 +256,7 @@ func buildProvidersForRef(
k8sClient client.Client,
clusterScoper cloud.SessionMetadata,
ref *infrav1.AWSIdentityReference,
region string,
log logger.Wrapper) ([]identity.AWSPrincipalTypeProvider, error) {
if ref == nil {
log.Trace("AWSCluster does not have a IdentityRef specified")
Expand Down Expand Up @@ -299,7 +300,7 @@ func buildProvidersForRef(
setPrincipalUsageAllowedCondition(clusterScoper)

if roleIdentity.Spec.SourceIdentityRef != nil {
providers, err = buildProvidersForRef(ctx, providers, k8sClient, clusterScoper, roleIdentity.Spec.SourceIdentityRef, log)
providers, err = buildProvidersForRef(ctx, providers, k8sClient, clusterScoper, roleIdentity.Spec.SourceIdentityRef, region, log)
if err != nil {
return providers, err
}
Expand All @@ -313,7 +314,7 @@ func buildProvidersForRef(
}
}

provider = identity.NewAWSRolePrincipalTypeProvider(roleIdentity, sourceProvider, log)
provider = identity.NewAWSRolePrincipalTypeProvider(roleIdentity, sourceProvider, region, log)
providers = append(providers, provider)
default:
return providers, errors.Errorf("No such provider known: '%s'", ref.Kind)
Expand Down Expand Up @@ -404,9 +405,9 @@ func buildAWSClusterControllerIdentity(ctx context.Context, identityObjectKey cl
return nil
}

func getProvidersForCluster(ctx context.Context, k8sClient client.Client, clusterScoper cloud.SessionMetadata, log logger.Wrapper) ([]identity.AWSPrincipalTypeProvider, error) {
func getProvidersForCluster(ctx context.Context, k8sClient client.Client, clusterScoper cloud.SessionMetadata, region string, log logger.Wrapper) ([]identity.AWSPrincipalTypeProvider, error) {
providers := make([]identity.AWSPrincipalTypeProvider, 0)
providers, err := buildProvidersForRef(ctx, providers, k8sClient, clusterScoper, clusterScoper.IdentityRef(), log)
providers, err := buildProvidersForRef(ctx, providers, k8sClient, clusterScoper, clusterScoper.IdentityRef(), region, log)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/cloud/scope/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ func TestPrincipalParsing(t *testing.T) {
Namespace: "default",
},
},
AWSCluster: &infrav1.AWSCluster{},
AWSCluster: &infrav1.AWSCluster{Spec: infrav1.AWSClusterSpec{Region: "us-west-2"}},
},
)

Expand Down Expand Up @@ -489,7 +489,7 @@ func TestPrincipalParsing(t *testing.T) {
k8sClient := fake.NewClientBuilder().WithScheme(scheme).Build()
tc.setup(t, k8sClient)
clusterScope.AWSCluster = &tc.awsCluster
providers, err := getProvidersForCluster(context.Background(), k8sClient, clusterScope, logger.NewLogger(klog.Background()))
providers, err := getProvidersForCluster(context.Background(), k8sClient, clusterScope, clusterScope.Region(), logger.NewLogger(klog.Background()))
if tc.expectError {
if err == nil {
t.Fatal("Expected an error but didn't get one")
Expand Down

0 comments on commit a617ca1

Please sign in to comment.