diff --git a/go/vt/sqlparser/rewriter_api.go b/go/vt/sqlparser/rewriter_api.go index 260089dd202..ec448162f27 100644 --- a/go/vt/sqlparser/rewriter_api.go +++ b/go/vt/sqlparser/rewriter_api.go @@ -34,6 +34,10 @@ package sqlparser // Only fields that refer to AST nodes are considered children; // i.e., fields of basic types (strings, []byte, etc.) are ignored. func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode) { + return rewriteNode(node, pre, post, false) +} + +func rewriteNode(node SQLNode, pre ApplyFunc, post ApplyFunc, collectPaths bool) SQLNode { parent := &RootNode{node} // this is the root-replacer, used when the user replaces the root of the ast @@ -42,8 +46,9 @@ func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode) { } a := &application{ - pre: pre, - post: post, + pre: pre, + post: post, + collectPaths: collectPaths, } a.rewriteSQLNode(parent, node, replacer) @@ -51,6 +56,10 @@ func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode) { return parent.SQLNode } +func RewriteWithPath(node SQLNode, pre, post ApplyFunc) (result SQLNode) { + return rewriteNode(node, pre, post, true) +} + // SafeRewrite does not allow replacing nodes on the down walk of the tree walking // Long term this is the only Rewrite functionality we want func SafeRewrite( @@ -139,6 +148,12 @@ func (c *Cursor) ReplaceAndRevisit(newNode SQLNode) { c.revisit = true } +// CurrentPath returns the current path that got us to the current location in the AST +// Only works if the AST walk was configured to collect path as walking +func (c *Cursor) CurrentPath() ASTPath { + return c.current +} + type replacerFunc func(newNode, parent SQLNode) // application carries all the shared data so we can pass it around cheaply. diff --git a/go/vt/sqlparser/rewriter_test.go b/go/vt/sqlparser/rewriter_test.go index 628d6fbd0a4..89dfc25cac9 100644 --- a/go/vt/sqlparser/rewriter_test.go +++ b/go/vt/sqlparser/rewriter_test.go @@ -17,6 +17,7 @@ limitations under the License. package sqlparser import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -65,6 +66,20 @@ func TestReplaceWorksInLaterCalls(t *testing.T) { assert.Equal(t, 2, count) } +func TestFindColNames(t *testing.T) { + // this is the tpch query #1 + q := "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, avg(l_quantity) as avg_qty, avg(l_extendedprice) as avg_price, avg(l_discount) as avg_disc, count(*) as count_order from lineitem where l_shipdate <= '1998-12-01' - interval '108' day group by l_returnflag, l_linestatus order by l_returnflag, l_linestatus" + ast, err := NewTestParser().Parse(q) + require.NoError(t, err) + RewriteWithPath(ast, nil, func(cursor *Cursor) bool { + if _, isColName := cursor.Node().(*ColName); isColName { + // TODO: actually assert something here + fmt.Println(cursor.CurrentPath().DebugString()) + } + return true + }) +} + func TestReplaceAndRevisitWorksInLaterCalls(t *testing.T) { q := "select * from tbl1" parser := NewTestParser()