diff --git a/README.md b/README.md index 7b80771..e1c218b 100644 --- a/README.md +++ b/README.md @@ -6,19 +6,22 @@ --- -`defer` is a golang analyzer that finds defer functions which return anything. +`nakedefer` is a golang analyzer that finds defer functions which return anything. ### Installation ```shell -go get -u github.com/GaijinEntertainment/go-defer/cmd/defer +go get -u github.com/GaijinEntertainment/go-nakedefer/cmd/nakedefer ``` ### Usage ``` -defer ./... +nakedefer [-flag] [package] +Flags: + -e value + Regular expression to exclude function names ``` @@ -30,38 +33,90 @@ func funcNotReturnAnyType() { } func funcReturnErr() error { - return errors.New("some error") + return errors.New("some error") } -// valid -func someFuncWithValidDefer1() { - defer func() { - }() +func funcReturnFuncAndErr() (func(), error) { + return func() { + }, nil } -// valid -func someFuncWithValidDefer2() { - defer funcNotReturnAnyType() +func ignoreFunc() error { + return errors.New("some error") } -// invalid, deferred call should not return any type -func someFuncWithInvalidDefer1() { - defer func() error { +func testCaseValid1() { + defer funcNotReturnAnyType() // valid + + defer func() { // valid + funcNotReturnAnyType() + }() + + defer func() { // valid + _ = funcReturnErr() + }() +} + +func testCaseInvalid1() { + defer funcReturnErr() // invalid + + defer funcReturnFuncAndErr() // invalid + + defer func() error { // invalid return nil }() + + defer func() func() { // invalid + return func() {} + }() } -// invalid, deferred call should not return any type -func someFuncWithInvalidDefer2() { - defer funcReturnErr() +func testCase1() { + defer fmt.Errorf("some text") // invalid + + r := new(bytes.Buffer) + defer io.LimitReader(r, 1) // invalid + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("DONE")) + })) + defer srv.Close() // invalid + defer srv.CloseClientConnections() // invalid + defer srv.Certificate() // invalid } -// invalid, deferred call should not return any type -func someFuncWithInvalidDefer3() { - defer func() func() { - return func() { +func testCase2() { + s := datatest.SomeStruct{} + defer s.RetNothing() // valid + defer s.RetErr() // invalid + defer s.RetInAndErr() // invalid +} + +func testCaseExclude1() { + // exclude ignoreFunc + defer ignoreFunc() // valid - excluded +} + +func testCaseExclude2() { + // exclude os\.(Create|WriteFile|Chmod) + defer os.Create("file_test1") // valid - excluded + defer os.WriteFile("file_test2", []byte("data"), os.ModeAppend) // valid - excluded + defer os.Chmod("file_test3", os.ModeAppend) // valid - excluded + defer os.FindProcess(100500) // invalid +} + +func testCaseExclude3() { + // exclude fmt\.Print.* + defer fmt.Println("e1") // valid - excluded + defer fmt.Print("e1") // valid - excluded + defer fmt.Printf("e1") // valid - excluded + defer fmt.Sprintf("some text") // invalid +} - } - }() +func testCaseExclude4() { + // exclude io\.Close + rc, _ := zlib.NewReader(bytes.NewReader([]byte("111"))) + defer rc.Close() // valid - excluded } ``` \ No newline at end of file diff --git a/cmd/defer/main.go b/cmd/nakedefer/main.go similarity index 58% rename from cmd/defer/main.go rename to cmd/nakedefer/main.go index ed91ce3..1036211 100644 --- a/cmd/defer/main.go +++ b/cmd/nakedefer/main.go @@ -3,11 +3,11 @@ package main import ( "golang.org/x/tools/go/analysis/singlechecker" - "github.com/GaijinEntertainment/go-defer/pkg/analyzer" + "github.com/GaijinEntertainment/go-nakedefer/pkg/analyzer" ) func main() { - a, err := analyzer.NewAnalyzer() + a, err := analyzer.NewAnalyzer([]string{}) if err != nil { panic(err) } diff --git a/go.mod b/go.mod index 019b6ae..9472fcb 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/GaijinEntertainment/go-defer +module github.com/GaijinEntertainment/go-nakedefer go 1.19 diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index 422897a..a5ca516 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -1,36 +1,76 @@ package analyzer import ( + "bytes" + "errors" + "flag" "go/ast" + "go/printer" + "go/token" + "go/types" + "strings" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" "golang.org/x/tools/go/ast/inspector" ) +var ( + ErrEmptyExcludePattern = errors.New("pattern for excluding function can't be empty") +) + +type analyzer struct { + typesInfo *types.Info + exclude PatternsList +} + // NewAnalyzer returns a go/analysis-compatible analyzer. -func NewAnalyzer() (*analysis.Analyzer, error) { +func NewAnalyzer(exclude []string) (*analysis.Analyzer, error) { + a := analyzer{} //nolint:exhaustruct + + var err error + + a.exclude, err = newPatternsList(exclude) + if err != nil { + return nil, err + } + return &analysis.Analyzer{ //nolint:exhaustruct - Name: "defer", + Name: "nakedefer", Doc: "Checks that deferred call does not return anything.", - Run: run, + Run: a.run, Requires: []*analysis.Analyzer{inspect.Analyzer}, + Flags: a.newFlagSet(), }, nil } -func run(pass *analysis.Pass) (interface{}, error) { +func (a *analyzer) newFlagSet() flag.FlagSet { + fs := flag.NewFlagSet("nakedefer flags", flag.PanicOnError) + + fs.Var( + &reListVar{values: &a.exclude}, + "e", + "Regular expression to exclude function names", + ) + + return *fs +} + +func (a *analyzer) run(pass *analysis.Pass) (interface{}, error) { insp := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) //nolint:forcetypeassert nodeFilter := []ast.Node{ (*ast.DeferStmt)(nil), } - insp.Preorder(nodeFilter, newVisitor(pass)) + a.typesInfo = pass.TypesInfo + + insp.Preorder(nodeFilter, a.newVisitor(pass)) return nil, nil //nolint:nilnil } -func newVisitor(pass *analysis.Pass) func(node ast.Node) { +func (a *analyzer) newVisitor(pass *analysis.Pass) func(node ast.Node) { return func(node ast.Node) { deferStmt, ok := node.(*ast.DeferStmt) if !ok { @@ -41,22 +81,24 @@ func newVisitor(pass *analysis.Pass) func(node ast.Node) { return } - var outgoingFieldList *ast.FieldList + funcName := a.funcName(deferStmt.Call) + if funcName != "" && a.exclude.MatchesAny(funcName) { + return + } + var hasReturn bool switch v := deferStmt.Call.Fun.(type) { - case *ast.Ident: // function is named - outgoingFieldList = getFuncDeclResults(v) case *ast.FuncLit: // function is anonymous - outgoingFieldList = getFuncLitResults(v) + hasReturn = a.isFuncLitReturnValues(v) + case *ast.Ident: + hasReturn = a.isIdentReturnValues(v) + case *ast.SelectorExpr: + hasReturn = a.isSelExprReturnValues(v) default: return } - if outgoingFieldList == nil || outgoingFieldList.List == nil { - return - } - - if len(outgoingFieldList.List) == 0 { + if !hasReturn { return } @@ -64,23 +106,91 @@ func newVisitor(pass *analysis.Pass) func(node ast.Node) { } } -func getFuncDeclResults(ident *ast.Ident) *ast.FieldList { - if ident.Obj == nil { - return nil +func (a *analyzer) isIdentReturnValues(ident *ast.Ident) bool { + if ident == nil || ident.Obj == nil { + return false } funcDecl, ok := ident.Obj.Decl.(*ast.FuncDecl) if !ok { - return nil + return false + } + + if funcDecl.Type == nil || funcDecl.Type.Results == nil { + return false + } + + if len(funcDecl.Type.Results.List) == 0 { + return false + } + + return true +} + +func (a *analyzer) isFuncLitReturnValues(funcLit *ast.FuncLit) bool { + if funcLit == nil || funcLit.Type == nil { + return false + } + + if funcLit.Type == nil || funcLit.Type.Results == nil { + return false + } + + if len(funcLit.Type.Results.List) == 0 { + return false + } + + return true +} + +func (a *analyzer) isSelExprReturnValues(selExpr *ast.SelectorExpr) bool { + if selExpr == nil { + return false + } + + t, ok := a.typesInfo.Types[selExpr].Type.(*types.Signature) + if !ok { + return false + } + + if t.Results() == nil || t.Results().Len() == 0 { + return false + } + + return true +} + +func (a *analyzer) funcName(call *ast.CallExpr) string { + fn, ok := a.getFunc(call) + if !ok { + return gofmt(call.Fun) } - return funcDecl.Type.Results + name := fn.FullName() + name = strings.ReplaceAll(name, "(", "") + name = strings.ReplaceAll(name, ")", "") + + return name } -func getFuncLitResults(funcLit *ast.FuncLit) *ast.FieldList { - if funcLit.Type == nil { - return nil +func (a *analyzer) getFunc(call *ast.CallExpr) (*types.Func, bool) { + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return nil, false } - return funcLit.Type.Results + fn, ok := a.typesInfo.ObjectOf(sel.Sel).(*types.Func) + if !ok { + return nil, false + } + + return fn, true +} + +func gofmt(x interface{}) string { + buf := bytes.Buffer{} + fs := token.NewFileSet() + printer.Fprint(&buf, fs, x) + + return buf.String() } diff --git a/pkg/analyzer/analyzer_test.go b/pkg/analyzer/analyzer_test.go index 9df43fa..29f27fc 100644 --- a/pkg/analyzer/analyzer_test.go +++ b/pkg/analyzer/analyzer_test.go @@ -7,7 +7,7 @@ import ( "golang.org/x/tools/go/analysis/analysistest" - "github.com/GaijinEntertainment/go-defer/pkg/analyzer" + "github.com/GaijinEntertainment/go-nakedefer/pkg/analyzer" ) func TestAll(t *testing.T) { @@ -20,7 +20,9 @@ func TestAll(t *testing.T) { testdata := filepath.Join(filepath.Dir(filepath.Dir(wd)), "testdata") - a, err := analyzer.NewAnalyzer() + a, err := analyzer.NewAnalyzer( + []string{"ignoreFunc", "os\\.(Create|WriteFile|Chmod)", "fmt\\.Print.*", "io\\.Close"}, + ) if err != nil { t.Error(err) } diff --git a/pkg/analyzer/patterns-list.go b/pkg/analyzer/patterns-list.go new file mode 100644 index 0000000..9de3afb --- /dev/null +++ b/pkg/analyzer/patterns-list.go @@ -0,0 +1,68 @@ +package analyzer + +import ( + "fmt" + "regexp" +) + +type PatternsList []*regexp.Regexp + +// MatchesAny matches provided string against all regexps in a slice. +func (l PatternsList) MatchesAny(str string) bool { + for _, r := range l { + if r.MatchString(str) { + return true + } + } + + return false +} + +// newPatternsList parses slice of strings to a slice of compiled regular +// expressions. +func newPatternsList(in []string) (PatternsList, error) { + list := PatternsList{} + + for _, str := range in { + re, err := strToRegexp(str) + if err != nil { + return nil, err + } + + list = append(list, re) + } + + return list, nil +} + +type reListVar struct { + values *PatternsList +} + +func (v *reListVar) Set(value string) error { + re, err := strToRegexp(value) + if err != nil { + return err + } + + *v.values = append(*v.values, re) + + return nil +} + +func (v *reListVar) String() string { + return "" +} + +func strToRegexp(str string) (*regexp.Regexp, error) { + if str == "" { + return nil, ErrEmptyExcludePattern + } + + re, err := regexp.Compile(str) + if err != nil { + return nil, fmt.Errorf("unable to compile %s as regular expression: %w", str, err) + } + + return re, nil +} diff --git a/testdata/src/datatest/datatest.go b/testdata/src/datatest/datatest.go new file mode 100644 index 0000000..a82c7b1 --- /dev/null +++ b/testdata/src/datatest/datatest.go @@ -0,0 +1,15 @@ +package datatest + +type SomeStruct struct { +} + +func (s SomeStruct) RetErr() error { + return nil +} + +func (s SomeStruct) RetInAndErr() (int, error) { + return 100, nil +} + +func (s SomeStruct) RetNothing() { +} diff --git a/testdata/src/p/p.go b/testdata/src/p/p.go index b6eacba..eb7dbdf 100644 --- a/testdata/src/p/p.go +++ b/testdata/src/p/p.go @@ -1,7 +1,16 @@ package p import ( + "bytes" + "compress/zlib" "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + + "datatest" ) func funcNotReturnAnyType() { @@ -16,32 +25,81 @@ func funcReturnFuncAndErr() (func(), error) { }, nil } -func funcDeferNotReturnAnyType1() { - defer funcNotReturnAnyType() +func ignoreFunc() error { + return errors.New("some error") } -func funcDeferNotReturnAnyType2() { - defer func() { +func testCaseValid1() { + defer funcNotReturnAnyType() // ignore + + defer func() { //ignore + funcNotReturnAnyType() + }() + + defer func() { //ignore _ = funcReturnErr() }() } -func funcDeferReturnErr() { +func testCaseInvalid1() { defer funcReturnErr() // want "deferred call should not return anything" -} -func funcDeferReturnErrAndFunc() { defer funcReturnFuncAndErr() // want "deferred call should not return anything" -} -func funcDeferAnonymousReturnFunc() { + defer func() error { // want "deferred call should not return anything" + return nil + }() + defer func() func() { // want "deferred call should not return anything" return func() {} }() } -func funcDeferAnonymousReturnIntAndErr() { - defer func() (int, error) { // want "deferred call should not return anything" - return 1, nil - }() +func testCase1() { + defer fmt.Errorf("some text") // want "deferred call should not return anything" + + r := new(bytes.Buffer) + defer io.LimitReader(r, 1) // want "deferred call should not return anything" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("DONE")) + })) + defer srv.Close() //ignore + defer srv.CloseClientConnections() //ignore + defer srv.Certificate() // want "deferred call should not return anything" +} + +func testCase2() { + s := datatest.SomeStruct{} + defer s.RetNothing() // ignore + defer s.RetErr() // want "deferred call should not return anything" + defer s.RetInAndErr() // want "deferred call should not return anything" +} + +func testCaseExclude1() { + // exclude ignoreFunc + defer ignoreFunc() // ignore +} + +func testCaseExclude2() { + // exclude os\.(Create|WriteFile|Chmod) + defer os.Create("file_test1") // ignore + defer os.WriteFile("file_test2", []byte("data"), os.ModeAppend) // ignore + defer os.Chmod("file_test3", os.ModeAppend) // ignore + defer os.FindProcess(100500) // want "deferred call should not return anything" +} + +func testCaseExclude3() { + // exclude fmt\.Print.* + defer fmt.Println("e1") // ignore + defer fmt.Print("e1") // ignore + defer fmt.Printf("e1") // ignore + defer fmt.Sprintf("some text") // want "deferred call should not return anything" +} + +func testCaseExclude4() { + // exclude io\.Close + rc, _ := zlib.NewReader(bytes.NewReader([]byte("111"))) + defer rc.Close() }