diff --git a/db.go b/db.go index 89a52d1..f8b633f 100644 --- a/db.go +++ b/db.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "slices" "strings" "sync" @@ -514,3 +515,94 @@ func (db *DB) Unlisten(ctx context.Context, channel string) error { _, err := db.Exec(ctx, query, channel) return err } + +// UpdateJSONB updates a JSONB column (full or partial) in the database by building and executing an +// SQL query based on the provided values and the given tableName and columnName. +// The values parameter is a map of key-value pairs where the key is the json field name and the value is its new value, +// new keys are accepted. Note that tableName and columnName are not escaped. +func (db *DB) UpdateJSONB(ctx context.Context, tableName, columnName, rowID string, values map[string]any, fieldsToUpdate []string) (int64, error) { + td, err := db.schema.GetByTableName(tableName) + if err != nil { + return 0, err + } + primaryKey, ok := td.PrimaryKey() + if !ok { + return 0, fmt.Errorf("primary key is required in order to perform update jsonb on table: %s", tableName) + } + + var ( + tag pgconn.CommandTag + ) + + // We could extract the id from the column and do a select based on that but let's keep things simple and do it per row id. + // id, ok := values[primaryKey.Name] + // if !ok { + // return 0, fmt.Errorf("missing primary key value") + // } + + // Partial Update. + if len(fieldsToUpdate) > 0 { + /* + // Loop over the keys and construct the path and value arrays. + path := []string{} + value := []interface{}{} + for _, key := range fieldsToUpdate { + // Get the value for the key from the map. + v, ok := values[key] + if !ok { + return 0, fmt.Errorf("missing value for key: %s", key) + } + // Append the key to the path array. + path = append(path, key) + // Append the value to the value array. + value = append(value, v) + } + + // Convert the path and value arrays to JSON. + // pathJSON, jsonErr := json.Marshal(path) + // if jsonErr != nil { + // return 0, fmt.Errorf("error converting path to json: %w", jsonErr) + // } + valueJSON, jsonErr := json.Marshal(value) + if jsonErr != nil { + return 0, fmt.Errorf("error converting value to json: %w", jsonErr) + } + + // Construct the query using jsonb_set. + query := fmt.Sprintf("UPDATE %s SET %s = jsonb_set (%s, $1::text[], $2::jsonb, true) WHERE id = $3;", tableName, columnName, columnName) + + fmt.Println(query, path, string(valueJSON), rowID) + + // Execute the query with the path, value and id parameters. + tag, err = db.Exec(ctx, query, path, string(valueJSON), rowID) + */ + + // Check if all the keys are present in the map. + for _, key := range fieldsToUpdate { + // Get the value for the key from the map. + _, ok := values[key] + if !ok { + return 0, fmt.Errorf("missing value for key: %s", key) + } + } + + // Delete the keys that are not present in the fieldsToUpdate slice. + for key := range values { + if !slices.Contains(fieldsToUpdate, key) { + delete(values, key) + } + } + + query := fmt.Sprintf("UPDATE %s SET %s = %s || $1 WHERE %s = $2;", tableName, columnName, columnName, primaryKey.Name) + tag, err = db.Exec(ctx, query, values, rowID) + } else { + // Full Update. + query := fmt.Sprintf("UPDATE %s SET %s = $1 WHERE %s = $2;", tableName, columnName, primaryKey.Name) + tag, err = db.Exec(ctx, query, values, rowID) + } + if err != nil { + return 0, fmt.Errorf("update jsonb: %w", err) + } + + return tag.RowsAffected(), nil +} diff --git a/desc/scanner.go b/desc/scanner.go index 498fb6d..e1c25c9 100644 --- a/desc/scanner.go +++ b/desc/scanner.go @@ -144,9 +144,10 @@ func findScanTargets(dstElemValue reflect.Value, td *Table, fieldDescs []pgconn. } } - if col.Type == UUID && col.Nullable { + if col.Nullable && (col.Type == UUID || + col.Type == Text || col.Type == CharacterVarying) /* Allow receive null on uuid, text and varchar columns even if the field is not a string pointer. */ { scanTargets[i] = &nullableScanner{ - uuidFieldPtr: dstElemValue.FieldByIndex(col.FieldIndex), + fieldPtr: dstElemValue.FieldByIndex(col.FieldIndex), } continue @@ -164,15 +165,15 @@ type noOpScanner struct{} func (t *noOpScanner) Scan(src interface{}) error { return nil } type nullableScanner struct { // useful for UUIDs with null values. - uuidFieldPtr reflect.Value + fieldPtr reflect.Value } func (t *nullableScanner) Scan(src interface{}) error { - if src == nil { + if src == nil { // <- IMPORTANT. return nil } - t.uuidFieldPtr.Set(reflect.ValueOf(src)) + t.fieldPtr.Set(reflect.ValueOf(src)) return nil }