Skip to content

Commit

Permalink
refactor: add TableExists assertion support and improve docs (#641)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman authored Nov 12, 2023
1 parent c5e0d3c commit 9e6ef20
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 44 deletions.
7 changes: 5 additions & 2 deletions database/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ type Store interface {
Tablename() string

// CreateVersionTable creates the version table. This table is used to record applied
// migrations.
// migrations. When creating the table, the implementation must also insert a row for the
// initial version (0).
CreateVersionTable(ctx context.Context, db DBTxConn) error

// Insert inserts a version id into the version table.
Expand All @@ -35,7 +36,9 @@ type Store interface {
GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error)

// ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If
// there are no migrations, return empty slice with no error.
// there are no migrations, return empty slice with no error. Typically this method will return
// at least one migration, because the initial version (0) is always inserted into the version
// table when it is created.
ListMigrations(ctx context.Context, db DBTxConn) ([]*ListMigrationsResult, error)

// TODO(mf): remove this method once the Provider is public and a custom Store can be used.
Expand Down
6 changes: 3 additions & 3 deletions lock/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ func (l *postgresSessionLocker) SessionUnlock(ctx context.Context, conn *sql.Con
return nil
}
/*
TODO(mf): provide users with some documentation on how they can unlock the session
docs(md): provide users with some documentation on how they can unlock the session
manually.
This is probably not an issue for 99.99% of users since pg_advisory_unlock_all() will
release all session level advisory locks held by the current session. This function is
implicitly invoked at session end, even if the client disconnects ungracefully.
release all session level advisory locks held by the current session. It is implicitly
invoked at session end, even if the client disconnects ungracefully.
Here is output from a session that has a lock held:
Expand Down
5 changes: 5 additions & 0 deletions migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,8 @@ func truncateDuration(d time.Duration) time.Duration {
}
return d
}

// ref returns a string that identifies the migration. This is used for logging and error messages.
func (m *Migration) ref() string {
return fmt.Sprintf("(type:%s,version:%d)", m.Type, m.Version)
}
10 changes: 5 additions & 5 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func (p *Provider) up(
}
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to initialize: %w", err)
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
Expand Down Expand Up @@ -339,7 +339,7 @@ func (p *Provider) down(
) (_ []*MigrationResult, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to initialize: %w", err)
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
Expand Down Expand Up @@ -397,7 +397,7 @@ func (p *Provider) apply(
}
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to initialize: %w", err)
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
Expand All @@ -422,7 +422,7 @@ func (p *Provider) apply(
func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to initialize: %w", err)
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
Expand Down Expand Up @@ -455,7 +455,7 @@ func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr err
func (p *Provider) getDBMaxVersion(ctx context.Context) (_ int64, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return 0, err
return 0, fmt.Errorf("failed to initialize: %w", err)
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
Expand Down
8 changes: 4 additions & 4 deletions provider_collect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ func TestCheckMissingMigrations(t *testing.T) {
}
got := checkMissingMigrations(dbMigrations, fsMigrations)
check.Number(t, len(got), 2)
check.Number(t, got[0].versionID, 2)
check.Number(t, got[1].versionID, 6)
check.Number(t, got[0], 2)
check.Number(t, got[1], 6)

// Sanity check.
check.Number(t, len(checkMissingMigrations(nil, nil)), 0)
Expand All @@ -333,8 +333,8 @@ func TestCheckMissingMigrations(t *testing.T) {
}
got := checkMissingMigrations(dbMigrations, fsMigrations)
check.Number(t, len(got), 2)
check.Number(t, got[0].versionID, 3)
check.Number(t, got[1].versionID, 4)
check.Number(t, got[0], 3)
check.Number(t, got[1], 4)
})
}

Expand Down
10 changes: 4 additions & 6 deletions provider_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package goose
import (
"errors"
"fmt"
"path/filepath"
)

var (
Expand Down Expand Up @@ -32,9 +31,8 @@ type PartialError struct {
}

func (e *PartialError) Error() string {
filename := "(file unknown)"
if e.Failed != nil && e.Failed.Source.Path != "" {
filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Path))
}
return fmt.Sprintf("partial migration error %s (%d): %v", filename, e.Failed.Source.Version, e.Err)
return fmt.Sprintf(
"partial migration error (type:%s,version:%d): %v",
e.Failed.Source.Type, e.Failed.Source.Version, e.Err,
)
}
51 changes: 30 additions & 21 deletions provider_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io/fs"
"sort"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -43,7 +44,7 @@ func (p *Provider) resolveUpMigrations(
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
var collected []string
for _, v := range missingMigrations {
collected = append(collected, fmt.Sprintf("%d", v.versionID))
collected = append(collected, strconv.FormatInt(v, 10))
}
msg := "migration"
if len(collected) > 1 {
Expand All @@ -53,8 +54,8 @@ func (p *Provider) resolveUpMigrations(
len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","),
)
}
for _, v := range missingMigrations {
m, err := p.getMigration(v.versionID)
for _, missingVersion := range missingMigrations {
m, err := p.getMigration(missingVersion)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -141,7 +142,7 @@ func (p *Provider) runMigrations(

for _, m := range apply {
if err := p.prepareMigration(p.fsys, m, direction.ToBool()); err != nil {
return nil, err
return nil, fmt.Errorf("failed to prepare migration %s: %w", m.ref(), err)
}
}

Expand Down Expand Up @@ -301,11 +302,26 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err
}

func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) {
// feat(mf): this is where we can check if the version table exists instead of trying to fetch
// from a table that may not exist. https://github.com/pressly/goose/issues/461
res, err := p.store.GetMigration(ctx, conn, 0)
if err == nil && res != nil {
return nil
// existor is an interface that extends the Store interface with a method to check if the
// version table exists. This API is not stable and may change in the future.
type existor interface {
TableExists(context.Context, database.DBTxConn, string) (bool, error)
}
if e, ok := p.store.(existor); ok {
exists, err := e.TableExists(ctx, conn, p.store.Tablename())
if err != nil {
return fmt.Errorf("failed to check if version table exists: %w", err)
}
if exists {
return nil
}
} else {
// feat(mf): this is where we can check if the version table exists instead of trying to fetch
// from a table that may not exist. https://github.com/pressly/goose/issues/461
res, err := p.store.GetMigration(ctx, conn, 0)
if err == nil && res != nil {
return nil
}
}
return beginTx(ctx, conn, func(tx *sql.Tx) error {
if err := p.store.CreateVersionTable(ctx, tx); err != nil {
Expand All @@ -318,16 +334,12 @@ func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retE
})
}

type missingMigration struct {
versionID int64
}

// checkMissingMigrations returns a list of migrations that are missing from the database. A missing
// migration is one that has a version less than the max version in the database.
func checkMissingMigrations(
dbMigrations []*database.ListMigrationsResult,
fsMigrations []*Migration,
) []missingMigration {
) []int64 {
existing := make(map[int64]bool)
var dbMaxVersion int64
for _, m := range dbMigrations {
Expand All @@ -336,17 +348,14 @@ func checkMissingMigrations(
dbMaxVersion = m.Version
}
}
var missing []missingMigration
var missing []int64
for _, m := range fsMigrations {
version := m.Version
if !existing[version] && version < dbMaxVersion {
missing = append(missing, missingMigration{
versionID: version,
})
if !existing[m.Version] && m.Version < dbMaxVersion {
missing = append(missing, m.Version)
}
}
sort.Slice(missing, func(i, j int) bool {
return missing[i].versionID < missing[j].versionID
return missing[i] < missing[j]
})
return missing
}
Expand Down
31 changes: 29 additions & 2 deletions provider_run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"testing/fstest"

"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/testdb"
"github.com/pressly/goose/v3/lock"
Expand All @@ -31,7 +32,7 @@ func TestProviderRun(t *testing.T) {
check.NoError(t, db.Close())
_, err := p.Up(context.Background())
check.HasError(t, err)
check.Equal(t, err.Error(), "sql: database is closed")
check.Equal(t, err.Error(), "failed to initialize: sql: database is closed")
})
t.Run("ping_and_close", func(t *testing.T) {
p, _ := newProviderWithDB(t)
Expand Down Expand Up @@ -324,7 +325,7 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3');
check.NoError(t, err)
_, err = p.Up(ctx)
check.HasError(t, err)
check.Contains(t, err.Error(), "partial migration error (00002_partial_error.sql) (2)")
check.Contains(t, err.Error(), "partial migration error (type:sql,version:2)")
var expected *goose.PartialError
check.Bool(t, errors.As(err, &expected), true)
// Check Err field
Expand Down Expand Up @@ -723,6 +724,32 @@ func TestSQLiteSharedCache(t *testing.T) {
})
}

func TestCustomStoreTableExists(t *testing.T) {
t.Parallel()

store, err := database.NewStore(database.DialectSQLite3, goose.DefaultTablename)
check.NoError(t, err)
p, err := goose.NewProvider("", newDB(t), newFsys(),
goose.WithStore(&customStoreSQLite3{store}),
)
check.NoError(t, err)
_, err = p.Up(context.Background())
check.NoError(t, err)
}

type customStoreSQLite3 struct {
database.Store
}

func (s *customStoreSQLite3) TableExists(ctx context.Context, db database.DBTxConn, name string) (bool, error) {
q := `SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type='table' AND name=$1) AS table_exists`
var exists bool
if err := db.QueryRowContext(ctx, q, name).Scan(&exists); err != nil {
return false, err
}
return exists, nil
}

func getGooseVersionCount(db *sql.DB, gooseTable string) (int64, error) {
var gotVersion int64
if err := db.QueryRow(
Expand Down
1 change: 0 additions & 1 deletion up.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ func UpByOneContext(ctx context.Context, db *sql.DB, dir string, opts ...Options
}

// listAllDBVersions returns a list of all migrations, ordered ascending.
// TODO(mf): fairly cheap, but a nice-to-have is pagination support.
func listAllDBVersions(ctx context.Context, db *sql.DB) (Migrations, error) {
dbMigrations, err := store.ListMigrations(ctx, db, TableName())
if err != nil {
Expand Down

0 comments on commit 9e6ef20

Please sign in to comment.