From a617ca15cc737b022f8b528d6aefc737b419fdb3 Mon Sep 17 00:00:00 2001 From: calvix Date: Wed, 13 Mar 2024 08:44:49 +0100 Subject: [PATCH] create-aws-client-with-region --- pkg/cloud/identity/identity.go | 6 ++++-- pkg/cloud/identity/identity_test.go | 2 ++ pkg/cloud/scope/session.go | 11 ++++++----- pkg/cloud/scope/session_test.go | 4 ++-- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/pkg/cloud/identity/identity.go b/pkg/cloud/identity/identity.go index 29a57a7337..18e77bf293 100644 --- a/pkg/cloud/identity/identity.go +++ b/pkg/cloud/identity/identity.go @@ -80,10 +80,11 @@ func GetAssumeRoleCredentials(roleIdentityProvider *AWSRolePrincipalTypeProvider } // NewAWSRolePrincipalTypeProvider will create a new AWSRolePrincipalTypeProvider from an AWSClusterRoleIdentity. -func NewAWSRolePrincipalTypeProvider(identity *infrav1.AWSClusterRoleIdentity, sourceProvider AWSPrincipalTypeProvider, log logger.Wrapper) *AWSRolePrincipalTypeProvider { +func NewAWSRolePrincipalTypeProvider(identity *infrav1.AWSClusterRoleIdentity, sourceProvider AWSPrincipalTypeProvider, region string, log logger.Wrapper) *AWSRolePrincipalTypeProvider { return &AWSRolePrincipalTypeProvider{ credentials: nil, stsClient: nil, + region: region, Principal: identity, sourceProvider: sourceProvider, log: log.WithName("AWSRolePrincipalTypeProvider"), @@ -130,6 +131,7 @@ func (p *AWSStaticPrincipalTypeProvider) IsExpired() bool { type AWSRolePrincipalTypeProvider struct { Principal *infrav1.AWSClusterRoleIdentity credentials *credentials.Credentials + region string sourceProvider AWSPrincipalTypeProvider log logger.Wrapper stsClient stsiface.STSAPI @@ -154,7 +156,7 @@ func (p *AWSRolePrincipalTypeProvider) Name() string { // Retrieve returns the credential values for the AWSRolePrincipalTypeProvider. func (p *AWSRolePrincipalTypeProvider) Retrieve() (credentials.Value, error) { if p.credentials == nil || p.IsExpired() { - awsConfig := aws.NewConfig() + awsConfig := aws.NewConfig().WithRegion(p.region) if p.sourceProvider != nil { sourceCreds, err := p.sourceProvider.Retrieve() if err != nil { diff --git a/pkg/cloud/identity/identity_test.go b/pkg/cloud/identity/identity_test.go index 8c204be9f4..9f4a995ab8 100644 --- a/pkg/cloud/identity/identity_test.go +++ b/pkg/cloud/identity/identity_test.go @@ -61,6 +61,7 @@ func TestAWSStaticPrincipalTypeProvider(t *testing.T) { roleProvider := &AWSRolePrincipalTypeProvider{ credentials: nil, Principal: roleIdentity, + region: "us-west-2", sourceProvider: staticProvider, stsClient: stsMock, } @@ -78,6 +79,7 @@ func TestAWSStaticPrincipalTypeProvider(t *testing.T) { roleProvider2 := &AWSRolePrincipalTypeProvider{ credentials: nil, Principal: roleIdentity2, + region: "us-west-2", sourceProvider: roleProvider, stsClient: stsMock, } diff --git a/pkg/cloud/scope/session.go b/pkg/cloud/scope/session.go index 95f5e68662..546e11089b 100644 --- a/pkg/cloud/scope/session.go +++ b/pkg/cloud/scope/session.go @@ -120,7 +120,7 @@ func sessionForClusterWithRegion(k8sClient client.Client, clusterScoper cloud.Se return endpoints.DefaultResolver().EndpointFor(service, region, optFns...) } - providers, err := getProvidersForCluster(context.Background(), k8sClient, clusterScoper, log) + providers, err := getProvidersForCluster(context.Background(), k8sClient, clusterScoper, region, log) if err != nil { // could not get providers and retrieve the credentials conditions.MarkFalse(clusterScoper.InfraCluster(), infrav1.PrincipalCredentialRetrievedCondition, infrav1.PrincipalCredentialRetrievalFailedReason, clusterv1.ConditionSeverityError, err.Error()) @@ -256,6 +256,7 @@ func buildProvidersForRef( k8sClient client.Client, clusterScoper cloud.SessionMetadata, ref *infrav1.AWSIdentityReference, + region string, log logger.Wrapper) ([]identity.AWSPrincipalTypeProvider, error) { if ref == nil { log.Trace("AWSCluster does not have a IdentityRef specified") @@ -299,7 +300,7 @@ func buildProvidersForRef( setPrincipalUsageAllowedCondition(clusterScoper) if roleIdentity.Spec.SourceIdentityRef != nil { - providers, err = buildProvidersForRef(ctx, providers, k8sClient, clusterScoper, roleIdentity.Spec.SourceIdentityRef, log) + providers, err = buildProvidersForRef(ctx, providers, k8sClient, clusterScoper, roleIdentity.Spec.SourceIdentityRef, region, log) if err != nil { return providers, err } @@ -313,7 +314,7 @@ func buildProvidersForRef( } } - provider = identity.NewAWSRolePrincipalTypeProvider(roleIdentity, sourceProvider, log) + provider = identity.NewAWSRolePrincipalTypeProvider(roleIdentity, sourceProvider, region, log) providers = append(providers, provider) default: return providers, errors.Errorf("No such provider known: '%s'", ref.Kind) @@ -404,9 +405,9 @@ func buildAWSClusterControllerIdentity(ctx context.Context, identityObjectKey cl return nil } -func getProvidersForCluster(ctx context.Context, k8sClient client.Client, clusterScoper cloud.SessionMetadata, log logger.Wrapper) ([]identity.AWSPrincipalTypeProvider, error) { +func getProvidersForCluster(ctx context.Context, k8sClient client.Client, clusterScoper cloud.SessionMetadata, region string, log logger.Wrapper) ([]identity.AWSPrincipalTypeProvider, error) { providers := make([]identity.AWSPrincipalTypeProvider, 0) - providers, err := buildProvidersForRef(ctx, providers, k8sClient, clusterScoper, clusterScoper.IdentityRef(), log) + providers, err := buildProvidersForRef(ctx, providers, k8sClient, clusterScoper, clusterScoper.IdentityRef(), region, log) if err != nil { return nil, err } diff --git a/pkg/cloud/scope/session_test.go b/pkg/cloud/scope/session_test.go index 13bffa1a9e..9620d23df1 100644 --- a/pkg/cloud/scope/session_test.go +++ b/pkg/cloud/scope/session_test.go @@ -228,7 +228,7 @@ func TestPrincipalParsing(t *testing.T) { Namespace: "default", }, }, - AWSCluster: &infrav1.AWSCluster{}, + AWSCluster: &infrav1.AWSCluster{Spec: infrav1.AWSClusterSpec{Region: "us-west-2"}}, }, ) @@ -489,7 +489,7 @@ func TestPrincipalParsing(t *testing.T) { k8sClient := fake.NewClientBuilder().WithScheme(scheme).Build() tc.setup(t, k8sClient) clusterScope.AWSCluster = &tc.awsCluster - providers, err := getProvidersForCluster(context.Background(), k8sClient, clusterScope, logger.NewLogger(klog.Background())) + providers, err := getProvidersForCluster(context.Background(), k8sClient, clusterScope, clusterScope.Region(), logger.NewLogger(klog.Background())) if tc.expectError { if err == nil { t.Fatal("Expected an error but didn't get one")