Skip to content

Commit

Permalink
add UpdateOnlyColumnsReportNoRows
Browse files Browse the repository at this point in the history
  • Loading branch information
kataras committed Nov 15, 2024
1 parent 405d793 commit 1587893
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 24 deletions.
23 changes: 16 additions & 7 deletions db_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ func (db *DB) UpdateExceptColumns(ctx context.Context, columnsToExcept []string,
}

columnsToUpdate := td.ListColumnNamesExcept(columnsToExcept...)
return db.updateTableRecords(ctx, td, columnsToUpdate, values)
return db.updateTableRecords(ctx, td, columnsToUpdate, false, values)
}

// UpdateOnlyColumns updates one or more values in the database by building and executing an
Expand All @@ -315,25 +315,25 @@ func (db *DB) UpdateOnlyColumns(ctx context.Context, columnsToUpdate []string, v
return 0, err // return the error if the table definition is not found
}

return db.updateTableRecords(ctx, td, columnsToUpdate, values)
return db.updateTableRecords(ctx, td, columnsToUpdate, false, values)
}

func (db *DB) updateTableRecords(ctx context.Context, td *desc.Table, columnsToUpdate []string, values []any) (int64, error) {
func (db *DB) updateTableRecords(ctx context.Context, td *desc.Table, columnsToUpdate []string, reportNotFound bool, values []any) (int64, error) {
primaryKey, ok := td.PrimaryKey()
if !ok {
return 0, fmt.Errorf("no primary key found in table definition: %s", td.Name)
}

if len(values) == 1 {
return db.updateTableRecord(ctx, values[0], columnsToUpdate, primaryKey)
return db.updateTableRecord(ctx, values[0], columnsToUpdate, reportNotFound, primaryKey)
}

// if more than one: update each value inside a transaction.
var totalRowsAffected int64

err := db.InTransaction(ctx, func(db *DB) error {
for _, value := range values {
rowsAffected, err := db.updateTableRecord(ctx, value, columnsToUpdate, primaryKey)
rowsAffected, err := db.updateTableRecord(ctx, value, columnsToUpdate, reportNotFound, primaryKey)
if err != nil {
return err
}
Expand All @@ -350,13 +350,22 @@ func (db *DB) updateTableRecords(ctx context.Context, td *desc.Table, columnsToU
return totalRowsAffected, nil
}

func (db *DB) updateTableRecord(ctx context.Context, value any, columnsToUpdate []string, primaryKey *desc.Column) (int64, error) {
func (db *DB) updateTableRecord(ctx context.Context, value any, columnsToUpdate []string, reportNotFound bool, primaryKey *desc.Column) (int64, error) {
// build the SQL query and arguments using the table definition and its primary key.
query, args, err := desc.BuildUpdateQuery(value, columnsToUpdate, primaryKey)
query, args, err := desc.BuildUpdateQuery(value, columnsToUpdate, reportNotFound, primaryKey)
if err != nil {
return 0, err
}

if reportNotFound {
scanErr := db.QueryRow(ctx, query, args...).Scan(nil)
if scanErr != nil {
return 0, scanErr
}

return 1, nil
}

// execute the query using db.Exec and pass in the primary key values as a parameter
tag, err := db.Exec(ctx, query, args...)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion desc/insert_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func BuildInsertQuery(td *Table, structValue reflect.Value, idPtr any, forceOnCo
if idPtr != nil {
// if idPtr is not nil, it means we want to get the primary key value of the inserted row
columnDefinition, ok := td.PrimaryKey() // get the primary key column definition from the table definition
if ok && idPtr != nil {
if ok {
returningColumn = columnDefinition.Name // assign the column name to returningColumn
}
}
Expand Down
10 changes: 7 additions & 3 deletions desc/update_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

// BuildUpdateQuery builds and returns an SQL query for updating a row in the table,
// using the given struct value and the primary key.
func BuildUpdateQuery(value any, columnsToUpdate []string, primaryKey *Column) (string, []any, error) {
func BuildUpdateQuery(value any, columnsToUpdate []string, reportNotFound bool, primaryKey *Column) (string, []any, error) {
args, err := extractUpdateArguments(value, columnsToUpdate, primaryKey)
if err != nil {
return "", nil, err
Expand All @@ -27,7 +27,7 @@ func BuildUpdateQuery(value any, columnsToUpdate []string, primaryKey *Column) (
}

// build the SQL query using the table definition and its primary key.
query := buildUpdateQuery(primaryKey.Table, args, primaryKey.Name, shouldUpdateID)
query := buildUpdateQuery(primaryKey.Table, args, primaryKey.Name, shouldUpdateID, reportNotFound)
return query, args.Values(), nil
}

Expand Down Expand Up @@ -81,7 +81,7 @@ func extractUpdateArguments(value any, columnsToUpdate []string, primaryKey *Col
return args, nil
}

func buildUpdateQuery(td *Table, args Arguments, primaryKeyName string, shouldUpdateID bool) string {
func buildUpdateQuery(td *Table, args Arguments, primaryKeyName string, shouldUpdateID bool, reportNotFound bool) string {
var b strings.Builder

b.WriteString(`UPDATE "` + td.Name + `" SET `)
Expand Down Expand Up @@ -122,6 +122,10 @@ func buildUpdateQuery(td *Table, args Arguments, primaryKeyName string, shouldUp
}
b.WriteString(` WHERE "` + primaryKeyName + `" = $` + strconv.Itoa(primaryKeyWhereIndex))

if reportNotFound {
b.WriteString(` RETURNING "` + primaryKeyName + `"`)
}

b.WriteByte(';')

return b.String()
Expand Down
8 changes: 4 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ go 1.23
require (
github.com/gertd/go-pluralize v0.2.1
github.com/jackc/pgx/v5 v5.7.1
golang.org/x/mod v0.21.0
golang.org/x/mod v0.22.0
)

require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
golang.org/x/crypto v0.28.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/crypto v0.29.0 // indirect
golang.org/x/sync v0.9.0 // indirect
golang.org/x/text v0.20.0 // indirect
)
16 changes: 8 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ=
golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg=
golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4=
golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ=
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down
26 changes: 25 additions & 1 deletion repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,31 @@ func (repo *Repository[T]) UpdateOnlyColumns(ctx context.Context, columnsToUpdat
}

valuesAsInterfaces := toInterfaces(values)
return repo.db.updateTableRecords(ctx, repo.td, columnsToUpdate, valuesAsInterfaces)
return repo.db.updateTableRecords(ctx, repo.td, columnsToUpdate, false, valuesAsInterfaces)
}

// UpdateOnlyColumnsReportNoRows updates one or more values of type T in the database by their primary key values.
// It returns an ErrNoRows if there is no matching row on the given criteria.
func (repo *Repository[T]) UpdateOnlyColumnsReportNoRows(ctx context.Context, columnsToUpdate []string, values ...T) (bool, error) {
if repo.IsReadOnly() {
return false, ErrIsReadOnly
}

if len(values) == 0 {
return false, nil
}

valuesAsInterfaces := toInterfaces(values)
_, err := repo.db.updateTableRecords(ctx, repo.td, columnsToUpdate, true, valuesAsInterfaces)
if err != nil {
if errors.Is(err, ErrNoRows) {
return false, nil
}

return false, err
}

return true, nil
}

func toInterfaces[T any](values []T) []any {
Expand Down

0 comments on commit 1587893

Please sign in to comment.