Skip to content

Commit

Permalink
we still have to rewrite expressions in derived tables
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Jan 25, 2025
1 parent 0217996 commit 33bcae8
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 25 deletions.
62 changes: 37 additions & 25 deletions go/vt/sqlparser/normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package sqlparser

import (
"bytes"
"fmt"
"strconv"
"strings"

Expand Down Expand Up @@ -59,8 +60,8 @@ type (
sysVars map[string]string
views VSchemaViews

onLeave map[*AliasedExpr]func(*AliasedExpr)
shouldParameterize bool
onLeave map[*AliasedExpr]func(*AliasedExpr)
parameterize bool
}
// RewriteASTResult holds the result of rewriting the AST, including bind variable needs.
RewriteASTResult struct {
Expand Down Expand Up @@ -137,32 +138,33 @@ func newNormalizer(
parameterize bool,
) *normalizer {
return &normalizer{
bindVars: bindVars,
reserved: reserved,
vals: make(map[Literal]string),
bindVarNeeds: &BindVarNeeds{},
keyspace: keyspace,
selectLimit: selectLimit,
setVarComment: setVarComment,
fkChecksState: fkChecksState,
sysVars: sysVars,
views: views,
onLeave: make(map[*AliasedExpr]func(*AliasedExpr)),
shouldParameterize: parameterize,
bindVars: bindVars,
reserved: reserved,
vals: make(map[Literal]string),
bindVarNeeds: &BindVarNeeds{},
keyspace: keyspace,
selectLimit: selectLimit,
setVarComment: setVarComment,
fkChecksState: fkChecksState,
sysVars: sysVars,
views: views,
onLeave: make(map[*AliasedExpr]func(*AliasedExpr)),
parameterize: parameterize,
}
}

// walkDown processes nodes when traversing down the AST.
// It handles normalization logic based on node types.
func (nz *normalizer) walkDown(node, _ SQLNode) bool {
fmt.Printf("DOWN %T\n", node)
switch node := node.(type) {
case *Begin, *Commit, *Rollback, *Savepoint, *SRollback, *Release, *OtherAdmin, *Analyze, *AssignmentExpr,
*PrepareStmt, *ExecuteStmt, *FramePoint, *ColName, TableName, *ConvertType:
// These statement don't need normalizing
return false
case *Set:
// Disable parameterization within SET statements.
nz.shouldParameterize = false
nz.parameterize = false
case *DerivedTable:
nz.inDerived++
case *Select:
Expand All @@ -172,8 +174,8 @@ func (nz *normalizer) walkDown(node, _ SQLNode) bool {
}
case *AliasedExpr:
nz.noteAliasedExprName(node)
case SelectExprs:
return nz.inDerived == 0
//case SelectExprs:
// return nz.inDerived == 0
case *ComparisonExpr:
nz.convertComparison(node)
case *UpdateExpr:
Expand All @@ -191,7 +193,11 @@ func (nz *normalizer) walkDown(node, _ SQLNode) bool {
nz.bindVarNeeds.AddSysVar(sysVar)
}
}
return nz.err == nil
b := nz.err == nil
if !b {
fmt.Println(1)
}
return b
}

// noteAliasedExprName tracks expressions without aliases to add alias if expression is rewritten
Expand All @@ -212,6 +218,7 @@ func (nz *normalizer) noteAliasedExprName(node *AliasedExpr) {
// walkUp processes nodes when traversing up the AST.
// It finalizes normalization logic based on node types.
func (nz *normalizer) walkUp(cursor *Cursor) bool {
fmt.Printf("UP %T\n", cursor.Node())
// Add SET_VAR comments if applicable.
if supportOptimizerHint, supports := cursor.Node().(SupportOptimizerHint); supports {
if nz.setVarComment != "" {
Expand Down Expand Up @@ -241,6 +248,7 @@ func (nz *normalizer) walkUp(cursor *Cursor) bool {
// if we are tracking this node for changes, this is the time to add the alias if needed
if onLeave, ok := nz.onLeave[node]; ok {
onLeave(node)
delete(nz.onLeave, node)
}
case *Union:
nz.rewriteUnion(node)
Expand All @@ -267,7 +275,7 @@ func (nz *normalizer) walkUp(cursor *Cursor) bool {
}

func (nz *normalizer) visitLiteral(cursor *Cursor, node *Literal) {
if !nz.shouldParameterize {
if !nz.shouldParameterize() {
return
}
if nz.inSelect == 0 {
Expand Down Expand Up @@ -370,15 +378,15 @@ func (nz *normalizer) convertComparison(node *ComparisonExpr) {

// rewriteOtherComparisons parameterizes non-IN comparison expressions.
func (nz *normalizer) rewriteOtherComparisons(node *ComparisonExpr) {
newR := nz.parameterize(node.Left, node.Right)
newR := nz.normalizeComparisonWithBindVar(node.Left, node.Right)
if newR != nil {
node.Right = newR
}
}

// parameterize attempts to replace a literal in a comparison with a bind variable.
func (nz *normalizer) parameterize(left, right Expr) Expr {
if !nz.shouldParameterize {
// normalizeComparisonWithBindVar attempts to replace a literal in a comparison with a bind variable.
func (nz *normalizer) normalizeComparisonWithBindVar(left, right Expr) Expr {
if !nz.shouldParameterize() {
return nil
}
col, ok := left.(*ColName)
Expand Down Expand Up @@ -424,7 +432,7 @@ func (nz *normalizer) decideBindVarName(lit *Literal, col *ColName, bval *queryp

// rewriteInComparisons converts IN and NOT IN expressions to use list bind variables.
func (nz *normalizer) rewriteInComparisons(node *ComparisonExpr) {
if !nz.shouldParameterize {
if !nz.shouldParameterize() {
return
}
tupleVals, ok := node.Right.(ValTuple)
Expand Down Expand Up @@ -453,7 +461,7 @@ func (nz *normalizer) rewriteInComparisons(node *ComparisonExpr) {

// convertUpdateExpr parameterizes expressions in UPDATE statements.
func (nz *normalizer) convertUpdateExpr(node *UpdateExpr) {
newR := nz.parameterize(node.Name, node.Expr)
newR := nz.normalizeComparisonWithBindVar(node.Name, node.Expr)
if newR != nil {
node.Expr = newR
}
Expand Down Expand Up @@ -824,6 +832,10 @@ func (nz *normalizer) rewriteDistinctableAggr(node DistinctableAggr) {
}
}

func (nz *normalizer) shouldParameterize() bool {
return !(nz.inDerived > 0 && len(nz.onLeave) > 0) && nz.parameterize
}

// SystemSchema checks if the given schema is a system schema.
func SystemSchema(schema string) bool {
return strings.EqualFold(schema, "information_schema") ||
Expand Down
4 changes: 4 additions & 0 deletions go/vt/sqlparser/normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,10 @@ func TestRewrites(in *testing.T) {
in: `select (select (select (select (select (select last_insert_id()))))) as x`,
expected: "select :__lastInsertId as x from dual",
liid: true,
}, {
in: `select * from (select last_insert_id()) as t`,
expected: "select * from (select :__lastInsertId as `last_insert_id()` from dual) as t",
liid: true,
}, {
in: `select * from user where col = @@ddl_strategy`,
expected: "select * from user where col = :__vtddl_strategy",
Expand Down

0 comments on commit 33bcae8

Please sign in to comment.