diff --git a/db.go b/db.go index ff00dca..023111d 100644 --- a/db.go +++ b/db.go @@ -520,10 +520,18 @@ func (db *DB) Unlisten(ctx context.Context, channel string) error { // SQL query based on the provided values and the given tableName and columnName. // The values parameter is a map of key-value pairs where the key is the json field name and the value is its new value, // new keys are accepted. Note that tableName and columnName are not escaped. -func (db *DB) UpdateJSONB(ctx context.Context, tableName, columnName, rowID string, values map[string]any, fieldsToUpdate []string, primaryKey *desc.Column) (int64, error) { +func (db *DB) UpdateJSONB(ctx context.Context, tableName, columnName, rowID string, values map[string]any, fieldsToUpdate []string) (int64, error) { + td, err := db.schema.GetByTableName(tableName) + if err != nil { + return 0, err + } + primaryKey, ok := td.PrimaryKey() + if !ok { + return 0, fmt.Errorf("primary key is required in order to perform update jsonb on table: %s", tableName) + } + var ( tag pgconn.CommandTag - err error ) // We could extract the id from the column and do a select based on that but let's keep things simple and do it per row id. @@ -559,13 +567,13 @@ func (db *DB) UpdateJSONB(ctx context.Context, tableName, columnName, rowID stri query = strings.TrimSuffix(query, ", ") // Add the WHERE clause. - query += " WHERE id = $1" + query += fmt.Sprintf(" WHERE %s = $1", primaryKey.Name) // Execute the query with the id parameter. tag, err = db.Exec(ctx, query, rowID) } else { // Full Update. - query := fmt.Sprintf("UPDATE %s SET %s = $1 WHERE id = $2;", tableName, columnName) + query := fmt.Sprintf("UPDATE %s SET %s = $1 WHERE %s = $2;", tableName, columnName, primaryKey.Name) tag, err = db.Exec(ctx, query, values, rowID) } if err != nil {