Skip to content

Commit

Permalink
Merge pull request #353 from valenok-husky/issue_352
Browse files Browse the repository at this point in the history
Issue #352
  • Loading branch information
LandonTClipp authored Dec 28, 2020
2 parents 55c4821 + 89f6713 commit e725139
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 12 deletions.
5 changes: 5 additions & 0 deletions pkg/fixtures/func_args_collision.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package test

type FuncArgsCollision interface {
Foo(ret interface{}) error
}
41 changes: 29 additions & 12 deletions pkg/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ func (g *Generator) Generate(ctx context.Context) error {
g.printf("(%s) {\n", strings.Join(returns.Types, ", "))
}

var formattedParamNames string
formattedParamNames := ""
setOfParamNames := make(map[string]struct{}, len(params.Names))
for i, name := range params.Names {
if i > 0 {
formattedParamNames += ", "
Expand All @@ -510,35 +511,36 @@ func (g *Generator) Generate(ctx context.Context) error {
name += "..."
}
formattedParamNames += name

setOfParamNames[name] = struct{}{}
}

called := g.generateCalled(params, formattedParamNames) // _m.Called invocation string

if len(returns.Types) > 0 {
g.printf("\tret := %s\n\n", called)
retVariable := resolveCollision(setOfParamNames, "ret")
g.printf("\t%s := %s\n\n", retVariable, called)

var (
ret []string
)
ret := make([]string, len(returns.Types))

for idx, typ := range returns.Types {
g.printf("\tvar r%d %s\n", idx, typ)
g.printf("\tif rf, ok := ret.Get(%d).(func(%s) %s); ok {\n",
idx, strings.Join(params.Types, ", "), typ)
g.printf("\tif rf, ok := %s.Get(%d).(func(%s) %s); ok {\n",
retVariable, idx, strings.Join(params.Types, ", "), typ)
g.printf("\t\tr%d = rf(%s)\n", idx, formattedParamNames)
g.printf("\t} else {\n")
if typ == "error" {
g.printf("\t\tr%d = ret.Error(%d)\n", idx, idx)
g.printf("\t\tr%d = %s.Error(%d)\n", idx, retVariable, idx)
} else if returns.Nilable[idx] {
g.printf("\t\tif ret.Get(%d) != nil {\n", idx)
g.printf("\t\t\tr%d = ret.Get(%d).(%s)\n", idx, idx, typ)
g.printf("\t\tif %s.Get(%d) != nil {\n", retVariable, idx)
g.printf("\t\t\tr%d = %s.Get(%d).(%s)\n", idx, retVariable, idx, typ)
g.printf("\t\t}\n")
} else {
g.printf("\t\tr%d = ret.Get(%d).(%s)\n", idx, idx, typ)
g.printf("\t\tr%d = %s.Get(%d).(%s)\n", idx, retVariable, idx, typ)
}
g.printf("\t}\n\n")

ret = append(ret, fmt.Sprintf("r%d", idx))
ret[idx] = fmt.Sprintf("r%d", idx)
}

g.printf("\treturn %s\n", strings.Join(ret, ", "))
Expand Down Expand Up @@ -626,3 +628,18 @@ func (g *Generator) Write(w io.Writer) error {
w.Write(res)
return nil
}

func resolveCollision(names map[string]struct{}, variable string) string {
ret := variable

for i := len(names); true; i++ {
_, ok := names[ret]
if !ok {
break
}

ret = fmt.Sprintf("%s_%d", variable, i)
}

return ret
}
25 changes: 25 additions & 0 deletions pkg/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,31 @@ func (_m *MapToInterface) Foo(arg1 ...map[string]interface{}) {

}

func (s *GeneratorSuite) TestGeneratorFunctionArgsNamesCollision() {
expected := `// FuncArgsCollision is an autogenerated mock type for the FuncArgsCollision type
type FuncArgsCollision struct {
mock.Mock
}
// Foo provides a mock function with given fields: ret
func (_m *FuncArgsCollision) Foo(ret interface{}) error {
ret_1 := _m.Called(ret)
var r0 error
if rf, ok := ret_1.Get(0).(func(interface{}) error); ok {
r0 = rf(ret)
} else {
r0 = ret_1.Error(0)
}
return r0
}
`
s.checkGeneration(
filepath.Join(fixturePath, "func_args_collision.go"), "FuncArgsCollision", false, "", expected,
)
}

func (s *GeneratorSuite) TestGeneratorWithImportSameAsLocalPackage() {
expected := `// ImportsSameAsPackage is an autogenerated mock type for the ImportsSameAsPackage type
type ImportsSameAsPackage struct {
Expand Down

0 comments on commit e725139

Please sign in to comment.