Skip to content

Commit

Permalink
Clean template functions
Browse files Browse the repository at this point in the history
  • Loading branch information
LandonTClipp committed Jan 2, 2025
1 parent 09bf81f commit f798319
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 118 deletions.
10 changes: 0 additions & 10 deletions cmd/mockery.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,10 @@ func NewRootCmd() (*cobra.Command, error) {

pFlags := cmd.PersistentFlags()
pFlags.StringVar(&cfgFile, "config", "", "config file to use")
pFlags.String("dir", "", "directory to search for interfaces")
pFlags.BoolP("recursive", "r", false, "recurse search into sub-directories")
pFlags.StringArray("exclude", nil, "prefixes of subdirectories and files to exclude from search")
pFlags.Bool("all", false, "generates mocks for all found interfaces in all sub-directories")
pFlags.String("note", "", "comment to insert into prologue of each generated file")
pFlags.String("cpuprofile", "", "write cpu profile to file")
pFlags.Bool("version", false, "prints the installed version of mockery")
pFlags.String("tags", "", "space-separated list of additional build tags to load packages")
pFlags.String("mock-build-tags", "", "set the build tags of the generated mocks. Read more about the format: https://pkg.go.dev/cmd/go#hdr-Build_constraints")
pFlags.String("filename", "", "name of generated file (only works with -name and no regex)")
pFlags.String("structname", "", "name of generated struct (only works with -name and no regex)")
pFlags.String("log-level", "info", "Level of logging")
pFlags.String("srcpkg", "", "source pkg to search for interfaces")
pFlags.BoolP("dry-run", "d", false, "Do a dry run, don't modify any files")
pFlags.String("boilerplate-file", "", "File to read a boilerplate text from. Text should be a go block comment, i.e. /* ... */")
pFlags.Bool("unroll-variadic", true, "For functions with variadic arguments, do not unroll the arguments into the underlying testify call. Instead, pass variadic slice as-is.")

Expand Down
2 changes: 1 addition & 1 deletion mockery-tools.env
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION=v3.0.0-alpha.7
VERSION=v3.0.0-alpha.8
49 changes: 2 additions & 47 deletions pkg/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ import (
"strings"

"github.com/chigopher/pathlib"
"github.com/huandu/xstrings"
"github.com/jinzhu/copier"
"github.com/mitchellh/mapstructure"
"github.com/rs/zerolog"
"github.com/spf13/viper"
"github.com/vektra/mockery/v3/pkg/logging"
"github.com/vektra/mockery/v3/pkg/stackerr"
mockeryTemplate "github.com/vektra/mockery/v3/pkg/template"
"golang.org/x/tools/go/packages"
"gopkg.in/yaml.v3"
)
Expand Down Expand Up @@ -66,7 +66,6 @@ func NewConfigFromViper(v *viper.Viper) (*Config, error) {
v.SetDefault("formatter", "goimports")
v.SetDefault("mockname", "Mock{{.InterfaceName}}")
v.SetDefault("pkgname", "{{.SrcPackageName}}")
v.SetDefault("dry-run", false)
v.SetDefault("log-level", "info")

if err := v.UnmarshalExact(c); err != nil {
Expand Down Expand Up @@ -757,50 +756,6 @@ func (c *Config) TagName(name string) string {

var ErrInfiniteLoop = fmt.Errorf("infinite loop in template variables detected")

// Functions available in the template for manipulating
//
// Since the map and its functions are stateless, it exists as
// a package var rather than being initialized on every call
// in [parseConfigTemplates] and [generator.printTemplate]
var templateFuncMap = template.FuncMap{
// String inspection and manipulation
"contains": strings.Contains,
"hasPrefix": strings.HasPrefix,
"hasSuffix": strings.HasSuffix,
"join": strings.Join,
"replace": strings.Replace,
"replaceAll": strings.ReplaceAll,
"split": strings.Split,
"splitAfter": strings.SplitAfter,
"splitAfterN": strings.SplitAfterN,
"trim": strings.Trim,
"trimLeft": strings.TrimLeft,
"trimPrefix": strings.TrimPrefix,
"trimRight": strings.TrimRight,
"trimSpace": strings.TrimSpace,
"trimSuffix": strings.TrimSuffix,
"lower": strings.ToLower,
"upper": strings.ToUpper,
"camelcase": xstrings.ToCamelCase,
"snakecase": xstrings.ToSnakeCase,
"kebabcase": xstrings.ToKebabCase,
"firstLower": xstrings.FirstRuneToLower,
"firstUpper": xstrings.FirstRuneToUpper,

// Regular expression matching
"matchString": regexp.MatchString,
"quoteMeta": regexp.QuoteMeta,

// Filepath manipulation
"base": filepath.Base,
"clean": filepath.Clean,
"dir": filepath.Dir,

// Basic access to reading environment variables
"expandEnv": os.ExpandEnv,
"getenv": os.Getenv,
}

// ParseTemplates parses various templated strings
// in the config struct into their fully defined values. This mutates
// the config object passed. An *Interface object can be supplied to satisfy
Expand Down Expand Up @@ -891,7 +846,7 @@ func (c *Config) ParseTemplates(ctx context.Context, iface *Interface, srcPkg *p
for name, attributePointer := range templateMap {
oldVal := *attributePointer

attributeTempl, err := template.New("interface-template").Funcs(templateFuncMap).Parse(*attributePointer)
attributeTempl, err := template.New("config-template").Funcs(mockeryTemplate.StringManipulationFuncs).Parse(*attributePointer)
if err != nil {
return fmt.Errorf("failed to parse %s template: %w", name, err)
}
Expand Down
20 changes: 10 additions & 10 deletions pkg/mockery.templ
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ package {{.PkgName}}

import (
{{- range .Imports}}
{{. | ImportStatement}}
{{. | importStatement}}
{{- end}}
mock "github.com/stretchr/testify/mock"
)
Expand Down Expand Up @@ -71,19 +71,19 @@ func (_mock *{{$mock.MockName}}{{ $mock.TypeInstantiation }}) {{$method.Name}}({
{{- $calledString = "" }}
{{- end }}

{{- $lastParam := index $method.Params (len $method.Params | Add -1 )}}
{{- $lastParam := index $method.Params (len $method.Params | add -1 )}}
if len({{ $lastParam.Var.Name }}) > 0 {
{{- if ne (len $method.Returns) 0}}tmpRet = {{ end }}_mock.Called({{- if (index $mock.TemplateData "unroll-variadic") }}{{ $method.ArgCallList }}{{- else }}{{ $method.ArgCallListNoEllipsis }}{{- end }})
} else {
{{- if ne (len $method.Returns) 0}}tmpRet = {{ end }}_mock.Called({{- if (index $mock.TemplateData "unroll-variadic") }}{{ $method.ArgCallListSlice 0 (len $method.Params | Add -1 )}}{{- else }}{{ $method.ArgCallListSliceNoEllipsis 0 (len $method.Params | Add -1 )}}{{- end }})
{{- if ne (len $method.Returns) 0}}tmpRet = {{ end }}_mock.Called({{- if (index $mock.TemplateData "unroll-variadic") }}{{ $method.ArgCallListSlice 0 (len $method.Params | add -1 )}}{{- else }}{{ $method.ArgCallListSliceNoEllipsis 0 (len $method.Params | add -1 )}}{{- end }})
}
{{- else }}
{{- $calledString = printf "_mock.Called(%s)" $method.ArgCallList }}
{{- end }}
{{- else }}
{{- $lastParam := (index $method.Params (len $method.Params | Add -1)) }}
{{- $lastParam := (index $method.Params (len $method.Params | add -1)) }}
{{- $variadicArgsName := $lastParam.Var.Name }}
{{- $strippedTypeString := TrimPrefix "..." $lastParam.TypeStringEllipsis }}
{{- $strippedTypeString := trimPrefix "..." $lastParam.TypeStringEllipsis }}

{{- if and (ne $strippedTypeString "interface{}") (ne $strippedTypeString "any") }}
// {{ $strippedTypeString }}
Expand All @@ -95,7 +95,7 @@ func (_mock *{{$mock.MockName}}{{ $mock.TypeInstantiation }}) {{$method.Name}}({
{{- end }}
var _ca []interface{}
{{- if gt (len $method.Params) 1 }}
_ca = append(_ca, {{ $method.ArgCallListSlice 0 (len $method.Params | Add -1) }})
_ca = append(_ca, {{ $method.ArgCallListSlice 0 (len $method.Params | add -1) }})
{{- end }}
_ca = append(_ca, {{ $variadicArgsName }}...)
{{- $calledString = "_mock.Called(_ca...)" }}
Expand Down Expand Up @@ -136,7 +136,7 @@ func (_mock *{{$mock.MockName}}{{ $mock.TypeInstantiation }}) {{$method.Name}}({
}
{{- end }} {{/* END RETURN RANGE */}}
{{- end }}
return {{ range $retIdx, $ret := $method.Returns }}r{{ $retIdx }}{{ if ne $retIdx (len $method.Returns | Add -1) }}, {{ end }}{{ end }}
return {{ range $retIdx, $ret := $method.Returns }}r{{ $retIdx }}{{ if ne $retIdx (len $method.Returns | add -1) }}, {{ end }}{{ end }}
}

{{/* CREATE EXPECTER METHOD */}}
Expand All @@ -161,7 +161,7 @@ func (_e *{{ $expecterNameInstantiated }}) {{ $method.Name }}({{ range $method.P
{{- else }}
append([]interface{}{
{{- range $i, $param := $method.Params }}
{{- if (lt $i (len $method.Params | Add -1 ))}} {{ $param.Var.Name }},
{{- if (lt $i (len $method.Params | add -1 ))}} {{ $param.Var.Name }},
{{- else }} }, {{ $param.Var.Name }}...
{{- end }}
{{- end}} )...
Expand All @@ -173,8 +173,8 @@ func (_c *{{ $ExpecterCallNameInstantiated }}) Run(run func({{ $method.ArgList }
{{- if not $method.IsVariadic }}
run({{range $i, $param := $method.Params }}args[{{$i}}].({{ $param.TypeString}}),{{end}})
{{- else}}
{{- $variadicParam := index $method.Params (len $method.Params | Add -1) }}
{{- $nonVariadicParams := slice $method.Params 0 (len $method.Params | Add -1 )}}
{{- $variadicParam := index $method.Params (len $method.Params | add -1) }}
{{- $nonVariadicParams := slice $method.Params 0 (len $method.Params | add -1 )}}
variadicArgs := make([]{{ $variadicParam.TypeStringVariadicUnderlying }}, len(args) - {{len $nonVariadicParams}})
for i, a := range args[{{len $nonVariadicParams}}:] {
if a != nil {
Expand Down
20 changes: 11 additions & 9 deletions pkg/moq.templ
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ package {{.PkgName}}

import (
{{- range .Imports}}
{{. | ImportStatement}}
{{. | importStatement}}
{{- end}}
{{- if .Mocks | mocksSomeMethod }}
"sync"
{{- end }}
"fmt"
)

Expand Down Expand Up @@ -54,7 +56,7 @@ var _ {{$.SrcPkgQualifier}}{{.InterfaceName -}}
type {{.MockName}}
{{- if .TypeParams -}}
[{{- range $index, $param := .TypeParams}}
{{- if $index}}, {{end}}{{$param.Name | Exported}} {{$param.TypeString}}
{{- if $index}}, {{end}}{{$param.Name | exported}} {{$param.TypeString}}
{{- end -}}]
{{- end }} struct {
{{- range .Methods}}
Expand All @@ -67,14 +69,14 @@ type {{.MockName}}
// {{.Name}} holds details about calls to the {{.Name}} method.
{{.Name}} []struct {
{{- range .Params}}
// {{.Name | Exported}} is the {{.Name}} argument value.
{{.Name | Exported}} {{.TypeString}}
// {{.Name | exported}} is the {{.Name}} argument value.
{{.Name | exported}} {{.TypeString}}
{{- end}}
}
{{- end}}
}
{{- range .Methods}}
lock{{.Name}} {{$.Imports | SyncPkgQualifier}}.RWMutex
lock{{.Name}} {{$.Imports | syncPkgQualifier}}.RWMutex
{{- end}}
}
{{range .Methods}}
Expand All @@ -87,11 +89,11 @@ func (mock *{{$mock.MockName}}{{ $mock.TypeInstantiation }}) {{.Name}}({{.ArgLis
{{- end}}
callInfo := struct {
{{- range .Params}}
{{.Name | Exported}} {{.TypeString}}
{{.Name | exported}} {{.TypeString}}
{{- end}}
}{
{{- range .Params}}
{{.Name | Exported}}: {{.Name}},
{{.Name | exported}}: {{.Name}},
{{- end}}
}
mock.lock{{.Name}}.Lock()
Expand Down Expand Up @@ -125,12 +127,12 @@ func (mock *{{$mock.MockName}}{{ $mock.TypeInstantiation }}) {{.Name}}({{.ArgLis
// len(mocked{{$mock.InterfaceName}}.{{.Name}}Calls())
func (mock *{{$mock.MockName}}{{ $mock.TypeInstantiation }}) {{.Name}}Calls() []struct {
{{- range .Params}}
{{.Name | Exported}} {{.TypeString}}
{{.Name | exported}} {{.TypeString}}
{{- end}}
} {
var calls []struct {
{{- range .Params}}
{{.Name | Exported}} {{.TypeString}}
{{.Name | exported}} {{.TypeString}}
{{- end}}
}
mock.lock{{.Name}}.RLock()
Expand Down
85 changes: 48 additions & 37 deletions pkg/template/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@ type Template struct {

// New returns a new instance of Template.
func New(templateString string, name string) (Template, error) {
tmpl, err := template.New(name).Funcs(templateFuncs).Parse(templateString)
mergedFuncMap := template.FuncMap{}
for key, val := range StringManipulationFuncs {
mergedFuncMap[key] = val
}
for key, val := range TemplateMockFuncs {
mergedFuncMap[key] = val
}

tmpl, err := template.New(name).Funcs(mergedFuncMap).Parse(templateString)
if err != nil {
return Template{}, err
}
Expand Down Expand Up @@ -54,14 +62,14 @@ func exported(s string) string {
return strings.ToUpper(s[0:1]) + s[1:]
}

var templateFuncs = template.FuncMap{
"ImportStatement": func(imprt *registry.Package) string {
var TemplateMockFuncs = template.FuncMap{
"importStatement": func(imprt *registry.Package) string {
if imprt.Alias == "" {
return `"` + imprt.Path() + `"`
}
return imprt.Alias + ` "` + imprt.Path() + `"`
},
"SyncPkgQualifier": func(imports []*registry.Package) string {
"syncPkgQualifier": func(imports []*registry.Package) string {
for _, imprt := range imports {
if imprt.Path() == "sync" {
return imprt.Qualifier()
Expand All @@ -70,9 +78,9 @@ var templateFuncs = template.FuncMap{

return "sync"
},
"Exported": exported,
"exported": exported,

"MocksSomeMethod": func(mocks []MockData) bool {
"mocksSomeMethod": func(mocks []MockData) bool {
for _, m := range mocks {
if len(m.Methods) > 0 {
return true
Expand All @@ -81,7 +89,7 @@ var templateFuncs = template.FuncMap{

return false
},
"TypeConstraintTest": func(m MockData) string {
"typeConstraintTest": func(m MockData) string {
if len(m.TypeParams) == 0 {
return ""
}
Expand All @@ -97,45 +105,48 @@ var templateFuncs = template.FuncMap{
s += "]"
return s
},
}

var StringManipulationFuncs = template.FuncMap{
// String inspection and manipulation. Note that the first argument is replaced
// as the last argument in some functions in order to support chained
// template pipelines.
"Contains": func(substr string, s string) bool { return strings.Contains(s, substr) },
"HasPrefix": func(prefix string, s string) bool { return strings.HasPrefix(s, prefix) },
"HasSuffix": func(suffix string, s string) bool { return strings.HasSuffix(s, suffix) },
"Join": func(sep string, elems []string) string { return strings.Join(elems, sep) },
"Replace": func(old string, new string, n int, s string) string { return strings.Replace(s, old, new, n) },
"ReplaceAll": func(old string, new string, s string) string { return strings.ReplaceAll(s, old, new) },
"Split": func(sep string, s string) []string { return strings.Split(s, sep) },
"SplitAfter": func(sep string, s string) []string { return strings.SplitAfter(s, sep) },
"SplitAfterN": func(sep string, n int, s string) []string { return strings.SplitAfterN(s, sep, n) },
"Trim": func(cutset string, s string) string { return strings.Trim(s, cutset) },
"TrimLeft": func(cutset string, s string) string { return strings.TrimLeft(s, cutset) },
"TrimPrefix": func(prefix string, s string) string { return strings.TrimPrefix(s, prefix) },
"TrimRight": func(cutset string, s string) string { return strings.TrimRight(s, cutset) },
"TrimSpace": strings.TrimSpace,
"TrimSuffix": func(suffix string, s string) string { return strings.TrimSuffix(s, suffix) },
"Lower": strings.ToLower,
"Upper": strings.ToUpper,
"Camelcase": xstrings.ToCamelCase,
"Snakecase": xstrings.ToSnakeCase,
"Kebabcase": xstrings.ToKebabCase,
"FirstLower": xstrings.FirstRuneToLower,
"FirstUpper": xstrings.FirstRuneToUpper,
"contains": func(substr string, s string) bool { return strings.Contains(s, substr) },
"hasPrefix": func(prefix string, s string) bool { return strings.HasPrefix(s, prefix) },
"hasSuffix": func(suffix string, s string) bool { return strings.HasSuffix(s, suffix) },
"join": func(sep string, elems []string) string { return strings.Join(elems, sep) },
"replace": func(old string, new string, n int, s string) string { return strings.Replace(s, old, new, n) },
"replaceAll": func(old string, new string, s string) string { return strings.ReplaceAll(s, old, new) },
"split": func(sep string, s string) []string { return strings.Split(s, sep) },
"splitAfter": func(sep string, s string) []string { return strings.SplitAfter(s, sep) },
"splitAfterN": func(sep string, n int, s string) []string { return strings.SplitAfterN(s, sep, n) },
"trim": func(cutset string, s string) string { return strings.Trim(s, cutset) },
"trimLeft": func(cutset string, s string) string { return strings.TrimLeft(s, cutset) },
"trimPrefix": func(prefix string, s string) string { return strings.TrimPrefix(s, prefix) },
"trimRight": func(cutset string, s string) string { return strings.TrimRight(s, cutset) },
"trimSpace": strings.TrimSpace,
"trimSuffix": func(suffix string, s string) string { return strings.TrimSuffix(s, suffix) },
"lower": strings.ToLower,
"upper": strings.ToUpper,
"camelcase": xstrings.ToCamelCase,
"snakecase": xstrings.ToSnakeCase,
"kebabcase": xstrings.ToKebabCase,
"firstLower": xstrings.FirstRuneToLower,
"firstUpper": xstrings.FirstRuneToUpper,

// Regular expression matching
"MatchString": regexp.MatchString,
"QuoteMeta": regexp.QuoteMeta,
"matchString": regexp.MatchString,
"quoteMeta": regexp.QuoteMeta,

// Filepath manipulation
"Base": filepath.Base,
"Clean": filepath.Clean,
"Dir": filepath.Dir,
"base": filepath.Base,
"clean": filepath.Clean,
"dir": filepath.Dir,

// Basic access to reading environment variables
"ExpandEnv": os.ExpandEnv,
"Getenv": os.Getenv,
"expandEnv": os.ExpandEnv,
"getenv": os.Getenv,

// Arithmetic
"Add": func(i1, i2 int) int { return i1 + i2 },
"add": func(i1, i2 int) int { return i1 + i2 },
}
Loading

0 comments on commit f798319

Please sign in to comment.