Skip to content
This repository was archived by the owner on Jun 27, 2025. It is now read-only.

Commit 4594709

Browse files
samhzaIoIxD
authored andcommitted
allow for caching messages in postgres
1 parent 913815e commit 4594709

File tree

7 files changed

+552
-62
lines changed

7 files changed

+552
-62
lines changed

database/database.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package database
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
"github.com/diamondburned/arikawa/v3/discord"
8+
)
9+
10+
type Database interface {
11+
Close() error
12+
13+
SetUpdatedAt(ctx context.Context, post discord.ChannelID, t time.Time) error
14+
UpdatedAt(ctx context.Context, post discord.ChannelID) (time.Time, error)
15+
UpdateMessages(ctx context.Context, post discord.ChannelID, msgs []discord.Message) error
16+
InsertMessage(ctx context.Context, msg discord.Message) error
17+
UpdateMessage(ctx context.Context, msg discord.Message) error
18+
DeleteMessage(ctx context.Context, msg discord.MessageID) error
19+
MessagesAfter(ctx context.Context, post discord.ChannelID, after discord.MessageID, limit uint) ([]discord.Message, bool, error)
20+
MessagesBefore(ctx context.Context, post discord.ChannelID, before discord.MessageID, limit uint) ([]discord.Message, bool, error)
21+
}

database/postgres.go

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
package database
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"encoding/json"
7+
"errors"
8+
"fmt"
9+
"log"
10+
"time"
11+
12+
"github.com/diamondburned/arikawa/v3/discord"
13+
_ "github.com/lib/pq"
14+
)
15+
16+
const postgresConfigSchema = `
17+
CREATE TABLE IF NOT EXISTS "Config" (
18+
id SMALLINT PRIMARY KEY,
19+
version INTEGER NOT NULL,
20+
CHECK(id = 1)
21+
);
22+
`
23+
24+
const postgresSchema = `
25+
CREATE TABLE "Message" (
26+
id BIGINT NOT NULL PRIMARY KEY,
27+
edited_at TIMESTAMP WITH TIME ZONE,
28+
author BIGINT NOT NULL,
29+
channel BIGINT NOT NULL,
30+
content TEXT NOT NULL,
31+
json TEXT NOT NULL
32+
);
33+
34+
CREATE TABLE "Channel" (
35+
id BIGINT NOT NULL PRIMARY KEY,
36+
updated_at TIMESTAMP NOT NULL
37+
);
38+
`
39+
40+
var postgresMigrations = []string{""}
41+
42+
type Postgres struct {
43+
db *sql.DB
44+
connectedAt time.Time
45+
}
46+
47+
func (db *Postgres) Close() error {
48+
return db.db.Close()
49+
}
50+
51+
func (db *Postgres) SetUpdatedAt(ctx context.Context, post discord.ChannelID, time time.Time) error {
52+
_, err := db.db.ExecContext(ctx, `UPDATE "Channel" SET updated_at = $1 WHERE id = $2`, time, post)
53+
return err
54+
}
55+
56+
func (db *Postgres) UpdatedAt(ctx context.Context, post discord.ChannelID) (time.Time, error) {
57+
var t time.Time
58+
err := db.db.QueryRowContext(ctx, `SELECT updated_at FROM "Channel" WHERE id = $1`, post).Scan(&t)
59+
if err != nil && !errors.Is(err, sql.ErrNoRows) {
60+
return time.Time{}, err
61+
}
62+
return t, nil
63+
}
64+
65+
func (db *Postgres) UpdateMessages(ctx context.Context, post discord.ChannelID, msgs []discord.Message) error {
66+
tx, err := db.db.BeginTx(ctx, nil)
67+
if err != nil {
68+
return err
69+
}
70+
defer tx.Rollback()
71+
var exists bool
72+
err = tx.QueryRowContext(ctx, `SELECT EXISTS(SELECT 1 FROM "Channel" WHERE id = $1)`, post).Scan(&exists)
73+
if err != nil {
74+
return fmt.Errorf("reading channel cache information: %w", err)
75+
}
76+
if exists {
77+
_, err = tx.ExecContext(ctx, `UPDATE "Channel" SET updated_at = $1 WHERE id = $2`, time.Now().UTC(), post)
78+
} else {
79+
_, err = tx.ExecContext(ctx, `INSERT INTO "Channel" (id, updated_at) VALUES ($1, $2)`, post, time.Now().UTC())
80+
}
81+
if err != nil {
82+
return fmt.Errorf("writing channel cache information: %w", err)
83+
}
84+
insert, err := tx.PrepareContext(ctx, `INSERT INTO "Message" (id, author, channel, edited_at, content, json) VALUES($1, $2, $3, $4, $5, $6)`)
85+
if err != nil {
86+
return err
87+
}
88+
if !exists {
89+
for _, msg := range msgs {
90+
content := msg.Content
91+
msg.Content = ""
92+
jsonb, err := json.Marshal(msg)
93+
if err != nil {
94+
return err
95+
}
96+
_, err = insert.ExecContext(ctx, msg.ID, msg.Author.ID, msg.ChannelID, msg.EditedTimestamp.Time(), content, jsonb)
97+
if err != nil {
98+
return fmt.Errorf("inserting message: %w", err)
99+
}
100+
}
101+
return tx.Commit()
102+
}
103+
rows, err := tx.QueryContext(ctx, `SELECT id, edited_at FROM "Message" WHERE channel = $1 ORDER BY id ASC`, post)
104+
if err != nil {
105+
return err
106+
}
107+
var toDelete []discord.MessageID
108+
var toInsert []discord.Message
109+
var toUpdate []discord.Message
110+
for _, msg := range msgs {
111+
var id discord.MessageID
112+
var updated time.Time
113+
exists := false
114+
for rows.Next() {
115+
if err = rows.Scan(&id, &updated); err != nil {
116+
return err
117+
}
118+
if id != msg.ID {
119+
toDelete = append(toDelete, id)
120+
continue
121+
}
122+
exists = true
123+
}
124+
if !exists {
125+
toInsert = append(toInsert, msg)
126+
continue
127+
}
128+
if updated.Before(msg.EditedTimestamp.Time()) {
129+
toUpdate = append(toUpdate, msg)
130+
}
131+
}
132+
for rows.Next() {
133+
var id discord.MessageID
134+
var updated time.Time
135+
if err = rows.Scan(&id, &updated); err != nil {
136+
return err
137+
}
138+
toDelete = append(toDelete, id)
139+
}
140+
if len(toDelete) > 0 {
141+
del, err := tx.PrepareContext(ctx, `DELETE FROM "Message" WHERE ID = $1`)
142+
if err != nil {
143+
return err
144+
}
145+
for _, id := range toDelete {
146+
if _, err := del.ExecContext(ctx, id); err != nil {
147+
return err
148+
}
149+
}
150+
}
151+
if len(toUpdate) > 0 {
152+
update, err := tx.PrepareContext(ctx, `UPDATE "Message" SET content = $1, edited_at = $2, json = $3 WHERE id = $4`)
153+
if err != nil {
154+
return err
155+
}
156+
for _, msg := range toUpdate {
157+
content := msg.Content
158+
msg.Content = ""
159+
jsonb, err := json.Marshal(msg)
160+
if err != nil {
161+
return fmt.Errorf("marshaling message as JSON: %v", err)
162+
}
163+
if _, err := update.ExecContext(ctx, content, msg.EditedTimestamp.Time(), jsonb, msg.ID); err != nil {
164+
return err
165+
}
166+
}
167+
}
168+
if len(toInsert) > 0 {
169+
for _, msg := range toInsert {
170+
content := msg.Content
171+
msg.Content = ""
172+
jsonb, err := json.Marshal(msg)
173+
if err != nil {
174+
return err
175+
}
176+
_, err = insert.ExecContext(ctx, msg.ID, msg.Author.ID, msg.ChannelID, msg.EditedTimestamp.Time(), content, jsonb)
177+
if err != nil {
178+
return fmt.Errorf("inserting message: %w", err)
179+
}
180+
}
181+
}
182+
return tx.Commit()
183+
}
184+
185+
func (db *Postgres) InsertMessage(ctx context.Context, msg discord.Message) error {
186+
content := msg.Content
187+
msg.Content = ""
188+
jsonb, err := json.Marshal(msg)
189+
if err != nil {
190+
return fmt.Errorf("marshaling message as JSON: %v", err)
191+
}
192+
_, err = db.db.ExecContext(ctx, `INSERT INTO "Message" (id, author, channel, edited_at, content, json) VALUES($1, $2, $3, $4, $5, $6) ON CONFLICT DO NOTHING`,
193+
msg.ID, msg.Author.ID, msg.ChannelID, msg.EditedTimestamp.Time(), content, jsonb)
194+
return err
195+
}
196+
197+
func (db *Postgres) DeleteMessage(ctx context.Context, msg discord.MessageID) error {
198+
_, err := db.db.ExecContext(ctx, `DELETE FROM "Message" WHERE id = $1`, msg)
199+
return err
200+
}
201+
202+
func (db *Postgres) UpdateMessage(ctx context.Context, msg discord.Message) error {
203+
tx, err := db.db.BeginTx(ctx, nil)
204+
if err != nil {
205+
return err
206+
}
207+
var edited time.Time
208+
err = tx.QueryRowContext(ctx, `SELECT edited_at FROM "Message" WHERE id = $1`, msg.ID).Scan(&edited)
209+
if err != nil && errors.Is(err, sql.ErrNoRows) {
210+
return err
211+
}
212+
if msg.EditedTimestamp.Time().Before(edited) {
213+
return nil
214+
}
215+
content := msg.Content
216+
msg.Content = ""
217+
jsonb, err := json.Marshal(msg)
218+
if err != nil {
219+
return fmt.Errorf("marshaling message as JSON: %v", err)
220+
}
221+
_, err = db.db.ExecContext(ctx, `UPDATE "Message" SET content = $1, edited_at = $2, json = $3 WHERE id = $4`,
222+
content, msg.EditedTimestamp.Time(), jsonb, msg.ID)
223+
return err
224+
}
225+
226+
func (db *Postgres) MessagesAfter(ctx context.Context, ch discord.ChannelID, msg discord.MessageID, limit uint) (msgs []discord.Message, hasbefore bool, err error) {
227+
tx, err := db.db.BeginTx(ctx, nil)
228+
if err != nil {
229+
return
230+
}
231+
defer tx.Rollback()
232+
err = tx.QueryRowContext(ctx, `SELECT EXISTS (SELECT 1 FROM "Message" WHERE channel = $1 AND id <= $2)`, ch, msg).Scan(&hasbefore)
233+
if err != nil {
234+
return
235+
}
236+
rows, err := db.db.QueryContext(ctx, `SELECT content, json FROM "Message" WHERE channel = $1 AND id > $2 ORDER BY id ASC LIMIT $3`,
237+
ch, msg, limit)
238+
if err != nil {
239+
err = fmt.Errorf("querying messages: %v", err)
240+
return
241+
}
242+
for rows.Next() {
243+
var content string
244+
var jsonb []byte
245+
if err = rows.Scan(&content, &jsonb); err != nil {
246+
fmt.Errorf("error scanning message: %w", err)
247+
return
248+
}
249+
var msg discord.Message
250+
if err = json.Unmarshal(jsonb, &msg); err != nil {
251+
fmt.Errorf("unmrshaling message content: %w", err)
252+
return
253+
}
254+
msg.Content = content
255+
msgs = append(msgs, msg)
256+
}
257+
return
258+
}
259+
260+
func (db *Postgres) MessagesBefore(ctx context.Context, ch discord.ChannelID, msg discord.MessageID, limit uint) (msgs []discord.Message, hasafter bool, err error) {
261+
tx, err := db.db.BeginTx(ctx, nil)
262+
if err != nil {
263+
return
264+
}
265+
defer tx.Rollback()
266+
err = tx.QueryRowContext(ctx, `SELECT EXISTS (SELECT 1 FROM "Message" WHERE channel = $1 AND id >= $2)`, ch, msg).Scan(&hasafter)
267+
if err != nil {
268+
return
269+
}
270+
rows, err := tx.QueryContext(ctx, `SELECT content, json FROM (SELECT id, content, json FROM "Message" WHERE channel = $1 AND id < $2 ORDER BY id DESC LIMIT $3) ORDER BY id ASC`,
271+
ch, msg, limit)
272+
if err != nil {
273+
err = fmt.Errorf("querying messages: %v", err)
274+
return
275+
}
276+
for rows.Next() {
277+
var content string
278+
var jsonb []byte
279+
if err = rows.Scan(&content); err != nil {
280+
err = fmt.Errorf("error scanning message content: %w", err)
281+
return
282+
}
283+
if err = rows.Scan(&jsonb); err != nil {
284+
err = fmt.Errorf("error scanning message json: %w", err)
285+
return
286+
}
287+
var msg discord.Message
288+
if err = json.Unmarshal(jsonb, &msg); err != nil {
289+
err = fmt.Errorf("unmrshaling message content: %w", err)
290+
return
291+
}
292+
msg.Content = content
293+
msgs = append(msgs, msg)
294+
}
295+
return
296+
}
297+
298+
func OpenPostgres(source string) (Database, error) {
299+
sqldb, err := sql.Open("postgres", source)
300+
if err != nil {
301+
return nil, err
302+
}
303+
sqldb.SetMaxOpenConns(25)
304+
305+
db := &Postgres{db: sqldb, connectedAt: time.Now()}
306+
if err := db.upgrade(); err != nil {
307+
sqldb.Close()
308+
return nil, err
309+
}
310+
return db, nil
311+
}
312+
313+
func (db *Postgres) upgrade() error {
314+
tx, err := db.db.Begin()
315+
if err != nil {
316+
return fmt.Errorf("couldn't start db transaction: %w", err)
317+
}
318+
defer tx.Rollback()
319+
_, err = tx.Exec(postgresConfigSchema)
320+
if err != nil {
321+
return err
322+
}
323+
var version int
324+
err = tx.QueryRow(`SELECT version FROM "Config"`).Scan(&version)
325+
if err != nil && !errors.Is(err, sql.ErrNoRows) {
326+
return fmt.Errorf("couldn't query schema version: %v", err)
327+
}
328+
if version > len(postgresMigrations) {
329+
log.Fatalln("database is from a newer dforum")
330+
}
331+
if version == 0 {
332+
if _, err := tx.Exec(postgresSchema); err != nil {
333+
return fmt.Errorf("failed while executing schema: %v", err)
334+
}
335+
} else if version < len(postgresMigrations) {
336+
for version < len(postgresMigrations) {
337+
_, err := tx.Exec(postgresMigrations[version])
338+
if err != nil {
339+
return fmt.Errorf("failed while executing migration %d: %v", version, err)
340+
}
341+
version++
342+
}
343+
}
344+
_, err = tx.Exec(`INSERT INTO "Config" (id, version) VALUES (1, $1)
345+
ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations))
346+
if err != nil {
347+
return fmt.Errorf("failed to change schema version: %v", err)
348+
}
349+
return tx.Commit()
350+
}

0 commit comments

Comments
 (0)