diff --git a/lib/cloud/mocks/aws_sts.go b/lib/cloud/mocks/aws_sts.go index 178a1259669a4..cf117788e696f 100644 --- a/lib/cloud/mocks/aws_sts.go +++ b/lib/cloud/mocks/aws_sts.go @@ -54,6 +54,12 @@ type STSClient struct { recordFn func(roleARN, externalID string) } +func (m *STSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { + return &sts.GetCallerIdentityOutput{ + Arn: aws.String(m.ARN), + }, nil +} + func (m *STSClient) AssumeRoleWithWebIdentity(ctx context.Context, in *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { m.record(aws.ToString(in.RoleArn), "") expiry := time.Now().Add(60 * time.Minute) diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index a5ed6cbeff80d..3eea560f67174 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -37,7 +37,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v3" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go-v2/service/eks" @@ -46,8 +45,6 @@ import ( redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go-v2/service/ssm" ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" - "github.com/aws/aws-sdk-go-v2/service/sts" - ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" "github.com/aws/aws-sdk-go/service/rds" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -1530,7 +1527,10 @@ func TestDiscoveryInCloudKube(t *testing.T) { tlsServer.Auth().SetUsageReporter(reporter) mockedClients := &mockFetchersClients{ - stsClient: &mockSTSClient{}, + AWSConfigProvider: mocks.AWSConfigProvider{ + STSClient: &mocks.STSClient{}, + OIDCIntegrationClient: newFakeAccessPoint(), + }, eksClusters: newPopulatedEKSMock().clusters, } @@ -1586,8 +1586,8 @@ func TestDiscoveryInCloudKube(t *testing.T) { return len(clustersNotUpdated) == 0 && clustersFoundInAuth }, 5*time.Second, 200*time.Millisecond) - require.ElementsMatch(t, tc.expectedAssumedRoles, mockedClients.stsClient.GetAssumedRoleARNs(), "roles incorrectly assumed") - require.ElementsMatch(t, tc.expectedExternalIDs, mockedClients.stsClient.GetAssumedRoleExternalIDs(), "external IDs incorrectly assumed") + require.ElementsMatch(t, tc.expectedAssumedRoles, mockedClients.STSClient.GetAssumedRoleARNs(), "roles incorrectly assumed") + require.ElementsMatch(t, tc.expectedExternalIDs, mockedClients.STSClient.GetAssumedRoleExternalIDs(), "external IDs incorrectly assumed") if tc.wantEvents > 0 { require.Eventually(t, func() bool { @@ -1824,65 +1824,8 @@ func newPopulatedEKSMock() *mockEKSAPI { } } -type mockSTSClient struct { - mu sync.Mutex - - ARN string - - assumedRoleARNs []string - assumedRoleExternalIDs []string - - // NOTE: Not used, but needed to comply with awsconfig.WithSTSClientProvider. - stscreds.AssumeRoleWithWebIdentityAPIClient -} - -func (m *mockSTSClient) GetAssumedRoleARNs() []string { - m.mu.Lock() - defer m.mu.Unlock() - return m.assumedRoleARNs -} - -func (m *mockSTSClient) GetAssumedRoleExternalIDs() []string { - m.mu.Lock() - defer m.mu.Unlock() - return m.assumedRoleExternalIDs -} - -func (m *mockSTSClient) ResetAssumeRoleHistory() { - m.mu.Lock() - defer m.mu.Unlock() - m.assumedRoleARNs = nil - m.assumedRoleExternalIDs = nil -} - -func (m *mockSTSClient) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { - m.mu.Lock() - defer m.mu.Unlock() - - if !slices.Contains(m.assumedRoleARNs, aws.ToString(params.RoleArn)) { - m.assumedRoleARNs = append(m.assumedRoleARNs, aws.ToString(params.RoleArn)) - m.assumedRoleExternalIDs = append(m.assumedRoleExternalIDs, aws.ToString(params.ExternalId)) - } - expiry := time.Now().Add(60 * time.Minute) - return &sts.AssumeRoleOutput{ - Credentials: &ststypes.Credentials{ - AccessKeyId: params.RoleArn, - SecretAccessKey: aws.String("secret"), - SessionToken: aws.String("token"), - Expiration: &expiry, - }, - }, nil -} - -func (m *mockSTSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { - return &sts.GetCallerIdentityOutput{ - Arn: aws.String(m.ARN), - }, nil -} - type mockFetchersClients struct { mocks.AWSConfigProvider - stsClient *mockSTSClient eksClusters []*ekstypes.Cluster } @@ -1893,10 +1836,10 @@ func (m *mockFetchersClients) GetAWSEKSClient(aws.Config) fetchers.EKSClient { } func (m *mockFetchersClients) GetAWSSTSClient(aws.Config) fetchers.STSClient { - if m.stsClient != nil { - return m.stsClient + if m.AWSConfigProvider.STSClient != nil { + return m.AWSConfigProvider.STSClient } - return &mockSTSClient{} + return &mocks.STSClient{} } func (m *mockFetchersClients) GetAWSSTSPresignClient(aws.Config) fetchers.STSPresignClient { diff --git a/lib/srv/discovery/kube_integration_watcher_test.go b/lib/srv/discovery/kube_integration_watcher_test.go index cc312bd0b817b..adb82ad29ecbe 100644 --- a/lib/srv/discovery/kube_integration_watcher_test.go +++ b/lib/srv/discovery/kube_integration_watcher_test.go @@ -44,6 +44,7 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/authz" + "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" @@ -138,6 +139,37 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) { testCAData = "VGVzdENBREFUQQ==" ) + // Create and start test auth server. + testAuthServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{ + Dir: t.TempDir(), + }) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, testAuthServer.Close()) }) + + awsOIDCIntegration, err := types.NewIntegrationAWSOIDC(types.Metadata{ + Name: "integration1", + }, &types.AWSOIDCIntegrationSpecV1{ + RoleARN: roleArn, + }) + require.NoError(t, err) + testAuthServer.AuthServer.IntegrationsTokenGenerator = &mockIntegrationsTokenGenerator{ + proxies: nil, + integrations: map[string]types.Integration{ + awsOIDCIntegration.GetName(): awsOIDCIntegration, + }, + } + + ctx := context.Background() + tlsServer, err := testAuthServer.NewTestTLSServer() + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, tlsServer.Close()) }) + _, err = tlsServer.Auth().CreateIntegration(ctx, awsOIDCIntegration) + require.NoError(t, err) + + fakeConfigProvider := mocks.AWSConfigProvider{ + OIDCIntegrationClient: tlsServer.Auth(), + } + testEKSClusters := []ekstypes.Cluster{ { Name: aws.String("eks-cluster1"), @@ -364,7 +396,8 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) { authz.ContextWithUser(ctx, identity.I), &Config{ AWSFetchersClients: &mockFetchersClients{ - eksClusters: eksMockClusters[:2], + AWSConfigProvider: fakeConfigProvider, + eksClusters: eksMockClusters[:2], }, ClusterFeatures: func() proto.Features { return proto.Features{} }, KubernetesClient: fake.NewSimpleClientset(),