diff --git a/lib/cloud/awsconfig/awsconfig.go b/lib/cloud/awsconfig/awsconfig.go index 8be00483f4012..7b1cabe5ffe75 100644 --- a/lib/cloud/awsconfig/awsconfig.go +++ b/lib/cloud/awsconfig/awsconfig.go @@ -28,6 +28,7 @@ import ( "github.com/gravitational/trace" "go.opentelemetry.io/otel" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/modules" ) @@ -43,12 +44,25 @@ const ( credentialsSourceIntegration ) -// IntegrationSessionProviderFunc defines a function that creates a credential provider from a region and an integration. -// This is used to generate aws configs for clients that must use an integration instead of ambient credentials. -type IntegrationCredentialProviderFunc func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) +// OIDCIntegrationClient is an interface that indicates which APIs are +// required to generate an AWS OIDC integration token. +type OIDCIntegrationClient interface { + // GetIntegration returns the specified integration resource. + GetIntegration(ctx context.Context, name string) (types.Integration, error) -// AssumeRoleClientProviderFunc provides an AWS STS assume role API client. -type AssumeRoleClientProviderFunc func(aws.Config) stscreds.AssumeRoleAPIClient + // GenerateAWSOIDCToken generates a token to be used to execute an AWS OIDC + // Integration action. + GenerateAWSOIDCToken(ctx context.Context, integrationName string) (string, error) +} + +// STSClient is a subset of the AWS STS API. +type STSClient interface { + stscreds.AssumeRoleAPIClient + stscreds.AssumeRoleWithWebIdentityAPIClient +} + +// STSClientProviderFunc provides an AWS STS assume role API client. +type STSClientProviderFunc func(aws.Config) STSClient // AssumeRole is an AWS role to assume, optionally with an external ID. type AssumeRole struct { @@ -68,14 +82,16 @@ type options struct { credentialsSource credentialsSource // integration is the name of the integration to be used to fetch the credentials. integration string - // integrationCredentialsProvider is the integration credential provider to use. - integrationCredentialsProvider IntegrationCredentialProviderFunc + // oidcIntegrationClient provides APIs to generate AWS OIDC tokens, which + // can then be exchanged for IAM credentials. + // Required if integration credentials are requested. + oidcIntegrationClient OIDCIntegrationClient // customRetryer is a custom retryer to use for the config. customRetryer func() aws.Retryer // maxRetries is the maximum number of retries to use for the config. maxRetries *int - // assumeRoleClientProvider sets the STS assume role client provider func. - assumeRoleClientProvider AssumeRoleClientProviderFunc + // stsClientProvider sets the STS assume role client provider func. + stsClientProvider STSClientProviderFunc } func buildOptions(optFns ...OptionsFn) (*options, error) { @@ -99,6 +115,9 @@ func (o *options) checkAndSetDefaults() error { if o.integration == "" { return trace.BadParameter("missing integration name") } + if o.oidcIntegrationClient == nil { + return trace.BadParameter("missing AWS OIDC integration client") + } default: return trace.BadParameter("missing credentials source (ambient or integration)") } @@ -106,8 +125,8 @@ func (o *options) checkAndSetDefaults() error { return trace.BadParameter("role chain contains more than 2 roles") } - if o.assumeRoleClientProvider == nil { - o.assumeRoleClientProvider = func(cfg aws.Config) stscreds.AssumeRoleAPIClient { + if o.stsClientProvider == nil { + o.stsClientProvider = func(cfg aws.Config) STSClient { return sts.NewFromConfig(cfg, func(o *sts.Options) { o.TracerProvider = smithyoteltracing.Adapt(otel.GetTracerProvider()) }) @@ -175,18 +194,17 @@ func WithAmbientCredentials() OptionsFn { } } -// WithIntegrationCredentialProvider sets the integration credential provider. -func WithIntegrationCredentialProvider(cred IntegrationCredentialProviderFunc) OptionsFn { +// WithSTSClientProvider sets the STS API client factory func. +func WithSTSClientProvider(fn STSClientProviderFunc) OptionsFn { return func(options *options) { - options.integrationCredentialsProvider = cred + options.stsClientProvider = fn } } -// WithAssumeRoleClientProviderFunc sets the STS API client factory func used to -// assume roles. -func WithAssumeRoleClientProviderFunc(fn AssumeRoleClientProviderFunc) OptionsFn { +// WithOIDCIntegrationClient sets the OIDC integration client. +func WithOIDCIntegrationClient(c OIDCIntegrationClient) OptionsFn { return func(options *options) { - options.assumeRoleClientProvider = fn + options.oidcIntegrationClient = c } } @@ -202,7 +220,7 @@ func GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Con if err != nil { return aws.Config{}, trace.Wrap(err) } - return getConfigForRoleChain(ctx, cfg, opts.assumeRoles, opts.assumeRoleClientProvider) + return getConfigForRoleChain(ctx, cfg, opts.assumeRoles, opts.stsClientProvider) } // loadDefaultConfig loads a new config. @@ -217,6 +235,7 @@ func buildConfigOptions(region string, cred aws.CredentialsProvider, opts *optio config.WithDefaultRegion(defaultRegion), config.WithRegion(region), config.WithCredentialsProvider(cred), + config.WithCredentialsCacheOptions(awsCredentialsCacheOptions), } if modules.GetModules().IsBoringBinary() { configOpts = append(configOpts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled)) @@ -232,27 +251,35 @@ func buildConfigOptions(region string, cred aws.CredentialsProvider, opts *optio // getBaseConfig returns an AWS config without assuming any roles. func getBaseConfig(ctx context.Context, region string, opts *options) (aws.Config, error) { - var cred aws.CredentialsProvider + slog.DebugContext(ctx, "Initializing AWS config from default credential chain", + "region", region, + ) + cfg, err := loadDefaultConfig(ctx, region, nil, opts) + if err != nil { + return aws.Config{}, trace.Wrap(err) + } + if opts.credentialsSource == credentialsSourceIntegration { - if opts.integrationCredentialsProvider == nil { - return aws.Config{}, trace.BadParameter("missing aws integration credential provider") + slog.DebugContext(ctx, "Initializing AWS config with OIDC integration credentials", + "region", region, + "integration", opts.integration, + ) + provider := &integrationCredentialsProvider{ + OIDCIntegrationClient: opts.oidcIntegrationClient, + stsClt: opts.stsClientProvider(cfg), + integrationName: opts.integration, } - - slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", opts.integration) - var err error - cred, err = opts.integrationCredentialsProvider(ctx, region, opts.integration) + cc := aws.NewCredentialsCache(provider, awsCredentialsCacheOptions) + _, err := cc.Retrieve(ctx) if err != nil { return aws.Config{}, trace.Wrap(err) } - } else { - slog.DebugContext(ctx, "Initializing AWS config from default credential chain", "region", region) + cfg.Credentials = cc } - - cfg, err := loadDefaultConfig(ctx, region, cred, opts) - return cfg, trace.Wrap(err) + return cfg, nil } -func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []AssumeRole, newCltFn AssumeRoleClientProviderFunc) (aws.Config, error) { +func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []AssumeRole, newCltFn STSClientProviderFunc) (aws.Config, error) { for _, r := range roles { cfg.Credentials = getAssumeRoleProvider(ctx, newCltFn(cfg), r) } @@ -277,3 +304,41 @@ func getAssumeRoleProvider(ctx context.Context, clt stscreds.AssumeRoleAPIClient } }) } + +// staticIdentityToken provides itself as a JWT []byte token to implement +// [stscreds.IdentityTokenRetriever]. +type staticIdentityToken string + +// GetIdentityToken retrieves the JWT token. +func (t staticIdentityToken) GetIdentityToken() ([]byte, error) { + return []byte(t), nil +} + +// integrationCredentialsProvider provides AWS OIDC integration credentials. +type integrationCredentialsProvider struct { + OIDCIntegrationClient + stsClt STSClient + integrationName string +} + +// Retrieve provides [aws.Credentials] for an AWS OIDC integration. +func (p *integrationCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { + integration, err := p.GetIntegration(ctx, p.integrationName) + if err != nil { + return aws.Credentials{}, trace.Wrap(err) + } + spec := integration.GetAWSOIDCIntegrationSpec() + if spec == nil { + return aws.Credentials{}, trace.BadParameter("invalid integration subkind, expected awsoidc, got %s", integration.GetSubKind()) + } + token, err := p.GenerateAWSOIDCToken(ctx, p.integrationName) + if err != nil { + return aws.Credentials{}, trace.Wrap(err) + } + cred, err := stscreds.NewWebIdentityRoleProvider( + p.stsClt, + spec.RoleARN, + staticIdentityToken(token), + ).Retrieve(ctx) + return cred, trace.Wrap(err) +} diff --git a/lib/cloud/awsconfig/awsconfig_test.go b/lib/cloud/awsconfig/awsconfig_test.go index 3cb2c4eda3123..2de624fe86c54 100644 --- a/lib/cloud/awsconfig/awsconfig_test.go +++ b/lib/cloud/awsconfig/awsconfig_test.go @@ -24,20 +24,13 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/sts" ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" "github.com/gravitational/trace" "github.com/stretchr/testify/require" -) - -type mockCredentialProvider struct { - cred aws.Credentials -} -func (m *mockCredentialProvider) Retrieve(_ context.Context) (aws.Credentials, error) { - return m.cred, nil -} + "github.com/gravitational/teleport/api/types" +) type mockAssumeRoleAPIClient struct{} @@ -57,6 +50,18 @@ func (m *mockAssumeRoleAPIClient) AssumeRole(_ context.Context, params *sts.Assu }, nil } +func (m *mockAssumeRoleAPIClient) AssumeRoleWithWebIdentity(ctx context.Context, in *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { + expiry := time.Now().Add(60 * time.Minute) + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &ststypes.Credentials{ + AccessKeyId: in.RoleArn, + SecretAccessKey: in.WebIdentityToken, + SessionToken: aws.String("token"), + Expiration: &expiry, + }, + }, nil +} + func TestGetConfigIntegration(t *testing.T) { t.Parallel() @@ -86,32 +91,100 @@ func testGetConfigIntegration(t *testing.T, provider Provider) { dummyIntegration := "integration-test" dummyRegion := "test-region-123" - t.Run("without an integration credential provider, must return missing credential provider error", func(t *testing.T) { + awsOIDCIntegration, err := types.NewIntegrationAWSOIDC( + types.Metadata{Name: "integration-test"}, + &types.AWSOIDCIntegrationSpecV1{ + RoleARN: "arn:aws:sts::123456789012:role/TestRole", + }, + ) + require.NoError(t, err) + fakeIntegrationClt := fakeOIDCIntegrationClient{ + getIntegrationFn: func(context.Context, string) (types.Integration, error) { + return awsOIDCIntegration, nil + }, + getTokenFn: func(context.Context, string) (string, error) { + return "oidc-token", nil + }, + } + + stsClt := func(cfg aws.Config) STSClient { + return &mockAssumeRoleAPIClient{} + } + + t.Run("without an integration client, must return missing credential provider error", func(t *testing.T) { ctx := context.Background() _, err := provider.GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration)) require.True(t, trace.IsBadParameter(err), "unexpected error: %v", err) - require.ErrorContains(t, err, "missing aws integration credential provider") + require.ErrorContains(t, err, "missing AWS OIDC integration client") + }) + + t.Run("with an integration client, must return integration fetch error", func(t *testing.T) { + ctx := context.Background() + + fakeIntegrationClt := fakeIntegrationClt + fakeIntegrationClt.getIntegrationFn = func(context.Context, string) (types.Integration, error) { + return nil, trace.NotFound("integration not found") + } + _, err := provider.GetConfig(ctx, dummyRegion, + WithCredentialsMaybeIntegration(dummyIntegration), + WithOIDCIntegrationClient(&fakeIntegrationClt), + WithSTSClientProvider(stsClt), + ) + require.Error(t, err) + require.ErrorContains(t, err, "integration not found") + }) + + t.Run("with an integration client, must check for AWS integration subkind", func(t *testing.T) { + ctx := context.Background() + + azureIntegration, err := types.NewIntegrationAzureOIDC( + types.Metadata{Name: "integration-test"}, + &types.AzureOIDCIntegrationSpecV1{ + TenantID: "abc", + ClientID: "123", + }, + ) + require.NoError(t, err) + fakeIntegrationClt := fakeIntegrationClt + fakeIntegrationClt.getIntegrationFn = func(context.Context, string) (types.Integration, error) { + return azureIntegration, nil + } + _, err = provider.GetConfig(ctx, dummyRegion, + WithCredentialsMaybeIntegration(dummyIntegration), + WithOIDCIntegrationClient(&fakeIntegrationClt), + WithSTSClientProvider(stsClt), + ) + require.Error(t, err) + require.ErrorContains(t, err, "invalid integration subkind") + }) + + t.Run("with an integration client, must return token generation errors", func(t *testing.T) { + ctx := context.Background() + fakeIntegrationClt := fakeIntegrationClt + fakeIntegrationClt.getTokenFn = func(context.Context, string) (string, error) { + return "", trace.BadParameter("failed to generate OIDC token") + } + _, err = provider.GetConfig(ctx, dummyRegion, + WithCredentialsMaybeIntegration(dummyIntegration), + WithOIDCIntegrationClient(&fakeIntegrationClt), + WithSTSClientProvider(stsClt), + ) + require.Error(t, err) + require.ErrorContains(t, err, "failed to generate OIDC token") }) - t.Run("with an integration credential provider, must return the credentials", func(t *testing.T) { + t.Run("with an integration client, must return the credentials", func(t *testing.T) { ctx := context.Background() cfg, err := provider.GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration), - WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { - if region == dummyRegion && integration == dummyIntegration { - return &mockCredentialProvider{ - cred: aws.Credentials{ - SessionToken: "foo-bar", - }, - }, nil - } - return nil, trace.NotFound("no creds in region %q with integration %q", region, integration) - })) + WithOIDCIntegrationClient(&fakeIntegrationClt), + WithSTSClientProvider(stsClt), + ) require.NoError(t, err) creds, err := cfg.Credentials.Retrieve(ctx) require.NoError(t, err) - require.Equal(t, "foo-bar", creds.SessionToken) + require.Equal(t, "oidc-token", creds.SecretAccessKey) }) t.Run("with an integration credential provider assuming a role, must return assumed role credentials", func(t *testing.T) { @@ -119,23 +192,9 @@ func testGetConfigIntegration(t *testing.T, provider Provider) { cfg, err := provider.GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration), - WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { - if region == dummyRegion && integration == dummyIntegration { - return &mockCredentialProvider{ - cred: aws.Credentials{ - SessionToken: "foo-bar", - }, - }, nil - } - return nil, trace.NotFound("no creds in region %q with integration %q", region, integration) - }), + WithOIDCIntegrationClient(&fakeIntegrationClt), WithAssumeRole("roleA", "abc123"), - WithAssumeRoleClientProviderFunc(func(cfg aws.Config) stscreds.AssumeRoleAPIClient { - creds, err := cfg.Credentials.Retrieve(context.Background()) - require.NoError(t, err) - require.Equal(t, "foo-bar", creds.SessionToken) - return &mockAssumeRoleAPIClient{} - }), + WithSTSClientProvider(stsClt), ) require.NoError(t, err) creds, err := cfg.Credentials.Retrieve(ctx) @@ -148,25 +207,11 @@ func testGetConfigIntegration(t *testing.T, provider Provider) { ctx := context.Background() _, err := provider.GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration), - WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { - if region == dummyRegion && integration == dummyIntegration { - return &mockCredentialProvider{ - cred: aws.Credentials{ - SessionToken: "foo-bar", - }, - }, nil - } - return nil, trace.NotFound("no creds in region %q with integration %q", region, integration) - }), + WithOIDCIntegrationClient(&fakeIntegrationClt), WithAssumeRole("roleA", "abc123"), WithAssumeRole("roleB", "abc123"), WithAssumeRole("roleC", "abc123"), - WithAssumeRoleClientProviderFunc(func(cfg aws.Config) stscreds.AssumeRoleAPIClient { - creds, err := cfg.Credentials.Retrieve(context.Background()) - require.NoError(t, err) - require.Equal(t, "foo-bar", creds.SessionToken) - return &mockAssumeRoleAPIClient{} - }), + WithSTSClientProvider(stsClt), ) require.Error(t, err) require.ErrorContains(t, err, "role chain contains more than 2 roles") @@ -177,10 +222,8 @@ func testGetConfigIntegration(t *testing.T, provider Provider) { _, err := provider.GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(""), - WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { - require.Fail(t, "this function should not be called") - return nil, nil - })) + WithOIDCIntegrationClient(&fakeOIDCIntegrationClient{unauth: true}), + ) require.NoError(t, err) }) @@ -189,10 +232,8 @@ func testGetConfigIntegration(t *testing.T, provider Provider) { _, err := provider.GetConfig(ctx, dummyRegion, WithAmbientCredentials(), - WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { - require.Fail(t, "this function should not be called") - return nil, nil - })) + WithOIDCIntegrationClient(&fakeOIDCIntegrationClient{unauth: true}), + ) require.NoError(t, err) }) @@ -200,10 +241,8 @@ func testGetConfigIntegration(t *testing.T, provider Provider) { ctx := context.Background() _, err := provider.GetConfig(ctx, dummyRegion, - WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { - require.Fail(t, "this function should not be called") - return nil, nil - })) + WithOIDCIntegrationClient(&fakeOIDCIntegrationClient{unauth: true}), + ) require.Error(t, err) require.ErrorContains(t, err, "missing credentials source") }) @@ -221,3 +260,24 @@ func TestNewCacheKey(t *testing.T) { `) require.Equal(t, want, got) } + +type fakeOIDCIntegrationClient struct { + unauth bool + + getIntegrationFn func(context.Context, string) (types.Integration, error) + getTokenFn func(context.Context, string) (string, error) +} + +func (f *fakeOIDCIntegrationClient) GetIntegration(ctx context.Context, name string) (types.Integration, error) { + if f.unauth { + return nil, trace.AccessDenied("unauthorized") + } + return f.getIntegrationFn(ctx, name) +} + +func (f *fakeOIDCIntegrationClient) GenerateAWSOIDCToken(ctx context.Context, integrationName string) (string, error) { + if f.unauth { + return "", trace.AccessDenied("unauthorized") + } + return f.getTokenFn(ctx, integrationName) +} diff --git a/lib/cloud/awsconfig/cache.go b/lib/cloud/awsconfig/cache.go index 3d664ba04c350..cdb315703212a 100644 --- a/lib/cloud/awsconfig/cache.go +++ b/lib/cloud/awsconfig/cache.go @@ -36,10 +36,23 @@ func awsCredentialsCacheOptions(opts *aws.CredentialsCacheOptions) { // role. type Cache struct { awsConfigCache *utils.FnCache + defaultOptions []OptionsFn +} + +// CacheOption is an option func for setting additional options when creating +// a new config cache. +type CacheOption func(*Cache) + +// WithDefaults is a [CacheOption] function that sets default [OptionsFn] to +// use when getting AWS config. +func WithDefaults(optFns ...OptionsFn) CacheOption { + return func(c *Cache) { + c.defaultOptions = optFns + } } // NewCache returns a new [Cache]. -func NewCache() (*Cache, error) { +func NewCache(optFns ...CacheOption) (*Cache, error) { c, err := utils.NewFnCache(utils.FnCacheConfig{ TTL: 15 * time.Minute, ReloadOnErr: true, @@ -47,14 +60,27 @@ func NewCache() (*Cache, error) { if err != nil { return nil, trace.Wrap(err) } - return &Cache{ + cache := &Cache{ awsConfigCache: c, - }, nil + } + for _, fn := range optFns { + fn(cache) + } + return cache, nil +} + +// withDefaultOptions prepends default options to the given option funcs, +// providing for default cache options and per-call options. +func (c *Cache) withDefaultOptions(optFns []OptionsFn) []OptionsFn { + if c.defaultOptions != nil { + return append(c.defaultOptions, optFns...) + } + return optFns } // GetConfig returns an [aws.Config] for the given region and options. func (c *Cache) GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error) { - opts, err := buildOptions(optFns...) + opts, err := buildOptions(c.withDefaultOptions(optFns)...) if err != nil { return aws.Config{}, trace.Wrap(err) } @@ -112,7 +138,7 @@ func (c *Cache) getConfigForRoleChain(ctx context.Context, cfg aws.Config, opts } credProvider, err := utils.FnCacheGet(ctx, c.awsConfigCache, cacheKey, func(ctx context.Context) (aws.CredentialsProvider, error) { - clt := opts.assumeRoleClientProvider(cfg) + clt := opts.stsClientProvider(cfg) credProvider := getAssumeRoleProvider(ctx, clt, r) cc := aws.NewCredentialsCache(credProvider, awsCredentialsCacheOptions, diff --git a/lib/cloud/mocks/aws_config.go b/lib/cloud/mocks/aws_config.go index 7edadf80a9e20..b52dfbd36d74a 100644 --- a/lib/cloud/mocks/aws_config.go +++ b/lib/cloud/mocks/aws_config.go @@ -22,12 +22,15 @@ import ( "context" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/gravitational/trace" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud/awsconfig" ) type AWSConfigProvider struct { - STSClient *STSClient + STSClient *STSClient + OIDCIntegrationClient awsconfig.OIDCIntegrationClient } func (f *AWSConfigProvider) GetConfig(ctx context.Context, region string, optFns ...awsconfig.OptionsFn) (aws.Config, error) { @@ -35,8 +38,32 @@ func (f *AWSConfigProvider) GetConfig(ctx context.Context, region string, optFns if stsClt == nil { stsClt = &STSClient{} } - optFns = append(optFns, awsconfig.WithAssumeRoleClientProviderFunc( - newAssumeRoleClientProviderFunc(stsClt), - )) + optFns = append(optFns, + awsconfig.WithOIDCIntegrationClient(f.OIDCIntegrationClient), + awsconfig.WithSTSClientProvider( + newAssumeRoleClientProviderFunc(stsClt), + ), + ) return awsconfig.GetConfig(ctx, region, optFns...) } + +type FakeOIDCIntegrationClient struct { + Unauth bool + + Integration types.Integration + Token string +} + +func (f *FakeOIDCIntegrationClient) GetIntegration(ctx context.Context, name string) (types.Integration, error) { + if f.Unauth { + return nil, trace.AccessDenied("unauthorized") + } + return f.Integration, nil +} + +func (f *FakeOIDCIntegrationClient) GenerateAWSOIDCToken(ctx context.Context, integrationName string) (string, error) { + if f.Unauth { + return "", trace.AccessDenied("unauthorized") + } + return f.Token, nil +} diff --git a/lib/cloud/mocks/aws_sts.go b/lib/cloud/mocks/aws_sts.go index 713de480ebf86..178a1259669a4 100644 --- a/lib/cloud/mocks/aws_sts.go +++ b/lib/cloud/mocks/aws_sts.go @@ -54,7 +54,20 @@ type STSClient struct { recordFn func(roleARN, externalID string) } -func (m *STSClient) AssumeRole(ctx context.Context, in *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { +func (m *STSClient) AssumeRoleWithWebIdentity(ctx context.Context, in *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { + m.record(aws.ToString(in.RoleArn), "") + expiry := time.Now().Add(60 * time.Minute) + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &ststypes.Credentials{ + AccessKeyId: in.RoleArn, + SecretAccessKey: aws.String("secret"), + SessionToken: aws.String("token"), + Expiration: &expiry, + }, + }, nil +} + +func (m *STSClient) AssumeRole(ctx context.Context, in *sts.AssumeRoleInput, _ ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { // Retrieve credentials if we have a credential provider, so that all // assume-role providers in a role chain are triggered to call AssumeRole. if m.credentialProvider != nil { @@ -93,8 +106,8 @@ func (m *STSClient) record(roleARN, externalID string) { } } -func newAssumeRoleClientProviderFunc(base *STSClient) awsconfig.AssumeRoleClientProviderFunc { - return func(cfg aws.Config) stscreds.AssumeRoleAPIClient { +func newAssumeRoleClientProviderFunc(base *STSClient) awsconfig.STSClientProviderFunc { + return func(cfg aws.Config) awsconfig.STSClient { if cfg.Credentials != nil { if _, ok := cfg.Credentials.(*stscreds.AssumeRoleProvider); ok { // Create a new fake client linked to the old one. diff --git a/lib/integrations/awsoidc/clientsv1.go b/lib/integrations/awsoidc/clientsv1.go index 8c16f4c66156a..ae2e0be6a186b 100644 --- a/lib/integrations/awsoidc/clientsv1.go +++ b/lib/integrations/awsoidc/clientsv1.go @@ -44,9 +44,6 @@ type IntegrationTokenGenerator interface { // GetIntegration returns the specified integration resources. GetIntegration(ctx context.Context, name string) (types.Integration, error) - // GetProxies returns a list of registered proxies. - GetProxies() ([]types.Server, error) - // GenerateAWSOIDCToken generates a token to be used to execute an AWS OIDC Integration action. GenerateAWSOIDCToken(ctx context.Context, integration string) (string, error) } diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index 28690130d51a7..f37ba025d2450 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -224,7 +224,11 @@ kubernetes matchers are present.`) c.CloudClients = cloudClients } if c.AWSConfigProvider == nil { - provider, err := awsconfig.NewCache() + provider, err := awsconfig.NewCache( + awsconfig.WithDefaults( + awsconfig.WithOIDCIntegrationClient(c.AccessPoint), + ), + ) if err != nil { return trace.Wrap(err, "unable to create AWS config provider cache") } @@ -232,9 +236,8 @@ kubernetes matchers are present.`) } if c.AWSDatabaseFetcherFactory == nil { factory, err := db.NewAWSFetcherFactory(db.AWSFetcherFactoryConfig{ - CloudClients: c.CloudClients, - AWSConfigProvider: c.AWSConfigProvider, - IntegrationCredentialProviderFn: c.getIntegrationCredentialProviderFn(), + CloudClients: c.CloudClients, + AWSConfigProvider: c.AWSConfigProvider, }) if err != nil { return trace.Wrap(err) @@ -312,33 +315,10 @@ kubernetes matchers are present.`) } func (c *Config) getAWSConfig(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (aws.Config, error) { - opts = append(opts, awsconfig.WithIntegrationCredentialProvider(c.getIntegrationCredentialProviderFn())) cfg, err := c.AWSConfigProvider.GetConfig(ctx, region, opts...) return cfg, trace.Wrap(err) } -func (c *Config) getIntegrationCredentialProviderFn() awsconfig.IntegrationCredentialProviderFunc { - return func(ctx context.Context, region, integrationName string) (aws.CredentialsProvider, error) { - integration, err := c.AccessPoint.GetIntegration(ctx, integrationName) - if err != nil { - return nil, trace.Wrap(err) - } - if integration.GetAWSOIDCIntegrationSpec() == nil { - return nil, trace.BadParameter("integration does not have aws oidc spec fields %q", integrationName) - } - token, err := c.AccessPoint.GenerateAWSOIDCToken(ctx, integrationName) - if err != nil { - return nil, trace.Wrap(err) - } - cred, err := awsoidc.NewAWSCredentialsProvider(ctx, &awsoidc.AWSClientRequest{ - Token: token, - RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN, - Region: region, - }) - return cred, trace.Wrap(err) - } -} - // Server is a discovery server, used to discover cloud resources for // inclusion in Teleport type Server struct { diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index f3c387a475932..865517ba4c33c 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -37,7 +37,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v3" awsv2 "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go-v2/service/redshift" @@ -2032,18 +2031,6 @@ func TestDiscoveryDatabase(t *testing.T) { Clusters: []*eks.Cluster{eksAWSResource}, }, } - fakeConfigProvider := &mocks.AWSConfigProvider{} - dbFetcherFactory, err := db.NewAWSFetcherFactory(db.AWSFetcherFactoryConfig{ - AWSConfigProvider: fakeConfigProvider, - CloudClients: testCloudClients, - IntegrationCredentialProviderFn: func(_ context.Context, _, _ string) (awsv2.CredentialsProvider, error) { - return credentials.NewStaticCredentialsProvider("key", "secret", "session"), nil - }, - RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{ - Clusters: []redshifttypes.Cluster{*awsRedshiftResource}, - }), - }) - require.NoError(t, err) tcs := []struct { name string @@ -2334,6 +2321,23 @@ func TestDiscoveryDatabase(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { require.NoError(t, tlsServer.Close()) }) + awsOIDCIntegration, err := types.NewIntegrationAWSOIDC(types.Metadata{ + Name: integrationName, + }, &types.AWSOIDCIntegrationSpecV1{ + RoleARN: "arn:aws:iam::123456789012:role/teleport", + }) + require.NoError(t, err) + + testAuthServer.AuthServer.IntegrationsTokenGenerator = &mockIntegrationsTokenGenerator{ + proxies: nil, + integrations: map[string]types.Integration{ + awsOIDCIntegration.GetName(): awsOIDCIntegration, + }, + } + + _, err = tlsServer.Auth().CreateIntegration(ctx, awsOIDCIntegration) + require.NoError(t, err) + // Auth client for discovery service. identity := auth.TestServerID(types.RoleDiscovery, "hostID") authClient, err := tlsServer.NewClient(identity) @@ -2349,6 +2353,19 @@ func TestDiscoveryDatabase(t *testing.T) { waitForReconcile := make(chan struct{}) reporter := &mockUsageReporter{} tlsServer.Auth().SetUsageReporter(reporter) + accessPoint := getDiscoveryAccessPoint(tlsServer.Auth(), authClient) + fakeConfigProvider := &mocks.AWSConfigProvider{ + OIDCIntegrationClient: accessPoint, + } + dbFetcherFactory, err := db.NewAWSFetcherFactory(db.AWSFetcherFactoryConfig{ + AWSConfigProvider: fakeConfigProvider, + CloudClients: testCloudClients, + RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{ + Clusters: []redshifttypes.Cluster{*awsRedshiftResource}, + }), + }) + require.NoError(t, err) + srv, err := New( authz.ContextWithUser(ctx, identity.I), &Config{ @@ -2358,7 +2375,7 @@ func TestDiscoveryDatabase(t *testing.T) { AWSConfigProvider: fakeConfigProvider, ClusterFeatures: func() proto.Features { return proto.Features{} }, KubernetesClient: fake.NewSimpleClientset(), - AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), + AccessPoint: accessPoint, Matchers: Matchers{ AWS: tc.awsMatchers, Azure: tc.azureMatchers, diff --git a/lib/srv/discovery/fetchers/db/aws.go b/lib/srv/discovery/fetchers/db/aws.go index f87e0e9a6c443..d6d70912d7092 100644 --- a/lib/srv/discovery/fetchers/db/aws.go +++ b/lib/srv/discovery/fetchers/db/aws.go @@ -55,9 +55,6 @@ type awsFetcherConfig struct { AWSClients cloud.AWSClients // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. AWSConfigProvider awsconfig.Provider - // IntegrationCredentialProviderFn is a required function that provides - // credentials via AWS OIDC integration. - IntegrationCredentialProviderFn awsconfig.IntegrationCredentialProviderFunc // Type is the type of DB matcher, for example "rds", "redshift", etc. Type string // AssumeRole provides a role ARN and ExternalID to assume an AWS role diff --git a/lib/srv/discovery/fetchers/db/aws_redshift.go b/lib/srv/discovery/fetchers/db/aws_redshift.go index 508cb6e8810f1..0cda0b478e67b 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift.go @@ -53,7 +53,6 @@ func (f *redshiftPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region, awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), awsconfig.WithCredentialsMaybeIntegration(cfg.Integration), - awsconfig.WithIntegrationCredentialProvider(cfg.IntegrationCredentialProviderFn), ) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/discovery/fetchers/db/db.go b/lib/srv/discovery/fetchers/db/db.go index 3ef56532d90af..8d79bc2bb65bc 100644 --- a/lib/srv/discovery/fetchers/db/db.go +++ b/lib/srv/discovery/fetchers/db/db.go @@ -73,9 +73,6 @@ type AWSFetcherFactoryConfig struct { AWSConfigProvider awsconfig.Provider // CloudClients is an interface for retrieving AWS SDK v1 cloud clients. CloudClients cloud.AWSClients - // IntegrationCredentialProviderFn is an optional function that provides - // credentials via AWS OIDC integration. - IntegrationCredentialProviderFn awsconfig.IntegrationCredentialProviderFunc // RedshiftClientProviderFn is an optional function that provides RedshiftClientProviderFn RedshiftClientProviderFunc } @@ -128,16 +125,15 @@ func (f *AWSFetcherFactory) MakeFetchers(ctx context.Context, matchers []types.A for _, makeFetcher := range makeFetchers { for _, region := range matcher.Regions { fetcher, err := makeFetcher(awsFetcherConfig{ - AWSClients: f.cfg.CloudClients, - Type: matcherType, - AssumeRole: assumeRole, - Labels: matcher.Tags, - Region: region, - Integration: matcher.Integration, - DiscoveryConfigName: discoveryConfigName, - AWSConfigProvider: f.cfg.AWSConfigProvider, - IntegrationCredentialProviderFn: f.cfg.IntegrationCredentialProviderFn, - redshiftClientProviderFn: f.cfg.RedshiftClientProviderFn, + AWSClients: f.cfg.CloudClients, + Type: matcherType, + AssumeRole: assumeRole, + Labels: matcher.Tags, + Region: region, + Integration: matcher.Integration, + DiscoveryConfigName: discoveryConfigName, + AWSConfigProvider: f.cfg.AWSConfigProvider, + redshiftClientProviderFn: f.cfg.RedshiftClientProviderFn, }) if err != nil { return nil, trace.Wrap(err)