Skip to content

Commit

Permalink
enable aws iam rds auth for the postgres scaler
Browse files Browse the repository at this point in the history
add iam auth for postgres
  • Loading branch information
Haydn Evans committed Dec 30, 2024
1 parent 87158f3 commit dec7e3c
Show file tree
Hide file tree
Showing 10 changed files with 798 additions and 0 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ require (
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.25.0 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.48.1 // indirect
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.48.1 // indirect
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2 // indirect
github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect
github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 // indirect
github.com/envoyproxy/go-control-plane v0.13.1 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1501,6 +1501,8 @@ github.com/aws/aws-sdk-go-v2/credentials v1.17.48/go.mod h1:tOscxHN3CGmuX9idQ3+q
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.13/go.mod h1:y0eXmsNBFIVjUE8ZBjES8myOHlMsXDz7qGT93+MVdjk=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22 h1:kqOrpojG71DxJm/KDPO+Z/y1phm1JlC8/iT+5XRmAn8=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22/go.mod h1:NtSFajXVVL8TA2QNngagVZmUtXciyrHOt7xgz4faS/M=
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2 h1:fo+GuZNME9oGDc7VY+EBT+oCrco6RjRgUp1bKTcaHrU=
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2/go.mod h1:fnqb94UO6YCjBIic4WaqDYkNVAEFWOWiReVHitBBWW0=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.19/go.mod h1:llxE6bwUZhuCas0K7qGiu5OgMis3N7kdWtFSxoHmJ7E=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26 h1:I/5wmGMffY4happ8NOCuIUEWGUvvFp5NSeQcXl9RHcI=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26/go.mod h1:FR8f4turZtNy6baO0KJ5FJUmXH/cSkI9fOngs0yl6mA=
Expand Down
50 changes: 50 additions & 0 deletions pkg/scalers/postgresql_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
"github.com/go-logr/logr"
_ "github.com/jackc/pgx/v5/stdlib" // PostreSQL drive required for this scaler
awsutils "github.com/kedacore/keda/v2/pkg/scalers/aws"
v2 "k8s.io/api/autoscaling/v2"
"k8s.io/metrics/pkg/apis/external_metrics"

Expand Down Expand Up @@ -47,6 +49,9 @@ type postgreSQLMetadata struct {
Query string `keda:"name=query, order=triggerMetadata"`
triggerIndex int
azureAuthContext azureAuthContext
AwsRegion string `keda:"name=awsRegion, order=triggerMetadata;authParams"`
awsAuthorization awsutils.AuthorizationMetadata
awsAuthContext awsAuthContext

Host string `keda:"name=host, order=authParams;triggerMetadata, optional"`
Port string `keda:"name=port, order=authParams;triggerMetadata, optional"`
Expand Down Expand Up @@ -88,6 +93,10 @@ type azureAuthContext struct {
token *azcore.AccessToken
}

type awsAuthContext struct {
expiry time.Time
}

// NewPostgreSQLScaler creates a new postgreSQL scaler
func NewPostgreSQLScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (Scaler, error) {
metricType, err := GetMetricTargetType(config)
Expand Down Expand Up @@ -144,6 +153,19 @@ func parsePostgreSQLMetadata(logger logr.Logger, config *scalersconfig.ScalerCon
meta.azureAuthContext.cred = cred
authPodIdentity = kedav1alpha1.AuthPodIdentity{Provider: config.PodIdentity.Provider}

params = append(params, "%PASSWORD%")
meta.Connection = strings.Join(params, " ")
case kedav1alpha1.PodIdentityProviderAws:
params := buildConnArray(meta)

auth, err := awsutils.GetAwsAuthorization(config.TriggerUniqueKey, meta.AwsRegion, config.PodIdentity, config.TriggerMetadata, config.AuthParams, config.ResolvedEnv)
if err != nil {
return nil, authPodIdentity, err
}

meta.awsAuthorization = auth
authPodIdentity = kedav1alpha1.AuthPodIdentity{Provider: config.PodIdentity.Provider}

params = append(params, "%PASSWORD%")
meta.Connection = strings.Join(params, " ")
}
Expand Down Expand Up @@ -175,6 +197,22 @@ func getConnection(ctx context.Context, meta *postgreSQLMetadata, podIdentity ke
connectionString = passwordConnPattern.ReplaceAllString(meta.Connection, newPasswordField)
}

if podIdentity.Provider == kedav1alpha1.PodIdentityProviderAws {
cfg, err := awsutils.GetAwsConfig(ctx, meta.awsAuthorization)
if err != nil {
return nil, err
}
DBendpoint := fmt.Sprintf("%s:%s", meta.Host, meta.Port)
password, err := auth.BuildAuthToken(ctx, DBendpoint, meta.AwsRegion, meta.UserName, cfg.Credentials)
if err != nil {
return nil, err
}
meta.awsAuthContext.expiry = time.Now().Add(14 * time.Minute)

newPasswordField := "password=" + escapePostgreConnectionParameter(password)
connectionString = passwordConnPattern.ReplaceAllString(meta.Connection, newPasswordField)
}

db, err := sql.Open("pgx", connectionString)
if err != nil {
logger.Error(err, fmt.Sprintf("Found error opening postgreSQL: %s", err))
Expand Down Expand Up @@ -213,6 +251,18 @@ func (s *postgreSQLScaler) getActiveNumber(ctx context.Context) (float64, error)
}
}

if s.podIdentity.Provider == kedav1alpha1.PodIdentityProviderAws {
if s.metadata.awsAuthContext.expiry.Before(time.Now()) {
s.logger.Info("The AWS Access Token expired, retrieving a new AWS Access Token and instantiating a new Postgres connection object.")
s.connection.Close()
newConnection, err := getConnection(ctx, s.metadata, s.podIdentity, s.logger)
if err != nil {
return 0, fmt.Errorf("error establishing postgreSQL connection: %w", err)
}
s.connection = newConnection
}
}

err := s.connection.QueryRowContext(ctx, s.metadata.Query).Scan(&id)
if err != nil {
s.logger.Error(err, fmt.Sprintf("could not query postgreSQL: %s", err))
Expand Down
19 changes: 19 additions & 0 deletions pkg/scalers/postgresql_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ var testPodIdentityAzureWorkloadPostgreSQLConnectionstring = []postgreSQLConnect
{metadata: map[string]string{"query": "test_query", "targetQueryValue": "5", "host": "localhost", "port": "1234", "dbName": "testDb", "userName": "user", "sslmode": "required"}, connectionString: "host=localhost port=1234 user=user dbname=testDb sslmode=required %PASSWORD%"},
}

var testPodIdentityAwsWorkloadPostgresSQLConnectionstring = []postgreSQLConnectionStringTestData{
// from meta
{metadata: map[string]string{"query": "test_query", "targetQueryValue": "5", "host": "localhost", "port": "1234", "dbName": "testDb", "userName": "user", "sslmode": "required"}, connectionString: "host=localhost port=1234 user=user dbname=testDb sslmode=required %PASSWORD%"},
}

func TestPodIdentityAzureWorkloadPosgresSQLConnectionStringGeneration(t *testing.T) {
identityID := "IDENTITY_ID_CORRESPONDING_TO_USERNAME_FIELD"
for _, testData := range testPodIdentityAzureWorkloadPostgreSQLConnectionstring {
Expand All @@ -110,6 +115,20 @@ func TestPodIdentityAzureWorkloadPosgresSQLConnectionStringGeneration(t *testing
}
}

func TestPodIdentityAWSWorkloadPosgresSQLConnectionStringGeneration(t *testing.T) {
identityID := "IDENTITY_ID_CORRESPONDING_TO_USERNAME_FIELD"
for _, testData := range testPodIdentityAwsWorkloadPostgresSQLConnectionstring {
meta, _, err := parsePostgreSQLMetadata(logr.Discard(), &scalersconfig.ScalerConfig{ResolvedEnv: testData.resolvedEnv, TriggerMetadata: testData.metadata, PodIdentity: kedav1alpha1.AuthPodIdentity{Provider: kedav1alpha1.PodIdentityProviderAWSWorkload, IdentityID: &identityID}, AuthParams: testData.authParam, TriggerIndex: 0})

Check failure on line 121 in pkg/scalers/postgresql_scaler_test.go

View workflow job for this annotation

GitHub Actions / Static Checks

undefined: kedav1alpha1.PodIdentityProviderAWSWorkload (typecheck)
if err != nil {
t.Fatal("Could not parse metadata:", err)
}

if meta.Connection != testData.connectionString {
t.Errorf("Error generating connectionString, expected '%s' and get '%s'", testData.connectionString, meta.Connection)
}
}
}

type parsePostgresMetadataTestData struct {
metadata map[string]string
authParams map[string]string
Expand Down
Loading

0 comments on commit dec7e3c

Please sign in to comment.