Skip to content

Commit 7c30fb0

Browse files
Merge pull request #1462 from github/meiji163/multi-stmt
Use multiStatement to apply DML
2 parents a834c00 + 2e62f2a commit 7c30fb0

File tree

1 file changed

+62
-22
lines changed

1 file changed

+62
-22
lines changed

go/logic/applier.go

+62-22
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@ import (
1414

1515
"github.com/github/gh-ost/go/base"
1616
"github.com/github/gh-ost/go/binlog"
17-
"github.com/github/gh-ost/go/mysql"
1817
"github.com/github/gh-ost/go/sql"
1918

20-
"github.com/openark/golib/log"
19+
"context"
20+
"database/sql/driver"
21+
22+
"github.com/github/gh-ost/go/mysql"
23+
drivermysql "github.com/go-sql-driver/mysql"
2124
"github.com/openark/golib/sqlutils"
2225
)
2326

@@ -77,7 +80,8 @@ func NewApplier(migrationContext *base.MigrationContext) *Applier {
7780

7881
func (this *Applier) InitDBConnections() (err error) {
7982
applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName)
80-
if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, applierUri); err != nil {
83+
uriWithMulti := fmt.Sprintf("%s&multiStatements=true", applierUri)
84+
if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, uriWithMulti); err != nil {
8185
return err
8286
}
8387
singletonApplierUri := fmt.Sprintf("%s&timeout=0", applierUri)
@@ -1207,44 +1211,80 @@ func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) []*dmlB
12071211
// ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table
12081212
func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) error {
12091213
var totalDelta int64
1214+
ctx := context.Background()
12101215

12111216
err := func() error {
1212-
tx, err := this.db.Begin()
1217+
conn, err := this.db.Conn(ctx)
12131218
if err != nil {
12141219
return err
12151220
}
1221+
defer conn.Close()
1222+
1223+
sessionQuery := "SET /* gh-ost */ SESSION time_zone = '+00:00'"
1224+
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())
1225+
if _, err := conn.ExecContext(ctx, sessionQuery); err != nil {
1226+
return err
1227+
}
12161228

1229+
tx, err := conn.BeginTx(ctx, nil)
1230+
if err != nil {
1231+
return err
1232+
}
12171233
rollback := func(err error) error {
12181234
tx.Rollback()
12191235
return err
12201236
}
12211237

1222-
sessionQuery := "SET /* gh-ost */ SESSION time_zone = '+00:00'"
1223-
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())
1224-
1225-
if _, err := tx.Exec(sessionQuery); err != nil {
1226-
return rollback(err)
1227-
}
1238+
buildResults := make([]*dmlBuildResult, 0, len(dmlEvents))
1239+
nArgs := 0
12281240
for _, dmlEvent := range dmlEvents {
12291241
for _, buildResult := range this.buildDMLEventQuery(dmlEvent) {
12301242
if buildResult.err != nil {
12311243
return rollback(buildResult.err)
12321244
}
1233-
result, err := tx.Exec(buildResult.query, buildResult.args...)
1234-
if err != nil {
1235-
err = fmt.Errorf("%w; query=%s; args=%+v", err, buildResult.query, buildResult.args)
1236-
return rollback(err)
1237-
}
1245+
nArgs += len(buildResult.args)
1246+
buildResults = append(buildResults, buildResult)
1247+
}
1248+
}
12381249

1239-
rowsAffected, err := result.RowsAffected()
1240-
if err != nil {
1241-
log.Warningf("error getting rows affected from DML event query: %s. i'm going to assume that the DML affected a single row, but this may result in inaccurate statistics", err)
1242-
rowsAffected = 1
1250+
// We batch together the DML queries into multi-statements to minimize network trips.
1251+
// We have to use the raw driver connection to access the rows affected
1252+
// for each statement in the multi-statement.
1253+
execErr := conn.Raw(func(driverConn any) error {
1254+
ex := driverConn.(driver.ExecerContext)
1255+
nvc := driverConn.(driver.NamedValueChecker)
1256+
1257+
multiArgs := make([]driver.NamedValue, 0, nArgs)
1258+
multiQueryBuilder := strings.Builder{}
1259+
for _, buildResult := range buildResults {
1260+
for _, arg := range buildResult.args {
1261+
nv := driver.NamedValue{Value: driver.Value(arg)}
1262+
nvc.CheckNamedValue(&nv)
1263+
multiArgs = append(multiArgs, nv)
12431264
}
1244-
// each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1).
1245-
// multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event
1246-
totalDelta += buildResult.rowsDelta * rowsAffected
1265+
1266+
multiQueryBuilder.WriteString(buildResult.query)
1267+
multiQueryBuilder.WriteString(";\n")
12471268
}
1269+
1270+
res, err := ex.ExecContext(ctx, multiQueryBuilder.String(), multiArgs)
1271+
if err != nil {
1272+
err = fmt.Errorf("%w; query=%s; args=%+v", err, multiQueryBuilder.String(), multiArgs)
1273+
return err
1274+
}
1275+
1276+
mysqlRes := res.(drivermysql.Result)
1277+
1278+
// each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1).
1279+
// multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event
1280+
for i, rowsAffected := range mysqlRes.AllRowsAffected() {
1281+
totalDelta += buildResults[i].rowsDelta * rowsAffected
1282+
}
1283+
return nil
1284+
})
1285+
1286+
if execErr != nil {
1287+
return rollback(execErr)
12481288
}
12491289
if err := tx.Commit(); err != nil {
12501290
return err

0 commit comments

Comments
 (0)