Skip to content

Commit

Permalink
new db and repository method: ListenTable to listen for table's row c…
Browse files Browse the repository at this point in the history
…hanges (INSERT, UPDATE, DELETE)
  • Loading branch information
kataras committed Oct 28, 2023
1 parent 84634dd commit e75bffa
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 14 deletions.
49 changes: 40 additions & 9 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"strings"
"sync"

"github.com/kataras/pg/desc"

Expand Down Expand Up @@ -87,6 +88,10 @@ type DB struct {
tx pgx.Tx
dbTxClosed bool

tableChangeNotifyOnceMutex *sync.RWMutex
tableChangeNotifyFunctionOnce *uint32
tableChangeNotifyTriggerOnce map[string]struct{}

schema *Schema
}

Expand Down Expand Up @@ -176,10 +181,13 @@ func OpenPool(schema *Schema, pool *pgxpool.Pool) *DB {
}

db := &DB{ // create a new DB instance with the fields
Pool: pool, // set the pool field
ConnectionOptions: config, // set the connection options field
searchPath: searchPath, // set the search path field
schema: schema, // set the schema field
Pool: pool, // set the pool field
ConnectionOptions: config, // set the connection options field
searchPath: searchPath, // set the search path field
schema: schema, // set the schema field
tableChangeNotifyOnceMutex: new(sync.RWMutex),
tableChangeNotifyFunctionOnce: new(uint32),
tableChangeNotifyTriggerOnce: make(map[string]struct{}),
}

return db // return the DB instance
Expand All @@ -194,11 +202,13 @@ func (db *DB) Close() {
// and returns a new DB pointer to instance.
func (db *DB) clone(tx pgx.Tx) *DB {
clone := &DB{
Pool: db.Pool,
ConnectionOptions: db.ConnectionOptions,
tx: tx,
schema: db.schema,
searchPath: db.searchPath,
Pool: db.Pool,
ConnectionOptions: db.ConnectionOptions,
tx: tx,
schema: db.schema,
searchPath: db.searchPath,
tableChangeNotifyFunctionOnce: db.tableChangeNotifyFunctionOnce,
tableChangeNotifyTriggerOnce: db.tableChangeNotifyTriggerOnce,
}

return clone
Expand Down Expand Up @@ -440,6 +450,27 @@ func (db *DB) ExecFiles(ctx context.Context, fileReader interface {
}

// Listen listens for notifications on the given channel and returns a Listener instance.
//
// Example Code:
//
// conn, err := db.Listen(context.Background(), channel)
// if err != nil {
// fmt.Println(fmt.Errorf("listen: %w\n", err))
// return
// }
//
// // To just terminate this listener's connection and unlisten from the channel:
// defer conn.Close(context.Background())
//
// for {
// notification, err := conn.Accept(context.Background())
// if err != nil {
// fmt.Println(fmt.Errorf("accept: %w\n", err))
// return
// }
//
// fmt.Printf("channel: %s, payload: %s\n", notification.Channel, notification.Payload)
// }
func (db *DB) Listen(ctx context.Context, channel string) (*Listener, error) {
conn, err := db.Pool.Acquire(ctx) // Always on top.
if err != nil {
Expand Down
134 changes: 134 additions & 0 deletions db_table_listener.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package pg

import (
"context"
"encoding/json"
"fmt"
"sync/atomic"
)

// TableChangeType is the type of the table change.
// Available values: INSERT, UPDATE, DELETE.
type TableChangeType string

const (
// TableChangeTypeInsert is the INSERT table change type.
TableChangeTypeInsert TableChangeType = "INSERT"
// TableChangeTypeUpdate is the UPDATE table change type.
TableChangeTypeUpdate TableChangeType = "UPDATE"
// TableChangeTypeDelete is the DELETE table change type.
TableChangeTypeDelete TableChangeType = "DELETE"
)

type (
// TableNotification is the notification message sent by the postgresql server
// when a table change occurs.
// The subscribed postgres channel is named 'table_change_notifications'.
// The "old" and "new" fields are the old and new values of the row.
// The "old" field is only available for UPDATE and DELETE table change types.
// The "new" field is only available for INSERT and UPDATE table change types.
// The "old" and "new" fields are raw json values, use the "json.Unmarshal" to decode them.
// See "DB.ListenTable" method.
TableNotification[T any] struct {
Table string `json:"table"`
Change TableChangeType `json:"change"` // INSERT, UPDATE, DELETE.

New T `json:"new"`
Old T `json:"old"`
}

// TableNotificationJSON is the generic version of the TableNotification.
TableNotificationJSON = TableNotification[json.RawMessage]
)

// ListenTable registers a function which notifies on the given "table" changes (INSERT, UPDATE, DELETE),
// the subscribed postgres channel is named 'table_change_notifications'.
//
// The callback function can return ErrStop to stop the listener without actual error.
// The callback function can return any other error to stop the listener and return the error.
// The callback function can return nil to continue listening.
//
// TableNotification's New and Old fields are raw json values, use the "json.Unmarshal" to decode them
// to the actual type.
func (db *DB) ListenTable(ctx context.Context, table string, callback func(TableNotificationJSON, error) error) (Closer, error) {
channelName := "table_change_notifications"

if atomic.LoadUint32(db.tableChangeNotifyFunctionOnce) == 0 {
// First, check and create the trigger for all tables.
query := fmt.Sprintf(`
CREATE OR REPLACE FUNCTION table_change_notify() RETURNS trigger AS $$
DECLARE
payload text;
channel text := '%s';
BEGIN
SELECT json_build_object('table', TG_TABLE_NAME, 'change', TG_OP, 'old', OLD, 'new', NEW)::text
INTO payload;
PERFORM pg_notify(channel, payload);
RETURN NEW;
END;
$$
LANGUAGE plpgsql;`, channelName)

_, err := db.Exec(ctx, query)
if err != nil {
return nil, fmt.Errorf("create or replace function table_change_notify: %w", err)
}

atomic.StoreUint32(db.tableChangeNotifyFunctionOnce, 1)
}

db.tableChangeNotifyOnceMutex.RLock()
_, triggerCreated := db.tableChangeNotifyTriggerOnce[table]
db.tableChangeNotifyOnceMutex.RUnlock()
if !triggerCreated {
query := `CREATE TRIGGER ` + table + `_table_change_notify
BEFORE INSERT OR
UPDATE OR
DELETE
ON ` + table + `
FOR EACH ROW
EXECUTE FUNCTION table_change_notify();`

_, err := db.Exec(ctx, query)
if err != nil {
return nil, fmt.Errorf("create trigger %s_table_change_notify: %w", table, err)
}

db.tableChangeNotifyOnceMutex.Lock()
db.tableChangeNotifyTriggerOnce[table] = struct{}{}
db.tableChangeNotifyOnceMutex.Unlock()
}

conn, err := db.Listen(ctx, channelName)
if err != nil {
return nil, err
}

go func() {
defer conn.Close(ctx)

for {
var evt TableNotificationJSON

notification, err := conn.Accept(context.Background())
if err != nil {
if callback(evt, err) != nil {
return
}
}

if err = json.Unmarshal([]byte(notification.Payload), &evt); err != nil {
if callback(evt, err) != nil {
return
}
}

if err = callback(evt, nil); err != nil {
return
}
}
}()

return conn, nil
}
58 changes: 58 additions & 0 deletions db_table_listener_example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package pg

import (
"context"
"fmt"
"time"
)

func ExampleDB_ListenTable() {
db, err := openTestConnection()
if err != nil {
handleExampleError(err)
return
}
defer db.Close()

closer, err := db.ListenTable(context.Background(), "customers", func(evt TableNotificationJSON, err error) error {
if err != nil {
fmt.Printf("received error: %v\n", err)
return err
}

if evt.Change == "INSERT" {
fmt.Printf("table: %s, event: %s, old: %s\n", evt.Table, evt.Change, string(evt.Old)) // new can't be predicated through its ID and timestamps.
} else {
fmt.Printf("table: %s, event: %s\n", evt.Table, evt.Change)
}

return nil
})
if err != nil {
fmt.Println(err)
return
}
defer closer.Close(context.Background())

newCustomer := Customer{
CognitoUserID: "766064d4-a2a7-442d-aa75-33493bb4dbb9",
Email: "[email protected]",
Name: "Makis",
}
err = db.InsertSingle(context.Background(), newCustomer, &newCustomer.ID)
if err != nil {
fmt.Println(err)
return
}

newCustomer.Name = "Makis_UPDATED"
_, err = db.UpdateOnlyColumns(context.Background(), []string{"name"}, newCustomer)
if err != nil {
fmt.Println(err)
return
}
time.Sleep(5 * time.Second) // give it sometime to receive the notifications.
// Output:
// table: customers, event: INSERT, old: null
// table: customers, event: UPDATE
}
27 changes: 22 additions & 5 deletions listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"sync/atomic"
"unsafe"

"github.com/jackc/pgx/v5/pgconn"
Expand All @@ -13,13 +14,22 @@ import (
// Notification is a type alias of pgconn.Notification type.
type Notification = pgconn.Notification

// Closer is the interface which is implemented by the Listener.
// It's used to close the underline connection.
type Closer interface {
Close(ctx context.Context) error
}

// Listener represents a postgres database LISTEN connection.
type Listener struct {
conn *pgxpool.Conn

channel string
closed uint32
}

var _ Closer = (*Listener)(nil)

// ErrEmptyPayload is returned when the notification payload is empty.
var ErrEmptyPayload = fmt.Errorf("empty payload")

Expand All @@ -46,15 +56,22 @@ func (l *Listener) Accept(ctx context.Context) (*Notification, error) {

// Close closes the listener connection.
func (l *Listener) Close(ctx context.Context) error {
if l == nil {
return nil
}

if l.conn == nil {
return nil
}
defer l.conn.Release()

query := `SELECT UNLISTEN $1;`
_, err := l.conn.Exec(ctx, query, l.channel)
if err != nil {
return err
if atomic.CompareAndSwapUint32(&l.closed, 0, 1) {
defer l.conn.Release()

query := `SELECT UNLISTEN $1;`
_, err := l.conn.Exec(ctx, query, l.channel)
if err != nil {
return err
}
}

return nil
Expand Down
43 changes: 43 additions & 0 deletions repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package pg

import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"

"github.com/kataras/pg/desc"
Expand Down Expand Up @@ -296,3 +298,44 @@ func toInterfaces[T any](values []T) []any {
func (repo *Repository[T]) Duplicate(ctx context.Context, id any, newIDPtr any) error {
return repo.db.duplicateTableRecord(ctx, repo.td, id, newIDPtr)
}

// ListenTable registers a function which notifies on the current table's changes (INSERT, UPDATE, DELETE),
// the subscribed postgres channel is named 'table_change_notifications'.
// The callback function is called on a separate goroutine.
//
// The callback function can return ErrStop to stop the listener without actual error.
// The callback function can return any other error to stop the listener and return the error.
// The callback function can return nil to continue listening.
func (repo *Repository[T]) ListenTable(ctx context.Context, callback func(TableNotification[T], error) error) (Closer, error) {
return repo.db.ListenTable(ctx, repo.td.Name, func(tableEvt TableNotificationJSON, err error) error {
if err != nil {
failEvt := TableNotification[T]{
Table: repo.td.Name,
Change: tableEvt.Change, // may empty.
}

return callback(failEvt, err)
}

evt := TableNotification[T]{
Table: tableEvt.Table,
Change: tableEvt.Change,
}

if len(tableEvt.Old) > 0 {
err := json.Unmarshal(tableEvt.Old, &evt.Old)
if err != nil {
return fmt.Errorf("table: %s: unmarshal old: %w", tableEvt.Table, err)
}
}

if len(tableEvt.New) > 0 {
err := json.Unmarshal(tableEvt.New, &evt.New)
if err != nil {
return fmt.Errorf("table: %s: unmarshal old: %w", tableEvt.Table, err)
}
}

return callback(evt, nil)
})
}
Loading

0 comments on commit e75bffa

Please sign in to comment.