From 4191347d64d25a0e5f6cc12b71c6c6e917556b17 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Thu, 23 Jan 2025 11:03:41 +0100 Subject: [PATCH] Always make sure to escape all strings Don't directly interpolate these strings. We don't know of any user controllable ways to do this, but it's still too risky to ever do this. We always need to escape all strings. Ideally we refactor this as well to use better statement binding in the future. Signed-off-by: Dirkjan Bussink --- .../vreplication/vdiff_helper_test.go | 4 +--- go/vt/binlog/binlogplayer/binlog_player.go | 22 +++++++++---------- go/vt/vtctl/vdiff_env_test.go | 2 +- go/vt/vtctl/workflow/resharder.go | 1 - go/vt/vtctl/workflow/traffic_switcher.go | 10 ++++----- go/vt/vtctl/workflow/utils.go | 5 +---- go/vt/vttablet/endtoend/vstreamer_test.go | 4 +--- go/vt/vttablet/onlineddl/executor.go | 4 ++-- go/vt/vttablet/tabletmanager/vdiff/utils.go | 4 +--- .../vreplication/insert_generator.go | 8 +++---- .../tabletmanager/vreplication/vreplicator.go | 6 ++--- go/vt/vttablet/tabletserver/schema/tracker.go | 4 +--- .../tabletserver/vstreamer/vstreamer.go | 4 +--- go/vt/wrangler/keyspace.go | 4 +--- 14 files changed, 31 insertions(+), 51 deletions(-) diff --git a/go/test/endtoend/vreplication/vdiff_helper_test.go b/go/test/endtoend/vreplication/vdiff_helper_test.go index fcc112b670b..49fe4c45f6a 100644 --- a/go/test/endtoend/vreplication/vdiff_helper_test.go +++ b/go/test/endtoend/vreplication/vdiff_helper_test.go @@ -274,9 +274,7 @@ func getVDiffInfo(json string) *vdiffInfo { } func encodeString(in string) string { - var buf strings.Builder - sqltypes.NewVarChar(in).EncodeSQL(&buf) - return buf.String() + return sqltypes.EncodeStringSQL(in) } // generateMoreCustomers creates additional test data for better tests diff --git a/go/vt/binlog/binlogplayer/binlog_player.go b/go/vt/binlog/binlogplayer/binlog_player.go index 92718a4b5ed..29264cf54b9 100644 --- a/go/vt/binlog/binlogplayer/binlog_player.go +++ b/go/vt/binlog/binlogplayer/binlog_player.go @@ -549,7 +549,7 @@ func (blp *BinlogPlayer) setVReplicationState(state binlogdatapb.VReplicationWor }) } blp.blplStats.State.Store(state.String()) - query := fmt.Sprintf("update _vt.vreplication set state='%v', message=%v where id=%v", state.String(), encodeString(MessageTruncate(message)), blp.uid) + query := fmt.Sprintf("update _vt.vreplication set state=%v, message=%v where id=%v", encodeString(state.String()), encodeString(MessageTruncate(message)), blp.uid) if _, err := blp.dbClient.ExecuteFetch(query, 1); err != nil { return fmt.Errorf("could not set state: %v: %v", query, err) } @@ -637,9 +637,9 @@ func CreateVReplication(workflow string, source *binlogdatapb.BinlogSource, posi protoutil.SortBinlogSourceTables(source) return fmt.Sprintf("insert into _vt.vreplication "+ "(workflow, source, pos, max_tps, max_replication_lag, time_updated, transaction_timestamp, state, db_name, workflow_type, workflow_sub_type, defer_secondary_keys, options) "+ - "values (%v, %v, %v, %v, %v, %v, 0, '%v', %v, %d, %d, %v, %s)", + "values (%v, %v, %v, %v, %v, %v, 0, %v, %v, %d, %d, %v, %s)", encodeString(workflow), encodeString(source.String()), encodeString(position), maxTPS, maxReplicationLag, - timeUpdated, binlogdatapb.VReplicationWorkflowState_Running.String(), encodeString(dbName), workflowType, + timeUpdated, encodeString(binlogdatapb.VReplicationWorkflowState_Running.String()), encodeString(dbName), workflowType, workflowSubType, deferSecondaryKeys, encodeString("{}")) } @@ -649,9 +649,9 @@ func CreateVReplicationState(workflow string, source *binlogdatapb.BinlogSource, protoutil.SortBinlogSourceTables(source) return fmt.Sprintf("insert into _vt.vreplication "+ "(workflow, source, pos, max_tps, max_replication_lag, time_updated, transaction_timestamp, state, db_name, workflow_type, workflow_sub_type, options) "+ - "values (%v, %v, %v, %v, %v, %v, 0, '%v', %v, %d, %d, %s)", + "values (%v, %v, %v, %v, %v, %v, 0, %v, %v, %d, %d, %s)", encodeString(workflow), encodeString(source.String()), encodeString(position), throttler.MaxRateModuleDisabled, - throttler.ReplicationLagModuleDisabled, time.Now().Unix(), state.String(), encodeString(dbName), + throttler.ReplicationLagModuleDisabled, time.Now().Unix(), encodeString(state.String()), encodeString(dbName), workflowType, workflowSubType, encodeString("{}")) } @@ -694,15 +694,15 @@ func GenerateUpdateTimeThrottled(uid int32, timeThrottledUnix int64, componentTh // StartVReplicationUntil returns a statement to start the replication with a stop position. func StartVReplicationUntil(uid int32, pos string) string { return fmt.Sprintf( - "update _vt.vreplication set state='%v', stop_pos=%v where id=%v", - binlogdatapb.VReplicationWorkflowState_Running.String(), encodeString(pos), uid) + "update _vt.vreplication set state=%v, stop_pos=%v where id=%v", + encodeString(binlogdatapb.VReplicationWorkflowState_Running.String()), encodeString(pos), uid) } // StopVReplication returns a statement to stop the replication. func StopVReplication(uid int32, message string) string { return fmt.Sprintf( - "update _vt.vreplication set state='%v', message=%v where id=%v", - binlogdatapb.VReplicationWorkflowState_Stopped.String(), encodeString(MessageTruncate(message)), uid) + "update _vt.vreplication set state=%v, message=%v where id=%v", + encodeString(binlogdatapb.VReplicationWorkflowState_Stopped.String()), encodeString(MessageTruncate(message)), uid) } // DeleteVReplication returns a statement to delete the replication. @@ -717,9 +717,7 @@ func MessageTruncate(msg string) string { } func encodeString(in string) string { - buf := bytes.NewBuffer(nil) - sqltypes.NewVarChar(in).EncodeSQL(buf) - return buf.String() + return sqltypes.EncodeStringSQL(in) } // ReadVReplicationPos returns a statement to query the gtid for a diff --git a/go/vt/vtctl/vdiff_env_test.go b/go/vt/vtctl/vdiff_env_test.go index fdcf29367cc..9b2fade3204 100644 --- a/go/vt/vtctl/vdiff_env_test.go +++ b/go/vt/vtctl/vdiff_env_test.go @@ -128,7 +128,7 @@ func newTestVDiffEnv(t testing.TB, ctx context.Context, sourceShards, targetShar // But this is one statement per stream. env.tmc.setVRResults( primary.tablet, - fmt.Sprintf("update _vt.vreplication set state='Running', stop_pos='%s', message='synchronizing for vdiff' where id=%d", vdiffSourceGtid, j+1), + fmt.Sprintf("update _vt.vreplication set state='Running', stop_pos=%s, message='synchronizing for vdiff' where id=%d", sqltypes.EncodeStringSQL(vdiffSourceGtid), j+1), &sqltypes.Result{}, ) } diff --git a/go/vt/vtctl/workflow/resharder.go b/go/vt/vtctl/workflow/resharder.go index c270a9a6f0b..81261f3e39e 100644 --- a/go/vt/vtctl/workflow/resharder.go +++ b/go/vt/vtctl/workflow/resharder.go @@ -296,7 +296,6 @@ func (rs *resharder) createStreams(ctx context.Context) error { if err != nil { return err } - optionsJSON = fmt.Sprintf("'%s'", optionsJSON) for _, source := range rs.sourceShards { if !key.KeyRangeIntersect(target.KeyRange, source.KeyRange) { continue diff --git a/go/vt/vtctl/workflow/traffic_switcher.go b/go/vt/vtctl/workflow/traffic_switcher.go index 3b151103d7f..76c5ac551f4 100644 --- a/go/vt/vtctl/workflow/traffic_switcher.go +++ b/go/vt/vtctl/workflow/traffic_switcher.go @@ -858,8 +858,8 @@ func (ts *trafficSwitcher) getReverseVReplicationUpdateQuery(targetCell string, } if ts.optCells != "" || ts.optTabletTypes != "" { - query := fmt.Sprintf("update _vt.vreplication set cell = '%s', tablet_types = '%s', options = '%s' where workflow = '%s' and db_name = '%s'", - ts.optCells, ts.optTabletTypes, options, ts.ReverseWorkflowName(), dbname) + query := fmt.Sprintf("update _vt.vreplication set cell = %s, tablet_types = %s, options = %s where workflow = %s and db_name = %s", + sqltypes.EncodeStringSQL(ts.optCells), sqltypes.EncodeStringSQL(ts.optTabletTypes), sqltypes.EncodeStringSQL(options), sqltypes.EncodeStringSQL(ts.ReverseWorkflowName()), sqltypes.EncodeStringSQL(dbname)) return query } return "" @@ -941,8 +941,8 @@ func (ts *trafficSwitcher) createReverseVReplication(ctx context.Context) error // For non-reference tables we return an error if there's no primary // vindex as it's not clear what to do. if len(vtable.ColumnVindexes) > 0 && len(vtable.ColumnVindexes[0].Columns) > 0 { - inKeyrange = fmt.Sprintf(" where in_keyrange(%s, '%s.%s', '%s')", sqlparser.String(vtable.ColumnVindexes[0].Columns[0]), - ts.SourceKeyspaceName(), vtable.ColumnVindexes[0].Name, key.KeyRangeString(source.GetShard().KeyRange)) + inKeyrange = fmt.Sprintf(" where in_keyrange(%s, %s, %s)", sqlparser.String(vtable.ColumnVindexes[0].Columns[0]), + sqlparser.String(sqlparser.NewTableNameWithQualifier(vtable.ColumnVindexes[0].Name, ts.SourceKeyspaceName())), sqltypes.EncodeStringSQL(key.KeyRangeString(source.GetShard().KeyRange))) } else { return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "no primary vindex found for the %s table in the %s keyspace", vtable.Name.String(), ts.SourceKeyspaceName()) @@ -1184,7 +1184,7 @@ func (ts *trafficSwitcher) freezeTargetVReplication(ctx context.Context) error { // re-invoked after a freeze, it will skip all the previous steps err := ts.ForAllTargets(func(target *MigrationTarget) error { ts.Logger().Infof("Marking target streams frozen for workflow %s db_name %s", ts.WorkflowName(), target.GetPrimary().DbName()) - query := fmt.Sprintf("update _vt.vreplication set message = '%s' where db_name=%s and workflow=%s", Frozen, + query := fmt.Sprintf("update _vt.vreplication set message = %s where db_name=%s and workflow=%s", encodeString(Frozen), encodeString(target.GetPrimary().DbName()), encodeString(ts.WorkflowName())) _, err := ts.TabletManagerClient().VReplicationExec(ctx, target.GetPrimary().Tablet, query) return err diff --git a/go/vt/vtctl/workflow/utils.go b/go/vt/vtctl/workflow/utils.go index 3c33ef25560..571c40b474f 100644 --- a/go/vt/vtctl/workflow/utils.go +++ b/go/vt/vtctl/workflow/utils.go @@ -17,7 +17,6 @@ limitations under the License. package workflow import ( - "bytes" "context" "encoding/json" "fmt" @@ -627,9 +626,7 @@ func ReverseWorkflowName(workflow string) string { // this public, but it doesn't belong in package workflow. Maybe package sqltypes, // or maybe package sqlescape? func encodeString(in string) string { - buf := bytes.NewBuffer(nil) - sqltypes.NewVarChar(in).EncodeSQL(buf) - return buf.String() + return sqltypes.EncodeStringSQL(in) } func getRenameFileName(tableName string) string { diff --git a/go/vt/vttablet/endtoend/vstreamer_test.go b/go/vt/vttablet/endtoend/vstreamer_test.go index 997ab222255..776d45dbc02 100644 --- a/go/vt/vttablet/endtoend/vstreamer_test.go +++ b/go/vt/vttablet/endtoend/vstreamer_test.go @@ -472,9 +472,7 @@ func expectLogs(ctx context.Context, t *testing.T, query string, eventCh chan [] } func encodeString(in string) string { - buf := bytes.NewBuffer(nil) - sqltypes.NewVarChar(in).EncodeSQL(buf) - return buf.String() + return sqltypes.EncodeStringSQL(in) } func validateSchemaInserted(client *framework.QueryClient, ddl string) bool { diff --git a/go/vt/vttablet/onlineddl/executor.go b/go/vt/vttablet/onlineddl/executor.go index 76c7af7fc2e..07119263399 100644 --- a/go/vt/vttablet/onlineddl/executor.go +++ b/go/vt/vttablet/onlineddl/executor.go @@ -1571,8 +1571,8 @@ func (e *Executor) ExecuteWithVReplication(ctx context.Context, onlineDDL *schem { // temporary hack. todo: this should be done when inserting any _vt.vreplication record across all workflow types - query := fmt.Sprintf("update _vt.vreplication set workflow_type = %d where workflow = '%s'", - binlogdatapb.VReplicationWorkflowType_OnlineDDL, v.workflow) + query := fmt.Sprintf("update _vt.vreplication set workflow_type = %d where workflow = %s", + binlogdatapb.VReplicationWorkflowType_OnlineDDL, sqltypes.EncodeStringSQL(v.workflow)) if _, err := e.vreplicationExec(ctx, tablet.Tablet, query); err != nil { return vterrors.Wrapf(err, "VReplicationExec(%v, %s)", tablet.Tablet, query) } diff --git a/go/vt/vttablet/tabletmanager/vdiff/utils.go b/go/vt/vttablet/tabletmanager/vdiff/utils.go index 68e8a6acb57..a19fce67c56 100644 --- a/go/vt/vttablet/tabletmanager/vdiff/utils.go +++ b/go/vt/vttablet/tabletmanager/vdiff/utils.go @@ -59,9 +59,7 @@ func newMergeSorter(participants map[string]*shardStreamer, comparePKs []compare // Utility functions func encodeString(in string) string { - var buf strings.Builder - sqltypes.NewVarChar(in).EncodeSQL(&buf) - return buf.String() + return sqltypes.EncodeStringSQL(in) } func pkColsToGroupByParams(pkCols []int, collationEnv *collations.Environment) []*engine.GroupByParams { diff --git a/go/vt/vttablet/tabletmanager/vreplication/insert_generator.go b/go/vt/vttablet/tabletmanager/vreplication/insert_generator.go index a43278d783c..86753456af1 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/insert_generator.go +++ b/go/vt/vttablet/tabletmanager/vreplication/insert_generator.go @@ -53,10 +53,10 @@ func NewInsertGenerator(state binlogdatapb.VReplicationWorkflowState, dbname str func (ig *InsertGenerator) AddRow(workflow string, bls *binlogdatapb.BinlogSource, pos, cell, tabletTypes string, workflowType binlogdatapb.VReplicationWorkflowType, workflowSubType binlogdatapb.VReplicationWorkflowSubType, deferSecondaryKeys bool, options string) { if options == "" { - options = "'{}'" + options = "{}" } protoutil.SortBinlogSourceTables(bls) - fmt.Fprintf(ig.buf, "%s(%v, %v, %v, %v, %v, %v, %v, %v, 0, '%v', %v, %d, %d, %v, %v)", + fmt.Fprintf(ig.buf, "%s(%v, %v, %v, %v, %v, %v, %v, %v, 0, %v, %v, %d, %d, %v, %v)", ig.prefix, encodeString(workflow), encodeString(bls.String()), @@ -66,12 +66,12 @@ func (ig *InsertGenerator) AddRow(workflow string, bls *binlogdatapb.BinlogSourc encodeString(cell), encodeString(tabletTypes), ig.now, - ig.state, + encodeString(ig.state), encodeString(ig.dbname), workflowType, workflowSubType, deferSecondaryKeys, - options, + encodeString(options), ) ig.prefix = ", " } diff --git a/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go b/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go index 76177b56b5b..54e76efa092 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go +++ b/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go @@ -508,7 +508,7 @@ func (vr *vreplicator) setState(state binlogdatapb.VReplicationWorkflowState, me }) } vr.stats.State.Store(state.String()) - query := fmt.Sprintf("update _vt.vreplication set state='%v', message=%v where id=%v", state, encodeString(binlogplayer.MessageTruncate(message)), vr.id) + query := fmt.Sprintf("update _vt.vreplication set state=%v, message=%v where id=%v", encodeString(state.String()), encodeString(binlogplayer.MessageTruncate(message)), vr.id) // If we're batching a transaction, then include the state update // in the current transaction batch. if vr.dbClient.InTransaction && vr.dbClient.maxBatchSize > 0 { @@ -528,9 +528,7 @@ func (vr *vreplicator) setState(state binlogdatapb.VReplicationWorkflowState, me } func encodeString(in string) string { - var buf strings.Builder - sqltypes.NewVarChar(in).EncodeSQL(&buf) - return buf.String() + return sqltypes.EncodeStringSQL(in) } func (vr *vreplicator) getSettingFKCheck() error { diff --git a/go/vt/vttablet/tabletserver/schema/tracker.go b/go/vt/vttablet/tabletserver/schema/tracker.go index 252a81f3493..82ab31addb5 100644 --- a/go/vt/vttablet/tabletserver/schema/tracker.go +++ b/go/vt/vttablet/tabletserver/schema/tracker.go @@ -243,9 +243,7 @@ func (tr *Tracker) saveCurrentSchemaToDb(ctx context.Context, gtid, ddl string, } func encodeString(in string) string { - buf := bytes.NewBuffer(nil) - sqltypes.NewVarChar(in).EncodeSQL(buf) - return buf.String() + return sqltypes.EncodeStringSQL(in) } // MustReloadSchemaOnDDL returns true if the ddl is for the db which is part of the workflow and is not an online ddl artifact diff --git a/go/vt/vttablet/tabletserver/vstreamer/vstreamer.go b/go/vt/vttablet/tabletserver/vstreamer/vstreamer.go index fb4cb324047..237b2e78386 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/vstreamer.go +++ b/go/vt/vttablet/tabletserver/vstreamer/vstreamer.go @@ -960,9 +960,7 @@ type extColInfo struct { } func encodeString(in string) string { - buf := bytes.NewBuffer(nil) - sqltypes.NewVarChar(in).EncodeSQL(buf) - return buf.String() + return sqltypes.EncodeStringSQL(in) } func (vs *vstreamer) processJournalEvent(vevents []*binlogdatapb.VEvent, plan *streamerPlan, rows mysql.Rows) ([]*binlogdatapb.VEvent, error) { diff --git a/go/vt/wrangler/keyspace.go b/go/vt/wrangler/keyspace.go index 98551a084c9..e6bc451ff1b 100644 --- a/go/vt/wrangler/keyspace.go +++ b/go/vt/wrangler/keyspace.go @@ -125,7 +125,5 @@ func (wr *Wrangler) updateShardRecords(ctx context.Context, keyspace string, sha } func encodeString(in string) string { - buf := bytes.NewBuffer(nil) - sqltypes.NewVarChar(in).EncodeSQL(buf) - return buf.String() + return sqltypes.EncodeStringSQL(in) }