Skip to content

Commit

Permalink
Merge pull request #1462 from github/meiji163/multi-stmt
Browse files Browse the repository at this point in the history
Use multiStatement to apply DML
  • Loading branch information
arthurschreiber authored Oct 25, 2024
2 parents a834c00 + 2e62f2a commit 7c30fb0
Showing 1 changed file with 62 additions and 22 deletions.
84 changes: 62 additions & 22 deletions go/logic/applier.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ import (

"github.com/github/gh-ost/go/base"
"github.com/github/gh-ost/go/binlog"
"github.com/github/gh-ost/go/mysql"
"github.com/github/gh-ost/go/sql"

"github.com/openark/golib/log"
"context"
"database/sql/driver"

"github.com/github/gh-ost/go/mysql"
drivermysql "github.com/go-sql-driver/mysql"
"github.com/openark/golib/sqlutils"
)

Expand Down Expand Up @@ -77,7 +80,8 @@ func NewApplier(migrationContext *base.MigrationContext) *Applier {

func (this *Applier) InitDBConnections() (err error) {
applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName)
if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, applierUri); err != nil {
uriWithMulti := fmt.Sprintf("%s&multiStatements=true", applierUri)
if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, uriWithMulti); err != nil {
return err
}
singletonApplierUri := fmt.Sprintf("%s&timeout=0", applierUri)
Expand Down Expand Up @@ -1207,44 +1211,80 @@ func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) []*dmlB
// ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table
func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) error {
var totalDelta int64
ctx := context.Background()

err := func() error {
tx, err := this.db.Begin()
conn, err := this.db.Conn(ctx)
if err != nil {
return err
}
defer conn.Close()

sessionQuery := "SET /* gh-ost */ SESSION time_zone = '+00:00'"
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())
if _, err := conn.ExecContext(ctx, sessionQuery); err != nil {
return err
}

tx, err := conn.BeginTx(ctx, nil)
if err != nil {
return err
}
rollback := func(err error) error {
tx.Rollback()
return err
}

sessionQuery := "SET /* gh-ost */ SESSION time_zone = '+00:00'"
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())

if _, err := tx.Exec(sessionQuery); err != nil {
return rollback(err)
}
buildResults := make([]*dmlBuildResult, 0, len(dmlEvents))
nArgs := 0
for _, dmlEvent := range dmlEvents {
for _, buildResult := range this.buildDMLEventQuery(dmlEvent) {
if buildResult.err != nil {
return rollback(buildResult.err)
}
result, err := tx.Exec(buildResult.query, buildResult.args...)
if err != nil {
err = fmt.Errorf("%w; query=%s; args=%+v", err, buildResult.query, buildResult.args)
return rollback(err)
}
nArgs += len(buildResult.args)
buildResults = append(buildResults, buildResult)
}
}

rowsAffected, err := result.RowsAffected()
if err != nil {
log.Warningf("error getting rows affected from DML event query: %s. i'm going to assume that the DML affected a single row, but this may result in inaccurate statistics", err)
rowsAffected = 1
// We batch together the DML queries into multi-statements to minimize network trips.
// We have to use the raw driver connection to access the rows affected
// for each statement in the multi-statement.
execErr := conn.Raw(func(driverConn any) error {
ex := driverConn.(driver.ExecerContext)
nvc := driverConn.(driver.NamedValueChecker)

multiArgs := make([]driver.NamedValue, 0, nArgs)
multiQueryBuilder := strings.Builder{}
for _, buildResult := range buildResults {
for _, arg := range buildResult.args {
nv := driver.NamedValue{Value: driver.Value(arg)}
nvc.CheckNamedValue(&nv)
multiArgs = append(multiArgs, nv)
}
// each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1).
// multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event
totalDelta += buildResult.rowsDelta * rowsAffected

multiQueryBuilder.WriteString(buildResult.query)
multiQueryBuilder.WriteString(";\n")
}

res, err := ex.ExecContext(ctx, multiQueryBuilder.String(), multiArgs)
if err != nil {
err = fmt.Errorf("%w; query=%s; args=%+v", err, multiQueryBuilder.String(), multiArgs)
return err
}

mysqlRes := res.(drivermysql.Result)

// each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1).
// multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event
for i, rowsAffected := range mysqlRes.AllRowsAffected() {
totalDelta += buildResults[i].rowsDelta * rowsAffected
}
return nil
})

if execErr != nil {
return rollback(execErr)
}
if err := tx.Commit(); err != nil {
return err
Expand Down

0 comments on commit 7c30fb0

Please sign in to comment.