Skip to content

Commit 450a18f

Browse files
Merge pull request #185 from hanchuanchuan/fix-trans-mix-ddl-dml
fix: 修复在事务中DDL和DML混合执行时可能出错的问题 (#182)
2 parents ec01295 + c165a35 commit 450a18f

File tree

5 files changed

+183
-23
lines changed

5 files changed

+183
-23
lines changed

session/conn.go

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func (s *session) Raw(sqlStr string) (rows *sql.Rows, err error) {
8787
return
8888
}
8989

90-
// Raw 执行sql语句,连接失败时自动重连,自动重置当前数据库
90+
// Exec 执行sql语句,连接失败时自动重连,自动重置当前数据库
9191
func (s *session) Exec(sqlStr string, retry bool) (res sql.Result, err error) {
9292
// 连接断开无效时,自动重试
9393
for i := 0; i < maxBadConnRetries; i++ {
@@ -114,6 +114,33 @@ func (s *session) Exec(sqlStr string, retry bool) (res sql.Result, err error) {
114114
return
115115
}
116116

117+
// ExecDDL 执行sql语句,连接失败时自动重连,自动重置当前数据库
118+
func (s *session) ExecDDL(sqlStr string, retry bool) (res sql.Result, err error) {
119+
// 连接断开无效时,自动重试
120+
for i := 0; i < maxBadConnRetries; i++ {
121+
res, err = s.ddlDB.DB().Exec(sqlStr)
122+
if err == nil {
123+
return
124+
} else {
125+
log.Errorf("con:%d %v sql:%s", s.sessionVars.ConnectionID, err, sqlStr)
126+
if err == mysqlDriver.ErrInvalidConn {
127+
err1 := s.initConnection()
128+
if err1 != nil {
129+
return res, err1
130+
}
131+
if retry {
132+
s.AppendErrorMessage(mysqlDriver.ErrInvalidConn.Error())
133+
continue
134+
} else {
135+
return
136+
}
137+
}
138+
return
139+
}
140+
}
141+
return
142+
}
143+
117144
// Raw 执行sql语句,连接失败时自动重连,自动重置当前数据库
118145
func (s *session) RawScan(sqlStr string, dest interface{}) (err error) {
119146
// 连接断开无效时,自动重试
@@ -180,3 +207,54 @@ func (s *session) initConnection() (err error) {
180207
}
181208
return
182209
}
210+
211+
// // SwitchDatabase USE切换到当前数据库. (避免连接断开后当前数据库置空)
212+
// func (s *session) SwitchDatabase(db *gorm.DB) error {
213+
// name := s.DBName
214+
// if name == "" {
215+
// name = s.opt.db
216+
// }
217+
// if name == "" {
218+
// return nil
219+
// }
220+
221+
// // log.Infof("SwitchDatabase: %v", name)
222+
// _, err := db.DB().Exec(fmt.Sprintf("USE `%s`", name))
223+
// if err != nil {
224+
// log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
225+
// if myErr, ok := err.(*mysqlDriver.MySQLError); ok {
226+
// s.AppendErrorMessage(myErr.Message)
227+
// } else {
228+
// s.AppendErrorMessage(err.Error())
229+
// }
230+
// }
231+
// return err
232+
// }
233+
234+
// // GetDatabase 获取当前数据库
235+
// func (s *session) GetDatabase() string {
236+
// log.Debug("GetDatabase")
237+
238+
// var value string
239+
// sql := "select database();"
240+
241+
// rows, err := s.Raw(sql)
242+
// if rows != nil {
243+
// defer rows.Close()
244+
// }
245+
246+
// if err != nil {
247+
// log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
248+
// if myErr, ok := err.(*mysqlDriver.MySQLError); ok {
249+
// s.AppendErrorMessage(myErr.Message)
250+
// } else {
251+
// s.AppendErrorMessage(err.Error())
252+
// }
253+
// } else {
254+
// for rows.Next() {
255+
// rows.Scan(&value)
256+
// }
257+
// }
258+
259+
// return value
260+
// }

session/session.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ type session struct {
169169
db *gorm.DB
170170
backupdb *gorm.DB
171171

172+
// 执行DDL操作的数据库连接. 仅用于事务功能
173+
ddlDB *gorm.DB
174+
172175
DBName string
173176

174177
myRecord *Record

session/session_inception.go

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"bytes"
2222
"crypto/tls"
2323
"crypto/x509"
24+
"database/sql"
2425
"database/sql/driver"
2526
"fmt"
2627
"io/ioutil"
@@ -525,6 +526,9 @@ func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqle
525526
if s.db != nil {
526527
defer s.db.Close()
527528
}
529+
if s.ddlDB != nil {
530+
defer s.ddlDB.Close()
531+
}
528532
if s.backupdb != nil {
529533
defer s.backupdb.Close()
530534
}
@@ -1655,6 +1659,8 @@ func (s *session) executeAllStatement(ctx context.Context) {
16551659
trans = make([]*Record, 0, s.opt.tranBatch)
16561660
}
16571661

1662+
// 用于事务. 判断是否为DML语句
1663+
// lastIsDMLTrans := false
16581664
for i, record := range s.recordSets.All() {
16591665

16601666
// 忽略不需要备份的类型
@@ -1684,11 +1690,13 @@ func (s *session) executeAllStatement(ctx context.Context) {
16841690
}
16851691
}
16861692
}
1693+
1694+
// lastIsDMLTrans = true
16871695
case *ast.UseStmt, *ast.SetStmt:
16881696
// 环境命令
16891697
// 事务内部和非事务均需要执行
16901698
// log.Infof("1111: [%s] [%d] %s,RowsAffected: %d", s.DBName, s.fetchThreadID(), record.Sql, record.AffectedRows)
1691-
_, err := s.Exec(record.Sql, true)
1699+
_, err := s.ExecDDL(record.Sql, true)
16921700
if err != nil {
16931701
// log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
16941702
if myErr, ok := err.(*mysqlDriver.MySQLError); ok {
@@ -1716,7 +1724,14 @@ func (s *session) executeAllStatement(ctx context.Context) {
17161724
trans = nil
17171725
}
17181726

1719-
s.executeRemoteCommand(record)
1727+
// 如果前端是DML语句,则在执行DDL前切换一次数据库
1728+
// log.Infof("lastIsDMLTrans: %v", lastIsDMLTrans)
1729+
// if lastIsDMLTrans {
1730+
// s.SwitchDatabase(s.ddlDB)
1731+
// lastIsDMLTrans = false
1732+
// }
1733+
1734+
s.executeRemoteCommand(record, true)
17201735

17211736
// trans = append(trans, record)
17221737
// s.executeTransaction(trans)
@@ -1731,7 +1746,7 @@ func (s *session) executeAllStatement(ctx context.Context) {
17311746
}
17321747
}
17331748
} else {
1734-
s.executeRemoteCommand(record)
1749+
s.executeRemoteCommand(record, false)
17351750
}
17361751

17371752
if s.hasErrorBefore() {
@@ -1946,7 +1961,7 @@ func (s *session) executeTransaction(records []*Record) int {
19461961
return 0
19471962
}
19481963

1949-
func (s *session) executeRemoteCommand(record *Record) int {
1964+
func (s *session) executeRemoteCommand(record *Record, isTran bool) int {
19501965

19511966
s.myRecord = record
19521967
record.Stage = StageExec
@@ -1972,7 +1987,7 @@ func (s *session) executeRemoteCommand(record *Record) int {
19721987
*ast.SetStmt,
19731988
*ast.DropIndexStmt:
19741989

1975-
s.executeRemoteStatement(record)
1990+
s.executeRemoteStatement(record, isTran)
19761991

19771992
default:
19781993
log.Infof("无匹配类型: %T\n", node)
@@ -2181,10 +2196,10 @@ func statisticsTableSQL() string {
21812196
return buf.String()
21822197
}
21832198

2184-
func (s *session) executeRemoteStatement(record *Record) {
2199+
func (s *session) executeRemoteStatement(record *Record, isTran bool) {
21852200
log.Debug("executeRemoteStatement")
21862201

2187-
sql := record.Sql
2202+
sqlStmt := record.Sql
21882203

21892204
start := time.Now()
21902205

@@ -2205,7 +2220,13 @@ func (s *session) executeRemoteStatement(record *Record) {
22052220

22062221
return
22072222
} else {
2208-
res, err := s.Exec(sql, false)
2223+
var res sql.Result
2224+
var err error
2225+
if isTran {
2226+
res, err = s.ExecDDL(sqlStmt, false)
2227+
} else {
2228+
res, err = s.Exec(sqlStmt, false)
2229+
}
22092230

22102231
record.ExecTime = fmt.Sprintf("%.3f", time.Since(start).Seconds())
22112232
record.ExecTimestamp = time.Now().Unix()
@@ -2295,7 +2316,7 @@ func (s *session) executeRemoteStatementAndBackup(record *Record) {
22952316
return
22962317
}
22972318

2298-
s.executeRemoteStatement(record)
2319+
s.executeRemoteStatement(record, false)
22992320

23002321
if !s.hasError() || record.ExecComplete {
23012322
if s.opt.backup {
@@ -2906,6 +2927,11 @@ func (s *session) parseOptions(sql string) {
29062927
return
29072928
}
29082929

2930+
if s.opt.tranBatch > 1 {
2931+
s.ddlDB, _ = gorm.Open("mysql", fmt.Sprintf("%s&autocommit=1", addr))
2932+
s.ddlDB.LogMode(false)
2933+
}
2934+
29092935
// 禁用日志记录器,不显示任何日志
29102936
db.LogMode(false)
29112937

session/session_inception_common_test.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -491,12 +491,12 @@ func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...stri
491491
if tableName == "" {
492492
sql := "select tablename from `%s`.`%s` where opid_time = ?"
493493
sql = fmt.Sprintf(sql, backupDBName, s.remoteBackupTable)
494-
rows, err := s.db.Raw(sql, opid).Rows()
494+
tableRows, err := s.db.Raw(sql, opid).Rows()
495495
c.Assert(err, IsNil)
496-
for rows.Next() {
497-
rows.Scan(&tableName)
496+
for tableRows.Next() {
497+
tableRows.Scan(&tableName)
498498
}
499-
rows.Close()
499+
tableRows.Close()
500500
}
501501
c.Assert(tableName, Not(Equals), "", Commentf("%v", row))
502502

@@ -507,10 +507,9 @@ func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...stri
507507

508508
// 如果表改变了,或者超过500行了
509509
if lastTable != currentTable || len(ids) >= 500 {
510-
lastTable = currentTable
511510
if len(ids) > 0 {
512511
sql := "select rollback_statement from %s where opid_time in (?) order by opid_time,id;"
513-
sql = fmt.Sprintf(sql, currentTable)
512+
sql = fmt.Sprintf(sql, lastTable)
514513
rows, err := s.db.Raw(sql, ids).Rows()
515514
c.Assert(err, IsNil)
516515

@@ -522,11 +521,12 @@ func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...stri
522521
}
523522
rows.Close()
524523

525-
c.Assert(len(result1), Not(Equals), 0, Commentf("-----------: %v", sql))
524+
c.Assert(len(result1), Not(Equals), 0, Commentf("-----------: %v,%v", sql, ids))
526525
result = append(result, result1...)
527526

528527
ids = nil
529528
}
529+
lastTable = currentTable
530530
}
531531

532532
ids = append(ids, opid)
@@ -536,22 +536,22 @@ func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...stri
536536
if len(ids) > 0 {
537537
sql := "select rollback_statement from %s where opid_time in (?) order by opid_time,id;"
538538
sql = fmt.Sprintf(sql, currentTable)
539-
rows, err := s.db.Raw(sql, ids).Rows()
539+
rollbackRows, err := s.db.Raw(sql, ids).Rows()
540540
c.Assert(err, IsNil)
541541

542542
str := ""
543543
result1 := []string{}
544-
for rows.Next() {
545-
rows.Scan(&str)
544+
for rollbackRows.Next() {
545+
rollbackRows.Scan(&str)
546546
result1 = append(result1, s.trim(str))
547547
}
548-
rows.Close()
548+
rollbackRows.Close()
549549

550-
c.Assert(len(result1), Not(Equals), 0, Commentf("------2-----: %v", sql))
550+
c.Assert(len(result1), Not(Equals), 0, Commentf("------2-----: %v", rows))
551551
result = append(result, result1...)
552552
}
553553

554-
c.Assert(len(result), Equals, len(rollbackSqls), Commentf("%v", rows))
554+
c.Assert(len(result), Equals, len(rollbackSqls), Commentf("%v", result))
555555

556556
// 如果是UPDATE多表操作,此时回滚的SQL可能是无序的
557557
if len(result) > 1 && strings.HasPrefix(result[0], "UPDATE") {

session/session_inception_tran_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,3 +444,56 @@ func (s *testSessionIncTranSuite) TestDelete(c *C) {
444444
c.Assert(backup, Equals, "INSERT INTO `test_inc`.`t1`(`id`,`c1`) VALUES(1,'😁😄🙂👩');", Commentf("%v", res.Rows()))
445445

446446
}
447+
448+
func (s *testSessionIncTranSuite) TestCreateTable(c *C) {
449+
saved := config.GetGlobalConfig().Inc
450+
defer func() {
451+
config.GetGlobalConfig().Inc = saved
452+
}()
453+
454+
var (
455+
res *testkit.Result
456+
// row []interface{}
457+
// backup string
458+
)
459+
460+
res = s.mustRunBackupTran(c, `DROP TABLE IF EXISTS t1,t2;
461+
462+
CREATE TABLE t1 (id int(11) NOT NULL,
463+
c1 int(11) DEFAULT NULL,
464+
c2 int(11) DEFAULT NULL,
465+
PRIMARY KEY (id));
466+
467+
INSERT INTO t1 VALUES (1, 1, 1);
468+
469+
CREATE TABLE t2 (id int(11) NOT NULL,
470+
c1 int(11) DEFAULT NULL,
471+
c2 int(11) DEFAULT NULL,
472+
PRIMARY KEY (id))`)
473+
s.assertRows(c, res.Rows()[2:],
474+
"DROP TABLE `test_inc`.`t1`;",
475+
"DELETE FROM `test_inc`.`t1` WHERE `id`=1;",
476+
"DROP TABLE `test_inc`.`t2`;")
477+
478+
res = s.mustRunBackupTran(c, `DROP TABLE IF EXISTS t1,t2;
479+
create table t1(id int primary key,c1 int);
480+
insert into t1 values(1,1),(2,2);
481+
delete from t1 where id=1;
482+
alter table t1 add column c2 int;
483+
insert into t1 values(3,3,3);
484+
delete from t1 where id>0;
485+
create table t2(id int primary key,c1 int);
486+
insert into t2 values(3,3);`)
487+
s.assertRows(c, res.Rows()[2:],
488+
"DROP TABLE `test_inc`.`t1`;",
489+
"DELETE FROM `test_inc`.`t1` WHERE `id`=1;",
490+
"DELETE FROM `test_inc`.`t1` WHERE `id`=2;",
491+
"INSERT INTO `test_inc`.`t1`(`id`,`c1`) VALUES(1,1);",
492+
"ALTER TABLE `test_inc`.`t1` DROP COLUMN `c2`;",
493+
"DELETE FROM `test_inc`.`t1` WHERE `id`=3;",
494+
"INSERT INTO `test_inc`.`t1`(`id`,`c1`,`c2`) VALUES(2,2,NULL);",
495+
"INSERT INTO `test_inc`.`t1`(`id`,`c1`,`c2`) VALUES(3,3,3);",
496+
"DROP TABLE `test_inc`.`t2`;",
497+
"DELETE FROM `test_inc`.`t2` WHERE `id`=3;")
498+
499+
}

0 commit comments

Comments
 (0)