Skip to content

Commit

Permalink
restrict which commands can have arguments rewritten
Browse files Browse the repository at this point in the history
  • Loading branch information
jakecoffman committed Jun 29, 2022
1 parent ea2873b commit 31a37b5
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 9 deletions.
36 changes: 27 additions & 9 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,15 @@ func main() {
return
}

// Change any ssh or git url arguments to https
for i, arg := range os.Args {
os.Args[i] = Scrub(arg)
if IsRewriteAllowed(os.Args[1:]) {
// Change any ssh or git url arguments to https
for i, arg := range os.Args[1:] {
os.Args[i] = Scrub(arg)
}
}

// Run the scrubbed git command
var cmd *exec.Cmd
if len(os.Args) == 1 {
cmd = exec.Command(os.Args[0])
} else {
cmd = exec.Command(os.Args[0], os.Args[1:]...)
}
cmd := exec.Command(os.Args[0], os.Args[1:]...)
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
err := cmd.Run()
Expand All @@ -39,6 +36,26 @@ func main() {
}
}

var allowedCommands = []string{"clone", "fetch"}

// IsRewriteAllowed returns true if it is safe to rewrite arguments. Some commands
// such as config would break if rewritten, like when using insteadOf.
func IsRewriteAllowed(args []string) bool {
for _, arg := range args {
if strings.HasPrefix(arg, "-") {
continue
}
for _, allowed := range allowedCommands {
if arg == allowed {
return true
}
}
return false
}
return false
}

// FindGit finds the second git executable on the path, the first being this one.
func FindGit(envPath string) string {
paths := strings.Split(envPath, string(os.PathListSeparator))
var shimPath string
Expand Down Expand Up @@ -68,6 +85,7 @@ func FindGit(envPath string) string {

var scpUrl = regexp.MustCompile(`^(?P<user>\S+?)@(?P<host>[a-zA-Z\d-]+(\.[a-zA-Z\d-]+)+\.?):(?P<path>.*?/.*?)$`)

// Scrub rewrites arguments that look like URLs to have the HTTPS protocol.
func Scrub(argument string) string {
u, err := url.ParseRequestURI(argument)
if err == nil && u.Scheme != "" {
Expand Down
32 changes: 32 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,38 @@ import (
"testing"
)

func TestIsRewriteAllowed(t *testing.T) {
var cases = []struct {
input []string
expected bool
}{
{
input: []string{"clone", "[email protected]:org/repo"},
expected: true,
},
{
input: []string{"fetch", ""},
expected: true,
},
{
input: []string{"--work-tree=/work", "clone"},
expected: true,
},
{
input: []string{"config", "--global", "url.\"https://github.com/\".insteadOf", "[email protected]:"},
expected: false,
},
}

for _, test := range cases {
t.Run(fmt.Sprintln(test.input), func(t *testing.T) {
if v := IsRewriteAllowed(test.input); v != test.expected {
t.Errorf("Input: %v\tExpected: %v\tGot: %v\n", test.input, test.expected, v)
}
})
}
}

func TestScrub(t *testing.T) {
var cases = []struct {
input string
Expand Down

0 comments on commit 31a37b5

Please sign in to comment.