Skip to content

Commit

Permalink
*: improve security (#810)
Browse files Browse the repository at this point in the history
close #811
  • Loading branch information
lance6716 authored Aug 1, 2024
1 parent 0d29f19 commit 202dcf9
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 69 deletions.
4 changes: 2 additions & 2 deletions pkg/check/privilege.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (pc *SourceDumpPrivilegeChecker) Check(ctx context.Context) *Result {
Extra: fmt.Sprintf("address of db instance - %s:%d", pc.dbinfo.Host, pc.dbinfo.Port),
}

grants, err := dbutil.ShowGrants(ctx, pc.db, "", "")
grants, err := dbutil.ShowGrants(ctx, pc.db)
if err != nil {
markCheckError(result, err)
return result
Expand Down Expand Up @@ -107,7 +107,7 @@ func (pc *SourceReplicatePrivilegeChecker) Check(ctx context.Context) *Result {
Extra: fmt.Sprintf("address of db instance - %s:%d", pc.dbinfo.Host, pc.dbinfo.Port),
}

grants, err := dbutil.ShowGrants(ctx, pc.db, "", "")
grants, err := dbutil.ShowGrants(ctx, pc.db)
if err != nil {
markCheckError(result, err)
return result
Expand Down
15 changes: 2 additions & 13 deletions pkg/dbutil/variable.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,8 @@ func ShowMySQLVariable(ctx context.Context, db QueryExecutor, variable string) (

// ShowGrants queries privileges for a mysql user.
// For mysql 8.0, if user has granted roles, ShowGrants also extract privilege from roles.
func ShowGrants(ctx context.Context, db QueryExecutor, user, host string) ([]string, error) {
if host == "" {
host = "%"
}

var query string
if user == "" {
// for current user.
query = "SHOW GRANTS"
} else {
query = fmt.Sprintf("SHOW GRANTS FOR '%s'@'%s'", user, host)
}

func ShowGrants(ctx context.Context, db QueryExecutor) ([]string, error) {
query := "SHOW GRANTS"
readGrantsFunc := func() ([]string, error) {
rows, err := db.QueryContext(ctx, query)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions pkg/dbutil/variable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func (*testDBSuite) TestShowGrants(c *C) {
}
mock.ExpectQuery("SHOW GRANTS").WillReturnRows(rows)

grants, err := ShowGrants(ctx, db, "", "")
grants, err := ShowGrants(ctx, db)
c.Assert(err, IsNil)
c.Assert(grants, DeepEquals, mockGrants)
c.Assert(mock.ExpectationsWereMet(), IsNil)
Expand Down Expand Up @@ -54,7 +54,7 @@ func (*testDBSuite) TestShowGrantsWithRoles(c *C) {
}
mock.ExpectQuery("SHOW GRANTS").WillReturnRows(rows2)

grants, err := ShowGrants(ctx, db, "", "")
grants, err := ShowGrants(ctx, db)
c.Assert(err, IsNil)
c.Assert(grants, DeepEquals, mockGrantsWithRoles)
c.Assert(mock.ExpectationsWereMet(), IsNil)
Expand Down Expand Up @@ -92,7 +92,7 @@ func (*testDBSuite) TestShowGrantsPasswordMasked(c *C) {
rows.AddRow(ca.original)
mock.ExpectQuery("SHOW GRANTS").WillReturnRows(rows)

grants, err := ShowGrants(ctx, db, "", "")
grants, err := ShowGrants(ctx, db)
c.Assert(err, IsNil)
c.Assert(grants, HasLen, 1)
c.Assert(grants[0], DeepEquals, ca.expected)
Expand Down
11 changes: 2 additions & 9 deletions pkg/ddl-checker/executable_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ package checker

import (
"context"
"fmt"

"github.com/pingcap/errors"
"github.com/pingcap/log"
"github.com/pingcap/tidb/pkg/lightning/common"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/session"
Expand Down Expand Up @@ -70,13 +70,6 @@ func (ec *ExecutableChecker) Execute(context context.Context, sql string) error
return nil
}

// IsTableExist returns whether the table with the specified name exists
func (ec *ExecutableChecker) IsTableExist(context *context.Context, tableName string) bool {
_, err := ec.session.Execute(*context,
fmt.Sprintf("select 0 from `%s` limit 1", tableName))
return err == nil
}

// CreateTable creates a new table with the specified sql
func (ec *ExecutableChecker) CreateTable(context context.Context, sql string) error {
err := ec.Execute(context, sql)
Expand All @@ -88,7 +81,7 @@ func (ec *ExecutableChecker) CreateTable(context context.Context, sql string) er

// DropTable drops the the specified table
func (ec *ExecutableChecker) DropTable(context context.Context, tableName string) error {
err := ec.Execute(context, fmt.Sprintf("drop table if exists `%s`", tableName))
err := ec.Execute(context, common.SprintfWithIdentifiers("drop table if exists %s", tableName))
if err != nil {
return errors.Trace(err)
}
Expand Down
37 changes: 3 additions & 34 deletions pkg/diff/checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,37 +156,6 @@ func initChunks(ctx context.Context, db *sql.DB, instanceID, schema, table strin
return nil
}

// getChunk gets chunk info from table `chunk` by chunkID
func getChunk(ctx context.Context, db *sql.DB, instanceID, schema, table string, chunkID int) (*ChunkRange, error) {
query := fmt.Sprintf("SELECT `chunk_str` FROM `%s`.`%s` WHERE `instance_id` = ? AND `schema` = ? AND `table` = ? AND `chunk_id` = ? limit 1", checkpointSchemaName, chunkTableName)
rows, err := db.QueryContext(ctx, query, instanceID, schema, table, chunkID)
if err != nil {
return nil, err
}
defer rows.Close()

for rows.Next() {
fields, err1 := dbutil.ScanRow(rows)
if err1 != nil {
return nil, errors.Trace(err1)
}

chunkStr := fields["chunk_str"].Data
chunk := new(ChunkRange)
err := json.Unmarshal(chunkStr, &chunk)
if err != nil {
return nil, err
}
return chunk, nil
}

if rows.Err() != nil {
return nil, errors.Trace(rows.Err())
}

return nil, errors.NotFoundf("instanceID %d, schema %s, table %s, chunk %d", instanceID, schema, table, chunkID)
}

// loadChunks loads chunk info from table `chunk`
func loadChunks(ctx context.Context, db *sql.DB, instanceID, schema, table string) ([]*ChunkRange, error) {
chunks := make([]*ChunkRange, 0, 100)
Expand Down Expand Up @@ -250,7 +219,7 @@ func getTableSummary(ctx context.Context, db *sql.DB, schema, table string) (tot
}

// initTableSummary initials a table's summary info in table `summary`
func initTableSummary(ctx context.Context, db *sql.DB, schema, table string, configHash string) error {
func initTableSummary(ctx context.Context, db *sql.DB, schema, table, configHash string) error {
sql := fmt.Sprintf("REPLACE INTO `%s`.`%s`(`schema`, `table`, `state`, `config_hash`) VALUES(?, ?, ?, ?)", checkpointSchemaName, summaryTableName)
err := dbutil.ExecSQLWithRetry(ctx, db, sql, schema, table, notCheckedState, configHash)
if err != nil {
Expand Down Expand Up @@ -323,7 +292,7 @@ func createCheckpointTable(ctx context.Context, db *sql.DB) error {
"`check_failed_num` int not null default 0," +
"`check_ignore_num` int not null default 0," +
"`state` enum('not_checked', 'checking', 'success', 'failed') DEFAULT 'not_checked'," +
"`config_hash` varchar(50)," +
"`config_hash` varchar(100)," +
"`update_time` datetime ON UPDATE CURRENT_TIMESTAMP," +
"PRIMARY KEY(`schema`, `table`));"

Expand Down Expand Up @@ -410,7 +379,7 @@ func loadFromCheckPoint(ctx context.Context, db *sql.DB, schema, table, configHa
}

if cfgHash.Valid {
if configHash != cfgHash.String {
if string(configHash) != cfgHash.String {
return false, nil
}
}
Expand Down
5 changes: 0 additions & 5 deletions pkg/diff/checkpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,6 @@ func (s *testCheckpointSuite) testSaveAndLoadChunk(c *C, db *sql.DB) {
err := saveChunk(context.Background(), db, chunk.ID, "target", "test", "checkpoint", "", chunk)
c.Assert(err, IsNil)

newChunk, err := getChunk(context.Background(), db, "target", "test", "checkpoint", chunk.ID)
c.Assert(err, IsNil)
newChunk.updateColumnOffset()
c.Assert(newChunk, DeepEquals, chunk)

chunks, err := loadChunks(context.Background(), db, "target", "test", "checkpoint")
c.Assert(err, IsNil)
c.Assert(chunks, HasLen, 1)
Expand Down
4 changes: 2 additions & 2 deletions pkg/diff/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package diff
import (
"container/heap"
"context"
"crypto/md5"
"crypto/sha256"
"database/sql"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -120,7 +120,7 @@ func (t *TableDiff) setConfigHash() error {
return errors.Trace(err)
}

t.configHash = fmt.Sprintf("%x", md5.Sum(jsonBytes))
t.configHash = fmt.Sprintf("%x", sha256.Sum256(jsonBytes))
log.Debug("sync-diff-inspector config", zap.ByteString("config", jsonBytes), zap.String("hash", t.configHash))

return nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/diff/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,5 +394,5 @@ func (*testDiffSuite) TestConfigHash(c *C) {
tbDiff.Range = "b < 10"
tbDiff.setConfigHash()
hash3 := tbDiff.configHash
c.Assert(hash1 == hash3, Equals, false)
c.Assert(hash1, Not(Equals), hash3)
}

0 comments on commit 202dcf9

Please sign in to comment.