Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable aws iam rds auth for the postgres scaler #6449

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ require (
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.25.0 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.48.1 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.48.1 // indirect
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2 // indirect
github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect
github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 // indirect
github.com/envoyproxy/go-control-plane v0.13.1 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1501,6 +1501,8 @@ github.com/aws/aws-sdk-go-v2/credentials v1.17.48/go.mod h1:tOscxHN3CGmuX9idQ3+q
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.13/go.mod h1:y0eXmsNBFIVjUE8ZBjES8myOHlMsXDz7qGT93+MVdjk=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22 h1:kqOrpojG71DxJm/KDPO+Z/y1phm1JlC8/iT+5XRmAn8=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22/go.mod h1:NtSFajXVVL8TA2QNngagVZmUtXciyrHOt7xgz4faS/M=
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2 h1:fo+GuZNME9oGDc7VY+EBT+oCrco6RjRgUp1bKTcaHrU=
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2/go.mod h1:fnqb94UO6YCjBIic4WaqDYkNVAEFWOWiReVHitBBWW0=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.19/go.mod h1:llxE6bwUZhuCas0K7qGiu5OgMis3N7kdWtFSxoHmJ7E=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26 h1:I/5wmGMffY4happ8NOCuIUEWGUvvFp5NSeQcXl9RHcI=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26/go.mod h1:FR8f4turZtNy6baO0KJ5FJUmXH/cSkI9fOngs0yl6mA=
Expand Down
50 changes: 50 additions & 0 deletions pkg/scalers/postgresql_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
"github.com/go-logr/logr"
_ "github.com/jackc/pgx/v5/stdlib" // PostreSQL drive required for this scaler
awsutils "github.com/kedacore/keda/v2/pkg/scalers/aws"
v2 "k8s.io/api/autoscaling/v2"
"k8s.io/metrics/pkg/apis/external_metrics"

Expand Down Expand Up @@ -47,6 +49,9 @@ type postgreSQLMetadata struct {
Query string `keda:"name=query, order=triggerMetadata"`
triggerIndex int
azureAuthContext azureAuthContext
AwsRegion string `keda:"name=awsRegion, order=triggerMetadata;authParams"`
awsAuthorization awsutils.AuthorizationMetadata
awsAuthContext awsAuthContext

Host string `keda:"name=host, order=authParams;triggerMetadata, optional"`
Port string `keda:"name=port, order=authParams;triggerMetadata, optional"`
Expand Down Expand Up @@ -88,6 +93,10 @@ type azureAuthContext struct {
token *azcore.AccessToken
}

type awsAuthContext struct {
expiry time.Time
}

// NewPostgreSQLScaler creates a new postgreSQL scaler
func NewPostgreSQLScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (Scaler, error) {
metricType, err := GetMetricTargetType(config)
Expand Down Expand Up @@ -144,6 +153,19 @@ func parsePostgreSQLMetadata(logger logr.Logger, config *scalersconfig.ScalerCon
meta.azureAuthContext.cred = cred
authPodIdentity = kedav1alpha1.AuthPodIdentity{Provider: config.PodIdentity.Provider}

params = append(params, "%PASSWORD%")
meta.Connection = strings.Join(params, " ")
case kedav1alpha1.PodIdentityProviderAws:
params := buildConnArray(meta)

auth, err := awsutils.GetAwsAuthorization(config.TriggerUniqueKey, meta.AwsRegion, config.PodIdentity, config.TriggerMetadata, config.AuthParams, config.ResolvedEnv)
if err != nil {
return nil, authPodIdentity, err
}

meta.awsAuthorization = auth
authPodIdentity = kedav1alpha1.AuthPodIdentity{Provider: config.PodIdentity.Provider}

params = append(params, "%PASSWORD%")
meta.Connection = strings.Join(params, " ")
}
Expand Down Expand Up @@ -175,6 +197,22 @@ func getConnection(ctx context.Context, meta *postgreSQLMetadata, podIdentity ke
connectionString = passwordConnPattern.ReplaceAllString(meta.Connection, newPasswordField)
}

if podIdentity.Provider == kedav1alpha1.PodIdentityProviderAws {
cfg, err := awsutils.GetAwsConfig(ctx, meta.awsAuthorization)
if err != nil {
return nil, err
}
DBendpoint := fmt.Sprintf("%s:%s", meta.Host, meta.Port)
Copy link

@semgrep-app semgrep-app bot Dec 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use net.JoinHostPort instead of fmt.Sprintf(meta.Port, meta.Host)

🧹 Fixed in commit d00a759 🧹

password, err := auth.BuildAuthToken(ctx, DBendpoint, meta.AwsRegion, meta.UserName, cfg.Credentials)
if err != nil {
return nil, err
}
meta.awsAuthContext.expiry = time.Now().Add(14 * time.Minute)

newPasswordField := "password=" + escapePostgreConnectionParameter(password)
connectionString = passwordConnPattern.ReplaceAllString(meta.Connection, newPasswordField)
}

db, err := sql.Open("pgx", connectionString)
if err != nil {
logger.Error(err, fmt.Sprintf("Found error opening postgreSQL: %s", err))
Expand Down Expand Up @@ -213,6 +251,18 @@ func (s *postgreSQLScaler) getActiveNumber(ctx context.Context) (float64, error)
}
}

if s.podIdentity.Provider == kedav1alpha1.PodIdentityProviderAws {
if s.metadata.awsAuthContext.expiry.Before(time.Now()) {
s.logger.Info("The AWS Access Token expired, retrieving a new AWS Access Token and instantiating a new Postgres connection object.")
s.connection.Close()
newConnection, err := getConnection(ctx, s.metadata, s.podIdentity, s.logger)
if err != nil {
return 0, fmt.Errorf("error establishing postgreSQL connection: %w", err)
}
s.connection = newConnection
}
}

err := s.connection.QueryRowContext(ctx, s.metadata.Query).Scan(&id)
if err != nil {
s.logger.Error(err, fmt.Sprintf("could not query postgreSQL: %s", err))
Expand Down
19 changes: 19 additions & 0 deletions pkg/scalers/postgresql_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@
{metadata: map[string]string{"query": "test_query", "targetQueryValue": "5", "host": "localhost", "port": "1234", "dbName": "testDb", "userName": "user", "sslmode": "required"}, connectionString: "host=localhost port=1234 user=user dbname=testDb sslmode=required %PASSWORD%"},
}

var testPodIdentityAwsWorkloadPostgresSQLConnectionstring = []postgreSQLConnectionStringTestData{
// from meta
{metadata: map[string]string{"query": "test_query", "targetQueryValue": "5", "host": "localhost", "port": "1234", "dbName": "testDb", "userName": "user", "sslmode": "required"}, connectionString: "host=localhost port=1234 user=user dbname=testDb sslmode=required %PASSWORD%"},
}

func TestPodIdentityAzureWorkloadPosgresSQLConnectionStringGeneration(t *testing.T) {
identityID := "IDENTITY_ID_CORRESPONDING_TO_USERNAME_FIELD"
for _, testData := range testPodIdentityAzureWorkloadPostgreSQLConnectionstring {
Expand All @@ -110,6 +115,20 @@
}
}

func TestPodIdentityAWSWorkloadPosgresSQLConnectionStringGeneration(t *testing.T) {
identityID := "IDENTITY_ID_CORRESPONDING_TO_USERNAME_FIELD"
for _, testData := range testPodIdentityAwsWorkloadPostgresSQLConnectionstring {
meta, _, err := parsePostgreSQLMetadata(logr.Discard(), &scalersconfig.ScalerConfig{ResolvedEnv: testData.resolvedEnv, TriggerMetadata: testData.metadata, PodIdentity: kedav1alpha1.AuthPodIdentity{Provider: kedav1alpha1.PodIdentityProviderAWSWorkload, IdentityID: &identityID}, AuthParams: testData.authParam, TriggerIndex: 0})

Check failure on line 121 in pkg/scalers/postgresql_scaler_test.go

View workflow job for this annotation

GitHub Actions / Static Checks

undefined: kedav1alpha1.PodIdentityProviderAWSWorkload (typecheck)
if err != nil {
t.Fatal("Could not parse metadata:", err)
}

if meta.Connection != testData.connectionString {
t.Errorf("Error generating connectionString, expected '%s' and get '%s'", testData.connectionString, meta.Connection)
}
}
}

type parsePostgresMetadataTestData struct {
metadata map[string]string
authParams map[string]string
Expand Down
Loading
Loading