From 9807745addbf3046ee1d50e961c006dadd72c872 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 31 Jan 2025 10:21:46 +0100 Subject: [PATCH] wip --- go/tools/asthelpergen/ast_path_gen.go | 67 +++++++------- go/tools/asthelpergen/integration/ast_path.go | 88 +++++++++++++------ .../asthelpergen/integration/ast_path_test.go | 67 ++++++++++++++ .../asthelpergen/integration/test_helpers.go | 73 +++++++++++++++ 4 files changed, 237 insertions(+), 58 deletions(-) create mode 100644 go/tools/asthelpergen/integration/ast_path_test.go diff --git a/go/tools/asthelpergen/ast_path_gen.go b/go/tools/asthelpergen/ast_path_gen.go index d6784bf0c9f..8b981b46f81 100644 --- a/go/tools/asthelpergen/ast_path_gen.go +++ b/go/tools/asthelpergen/ast_path_gen.go @@ -72,7 +72,7 @@ func (p *pathGen) genFile(spi generatorSPI) (string, *jen.File) { p.file.Add(p.debugString()) // Add the generated WalkASTPath method - p.file.Add(p.generateWalkASTPath(spi)) + p.file.Add(p.generateWalkASTPath()) return "ast_path.go", p.file } @@ -91,16 +91,30 @@ func (p *pathGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi gener return nil } +func (p *pathGen) addStep( + container types.Type, // the name of the container type + typ types.Type, // the type of the field + name string, // the name of the field + slice bool, // whether the field is a slice +) { + s := step{ + container: container, + name: name, + typ: typ, + slice: slice, + } + stepName := s.AsEnum() + fmt.Println("Adding step:", stepName) + p.steps = append(p.steps, s) + +} + func (p *pathGen) addStructFields(t types.Type, strct *types.Struct, spi generatorSPI) { for i := 0; i < strct.NumFields(); i++ { field := strct.Field(i) // Check if the field type implements the interface if types.Implements(field.Type(), spi.iface()) { - p.steps = append(p.steps, step{ - container: t, - name: field.Name(), - typ: field.Type(), - }) + p.addStep(t, field.Type(), field.Name(), false) continue } // Check if the field type is a slice @@ -108,20 +122,10 @@ func (p *pathGen) addStructFields(t types.Type, strct *types.Struct, spi generat if isSlice { // Check if the slice type implements the interface if types.Implements(slice, spi.iface()) { - p.steps = append(p.steps, step{ - container: t, - slice: true, - name: field.Name(), - typ: slice, - }) + p.addStep(t, slice, field.Name(), true) } else if types.Implements(slice.Elem(), spi.iface()) { // Check if the element type of the slice implements the interface - p.steps = append(p.steps, step{ - container: t, - slice: true, - name: field.Name(), - typ: slice.Elem(), - }) + p.addStep(t, slice.Elem(), field.Name(), true) } } } @@ -132,16 +136,11 @@ func (p *pathGen) ptrToBasicMethod(t types.Type, basic *types.Basic, spi generat } func (p *pathGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error { - elemType := slice.Elem() - if types.Implements(elemType, spi.iface()) { - p.steps = append(p.steps, step{ - container: t, - name: sliceMarker, - slice: true, - typ: elemType, - }) - } - + //elemType := slice.Elem() + //if types.Implements(elemType, spi.iface()) { + // p.addStep(t, elemType, sliceMarker, true) + //} + // return nil } @@ -213,11 +212,13 @@ func (p *pathGen) buildConstWithEnum() *jen.Statement { return constBlock } -func (p *pathGen) generateWalkASTPath(spi generatorSPI) *jen.Statement { +func (p *pathGen) generateWalkASTPath() *jen.Statement { method := jen.Func().Id("WalkASTPath").Params( jen.Id("node").Id(p.ifaceName), jen.Id("path").Id("ASTPath"), ).Id(p.ifaceName).Block( + jen.If(jen.Id("path").Op("==").Id(`""`).Block( + jen.Return(jen.Id("node")))), jen.Id("step").Op(":=").Qual("encoding/binary", "BigEndian").Dot("Uint16").Call(jen.Index().Byte().Parens(jen.Id("path[:2]"))), jen.Id("path").Op("=").Id("path[2:]"), @@ -235,7 +236,7 @@ func (p *pathGen) generateWalkCases() []jen.Code { if !step.slice { // return WalkASTPath(node.(*RefContainer).ASTType, path) - t := types.TypeString(step.typ, noQualifier) + t := types.TypeString(step.container, noQualifier) n := jen.Id("node").Dot(fmt.Sprintf("(%s)", t)).Dot(step.name) cases = append(cases, jen.Case(jen.Id(stepName)).Block( @@ -250,7 +251,7 @@ func (p *pathGen) generateWalkCases() []jen.Code { if step.name == sliceMarker { assignNode = jen.Id("node").Dot(fmt.Sprintf("(%s)", t)).Index(jen.Id("idx")) } else { - assignNode = jen.Id("node").Dot(step.name).Index(jen.Id("idx")) + assignNode = jen.Id("node").Dot(fmt.Sprintf("(%s)", t)).Dot(step.name).Index(jen.Id("idx")) } cases = append(cases, jen.Case(jen.Id(stepName+"8")).Block( @@ -260,8 +261,8 @@ func (p *pathGen) generateWalkCases() []jen.Code { )) cases = append(cases, jen.Case(jen.Id(stepName+"32")).Block( - jen.Id("idx").Op(":=").Qual("encoding/binary", "BigEndian").Dot("Uint16").Call(jen.Index().Byte().Parens(jen.Id("path[:2]"))), - jen.Id("path").Op("=").Id("path[2:]"), + jen.Id("idx").Op(":=").Qual("encoding/binary", "BigEndian").Dot("Uint32").Call(jen.Index().Byte().Parens(jen.Id("path[:2]"))), + jen.Id("path").Op("=").Id("path[4:]"), jen.Return(jen.Id("WalkASTPath").Call(assignNode, jen.Id("path"))), )) } diff --git a/go/tools/asthelpergen/integration/ast_path.go b/go/tools/asthelpergen/integration/ast_path.go index 87cb659c430..986b6dda8f4 100644 --- a/go/tools/asthelpergen/integration/ast_path.go +++ b/go/tools/asthelpergen/integration/ast_path.go @@ -17,14 +17,12 @@ limitations under the License. package integration +import "encoding/binary" + type ASTStep uint16 const ( - InterfaceSlice8 ASTStep = iota - InterfaceSlice32 - LeafSlice8 - LeafSlice32 - RefOfRefContainerASTType + RefOfRefContainerASTType ASTStep = iota RefOfRefContainerASTImplementationType RefOfRefSliceContainerASTElements8 RefOfRefSliceContainerASTElements32 @@ -36,10 +34,6 @@ const ( ValueSliceContainerASTElements8 ValueSliceContainerASTElements32 ValueSliceContainerASTImplementationElements - SliceOfAST8 - SliceOfAST32 - SliceOfRefOfLeaf8 - SliceOfRefOfLeaf32 RefOfValueContainerASTType RefOfValueContainerASTImplementationType RefOfValueSliceContainerASTElements8 @@ -49,14 +43,6 @@ const ( func (s ASTStep) DebugString() string { switch s { - case InterfaceSlice8: - return "(InterfaceSlice)[]8" - case InterfaceSlice32: - return "(InterfaceSlice)[]32" - case LeafSlice8: - return "(LeafSlice)[]8" - case LeafSlice32: - return "(LeafSlice)[]32" case RefOfRefContainerASTType: return "(*RefContainer).ASTType" case RefOfRefContainerASTImplementationType: @@ -81,14 +67,6 @@ func (s ASTStep) DebugString() string { return "(ValueSliceContainer).ASTElements32" case ValueSliceContainerASTImplementationElements: return "(ValueSliceContainer).ASTImplementationElements" - case SliceOfAST8: - return "([]AST)[]8" - case SliceOfAST32: - return "([]AST)[]32" - case SliceOfRefOfLeaf8: - return "([]*Leaf)[]8" - case SliceOfRefOfLeaf32: - return "([]*Leaf)[]32" case RefOfValueContainerASTType: return "(*ValueContainer).ASTType" case RefOfValueContainerASTImplementationType: @@ -102,3 +80,63 @@ func (s ASTStep) DebugString() string { } panic("unknown ASTStep") } +func WalkASTPath(node AST, path ASTPath) AST { + if path == "" { + return node + } + step := binary.BigEndian.Uint16([]byte(path[:2])) + path = path[2:] + switch ASTStep(step) { + case RefOfRefContainerASTType: + return WalkASTPath(node.(*RefContainer).ASTType, path) + case RefOfRefContainerASTImplementationType: + return WalkASTPath(node.(*RefContainer).ASTImplementationType, path) + case RefOfRefSliceContainerASTElements8: + idx := path[0] + path = path[1:] + return WalkASTPath(node.(*RefSliceContainer).ASTElements[idx], path) + case RefOfRefSliceContainerASTElements32: + idx := binary.BigEndian.Uint32([]byte(path[:2])) + path = path[4:] + return WalkASTPath(node.(*RefSliceContainer).ASTElements[idx], path) + case RefOfRefSliceContainerASTImplementationElements8: + idx := path[0] + path = path[1:] + return WalkASTPath(node.(*RefSliceContainer).ASTImplementationElements[idx], path) + case RefOfRefSliceContainerASTImplementationElements32: + idx := binary.BigEndian.Uint32([]byte(path[:2])) + path = path[4:] + return WalkASTPath(node.(*RefSliceContainer).ASTImplementationElements[idx], path) + case RefOfSubImplinner: + return WalkASTPath(node.(*SubImpl).inner, path) + case ValueContainerASTType: + return WalkASTPath(node.(ValueContainer).ASTType, path) + case ValueContainerASTImplementationType: + return WalkASTPath(node.(ValueContainer).ASTImplementationType, path) + case ValueSliceContainerASTElements8: + idx := path[0] + path = path[1:] + return WalkASTPath(node.(ValueSliceContainer).ASTElements[idx], path) + case ValueSliceContainerASTElements32: + idx := binary.BigEndian.Uint32([]byte(path[:2])) + path = path[4:] + return WalkASTPath(node.(ValueSliceContainer).ASTElements[idx], path) + case ValueSliceContainerASTImplementationElements: + return WalkASTPath(node.(ValueSliceContainer).ASTImplementationElements, path) + case RefOfValueContainerASTType: + return WalkASTPath(node.(*ValueContainer).ASTType, path) + case RefOfValueContainerASTImplementationType: + return WalkASTPath(node.(*ValueContainer).ASTImplementationType, path) + case RefOfValueSliceContainerASTElements8: + idx := path[0] + path = path[1:] + return WalkASTPath(node.(*ValueSliceContainer).ASTElements[idx], path) + case RefOfValueSliceContainerASTElements32: + idx := binary.BigEndian.Uint32([]byte(path[:2])) + path = path[4:] + return WalkASTPath(node.(*ValueSliceContainer).ASTElements[idx], path) + case RefOfValueSliceContainerASTImplementationElements: + return WalkASTPath(node.(*ValueSliceContainer).ASTImplementationElements, path) + } + return nil +} diff --git a/go/tools/asthelpergen/integration/ast_path_test.go b/go/tools/asthelpergen/integration/ast_path_test.go new file mode 100644 index 00000000000..e32956fac37 --- /dev/null +++ b/go/tools/asthelpergen/integration/ast_path_test.go @@ -0,0 +1,67 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package integration + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "reflect" + "testing" +) + +func TestWalkAllPartsOfAST(t *testing.T) { + sliceContainer := &RefSliceContainer{ + something: 12, + ASTElements: []AST{}, + NotASTElements: []int{1, 2}, + ASTImplementationElements: []*Leaf{{v: 1}, {v: 2}}, + } + + for i := range 300 { + sliceContainer.ASTImplementationElements = append(sliceContainer.ASTImplementationElements, &Leaf{v: i}) + } + + ast := &RefContainer{ + ASTType: sliceContainer, + NotASTType: 2, + ASTImplementationType: &Leaf{v: 3}, + } + + v := make(map[ASTPath]AST) + + RewriteWithPaths(ast, func(c *Cursor) bool { + node := c.Node() + if !reflect.TypeOf(node).Comparable() { + return true + } + current := c.current + v[current] = node + return true + }, nil) + + fmt.Println("walked all parts of AST") + + assert.NotEmpty(t, v) + + for path, n1 := range v { + s := path.DebugString() + fmt.Println(s) + + n2 := WalkASTPath(ast, path) + assert.Equal(t, n1, n2) + } +} diff --git a/go/tools/asthelpergen/integration/test_helpers.go b/go/tools/asthelpergen/integration/test_helpers.go index 64278f3c357..592cda2e19e 100644 --- a/go/tools/asthelpergen/integration/test_helpers.go +++ b/go/tools/asthelpergen/integration/test_helpers.go @@ -17,6 +17,8 @@ limitations under the License. package integration import ( + "encoding/binary" + "fmt" "strings" ) @@ -100,6 +102,22 @@ func Rewrite(node AST, pre, post ApplyFunc) AST { return outer.AST } +func RewriteWithPaths(node AST, pre, post ApplyFunc) AST { + outer := &struct{ AST }{node} + + a := &application{ + pre: pre, + post: post, + collectPaths: true, + } + + a.rewriteAST(outer, node, func(newNode, parent AST) { + outer.AST = newNode + }) + + return outer.AST +} + type ( cow struct { pre func(node, parent AST) bool @@ -115,3 +133,58 @@ type ( func (c *cow) postVisit(a, b AST, d bool) (AST, bool) { return a, d } + +func (path ASTPath) DebugString() string { + var sb strings.Builder + + remaining := []byte(path) + stepCount := 0 + + for len(remaining) >= 2 { + // Read the step code (2 bytes) + stepVal := binary.BigEndian.Uint16(remaining[:2]) + remaining = remaining[2:] + + step := ASTStep(stepVal) + stepStr := step.DebugString() // e.g. "CaseExprWhens8" or "CaseExprWhens32" + + // If this isn't the very first step in the path, prepend a separator + if stepCount > 0 { + sb.WriteString("->") + } + stepCount++ + + // Write the step name + sb.WriteString(stepStr) + + // Check suffix to see if we need to read an offset + switch { + // 1-byte offset if stepStr ends with "8" + case strings.HasSuffix(stepStr, "8"): + if len(remaining) < 1 { + sb.WriteString("(ERR-no-offset-byte)") + return sb.String() + } + offsetByte := remaining[0] + remaining = remaining[1:] + sb.WriteString(fmt.Sprintf("(%d)", offsetByte)) + + // 4-byte offset if stepStr ends with "32" + case strings.HasSuffix(stepStr, "32"): + if len(remaining) < 4 { + sb.WriteString("(ERR-no-offset-uint32)") + return sb.String() + } + offsetVal := binary.BigEndian.Uint32(remaining[:4]) + remaining = remaining[4:] + sb.WriteString(fmt.Sprintf("(%d)", offsetVal)) + } + } + + // If there's leftover data that doesn't fit into 2 (or more) bytes, you could note it: + if len(remaining) != 0 { + sb.WriteString("->(ERR-unaligned-extra-bytes)") + } + + return sb.String() +}