diff --git a/.gitignore b/.gitignore index 3b735ec..bf3691f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,3 @@ -# If you prefer the allow list template instead of the deny list, see community template: -# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore -# # Binaries for programs and plugins *.exe *.exe~ @@ -8,14 +5,17 @@ *.so *.dylib -# Test binary, built with `go test -c` +# Test binary, build with `go test -c` *.test # Output of the go coverage tool, specifically when used with LiteIDE *.out -# Dependency directories (remove the comment below to include it) -# vendor/ +# Visual Studio Code internal folder +.vscode -# Go workspace file -go.work +# Packages ouput folder +dist + +# delve debugger file +debug \ No newline at end of file diff --git a/README.md b/README.md index b6fd806..92837a1 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,56 @@ # pgx-migrator -Simple pgx based PostgreSQL schema migration library for Go + +Simple [pgx](https://github.com/jackc/pgx) oriented [PostgreSQL](https://www.postgresql.org/) schema migration library for Go based on [lopezator/migrator](https://github.com/lopezator/migrator). + +# Features + +* Simple code +* Usage as a library, embeddable and extensible on your behalf +* Made to use with `jackc/pgx` +* Go code migrations, either transactional or transaction-less, using `pgx.Tx` (`migrator.Migration`) or `pgx.Conn` and `pgx.Pool` (`migrator.MigrationNoTx`) +* No need to use `//go:embed` or others, since all migrations are just Go code + +# Usage + +Customize this to your needs by changing the driver and/or connection settings. + +### QuickStart: + +```go +package main + +import ( + + pgx "github.com/jackc/pgx/v5" + migrator "github.com/cybertec-postgresql/pgx-migrator" +) + +func main() { + // Configure migrations + m, err := migrator.New( + migrator.Migrations( + &migrator.Migration{ + Name: "Create table foo", + Func: func(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, "CREATE TABLE foo (id INT PRIMARY KEY)") + return err + }, + }, + ), + ) + if err != nil { + panic(err) + } + + // Open database connection + conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) + if err != nil { + panic(err) + } + + // Migrate up + if err := m.Migrate(conn); err != nil { + panic(err) + } +} +``` \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..118d3c2 --- /dev/null +++ b/go.mod @@ -0,0 +1,21 @@ +module github.com/cybertec-postgresql/pgx-migrator + +go 1.22.0 + +require ( + github.com/jackc/pgx/v5 v5.5.5 + github.com/pashagolub/pgxmock/v3 v3.4.0 + github.com/stretchr/testify v1.9.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/crypto v0.17.0 // indirect + golang.org/x/sync v0.1.0 // indirect + golang.org/x/text v0.14.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..998f725 --- /dev/null +++ b/go.sum @@ -0,0 +1,38 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= +github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pashagolub/pgxmock/v3 v3.4.0 h1:87VMr2q7m2+6VzXo4Tsp9kMklGlj6mMN19Hp/bp2Rwo= +github.com/pashagolub/pgxmock/v3 v3.4.0/go.mod h1:FvCl7xqPbLLI3XohihJ1NzXnikjM3q/NWSixg4t9hrU= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= +golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/migrator.go b/migrator.go new file mode 100644 index 0000000..95bc216 --- /dev/null +++ b/migrator.go @@ -0,0 +1,218 @@ +package migrator + +import ( + "context" + "errors" + "fmt" + + pgx "github.com/jackc/pgx/v5" + pgconn "github.com/jackc/pgx/v5/pgconn" +) + +// PgxIface is interface for database connection or transaction +type PgxIface interface { + Begin(ctx context.Context) (pgx.Tx, error) + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row + Query(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error) + Ping(ctx context.Context) error +} + +const defaultTableName = "migrations" + +// Migrator is the migrator implementation +type Migrator struct { + TableName string + migrations []interface{} + onNotice func(string) +} + +// Option sets options such migrations or table name. +type Option func(*Migrator) + +// TableName creates an option to allow overriding the default table name +func TableName(tableName string) Option { + return func(m *Migrator) { + m.TableName = tableName + } +} + +// SetNotice overrides the default standard output function +func SetNotice(noticeFunc func(string)) Option { + return func(m *Migrator) { + m.onNotice = noticeFunc + } +} + +// Migrations creates an option with provided migrations +func Migrations(migrations ...interface{}) Option { + return func(m *Migrator) { + m.migrations = migrations + } +} + +// New creates a new migrator instance +func New(opts ...Option) (*Migrator, error) { + m := &Migrator{ + TableName: defaultTableName, + onNotice: func(msg string) { + fmt.Println(msg) + }, + } + for _, opt := range opts { + opt(m) + } + + if len(m.migrations) == 0 { + return nil, errors.New("Migrations must be provided") + } + + for _, m := range m.migrations { + switch m.(type) { + case *Migration: + case *MigrationNoTx: + default: + return nil, errors.New("Invalid migration type") + } + } + + return m, nil +} + +// Migrate applies all available migrations +func (m *Migrator) Migrate(ctx context.Context, db PgxIface) error { + // create migrations table if doesn't exist + _, err := db.Exec(ctx, fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id INT8 NOT NULL, + version TEXT NOT NULL, + PRIMARY KEY (id) + ); + `, m.TableName)) + if err != nil { + return err + } + + pm, count, err := m.Pending(ctx, db) + if err != nil { + return err + } + + // plan migrations + for idx, migration := range pm { + insertVersion := fmt.Sprintf("INSERT INTO %s (id, version) VALUES (%d, '%s')", m.TableName, idx+count, migration.(fmt.Stringer).String()) + switch mm := migration.(type) { + case *Migration: + if err := migrate(ctx, db, insertVersion, mm, m.onNotice); err != nil { + return fmt.Errorf("Error while running migrations: %w", err) + } + case *MigrationNoTx: + if err := migrateNoTx(ctx, db, insertVersion, mm, m.onNotice); err != nil { + return fmt.Errorf("Error while running migrations: %w", err) + } + } + } + + return nil +} + +// Pending returns all pending (not yet applied) migrations and count of migration applied +func (m *Migrator) Pending(ctx context.Context, db PgxIface) ([]interface{}, int, error) { + count, err := countApplied(ctx, db, m.TableName) + if err != nil { + return nil, 0, err + } + if count > len(m.migrations) { + count = len(m.migrations) + } + return m.migrations[count:len(m.migrations)], count, nil +} + +// NeedUpgrade returns True if database need to be updated with migrations +func (m *Migrator) NeedUpgrade(ctx context.Context, db PgxIface) (bool, error) { + exists, err := tableExists(ctx, db, m.TableName) + if !exists { + return true, err + } + mm, _, err := m.Pending(ctx, db) + return len(mm) > 0, err +} + +func countApplied(ctx context.Context, db PgxIface, tableName string) (int, error) { + // count applied migrations + var count int + err := db.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s", tableName)).Scan(&count) + if err != nil { + return 0, err + } + return count, nil +} + +func tableExists(ctx context.Context, db PgxIface, tableName string) (bool, error) { + var exists bool + err := db.QueryRow(ctx, "SELECT to_regclass($1) IS NOT NULL", tableName).Scan(&exists) + if err != nil { + return false, err + } + return exists, nil +} + +// Migration represents a single migration +type Migration struct { + Name string + Func func(context.Context, pgx.Tx) error +} + +// String returns a string representation of the migration +func (m *Migration) String() string { + return m.Name +} + +// MigrationNoTx represents a single not transactional migration +type MigrationNoTx struct { + Name string + Func func(context.Context, PgxIface) error +} + +func (m *MigrationNoTx) String() string { + return m.Name +} + +func migrate(ctx context.Context, db PgxIface, insertVersion string, migration *Migration, notice func(string)) error { + tx, err := db.Begin(ctx) + if err != nil { + return err + } + defer func() { + if err != nil { + if errRb := tx.Rollback(ctx); errRb != nil { + err = fmt.Errorf("Error rolling back: %s\n%s", errRb, err) + } + return + } + err = tx.Commit(ctx) + }() + notice(fmt.Sprintf("Applying migration named '%s'...", migration.Name)) + if err = migration.Func(ctx, tx); err != nil { + return fmt.Errorf("Error executing golang migration: %w", err) + } + if _, err = tx.Exec(ctx, insertVersion); err != nil { + return fmt.Errorf("Error updating migration versions: %w", err) + } + notice(fmt.Sprintf("Applied migration named '%s'", migration.Name)) + + return err +} + +func migrateNoTx(ctx context.Context, db PgxIface, insertVersion string, migration *MigrationNoTx, notice func(string)) error { + notice(fmt.Sprintf("Applying no tx migration named '%s'...", migration.Name)) + if err := migration.Func(ctx, db); err != nil { + return fmt.Errorf("Error executing golang migration: %w", err) + } + if _, err := db.Exec(ctx, insertVersion); err != nil { + return fmt.Errorf("Error updating migration versions: %w", err) + } + notice(fmt.Sprintf("Applied no tx migration named '%s'...", migration.Name)) + + return nil +} diff --git a/migrator_test.go b/migrator_test.go new file mode 100644 index 0000000..c03d2b0 --- /dev/null +++ b/migrator_test.go @@ -0,0 +1,214 @@ +package migrator_test + +import ( + "context" + "errors" + "math" + "testing" + + migrator "github.com/cybertec-postgresql/pgx-migrator" + pgx "github.com/jackc/pgx/v5" + "github.com/pashagolub/pgxmock/v3" + "github.com/stretchr/testify/assert" +) + +var migrations = []interface{}{ + &migrator.Migration{ + Name: "Using tx, encapsulate two queries", + Func: func(ctx context.Context, tx pgx.Tx) error { + if _, err := tx.Exec(ctx, "CREATE TABLE foo (id INT PRIMARY KEY)"); err != nil { + return err + } + if _, err := tx.Exec(ctx, "INSERT INTO foo (id) VALUES (1)"); err != nil { + return err + } + return nil + }, + }, + &migrator.MigrationNoTx{ + Name: "Using db, execute one query", + Func: func(ctx context.Context, db migrator.PgxIface) error { + if _, err := db.Exec(ctx, "INSERT INTO foo (id) VALUES (2)"); err != nil { + return err + } + return nil + }, + }, +} + +func TestMigratorConstructor(t *testing.T) { + _, err := migrator.New() //migrator.Migrations(migrations...) + assert.Error(t, err, "Should throw error when migrations are empty") + + _, err = migrator.New(migrator.Migrations(struct{ Foo string }{Foo: "bar"})) + assert.Error(t, err, "Should throw error for unknown migration type") +} + +func TestTableExists(t *testing.T) { + mock, err := pgxmock.NewPool() + assert.NoError(t, err) + defer mock.Close() + + m, err := migrator.New(migrator.Migrations(migrations...)) + assert.NoError(t, err) + assert.NotNil(t, m) + + sqlresults := []struct { + testname string + tableexists bool + appliedcount int + needupgrade bool + tableerr error + counterr error + }{ + { + testname: "table exists and no migrations applied", + tableexists: true, + appliedcount: 0, + needupgrade: true, + tableerr: nil, + counterr: nil, + }, + { + testname: "table exists and a lot of migrations applied", + tableexists: true, + appliedcount: math.MaxInt32, + needupgrade: false, + tableerr: nil, + counterr: nil, + }, + { + testname: "error occurred during count query", + tableexists: true, + appliedcount: 0, + needupgrade: false, + tableerr: nil, + counterr: errors.New("internal error"), + }, + { + testname: "error occurred during table exists query", + tableexists: false, + appliedcount: 0, + needupgrade: true, + tableerr: errors.New("internal error"), + counterr: nil, + }, + } + var expectederr error + for _, res := range sqlresults { + if q := mock.ExpectQuery("SELECT to_regclass").WithArgs(pgxmock.AnyArg()); res.tableerr != nil { + q.WillReturnError(res.tableerr) + expectederr = res.tableerr + } else { + q.WillReturnRows(pgxmock.NewRows([]string{"to_regclass"}).AddRow(res.tableexists)) + } + if q := mock.ExpectQuery("SELECT count"); res.counterr != nil { + q.WillReturnError(res.counterr) + expectederr = res.counterr + } else { + q.WillReturnRows(pgxmock.NewRows([]string{"count"}).AddRow(res.appliedcount)) + } + need, err := m.NeedUpgrade(context.Background(), mock) + assert.Equal(t, expectederr, err, "NeedUpgrade test failed: ", res.testname) + assert.Equal(t, res.needupgrade, need, "NeedUpgrade incorrect return: ", res.testname) + } +} + +func TestMigrateExists(t *testing.T) { + mock, err := pgxmock.NewPool() + assert.NoError(t, err) + defer mock.Close() + + m, err := migrator.New(migrator.Migrations(migrations...)) + assert.NoError(t, err) + assert.NotNil(t, m) + + expectederr := errors.New("internal error") + + mock.ExpectExec("CREATE TABLE").WillReturnResult(pgxmock.NewResult("DDL", 0)) + mock.ExpectQuery("SELECT count").WillReturnError(expectederr) + + err = m.Migrate(context.Background(), mock) + assert.Equal(t, expectederr, err, "Migrate test failed: ", err) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestMigrateNoTxError(t *testing.T) { + mock, err := pgxmock.NewPool() + assert.NoError(t, err) + defer mock.Close() + + m, err := migrator.New(migrator.Migrations(&migrator.MigrationNoTx{Func: func(context.Context, migrator.PgxIface) error { return nil }})) + assert.NoError(t, err) + assert.NotNil(t, m) + + expectederr := errors.New("internal error") + + mock.ExpectExec("CREATE TABLE").WillReturnResult(pgxmock.NewResult("DDL", 0)) + mock.ExpectQuery("SELECT count").WillReturnRows(pgxmock.NewRows([]string{"count"}).AddRow(0)) + mock.ExpectExec("INSERT").WillReturnError(expectederr) + + err = m.Migrate(context.Background(), mock) + for errors.Unwrap(err) != nil { + err = errors.Unwrap(err) + } + assert.Equal(t, expectederr, err, "MigrateNoTxError test failed: ", err) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestMigrateTxError(t *testing.T) { + mock, err := pgxmock.NewPool() + assert.NoError(t, err) + defer mock.Close() + + m, err := migrator.New(migrator.Migrations(&migrator.Migration{Func: func(context.Context, pgx.Tx) error { return nil }})) + assert.NoError(t, err) + assert.NotNil(t, m) + + expectederr := errors.New("create table error") + mock.ExpectExec("CREATE TABLE").WillReturnError(expectederr) + err = m.Migrate(context.Background(), mock) + assert.Equal(t, expectederr, err, "MigrateTxError test failed: ", err) + + expectederr = errors.New("internal tx error") + mock.ExpectExec("CREATE TABLE").WillReturnResult(pgxmock.NewResult("DDL", 0)) + mock.ExpectQuery("SELECT count").WillReturnRows(pgxmock.NewRows([]string{"count"}).AddRow(0)) + mock.ExpectBegin().WillReturnError(expectederr) + err = m.Migrate(context.Background(), mock) + for errors.Unwrap(err) != nil { + err = errors.Unwrap(err) + } + assert.Equal(t, expectederr, err, "MigrateTxError test failed: ", err) + + expectederr = errors.New("internal tx error") + mock.ExpectExec("CREATE TABLE").WillReturnResult(pgxmock.NewResult("DDL", 0)) + mock.ExpectQuery("SELECT count").WillReturnRows(pgxmock.NewRows([]string{"count"}).AddRow(0)) + mock.ExpectBegin() + mock.ExpectExec("INSERT").WillReturnError(expectederr) + err = m.Migrate(context.Background(), mock) + for errors.Unwrap(err) != nil { + err = errors.Unwrap(err) + } + assert.Equal(t, expectederr, err, "MigrateTxError test failed: ", err) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestMigratorOptions(t *testing.T) { + O := migrator.TableName("foo") + m := &migrator.Migrator{} + O(m) + assert.Equal(t, "foo", m.TableName) + + f := func(string) {} + O = migrator.SetNotice(f) + O(m) +}