Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
systay committed Jan 31, 2025
1 parent 15c4d8b commit 9807745
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 58 deletions.
67 changes: 34 additions & 33 deletions go/tools/asthelpergen/ast_path_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (p *pathGen) genFile(spi generatorSPI) (string, *jen.File) {
p.file.Add(p.debugString())

// Add the generated WalkASTPath method
p.file.Add(p.generateWalkASTPath(spi))
p.file.Add(p.generateWalkASTPath())

return "ast_path.go", p.file
}
Expand All @@ -91,37 +91,41 @@ func (p *pathGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi gener
return nil
}

func (p *pathGen) addStep(
container types.Type, // the name of the container type
typ types.Type, // the type of the field
name string, // the name of the field
slice bool, // whether the field is a slice
) {
s := step{
container: container,
name: name,
typ: typ,
slice: slice,
}
stepName := s.AsEnum()
fmt.Println("Adding step:", stepName)
p.steps = append(p.steps, s)

}

func (p *pathGen) addStructFields(t types.Type, strct *types.Struct, spi generatorSPI) {
for i := 0; i < strct.NumFields(); i++ {
field := strct.Field(i)
// Check if the field type implements the interface
if types.Implements(field.Type(), spi.iface()) {
p.steps = append(p.steps, step{
container: t,
name: field.Name(),
typ: field.Type(),
})
p.addStep(t, field.Type(), field.Name(), false)
continue
}
// Check if the field type is a slice
slice, isSlice := field.Type().(*types.Slice)
if isSlice {
// Check if the slice type implements the interface
if types.Implements(slice, spi.iface()) {
p.steps = append(p.steps, step{
container: t,
slice: true,
name: field.Name(),
typ: slice,
})
p.addStep(t, slice, field.Name(), true)
} else if types.Implements(slice.Elem(), spi.iface()) {
// Check if the element type of the slice implements the interface
p.steps = append(p.steps, step{
container: t,
slice: true,
name: field.Name(),
typ: slice.Elem(),
})
p.addStep(t, slice.Elem(), field.Name(), true)
}
}
}
Expand All @@ -132,16 +136,11 @@ func (p *pathGen) ptrToBasicMethod(t types.Type, basic *types.Basic, spi generat
}

func (p *pathGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error {
elemType := slice.Elem()
if types.Implements(elemType, spi.iface()) {
p.steps = append(p.steps, step{
container: t,
name: sliceMarker,
slice: true,
typ: elemType,
})
}

//elemType := slice.Elem()
//if types.Implements(elemType, spi.iface()) {
// p.addStep(t, elemType, sliceMarker, true)
//}
//
return nil
}

Expand Down Expand Up @@ -213,11 +212,13 @@ func (p *pathGen) buildConstWithEnum() *jen.Statement {
return constBlock
}

func (p *pathGen) generateWalkASTPath(spi generatorSPI) *jen.Statement {
func (p *pathGen) generateWalkASTPath() *jen.Statement {
method := jen.Func().Id("WalkASTPath").Params(
jen.Id("node").Id(p.ifaceName),
jen.Id("path").Id("ASTPath"),
).Id(p.ifaceName).Block(
jen.If(jen.Id("path").Op("==").Id(`""`).Block(
jen.Return(jen.Id("node")))),
jen.Id("step").Op(":=").Qual("encoding/binary", "BigEndian").Dot("Uint16").Call(jen.Index().Byte().Parens(jen.Id("path[:2]"))),
jen.Id("path").Op("=").Id("path[2:]"),

Expand All @@ -235,7 +236,7 @@ func (p *pathGen) generateWalkCases() []jen.Code {

if !step.slice {
// return WalkASTPath(node.(*RefContainer).ASTType, path)
t := types.TypeString(step.typ, noQualifier)
t := types.TypeString(step.container, noQualifier)
n := jen.Id("node").Dot(fmt.Sprintf("(%s)", t)).Dot(step.name)

cases = append(cases, jen.Case(jen.Id(stepName)).Block(
Expand All @@ -250,7 +251,7 @@ func (p *pathGen) generateWalkCases() []jen.Code {
if step.name == sliceMarker {
assignNode = jen.Id("node").Dot(fmt.Sprintf("(%s)", t)).Index(jen.Id("idx"))
} else {
assignNode = jen.Id("node").Dot(step.name).Index(jen.Id("idx"))
assignNode = jen.Id("node").Dot(fmt.Sprintf("(%s)", t)).Dot(step.name).Index(jen.Id("idx"))
}

cases = append(cases, jen.Case(jen.Id(stepName+"8")).Block(
Expand All @@ -260,8 +261,8 @@ func (p *pathGen) generateWalkCases() []jen.Code {
))

cases = append(cases, jen.Case(jen.Id(stepName+"32")).Block(
jen.Id("idx").Op(":=").Qual("encoding/binary", "BigEndian").Dot("Uint16").Call(jen.Index().Byte().Parens(jen.Id("path[:2]"))),
jen.Id("path").Op("=").Id("path[2:]"),
jen.Id("idx").Op(":=").Qual("encoding/binary", "BigEndian").Dot("Uint32").Call(jen.Index().Byte().Parens(jen.Id("path[:2]"))),
jen.Id("path").Op("=").Id("path[4:]"),
jen.Return(jen.Id("WalkASTPath").Call(assignNode, jen.Id("path"))),
))
}
Expand Down
88 changes: 63 additions & 25 deletions go/tools/asthelpergen/integration/ast_path.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

67 changes: 67 additions & 0 deletions go/tools/asthelpergen/integration/ast_path_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
Copyright 2025 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package integration

import (
"fmt"
"github.com/stretchr/testify/assert"
"reflect"
"testing"
)

func TestWalkAllPartsOfAST(t *testing.T) {
sliceContainer := &RefSliceContainer{
something: 12,
ASTElements: []AST{},
NotASTElements: []int{1, 2},
ASTImplementationElements: []*Leaf{{v: 1}, {v: 2}},
}

for i := range 300 {
sliceContainer.ASTImplementationElements = append(sliceContainer.ASTImplementationElements, &Leaf{v: i})
}

ast := &RefContainer{
ASTType: sliceContainer,
NotASTType: 2,
ASTImplementationType: &Leaf{v: 3},
}

v := make(map[ASTPath]AST)

RewriteWithPaths(ast, func(c *Cursor) bool {
node := c.Node()
if !reflect.TypeOf(node).Comparable() {
return true
}
current := c.current
v[current] = node
return true
}, nil)

fmt.Println("walked all parts of AST")

assert.NotEmpty(t, v)

for path, n1 := range v {
s := path.DebugString()
fmt.Println(s)

n2 := WalkASTPath(ast, path)
assert.Equal(t, n1, n2)
}
}
Loading

0 comments on commit 9807745

Please sign in to comment.