Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate fetching of MySQL server info #1229

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions go/base/context.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2022 GitHub Inc.
Copyright 2023 GitHub Inc.
See https://github.com/github/gh-ost/blob/master/LICENSE
*/

Expand Down Expand Up @@ -163,18 +163,15 @@ type MigrationContext struct {

Hostname string
AssumeMasterHostname string
ApplierTimeZone string
TableEngine string
RowsEstimate int64
RowsDeltaEstimate int64
UsedRowsEstimateMethod RowsEstimateMethod
HasSuperPrivilege bool
OriginalBinlogFormat string
OriginalBinlogRowImage string
InspectorConnectionConfig *mysql.ConnectionConfig
InspectorMySQLVersion string
InspectorServerInfo *mysql.ServerInfo
ApplierConnectionConfig *mysql.ConnectionConfig
ApplierMySQLVersion string
ApplierServerInfo *mysql.ServerInfo
StartTime time.Time
RowCopyStartTime time.Time
RowCopyEndTime time.Time
Expand Down Expand Up @@ -368,11 +365,6 @@ func (this *MigrationContext) GetVoluntaryLockName() string {
return fmt.Sprintf("%s.%s.lock", this.DatabaseName, this.OriginalTableName)
}

// RequiresBinlogFormatChange is `true` when the original binlog format isn't `ROW`
func (this *MigrationContext) RequiresBinlogFormatChange() bool {
return this.OriginalBinlogFormat != "ROW"
}

// GetApplierHostname is a safe access method to the applier hostname
func (this *MigrationContext) GetApplierHostname() string {
if this.ApplierConnectionConfig == nil {
Expand Down
41 changes: 13 additions & 28 deletions go/base/utils.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2022 GitHub Inc.
Copyright 2023 GitHub Inc.
See https://github.com/github/gh-ost/blob/master/LICENSE
*/

Expand All @@ -12,8 +12,6 @@ import (
"strings"
"time"

gosql "database/sql"

"github.com/github/gh-ost/go/mysql"
)

Expand Down Expand Up @@ -61,35 +59,22 @@ func StringContainsAll(s string, substrings ...string) bool {
return nonEmptyStringsFound
}

func ValidateConnection(db *gosql.DB, connectionConfig *mysql.ConnectionConfig, migrationContext *MigrationContext, name string) (string, error) {
versionQuery := `select @@global.version`
var port, extraPort int
var version string
if err := db.QueryRow(versionQuery).Scan(&version); err != nil {
return "", err
}
extraPortQuery := `select @@global.extra_port`
if err := db.QueryRow(extraPortQuery).Scan(&extraPort); err != nil { //nolint:staticcheck
// swallow this error. not all servers support extra_port
}
// ValidateConnection confirms the database server info matches the provided connection config.
func ValidateConnection(serverInfo *mysql.ServerInfo, connectionConfig *mysql.ConnectionConfig, migrationContext *MigrationContext, name string) error {
// AliyunRDS set users port to "NULL", replace it by gh-ost param
// GCP set users port to "NULL", replace it by gh-ost param
// Azure MySQL set users port to a different value by design, replace it by gh-ost para
// Azure MySQL set users port to a different value by design, replace it by gh-ost param
if migrationContext.AliyunRDS || migrationContext.GoogleCloudPlatform || migrationContext.AzureMySQL {
port = connectionConfig.Key.Port
} else {
portQuery := `select @@global.port`
if err := db.QueryRow(portQuery).Scan(&port); err != nil {
return "", err
}
serverInfo.Port.Int64 = connectionConfig.Key.Port
serverInfo.Port.Valid = connectionConfig.Key.Port > 0
}

if connectionConfig.Key.Port == port || (extraPort > 0 && connectionConfig.Key.Port == extraPort) {
migrationContext.Log.Infof("%s connection validated on %+v", name, connectionConfig.Key)
return version, nil
} else if extraPort == 0 {
return "", fmt.Errorf("Unexpected database port reported: %+v", port)
} else {
return "", fmt.Errorf("Unexpected database port reported: %+v / extra_port: %+v", port, extraPort)
if !serverInfo.Port.Valid && !serverInfo.ExtraPort.Valid {
return fmt.Errorf("Unexpected database port reported: %+v", serverInfo.Port.Int64)
} else if connectionConfig.Key.Port != serverInfo.Port.Int64 && connectionConfig.Key.Port != serverInfo.ExtraPort.Int64 {
return fmt.Errorf("Unexpected database port reported: %+v / extra_port: %+v", serverInfo.Port.Int64, serverInfo.ExtraPort.Int64)
}

migrationContext.Log.Infof("%s connection validated on %+v", name, connectionConfig.Key)
return nil
}
85 changes: 84 additions & 1 deletion go/base/utils_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
/*
Copyright 2016 GitHub Inc.
Copyright 2023 GitHub Inc.
See https://github.com/github/gh-ost/blob/master/LICENSE
*/

package base

import (
gosql "database/sql"
"testing"

"github.com/github/gh-ost/go/mysql"
"github.com/openark/golib/log"
test "github.com/openark/golib/tests"
)
Expand All @@ -16,6 +18,10 @@ func init() {
log.SetLevel(log.ERROR)
}

func newMysqlPort(port int64) gosql.NullInt64 {
return gosql.NullInt64{Int64: port, Valid: port > 0}
}

func TestStringContainsAll(t *testing.T) {
s := `insert,delete,update`

Expand All @@ -27,3 +33,80 @@ func TestStringContainsAll(t *testing.T) {
test.S(t).ExpectTrue(StringContainsAll(s, "insert", ""))
test.S(t).ExpectTrue(StringContainsAll(s, "insert", "update", "delete"))
}

func TestValidateConnection(t *testing.T) {
connectionConfig := &mysql.ConnectionConfig{
Key: mysql.InstanceKey{
Hostname: t.Name(),
Port: mysql.DefaultInstancePort,
},
}

// check valid port matching connectionConfig validates
{
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
serverInfo := &mysql.ServerInfo{
Port: newMysqlPort(mysql.DefaultInstancePort),
ExtraPort: newMysqlPort(mysql.DefaultInstancePort + 1),
}
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
// check NULL port validates when AliyunRDS=true
{
migrationContext := &MigrationContext{
Log: NewDefaultLogger(),
AliyunRDS: true,
}
serverInfo := &mysql.ServerInfo{}
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
// check NULL port validates when AzureMySQL=true
{
migrationContext := &MigrationContext{
Log: NewDefaultLogger(),
AzureMySQL: true,
}
serverInfo := &mysql.ServerInfo{}
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
// check NULL port validates when GoogleCloudPlatform=true
{
migrationContext := &MigrationContext{
Log: NewDefaultLogger(),
GoogleCloudPlatform: true,
}
serverInfo := &mysql.ServerInfo{}
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
// check extra_port validates when port=NULL
{
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
serverInfo := &mysql.ServerInfo{
ExtraPort: newMysqlPort(mysql.DefaultInstancePort),
}
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
// check extra_port validates when port does not match but extra_port does
{
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
serverInfo := &mysql.ServerInfo{
Port: newMysqlPort(12345),
ExtraPort: newMysqlPort(mysql.DefaultInstancePort),
}
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
// check validation fails when valid port does not match connectionConfig
{
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
serverInfo := &mysql.ServerInfo{
Port: newMysqlPort(9999),
}
test.S(t).ExpectNotNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
// check validation fails when port and extra_port are invalid
{
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
serverInfo := &mysql.ServerInfo{}
test.S(t).ExpectNotNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
}
}
4 changes: 2 additions & 2 deletions go/cmd/gh-ost/main.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2022 GitHub Inc.
Copyright 2023 GitHub Inc.
See https://github.com/github/gh-ost/blob/master/LICENSE
*/

Expand Down Expand Up @@ -49,7 +49,7 @@ func main() {
migrationContext := base.NewMigrationContext()
flag.StringVar(&migrationContext.InspectorConnectionConfig.Key.Hostname, "host", "127.0.0.1", "MySQL hostname (preferably a replica, not the master)")
flag.StringVar(&migrationContext.AssumeMasterHostname, "assume-master-host", "", "(optional) explicitly tell gh-ost the identity of the master. Format: some.host.com[:port] This is useful in master-master setups where you wish to pick an explicit master, or in a tungsten-replicator where gh-ost is unable to determine the master")
flag.IntVar(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)")
flag.Int64Var(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)")
flag.Float64Var(&migrationContext.InspectorConnectionConfig.Timeout, "mysql-timeout", 0.0, "Connect, read and write timeout for MySQL")
flag.Float64Var(&migrationContext.InspectorConnectionConfig.WaitTimeout, "mysql-wait-timeout", 0.0, "wait_timeout for MySQL sessions")
flag.StringVar(&migrationContext.CliUser, "user", "", "MySQL user")
Expand Down
39 changes: 14 additions & 25 deletions go/logic/applier.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2022 GitHub Inc.
Copyright 2023 GitHub Inc.
See https://github.com/github/gh-ost/blob/master/LICENSE
*/

Expand Down Expand Up @@ -71,25 +71,24 @@ func NewApplier(migrationContext *base.MigrationContext) *Applier {
}
}

func (this *Applier) ServerInfo() *mysql.ServerInfo {
return this.migrationContext.ApplierServerInfo
}

func (this *Applier) InitDBConnections() (err error) {
applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName)
if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, applierUri); err != nil {
return err
}
if this.migrationContext.ApplierServerInfo, err = mysql.GetServerInfo(this.db); err != nil {
return err
}
singletonApplierUri := fmt.Sprintf("%s&timeout=0", applierUri)
if this.singletonDB, _, err = mysql.GetDB(this.migrationContext.Uuid, singletonApplierUri); err != nil {
return err
}
this.singletonDB.SetMaxOpenConns(1)
version, err := base.ValidateConnection(this.db, this.connectionConfig, this.migrationContext, this.name)
if err != nil {
return err
}
if _, err := base.ValidateConnection(this.singletonDB, this.connectionConfig, this.migrationContext, this.name); err != nil {
return err
}
this.migrationContext.ApplierMySQLVersion = version
if err := this.validateAndReadTimeZone(); err != nil {
if err = base.ValidateConnection(this.ServerInfo(), this.connectionConfig, this.migrationContext, this.name); err != nil {
return err
}
if !this.migrationContext.AliyunRDS && !this.migrationContext.GoogleCloudPlatform && !this.migrationContext.AzureMySQL {
Expand All @@ -102,18 +101,8 @@ func (this *Applier) InitDBConnections() (err error) {
if err := this.readTableColumns(); err != nil {
return err
}
this.migrationContext.Log.Infof("Applier initiated on %+v, version %+v", this.connectionConfig.ImpliedKey, this.migrationContext.ApplierMySQLVersion)
return nil
}

// validateAndReadTimeZone potentially reads server time-zone
func (this *Applier) validateAndReadTimeZone() error {
query := `select /* gh-ost */ @@global.time_zone`
if err := this.db.QueryRow(query).Scan(&this.migrationContext.ApplierTimeZone); err != nil {
return err
}

this.migrationContext.Log.Infof("will use time_zone='%s' on applier", this.migrationContext.ApplierTimeZone)
this.migrationContext.Log.Infof("Applier initiated on %+v, version %+v (%+v)", this.connectionConfig.ImpliedKey,
this.ServerInfo().Version, this.ServerInfo().VersionComment)
return nil
}

Expand Down Expand Up @@ -238,7 +227,7 @@ func (this *Applier) CreateGhostTable() error {
}
defer tx.Rollback()

sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.migrationContext.ApplierTimeZone)
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.ServerInfo().TimeZone)
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())

if _, err := tx.Exec(sessionQuery); err != nil {
Expand Down Expand Up @@ -279,7 +268,7 @@ func (this *Applier) AlterGhost() error {
}
defer tx.Rollback()

sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.migrationContext.ApplierTimeZone)
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.ServerInfo().TimeZone)
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())

if _, err := tx.Exec(sessionQuery); err != nil {
Expand Down Expand Up @@ -640,7 +629,7 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected
}
defer tx.Rollback()

sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.migrationContext.ApplierTimeZone)
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.ServerInfo().TimeZone)
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())

if _, err := tx.Exec(sessionQuery); err != nil {
Expand Down
Loading