diff --git a/go/tools/asthelpergen/integration/ast_equals.go b/go/tools/asthelpergen/integration/ast_equals.go index 5e7b068671c..f3a4b6da9de 100644 --- a/go/tools/asthelpergen/integration/ast_equals.go +++ b/go/tools/asthelpergen/integration/ast_equals.go @@ -191,7 +191,8 @@ func (cmp *Comparator) RefOfRefSliceContainer(a, b *RefSliceContainer) bool { if a == nil || b == nil { return false } - return cmp.SliceOfAST(a.ASTElements, b.ASTElements) && + return a.something == b.something && + cmp.SliceOfAST(a.ASTElements, b.ASTElements) && cmp.SliceOfInt(a.NotASTElements, b.NotASTElements) && cmp.SliceOfRefOfLeaf(a.ASTImplementationElements, b.ASTImplementationElements) } diff --git a/go/tools/asthelpergen/integration/ast_rewrite.go b/go/tools/asthelpergen/integration/ast_rewrite.go index d4283711792..3f14d19d814 100644 --- a/go/tools/asthelpergen/integration/ast_rewrite.go +++ b/go/tools/asthelpergen/integration/ast_rewrite.go @@ -120,6 +120,10 @@ func (a *application) rewriteInterfaceSlice(parent AST, node InterfaceSlice, rep return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteAST(node, el, func(idx int) replacerFunc { return func(newNode, parent AST) { @@ -129,6 +133,9 @@ func (a *application) rewriteInterfaceSlice(parent AST, node InterfaceSlice, rep return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -181,6 +188,10 @@ func (a *application) rewriteLeafSlice(parent AST, node LeafSlice, replacer repl return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteRefOfLeaf(node, el, func(idx int) replacerFunc { return func(newNode, parent AST) { @@ -190,6 +201,9 @@ func (a *application) rewriteLeafSlice(parent AST, node LeafSlice, replacer repl return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -280,6 +294,9 @@ func (a *application) rewriteRefOfRefSliceContainer(parent AST, node *RefSliceCo } } var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node.ASTElements { if !a.rewriteAST(node, el, func(idx int) replacerFunc { return func(newNode, parent AST) { @@ -396,6 +413,9 @@ func (a *application) rewriteValueSliceContainer(parent AST, node ValueSliceCont } } var path ASTPath + if a.collectPaths { + path = a.cur.current + } for _, el := range node.ASTElements { if !a.rewriteAST(node, el, func(newNode, parent AST) { panic("[BUG] tried to replace 'ASTElements' on 'ValueSliceContainer'") @@ -537,6 +557,9 @@ func (a *application) rewriteRefOfValueSliceContainer(parent AST, node *ValueSli } } var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node.ASTElements { if !a.rewriteAST(node, el, func(idx int) replacerFunc { return func(newNode, parent AST) { diff --git a/go/tools/asthelpergen/integration/types.go b/go/tools/asthelpergen/integration/types.go index c090a1af327..0f3e9a146db 100644 --- a/go/tools/asthelpergen/integration/types.go +++ b/go/tools/asthelpergen/integration/types.go @@ -41,6 +41,7 @@ type ( } // Container implements the interface ByRef RefSliceContainer struct { + something int // want a non-AST field first ASTElements []AST NotASTElements []int ASTImplementationElements []*Leaf diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go index 951188e0a4e..d7774bb7454 100644 --- a/go/tools/asthelpergen/rewrite_gen.go +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -197,23 +197,40 @@ func (r *rewriteGen) sliceMethod(t types.Type, slice *types.Slice, spi generator jen.If(jen.Id("kontinue").Block(jen.Return(jen.True()))), ) - stmts = append(stmts, jen.If(jen.Id("a.pre!= nil").Block(preStmts...))) + stmts = append(stmts, jen.If(jen.Id("a.pre != nil").Block(preStmts...))) haveChildren := false if shouldAdd(slice.Elem(), spi.iface()) { /* + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for i, el := range node { - if err := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { - parent.(LeafSlice)[i] = newNode.(*Leaf) - }, pre, post); err != nil { - return err - } - } + if err := rewriteRefOfLeaf(node, el, func(newNode, parent AST) { + parent.(LeafSlice)[i] = newNode.(*Leaf) + }, pre, post); err != nil { + return err + } + } */ haveChildren = true + forBlock := []jen.Code{ + jen.Var().Id("path").Id("ASTPath"), + jen.If(jen.Id("a.collectPaths")).Block( + jen.Id("path").Op("=").Id("a.cur.current"), + ), + } + stmts = append(stmts, forBlock...) + rewriteChild := r.rewriteChildSlice(t, slice.Elem(), "notUsed", jen.Id("el"), jen.Index(jen.Id("idx")), false, 0) + stmts = append(stmts, jen.For(jen.Id("x, el").Op(":=").Id("range node")). - Block(r.rewriteChildSlice(t, slice.Elem(), "notUsed", jen.Id("el"), jen.Index(jen.Id("idx")), false))) + Block(rewriteChild...)) + + stmts = append(stmts, jen.If(jen.Id("a.collectPaths")).Block( + jen.Id("a.cur.current").Op("=").Id("path"), + )) } stmts = append(stmts, executePost(haveChildren)) @@ -298,16 +315,26 @@ func (r *rewriteGen) rewriteAllStructFields(t types.Type, strct *types.Struct, s */ var output []jen.Code + fieldNumber := 0 for i := 0; i < strct.NumFields(); i++ { field := strct.Field(i) if types.Implements(field.Type(), spi.iface()) { spi.addType(field.Type()) - rewriteLines := r.rewriteChild(t, field.Type(), field.Name(), jen.Id("node").Dot(field.Name()), jen.Dot(field.Name()), fail, i) + rewriteLines := r.rewriteChild(t, field.Type(), field.Name(), jen.Id("node").Dot(field.Name()), jen.Dot(field.Name()), fail, fieldNumber) + fieldNumber++ output = append(output, rewriteLines...) continue } slice, isSlice := field.Type().(*types.Slice) if isSlice && types.Implements(slice.Elem(), spi.iface()) { + if fieldNumber == 0 { + // if this is the first field we are dealing with, we need to store the incoming + // path into the local variable first of all + // if a.collectPaths { path = a.cur.current } + output = append(output, + jen.If(jen.Id("a.collectPaths")).Block(jen.Id("path").Op("=").Id("a.cur.current"))) + } + spi.addType(slice.Elem()) id := jen.Id("x") if fail { @@ -315,7 +342,8 @@ func (r *rewriteGen) rewriteAllStructFields(t types.Type, strct *types.Struct, s } output = append(output, jen.For(jen.List(id, jen.Id("el")).Op(":=").Id("range node."+field.Name())). - Block(r.rewriteChildSlice(t, slice.Elem(), field.Name(), jen.Id("el"), jen.Dot(field.Name()).Index(jen.Id("idx")), fail))) + Block(r.rewriteChildSlice(t, slice.Elem(), field.Name(), jen.Id("el"), jen.Dot(field.Name()).Index(jen.Id("idx")), fail, fieldNumber)...)) + fieldNumber++ } } return output @@ -380,7 +408,7 @@ func (r *rewriteGen) rewriteChild(t, field types.Type, fieldName string, param j } } -func (r *rewriteGen) rewriteChildSlice(t, field types.Type, fieldName string, param jen.Code, replace jen.Code, fail bool) jen.Code { +func (r *rewriteGen) rewriteChildSlice(t, field types.Type, fieldName string, param jen.Code, replace jen.Code, fail bool, fieldOffset int) []jen.Code { /* if errF := a.rewriteAST(node, el, func(idx int) replacerFunc { return func(newNode, parent AST) { @@ -416,7 +444,9 @@ func (r *rewriteGen) rewriteChildSlice(t, field types.Type, fieldName string, pa param, funcBlock).Block(returnFalse())) - return rewriteField + return []jen.Code{ + rewriteField, + } } var noQualifier = func(p *types.Package) string { diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index aeeaf53388e..1f06fb1b1db 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -582,6 +582,9 @@ func (a *application) rewriteRefOfAddColumns(parent SQLNode, node *AddColumns, r } } var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node.Columns { if !a.rewriteRefOfColumnDefinition(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -986,6 +989,7 @@ func (a *application) rewriteRefOfAlterMigration(parent SQLNode, node *AlterMigr } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfAlterMigrationRatio) } if !a.rewriteRefOfLiteral(node, node.Ratio, func(newNode, parent SQLNode) { @@ -1155,6 +1159,7 @@ func (a *application) rewriteRefOfAlterVschema(parent SQLNode, node *AlterVschem } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfAlterVschemaTable) } if !a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { @@ -1214,6 +1219,7 @@ func (a *application) rewriteRefOfAnalyze(parent SQLNode, node *Analyze, replace } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfAnalyzeTable) } if !a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { @@ -1370,6 +1376,7 @@ func (a *application) rewriteRefOfArgumentLessWindowExpr(parent SQLNode, node *A } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfArgumentLessWindowExprOverClause) } if !a.rewriteRefOfOverClause(node, node.OverClause, func(newNode, parent SQLNode) { @@ -1572,6 +1579,7 @@ func (a *application) rewriteRefOfBetweenExpr(parent SQLNode, node *BetweenExpr, } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfBetweenExprLeft) } if !a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { @@ -1627,6 +1635,7 @@ func (a *application) rewriteRefOfBinaryExpr(parent SQLNode, node *BinaryExpr, r } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfBinaryExprLeft) } if !a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { @@ -2246,6 +2255,10 @@ func (a *application) rewriteColumns(parent SQLNode, node Columns, replacer repl return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteIdentifierCI(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -2255,6 +2268,9 @@ func (a *application) rewriteColumns(parent SQLNode, node Columns, replacer repl return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -2383,6 +2399,7 @@ func (a *application) rewriteRefOfComparisonExpr(parent SQLNode, node *Compariso } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfComparisonExprLeft) } if !a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { @@ -2641,6 +2658,7 @@ func (a *application) rewriteRefOfCountStar(parent SQLNode, node *CountStar, rep } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfCountStarOverClause) } if !a.rewriteRefOfOverClause(node, node.OverClause, func(newNode, parent SQLNode) { @@ -2718,6 +2736,7 @@ func (a *application) rewriteRefOfCreateTable(parent SQLNode, node *CreateTable, } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfCreateTableTable) } if !a.rewriteTableName(node, node.Table, func(newNode, parent SQLNode) { @@ -3071,6 +3090,7 @@ func (a *application) rewriteRefOfDerivedTable(parent SQLNode, node *DerivedTabl } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfDerivedTableSelect) } if !a.rewriteTableStatement(node, node.Select, func(newNode, parent SQLNode) { @@ -3183,6 +3203,7 @@ func (a *application) rewriteRefOfDropKey(parent SQLNode, node *DropKey, replace } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfDropKeyName) } if !a.rewriteIdentifierCI(node, node.Name, func(newNode, parent SQLNode) { @@ -3217,6 +3238,7 @@ func (a *application) rewriteRefOfDropTable(parent SQLNode, node *DropTable, rep } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfDropTableFromTables) } if !a.rewriteTableNames(node, node.FromTables, func(newNode, parent SQLNode) { @@ -3394,6 +3416,7 @@ func (a *application) rewriteRefOfExplainStmt(parent SQLNode, node *ExplainStmt, } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfExplainStmtStatement) } if !a.rewriteStatement(node, node.Statement, func(newNode, parent SQLNode) { @@ -3475,6 +3498,10 @@ func (a *application) rewriteExprs(parent SQLNode, node Exprs, replacer replacer return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteExpr(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -3484,6 +3511,9 @@ func (a *application) rewriteExprs(parent SQLNode, node Exprs, replacer replacer return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -3513,6 +3543,7 @@ func (a *application) rewriteRefOfExtractFuncExpr(parent SQLNode, node *ExtractF } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfExtractFuncExprExpr) } if !a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -3600,6 +3631,7 @@ func (a *application) rewriteRefOfFirstOrLastValueExpr(parent SQLNode, node *Fir } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfFirstOrLastValueExprExpr) } if !a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -3650,6 +3682,7 @@ func (a *application) rewriteRefOfFlush(parent SQLNode, node *Flush, replacer re } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfFlushTableNames) } if !a.rewriteTableNames(node, node.TableNames, func(newNode, parent SQLNode) { @@ -3759,6 +3792,7 @@ func (a *application) rewriteRefOfFrameClause(parent SQLNode, node *FrameClause, } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfFrameClauseStart) } if !a.rewriteRefOfFramePoint(node, node.Start, func(newNode, parent SQLNode) { @@ -3801,6 +3835,7 @@ func (a *application) rewriteRefOfFramePoint(parent SQLNode, node *FramePoint, r } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfFramePointExpr) } if !a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -3920,6 +3955,7 @@ func (a *application) rewriteRefOfGTIDFuncExpr(parent SQLNode, node *GTIDFuncExp } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfGTIDFuncExprSet1) } if !a.rewriteExpr(node, node.Set1, func(newNode, parent SQLNode) { @@ -4143,6 +4179,7 @@ func (a *application) rewriteRefOfGeomCollPropertyFuncExpr(parent SQLNode, node } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfGeomCollPropertyFuncExprGeomColl) } if !a.rewriteExpr(node, node.GeomColl, func(newNode, parent SQLNode) { @@ -4190,6 +4227,7 @@ func (a *application) rewriteRefOfGeomFormatExpr(parent SQLNode, node *GeomForma } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfGeomFormatExprGeom) } if !a.rewriteExpr(node, node.Geom, func(newNode, parent SQLNode) { @@ -4237,6 +4275,7 @@ func (a *application) rewriteRefOfGeomFromGeoHashExpr(parent SQLNode, node *Geom } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfGeomFromGeoHashExprGeoHash) } if !a.rewriteExpr(node, node.GeoHash, func(newNode, parent SQLNode) { @@ -4340,6 +4379,7 @@ func (a *application) rewriteRefOfGeomFromTextExpr(parent SQLNode, node *GeomFro } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfGeomFromTextExprWktText) } if !a.rewriteExpr(node, node.WktText, func(newNode, parent SQLNode) { @@ -4395,6 +4435,7 @@ func (a *application) rewriteRefOfGeomFromWKBExpr(parent SQLNode, node *GeomFrom } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfGeomFromWKBExprWkbBlob) } if !a.rewriteExpr(node, node.WkbBlob, func(newNode, parent SQLNode) { @@ -4450,6 +4491,7 @@ func (a *application) rewriteRefOfGeomPropertyFuncExpr(parent SQLNode, node *Geo } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfGeomPropertyFuncExprGeom) } if !a.rewriteExpr(node, node.Geom, func(newNode, parent SQLNode) { @@ -4483,6 +4525,9 @@ func (a *application) rewriteRefOfGroupBy(parent SQLNode, node *GroupBy, replace } } var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node.Exprs { if !a.rewriteExpr(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -4524,6 +4569,7 @@ func (a *application) rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupCon } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfGroupConcatExprExprs) } if !a.rewriteExprs(node, node.Exprs, func(newNode, parent SQLNode) { @@ -4650,6 +4696,9 @@ func (a *application) rewriteRefOfIndexHint(parent SQLNode, node *IndexHint, rep } } var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node.Indexes { if !a.rewriteIdentifierCI(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -4690,6 +4739,10 @@ func (a *application) rewriteIndexHints(parent SQLNode, node IndexHints, replace return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteRefOfIndexHint(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -4699,6 +4752,9 @@ func (a *application) rewriteIndexHints(parent SQLNode, node IndexHints, replace return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -4723,6 +4779,7 @@ func (a *application) rewriteRefOfIndexInfo(parent SQLNode, node *IndexInfo, rep } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfIndexInfoName) } if !a.rewriteIdentifierCI(node, node.Name, func(newNode, parent SQLNode) { @@ -4765,6 +4822,7 @@ func (a *application) rewriteRefOfInsert(parent SQLNode, node *Insert, replacer } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfInsertComments) } if !a.rewriteRefOfParsedComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -4916,6 +4974,7 @@ func (a *application) rewriteRefOfIntervalDateExpr(parent SQLNode, node *Interva } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfIntervalDateExprDate) } if !a.rewriteExpr(node, node.Date, func(newNode, parent SQLNode) { @@ -5011,6 +5070,7 @@ func (a *application) rewriteRefOfIntroducerExpr(parent SQLNode, node *Introduce } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfIntroducerExprExpr) } if !a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -5178,6 +5238,7 @@ func (a *application) rewriteRefOfJSONAttributesExpr(parent SQLNode, node *JSONA } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfJSONAttributesExprJSONDoc) } if !a.rewriteExpr(node, node.JSONDoc, func(newNode, parent SQLNode) { @@ -5491,6 +5552,9 @@ func (a *application) rewriteRefOfJSONObjectExpr(parent SQLNode, node *JSONObjec } } var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node.Params { if !a.rewriteRefOfJSONObjectParam(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -6172,6 +6236,7 @@ func (a *application) rewriteRefOfJSONValueMergeExpr(parent SQLNode, node *JSONV } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfJSONValueMergeExprJSONDoc) } if !a.rewriteExpr(node, node.JSONDoc, func(newNode, parent SQLNode) { @@ -6219,6 +6284,7 @@ func (a *application) rewriteRefOfJSONValueModifierExpr(parent SQLNode, node *JS } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfJSONValueModifierExprJSONDoc) } if !a.rewriteExpr(node, node.JSONDoc, func(newNode, parent SQLNode) { @@ -6380,6 +6446,7 @@ func (a *application) rewriteRefOfJtOnResponse(parent SQLNode, node *JtOnRespons } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfJtOnResponseExpr) } if !a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -6467,6 +6534,7 @@ func (a *application) rewriteRefOfLagLeadExpr(parent SQLNode, node *LagLeadExpr, } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfLagLeadExprExpr) } if !a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -6621,6 +6689,7 @@ func (a *application) rewriteRefOfLinestrPropertyFuncExpr(parent SQLNode, node * } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfLinestrPropertyFuncExprLinestring) } if !a.rewriteExpr(node, node.Linestring, func(newNode, parent SQLNode) { @@ -6825,6 +6894,7 @@ func (a *application) rewriteRefOfLockingFunc(parent SQLNode, node *LockingFunc, } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfLockingFuncName) } if !a.rewriteExpr(node, node.Name, func(newNode, parent SQLNode) { @@ -6871,6 +6941,9 @@ func (a *application) rewriteRefOfMatchExpr(parent SQLNode, node *MatchExpr, rep } } var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node.Columns { if !a.rewriteRefOfColName(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -7338,6 +7411,10 @@ func (a *application) rewriteNamedWindows(parent SQLNode, node NamedWindows, rep return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteRefOfNamedWindow(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -7347,6 +7424,9 @@ func (a *application) rewriteNamedWindows(parent SQLNode, node NamedWindows, rep return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -7552,6 +7632,7 @@ func (a *application) rewriteRefOfOffset(parent SQLNode, node *Offset, replacer } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfOffsetOriginal) } if !a.rewriteExpr(node, node.Original, func(newNode, parent SQLNode) { @@ -7590,6 +7671,10 @@ func (a *application) rewriteOnDup(parent SQLNode, node OnDup, replacer replacer return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteRefOfUpdateExpr(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -7599,6 +7684,9 @@ func (a *application) rewriteOnDup(parent SQLNode, node OnDup, replacer replacer return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -7745,6 +7833,10 @@ func (a *application) rewriteOrderBy(parent SQLNode, node OrderBy, replacer repl return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteRefOfOrder(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -7754,6 +7846,9 @@ func (a *application) rewriteOrderBy(parent SQLNode, node OrderBy, replacer repl return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -8081,6 +8176,7 @@ func (a *application) rewriteRefOfPartitionOption(parent SQLNode, node *Partitio } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfPartitionOptionColList) } if !a.rewriteColumns(node, node.ColList, func(newNode, parent SQLNode) { @@ -8140,6 +8236,7 @@ func (a *application) rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionS } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfPartitionSpecNames) } if !a.rewritePartitions(node, node.Names, func(newNode, parent SQLNode) { @@ -8199,6 +8296,7 @@ func (a *application) rewriteRefOfPartitionValueRange(parent SQLNode, node *Part } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfPartitionValueRangeRange) } if !a.rewriteValTuple(node, node.Range, func(newNode, parent SQLNode) { @@ -8237,6 +8335,10 @@ func (a *application) rewritePartitions(parent SQLNode, node Partitions, replace return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteIdentifierCI(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -8246,6 +8348,9 @@ func (a *application) rewritePartitions(parent SQLNode, node Partitions, replace return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -8275,6 +8380,7 @@ func (a *application) rewriteRefOfPerformanceSchemaFuncExpr(parent SQLNode, node } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfPerformanceSchemaFuncExprArgument) } if !a.rewriteExpr(node, node.Argument, func(newNode, parent SQLNode) { @@ -8362,6 +8468,7 @@ func (a *application) rewriteRefOfPointPropertyFuncExpr(parent SQLNode, node *Po } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfPointPropertyFuncExprPoint) } if !a.rewriteExpr(node, node.Point, func(newNode, parent SQLNode) { @@ -8449,6 +8556,7 @@ func (a *application) rewriteRefOfPolygonPropertyFuncExpr(parent SQLNode, node * } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfPolygonPropertyFuncExprPolygon) } if !a.rewriteExpr(node, node.Polygon, func(newNode, parent SQLNode) { @@ -9101,6 +9209,7 @@ func (a *application) rewriteRefOfRevertMigration(parent SQLNode, node *RevertMi } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfRevertMigrationComments) } if !a.rewriteRefOfParsedComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -9304,6 +9413,7 @@ func (a *application) rewriteRefOfSelect(parent SQLNode, node *Select, replacer } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfSelectWith) } if !a.rewriteRefOfWith(node, node.With, func(newNode, parent SQLNode) { @@ -9423,6 +9533,10 @@ func (a *application) rewriteSelectExprs(parent SQLNode, node SelectExprs, repla return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteSelectExpr(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -9432,6 +9546,9 @@ func (a *application) rewriteSelectExprs(parent SQLNode, node SelectExprs, repla return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -9570,6 +9687,10 @@ func (a *application) rewriteSetExprs(parent SQLNode, node SetExprs, replacer re return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteRefOfSetExpr(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -9579,6 +9700,9 @@ func (a *application) rewriteSetExprs(parent SQLNode, node SetExprs, replacer re return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -9638,6 +9762,7 @@ func (a *application) rewriteRefOfShowBasic(parent SQLNode, node *ShowBasic, rep } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfShowBasicTbl) } if !a.rewriteTableName(node, node.Tbl, func(newNode, parent SQLNode) { @@ -9688,6 +9813,7 @@ func (a *application) rewriteRefOfShowCreate(parent SQLNode, node *ShowCreate, r } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfShowCreateOp) } if !a.rewriteTableName(node, node.Op, func(newNode, parent SQLNode) { @@ -9722,6 +9848,7 @@ func (a *application) rewriteRefOfShowFilter(parent SQLNode, node *ShowFilter, r } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfShowFilterFilter) } if !a.rewriteExpr(node, node.Filter, func(newNode, parent SQLNode) { @@ -9756,6 +9883,7 @@ func (a *application) rewriteRefOfShowMigrationLogs(parent SQLNode, node *ShowMi } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfShowMigrationLogsComments) } if !a.rewriteRefOfParsedComments(node, node.Comments, func(newNode, parent SQLNode) { @@ -10164,6 +10292,7 @@ func (a *application) rewriteRefOfSubPartition(parent SQLNode, node *SubPartitio } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfSubPartitionColList) } if !a.rewriteColumns(node, node.ColList, func(newNode, parent SQLNode) { @@ -10312,6 +10441,10 @@ func (a *application) rewriteSubPartitionDefinitions(parent SQLNode, node SubPar return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteRefOfSubPartitionDefinition(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -10321,6 +10454,9 @@ func (a *application) rewriteSubPartitionDefinitions(parent SQLNode, node SubPar return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -10493,6 +10629,10 @@ func (a *application) rewriteTableExprs(parent SQLNode, node TableExprs, replace return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteTableExpr(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -10502,6 +10642,9 @@ func (a *application) rewriteTableExprs(parent SQLNode, node TableExprs, replace return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -10570,6 +10713,10 @@ func (a *application) rewriteTableNames(parent SQLNode, node TableNames, replace return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteTableName(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -10579,6 +10726,9 @@ func (a *application) rewriteTableNames(parent SQLNode, node TableNames, replace return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -10632,6 +10782,9 @@ func (a *application) rewriteRefOfTableSpec(parent SQLNode, node *TableSpec, rep } } var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node.Columns { if !a.rewriteRefOfColumnDefinition(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -10779,6 +10932,7 @@ func (a *application) rewriteRefOfTrimFuncExpr(parent SQLNode, node *TrimFuncExp } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfTrimFuncExprTrimArg) } if !a.rewriteExpr(node, node.TrimArg, func(newNode, parent SQLNode) { @@ -10861,6 +11015,7 @@ func (a *application) rewriteRefOfUnaryExpr(parent SQLNode, node *UnaryExpr, rep } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfUnaryExprExpr) } if !a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -11125,6 +11280,10 @@ func (a *application) rewriteUpdateExprs(parent SQLNode, node UpdateExprs, repla return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteRefOfUpdateExpr(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -11134,6 +11293,9 @@ func (a *application) rewriteUpdateExprs(parent SQLNode, node UpdateExprs, repla return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -11249,6 +11411,7 @@ func (a *application) rewriteRefOfVExplainStmt(parent SQLNode, node *VExplainStm } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfVExplainStmtStatement) } if !a.rewriteStatement(node, node.Statement, func(newNode, parent SQLNode) { @@ -11362,6 +11525,10 @@ func (a *application) rewriteValTuple(parent SQLNode, node ValTuple, replacer re return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteExpr(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -11371,6 +11538,9 @@ func (a *application) rewriteValTuple(parent SQLNode, node ValTuple, replacer re return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -11423,6 +11593,10 @@ func (a *application) rewriteValues(parent SQLNode, node Values, replacer replac return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteValTuple(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -11432,6 +11606,9 @@ func (a *application) rewriteValues(parent SQLNode, node Values, replacer replac return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -11672,6 +11849,7 @@ func (a *application) rewriteRefOfVariable(parent SQLNode, node *Variable, repla } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfVariableName) } if !a.rewriteIdentifierCI(node, node.Name, func(newNode, parent SQLNode) { @@ -11929,6 +12107,7 @@ func (a *application) rewriteRefOfWhere(parent SQLNode, node *Where, replacer re } var path ASTPath if a.collectPaths { + path = a.cur.current a.cur.current = AddStep(path, RefOfWhereExpr) } if !a.rewriteExpr(node, node.Expr, func(newNode, parent SQLNode) { @@ -12010,6 +12189,10 @@ func (a *application) rewriteWindowDefinitions(parent SQLNode, node WindowDefini return true } } + var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node { if !a.rewriteRefOfWindowDefinition(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { @@ -12019,6 +12202,9 @@ func (a *application) rewriteWindowDefinitions(parent SQLNode, node WindowDefini return false } } + if a.collectPaths { + a.cur.current = path + } if a.post != nil { a.cur.replacer = replacer a.cur.parent = parent @@ -12101,6 +12287,9 @@ func (a *application) rewriteRefOfWith(parent SQLNode, node *With, replacer repl } } var path ASTPath + if a.collectPaths { + path = a.cur.current + } for x, el := range node.CTEs { if !a.rewriteRefOfCommonTableExpr(node, el, func(idx int) replacerFunc { return func(newNode, parent SQLNode) { diff --git a/go/vt/sqlparser/rewriter_test.go b/go/vt/sqlparser/rewriter_test.go index 89dfc25cac9..7e358d4ff40 100644 --- a/go/vt/sqlparser/rewriter_test.go +++ b/go/vt/sqlparser/rewriter_test.go @@ -66,7 +66,7 @@ func TestReplaceWorksInLaterCalls(t *testing.T) { assert.Equal(t, 2, count) } -func TestFindColNames(t *testing.T) { +func TestFindColNamesWithPaths(t *testing.T) { // this is the tpch query #1 q := "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, avg(l_quantity) as avg_qty, avg(l_extendedprice) as avg_price, avg(l_discount) as avg_disc, count(*) as count_order from lineitem where l_shipdate <= '1998-12-01' - interval '108' day group by l_returnflag, l_linestatus order by l_returnflag, l_linestatus" ast, err := NewTestParser().Parse(q)