From 202dcf9d60734092a54d17181c268dfd0b3645c9 Mon Sep 17 00:00:00 2001 From: lance6716 Date: Thu, 1 Aug 2024 17:04:20 +0800 Subject: [PATCH] *: improve security (#810) close pingcap/tidb-tools#811 --- pkg/check/privilege.go | 4 +-- pkg/dbutil/variable.go | 15 ++--------- pkg/dbutil/variable_test.go | 6 ++--- pkg/ddl-checker/executable_checker.go | 11 ++------ pkg/diff/checkpoint.go | 37 +++------------------------ pkg/diff/checkpoint_test.go | 5 ---- pkg/diff/diff.go | 4 +-- pkg/diff/diff_test.go | 2 +- 8 files changed, 15 insertions(+), 69 deletions(-) diff --git a/pkg/check/privilege.go b/pkg/check/privilege.go index db69c7b1e..ad670c1ee 100644 --- a/pkg/check/privilege.go +++ b/pkg/check/privilege.go @@ -69,7 +69,7 @@ func (pc *SourceDumpPrivilegeChecker) Check(ctx context.Context) *Result { Extra: fmt.Sprintf("address of db instance - %s:%d", pc.dbinfo.Host, pc.dbinfo.Port), } - grants, err := dbutil.ShowGrants(ctx, pc.db, "", "") + grants, err := dbutil.ShowGrants(ctx, pc.db) if err != nil { markCheckError(result, err) return result @@ -107,7 +107,7 @@ func (pc *SourceReplicatePrivilegeChecker) Check(ctx context.Context) *Result { Extra: fmt.Sprintf("address of db instance - %s:%d", pc.dbinfo.Host, pc.dbinfo.Port), } - grants, err := dbutil.ShowGrants(ctx, pc.db, "", "") + grants, err := dbutil.ShowGrants(ctx, pc.db) if err != nil { markCheckError(result, err) return result diff --git a/pkg/dbutil/variable.go b/pkg/dbutil/variable.go index 6bbd95594..3f0a554a0 100644 --- a/pkg/dbutil/variable.go +++ b/pkg/dbutil/variable.go @@ -55,19 +55,8 @@ func ShowMySQLVariable(ctx context.Context, db QueryExecutor, variable string) ( // ShowGrants queries privileges for a mysql user. // For mysql 8.0, if user has granted roles, ShowGrants also extract privilege from roles. -func ShowGrants(ctx context.Context, db QueryExecutor, user, host string) ([]string, error) { - if host == "" { - host = "%" - } - - var query string - if user == "" { - // for current user. - query = "SHOW GRANTS" - } else { - query = fmt.Sprintf("SHOW GRANTS FOR '%s'@'%s'", user, host) - } - +func ShowGrants(ctx context.Context, db QueryExecutor) ([]string, error) { + query := "SHOW GRANTS" readGrantsFunc := func() ([]string, error) { rows, err := db.QueryContext(ctx, query) if err != nil { diff --git a/pkg/dbutil/variable_test.go b/pkg/dbutil/variable_test.go index a4343f790..c23ba88a6 100644 --- a/pkg/dbutil/variable_test.go +++ b/pkg/dbutil/variable_test.go @@ -22,7 +22,7 @@ func (*testDBSuite) TestShowGrants(c *C) { } mock.ExpectQuery("SHOW GRANTS").WillReturnRows(rows) - grants, err := ShowGrants(ctx, db, "", "") + grants, err := ShowGrants(ctx, db) c.Assert(err, IsNil) c.Assert(grants, DeepEquals, mockGrants) c.Assert(mock.ExpectationsWereMet(), IsNil) @@ -54,7 +54,7 @@ func (*testDBSuite) TestShowGrantsWithRoles(c *C) { } mock.ExpectQuery("SHOW GRANTS").WillReturnRows(rows2) - grants, err := ShowGrants(ctx, db, "", "") + grants, err := ShowGrants(ctx, db) c.Assert(err, IsNil) c.Assert(grants, DeepEquals, mockGrantsWithRoles) c.Assert(mock.ExpectationsWereMet(), IsNil) @@ -92,7 +92,7 @@ func (*testDBSuite) TestShowGrantsPasswordMasked(c *C) { rows.AddRow(ca.original) mock.ExpectQuery("SHOW GRANTS").WillReturnRows(rows) - grants, err := ShowGrants(ctx, db, "", "") + grants, err := ShowGrants(ctx, db) c.Assert(err, IsNil) c.Assert(grants, HasLen, 1) c.Assert(grants[0], DeepEquals, ca.expected) diff --git a/pkg/ddl-checker/executable_checker.go b/pkg/ddl-checker/executable_checker.go index 13b430e64..3f100795e 100644 --- a/pkg/ddl-checker/executable_checker.go +++ b/pkg/ddl-checker/executable_checker.go @@ -15,10 +15,10 @@ package checker import ( "context" - "fmt" "github.com/pingcap/errors" "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/lightning/common" "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/session" @@ -70,13 +70,6 @@ func (ec *ExecutableChecker) Execute(context context.Context, sql string) error return nil } -// IsTableExist returns whether the table with the specified name exists -func (ec *ExecutableChecker) IsTableExist(context *context.Context, tableName string) bool { - _, err := ec.session.Execute(*context, - fmt.Sprintf("select 0 from `%s` limit 1", tableName)) - return err == nil -} - // CreateTable creates a new table with the specified sql func (ec *ExecutableChecker) CreateTable(context context.Context, sql string) error { err := ec.Execute(context, sql) @@ -88,7 +81,7 @@ func (ec *ExecutableChecker) CreateTable(context context.Context, sql string) er // DropTable drops the the specified table func (ec *ExecutableChecker) DropTable(context context.Context, tableName string) error { - err := ec.Execute(context, fmt.Sprintf("drop table if exists `%s`", tableName)) + err := ec.Execute(context, common.SprintfWithIdentifiers("drop table if exists %s", tableName)) if err != nil { return errors.Trace(err) } diff --git a/pkg/diff/checkpoint.go b/pkg/diff/checkpoint.go index afe57f929..92c6dd504 100644 --- a/pkg/diff/checkpoint.go +++ b/pkg/diff/checkpoint.go @@ -156,37 +156,6 @@ func initChunks(ctx context.Context, db *sql.DB, instanceID, schema, table strin return nil } -// getChunk gets chunk info from table `chunk` by chunkID -func getChunk(ctx context.Context, db *sql.DB, instanceID, schema, table string, chunkID int) (*ChunkRange, error) { - query := fmt.Sprintf("SELECT `chunk_str` FROM `%s`.`%s` WHERE `instance_id` = ? AND `schema` = ? AND `table` = ? AND `chunk_id` = ? limit 1", checkpointSchemaName, chunkTableName) - rows, err := db.QueryContext(ctx, query, instanceID, schema, table, chunkID) - if err != nil { - return nil, err - } - defer rows.Close() - - for rows.Next() { - fields, err1 := dbutil.ScanRow(rows) - if err1 != nil { - return nil, errors.Trace(err1) - } - - chunkStr := fields["chunk_str"].Data - chunk := new(ChunkRange) - err := json.Unmarshal(chunkStr, &chunk) - if err != nil { - return nil, err - } - return chunk, nil - } - - if rows.Err() != nil { - return nil, errors.Trace(rows.Err()) - } - - return nil, errors.NotFoundf("instanceID %d, schema %s, table %s, chunk %d", instanceID, schema, table, chunkID) -} - // loadChunks loads chunk info from table `chunk` func loadChunks(ctx context.Context, db *sql.DB, instanceID, schema, table string) ([]*ChunkRange, error) { chunks := make([]*ChunkRange, 0, 100) @@ -250,7 +219,7 @@ func getTableSummary(ctx context.Context, db *sql.DB, schema, table string) (tot } // initTableSummary initials a table's summary info in table `summary` -func initTableSummary(ctx context.Context, db *sql.DB, schema, table string, configHash string) error { +func initTableSummary(ctx context.Context, db *sql.DB, schema, table, configHash string) error { sql := fmt.Sprintf("REPLACE INTO `%s`.`%s`(`schema`, `table`, `state`, `config_hash`) VALUES(?, ?, ?, ?)", checkpointSchemaName, summaryTableName) err := dbutil.ExecSQLWithRetry(ctx, db, sql, schema, table, notCheckedState, configHash) if err != nil { @@ -323,7 +292,7 @@ func createCheckpointTable(ctx context.Context, db *sql.DB) error { "`check_failed_num` int not null default 0," + "`check_ignore_num` int not null default 0," + "`state` enum('not_checked', 'checking', 'success', 'failed') DEFAULT 'not_checked'," + - "`config_hash` varchar(50)," + + "`config_hash` varchar(100)," + "`update_time` datetime ON UPDATE CURRENT_TIMESTAMP," + "PRIMARY KEY(`schema`, `table`));" @@ -410,7 +379,7 @@ func loadFromCheckPoint(ctx context.Context, db *sql.DB, schema, table, configHa } if cfgHash.Valid { - if configHash != cfgHash.String { + if string(configHash) != cfgHash.String { return false, nil } } diff --git a/pkg/diff/checkpoint_test.go b/pkg/diff/checkpoint_test.go index f417cf808..d6ca1a1b0 100644 --- a/pkg/diff/checkpoint_test.go +++ b/pkg/diff/checkpoint_test.go @@ -71,11 +71,6 @@ func (s *testCheckpointSuite) testSaveAndLoadChunk(c *C, db *sql.DB) { err := saveChunk(context.Background(), db, chunk.ID, "target", "test", "checkpoint", "", chunk) c.Assert(err, IsNil) - newChunk, err := getChunk(context.Background(), db, "target", "test", "checkpoint", chunk.ID) - c.Assert(err, IsNil) - newChunk.updateColumnOffset() - c.Assert(newChunk, DeepEquals, chunk) - chunks, err := loadChunks(context.Background(), db, "target", "test", "checkpoint") c.Assert(err, IsNil) c.Assert(chunks, HasLen, 1) diff --git a/pkg/diff/diff.go b/pkg/diff/diff.go index 06242f9e3..2b151dfff 100644 --- a/pkg/diff/diff.go +++ b/pkg/diff/diff.go @@ -16,7 +16,7 @@ package diff import ( "container/heap" "context" - "crypto/md5" + "crypto/sha256" "database/sql" "encoding/json" "fmt" @@ -120,7 +120,7 @@ func (t *TableDiff) setConfigHash() error { return errors.Trace(err) } - t.configHash = fmt.Sprintf("%x", md5.Sum(jsonBytes)) + t.configHash = fmt.Sprintf("%x", sha256.Sum256(jsonBytes)) log.Debug("sync-diff-inspector config", zap.ByteString("config", jsonBytes), zap.String("hash", t.configHash)) return nil diff --git a/pkg/diff/diff_test.go b/pkg/diff/diff_test.go index 05bcda743..c16933a5e 100644 --- a/pkg/diff/diff_test.go +++ b/pkg/diff/diff_test.go @@ -394,5 +394,5 @@ func (*testDiffSuite) TestConfigHash(c *C) { tbDiff.Range = "b < 10" tbDiff.setConfigHash() hash3 := tbDiff.configHash - c.Assert(hash1 == hash3, Equals, false) + c.Assert(hash1, Not(Equals), hash3) }