@@ -2,6 +2,8 @@ package testutils
22
33import (
44 "bytes"
5+ "context"
6+ "database/sql"
57 "encoding/json"
68 "fmt"
79 "github.com/go-jet/jet/v2/internal/jet"
@@ -25,6 +27,18 @@ var UnixTimeComparer = cmp.Comparer(func(t1, t2 time.Time) bool {
2527 return t1 .Unix () == t2 .Unix ()
2628})
2729
30+ // AssertExecAndRollback will execute and rollback statement in sql transaction
31+ func AssertExecAndRollback (t * testing.T , stmt jet.Statement , db * sql.DB , rowsAffected ... int64 ) {
32+ tx , err := db .Begin ()
33+ require .NoError (t , err )
34+ defer func () {
35+ err := tx .Rollback ()
36+ require .NoError (t , err )
37+ }()
38+
39+ AssertExec (t , stmt , tx , rowsAffected ... )
40+ }
41+
2842// AssertExec assert statement execution for successful execution and number of rows affected
2943func AssertExec (t * testing.T , stmt jet.Statement , db qrm.DB , rowsAffected ... int64 ) {
3044 res , err := stmt .Exec (db )
@@ -38,13 +52,32 @@ func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int
3852 }
3953}
4054
55+ // ExecuteInTxAndRollback will execute function in sql transaction and then rollback transaction
56+ func ExecuteInTxAndRollback (t * testing.T , db * sql.DB , f func (tx * sql.Tx )) {
57+ tx , err := db .Begin ()
58+ require .NoError (t , err )
59+ defer func () {
60+ err := tx .Rollback ()
61+ require .NoError (t , err )
62+ }()
63+
64+ f (tx )
65+ }
66+
4167// AssertExecErr assert statement execution for failed execution with error string errorStr
4268func AssertExecErr (t * testing.T , stmt jet.Statement , db qrm.DB , errorStr string ) {
4369 _ , err := stmt .Exec (db )
4470
4571 require .Error (t , err , errorStr )
4672}
4773
74+ // AssertExecContextErr assert statement execution for failed execution with error string errorStr
75+ func AssertExecContextErr (t * testing.T , stmt jet.Statement , ctx context.Context , db qrm.DB , errorStr string ) {
76+ _ , err := stmt .ExecContext (ctx , db )
77+
78+ require .Error (t , err , errorStr )
79+ }
80+
4881func getFullPath (relativePath string ) string {
4982 path , _ := os .Getwd ()
5083 return filepath .Join (path , "../" , relativePath )
0 commit comments