diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 7e493ec5618..e28c17a4745 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -2429,48 +2429,38 @@ func RemoveKeyspaceInCol(in SQLNode) { }, in) } -// RemoveKeyspaceInTables removes the Qualifier on all TableNames in the AST -func RemoveKeyspaceInTables(in SQLNode) { - // Walk will only return an error if we return an error from the inner func. safe to ignore here - Rewrite(in, nil, func(cursor *Cursor) bool { - if tbl, ok := cursor.Node().(TableName); ok && tbl.Qualifier.NotEmpty() { - tbl.Qualifier = NewIdentifierCS("") - cursor.Replace(tbl) - } - - return true +// RemoveKeyspace removes the keyspace qualifier from all ColName and TableName +func RemoveKeyspace(in SQLNode) { + removeKeyspace(in, func(_ string) bool { + return true // Always remove }) } -// RemoveKeyspace removes the Qualifier.Qualifier on all ColNames and Qualifier on all TableNames in the AST -func RemoveKeyspace(in SQLNode) { - Rewrite(in, nil, func(cursor *Cursor) bool { - switch expr := cursor.Node().(type) { - case *ColName: - if expr.Qualifier.Qualifier.NotEmpty() { - expr.Qualifier.Qualifier = NewIdentifierCS("") - } - case TableName: - if expr.Qualifier.NotEmpty() { - expr.Qualifier = NewIdentifierCS("") - cursor.Replace(expr) - } - } - return true +// RemoveSpecificKeyspace removes the keyspace qualifier from all ColName and TableName +// when it matches the keyspace provided +func RemoveSpecificKeyspace(in SQLNode, keyspace string) { + removeKeyspace(in, func(qualifier string) bool { + return qualifier == keyspace // Remove only if it matches the provided keyspace }) } -// RemoveKeyspaceIgnoreSysSchema removes the Qualifier.Qualifier on all ColNames and Qualifier on all TableNames in the AST -// except for the system schema. +// RemoveKeyspaceIgnoreSysSchema removes the keyspace qualifier from all ColName and TableName +// except for the system schema qualifier. func RemoveKeyspaceIgnoreSysSchema(in SQLNode) { + removeKeyspace(in, func(qualifier string) bool { + return qualifier != "" && !SystemSchema(qualifier) // Remove if it's not empty and not a system schema + }) +} + +func removeKeyspace(in SQLNode, shouldRemove func(qualifier string) bool) { Rewrite(in, nil, func(cursor *Cursor) bool { switch expr := cursor.Node().(type) { case *ColName: - if expr.Qualifier.Qualifier.NotEmpty() && !SystemSchema(expr.Qualifier.Qualifier.String()) { + if shouldRemove(expr.Qualifier.Qualifier.String()) { expr.Qualifier.Qualifier = NewIdentifierCS("") } case TableName: - if expr.Qualifier.NotEmpty() && !SystemSchema(expr.Qualifier.String()) { + if shouldRemove(expr.Qualifier.String()) { expr.Qualifier = NewIdentifierCS("") cursor.Replace(expr) } diff --git a/go/vt/sqlparser/ast_funcs_test.go b/go/vt/sqlparser/ast_funcs_test.go index a3b744729f4..009f3e6cc15 100644 --- a/go/vt/sqlparser/ast_funcs_test.go +++ b/go/vt/sqlparser/ast_funcs_test.go @@ -219,3 +219,28 @@ func TestExtractTables(t *testing.T) { }) } } + +// TestRemoveKeyspace tests the RemoveKeyspaceIgnoreSysSchema function. +// It removes all the keyspace except system schema. +func TestRemoveKeyspaceIgnoreSysSchema(t *testing.T) { + stmt, err := NewTestParser().Parse("select 1 from uks.unsharded join information_schema.tables") + require.NoError(t, err) + RemoveKeyspaceIgnoreSysSchema(stmt) + + require.Equal(t, "select 1 from unsharded join information_schema.`tables`", String(stmt)) +} + +// TestRemoveSpecificKeyspace tests the RemoveSpecificKeyspace function. +// It removes the specific keyspace from the database qualifier. +func TestRemoveSpecificKeyspace(t *testing.T) { + stmt, err := NewTestParser().Parse("select 1 from uks.unsharded") + require.NoError(t, err) + + // does not match + RemoveSpecificKeyspace(stmt, "ks2") + require.Equal(t, "select 1 from uks.unsharded", String(stmt)) + + // match + RemoveSpecificKeyspace(stmt, "uks") + require.Equal(t, "select 1 from unsharded", String(stmt)) +} diff --git a/go/vt/sqlparser/ast_test.go b/go/vt/sqlparser/ast_test.go index c1484df7cc4..f01b47cbd7b 100644 --- a/go/vt/sqlparser/ast_test.go +++ b/go/vt/sqlparser/ast_test.go @@ -917,11 +917,3 @@ func TestCloneComments(t *testing.T) { assert.Equal(t, "b", val) } } - -func TestRemoveKeyspace(t *testing.T) { - stmt, err := NewTestParser().Parse("select 1 from uks.unsharded") - require.NoError(t, err) - RemoveKeyspaceIgnoreSysSchema(stmt) - - require.Equal(t, "select 1 from unsharded", String(stmt)) -} diff --git a/go/vt/vtgate/planbuilder/builder.go b/go/vt/vtgate/planbuilder/builder.go index ca4ccb7ac5a..065c50a6dfa 100644 --- a/go/vt/vtgate/planbuilder/builder.go +++ b/go/vt/vtgate/planbuilder/builder.go @@ -43,10 +43,6 @@ const ( Gen4Left2Right = querypb.ExecuteOptions_Gen4Left2Right ) -var ( - plannerVersions = []plancontext.PlannerVersion{Gen4, Gen4GreedyOnly, Gen4Left2Right} -) - type ( planResult struct { primitive engine.Primitive diff --git a/go/vt/vtgate/planbuilder/bypass.go b/go/vt/vtgate/planbuilder/bypass.go index d3384d509c1..6e3f64990d7 100644 --- a/go/vt/vtgate/planbuilder/bypass.go +++ b/go/vt/vtgate/planbuilder/bypass.go @@ -56,6 +56,8 @@ func buildPlanForBypass(stmt sqlparser.Statement, _ *sqlparser.ReservedVars, vsc } } + sqlparser.RemoveSpecificKeyspace(stmt, keyspace.Name) + send := &engine.Send{ Keyspace: keyspace, TargetDestination: vschema.Destination(), diff --git a/go/vt/vtgate/planbuilder/testdata/bypass_keyrange_cases.json b/go/vt/vtgate/planbuilder/testdata/bypass_keyrange_cases.json index b13bafd77f8..e1ff7de97a0 100644 --- a/go/vt/vtgate/planbuilder/testdata/bypass_keyrange_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/bypass_keyrange_cases.json @@ -164,5 +164,22 @@ "Query": "create /* test */ table t1(id bigint, primary key(id)) /* comments */" } } + }, + { + "comment": "remove the matching keyspace from shard targeted query", + "query": "select count(*), col from `main`.unsharded join vt_main.t1 where exists (select 1 from main.t2 join information_schema.tables where table_name = 't3')", + "plan": { + "QueryType": "SELECT", + "Original": "select count(*), col from `main`.unsharded join vt_main.t1 where exists (select 1 from main.t2 join information_schema.tables where table_name = 't3')", + "Instructions": { + "OperatorType": "Send", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "TargetDestination": "ExactKeyRange(-)", + "Query": "select count(*), col from unsharded join vt_main.t1 where exists (select 1 from t2 join information_schema.`tables` where table_name = 't3')" + } + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/bypass_shard_cases.json b/go/vt/vtgate/planbuilder/testdata/bypass_shard_cases.json index 02a00444724..80d4dbbd08b 100644 --- a/go/vt/vtgate/planbuilder/testdata/bypass_shard_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/bypass_shard_cases.json @@ -251,5 +251,22 @@ "QueryTimeout": 100 } } + }, + { + "comment": "remove the matching keyspace from shard targeted query", + "query": "select count(*), col from `main`.unsharded join vt_main.t1 where exists (select 1 from main.t2 join information_schema.tables where table_name = 't3')", + "plan": { + "QueryType": "SELECT", + "Original": "select count(*), col from `main`.unsharded join vt_main.t1 where exists (select 1 from main.t2 join information_schema.tables where table_name = 't3')", + "Instructions": { + "OperatorType": "Send", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "TargetDestination": "Shard(-80)", + "Query": "select count(*), col from unsharded join vt_main.t1 where exists (select 1 from t2 join information_schema.`tables` where table_name = 't3')" + } + } } ]