Skip to content

Commit

Permalink
Add --rewrite flag to get command (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeroenvervaeke authored Sep 25, 2024
1 parent 04bf93d commit ad13f95
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
6 changes: 4 additions & 2 deletions internal/importpaths/rewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ type RewriteModuleOptions struct {
NewVersion string
NewPrefix string
PkgDir string
OnRewrite func(pos token.Position, oldpath, newpath string)
OnRewrite func(pos token.Position, oldpath, newpath string) error
}

// RewriteModule rewrites imports of a specific module to a new version or prefix.
Expand All @@ -188,7 +188,9 @@ func RewriteModule(dir string, opt RewriteModuleOptions) error {
return "", ErrSkip
}
if opt.OnRewrite != nil {
opt.OnRewrite(pos, path, newpath)
if err := opt.OnRewrite(pos, path, newpath); err != nil {
return "", err
}
}
return newpath, nil
})
Expand Down
24 changes: 20 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"go/token"
"os"
"os/exec"
"regexp"
"runtime/debug"

"golang.org/x/mod/modfile"
Expand Down Expand Up @@ -117,13 +118,14 @@ func listcmd(args []string) error {
}

func getcmd(args []string) error {
var dir string
var dir, rewrite string
var pre, cached, major bool
fset := flag.NewFlagSet("get", flag.ExitOnError)
fset.BoolVar(&pre, "pre", false, "allow non-v0 prerelease versions")
fset.BoolVar(&major, "major", false, "only get newer major versions")
fset.StringVar(&dir, "dir", ".", "working directory")
fset.BoolVar(&cached, "cached", true, "only fetch cached content from the module proxy")
fset.StringVar(&rewrite, "rewrite", "", "exact package version to upgrade")
fset.Usage = func() {
fmt.Fprintln(os.Stderr, "Usage: gomajor get <pathspec>")
fset.PrintDefaults()
Expand Down Expand Up @@ -162,8 +164,9 @@ func getcmd(args []string) error {
err := importpaths.RewriteModule(dir, importpaths.RewriteModuleOptions{
Prefix: packages.ModPrefix(u.Module.Path),
NewVersion: u.Latest.Version,
OnRewrite: func(pos token.Position, _, newpath string) {
OnRewrite: func(pos token.Position, _, newpath string) error {
fmt.Printf("%s %s\n", pos, newpath)
return nil
},
})
if err != nil {
Expand Down Expand Up @@ -217,13 +220,25 @@ func getcmd(args []string) error {
if err := cmd.Run(); err != nil {
return err
}
var rewriteRegex *regexp.Regexp
if rewrite != "" {
if rewriteRegex, err = regexp.Compile(rewrite); err != nil {
return err
}
}
// rewrite imports
err = importpaths.RewriteModule(dir, importpaths.RewriteModuleOptions{
PkgDir: pkgdir,
Prefix: modprefix,
NewVersion: version,
OnRewrite: func(pos token.Position, _, newpath string) {
OnRewrite: func(pos token.Position, oldpath, newpath string) error {
if rewriteRegex != nil && !rewriteRegex.MatchString(oldpath) {
return importpaths.ErrSkip
}

fmt.Printf("%s %s\n", pos, newpath)

return nil
},
})
if err != nil {
Expand Down Expand Up @@ -297,8 +312,9 @@ func pathcmd(args []string) error {
Prefix: oldmodprefix,
NewVersion: version,
NewPrefix: modprefix,
OnRewrite: func(pos token.Position, _, newpath string) {
OnRewrite: func(pos token.Position, _, newpath string) error {
fmt.Printf("%s %s\n", pos, newpath)
return nil
},
})
if err != nil {
Expand Down

0 comments on commit ad13f95

Please sign in to comment.