diff --git a/cmd/hyperfleet-api/migrate/cmd.go b/cmd/hyperfleet-api/migrate/cmd.go index 6f1dc97..56bfee0 100755 --- a/cmd/hyperfleet-api/migrate/cmd.go +++ b/cmd/hyperfleet-api/migrate/cmd.go @@ -58,11 +58,11 @@ func runMigrateWithError(ctx context.Context, dbConfig *config.DatabaseConfig) e } }() - if err := db.Migrate(connection.New(ctx)); err != nil { + // Use MigrateWithLock to prevent concurrent migrations from multiple pods + if err := db.MigrateWithLock(ctx, connection); err != nil { logger.WithError(ctx, err).Error("Migration failed") return err } - logger.Info(ctx, "Migration completed successfully") return nil } diff --git a/docs/database.md b/docs/database.md index 9428aef..24218c9 100644 --- a/docs/database.md +++ b/docs/database.md @@ -61,6 +61,48 @@ Uses GORM AutoMigrate: - Additive (creates missing tables, columns, indexes) - Run via `./bin/hyperfleet-api migrate` +### Migration Coordination + +**Problem:** During rolling deployments, multiple pods attempt to run migrations simultaneously, causing race conditions and deployment failures. + +**Solution:** PostgreSQL advisory locks ensure exclusive migration execution. + +#### How It Works + +```go +// Only one pod/process acquires the lock and runs migrations +// Others wait until the lock is released +db.MigrateWithLock(ctx, factory) +``` + +**Implementation:** +1. Pod sets statement timeout (5 minutes) to prevent indefinite blocking +2. Pod acquires advisory lock via `pg_advisory_xact_lock(hash("migrations"), hash("Migrations"))` +3. Lock holder runs migrations exclusively +4. Other pods block until lock is released or timeout is reached +5. Lock automatically released on transaction commit + +**Key Features:** +- **Zero infrastructure overhead** - Uses native PostgreSQL locks +- **Automatic cleanup** - Locks released on transaction end or pod crash +- **Timeout protection** - 5-minute timeout prevents indefinite blocking if a pod hangs +- **Nested lock support** - Same lock can be acquired in nested contexts without deadlock +- **UUID-based ownership** - Only original acquirer can unlock + +#### Testing Concurrent Migrations + +Integration tests validate concurrent behavior: + +```bash +make test-integration # Runs TestConcurrentMigrations +``` + +**Test coverage:** +- `TestConcurrentMigrations` - Multiple pods running migrations simultaneously +- `TestAdvisoryLocksConcurrently` - Lock serialization under race conditions +- `TestAdvisoryLocksWithTransactions` - Lock + transaction interaction +- `TestAdvisoryLockBlocking` - Lock blocking behavior + ## Database Setup ```bash diff --git a/pkg/config/db.go b/pkg/config/db.go index 505cdc4..45975d0 100755 --- a/pkg/config/db.go +++ b/pkg/config/db.go @@ -43,13 +43,14 @@ type SSLConfig struct { // PoolConfig holds connection pool configuration // Includes fields from HYPERFLEET-694 for connection lifecycle management type PoolConfig struct { - MaxConnections int `mapstructure:"max_connections" json:"max_connections" validate:"required,min=1,max=200"` - MaxIdleConnections int `mapstructure:"max_idle_connections" json:"max_idle_connections" validate:"min=0"` - ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime" json:"conn_max_lifetime"` - ConnMaxIdleTime time.Duration `mapstructure:"conn_max_idle_time" json:"conn_max_idle_time"` - RequestTimeout time.Duration `mapstructure:"request_timeout" json:"request_timeout"` - ConnRetryAttempts int `mapstructure:"conn_retry_attempts" json:"conn_retry_attempts" validate:"min=1"` - ConnRetryInterval time.Duration `mapstructure:"conn_retry_interval" json:"conn_retry_interval"` + MaxConnections int `mapstructure:"max_connections" json:"max_connections" validate:"required,min=1,max=200"` + MaxIdleConnections int `mapstructure:"max_idle_connections" json:"max_idle_connections" validate:"min=0"` + ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime" json:"conn_max_lifetime"` + ConnMaxIdleTime time.Duration `mapstructure:"conn_max_idle_time" json:"conn_max_idle_time"` + RequestTimeout time.Duration `mapstructure:"request_timeout" json:"request_timeout"` + ConnRetryAttempts int `mapstructure:"conn_retry_attempts" json:"conn_retry_attempts" validate:"min=1"` + ConnRetryInterval time.Duration `mapstructure:"conn_retry_interval" json:"conn_retry_interval"` + AdvisoryLockTimeout time.Duration `mapstructure:"advisory_lock_timeout" json:"advisory_lock_timeout"` // HYPERFLEET-618: prevents indefinite blocking during migrations } // MarshalJSON implements custom JSON marshaling to redact sensitive fields @@ -91,13 +92,14 @@ func NewDatabaseConfig() *DatabaseConfig { RootCertFile: "", }, Pool: PoolConfig{ - MaxConnections: 50, - MaxIdleConnections: 10, - ConnMaxLifetime: 5 * time.Minute, - ConnMaxIdleTime: 1 * time.Minute, - RequestTimeout: 30 * time.Second, - ConnRetryAttempts: 10, - ConnRetryInterval: 3 * time.Second, + MaxConnections: 50, + MaxIdleConnections: 10, + ConnMaxLifetime: 5 * time.Minute, + ConnMaxIdleTime: 1 * time.Minute, + RequestTimeout: 30 * time.Second, + ConnRetryAttempts: 10, + ConnRetryInterval: 3 * time.Second, + AdvisoryLockTimeout: 5 * time.Minute, // HYPERFLEET-618: prevents indefinite blocking during migrations }, } } diff --git a/pkg/db/advisory_locks.go b/pkg/db/advisory_locks.go new file mode 100644 index 0000000..d43262c --- /dev/null +++ b/pkg/db/advisory_locks.go @@ -0,0 +1,122 @@ +package db + +import ( + "context" + "errors" + "fmt" + "hash/fnv" + "time" + + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/logger" + "gorm.io/gorm" +) + +// LockType represents the type of advisory lock +type LockType string + +const ( + // Migrations lock type for database migrations + Migrations LockType = "Migrations" + + // MigrationsLockID is the advisory lock ID used for migration coordination + MigrationsLockID = "migrations" +) + +// AdvisoryLock represents a postgres advisory lock +// +// begin # start a Tx +// select pg_advisory_xact_lock(id, lockType) # obtain the lock (blocking) +// end # end the Tx and release the lock +// +// ownerUUID is a way to own the lock. Only the very first +// service call that owns the lock will have the correct ownerUUID. This is necessary +// to allow functions to call other service functions as part of the same lock (id, lockType). +type AdvisoryLock struct { + g2 *gorm.DB + ownerUUID *string + id *string + lockType *LockType + timeoutSeconds int + startTime time.Time +} + +// newAdvisoryLock constructs a new AdvisoryLock object. +func newAdvisoryLock(ctx context.Context, connection SessionFactory, ownerUUID *string, id *string, locktype *LockType) (*AdvisoryLock, error) { + if connection == nil { + return nil, errors.New("AdvisoryLock: connection factory is missing") + } + + // it requires a new DB session to start the advisory lock. + g2 := connection.New(ctx) + + // start a Tx to ensure gorm will obtain/release the lock using a same connection. + tx := g2.Begin() + if tx.Error != nil { + return nil, tx.Error + } + + return &AdvisoryLock{ + ownerUUID: ownerUUID, + id: id, + lockType: locktype, + timeoutSeconds: connection.GetAdvisoryLockTimeout(), + g2: tx, + startTime: time.Now(), + }, nil +} + +// lock calls select pg_advisory_xact_lock(id, lockType) to obtain the lock defined by (id, lockType). +// It blocks until the lock is acquired or the statement timeout is reached. +// The timeout prevents indefinite blocking if a pod hangs while holding the lock. +func (l *AdvisoryLock) lock() error { + if l.g2 == nil { + return errors.New("AdvisoryLock: transaction is missing") + } + if l.id == nil { + return errors.New("AdvisoryLock: id is missing") + } + if l.lockType == nil { + return errors.New("AdvisoryLock: lockType is missing") + } + + // Set statement timeout to prevent indefinite blocking. + // This is transaction-scoped (SET LOCAL), so it only affects this lock acquisition. + // Note: We cannot use parameter binding (?) for SET commands in PostgreSQL + timeoutMs := l.timeoutSeconds * 1000 + if err := l.g2.Exec(fmt.Sprintf("SET LOCAL statement_timeout = %d", timeoutMs)).Error; err != nil { + return err + } + + idAsInt := hash(*l.id) + typeAsInt := hash(string(*l.lockType)) + err := l.g2.Exec("select pg_advisory_xact_lock(?, ?)", idAsInt, typeAsInt).Error + return err +} + +func (l *AdvisoryLock) unlock(ctx context.Context) error { + if l.g2 == nil { + return errors.New("AdvisoryLock: transaction is missing") + } + + duration := time.Since(l.startTime) + + // it ends the Tx and implicitly releases the lock. + err := l.g2.Commit().Error + l.g2 = nil + + if err == nil { + logger.With(ctx, logger.FieldLockDurationMs, duration.Milliseconds()).Info("Released advisory lock") + } + + return err +} + +// hash string to int32 (postgres integer) +// https://pkg.go.dev/math#pkg-constants +// https://www.postgresql.org/docs/12/datatype-numeric.html +func hash(s string) int32 { + h := fnv.New32a() + h.Write([]byte(s)) + // Sum32() returns uint32. needs conversion. + return int32(h.Sum32()) +} diff --git a/pkg/db/context.go b/pkg/db/context.go index 06c5114..c262bbf 100755 --- a/pkg/db/context.go +++ b/pkg/db/context.go @@ -3,10 +3,33 @@ package db import ( "context" + "github.com/google/uuid" + dbContext "github.com/openshift-hyperfleet/hyperfleet-api/pkg/db/db_context" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/logger" ) +type advisoryLockKey string + +const ( + advisoryLock advisoryLockKey = "advisoryLock" +) + +type advisoryLockMap map[string]*AdvisoryLock + +func (m advisoryLockMap) key(id string, lockType LockType) string { + return id + ":" + string(lockType) +} + +func (m advisoryLockMap) get(id string, lockType LockType) (*AdvisoryLock, bool) { + lock, ok := m[m.key(id, lockType)] + return lock, ok +} + +func (m advisoryLockMap) set(id string, lockType LockType, lock *AdvisoryLock) { + m[m.key(id, lockType)] = lock +} + // NewContext returns a new context with transaction stored in it. // Upon error, the original context is still returned along with an error func NewContext(ctx context.Context, connection SessionFactory) (context.Context, error) { @@ -53,3 +76,81 @@ func MarkForRollback(ctx context.Context, err error) { transaction.SetRollbackFlag(true) logger.WithError(ctx, err).Info("Marked transaction for rollback") } + +// NewAdvisoryLockContext returns a new context with AdvisoryLock stored in it. +// Upon error, the original context is still returned along with an error. +// +// CONCURRENCY: The returned context must not be shared across goroutines that call +// NewAdvisoryLockContext or Unlock concurrently, as the internal lock map is not +// protected by a mutex. Each goroutine should derive its own context chain. +func NewAdvisoryLockContext(ctx context.Context, connection SessionFactory, id string, lockType LockType) (context.Context, string, error) { + // lockOwnerID will be different for every service function that attempts to start a lock. + // only the initial call in the stack must unlock. + // Unlock() will compare UUIDs and ensure only the top level call succeeds. + lockOwnerID := uuid.New().String() + + locks, found := ctx.Value(advisoryLock).(advisoryLockMap) + if found { + if _, ok := locks.get(id, lockType); ok { + return ctx, lockOwnerID, nil + } + } else { + locks = make(advisoryLockMap) + } + + lock, err := newAdvisoryLock(ctx, connection, &lockOwnerID, &id, &lockType) + if err != nil { + logger.WithError(ctx, err).Error("Failed to create advisory lock") + return ctx, lockOwnerID, err + } + + // obtain the advisory lock (blocking) + err = lock.lock() + if err != nil { + logger.WithError(ctx, err).Error("Failed to acquire advisory lock") + lock.g2.Rollback() // clean up the open transaction + return ctx, lockOwnerID, err + } + + locks.set(id, lockType, lock) + + ctx = context.WithValue(ctx, advisoryLock, locks) + logger.With(ctx, logger.FieldLockID, id, logger.FieldLockType, lockType).Info("Acquired advisory lock") + + return ctx, lockOwnerID, nil +} + +// Unlock searches current locks and unlocks the one matching its owner id. +func Unlock(ctx context.Context, callerUUID string) { + locks, ok := ctx.Value(advisoryLock).(advisoryLockMap) + if !ok { + logger.Error(ctx, "Could not retrieve locks from context") + return + } + + for k, lock := range locks { + if lock.ownerUUID == nil { + logger.With(ctx, logger.FieldLockID, lock.id).Warn("lockOwnerID could not be found in AdvisoryLock") + } else if *lock.ownerUUID == callerUUID { + lockID := "" + lockType := LockType("") + + if lock.id != nil { + lockID = *lock.id + } + if lock.lockType != nil { + lockType = *lock.lockType + } + + if err := lock.unlock(ctx); err != nil { + logger.With(ctx, logger.FieldLockID, lockID, logger.FieldLockType, lockType). + WithError(err).Error("Could not unlock lock") + continue + } + logger.With(ctx, logger.FieldLockID, lockID, logger.FieldLockType, lockType).Info("Unlocked lock") + delete(locks, k) + } + // Note: if ownerUUID doesn't match callerUUID, the lock belongs to a different + // service call and is intentionally not unlocked here + } +} diff --git a/pkg/db/db_session/default.go b/pkg/db/db_session/default.go index 5736339..9c7ee8a 100755 --- a/pkg/db/db_session/default.go +++ b/pkg/db/db_session/default.go @@ -229,3 +229,7 @@ func (f *Default) ReconfigureLogger(level gormlogger.LogLevel) { newLogger := logger.NewGormLogger(level, slowQueryThreshold) f.g2.Logger = newLogger } + +func (f *Default) GetAdvisoryLockTimeout() int { + return int(f.config.Pool.AdvisoryLockTimeout.Seconds()) +} diff --git a/pkg/db/db_session/test.go b/pkg/db/db_session/test.go index 0bcc286..d9b0be1 100755 --- a/pkg/db/db_session/test.go +++ b/pkg/db/db_session/test.go @@ -232,3 +232,7 @@ func (f *Test) ReconfigureLogger(level gormlogger.LogLevel) { newLogger := logger.NewGormLogger(level, slowQueryThreshold) f.g2.Logger = newLogger } + +func (f *Test) GetAdvisoryLockTimeout() int { + return int(f.config.Pool.AdvisoryLockTimeout.Seconds()) +} diff --git a/pkg/db/db_session/testcontainer.go b/pkg/db/db_session/testcontainer.go index a70b8f7..ffbf26a 100755 --- a/pkg/db/db_session/testcontainer.go +++ b/pkg/db/db_session/testcontainer.go @@ -224,3 +224,7 @@ func (f *Testcontainer) ReconfigureLogger(level gormlogger.LogLevel) { newLogger := logger.NewGormLogger(level, slowQueryThreshold) f.g2.Logger = newLogger } + +func (f *Testcontainer) GetAdvisoryLockTimeout() int { + return int(f.config.Pool.AdvisoryLockTimeout.Seconds()) +} diff --git a/pkg/db/migrations.go b/pkg/db/migrations.go index 63fa6e0..5d647a7 100755 --- a/pkg/db/migrations.go +++ b/pkg/db/migrations.go @@ -24,6 +24,27 @@ func Migrate(g2 *gorm.DB) error { return nil } +// MigrateWithLock runs migrations with an advisory lock to prevent concurrent migrations +func MigrateWithLock(ctx context.Context, factory SessionFactory) error { + // Acquire advisory lock for migrations + ctx, lockOwnerID, err := NewAdvisoryLockContext(ctx, factory, MigrationsLockID, Migrations) + if err != nil { + logger.WithError(ctx, err).Error("Could not lock migrations") + return err + } + defer Unlock(ctx, lockOwnerID) + + // Run migrations with the locked context + g2 := factory.New(ctx) + if err := Migrate(g2); err != nil { + logger.WithError(ctx, err).Error("Could not migrate") + return err + } + + logger.Info(ctx, "Migration completed successfully") + return nil +} + // MigrateTo a specific migration will not seed the database, seeds are up to date with the latest // schema based on the most recent migration // This should be for testing purposes mainly diff --git a/pkg/db/mocks/session_factory.go b/pkg/db/mocks/session_factory.go index 18f8396..f0a81d8 100755 --- a/pkg/db/mocks/session_factory.go +++ b/pkg/db/mocks/session_factory.go @@ -77,3 +77,7 @@ func (m *MockSessionFactory) ResetDB() { func (m *MockSessionFactory) NewListener(ctx context.Context, channel string, callback func(id string)) { // Mock implementation - does nothing } + +func (m *MockSessionFactory) GetAdvisoryLockTimeout() int { + return int(config.NewDatabaseConfig().Pool.AdvisoryLockTimeout.Seconds()) +} diff --git a/pkg/db/session.go b/pkg/db/session.go index a124751..5d9dbbb 100755 --- a/pkg/db/session.go +++ b/pkg/db/session.go @@ -17,4 +17,5 @@ type SessionFactory interface { Close() error ResetDB() NewListener(ctx context.Context, channel string, callback func(id string)) + GetAdvisoryLockTimeout() int } diff --git a/pkg/logger/fields.go b/pkg/logger/fields.go index 8c5e5a0..8c0d45f 100644 --- a/pkg/logger/fields.go +++ b/pkg/logger/fields.go @@ -31,6 +31,9 @@ const ( FieldConnectionString = "connection_string" FieldTable = "table" FieldChannel = "channel" + FieldLockID = "lock_id" + FieldLockType = "lock_type" + FieldLockDurationMs = "lock_duration_ms" // Note: transaction_id is a context field (see context.go) ) diff --git a/test/integration/advisory_locks_test.go b/test/integration/advisory_locks_test.go new file mode 100644 index 0000000..bf85538 --- /dev/null +++ b/test/integration/advisory_locks_test.go @@ -0,0 +1,550 @@ +package integration + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "testing" + "time" + + . "github.com/onsi/gomega" + "gorm.io/gorm" + + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/db" + "github.com/openshift-hyperfleet/hyperfleet-api/test" +) + +// TestAdvisoryLocksConcurrently validates that advisory locks properly serialize +// concurrent access to shared resources. This test uses actual database operations +// to prove the lock prevents race conditions at the database level. +func TestAdvisoryLocksConcurrently(t *testing.T) { + h, _ := test.RegisterIntegration(t) + + // Create a counter table and initialize to 0 + g2 := h.DBFactory.New(context.Background()) + err := g2.Exec("CREATE TABLE IF NOT EXISTS lock_test_counter (id INTEGER PRIMARY KEY, value INTEGER)").Error + Expect(err).NotTo(HaveOccurred(), "Failed to create counter table") + + err = g2.Exec("INSERT INTO lock_test_counter (id, value) VALUES (1, 0)").Error + Expect(err).NotTo(HaveOccurred(), "Failed to initialize counter") + defer g2.Exec("DROP TABLE IF EXISTS lock_test_counter") + + total := 10 + var waiter sync.WaitGroup + waiter.Add(total) + + // Simulate a race condition where multiple threads are trying to access and modify the counter. + // The acquireLock func uses an advisory lock so the accesses should be properly serialized. + for i := 0; i < total; i++ { + go acquireLock(h, &waiter) + } + + // Wait for all goroutines to complete + waiter.Wait() + + // All goroutines should have incremented the counter by 1, resulting in 10 + var finalValue int + err = g2.Raw("SELECT value FROM lock_test_counter WHERE id = 1").Scan(&finalValue).Error + Expect(err).NotTo(HaveOccurred(), "Failed to read final counter value") + Expect(finalValue).To(Equal(total), "Counter should equal total") +} + +func acquireLock(h *test.Helper, waiter *sync.WaitGroup) { + defer waiter.Done() + + ctx := context.Background() + + // Acquire advisory lock + ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "test-resource", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire lock") + defer db.Unlock(ctx, lockOwnerID) + + g2 := h.DBFactory.New(ctx) + + // Read current value from database + var currentValue int + err = g2.Raw("SELECT value FROM lock_test_counter WHERE id = 1").Scan(¤tValue).Error + Expect(err).NotTo(HaveOccurred(), "Failed to read counter") + + // Some slow work to increase the likelihood of race conditions + time.Sleep(20 * time.Millisecond) + + // Increment and save to database + newValue := currentValue + 1 + err = g2.Exec("UPDATE lock_test_counter SET value = ? WHERE id = 1", newValue).Error + Expect(err).NotTo(HaveOccurred(), "Failed to update counter") +} + +// TestAdvisoryLocksWithTransactions validates that advisory locks work correctly +// when combined with database transactions in various orders. Uses actual database +// operations to prove serialization. +func TestAdvisoryLocksWithTransactions(t *testing.T) { + h, _ := test.RegisterIntegration(t) + + // Create a counter table and initialize to 0 + g2 := h.DBFactory.New(context.Background()) + err := g2.Exec("CREATE TABLE IF NOT EXISTS lock_test_counter_tx (id INTEGER PRIMARY KEY, value INTEGER)").Error + Expect(err).NotTo(HaveOccurred(), "Failed to create counter table") + + err = g2.Exec("INSERT INTO lock_test_counter_tx (id, value) VALUES (1, 0)").Error + Expect(err).NotTo(HaveOccurred(), "Failed to initialize counter") + defer g2.Exec("DROP TABLE IF EXISTS lock_test_counter_tx") + + // Test all three transaction ordering scenarios deterministically + testCases := []struct { + name string + txBeforeLock bool + txAfterLock bool + }{ + { + name: "tx_before_lock", + txBeforeLock: true, + txAfterLock: false, + }, + { + name: "tx_after_lock", + txBeforeLock: false, + txAfterLock: true, + }, + { + name: "no_tx", + txBeforeLock: false, + txAfterLock: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Run multiple goroutines for each scenario to test concurrency + goroutines := 3 + var waiter sync.WaitGroup + waiter.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go acquireLockWithTransaction(h, &waiter, tc.txBeforeLock, tc.txAfterLock) + } + + waiter.Wait() + }) + } + + // All test cases combined should have incremented the counter by 9 (3 scenarios × 3 goroutines) + expectedTotal := 9 + var finalValue int + err = g2.Raw("SELECT value FROM lock_test_counter_tx WHERE id = 1").Scan(&finalValue).Error + Expect(err).NotTo(HaveOccurred(), "Failed to read final counter value") + Expect(finalValue).To(Equal(expectedTotal), "Counter should equal total") +} + +func acquireLockWithTransaction(h *test.Helper, waiter *sync.WaitGroup, txBeforeLock bool, txAfterLock bool) { + defer waiter.Done() + + ctx := context.Background() + + var dberr error + + // Add Tx before lock if requested + if txBeforeLock { + ctx, dberr = db.NewContext(ctx, h.DBFactory) + Expect(dberr).NotTo(HaveOccurred(), "Failed to create transaction context") + defer db.Resolve(ctx) + } + + // Acquire advisory lock + ctx, lockOwnerID, dberr := db.NewAdvisoryLockContext(ctx, h.DBFactory, "test-resource-tx", db.Migrations) + Expect(dberr).NotTo(HaveOccurred(), "Failed to acquire lock") + defer db.Unlock(ctx, lockOwnerID) + + // Add Tx after lock if requested + if txAfterLock { + ctx, dberr = db.NewContext(ctx, h.DBFactory) + Expect(dberr).NotTo(HaveOccurred(), "Failed to create transaction context") + defer db.Resolve(ctx) + } + + g2 := h.DBFactory.New(ctx) + + // Read current value from database + var currentValue int + err := g2.Raw("SELECT value FROM lock_test_counter_tx WHERE id = 1").Scan(¤tValue).Error + Expect(err).NotTo(HaveOccurred(), "Failed to read counter") + + // Some slow work + time.Sleep(20 * time.Millisecond) + + // Increment and save to database + newValue := currentValue + 1 + err = g2.Exec("UPDATE lock_test_counter_tx SET value = ? WHERE id = 1", newValue).Error + Expect(err).NotTo(HaveOccurred(), "Failed to update counter") +} + +// TestLocksAndExpectedWaits validates the behavior of advisory locks: +// - Nested locks with the same (id, lockType) should not create additional locks +// - Different (id, lockType) combinations should create separate locks +// - Unlocking should only affect the lock matching the owner ID +func TestLocksAndExpectedWaits(t *testing.T) { + h, _ := test.RegisterIntegration(t) + + // Start lock + ctx := context.Background() + ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "system", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire lock") + defer db.Unlock(ctx, lockOwnerID) // Ensure lock is released on test exit + + // It should have 1 lock + g2 := h.DBFactory.New(ctx) + var pgLocks []struct{ Granted bool } + _ = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error + Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock") + + // Successive locking should have no effect (nested lock with same id/type) + // Pretend this runs in a nested func + ctx, lockOwnerID2, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "system", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire nested lock") + defer db.Unlock(ctx, lockOwnerID2) // Ensure lock is released on test exit + + // It should still have 1 lock + pgLocks = nil + _ = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error + Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock after nested acquire") + + // Unlock should have no effect either (unlocking nested lock) + // Pretend this runs in the nested func + db.Unlock(ctx, lockOwnerID2) + // It should still have 1 lock + pgLocks = nil + _ = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error + Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock after nested unlock") + + // Lock on a different (id, lockType) should work + // Pretend this runs in a nested func + ctx, lockOwnerID3, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "diff_system", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire different lock") + defer db.Unlock(ctx, lockOwnerID3) // Ensure lock is released on test exit + + // It should have 2 locks + pgLocks = nil + _ = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error + Expect(len(pgLocks)).To(Equal(2), "Expected 2 locks") + + // Pretend it releases the new lock in the nested func + db.Unlock(ctx, lockOwnerID3) + // It should have 1 lock + pgLocks = nil + _ = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error + Expect(len(pgLocks)).To(Equal(1), "Expected 1 lock after releasing different lock") + + // Unlock the topmost lock + // Pretend it returns back to the parent func + db.Unlock(ctx, lockOwnerID) + // The lock should be gone + pgLocks = nil + _ = g2.Raw("select granted from pg_locks WHERE locktype = 'advisory' and granted = true").Scan(&pgLocks).Error + Expect(len(pgLocks)).To(Equal(0), "Expected 0 locks after final unlock") +} + +// TestConcurrentMigrations validates that the MigrateWithLock function +// properly serializes concurrent migration attempts, ensuring only one +// instance actually runs migrations at a time. +func TestConcurrentMigrations(t *testing.T) { + h, _ := test.RegisterIntegration(t) + + // First, reset the database to a clean state + err := h.ResetDB() + Expect(err).NotTo(HaveOccurred(), "Failed to reset database") + + total := 5 + var waiter sync.WaitGroup + waiter.Add(total) + + // Track which goroutines successfully acquired the lock + var successCount int + var mu sync.Mutex + errors := make([]error, 0) + + // Simulate multiple pods trying to run migrations concurrently + for i := 0; i < total; i++ { + go func() { + defer waiter.Done() + + ctx := context.Background() + err := db.MigrateWithLock(ctx, h.DBFactory) + + mu.Lock() + defer mu.Unlock() + + if err != nil { + errors = append(errors, err) + } else { + successCount++ + } + }() + } + + waiter.Wait() + + // All migrations should succeed (they're idempotent) + Expect(errors).To(BeEmpty(), "Expected no errors during concurrent migrations") + + // All goroutines should complete successfully + Expect(successCount).To(Equal(total), "All migrations should succeed") +} + +// TestAdvisoryLockBlocking validates that a second goroutine trying to acquire +// the same lock will block until the first goroutine releases it. +func TestAdvisoryLockBlocking(t *testing.T) { + h, _ := test.RegisterIntegration(t) + + ctx := context.Background() + + // First goroutine acquires the lock + ctx1, lockOwnerID1, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "blocking-test", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire first lock") + defer db.Unlock(ctx1, lockOwnerID1) // Ensure lock is released on test exit + + // Track when the second goroutine acquires the lock + acquired := make(chan bool, 1) + released := make(chan bool, 1) + defer close(released) // ensure goroutine exits even on timeout + + // Second goroutine tries to acquire the same lock + go func() { + ctx2, lockOwnerID2, err := db.NewAdvisoryLockContext( + context.Background(), h.DBFactory, "blocking-test", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire second lock") + defer db.Unlock(ctx2, lockOwnerID2) + + acquired <- true + <-released // Wait for signal to release + }() + + // Wait for the second goroutine to be actively waiting on the lock + // by polling pg_locks for a non-granted advisory lock. + // This is more reliable than sleep, especially in slow CI environments. + g2 := h.DBFactory.New(ctx) + waitingForLock := false + for i := 0; i < 50; i++ { // Poll for up to 5 seconds (50 * 100ms) + var waitingLocks []struct{ Granted bool } + query := "SELECT granted FROM pg_locks WHERE locktype = 'advisory' AND granted = false" + err := g2.Raw(query).Scan(&waitingLocks).Error + Expect(err).NotTo(HaveOccurred(), "Failed to query pg_locks") + + if len(waitingLocks) > 0 { + waitingForLock = true + break + } + time.Sleep(100 * time.Millisecond) + } + + Expect(waitingForLock).To(BeTrue(), "Second goroutine should be waiting for lock") + + // The second goroutine should still be blocked + select { + case <-acquired: + t.Error("Second goroutine acquired lock while first still holds it") + default: + // Expected: second goroutine is still blocked + } + + // Release the first lock + db.Unlock(ctx1, lockOwnerID1) + + // Now the second goroutine should acquire the lock + select { + case <-acquired: + // Expected: second goroutine acquired the lock + released <- true + case <-time.After(5 * time.Second): + t.Error("Second goroutine did not acquire lock after first was released") + } +} + +// TestAdvisoryLockContextCancellation verifies that context cancellation properly +// terminates a waiting advisory lock acquisition. The context is passed through +// connection.New(ctx) and affects the blocking pg_advisory_xact_lock SQL call. +func TestAdvisoryLockContextCancellation(t *testing.T) { + h, _ := test.RegisterIntegration(t) + + ctx := context.Background() + + // First goroutine acquires the lock + ctx1, lockOwnerID1, err := db.NewAdvisoryLockContext(ctx, h.DBFactory, "cancel-test", db.Migrations) + Expect(err).NotTo(HaveOccurred(), "Failed to acquire first lock") + defer db.Unlock(ctx1, lockOwnerID1) + + // Track when the second goroutine gets cancelled + gotCancelError := make(chan bool, 1) + + // Create a cancellable context for the second goroutine + ctx2, cancel := context.WithCancel(context.Background()) + + // Use WaitGroup to ensure goroutine exits before test cleanup + var wg sync.WaitGroup + wg.Add(1) + + // Second goroutine tries to acquire the same lock with cancellable context + go func() { + defer wg.Done() + _, _, err := db.NewAdvisoryLockContext(ctx2, h.DBFactory, "cancel-test", db.Migrations) + if err != nil { + // Check if this is a cancellation-type error + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || + strings.Contains(err.Error(), "canceling statement due to user request") { + // Expected: context cancellation causes proper cancellation error + gotCancelError <- true + return + } + // Unexpected error - fail the test + t.Errorf("Unexpected error from lock acquisition: %v", err) + return + } + t.Error("Second goroutine acquired lock despite context cancellation (unexpected)") + }() + + // Wait for the second goroutine to be actively waiting on the lock + g2 := h.DBFactory.New(ctx) + waitingForLock := false + for i := 0; i < 50; i++ { + var waitingLocks []struct{ Granted bool } + query := "SELECT granted FROM pg_locks WHERE locktype = 'advisory' AND granted = false" + err := g2.Raw(query).Scan(&waitingLocks).Error + Expect(err).NotTo(HaveOccurred(), "Failed to query pg_locks") + + if len(waitingLocks) > 0 { + waitingForLock = true + break + } + time.Sleep(100 * time.Millisecond) + } + + Expect(waitingForLock).To(BeTrue(), "Second goroutine should be waiting for lock") + + // Cancel the context while the second goroutine is waiting + cancel() + + // The second goroutine should exit with a cancellation error + select { + case <-gotCancelError: + // Expected: context cancellation terminates the lock acquisition + case <-time.After(2 * time.Second): + t.Error("Second goroutine did not exit after context cancellation within timeout") + } + + // Ensure goroutine exits before test cleanup + wg.Wait() +} + +// migrateWithLockAndCustomMigration mimics db.MigrateWithLock but accepts a custom migration function +// This allows testing the lock acquisition/release pattern with controlled success/failure +func migrateWithLockAndCustomMigration( + ctx context.Context, + factory db.SessionFactory, + migrationFunc func(*gorm.DB) error, +) error { + // Acquire advisory lock for migrations (same pattern as production MigrateWithLock) + ctx, lockOwnerID, err := db.NewAdvisoryLockContext(ctx, factory, db.MigrationsLockID, db.Migrations) + if err != nil { + return err + } + defer db.Unlock(ctx, lockOwnerID) + + // Run custom migration with the locked context + g2 := factory.New(ctx) + if err := migrationFunc(g2); err != nil { + return err + } + + return nil +} + +// TestMigrationFailureUnderLock validates that when a migration fails while holding +// the advisory lock, the lock is properly released via defer, allowing other waiters +// to proceed. This tests the error path and cleanup behavior of the MigrateWithLock pattern. +func TestMigrationFailureUnderLock(t *testing.T) { + h, _ := test.RegisterIntegration(t) + + // Reset database to clean state + err := h.ResetDB() + Expect(err).NotTo(HaveOccurred(), "Failed to reset database") + + // Channels to coordinate goroutines + firstLockAcquired := make(chan bool, 1) + firstMigrationFailed := make(chan bool, 1) + secondCanProceed := make(chan bool, 1) + + // Track results + var mu sync.Mutex + successCount := 0 + failureCount := 0 + var wg sync.WaitGroup + + // Create a failing migration function that signals when it acquires lock and fails + failingMigration := func(_ *gorm.DB) error { + firstLockAcquired <- true + // Wait a bit to ensure second goroutine tries to acquire + time.Sleep(50 * time.Millisecond) + return fmt.Errorf("simulated migration failure") + } + + // Create a successful migration function + successfulMigration := func(_ *gorm.DB) error { + return nil + } + + // First goroutine: acquire lock and fail migration using production code path + wg.Add(1) + go func() { + defer wg.Done() + + ctx := context.Background() + err := migrateWithLockAndCustomMigration(ctx, h.DBFactory, failingMigration) + + mu.Lock() + if err != nil { + failureCount++ + } + mu.Unlock() + + firstMigrationFailed <- true + // Lock should be released via defer even though migration failed + }() + + // Wait for first goroutine to acquire lock + <-firstLockAcquired + + // Second goroutine: should block until first releases lock, then succeed + wg.Add(1) + go func() { + defer wg.Done() + + ctx := context.Background() + err := migrateWithLockAndCustomMigration(ctx, h.DBFactory, successfulMigration) + + mu.Lock() + if err == nil { + successCount++ + } + mu.Unlock() + + secondCanProceed <- true + }() + + // Wait for first migration to fail and release lock + <-firstMigrationFailed + + // Wait for second migration to complete + select { + case <-secondCanProceed: + // Expected: second goroutine acquired lock after first released it + case <-time.After(3 * time.Second): + t.Error("Second goroutine did not acquire lock after first failed") + } + + wg.Wait() + + // Verify both completed as expected + Expect(failureCount).To(Equal(1), "Expected 1 failure") + Expect(successCount).To(Equal(1), "Expected 1 success") +}