Skip to content

Commit cf5cb00

Browse files
adonovangopherbot
authored andcommittedApr 17, 2025·
internal/astutil: PreorderStack: a safer ast.Inspect for stacks
This CL defines PreorderStack, a safer function than ast.Inspect for when you need to maintain a stack. Beware, the stack that it produces does not include n itself--a half-open interval--so that nested traversals compose correctly. The CL also uses the new function in various places in x/tools where appropriate; in some cases it was clearer to rewrite using cursor.Cursor. + test Updates golang/go#73319 Change-Id: I843122cdd49cc4af8a7318badd8c34389479a92a Reviewed-on: https://go-review.googlesource.com/c/tools/+/664635 Auto-Submit: Alan Donovan <adonovan@google.com> Commit-Queue: Alan Donovan <adonovan@google.com> Reviewed-by: Robert Findley <rfindley@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
1 parent f76b112 commit cf5cb00

File tree

8 files changed

+171
-115
lines changed

8 files changed

+171
-115
lines changed
 

‎go/analysis/passes/lostcancel/lostcancel.go

+7-14
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"golang.org/x/tools/go/ast/inspector"
1818
"golang.org/x/tools/go/cfg"
1919
"golang.org/x/tools/internal/analysisinternal"
20+
"golang.org/x/tools/internal/astutil"
2021
)
2122

2223
//go:embed doc.go
@@ -83,30 +84,22 @@ func runFunc(pass *analysis.Pass, node ast.Node) {
8384
// {FuncDecl,FuncLit,CallExpr,SelectorExpr}.
8485

8586
// Find the set of cancel vars to analyze.
86-
stack := make([]ast.Node, 0, 32)
87-
ast.Inspect(node, func(n ast.Node) bool {
88-
switch n.(type) {
89-
case *ast.FuncLit:
90-
if len(stack) > 0 {
91-
return false // don't stray into nested functions
92-
}
93-
case nil:
94-
stack = stack[:len(stack)-1] // pop
95-
return true
87+
astutil.PreorderStack(node, nil, func(n ast.Node, stack []ast.Node) bool {
88+
if _, ok := n.(*ast.FuncLit); ok && len(stack) > 0 {
89+
return false // don't stray into nested functions
9690
}
97-
stack = append(stack, n) // push
9891

99-
// Look for [{AssignStmt,ValueSpec} CallExpr SelectorExpr]:
92+
// Look for n=SelectorExpr beneath stack=[{AssignStmt,ValueSpec} CallExpr]:
10093
//
10194
// ctx, cancel := context.WithCancel(...)
10295
// ctx, cancel = context.WithCancel(...)
10396
// var ctx, cancel = context.WithCancel(...)
10497
//
105-
if !isContextWithCancel(pass.TypesInfo, n) || !isCall(stack[len(stack)-2]) {
98+
if !isContextWithCancel(pass.TypesInfo, n) || !isCall(stack[len(stack)-1]) {
10699
return true
107100
}
108101
var id *ast.Ident // id of cancel var
109-
stmt := stack[len(stack)-3]
102+
stmt := stack[len(stack)-2]
110103
switch stmt := stmt.(type) {
111104
case *ast.ValueSpec:
112105
if len(stmt.Names) > 1 {

‎gopls/internal/golang/codeaction.go

+10-21
Original file line numberDiff line numberDiff line change
@@ -713,33 +713,24 @@ func refactorRewriteEliminateDotImport(ctx context.Context, req *codeActionsRequ
713713

714714
// Go through each use of the dot imported package, checking its scope for
715715
// shadowing and calculating an edit to qualify the identifier.
716-
var stack []ast.Node
717-
ast.Inspect(req.pgf.File, func(n ast.Node) bool {
718-
if n == nil {
719-
stack = stack[:len(stack)-1] // pop
720-
return false
721-
}
722-
stack = append(stack, n) // push
716+
for curId := range req.pgf.Cursor.Preorder((*ast.Ident)(nil)) {
717+
ident := curId.Node().(*ast.Ident)
723718

724-
ident, ok := n.(*ast.Ident)
725-
if !ok {
726-
return true
727-
}
728719
// Only keep identifiers that use a symbol from the
729720
// dot imported package.
730721
use := req.pkg.TypesInfo().Uses[ident]
731722
if use == nil || use.Pkg() == nil {
732-
return true
723+
continue
733724
}
734725
if use.Pkg() != imported {
735-
return true
726+
continue
736727
}
737728

738729
// Only qualify unqualified identifiers (due to dot imports).
739730
// All other references to a symbol imported from another package
740731
// are nested within a select expression (pkg.Foo, v.Method, v.Field).
741-
if is[*ast.SelectorExpr](stack[len(stack)-2]) {
742-
return true
732+
if is[*ast.SelectorExpr](curId.Parent().Node()) {
733+
continue
743734
}
744735

745736
// Make sure that the package name will not be shadowed by something else in scope.
@@ -750,24 +741,22 @@ func refactorRewriteEliminateDotImport(ctx context.Context, req *codeActionsRequ
750741
// allowed to go through.
751742
sc := fileScope.Innermost(ident.Pos())
752743
if sc == nil {
753-
return true
744+
continue
754745
}
755746
_, obj := sc.LookupParent(newName, ident.Pos())
756747
if obj != nil {
757-
return true
748+
continue
758749
}
759750

760751
rng, err := req.pgf.PosRange(ident.Pos(), ident.Pos()) // sic, zero-width range before ident
761752
if err != nil {
762-
return true
753+
continue
763754
}
764755
edits = append(edits, protocol.TextEdit{
765756
Range: rng,
766757
NewText: newName + ".",
767758
})
768-
769-
return true
770-
})
759+
}
771760

772761
req.addEditAction("Eliminate dot import", nil, protocol.DocumentChangeEdit(
773762
req.fh,

‎gopls/internal/golang/hover.go

+3-8
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import (
3838
gastutil "golang.org/x/tools/gopls/internal/util/astutil"
3939
"golang.org/x/tools/gopls/internal/util/bug"
4040
"golang.org/x/tools/gopls/internal/util/safetoken"
41+
internalastutil "golang.org/x/tools/internal/astutil"
4142
"golang.org/x/tools/internal/event"
4243
"golang.org/x/tools/internal/stdlib"
4344
"golang.org/x/tools/internal/tokeninternal"
@@ -1502,16 +1503,10 @@ func findDeclInfo(files []*ast.File, pos token.Pos) (decl ast.Decl, spec ast.Spe
15021503
stack := make([]ast.Node, 0, 20)
15031504

15041505
// Allocate the closure once, outside the loop.
1505-
f := func(n ast.Node) bool {
1506+
f := func(n ast.Node, stack []ast.Node) bool {
15061507
if found {
15071508
return false
15081509
}
1509-
if n != nil {
1510-
stack = append(stack, n) // push
1511-
} else {
1512-
stack = stack[:len(stack)-1] // pop
1513-
return false
1514-
}
15151510

15161511
// Skip subtrees (incl. files) that don't contain the search point.
15171512
if !(n.Pos() <= pos && pos < n.End()) {
@@ -1596,7 +1591,7 @@ func findDeclInfo(files []*ast.File, pos token.Pos) (decl ast.Decl, spec ast.Spe
15961591
return true
15971592
}
15981593
for _, file := range files {
1599-
ast.Inspect(file, f)
1594+
internalastutil.PreorderStack(file, stack, f)
16001595
if found {
16011596
return decl, spec, field
16021597
}

‎gopls/internal/golang/rename_check.go

+31-33
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ import (
4545
"golang.org/x/tools/go/ast/astutil"
4646
"golang.org/x/tools/gopls/internal/cache"
4747
"golang.org/x/tools/gopls/internal/util/safetoken"
48+
"golang.org/x/tools/internal/astutil/cursor"
49+
"golang.org/x/tools/internal/astutil/edge"
4850
"golang.org/x/tools/internal/typeparams"
4951
"golang.org/x/tools/internal/typesinternal"
5052
"golang.org/x/tools/refactor/satisfy"
@@ -338,64 +340,61 @@ func deeper(x, y *types.Scope) bool {
338340
// lexical block enclosing the reference. If fn returns false the
339341
// iteration is terminated and findLexicalRefs returns false.
340342
func forEachLexicalRef(pkg *cache.Package, obj types.Object, fn func(id *ast.Ident, block *types.Scope) bool) bool {
343+
filter := []ast.Node{
344+
(*ast.Ident)(nil),
345+
(*ast.SelectorExpr)(nil),
346+
(*ast.CompositeLit)(nil),
347+
}
341348
ok := true
342-
var stack []ast.Node
343-
344-
var visit func(n ast.Node) bool
345-
visit = func(n ast.Node) bool {
346-
if n == nil {
347-
stack = stack[:len(stack)-1] // pop
348-
return false
349-
}
349+
var visit func(cur cursor.Cursor, push bool) (descend bool)
350+
visit = func(cur cursor.Cursor, push bool) (descend bool) {
350351
if !ok {
351352
return false // bail out
352353
}
353-
354-
stack = append(stack, n) // push
355-
switch n := n.(type) {
354+
if !push {
355+
return false
356+
}
357+
switch n := cur.Node().(type) {
356358
case *ast.Ident:
357359
if pkg.TypesInfo().Uses[n] == obj {
358-
block := enclosingBlock(pkg.TypesInfo(), stack)
360+
block := enclosingBlock(pkg.TypesInfo(), cur)
359361
if !fn(n, block) {
360362
ok = false
361363
}
362364
}
363-
return visit(nil) // pop stack
364365

365366
case *ast.SelectorExpr:
366367
// don't visit n.Sel
367-
ast.Inspect(n.X, visit)
368-
return visit(nil) // pop stack, don't descend
368+
cur.ChildAt(edge.SelectorExpr_X, -1).Inspect(filter, visit)
369+
return false // don't descend
369370

370371
case *ast.CompositeLit:
371372
// Handle recursion ourselves for struct literals
372373
// so we don't visit field identifiers.
373374
tv, ok := pkg.TypesInfo().Types[n]
374375
if !ok {
375-
return visit(nil) // pop stack, don't descend
376+
return false // don't descend
376377
}
377378
if is[*types.Struct](typeparams.CoreType(typeparams.Deref(tv.Type))) {
378379
if n.Type != nil {
379-
ast.Inspect(n.Type, visit)
380+
cur.ChildAt(edge.CompositeLit_Type, -1).Inspect(filter, visit)
380381
}
381-
for _, elt := range n.Elts {
382-
if kv, ok := elt.(*ast.KeyValueExpr); ok {
383-
ast.Inspect(kv.Value, visit)
384-
} else {
385-
ast.Inspect(elt, visit)
382+
for i, elt := range n.Elts {
383+
curElt := cur.ChildAt(edge.CompositeLit_Elts, i)
384+
if _, ok := elt.(*ast.KeyValueExpr); ok {
385+
// skip kv.Key
386+
curElt = curElt.ChildAt(edge.KeyValueExpr_Value, -1)
386387
}
388+
curElt.Inspect(filter, visit)
387389
}
388-
return visit(nil) // pop stack, don't descend
390+
return false // don't descend
389391
}
390392
}
391393
return true
392394
}
393395

394-
for _, f := range pkg.Syntax() {
395-
ast.Inspect(f, visit)
396-
if len(stack) != 0 {
397-
panic(stack)
398-
}
396+
for _, pgf := range pkg.CompiledGoFiles() {
397+
pgf.Cursor.Inspect(filter, visit)
399398
if !ok {
400399
break
401400
}
@@ -404,11 +403,10 @@ func forEachLexicalRef(pkg *cache.Package, obj types.Object, fn func(id *ast.Ide
404403
}
405404

406405
// enclosingBlock returns the innermost block logically enclosing the
407-
// specified AST node (an ast.Ident), specified in the form of a path
408-
// from the root of the file, [file...n].
409-
func enclosingBlock(info *types.Info, stack []ast.Node) *types.Scope {
410-
for i := range stack {
411-
n := stack[len(stack)-1-i]
406+
// AST node (an ast.Ident), specified as a Cursor.
407+
func enclosingBlock(info *types.Info, curId cursor.Cursor) *types.Scope {
408+
for cur := range curId.Enclosing() {
409+
n := cur.Node()
412410
// For some reason, go/types always associates a
413411
// function's scope with its FuncType.
414412
// See comments about scope above.

‎internal/astutil/util.go

+32
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,35 @@ func PosInStringLiteral(lit *ast.BasicLit, offset int) (token.Pos, error) {
5757
}
5858
return pos, nil
5959
}
60+
61+
// PreorderStack traverses the tree rooted at root,
62+
// calling f before visiting each node.
63+
//
64+
// Each call to f provides the current node and traversal stack,
65+
// consisting of the original value of stack appended with all nodes
66+
// from root to n, excluding n itself. (This design allows calls
67+
// to PreorderStack to be nested without double counting.)
68+
//
69+
// If f returns false, the traversal skips over that subtree. Unlike
70+
// [ast.Inspect], no second call to f is made after visiting node n.
71+
// In practice, the second call is nearly always used only to pop the
72+
// stack, and it is surprisingly tricky to do this correctly; see
73+
// https://go.dev/issue/73319.
74+
func PreorderStack(root ast.Node, stack []ast.Node, f func(n ast.Node, stack []ast.Node) bool) {
75+
before := len(stack)
76+
ast.Inspect(root, func(n ast.Node) bool {
77+
if n != nil {
78+
if !f(n, stack) {
79+
// Do not push, as there will be no corresponding pop.
80+
return false
81+
}
82+
stack = append(stack, n) // push
83+
} else {
84+
stack = stack[:len(stack)-1] // pop
85+
}
86+
return true
87+
})
88+
if len(stack) != before {
89+
panic("push/pop mismatch")
90+
}
91+
}

‎internal/astutil/util_test.go

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright 2025 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package astutil_test
6+
7+
import (
8+
"fmt"
9+
"go/ast"
10+
"go/parser"
11+
"go/token"
12+
"reflect"
13+
"strings"
14+
"testing"
15+
16+
"golang.org/x/tools/internal/astutil"
17+
)
18+
19+
func TestPreorderStack(t *testing.T) {
20+
const src = `package a
21+
func f() {
22+
print("hello")
23+
}
24+
func g() {
25+
print("goodbye")
26+
panic("oops")
27+
}
28+
`
29+
fset := token.NewFileSet()
30+
f, _ := parser.ParseFile(fset, "a.go", src, 0)
31+
32+
str := func(n ast.Node) string {
33+
return strings.TrimPrefix(reflect.TypeOf(n).String(), "*ast.")
34+
}
35+
36+
var events []string
37+
var gotStack []string
38+
astutil.PreorderStack(f, nil, func(n ast.Node, stack []ast.Node) bool {
39+
events = append(events, str(n))
40+
if decl, ok := n.(*ast.FuncDecl); ok && decl.Name.Name == "f" {
41+
return false // skip subtree of f()
42+
}
43+
if lit, ok := n.(*ast.BasicLit); ok && lit.Value == `"oops"` {
44+
for _, n := range stack {
45+
gotStack = append(gotStack, str(n))
46+
}
47+
}
48+
return true
49+
})
50+
51+
// Check sequence of events.
52+
const wantEvents = `[File Ident ` + // package a
53+
`FuncDecl ` + // func f() [pruned]
54+
`FuncDecl Ident FuncType FieldList BlockStmt ` + // func g()
55+
`ExprStmt CallExpr Ident BasicLit ` + // print...
56+
`ExprStmt CallExpr Ident BasicLit]` // panic...
57+
if got := fmt.Sprint(events); got != wantEvents {
58+
t.Errorf("PreorderStack events:\ngot: %s\nwant: %s", got, wantEvents)
59+
}
60+
61+
// Check captured stack.
62+
const wantStack = `[File FuncDecl BlockStmt ExprStmt CallExpr]`
63+
if got := fmt.Sprint(gotStack); got != wantStack {
64+
t.Errorf("PreorderStack stack:\ngot: %s\nwant: %s", got, wantStack)
65+
}
66+
67+
}

‎internal/refactor/inline/callee.go

+10-19
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"strings"
1919

2020
"golang.org/x/tools/go/types/typeutil"
21+
"golang.org/x/tools/internal/astutil"
2122
"golang.org/x/tools/internal/typeparams"
2223
"golang.org/x/tools/internal/typesinternal"
2324
)
@@ -132,16 +133,11 @@ func AnalyzeCallee(logf func(string, ...any), fset *token.FileSet, pkg *types.Pa
132133
freeRefs []freeRef // free refs that may need renaming
133134
unexported []string // free refs to unexported objects, for later error checks
134135
)
135-
var f func(n ast.Node) bool
136-
visit := func(n ast.Node) { ast.Inspect(n, f) }
136+
var f func(n ast.Node, stack []ast.Node) bool
137137
var stack []ast.Node
138138
stack = append(stack, decl.Type) // for scope of function itself
139-
f = func(n ast.Node) bool {
140-
if n != nil {
141-
stack = append(stack, n) // push
142-
} else {
143-
stack = stack[:len(stack)-1] // pop
144-
}
139+
visit := func(n ast.Node, stack []ast.Node) { astutil.PreorderStack(n, stack, f) }
140+
f = func(n ast.Node, stack []ast.Node) bool {
145141
switch n := n.(type) {
146142
case *ast.SelectorExpr:
147143
// Check selections of free fields/methods.
@@ -153,7 +149,7 @@ func AnalyzeCallee(logf func(string, ...any), fset *token.FileSet, pkg *types.Pa
153149
}
154150

155151
// Don't recur into SelectorExpr.Sel.
156-
visit(n.X)
152+
visit(n.X, stack)
157153
return false
158154

159155
case *ast.CompositeLit:
@@ -162,7 +158,7 @@ func AnalyzeCallee(logf func(string, ...any), fset *token.FileSet, pkg *types.Pa
162158
litType := typeparams.Deref(info.TypeOf(n))
163159
if s, ok := typeparams.CoreType(litType).(*types.Struct); ok {
164160
if n.Type != nil {
165-
visit(n.Type)
161+
visit(n.Type, stack)
166162
}
167163
for i, elt := range n.Elts {
168164
var field *types.Var
@@ -180,7 +176,7 @@ func AnalyzeCallee(logf func(string, ...any), fset *token.FileSet, pkg *types.Pa
180176
}
181177

182178
// Don't recur into KeyValueExpr.Key.
183-
visit(value)
179+
visit(value, stack)
184180
}
185181
return false
186182
}
@@ -234,7 +230,7 @@ func AnalyzeCallee(logf func(string, ...any), fset *token.FileSet, pkg *types.Pa
234230
}
235231
return true
236232
}
237-
visit(decl)
233+
visit(decl, stack)
238234

239235
// Analyze callee body for "return expr" form,
240236
// where expr is f() or <-ch. These forms are
@@ -466,13 +462,7 @@ func analyzeParams(logf func(string, ...any), fset *token.FileSet, info *types.I
466462
fieldObjs := fieldObjs(sig)
467463
var stack []ast.Node
468464
stack = append(stack, decl.Type) // for scope of function itself
469-
ast.Inspect(decl.Body, func(n ast.Node) bool {
470-
if n != nil {
471-
stack = append(stack, n) // push
472-
} else {
473-
stack = stack[:len(stack)-1] // pop
474-
}
475-
465+
astutil.PreorderStack(decl.Body, stack, func(n ast.Node, stack []ast.Node) bool {
476466
if id, ok := n.(*ast.Ident); ok {
477467
if v, ok := info.Uses[id].(*types.Var); ok {
478468
if pinfo, ok := paramInfos[v]; ok {
@@ -487,6 +477,7 @@ func analyzeParams(logf func(string, ...any), fset *token.FileSet, info *types.I
487477
// Contrapositively, if param is not an interface type, then the
488478
// assignment may lose type information, for example in the case that
489479
// the substituted expression is an untyped constant or unnamed type.
480+
stack = append(stack, n) // (the two calls below want n)
490481
assignable, ifaceAssign, affectsInference := analyzeAssignment(info, stack)
491482
ref := refInfo{
492483
Offset: int(n.Pos() - decl.Pos()),

‎refactor/rename/check.go

+11-20
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"go/types"
1414

1515
"golang.org/x/tools/go/loader"
16+
"golang.org/x/tools/internal/astutil"
1617
"golang.org/x/tools/internal/typeparams"
1718
"golang.org/x/tools/internal/typesinternal"
1819
"golang.org/x/tools/refactor/satisfy"
@@ -313,19 +314,12 @@ func deeper(x, y *types.Scope) bool {
313314
// iteration is terminated and findLexicalRefs returns false.
314315
func forEachLexicalRef(info *loader.PackageInfo, obj types.Object, fn func(id *ast.Ident, block *types.Scope) bool) bool {
315316
ok := true
316-
var stack []ast.Node
317317

318-
var visit func(n ast.Node) bool
319-
visit = func(n ast.Node) bool {
320-
if n == nil {
321-
stack = stack[:len(stack)-1] // pop
322-
return false
323-
}
318+
var visit func(n ast.Node, stack []ast.Node) bool
319+
visit = func(n ast.Node, stack []ast.Node) bool {
324320
if !ok {
325321
return false // bail out
326322
}
327-
328-
stack = append(stack, n) // push
329323
switch n := n.(type) {
330324
case *ast.Ident:
331325
if info.Uses[n] == obj {
@@ -334,39 +328,36 @@ func forEachLexicalRef(info *loader.PackageInfo, obj types.Object, fn func(id *a
334328
ok = false
335329
}
336330
}
337-
return visit(nil) // pop stack
331+
return false
338332

339333
case *ast.SelectorExpr:
340334
// don't visit n.Sel
341-
ast.Inspect(n.X, visit)
342-
return visit(nil) // pop stack, don't descend
335+
astutil.PreorderStack(n.X, stack, visit)
336+
return false // don't descend
343337

344338
case *ast.CompositeLit:
345339
// Handle recursion ourselves for struct literals
346340
// so we don't visit field identifiers.
347341
tv := info.Types[n]
348342
if is[*types.Struct](typeparams.CoreType(typeparams.Deref(tv.Type))) {
349343
if n.Type != nil {
350-
ast.Inspect(n.Type, visit)
344+
astutil.PreorderStack(n.Type, stack, visit)
351345
}
352346
for _, elt := range n.Elts {
353347
if kv, ok := elt.(*ast.KeyValueExpr); ok {
354-
ast.Inspect(kv.Value, visit)
348+
astutil.PreorderStack(kv.Value, stack, visit)
355349
} else {
356-
ast.Inspect(elt, visit)
350+
astutil.PreorderStack(elt, stack, visit)
357351
}
358352
}
359-
return visit(nil) // pop stack, don't descend
353+
return false // don't descend
360354
}
361355
}
362356
return true
363357
}
364358

365359
for _, f := range info.Files {
366-
ast.Inspect(f, visit)
367-
if len(stack) != 0 {
368-
panic(stack)
369-
}
360+
astutil.PreorderStack(f, nil, visit)
370361
if !ok {
371362
break
372363
}

0 commit comments

Comments
 (0)
Please sign in to comment.