Skip to content

Commit

Permalink
Add provider CheckPending method (#756)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman authored Apr 25, 2024
1 parent 31de74d commit 6a5697e
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 17 deletions.
51 changes: 39 additions & 12 deletions internal/testing/integration/postgres_locking_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,30 +410,37 @@ func TestPostgresProviderLocking(t *testing.T) {
})
}

func TestPostgresHasPending(t *testing.T) {
func TestPostgresPending(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip("skipping test in short mode.")
}
const testDir = "testdata/migrations/postgres"

db, cleanup, err := testdb.NewPostgres()
require.NoError(t, err)
t.Cleanup(cleanup)

files, err := os.ReadDir(testDir)
require.NoError(t, err)

workers := 15

run := func(want bool) {
run := func(t *testing.T, want bool, wantCurrent, wantTarget int) {
t.Helper()
var g errgroup.Group
boolCh := make(chan bool, workers)
for i := 0; i < workers; i++ {
g.Go(func() error {
p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS("testdata/migrations/postgres"))
p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS(testDir))
check.NoError(t, err)
hasPending, err := p.HasPending(context.Background())
if err != nil {
return err
}
check.NoError(t, err)
boolCh <- hasPending
current, target, err := p.CheckPending(context.Background())
check.NoError(t, err)
check.Number(t, current, int64(wantCurrent))
check.Number(t, target, int64(wantTarget))
return nil

})
Expand All @@ -446,7 +453,7 @@ func TestPostgresHasPending(t *testing.T) {
}
}
t.Run("concurrent_has_pending", func(t *testing.T) {
run(true)
run(t, true, 0, len(files))
})

// apply all migrations
Expand All @@ -456,12 +463,12 @@ func TestPostgresHasPending(t *testing.T) {
check.NoError(t, err)

t.Run("concurrent_no_pending", func(t *testing.T) {
run(false)
run(t, false, len(files), len(files))
})

// Add a new migration file
last := p.ListSources()[len(p.ListSources())-1]
newVersion := fmt.Sprintf("%d_new_migration.sql", last.Version+1)
lastVersion := len(files)
newVersion := fmt.Sprintf("%d_new_migration.sql", lastVersion+1)
fsys := fstest.MapFS{
newVersion: &fstest.MapFile{Data: []byte(`
-- +goose Up
Expand All @@ -476,7 +483,7 @@ SELECT pg_sleep_for('4 seconds');
check.NoError(t, err)
check.Number(t, len(newProvider.ListSources()), 1)
oldProvider := p
check.Number(t, len(oldProvider.ListSources()), 6)
check.Number(t, len(oldProvider.ListSources()), len(files))

var g errgroup.Group
g.Go(func() error {
Expand All @@ -485,6 +492,12 @@ SELECT pg_sleep_for('4 seconds');
return err
}
check.Bool(t, hasPending, true)
current, target, err := newProvider.CheckPending(context.Background())
if err != nil {
return err
}
check.Number(t, current, lastVersion)
check.Number(t, target, lastVersion+1)
return nil
})
g.Go(func() error {
Expand All @@ -493,6 +506,12 @@ SELECT pg_sleep_for('4 seconds');
return err
}
check.Bool(t, hasPending, false)
current, target, err := oldProvider.CheckPending(context.Background())
if err != nil {
return err
}
check.Number(t, current, lastVersion)
check.Number(t, target, lastVersion)
return nil
})
check.NoError(t, g.Wait())
Expand All @@ -512,16 +531,24 @@ SELECT pg_sleep_for('4 seconds');
hasPending, err := oldProvider.HasPending(context.Background())
check.NoError(t, err)
check.Bool(t, hasPending, false)
current, target, err := oldProvider.CheckPending(context.Background())
check.NoError(t, err)
check.Number(t, current, lastVersion)
check.Number(t, target, lastVersion)
// Wait for the long running migration to finish
check.NoError(t, g.Wait())
// Check that the new migration was applied
hasPending, err = newProvider.HasPending(context.Background())
check.NoError(t, err)
check.Bool(t, hasPending, false)
current, target, err = newProvider.CheckPending(context.Background())
check.NoError(t, err)
check.Number(t, current, lastVersion+1)
check.Number(t, target, lastVersion+1)
// The max version should be the new migration
currentVersion, err := newProvider.GetDBVersion(context.Background())
check.NoError(t, err)
check.Number(t, currentVersion, last.Version+1)
check.Number(t, currentVersion, lastVersion+1)
}

func existsPgLock(ctx context.Context, db *sql.DB, lockID int64) (bool, error) {
Expand Down
44 changes: 44 additions & 0 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,15 @@ func (p *Provider) HasPending(ctx context.Context) (bool, error) {
return p.hasPending(ctx)
}

// CheckPending returns the current database version and the target version to migrate to. If there
// are no pending migrations, the target version will be the same as the current version.
//
// Note, this method will not use a SessionLocker if one is configured. This allows callers to check
// for pending migrations without blocking or being blocked by other operations.
func (p *Provider) CheckPending(ctx context.Context) (current, target int64, err error) {
return p.checkPending(ctx)
}

// GetDBVersion returns the highest version recorded in the database, regardless of the order in
// which migrations were applied. For example, if migrations were applied out of order (1,4,2,3),
// this method returns 4. If no migrations have been applied, it returns 0.
Expand Down Expand Up @@ -465,6 +474,41 @@ func (p *Provider) apply(
return p.runMigrations(ctx, conn, []*Migration{m}, d, true)
}

func (p *Provider) checkPending(ctx context.Context) (current, target int64, retErr error) {
conn, cleanup, err := p.initialize(ctx, false)
if err != nil {
return -1, -1, fmt.Errorf("failed to initialize: %w", err)
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()

// If versioning is disabled, we always have pending migrations and the target version is the
// last migration.
if p.cfg.disableVersioning {
return -1, p.migrations[len(p.migrations)-1].Version, nil
}
// optimize(mf): we should only fetch the max version from the database, no need to fetch all
// migrations only to get the max version when we're not using out-of-order migrations.
res, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return -1, -1, err
}
dbVersions := make([]int64, 0, len(res))
for _, m := range res {
dbVersions = append(dbVersions, m.Version)
}
sort.Slice(dbVersions, func(i, j int) bool {
return dbVersions[i] < dbVersions[j]
})
if len(dbVersions) == 0 {
return -1, -1, errMissingZeroVersion
} else {
current = dbVersions[len(dbVersions)-1]
}
return current, p.migrations[len(p.migrations)-1].Version, nil
}

func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) {
conn, cleanup, err := p.initialize(ctx, false)
if err != nil {
Expand Down
25 changes: 20 additions & 5 deletions provider_run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"database/sql"
"errors"
"fmt"
"io/fs"
"math"
"math/rand"
"os"
Expand Down Expand Up @@ -775,11 +774,12 @@ func TestProviderApply(t *testing.T) {
check.Bool(t, errors.Is(err, goose.ErrNotApplied), true)
}

func TestHasPending(t *testing.T) {
func TestPending(t *testing.T) {
t.Parallel()
t.Run("allow_out_of_order", func(t *testing.T) {
ctx := context.Background()
p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), newFsys(),
fsys := newFsys()
p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), fsys,
goose.WithAllowOutofOrder(true),
)
check.NoError(t, err)
Expand All @@ -791,17 +791,25 @@ func TestHasPending(t *testing.T) {
hasPending, err := p.HasPending(ctx)
check.NoError(t, err)
check.Bool(t, hasPending, true)
current, target, err := p.CheckPending(ctx)
check.NoError(t, err)
check.Number(t, current, 3)
check.Number(t, target, len(fsys))
// Apply the missing migrations.
_, err = p.Up(ctx)
check.NoError(t, err)
// All migrations have been applied.
hasPending, err = p.HasPending(ctx)
check.NoError(t, err)
check.Bool(t, hasPending, false)
current, target, err = p.CheckPending(ctx)
check.NoError(t, err)
check.Number(t, current, target)
})
t.Run("disallow_out_of_order", func(t *testing.T) {
ctx := context.Background()
p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), newFsys(),
fsys := newFsys()
p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), fsys,
goose.WithAllowOutofOrder(false),
)
check.NoError(t, err)
Expand All @@ -813,12 +821,19 @@ func TestHasPending(t *testing.T) {
hasPending, err := p.HasPending(ctx)
check.NoError(t, err)
check.Bool(t, hasPending, true)
current, target, err := p.CheckPending(ctx)
check.NoError(t, err)
check.Number(t, current, 2)
check.Number(t, target, len(fsys))
_, err = p.Up(ctx)
check.NoError(t, err)
// All migrations have been applied.
hasPending, err = p.HasPending(ctx)
check.NoError(t, err)
check.Bool(t, hasPending, false)
current, target, err = p.CheckPending(ctx)
check.NoError(t, err)
check.Number(t, current, target)
})
}

Expand Down Expand Up @@ -1089,7 +1104,7 @@ func newMapFile(data string) *fstest.MapFile {
}
}

func newFsys() fs.FS {
func newFsys() fstest.MapFS {
return fstest.MapFS{
"00001_users_table.sql": newMapFile(runMigration1),
"00002_posts_table.sql": newMapFile(runMigration2),
Expand Down

0 comments on commit 6a5697e

Please sign in to comment.