Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
creack committed Jan 9, 2025
1 parent c9d05c6 commit 5586bc0
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 67 deletions.
6 changes: 6 additions & 0 deletions lib/cloud/mocks/aws_sts.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ type STSClient struct {
recordFn func(roleARN, externalID string)
}

func (m *STSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) {
return &sts.GetCallerIdentityOutput{
Arn: aws.String(m.ARN),
}, nil
}

func (m *STSClient) AssumeRoleWithWebIdentity(ctx context.Context, in *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
m.record(aws.ToString(in.RoleArn), "")
expiry := time.Now().Add(60 * time.Minute)
Expand Down
75 changes: 9 additions & 66 deletions lib/srv/discovery/discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v3"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go-v2/service/eks"
Expand All @@ -46,8 +45,6 @@ import (
redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types"
"github.com/aws/aws-sdk-go-v2/service/ssm"
ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types"
"github.com/aws/aws-sdk-go-v2/service/sts"
ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
Expand Down Expand Up @@ -1530,7 +1527,10 @@ func TestDiscoveryInCloudKube(t *testing.T) {
tlsServer.Auth().SetUsageReporter(reporter)

mockedClients := &mockFetchersClients{
stsClient: &mockSTSClient{},
AWSConfigProvider: mocks.AWSConfigProvider{
STSClient: &mocks.STSClient{},
OIDCIntegrationClient: newFakeAccessPoint(),
},
eksClusters: newPopulatedEKSMock().clusters,
}

Expand Down Expand Up @@ -1586,8 +1586,8 @@ func TestDiscoveryInCloudKube(t *testing.T) {
return len(clustersNotUpdated) == 0 && clustersFoundInAuth
}, 5*time.Second, 200*time.Millisecond)

require.ElementsMatch(t, tc.expectedAssumedRoles, mockedClients.stsClient.GetAssumedRoleARNs(), "roles incorrectly assumed")
require.ElementsMatch(t, tc.expectedExternalIDs, mockedClients.stsClient.GetAssumedRoleExternalIDs(), "external IDs incorrectly assumed")
require.ElementsMatch(t, tc.expectedAssumedRoles, mockedClients.STSClient.GetAssumedRoleARNs(), "roles incorrectly assumed")
require.ElementsMatch(t, tc.expectedExternalIDs, mockedClients.STSClient.GetAssumedRoleExternalIDs(), "external IDs incorrectly assumed")

if tc.wantEvents > 0 {
require.Eventually(t, func() bool {
Expand Down Expand Up @@ -1824,65 +1824,8 @@ func newPopulatedEKSMock() *mockEKSAPI {
}
}

type mockSTSClient struct {
mu sync.Mutex

ARN string

assumedRoleARNs []string
assumedRoleExternalIDs []string

// NOTE: Not used, but needed to comply with awsconfig.WithSTSClientProvider.
stscreds.AssumeRoleWithWebIdentityAPIClient
}

func (m *mockSTSClient) GetAssumedRoleARNs() []string {
m.mu.Lock()
defer m.mu.Unlock()
return m.assumedRoleARNs
}

func (m *mockSTSClient) GetAssumedRoleExternalIDs() []string {
m.mu.Lock()
defer m.mu.Unlock()
return m.assumedRoleExternalIDs
}

func (m *mockSTSClient) ResetAssumeRoleHistory() {
m.mu.Lock()
defer m.mu.Unlock()
m.assumedRoleARNs = nil
m.assumedRoleExternalIDs = nil
}

func (m *mockSTSClient) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) {
m.mu.Lock()
defer m.mu.Unlock()

if !slices.Contains(m.assumedRoleARNs, aws.ToString(params.RoleArn)) {
m.assumedRoleARNs = append(m.assumedRoleARNs, aws.ToString(params.RoleArn))
m.assumedRoleExternalIDs = append(m.assumedRoleExternalIDs, aws.ToString(params.ExternalId))
}
expiry := time.Now().Add(60 * time.Minute)
return &sts.AssumeRoleOutput{
Credentials: &ststypes.Credentials{
AccessKeyId: params.RoleArn,
SecretAccessKey: aws.String("secret"),
SessionToken: aws.String("token"),
Expiration: &expiry,
},
}, nil
}

func (m *mockSTSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) {
return &sts.GetCallerIdentityOutput{
Arn: aws.String(m.ARN),
}, nil
}

type mockFetchersClients struct {
mocks.AWSConfigProvider
stsClient *mockSTSClient
eksClusters []*ekstypes.Cluster
}

Expand All @@ -1893,10 +1836,10 @@ func (m *mockFetchersClients) GetAWSEKSClient(aws.Config) fetchers.EKSClient {
}

func (m *mockFetchersClients) GetAWSSTSClient(aws.Config) fetchers.STSClient {
if m.stsClient != nil {
return m.stsClient
if m.AWSConfigProvider.STSClient != nil {
return m.AWSConfigProvider.STSClient
}
return &mockSTSClient{}
return &mocks.STSClient{}
}

func (m *mockFetchersClients) GetAWSSTSPresignClient(aws.Config) fetchers.STSPresignClient {
Expand Down
35 changes: 34 additions & 1 deletion lib/srv/discovery/kube_integration_watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/cloud/mocks"
"github.com/gravitational/teleport/lib/integrations/awsoidc"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/discovery/common"
Expand Down Expand Up @@ -138,6 +139,37 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) {
testCAData = "VGVzdENBREFUQQ=="
)

// Create and start test auth server.
testAuthServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
Dir: t.TempDir(),
})
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, testAuthServer.Close()) })

awsOIDCIntegration, err := types.NewIntegrationAWSOIDC(types.Metadata{
Name: "integration1",
}, &types.AWSOIDCIntegrationSpecV1{
RoleARN: roleArn,
})
require.NoError(t, err)
testAuthServer.AuthServer.IntegrationsTokenGenerator = &mockIntegrationsTokenGenerator{
proxies: nil,
integrations: map[string]types.Integration{
awsOIDCIntegration.GetName(): awsOIDCIntegration,
},
}

ctx := context.Background()
tlsServer, err := testAuthServer.NewTestTLSServer()
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, tlsServer.Close()) })
_, err = tlsServer.Auth().CreateIntegration(ctx, awsOIDCIntegration)
require.NoError(t, err)

fakeConfigProvider := mocks.AWSConfigProvider{
OIDCIntegrationClient: tlsServer.Auth(),
}

testEKSClusters := []ekstypes.Cluster{
{
Name: aws.String("eks-cluster1"),
Expand Down Expand Up @@ -364,7 +396,8 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) {
authz.ContextWithUser(ctx, identity.I),
&Config{
AWSFetchersClients: &mockFetchersClients{
eksClusters: eksMockClusters[:2],
AWSConfigProvider: fakeConfigProvider,
eksClusters: eksMockClusters[:2],
},
ClusterFeatures: func() proto.Features { return proto.Features{} },
KubernetesClient: fake.NewSimpleClientset(),
Expand Down

0 comments on commit 5586bc0

Please sign in to comment.