Skip to content

Commit

Permalink
add UpdateJSONB
Browse files Browse the repository at this point in the history
  • Loading branch information
kataras committed Nov 14, 2023
1 parent bf1a726 commit e30e9c9
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 5 deletions.
92 changes: 92 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"reflect"
"slices"
"strings"
"sync"

Expand Down Expand Up @@ -514,3 +515,94 @@ func (db *DB) Unlisten(ctx context.Context, channel string) error {
_, err := db.Exec(ctx, query, channel)
return err
}

// UpdateJSONB updates a JSONB column (full or partial) in the database by building and executing an
// SQL query based on the provided values and the given tableName and columnName.
// The values parameter is a map of key-value pairs where the key is the json field name and the value is its new value,
// new keys are accepted. Note that tableName and columnName are not escaped.
func (db *DB) UpdateJSONB(ctx context.Context, tableName, columnName, rowID string, values map[string]any, fieldsToUpdate []string) (int64, error) {
td, err := db.schema.GetByTableName(tableName)
if err != nil {
return 0, err
}
primaryKey, ok := td.PrimaryKey()
if !ok {
return 0, fmt.Errorf("primary key is required in order to perform update jsonb on table: %s", tableName)
}

var (
tag pgconn.CommandTag
)

// We could extract the id from the column and do a select based on that but let's keep things simple and do it per row id.
// id, ok := values[primaryKey.Name]
// if !ok {
// return 0, fmt.Errorf("missing primary key value")
// }

// Partial Update.
if len(fieldsToUpdate) > 0 {
/*
// Loop over the keys and construct the path and value arrays.
path := []string{}
value := []interface{}{}
for _, key := range fieldsToUpdate {
// Get the value for the key from the map.
v, ok := values[key]
if !ok {
return 0, fmt.Errorf("missing value for key: %s", key)
}
// Append the key to the path array.
path = append(path, key)
// Append the value to the value array.
value = append(value, v)
}
// Convert the path and value arrays to JSON.
// pathJSON, jsonErr := json.Marshal(path)
// if jsonErr != nil {
// return 0, fmt.Errorf("error converting path to json: %w", jsonErr)
// }
valueJSON, jsonErr := json.Marshal(value)
if jsonErr != nil {
return 0, fmt.Errorf("error converting value to json: %w", jsonErr)
}
// Construct the query using jsonb_set.
query := fmt.Sprintf("UPDATE %s SET %s = jsonb_set (%s, $1::text[], $2::jsonb, true) WHERE id = $3;", tableName, columnName, columnName)
fmt.Println(query, path, string(valueJSON), rowID)
// Execute the query with the path, value and id parameters.
tag, err = db.Exec(ctx, query, path, string(valueJSON), rowID)
*/

// Check if all the keys are present in the map.
for _, key := range fieldsToUpdate {
// Get the value for the key from the map.
_, ok := values[key]
if !ok {
return 0, fmt.Errorf("missing value for key: %s", key)
}
}

// Delete the keys that are not present in the fieldsToUpdate slice.
for key := range values {
if !slices.Contains(fieldsToUpdate, key) {
delete(values, key)
}
}

query := fmt.Sprintf("UPDATE %s SET %s = %s || $1 WHERE %s = $2;", tableName, columnName, columnName, primaryKey.Name)
tag, err = db.Exec(ctx, query, values, rowID)
} else {
// Full Update.
query := fmt.Sprintf("UPDATE %s SET %s = $1 WHERE %s = $2;", tableName, columnName, primaryKey.Name)
tag, err = db.Exec(ctx, query, values, rowID)
}
if err != nil {
return 0, fmt.Errorf("update jsonb: %w", err)
}

return tag.RowsAffected(), nil
}
11 changes: 6 additions & 5 deletions desc/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ func findScanTargets(dstElemValue reflect.Value, td *Table, fieldDescs []pgconn.
}
}

if col.Type == UUID && col.Nullable {
if col.Nullable && (col.Type == UUID ||
col.Type == Text || col.Type == CharacterVarying) /* Allow receive null on uuid, text and varchar columns even if the field is not a string pointer. */ {
scanTargets[i] = &nullableScanner{
uuidFieldPtr: dstElemValue.FieldByIndex(col.FieldIndex),
fieldPtr: dstElemValue.FieldByIndex(col.FieldIndex),
}

continue
Expand All @@ -164,15 +165,15 @@ type noOpScanner struct{}
func (t *noOpScanner) Scan(src interface{}) error { return nil }

type nullableScanner struct { // useful for UUIDs with null values.
uuidFieldPtr reflect.Value
fieldPtr reflect.Value
}

func (t *nullableScanner) Scan(src interface{}) error {
if src == nil {
if src == nil { // <- IMPORTANT.
return nil
}

t.uuidFieldPtr.Set(reflect.ValueOf(src))
t.fieldPtr.Set(reflect.ValueOf(src))

return nil
}
Expand Down

0 comments on commit e30e9c9

Please sign in to comment.