diff --git a/pkg/scalers/mysql_scaler.go b/pkg/scalers/mysql_scaler.go index 65de8cfb1ec..0ae20863ce2 100644 --- a/pkg/scalers/mysql_scaler.go +++ b/pkg/scalers/mysql_scaler.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "strings" + "time" "github.com/go-logr/logr" "github.com/go-sql-driver/mysql" @@ -16,6 +17,16 @@ import ( kedautil "github.com/kedacore/keda/v2/pkg/util" ) +var ( + // A map that holds MySQL connection pools, keyed by connection string + connectionPools *kedautil.RefMap[string, *sql.DB] +) + +func init() { + // Initialize the global connectionPools map + connectionPools = kedautil.NewRefMap[string, *sql.DB]() +} + type mySQLScaler struct { metricType v2.MetricTargetType metadata *mySQLMetadata @@ -34,6 +45,12 @@ type mySQLMetadata struct { QueryValue float64 `keda:"name=queryValue, order=triggerMetadata"` ActivationQueryValue float64 `keda:"name=activationQueryValue, order=triggerMetadata, default=0"` MetricName string `keda:"name=metricName, order=triggerMetadata, optional"` + + // Connection pool settings + UseGlobalConnPools bool `keda:"name=useGlobalConnPools, order=triggerMetadata, optional"` + MaxOpenConns int `keda:"name=maxOpenConns, order=triggerMetadata, optional"` + MaxIdleConns int `keda:"name=maxIdleConns, order=triggerMetadata, optional"` + ConnMaxIdleTime int `keda:"name=connMaxIdleTime, order=triggerMetadata, optional"` // seconds } // NewMySQLScaler creates a new MySQL scaler @@ -50,10 +67,19 @@ func NewMySQLScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { return nil, fmt.Errorf("error parsing MySQL metadata: %w", err) } - conn, err := newMySQLConnection(meta, logger) + // Create MySQL connection, if useGlobalConnPools is set to true, it will use + // the global connection pool for the given connection string, otherwise it + // will create a new local connection pool for the given connection string + var conn *sql.DB + if meta.UseGlobalConnPools { + conn, err = getConnectionPool(meta, logger) + } else { + conn, err = newMySQLConnection(meta, logger) + } if err != nil { - return nil, fmt.Errorf("error establishing MySQL connection: %w", err) + return nil, fmt.Errorf("error creating MySQL connection: %w", err) } + return &mySQLScaler{ metricType: metricType, metadata: meta, @@ -96,6 +122,40 @@ func metadataToConnectionStr(meta *mySQLMetadata) string { return connStr } +// getConnectionPool will check if the connection pool has already been +// created for the given connection string and return it. If it has not +// been created, it will create a new connection pool and store it in the +// connectionPools map. +func getConnectionPool(meta *mySQLMetadata, logger logr.Logger) (*sql.DB, error) { + connStr := metadataToConnectionStr(meta) + // Try to load an existing pool and increment its reference count if found + if pool, ok := connectionPools.Load(connStr); ok { + err := connectionPools.AddRef(connStr) + if err != nil { + logger.Error(err, "Error increasing connection pool reference count") + return nil, err + } + + return pool, nil + } + + // If pool does not exist, create a new one and store it in RefMap + newPool, err := newMySQLConnection(meta, logger) + if err != nil { + return nil, err + } + err = connectionPools.Store(connStr, newPool, func(db *sql.DB) error { + logger.Info("Closing MySQL connection pool", "connectionString", connStr) + return db.Close() + }) + if err != nil { + logger.Error(err, "Error storing connection pool in RefMap") + return nil, err + } + + return newPool, nil +} + // newMySQLConnection creates MySQL db connection func newMySQLConnection(meta *mySQLMetadata, logger logr.Logger) (*sql.DB, error) { connStr := metadataToConnectionStr(meta) @@ -104,14 +164,35 @@ func newMySQLConnection(meta *mySQLMetadata, logger logr.Logger) (*sql.DB, error logger.Error(err, fmt.Sprintf("Found error when opening connection: %s", err)) return nil, err } + err = db.Ping() if err != nil { logger.Error(err, fmt.Sprintf("Found error when pinging database: %s", err)) return nil, err } + + setConnectionPoolConfiguration(meta, db) + return db, nil } +// setConnectionPoolConfiguration configures the MySQL connection pool settings +// based on the parameters provided in mySQLMetadata. If a setting is zero, it +// is left at its default value. +func setConnectionPoolConfiguration(meta *mySQLMetadata, db *sql.DB) { + if meta.MaxOpenConns > 0 { + db.SetMaxOpenConns(meta.MaxOpenConns) + } + + if meta.MaxIdleConns > 0 { + db.SetMaxIdleConns(meta.MaxIdleConns) + } + + if meta.ConnMaxIdleTime > 0 { + db.SetConnMaxIdleTime(time.Duration(meta.ConnMaxIdleTime) * time.Second) + } +} + // parseMySQLDbNameFromConnectionStr returns dbname from connection string // in it is not able to parse it, it returns "dbname" string func parseMySQLDbNameFromConnectionStr(connectionString string) string { @@ -123,13 +204,30 @@ func parseMySQLDbNameFromConnectionStr(connectionString string) string { return "dbname" } -// Close disposes of MySQL connections -func (s *mySQLScaler) Close(context.Context) error { - err := s.connection.Close() - if err != nil { - s.logger.Error(err, "Error closing MySQL connection") +// Close disposes of MySQL connections, closing either the global pool if used +// or the local connection pool +func (s *mySQLScaler) Close(ctx context.Context) error { + if s.metadata.UseGlobalConnPools { + if err := s.closeGlobalPool(ctx); err != nil { + return fmt.Errorf("error closing MySQL connection: %w", err) + } + } else { + if err := s.connection.Close(); err != nil { + return fmt.Errorf("error closing MySQL connection: %w", err) + } + } + + return nil +} + +// closeGlobalPool closes all MySQL connections in the global pool +func (s *mySQLScaler) closeGlobalPool(_ context.Context) error { + connStr := metadataToConnectionStr(s.metadata) + if err := connectionPools.RemoveRef(connStr); err != nil { + s.logger.Error(err, "Error decreasing connection pool reference count") return err } + return nil } diff --git a/pkg/scalers/mysql_scaler_test.go b/pkg/scalers/mysql_scaler_test.go index 70ed8c71d15..41c40f53a9d 100644 --- a/pkg/scalers/mysql_scaler_test.go +++ b/pkg/scalers/mysql_scaler_test.go @@ -75,6 +75,27 @@ var testMySQLMetadata = []parseMySQLMetadataTestData{ resolvedEnv: map[string]string{}, raisesError: true, }, + // use global pool + { + metadata: map[string]string{"query": "query", "queryValue": "12", "useGlobalConnPools": "true"}, + authParams: map[string]string{"host": "test_host", "port": "test_port", "username": "test_username", "password": "MYSQL_PASSWORD", "dbName": "test_dbname"}, + resolvedEnv: testMySQLResolvedEnv, + raisesError: false, + }, + // use connection pool settings + { + metadata: map[string]string{"query": "query", "queryValue": "12", "maxOpenConns": "10", "maxIdleConns": "5", "connMaxIdleTime": "10"}, + authParams: map[string]string{"host": "test_host", "port": "test_port", "username": "test_username", "password": "MYSQL_PASSWORD", "dbName": "test_dbname"}, + resolvedEnv: testMySQLResolvedEnv, + raisesError: false, + }, + // use connection pool settings and global pool + { + metadata: map[string]string{"query": "query", "queryValue": "12", "maxOpenConns": "10", "maxIdleConns": "5", "connMaxIdleTime": "10", "useGlobalConnPools": "true"}, + authParams: map[string]string{"host": "test_host", "port": "test_port", "username": "test_username", "password": "MYSQL_PASSWORD", "dbName": "test_dbname"}, + resolvedEnv: testMySQLResolvedEnv, + raisesError: false, + }, } var mySQLMetricIdentifiers = []mySQLMetricIdentifier{ diff --git a/pkg/util/refmap.go b/pkg/util/refmap.go new file mode 100644 index 00000000000..ab353e56463 --- /dev/null +++ b/pkg/util/refmap.go @@ -0,0 +1,123 @@ +package util + +//nolint:depguard // sync/atomic +import ( + "fmt" + "sync" + "sync/atomic" +) + +// refCountedValue manages a reference-counted value with a cleanup function. +type refCountedValue[V any] struct { + value V + refCount atomic.Int64 + closeFunc func(V) error // Cleanup function to call when count reaches zero +} + +// Add increments the reference count. +func (r *refCountedValue[V]) Add() { + r.refCount.Add(1) +} + +// Remove decrements the reference count and invokes closeFunc if the count +// reaches zero. +func (r *refCountedValue[V]) Remove() error { + if r.refCount.Add(-1) == 0 { + return r.closeFunc(r.value) + } + + return nil +} + +// Value returns the underlying value. +func (r *refCountedValue[V]) Value() V { + return r.value +} + +// RefMap manages reference-counted items in a concurrent-safe map. +type RefMap[K comparable, V any] struct { + data map[K]*refCountedValue[V] + mu sync.RWMutex +} + +// NewRefMap initializes a new RefMap. A RefMap is an atomic reference-counted +// concurrent hashmap. The general usage pattern is to Store a value with a +// close function, once the value is contained within the RefMap, it can be +// accessed via the Load method. The AddRef method signals ownership of the +// value and increments the reference count. The RemoveRef method decrements +// the reference count. When the reference count reaches zero, the close +// function is called and the value is removed from the map. +func NewRefMap[K comparable, V any]() *RefMap[K, V] { + return &RefMap[K, V]{ + data: make(map[K]*refCountedValue[V]), + } +} + +// Store adds a new item with an initial reference count of 1 and a close +// function. +func (r *RefMap[K, V]) Store(key K, value V, closeFunc func(V) error) error { + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.data[key]; exists { + return fmt.Errorf("key already exists: %v", key) + } + + r.data[key] = &refCountedValue[V]{value: value, refCount: atomic.Int64{}, closeFunc: closeFunc} + r.data[key].Add() // Set initial reference count to 1 + + return nil +} + +// Load retrieves a value by key without modifying the reference count, +// returning the value and a boolean indicating if it was found. The reference +// count not being modified means that a check for the existence of a key +// can be performed without signalling ownership of the value. If the value is +// used after this method, it is recommended to call AddRef to increment the +// reference +func (r *RefMap[K, V]) Load(key K) (V, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + if refValue, found := r.data[key]; found { + return refValue.Value(), true + } + var zero V + + return zero, false +} + +// AddRef increments the reference count for a key if it exists. Ensure +// to call RemoveRef when done with the value to prevent memory leaks. +func (r *RefMap[K, V]) AddRef(key K) error { + r.mu.RLock() + defer r.mu.RUnlock() + + refValue, found := r.data[key] + if !found { + return fmt.Errorf("key not found: %v", key) + } + + refValue.Add() + return nil +} + +// RemoveRef decrements the reference count and deletes the entry if count +// reaches zero. +func (r *RefMap[K, V]) RemoveRef(key K) error { + r.mu.Lock() + defer r.mu.Unlock() + + refValue, found := r.data[key] + if !found { + return fmt.Errorf("key not found: %v", key) + } + + err := refValue.Remove() + + if refValue.refCount.Load() == 0 { + delete(r.data, key) + } + + return err // returns the error from closeFunc +} diff --git a/pkg/util/refmap_test.go b/pkg/util/refmap_test.go new file mode 100644 index 00000000000..093e20e8557 --- /dev/null +++ b/pkg/util/refmap_test.go @@ -0,0 +1,251 @@ +package util + +//nolint:depguard // sync/atomic +import ( + "errors" + "sync" + "sync/atomic" + "testing" + "time" +) + +// Test the initial storage and retrieval of a value +func TestRefMapStoreAndLoad(t *testing.T) { + refMap := NewRefMap[string, int]() + cleanupCalled := atomic.Bool{} + + closeFunc := func(value int) error { + cleanupCalled.Store(true) + return nil + } + + if err := refMap.Store("testKey", 42, closeFunc); err != nil { + t.Errorf("unexpected error on Store: %v", err) + } + + val, ok := refMap.Load("testKey") + if !ok || val != 42 { + t.Errorf("expected to load value 42 for key 'testKey', got %v", val) + } + + if cleanupCalled.Load() { + t.Error("expected cleanup function not to be called initially") + } +} + +// Test adding a reference and removing it, triggering cleanup on zero count +func TestRefMapAddAndRemoveRef(t *testing.T) { + refMap := NewRefMap[string, int]() + cleanupCalled := atomic.Bool{} + + closeFunc := func(value int) error { + cleanupCalled.Store(true) + return nil + } + + if err := refMap.Store("testKey", 42, closeFunc); err != nil { + t.Errorf("unexpected error on Store: %v", err) + } + + // Add a reference + if err := refMap.AddRef("testKey"); err != nil { + t.Errorf("unexpected error on AddRef: %v", err) + } + + // Remove references + if err := refMap.RemoveRef("testKey"); err != nil { + t.Errorf("unexpected error on first RemoveRef: %v", err) + } + + if err := refMap.RemoveRef("testKey"); err != nil { + t.Errorf("unexpected error on second RemoveRef: %v", err) + } + + // Check that cleanup was called + if !cleanupCalled.Load() { + t.Error("expected cleanup function to be called") + } + + // Check that key no longer exists + _, ok := refMap.Load("testKey") + if ok { + t.Error("expected key 'testKey' to be deleted after reference count reached zero") + } +} + +// Test removing reference from a non-existent key +func TestRefMapRemoveRefNonExistentKey(t *testing.T) { + refMap := NewRefMap[string, int]() + + if err := refMap.RemoveRef("nonExistentKey"); err == nil { + t.Error("expected error when removing reference from a non-existent key") + } +} + +// Test that multiple calls to AddRef and RemoveRef are handled concurrently and correctly +func TestRefMapConcurrentAddAndRemove(t *testing.T) { + refMap := NewRefMap[string, int]() + cleanupCalled := atomic.Bool{} + wg := sync.WaitGroup{} + + closeFunc := func(value int) error { + cleanupCalled.Store(true) + return nil + } + + // Store a value with an initial reference count of 1 + if err := refMap.Store("testKey", 42, closeFunc); err != nil { + t.Errorf("unexpected error on Store: %v", err) + } + + // Confirm initial state before any increments + val, ok := refMap.Load("testKey") + if !ok || val != 42 { + t.Fatalf("expected to find 'testKey' with initial value 42, found %v", val) + } + + // Concurrently increment references + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := refMap.AddRef("testKey"); err != nil { + t.Errorf("unexpected error on AddRef: %v", err) + } + }() + } + + wg.Wait() // Wait for all AddRef operations to complete + + // Concurrently remove references + for i := 0; i < 101; i++ { // Including the initial count + wg.Add(1) + go func() { + defer wg.Done() + if err := refMap.RemoveRef("testKey"); err != nil { + t.Errorf("unexpected error on RemoveRef: %v", err) + } + }() + } + + wg.Wait() // Wait for all RemoveRef operations to complete + + // Verify that cleanup was called after all references are removed + if !cleanupCalled.Load() { + t.Error("expected cleanup function to be called after all references are removed") + } + + // Verify that the key no longer exists + if _, ok := refMap.Load("testKey"); ok { + t.Error("expected key 'testKey' to be deleted after reference count reached zero") + } +} + +// Test that an error in closeFunc is handled properly and does not prevent map deletion +func TestRefMapCleanupErrorHandling(t *testing.T) { + refMap := NewRefMap[string, int]() + cleanupCalled := atomic.Bool{} + + closeFunc := func(value int) error { + cleanupCalled.Store(true) + return errors.New("close error") + } + + if err := refMap.Store("testKey", 42, closeFunc); err != nil { + t.Errorf("unexpected error on Store: %v", err) + } + + // initial store should increment reference count, so this should trigger cleanup + // the closeFunc then throws an error, but the key should still be deleted + if err := refMap.RemoveRef("testKey"); err == nil { + t.Error("expected error when removing reference from a key with cleanup error") + } + + if !cleanupCalled.Load() { + t.Error("expected cleanup function to be called even if it returns an error") + } + + // Check that key no longer exists + _, ok := refMap.Load("testKey") + if ok { + t.Error("expected key 'testKey' to be deleted after reference count reached zero") + } +} + +// Test to ensure references are counted correctly +func TestRefMapReferenceCounting(t *testing.T) { + refMap := NewRefMap[string, int]() + closeFunc := func(value int) error { return nil } + + if err := refMap.Store("testKey", 100, closeFunc); err != nil { + t.Errorf("unexpected error on Store: %v", err) + } + + if err := refMap.AddRef("testKey"); err != nil { + t.Errorf("unexpected error on first AddRef: %v", err) + } + + if err := refMap.AddRef("testKey"); err != nil { + t.Errorf("unexpected error on second AddRef: %v", err) + } + + val, ok := refMap.Load("testKey") + if !ok || val != 100 { + t.Errorf("expected to load value 100 for key 'testKey', got %v", val) + } + + // Remove references one by one, expecting the item to persist until the last removal + if err := refMap.RemoveRef("testKey"); err != nil { + t.Errorf("unexpected error on RemoveRef: %v", err) + } + + if _, ok := refMap.Load("testKey"); !ok { + t.Error("expected key 'testKey' to still exist after one RemoveRef") + } + + if err := refMap.RemoveRef("testKey"); err != nil { + t.Errorf("unexpected error on second RemoveRef: %v", err) + } + + if _, ok := refMap.Load("testKey"); !ok { + t.Error("expected key 'testKey' to still exist after second RemoveRef") + } + + // Final removal should delete the key + if err := refMap.RemoveRef("testKey"); err != nil { + t.Errorf("unexpected error on final RemoveRef: %v", err) + } + + if _, ok := refMap.Load("testKey"); ok { + t.Error("expected key 'testKey' to be deleted after all references were removed") + } +} + +// Test to check if cleanup function handles delay in cleanup process. +func TestRefMapDelayedCleanup(t *testing.T) { + refMap := NewRefMap[string, int]() + cleanupCalled := atomic.Bool{} + closeFunc := func(value int) error { + time.Sleep(50 * time.Millisecond) + cleanupCalled.Store(true) + return nil + } + + if err := refMap.Store("testKey", 100, closeFunc); err != nil { + t.Errorf("unexpected error on Store: %v", err) + } + + if err := refMap.RemoveRef("testKey"); err != nil { + t.Errorf("unexpected error on RemoveRef: %v", err) + } + + if !cleanupCalled.Load() { + t.Error("expected cleanup function to be called after delay") + } + + // Verify key deletion after cleanup completes + _, ok := refMap.Load("testKey") + if ok { + t.Error("expected key 'testKey' to be deleted after cleanup") + } +}