Skip to content

Commit

Permalink
Merge pull request #326 from vektra/revert-323-load-speedup
Browse files Browse the repository at this point in the history
Revert "Speed up loading of source code"
  • Loading branch information
LandonTClipp authored Aug 13, 2020
2 parents b78c3cf + 9d10898 commit c146958
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 99 deletions.
28 changes: 26 additions & 2 deletions cmd/mockery.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"path/filepath"
"regexp"
"runtime/pprof"
"strings"
Expand All @@ -19,6 +20,7 @@ import (
"github.com/vektra/mockery/v2/pkg/config"
"github.com/vektra/mockery/v2/pkg/logging"
"golang.org/x/crypto/ssh/terminal"
"golang.org/x/tools/go/packages"
)

var (
Expand Down Expand Up @@ -48,6 +50,7 @@ func printStackTrace(e error) {
fmt.Printf("%+s:%d\n", f, f)
}
}

}

// Execute executes the cobra CLI workflow
Expand Down Expand Up @@ -83,7 +86,7 @@ func init() {
pFlags.String("filename", "", "name of generated file (only works with -name and no regex)")
pFlags.String("structname", "", "name of generated struct (only works with -name and no regex)")
pFlags.String("log-level", "info", "Level of logging")
pFlags.String("srcpkg", "", "source package(s) to search for interfaces, may be a single package name or a package pattern (example: 'github.com/mockery/vektra/...'")
pFlags.String("srcpkg", "", "source pkg to search for interfaces")
pFlags.BoolP("dry-run", "d", false, "Do a dry run, don't modify any files")

viper.BindPFlags(pFlags)
Expand Down Expand Up @@ -133,6 +136,7 @@ func (r *RootApp) Run() error {
var recursive bool
var filter *regexp.Regexp
var err error
var limitOne bool

if r.Quiet {
// if "quiet" flag is set, disable logging
Expand Down Expand Up @@ -167,6 +171,7 @@ func (r *RootApp) Run() error {
}
} else {
filter = regexp.MustCompile(fmt.Sprintf("^%s$", r.Config.Name))
limitOne = true
}
} else if r.Config.All {
recursive = true
Expand Down Expand Up @@ -207,7 +212,25 @@ func (r *RootApp) Run() error {
baseDir := r.Config.Dir

if r.Config.SrcPkg != "" {
baseDir = r.Config.SrcPkg
pkgs, err := packages.Load(&packages.Config{
Mode: packages.NeedFiles,
}, r.Config.SrcPkg)
if err != nil || len(pkgs) == 0 {
log.Fatal().Err(err).Msgf("Failed to load package %s", r.Config.SrcPkg)
}

// NOTE: we only pass one package name (config.SrcPkg) to packages.Load
// it should return one package at most
pkg := pkgs[0]

if pkg.Errors != nil {
log.Fatal().Err(pkg.Errors[0]).Msgf("Failed to load package %s", r.Config.SrcPkg)
}

if len(pkg.GoFiles) == 0 {
log.Fatal().Msgf("No go files in package %s", r.Config.SrcPkg)
}
baseDir = filepath.Dir(pkg.GoFiles[0])
}

visitor := &pkg.GeneratorVisitor{
Expand All @@ -225,6 +248,7 @@ func (r *RootApp) Run() error {
BaseDir: baseDir,
Recursive: recursive,
Filter: filter,
LimitOne: limitOne,
BuildTags: strings.Split(r.Config.BuildTags, " "),
}

Expand Down
9 changes: 3 additions & 6 deletions pkg/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ import (
"context"
"go/format"
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"

"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/vektra/mockery/v2/pkg/config"
Expand All @@ -25,9 +23,8 @@ type GeneratorSuite struct {
}

func (s *GeneratorSuite) SetupTest() {
log := zerolog.New(os.Stdout).Level(zerolog.DebugLevel)
s.ctx = log.WithContext(context.Background())
s.parser = NewParser(nil)
s.ctx = context.Background()
}

func (s *GeneratorSuite) getInterfaceFromFile(interfacePath, interfaceName string) *Interface {
Expand Down Expand Up @@ -719,7 +716,7 @@ func (_m *Fooer) Foo(f func(string) string) error {
}
`
s.checkGeneration(
filepath.Join(fixturePath, "argument_is_func_type.go"), "Fooer", false, "", expected,
filepath.Join(fixturePath, "func_type.go"), "Fooer", false, "", expected,
)
}

Expand Down Expand Up @@ -911,7 +908,7 @@ func (_m *MapFunc) Get(m map[string]func(string) string) error {
}
`
s.checkGeneration(
filepath.Join(fixturePath, "argument_is_map_func.go"), "MapFunc", false, "", expected,
filepath.Join(fixturePath, "map_func.go"), "MapFunc", false, "", expected,
)
}

Expand Down
110 changes: 50 additions & 60 deletions pkg/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@ package pkg

import (
"context"
"fmt"
"go/ast"
"go/types"
"io/ioutil"
"os"
"path/filepath"
"sort"
"strings"

"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/vektra/mockery/v2/pkg/logging"
"golang.org/x/tools/go/packages"
)

Expand All @@ -26,14 +26,12 @@ type Parser struct {
entries []*parserEntry
entriesByFileName map[string]*parserEntry
parserPackages []*types.Package
conf *packages.Config
conf packages.Config
}

func NewParser(buildTags []string) *Parser {
conf := &packages.Config{
Mode: packages.NeedFiles | packages.NeedImports | packages.NeedName | packages.NeedSyntax | packages.NeedTypes,
Tests: false,
}
var conf packages.Config
conf.Mode = packages.LoadSyntax
if len(buildTags) > 0 {
conf.BuildFlags = []string{"-tags", strings.Join(buildTags, ",")}
}
Expand All @@ -44,75 +42,67 @@ func NewParser(buildTags []string) *Parser {
}
}

func (p *Parser) Parse(ctx context.Context, pattern string) error {
log := zerolog.Ctx(ctx)
func (p *Parser) Parse(ctx context.Context, path string) error {
// To support relative paths to mock targets w/ vendor deps, we need to provide eventual
// calls to build.Context.Import with an absolute path. It needs to be absolute because
// Import will only find the vendor directory if our target path for parsing is under
// a "root" (GOROOT or a GOPATH). Only absolute paths will pass the prefix-based validation.
//
// For example, if our parse target is "./ifaces", Import will check if any "roots" are a
// prefix of "ifaces" and decide to skip the vendor search.
path, err := filepath.Abs(path)
if err != nil {
return err
}

info, err := os.Stat(pattern)
if err != nil && !os.IsNotExist(err) {
dir := filepath.Dir(path)

files, err := ioutil.ReadDir(dir)
if err != nil {
return err
}

var query string
switch {
case os.IsNotExist(err):
// The pattern represents one or more packages and should be passed directly to the package loader.
log.Debug().Msgf("Loading packages corresponding to pattern %q.", pattern)
query = pattern

case !info.IsDir():
// A file should be passed directly to the package loader as a 'file' query.
if filepath.Ext(pattern) != ".go" {
return errors.Errorf("specified file %q cannot be parsed as it is not a source file", pattern)
} else if strings.HasPrefix(info.Name(), ".") || strings.HasPrefix(info.Name(), "_") {
log.Debug().Msgf("Skipping file %q as it is prefixed with either '.' or '_'.", pattern)
return nil
}
pattern, err = filepath.Abs(pattern)
if err != nil {
return err
for _, fi := range files {
log := zerolog.Ctx(ctx).With().
Str(logging.LogKeyDir, dir).
Str(logging.LogKeyFile, fi.Name()).
Logger()
ctx = log.WithContext(ctx)

if filepath.Ext(fi.Name()) != ".go" || strings.HasSuffix(fi.Name(), "_test.go") {
continue
}

log.Debug().Msgf("Loading file %q.", pattern)
query = "file=" + pattern
log.Debug().Msgf("parsing")

fname := fi.Name()
fpath := filepath.Join(dir, fname)
if _, ok := p.entriesByFileName[fpath]; ok {
continue
}

case info.IsDir():
// A directory must have its files parsed individually as the package loader does not accept directory queries.
var dir []os.FileInfo
dir, err = ioutil.ReadDir(pattern)
pkgs, err := packages.Load(&p.conf, "file="+fpath)
if err != nil {
return err
}
log.Debug().Msgf("Loading files in directory %q.", pattern)
for _, fi := range dir {
if fi.IsDir() || filepath.Ext(fi.Name()) != ".go" {
continue
}
if err = p.Parse(ctx, filepath.Join(pattern, fi.Name())); err != nil {
return err
if len(pkgs) == 0 {
continue
}
if len(pkgs) > 1 {
names := make([]string, len(pkgs))
for i, p := range pkgs {
names[i] = p.Name
}
panic(fmt.Sprintf("file %s resolves to multiple packages: %s", fpath, strings.Join(names, ", ")))
}

default:
// This is theoretically impossible to reach due to the disjunction of cases operated above.
return errors.Errorf("encountered unexpected situation when retrieving information about %q", pattern)
}

log.Debug().Msgf("parsing")

pkgs, err := packages.Load(p.conf, query)
if err != nil {
return err
} else if filepath.Ext(pattern) == ".go" && len(pkgs) > 1 {
err := errors.Errorf("file %q maps to multiple packages (%d) instead of a single one", pattern, len(pkgs))
log.Err(err).Msgf("invalid file content")
return err
}

for _, pkg := range pkgs {
log.Debug().Msgf("Parsed sources from %q.", query)
pkg := pkgs[0]
if len(pkg.Errors) > 0 {
return pkg.Errors[0]
}
if len(pkg.GoFiles) == 0 {
continue
}

for idx, f := range pkg.GoFiles {
if _, ok := p.entriesByFileName[f]; ok {
Expand Down
14 changes: 14 additions & 0 deletions pkg/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ func TestFileParse(t *testing.T) {
assert.NotNil(t, node)
}

func noTestFileInterfaces(t *testing.T) {
parser := NewParser(nil)

err := parser.Parse(ctx, testFile)
assert.NoError(t, err)

err = parser.Load()
assert.NoError(t, err)

nodes := parser.Interfaces()
assert.Equal(t, 1, len(nodes))
assert.Equal(t, "Requester", nodes[0].Name)
}

func TestBuildTagInFilename(t *testing.T) {
parser := NewParser(nil)

Expand Down
Loading

0 comments on commit c146958

Please sign in to comment.