Skip to content

Commit 7e90ff0

Browse files
authored
feat: Delete All User Data (RTBF) (#225)
* delete all user data * check return from rollback
1 parent d43c0d9 commit 7e90ff0

File tree

5 files changed

+115
-8
lines changed

5 files changed

+115
-8
lines changed

pkg/store/postgres/purge_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77
)
88

99
func TestPurgeDeleted(t *testing.T) {
10-
sessionID, err := setupTestDeleteData(testCtx, testDB)
10+
sessionID, err := setupSessionDeleteTestData(testCtx, testDB, "")
1111
assert.NoError(t, err, "setupTestDeleteData should not return an error")
1212

1313
sessionStore := NewSessionDAO(testDB)

pkg/store/postgres/session.go

+29-4
Original file line numberDiff line numberDiff line change
@@ -201,41 +201,66 @@ func (dao *SessionDAO) updateSession(
201201
}
202202

203203
// Delete soft-deletes a session from the database by its sessionID.
204-
// It also soft-deletes all messages and message embeddings associated with the session.
204+
// It also soft-deletes all messages, message embeddings, and summaries associated with the session.
205205
func (dao *SessionDAO) Delete(ctx context.Context, sessionID string) error {
206206
dbSession := &SessionSchema{}
207207

208-
r, err := dao.db.NewDelete().
208+
tx, err := dao.db.BeginTx(ctx, nil)
209+
if err != nil {
210+
return fmt.Errorf("failed to begin transaction: %w", err)
211+
}
212+
213+
r, err := tx.NewDelete().
209214
Model(dbSession).
210215
Where("session_id = ?", sessionID).
211216
Exec(ctx)
212217
if err != nil {
218+
rollbackErr := tx.Rollback()
219+
if rollbackErr != nil {
220+
return fmt.Errorf("failed to delete session: %v, failed to rollback transaction: %w", err, rollbackErr)
221+
}
213222
return fmt.Errorf("failed to delete session: %w", err)
214223
}
215224

216225
rowsAffected, err := r.RowsAffected()
217226
if err != nil {
227+
rollbackErr := tx.Rollback()
228+
if rollbackErr != nil {
229+
return fmt.Errorf("failed to delete session: %v, failed to rollback transaction: %w", err, rollbackErr)
230+
}
218231
return fmt.Errorf("failed to get rows affected: %w", err)
219232
}
220233
if rowsAffected == 0 {
234+
rollbackErr := tx.Rollback()
235+
if rollbackErr != nil {
236+
return fmt.Errorf("failed to delete session: %v, failed to rollback transaction: %w", err, rollbackErr)
237+
}
221238
return models.NewNotFoundError("session " + sessionID)
222239
}
223240

224-
// delete all messages and message embeddings associated with the session
241+
// delete all messages, message embeddings, and summaries associated with the session
225242
for _, schema := range messageTableList {
226243
if _, ok := schema.(*SessionSchema); ok {
227244
continue
228245
}
229246
log.Debugf("deleting session %s from schema %T", sessionID, schema)
230-
_, err := dao.db.NewDelete().
247+
_, err := tx.NewDelete().
231248
Model(schema).
232249
Where("session_id = ?", sessionID).
233250
Exec(ctx)
234251
if err != nil {
252+
rollbackErr := tx.Rollback()
253+
if rollbackErr != nil {
254+
return fmt.Errorf("failed to delete session: %v, failed to rollback transaction: %w", err, rollbackErr)
255+
}
235256
return fmt.Errorf("error deleting rows from %T: %w", schema, err)
236257
}
237258
}
238259

260+
if err := tx.Commit(); err != nil {
261+
return fmt.Errorf("failed to commit transaction: %w", err)
262+
}
263+
239264
return nil
240265
}
241266

pkg/store/postgres/session_test.go

+9-3
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ func TestSessionDAO_DeleteSessionDeletesSummaryMessages(t *testing.T) {
246246

247247
sessionStore := NewSessionDAO(testDB)
248248

249-
sessionID, err := setupTestDeleteData(testCtx, testDB)
249+
sessionID, err := setupSessionDeleteTestData(testCtx, testDB, "")
250250
assert.NoError(t, err, "setupTestDeleteData should not return an error")
251251

252252
err = sessionStore.Delete(testCtx, sessionID)
@@ -268,7 +268,7 @@ func TestSessionDAO_DeleteSessionDeletesSummaryMessages(t *testing.T) {
268268
}
269269

270270
func TestSessionDAO_UndeleteSession(t *testing.T) {
271-
sessionID, err := setupTestDeleteData(testCtx, testDB)
271+
sessionID, err := setupSessionDeleteTestData(testCtx, testDB, "")
272272
assert.NoError(t, err, "setupTestDeleteData should not return an error")
273273

274274
sessionStore := NewSessionDAO(testDB)
@@ -292,16 +292,22 @@ func TestSessionDAO_UndeleteSession(t *testing.T) {
292292
assert.Nil(t, respMessages, "getMessages should return nil")
293293
}
294294

295-
func setupTestDeleteData(ctx context.Context, testDB *bun.DB) (string, error) {
295+
func setupSessionDeleteTestData(ctx context.Context, testDB *bun.DB, userID string) (string, error) {
296296
// Test data
297297
sessionID, err := testutils.GenerateRandomSessionID(16)
298298
if err != nil {
299299
return "", err
300300
}
301301

302+
var userIDPtr *string
303+
if userID != "" {
304+
userIDPtr = &userID
305+
}
306+
302307
dao := NewSessionDAO(testDB)
303308
_, err = dao.Create(ctx, &models.CreateSessionRequest{
304309
SessionID: sessionID,
310+
UserID: userIDPtr,
305311
})
306312
if err != nil {
307313
return "", err

pkg/store/postgres/userstore.go

+47
Original file line numberDiff line numberDiff line change
@@ -170,18 +170,65 @@ func (dao *UserStoreDAO) updateUser(
170170

171171
// Delete deletes a user.
172172
func (dao *UserStoreDAO) Delete(ctx context.Context, userID string) error {
173+
// Start a new transaction
174+
tx, err := dao.db.Begin()
175+
if err != nil {
176+
return err
177+
}
178+
179+
// Delete all related sessions
180+
sessions, err := dao.GetSessions(ctx, userID)
181+
if err != nil {
182+
rollbackErr := tx.Rollback()
183+
if rollbackErr != nil {
184+
return fmt.Errorf("failed to delete user: %v, failed to rollback transaction: %w", err, rollbackErr)
185+
}
186+
return err
187+
}
188+
189+
sessionStore := NewSessionDAO(dao.db)
190+
for s := range sessions {
191+
err := sessionStore.Delete(ctx, sessions[s].SessionID)
192+
if err != nil {
193+
rollbackErr := tx.Rollback()
194+
if rollbackErr != nil {
195+
return fmt.Errorf("failed to delete user: %v, failed to rollback transaction: %w", err, rollbackErr)
196+
}
197+
return err
198+
}
199+
}
200+
201+
// Delete User
173202
r, err := dao.db.NewDelete().Model(&models.User{}).Where("user_id = ?", userID).Exec(ctx)
174203
if err != nil {
204+
rollbackErr := tx.Rollback()
205+
if rollbackErr != nil {
206+
return fmt.Errorf("failed to delete user: %v, failed to rollback transaction: %w", err, rollbackErr)
207+
}
175208
return err
176209
}
177210
rowsAffected, err := r.RowsAffected()
178211
if err != nil {
212+
rollbackErr := tx.Rollback()
213+
if rollbackErr != nil {
214+
return fmt.Errorf("failed to delete user: %v, failed to rollback transaction: %w", err, rollbackErr)
215+
}
179216
return err
180217
}
181218
if rowsAffected == 0 {
219+
rollbackErr := tx.Rollback()
220+
if rollbackErr != nil {
221+
return fmt.Errorf("failed to delete user: %v, failed to rollback transaction: %w", err, rollbackErr)
222+
}
182223
return models.NewNotFoundError("user " + userID)
183224
}
184225

226+
// Commit the transaction
227+
err = tx.Commit()
228+
if err != nil {
229+
return err
230+
}
231+
185232
return nil
186233
}
187234

pkg/store/postgres/userstore_test.go

+29
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,40 @@ func TestUserStoreDAO(t *testing.T) {
135135

136136
// Test Delete
137137
t.Run("Delete", func(t *testing.T) {
138+
testSessions := []string{}
139+
for i := 0; i < 2; i++ {
140+
sessionID, err := setupSessionDeleteTestData(testCtx, testDB, user.UserID)
141+
assert.NoError(t, err, "setupTestDeleteData should not return an error")
142+
testSessions = append(testSessions, sessionID)
143+
}
144+
138145
err := userStore.Delete(ctx, user.UserID)
139146
assert.NoError(t, err)
140147

141148
_, err = userStore.Get(ctx, user.UserID)
142149
assert.ErrorIs(t, err, models.ErrNotFound)
150+
151+
// Check that all related sessions are deleted
152+
retSessions, err := userStore.GetSessions(ctx, user.UserID)
153+
assert.NoError(t, err)
154+
assert.Equal(t, 0, len(retSessions))
155+
156+
// Test that messages and summaries are deleted
157+
for _, sessionID := range testSessions {
158+
respMessages, err := getMessages(testCtx, testDB, sessionID, 999, nil, 999)
159+
assert.NoError(t, err, "getMessages should not return an error")
160+
assert.Nil(t, respMessages, "getMessages should return nil")
161+
162+
// Test that summary is deleted
163+
respSummary, err := getSummary(testCtx, testDB, sessionID)
164+
assert.NoError(t, err, "getSummary should not return an error")
165+
assert.Nil(t, respSummary, "getSummary should return nil")
166+
167+
// check that embeddings are deleted
168+
respEmbeddings, err := getMessageEmbeddings(testCtx, testDB, sessionID)
169+
assert.NoError(t, err, "getMessageEmbeddings should not return an error")
170+
assert.Equal(t, 0, len(respEmbeddings), "getMessageEmbeddings should return 0 results")
171+
}
143172
})
144173

145174
t.Run("Delete Non-Existant Session should result in NotFoundError", func(t *testing.T) {

0 commit comments

Comments
 (0)