diff --git a/pkg/fixtures/func_args_collision.go b/pkg/fixtures/func_args_collision.go new file mode 100644 index 000000000..408551969 --- /dev/null +++ b/pkg/fixtures/func_args_collision.go @@ -0,0 +1,5 @@ +package test + +type FuncArgsCollision interface { + Foo(ret interface{}) error +} diff --git a/pkg/generator.go b/pkg/generator.go index dde2a64c6..33772ed9b 100644 --- a/pkg/generator.go +++ b/pkg/generator.go @@ -498,7 +498,8 @@ func (g *Generator) Generate(ctx context.Context) error { g.printf("(%s) {\n", strings.Join(returns.Types, ", ")) } - var formattedParamNames string + formattedParamNames := "" + setOfParamNames := make(map[string]struct{}, len(params.Names)) for i, name := range params.Names { if i > 0 { formattedParamNames += ", " @@ -510,35 +511,36 @@ func (g *Generator) Generate(ctx context.Context) error { name += "..." } formattedParamNames += name + + setOfParamNames[name] = struct{}{} } called := g.generateCalled(params, formattedParamNames) // _m.Called invocation string if len(returns.Types) > 0 { - g.printf("\tret := %s\n\n", called) + retVariable := resolveCollision(setOfParamNames, "ret") + g.printf("\t%s := %s\n\n", retVariable, called) - var ( - ret []string - ) + ret := make([]string, len(returns.Types)) for idx, typ := range returns.Types { g.printf("\tvar r%d %s\n", idx, typ) - g.printf("\tif rf, ok := ret.Get(%d).(func(%s) %s); ok {\n", - idx, strings.Join(params.Types, ", "), typ) + g.printf("\tif rf, ok := %s.Get(%d).(func(%s) %s); ok {\n", + retVariable, idx, strings.Join(params.Types, ", "), typ) g.printf("\t\tr%d = rf(%s)\n", idx, formattedParamNames) g.printf("\t} else {\n") if typ == "error" { - g.printf("\t\tr%d = ret.Error(%d)\n", idx, idx) + g.printf("\t\tr%d = %s.Error(%d)\n", idx, retVariable, idx) } else if returns.Nilable[idx] { - g.printf("\t\tif ret.Get(%d) != nil {\n", idx) - g.printf("\t\t\tr%d = ret.Get(%d).(%s)\n", idx, idx, typ) + g.printf("\t\tif %s.Get(%d) != nil {\n", retVariable, idx) + g.printf("\t\t\tr%d = %s.Get(%d).(%s)\n", idx, retVariable, idx, typ) g.printf("\t\t}\n") } else { - g.printf("\t\tr%d = ret.Get(%d).(%s)\n", idx, idx, typ) + g.printf("\t\tr%d = %s.Get(%d).(%s)\n", idx, retVariable, idx, typ) } g.printf("\t}\n\n") - ret = append(ret, fmt.Sprintf("r%d", idx)) + ret[idx] = fmt.Sprintf("r%d", idx) } g.printf("\treturn %s\n", strings.Join(ret, ", ")) @@ -626,3 +628,18 @@ func (g *Generator) Write(w io.Writer) error { w.Write(res) return nil } + +func resolveCollision(names map[string]struct{}, variable string) string { + ret := variable + + for i := len(names); true; i++ { + _, ok := names[ret] + if !ok { + break + } + + ret = fmt.Sprintf("%s_%d", variable, i) + } + + return ret +} diff --git a/pkg/generator_test.go b/pkg/generator_test.go index e9d260862..395a840e4 100644 --- a/pkg/generator_test.go +++ b/pkg/generator_test.go @@ -1050,6 +1050,31 @@ func (_m *MapToInterface) Foo(arg1 ...map[string]interface{}) { } +func (s *GeneratorSuite) TestGeneratorFunctionArgsNamesCollision() { + expected := `// FuncArgsCollision is an autogenerated mock type for the FuncArgsCollision type +type FuncArgsCollision struct { + mock.Mock +} + +// Foo provides a mock function with given fields: ret +func (_m *FuncArgsCollision) Foo(ret interface{}) error { + ret_1 := _m.Called(ret) + + var r0 error + if rf, ok := ret_1.Get(0).(func(interface{}) error); ok { + r0 = rf(ret) + } else { + r0 = ret_1.Error(0) + } + + return r0 +} +` + s.checkGeneration( + filepath.Join(fixturePath, "func_args_collision.go"), "FuncArgsCollision", false, "", expected, + ) +} + func (s *GeneratorSuite) TestGeneratorWithImportSameAsLocalPackage() { expected := `// ImportsSameAsPackage is an autogenerated mock type for the ImportsSameAsPackage type type ImportsSameAsPackage struct {