From 74642113e664bff94b086d49ac1241799a3a4984 Mon Sep 17 00:00:00 2001
From: Evan Cordell <cordell.evan@gmail.com>
Date: Sat, 22 Jul 2023 19:27:36 -0400
Subject: [PATCH] add sh.RunSh and sh.ExecSh that take functional options

this is primarily to enable passing in a working directory for the
command.
---
 sh/cmd.go      | 190 +++++++++++++++++++++++++++++++++++++------------
 sh/cmd_test.go |  21 ++++++
 2 files changed, 166 insertions(+), 45 deletions(-)

diff --git a/sh/cmd.go b/sh/cmd.go
index 312de65a..43f124a1 100644
--- a/sh/cmd.go
+++ b/sh/cmd.go
@@ -12,29 +12,152 @@ import (
 	"github.com/magefile/mage/mg"
 )
 
+// runOptions is a set of options to be applied with ExecSh.
+type runOptions struct {
+	cmd            string
+	args           []string
+	dir            string
+	env            map[string]string
+	stderr, stdout io.Writer
+}
+
+// RunOpt applies an option to a runOptions set.
+type RunOpt func(*runOptions)
+
+// WithV sets stderr and stdout the standard streams
+func WithV() RunOpt {
+	return func(options *runOptions) {
+		options.stdout = os.Stdout
+		options.stderr = os.Stderr
+	}
+}
+
+// WithEnv sets the env passed in env vars.
+func WithEnv(env map[string]string) RunOpt {
+	return func(options *runOptions) {
+		if options.env == nil {
+			options.env = make(map[string]string)
+		}
+		for k, v := range env {
+			options.env[k] = v
+		}
+	}
+}
+
+// WithStderr sets the stderr stream.
+func WithStderr(w io.Writer) RunOpt {
+	return func(options *runOptions) {
+		options.stderr = w
+	}
+}
+
+// WithStdout sets the stdout stream.
+func WithStdout(w io.Writer) RunOpt {
+	return func(options *runOptions) {
+		options.stdout = w
+	}
+}
+
+// WithDir sets the working directory for the command.
+func WithDir(dir string) RunOpt {
+	return func(options *runOptions) {
+		options.dir = dir
+	}
+}
+
+// WithArgs appends command arguments.
+func WithArgs(args ...string) RunOpt {
+	return func(options *runOptions) {
+		if options.args == nil {
+			options.args = make([]string, 0, len(args))
+		}
+		options.args = append(options.args, args...)
+	}
+}
+
+// RunSh returns a function that calls ExecSh, only returning errors.
+func RunSh(cmd string, options ...RunOpt) func(args ...string) error {
+	run := ExecSh(cmd, options...)
+	return func(args ...string) error {
+		_, err := run(args...)
+		return err
+	}
+}
+
+// ExecSh returns a function that executes the command, piping its stdout and
+// stderr according to the config options. If the command fails, it will return
+// an error that, if returned from a target or mg.Deps call, will cause mage to
+// exit with the same code as the command failed with.
+//
+// ExecSh takes a variable list of RunOpt objects to configure how the command
+// is executed. See RunOpt docs for more details.
+//
+// Env vars configured on the command override the current environment variables
+// set (which are also passed to the command). The cmd and args may include
+// references to environment variables in $FOO format, in which case these will be
+// expanded before the command is run.
+//
+// Ran reports if the command ran (rather than was not found or not executable).
+// Code reports the exit code the command returned if it ran. If err == nil, ran
+// is always true and code is always 0.
+func ExecSh(cmd string, options ...RunOpt) func(args ...string) (bool, error) {
+	opts := runOptions{
+		cmd: cmd,
+	}
+	for _, o := range options {
+		o(&opts)
+	}
+
+	if opts.stdout == nil && mg.Verbose() {
+		opts.stdout = os.Stdout
+	}
+
+	return func(args ...string) (bool, error) {
+		expand := func(s string) string {
+			s2, ok := opts.env[s]
+			if ok {
+				return s2
+			}
+			return os.Getenv(s)
+		}
+		cmd = os.Expand(cmd, expand)
+		finalArgs := append(opts.args, args...)
+		for i := range finalArgs {
+			finalArgs[i] = os.Expand(finalArgs[i], expand)
+		}
+		ran, code, err := run(opts.dir, opts.env, opts.stdout, opts.stderr, cmd, finalArgs...)
+
+		if err == nil {
+			return ran, nil
+		}
+		if ran {
+			return ran, mg.Fatalf(code, `running "%s %s" failed with exit code %d`, cmd, strings.Join(args, " "), code)
+		}
+		return ran, fmt.Errorf(`failed to run "%s %s: %v"`, cmd, strings.Join(args, " "), err)
+	}
+}
+
 // RunCmd returns a function that will call Run with the given command. This is
 // useful for creating command aliases to make your scripts easier to read, like
 // this:
 //
-//  // in a helper file somewhere
-//  var g0 = sh.RunCmd("go")  // go is a keyword :(
+//	 // in a helper file somewhere
+//	 var g0 = sh.RunCmd("go")  // go is a keyword :(
 //
-//  // somewhere in your main code
-//	if err := g0("install", "github.com/gohugo/hugo"); err != nil {
-//		return err
-//  }
+//	 // somewhere in your main code
+//		if err := g0("install", "github.com/gohugo/hugo"); err != nil {
+//			return err
+//	 }
 //
 // Args passed to command get baked in as args to the command when you run it.
 // Any args passed in when you run the returned function will be appended to the
 // original args.  For example, this is equivalent to the above:
 //
-//  var goInstall = sh.RunCmd("go", "install") goInstall("github.com/gohugo/hugo")
+//	var goInstall = sh.RunCmd("go", "install") goInstall("github.com/gohugo/hugo")
 //
 // RunCmd uses Exec underneath, so see those docs for more details.
 func RunCmd(cmd string, args ...string) func(args ...string) error {
-	return func(args2 ...string) error {
-		return Run(cmd, append(args, args2...)...)
-	}
+	return RunSh(cmd, WithArgs(args...))
 }
 
 // OutCmd is like RunCmd except the command returns the output of the
@@ -47,13 +170,12 @@ func OutCmd(cmd string, args ...string) func(args ...string) (string, error) {
 
 // Run is like RunWith, but doesn't specify any environment variables.
 func Run(cmd string, args ...string) error {
-	return RunWith(nil, cmd, args...)
+	return RunSh(cmd, WithArgs(args...))()
 }
 
 // RunV is like Run, but always sends the command's stdout to os.Stdout.
 func RunV(cmd string, args ...string) error {
-	_, err := Exec(nil, os.Stdout, os.Stderr, cmd, args...)
-	return err
+	return RunSh(cmd, WithV(), WithArgs(args...))()
 }
 
 // RunWith runs the given command, directing stderr to this program's stderr and
@@ -61,31 +183,25 @@ func RunV(cmd string, args ...string) error {
 // environment variables for the command being run. Environment variables should
 // be in the format name=value.
 func RunWith(env map[string]string, cmd string, args ...string) error {
-	var output io.Writer
-	if mg.Verbose() {
-		output = os.Stdout
-	}
-	_, err := Exec(env, output, os.Stderr, cmd, args...)
-	return err
+	return RunSh(cmd, WithEnv(env), WithArgs(args...))()
 }
 
 // RunWithV is like RunWith, but always sends the command's stdout to os.Stdout.
 func RunWithV(env map[string]string, cmd string, args ...string) error {
-	_, err := Exec(env, os.Stdout, os.Stderr, cmd, args...)
-	return err
+	return RunSh(cmd, WithV(), WithEnv(env), WithArgs(args...))()
 }
 
 // Output runs the command and returns the text from stdout.
 func Output(cmd string, args ...string) (string, error) {
 	buf := &bytes.Buffer{}
-	_, err := Exec(nil, buf, os.Stderr, cmd, args...)
+	err := RunSh(cmd, WithStderr(os.Stderr), WithStdout(buf), WithArgs(args...))()
 	return strings.TrimSuffix(buf.String(), "\n"), err
 }
 
 // OutputWith is like RunWith, but returns what is written to stdout.
 func OutputWith(env map[string]string, cmd string, args ...string) (string, error) {
 	buf := &bytes.Buffer{}
-	_, err := Exec(env, buf, os.Stderr, cmd, args...)
+	err := RunSh(cmd, WithEnv(env), WithStderr(os.Stderr), WithStdout(buf), WithArgs(args...))()
 	return strings.TrimSuffix(buf.String(), "\n"), err
 }
 
@@ -102,40 +218,23 @@ func OutputWith(env map[string]string, cmd string, args ...string) (string, erro
 // Code reports the exit code the command returned if it ran. If err == nil, ran
 // is always true and code is always 0.
 func Exec(env map[string]string, stdout, stderr io.Writer, cmd string, args ...string) (ran bool, err error) {
-	expand := func(s string) string {
-		s2, ok := env[s]
-		if ok {
-			return s2
-		}
-		return os.Getenv(s)
-	}
-	cmd = os.Expand(cmd, expand)
-	for i := range args {
-		args[i] = os.Expand(args[i], expand)
-	}
-	ran, code, err := run(env, stdout, stderr, cmd, args...)
-	if err == nil {
-		return true, nil
-	}
-	if ran {
-		return ran, mg.Fatalf(code, `running "%s %s" failed with exit code %d`, cmd, strings.Join(args, " "), code)
-	}
-	return ran, fmt.Errorf(`failed to run "%s %s: %v"`, cmd, strings.Join(args, " "), err)
+	return ExecSh(cmd, WithArgs(args...), WithStderr(stderr), WithStdout(stdout), WithEnv(env))()
 }
 
-func run(env map[string]string, stdout, stderr io.Writer, cmd string, args ...string) (ran bool, code int, err error) {
+func run(dir string, env map[string]string, stdout, stderr io.Writer, cmd string, args ...string) (ran bool, code int, err error) {
 	c := exec.Command(cmd, args...)
 	c.Env = os.Environ()
 	for k, v := range env {
 		c.Env = append(c.Env, k+"="+v)
 	}
+	c.Dir = dir
 	c.Stderr = stderr
 	c.Stdout = stdout
 	c.Stdin = os.Stdin
 
-	var quoted []string 
+	var quoted []string
 	for i := range args {
-		quoted = append(quoted, fmt.Sprintf("%q", args[i]));
+		quoted = append(quoted, fmt.Sprintf("%q", args[i]))
 	}
 	// To protect against logging from doing exec in global variables
 	if mg.Verbose() {
@@ -144,6 +243,7 @@ func run(env map[string]string, stdout, stderr io.Writer, cmd string, args ...st
 	err = c.Run()
 	return CmdRan(err), ExitStatus(err), err
 }
+
 // CmdRan examines the error to determine if it was generated as a result of a
 // command running via os/exec.Command.  If the error is nil, or the command ran
 // (even if it exited with a non-zero exit code), CmdRan reports true.  If the
diff --git a/sh/cmd_test.go b/sh/cmd_test.go
index c2f5d04f..aabc1412 100644
--- a/sh/cmd_test.go
+++ b/sh/cmd_test.go
@@ -3,6 +3,8 @@ package sh
 import (
 	"bytes"
 	"os"
+	"path/filepath"
+	"strings"
 	"testing"
 )
 
@@ -68,5 +70,24 @@ func TestAutoExpand(t *testing.T) {
 	if s != "baz" {
 		t.Fatalf(`Expected "baz" but got %q`, s)
 	}
+}
 
+func TestDirectory(t *testing.T) {
+	tmp := t.TempDir()
+	buf := &bytes.Buffer{}
+	err := RunSh("pwd", WithDir(tmp), WithStdout(buf))()
+	if err != nil {
+		t.Fatal(err)
+	}
+	dir, err := filepath.EvalSymlinks(strings.TrimSpace(buf.String()))
+	if err != nil {
+		t.Fatal(err)
+	}
+	tmpDir, err := filepath.EvalSymlinks(tmp)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if dir != tmpDir {
+		t.Fatalf(`Expected %q but got %q`, tmpDir, dir)
+	}
 }