diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 63fb7ed21c..04688a379c 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -22,7 +22,6 @@ import ( "errors" "slices" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -108,6 +107,11 @@ func (n *notificationDB) GetNotificationsByIDs(ctx context.Context, ids []string notifs, err := n.state.Caches.GTS.Notification.LoadIDs("ID", ids, func(uncached []string) ([]*gtsmodel.Notification, error) { + // Skip query if everything was cached. + if len(uncached) == 0 { + return nil, nil + } + // Preallocate expected length of uncached notifications. notifs := make([]*gtsmodel.Notification, 0, len(uncached)) @@ -282,26 +286,18 @@ func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.No } func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) error { - defer n.state.Caches.GTS.Notification.Invalidate("ID", id) - - // Load notif into cache before attempting a delete, - // as we need it cached in order to trigger the invalidate - // callback. This in turn invalidates others. - _, err := n.GetNotificationByID(gtscontext.SetBarebones(ctx), id) - if err != nil { - if errors.Is(err, db.ErrNoEntries) { - // not an issue. - err = nil - } + // Delete notif from DB. + if _, err := n.db. + NewDelete(). + Table("notifications"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { return err } - // Finally delete notif from DB. - _, err = n.db.NewDelete(). - TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). - Where("? = ?", bun.Ident("notification.id"), id). - Exec(ctx) - return err + // Invalidate deleted notification by ID. + n.state.Caches.GTS.Notification.Invalidate("ID", id) + return nil } func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) error { @@ -309,11 +305,8 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string return errors.New("DeleteNotifications: one of targetAccountID or originAccountID must be set") } - var notifIDs []string - q := n.db. - NewSelect(). - Column("id"). + NewDelete(). Table("notifications") if len(types) > 0 { @@ -328,61 +321,33 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string q = q.Where("? = ?", bun.Ident("origin_account_id"), originAccountID) } - if _, err := q.Exec(ctx, ¬ifIDs); err != nil { - return err - } - - // Invalidate all cached notifications by IDs on return. - defer n.state.Caches.GTS.Notification.InvalidateIDs("ID", notifIDs) + var notifIDs []string + q = q.Returning("?", bun.Ident("id")) - // Load all notif into cache, this *really* isn't great - // but it is the only way we can ensure we invalidate all - // related caches correctly (e.g. visibility). - for _, id := range notifIDs { - _, err := n.GetNotificationByID(ctx, id) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return err - } + // Delete from DB. + if _, err := q. + Exec(ctx, ¬ifIDs); err != nil { + return err } - // Finally delete all from DB. - _, err := n.db.NewDelete(). - Table("notifications"). - Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)). - Exec(ctx) - return err + // Invalidate all deleted notifications by IDs. + n.state.Caches.GTS.Notification.InvalidateIDs("ID", notifIDs) + return nil } func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) error { var notifIDs []string - q := n.db. - NewSelect(). - Column("id"). + if _, err := n.db. + NewDelete(). Table("notifications"). - Where("? = ?", bun.Ident("status_id"), statusID) - - if _, err := q.Exec(ctx, ¬ifIDs); err != nil { + Where("? = ?", bun.Ident("status_id"), statusID). + Returning("?", bun.Ident("id")). + Exec(ctx, ¬ifIDs); err != nil { return err } - // Invalidate all cached notifications by IDs on return. - defer n.state.Caches.GTS.Notification.InvalidateIDs("ID", notifIDs) - - // Load all notif into cache, this *really* isn't great - // but it is the only way we can ensure we invalidate all - // related caches correctly (e.g. visibility). - for _, id := range notifIDs { - _, err := n.GetNotificationByID(ctx, id) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return err - } - } - - // Finally delete all from DB. - _, err := n.db.NewDelete(). - Table("notifications"). - Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)). - Exec(ctx) - return err + // Invalidate all deleted notifications by IDs. + n.state.Caches.GTS.Notification.InvalidateIDs("ID", notifIDs) + return nil } diff --git a/internal/db/bundb/notification_test.go b/internal/db/bundb/notification_test.go index 9cc2e4743c..984c0ef8dd 100644 --- a/internal/db/bundb/notification_test.go +++ b/internal/db/bundb/notification_test.go @@ -73,7 +73,7 @@ func (suite *NotificationTestSuite) spamNotifs() { Read: util.Ptr(false), } - if err := suite.db.Put(context.Background(), notif); err != nil { + if err := suite.db.PutNotification(context.Background(), notif); err != nil { panic(err) } } @@ -133,9 +133,8 @@ func (suite *NotificationTestSuite) TestGetAccountNotificationsWithoutSpam() { func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() { suite.spamNotifs() testAccount := suite.testAccounts["local_account_1"] - err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, "") - suite.NoError(err) + // Test getting notifs first. notifications, err := suite.db.GetAccountNotifications( gtscontext.SetBarebones(context.Background()), testAccount.ID, @@ -145,8 +144,29 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() { 20, nil, ) - suite.NoError(err) - suite.Nil(notifications) + if err != nil { + suite.FailNow(err.Error()) + } + suite.Len(notifications, 20) + + // Now delete. + if err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, ""); err != nil { + suite.FailNow(err.Error()) + } + + // Now try getting again. + notifications, err = suite.db.GetAccountNotifications( + gtscontext.SetBarebones(context.Background()), + testAccount.ID, + id.Highest, + id.Lowest, + "", + 20, + nil, + ) + if err != nil { + suite.FailNow(err.Error()) + } suite.Empty(notifications) }