diff --git a/.github/workflows/01-golang-lint.yaml b/.github/workflows/01-golang-lint.yaml new file mode 100644 index 00000000..cff8143e --- /dev/null +++ b/.github/workflows/01-golang-lint.yaml @@ -0,0 +1,24 @@ +name: golangci-lint +on: + push: + tags: + - v* + branches: + - master + - main + pull_request: +permissions: + contents: read +jobs: + golangci: + name: lint + runs-on: ubuntu-latest + steps: + - uses: actions/setup-go@v4 + with: + go-version: stable + - uses: actions/checkout@v3 + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + version: latest diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 00000000..99192466 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,39 @@ +linters-settings: + govet: + check-shadowing: true + maligned: + suggest-new: true + dupl: + threshold: 200 + goconst: + min-len: 3 + min-occurrences: 2 + #forbidigo: + # forbid: + # - ^print.*$ + # - 'fmt\.Print.*' + gocognit: + min-complexity: 61 # This is a rather high value. We should gradually lower it to 30-40. + +linters: + enable: + - gofmt + - goimports + - bodyclose + - dupl + - gocognit + - gocritic + - goimports + - gosec + - nakedret + #- nolintlint + - revive + - stylecheck + - unconvert + - unparam + disable: + - forbidigo + - maligned + - lll + - gochecknoinits + - gochecknoglobals diff --git a/bool_slice_test.go b/bool_slice_test.go index 3c5a274f..a5128678 100644 --- a/bool_slice_test.go +++ b/bool_slice_test.go @@ -5,234 +5,212 @@ import ( "strconv" "strings" "testing" -) -func setUpBSFlagSet(bsp *[]bool) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.BoolSliceVar(bsp, "bs", []bool{}, "Command separated list!") - return f -} + "github.com/stretchr/testify/require" +) -func setUpBSFlagSetWithDefault(bsp *[]bool) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.BoolSliceVar(bsp, "bs", []bool{false, true}, "Command separated list!") - return f -} +func TestBoolSlice(t *testing.T) { + t.Parallel() -func TestEmptyBS(t *testing.T) { - var bs []bool - f := setUpBSFlagSet(&bs) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) + newFlag := func(bsp *[]bool) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.BoolSliceVar(bsp, "bs", []bool{}, "Command separated list!") + return f } - getBS, err := f.GetBoolSlice("bs") - if err != nil { - t.Fatal("got an error from GetBoolSlice():", err) - } - if len(getBS) != 0 { - t.Fatalf("got bs %v with len=%d but expected length=0", getBS, len(getBS)) - } -} + t.Run("with empty slice", func(t *testing.T) { + bs := make([]bool, 0) + f := newFlag(&bs) -func TestBS(t *testing.T) { - var bs []bool - f := setUpBSFlagSet(&bs) + require.NoError(t, f.Parse([]string{})) - vals := []string{"1", "F", "TRUE", "0"} - arg := fmt.Sprintf("--bs=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range bs { - b, err := strconv.ParseBool(vals[i]) - if err != nil { - t.Fatalf("got error: %v", err) - } - if b != v { - t.Fatalf("expected is[%d] to be %s but got: %t", i, vals[i], v) - } - } - getBS, err := f.GetBoolSlice("bs") - if err != nil { - t.Fatalf("got error: %v", err) - } - for i, v := range getBS { - b, err := strconv.ParseBool(vals[i]) - if err != nil { - t.Fatalf("got error: %v", err) - } - if b != v { - t.Fatalf("expected bs[%d] to be %s but got: %t from GetBoolSlice", i, vals[i], v) - } - } -} + getBS, err := f.GetBoolSlice("bs") + require.NoErrorf(t, err, + "got an error from GetBoolSlice(): %v", err, + ) -func TestBSDefault(t *testing.T) { - var bs []bool - f := setUpBSFlagSetWithDefault(&bs) + require.Empty(t, getBS) + }) - vals := []string{"false", "T"} + t.Run("with truthy/falsy values", func(t *testing.T) { + vals := []string{"1", "F", "TRUE", "0"} + bs := make([]bool, 0, len(vals)) + f := newFlag(&bs) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range bs { - b, err := strconv.ParseBool(vals[i]) - if err != nil { - t.Fatalf("got error: %v", err) - } - if b != v { - t.Fatalf("expected bs[%d] to be %t from GetBoolSlice but got: %t", i, b, v) - } - } + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--bs=%s", strings.Join(vals, ",")), + })) - getBS, err := f.GetBoolSlice("bs") - if err != nil { - t.Fatal("got an error from GetBoolSlice():", err) - } - for i, v := range getBS { - b, err := strconv.ParseBool(vals[i]) - if err != nil { - t.Fatal("got an error from GetBoolSlice():", err) - } - if b != v { - t.Fatalf("expected bs[%d] to be %t from GetBoolSlice but got: %t", i, b, v) + for i, v := range bs { + b, err := strconv.ParseBool(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, b, + "expected is[%d] to be %s but got: %t", i, vals[i], v, + ) } - } -} -func TestBSWithDefault(t *testing.T) { - var bs []bool - f := setUpBSFlagSetWithDefault(&bs) + getBS, erb := f.GetBoolSlice("bs") + require.NoError(t, erb) - vals := []string{"FALSE", "1"} - arg := fmt.Sprintf("--bs=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range bs { - b, err := strconv.ParseBool(vals[i]) - if err != nil { - t.Fatalf("got error: %v", err) - } - if b != v { - t.Fatalf("expected bs[%d] to be %t but got: %t", i, b, v) + for i, v := range getBS { + b, err := strconv.ParseBool(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, b, + "expected bs[%d] to be %s but got: %t from GetBoolSlice", i, vals[i], v, + ) } - } + }) - getBS, err := f.GetBoolSlice("bs") - if err != nil { - t.Fatal("got an error from GetBoolSlice():", err) + newFlagWithDefault := func(bsp *[]bool) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.BoolSliceVar(bsp, "bs", []bool{false, true}, "Command separated list!") + return f } - for i, v := range getBS { - b, err := strconv.ParseBool(vals[i]) - if err != nil { - t.Fatalf("got error: %v", err) + + t.Run("with defaults (1)", func(t *testing.T) { + vals := []string{"false", "T"} + bs := make([]bool, 0, len(vals)) + f := newFlagWithDefault(&bs) + + require.NoError(t, f.Parse([]string{})) + + for i, v := range bs { + b, err := strconv.ParseBool(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, b, + "expected bs[%d] to be %t from GetBoolSlice but got: %t", i, b, v, + ) } - if b != v { - t.Fatalf("expected bs[%d] to be %t from GetBoolSlice but got: %t", i, b, v) + + getBS, erb := f.GetBoolSlice("bs") + require.NoErrorf(t, erb, + "got an error from GetBoolSlice(): %v", erb, + ) + + for i, v := range getBS { + b, err := strconv.ParseBool(vals[i]) + require.NoErrorf(t, err, + "got an error from GetBoolSlice(): %v", err, + ) + require.Equalf(t, v, b, + "expected bs[%d] to be %t from GetBoolSlice but got: %t", i, b, v, + ) } - } -} + }) -func TestBSCalledTwice(t *testing.T) { - var bs []bool - f := setUpBSFlagSet(&bs) - - in := []string{"T,F", "T"} - expected := []bool{true, false, true} - argfmt := "--bs=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range bs { - if expected[i] != v { - t.Fatalf("expected bs[%d] to be %t but got %t", i, expected[i], v) + t.Run("with defaults (2)", func(t *testing.T) { + vals := []string{"FALSE", "1"} + bs := make([]bool, 0, len(vals)) + f := newFlagWithDefault(&bs) + + arg := fmt.Sprintf("--bs=%s", strings.Join(vals, ",")) + require.NoError(t, f.Parse([]string{arg})) + + for i, v := range bs { + b, err := strconv.ParseBool(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, b, + "expected bs[%d] to be %t but got: %t", i, b, v, + ) } - } -} -func TestBSAsSliceValue(t *testing.T) { - var bs []bool - f := setUpBSFlagSet(&bs) - - in := []string{"true", "false"} - argfmt := "--bs=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } + getBS, erb := f.GetBoolSlice("bs") + require.NoErrorf(t, erb, + "got an error from GetBoolSlice(): %v", erb, + ) - f.VisitAll(func(f *Flag) { - if val, ok := f.Value.(SliceValue); ok { - _ = val.Replace([]string{"false"}) + for i, v := range getBS { + b, err := strconv.ParseBool(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, b, + "expected bs[%d] to be %t from GetBoolSlice but got: %t", i, b, v, + ) } }) - if len(bs) != 1 || bs[0] != false { - t.Fatalf("Expected ss to be overwritten with 'false', but got: %v", bs) - } -} -func TestBSBadQuoting(t *testing.T) { - - tests := []struct { - Want []bool - FlagArg []string - }{ - { - Want: []bool{true, false, true}, - FlagArg: []string{"1", "0", "true"}, - }, - { - Want: []bool{true, false}, - FlagArg: []string{"True", "F"}, - }, - { - Want: []bool{true, false}, - FlagArg: []string{"T", "0"}, - }, - { - Want: []bool{true, false}, - FlagArg: []string{"1", "0"}, - }, - { - Want: []bool{true, false, false}, - FlagArg: []string{"true,false", "false"}, - }, - { - Want: []bool{true, false, false, true, false, true, false}, - FlagArg: []string{`"true,false,false,1,0, T"`, " false "}, - }, - { - Want: []bool{false, false, true, false, true, false, true}, - FlagArg: []string{`"0, False, T,false , true,F"`, "true"}, - }, - } + t.Run("called twice", func(t *testing.T) { + const argfmt = "--bs=%s" + in := []string{"T,F", "T"} + bs := make([]bool, 0, len(in)) + f := newFlag(&bs) + expected := []bool{true, false, true} - for i, test := range tests { + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) - var bs []bool - f := setUpBSFlagSet(&bs) + require.Equal(t, expected, bs) + }) - if err := f.Parse([]string{fmt.Sprintf("--bs=%s", strings.Join(test.FlagArg, ","))}); err != nil { - t.Fatalf("flag parsing failed with error: %s\nparsing:\t%#v\nwant:\t\t%#v", - err, test.FlagArg, test.Want[i]) - } + t.Run("as slice value", func(t *testing.T) { + const argfmt = "--bs=%s" + in := []string{"true", "false"} + bs := make([]bool, 0, len(in)) + f := newFlag(&bs) - for j, b := range bs { - if b != test.Want[j] { - t.Fatalf("bad value parsed for test %d on bool %d:\nwant:\t%t\ngot:\t%t", i, j, test.Want[j], b) + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + f.VisitAll(func(f *Flag) { + if val, ok := f.Value.(SliceValue); ok { + require.NoError(t, val.Replace([]string{"false"})) } + }) + + require.Equalf(t, []bool{false}, bs, + "expected ss to be overwritten with 'false', but got: %v", bs, + ) + }) + + t.Run("with quoting", func(t *testing.T) { + tests := []struct { + Want []bool + FlagArg []string + }{ + { + Want: []bool{true, false, true}, + FlagArg: []string{"1", "0", "true"}, + }, + { + Want: []bool{true, false}, + FlagArg: []string{"True", "F"}, + }, + { + Want: []bool{true, false}, + FlagArg: []string{"T", "0"}, + }, + { + Want: []bool{true, false}, + FlagArg: []string{"1", "0"}, + }, + { + Want: []bool{true, false, false}, + FlagArg: []string{"true,false", "false"}, + }, + { + Want: []bool{true, false, false, true, false, true, false}, + FlagArg: []string{`"true,false,false,1,0, T"`, " false "}, + }, + { + Want: []bool{false, false, true, false, true, false, true}, + FlagArg: []string{`"0, False, T,false , true,F"`, "true"}, + }, + } + + for i, test := range tests { + bs := make([]bool, 0, 7) + f := newFlag(&bs) + + require.NoErrorf(t, + f.Parse([]string{fmt.Sprintf("--bs=%s", strings.Join(test.FlagArg, ","))}), + "flag parsing failed for test %d with error:\nparsing:\t%#vnwant:\t\t%#v", + test.FlagArg, test.Want, + ) + + require.Equalf(t, test.Want, bs, "on test %d", i) } - } + }) } diff --git a/bool_test.go b/bool_test.go index a4319e79..ecba7103 100644 --- a/bool_test.go +++ b/bool_test.go @@ -8,6 +8,8 @@ import ( "bytes" "strconv" "testing" + + "github.com/stretchr/testify/require" ) // This value can be a boolean ("true", "false") or "maybe" @@ -26,7 +28,7 @@ func (v *triStateValue) IsBoolFlag() bool { } func (v *triStateValue) Get() interface{} { - return triStateValue(*v) + return *v } func (v *triStateValue) Set(s string) error { @@ -40,6 +42,7 @@ func (v *triStateValue) Set(s string) error { } else { *v = triStateFalse } + return err } @@ -60,120 +63,102 @@ func setUpFlagSet(tristate *triStateValue) *FlagSet { *tristate = triStateFalse flag := f.VarPF(tristate, "tristate", "t", "tristate value (true, maybe or false)") flag.NoOptDefVal = "true" - return f -} - -func TestExplicitTrue(t *testing.T) { - var tristate triStateValue - f := setUpFlagSet(&tristate) - err := f.Parse([]string{"--tristate=true"}) - if err != nil { - t.Fatal("expected no error; got", err) - } - if tristate != triStateTrue { - t.Fatal("expected", triStateTrue, "(triStateTrue) but got", tristate, "instead") - } -} - -func TestImplicitTrue(t *testing.T) { - var tristate triStateValue - f := setUpFlagSet(&tristate) - err := f.Parse([]string{"--tristate"}) - if err != nil { - t.Fatal("expected no error; got", err) - } - if tristate != triStateTrue { - t.Fatal("expected", triStateTrue, "(triStateTrue) but got", tristate, "instead") - } -} - -func TestShortFlag(t *testing.T) { - var tristate triStateValue - f := setUpFlagSet(&tristate) - err := f.Parse([]string{"-t"}) - if err != nil { - t.Fatal("expected no error; got", err) - } - if tristate != triStateTrue { - t.Fatal("expected", triStateTrue, "(triStateTrue) but got", tristate, "instead") - } -} - -func TestShortFlagExtraArgument(t *testing.T) { - var tristate triStateValue - f := setUpFlagSet(&tristate) - // The"maybe"turns into an arg, since short boolean options will only do true/false - err := f.Parse([]string{"-t", "maybe"}) - if err != nil { - t.Fatal("expected no error; got", err) - } - if tristate != triStateTrue { - t.Fatal("expected", triStateTrue, "(triStateTrue) but got", tristate, "instead") - } - args := f.Args() - if len(args) != 1 || args[0] != "maybe" { - t.Fatal("expected an extra 'maybe' argument to stick around") - } -} - -func TestExplicitMaybe(t *testing.T) { - var tristate triStateValue - f := setUpFlagSet(&tristate) - err := f.Parse([]string{"--tristate=maybe"}) - if err != nil { - t.Fatal("expected no error; got", err) - } - if tristate != triStateMaybe { - t.Fatal("expected", triStateMaybe, "(triStateMaybe) but got", tristate, "instead") - } -} - -func TestExplicitFalse(t *testing.T) { - var tristate triStateValue - f := setUpFlagSet(&tristate) - err := f.Parse([]string{"--tristate=false"}) - if err != nil { - t.Fatal("expected no error; got", err) - } - if tristate != triStateFalse { - t.Fatal("expected", triStateFalse, "(triStateFalse) but got", tristate, "instead") - } -} -func TestImplicitFalse(t *testing.T) { - var tristate triStateValue - f := setUpFlagSet(&tristate) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - if tristate != triStateFalse { - t.Fatal("expected", triStateFalse, "(triStateFalse) but got", tristate, "instead") - } -} - -func TestInvalidValue(t *testing.T) { - var tristate triStateValue - f := setUpFlagSet(&tristate) - var buf bytes.Buffer - f.SetOutput(&buf) - err := f.Parse([]string{"--tristate=invalid"}) - if err == nil { - t.Fatal("expected an error but did not get any, tristate has value", tristate) - } + return f } -func TestBoolP(t *testing.T) { - b := BoolP("bool", "b", false, "bool value in CommandLine") - c := BoolP("c", "c", false, "other bool value") - args := []string{"--bool"} - if err := CommandLine.Parse(args); err != nil { - t.Error("expected no error, got ", err) - } - if *b != true { - t.Errorf("expected b=true got b=%v", *b) - } - if *c != false { - t.Errorf("expect c=false got c=%v", *c) - } +func TestBool(t *testing.T) { + t.Parallel() + + t.Run("with explicit true", func(t *testing.T) { + var triState triStateValue + f := setUpFlagSet(&triState) + require.NoError(t, f.Parse([]string{"--tristate=true"})) + require.Equalf(t, triStateTrue, triState, + "expected", triStateTrue, "(triStateTrue) but got", triState, "instead", + ) + }) + + t.Run("with implicit true", func(t *testing.T) { + var triState triStateValue + f := setUpFlagSet(&triState) + require.NoError(t, f.Parse([]string{"--tristate"})) + require.Equalf(t, triStateTrue, triState, + "expected", triStateTrue, "(triStateTrue) but got", triState, "instead", + ) + }) + + t.Run("with short flag", func(t *testing.T) { + var triState triStateValue + f := setUpFlagSet(&triState) + require.NoError(t, f.Parse([]string{"-t"})) + require.Equalf(t, triStateTrue, triState, + "expected", triStateTrue, "(triStateTrue) but got", triState, "instead", + ) + }) + + t.Run("with short flag extra argument", func(t *testing.T) { + var triState triStateValue + f := setUpFlagSet(&triState) + // The"maybe"turns into an arg, since short boolean options will only do true/false + require.NoError(t, f.Parse([]string{"-t", "maybe"})) + require.Equalf(t, triStateTrue, triState, + "expected", triStateTrue, "(triStateTrue) but got", triState, "instead", + ) + args := f.Args() + require.Len(t, args, 1) + require.Equalf(t, "maybe", args[0], + "expected an extra 'maybe' argument to stick around", + ) + }) + + t.Run("with explicit maybe", func(t *testing.T) { + var triState triStateValue + f := setUpFlagSet(&triState) + require.NoError(t, f.Parse([]string{"--tristate=maybe"})) + require.Equalf(t, triStateMaybe, triState, + "expected", triStateMaybe, "(triStateMaybe) but got", triState, "instead", + ) + }) + + t.Run("with explicit false", func(t *testing.T) { + var triState triStateValue + f := setUpFlagSet(&triState) + require.NoError(t, f.Parse([]string{"--tristate=false"})) + require.Equalf(t, triStateFalse, triState, + "expected", triStateFalse, "(triStateFalse) but got", triState, "instead", + ) + }) + + t.Run("with implicit false", func(t *testing.T) { + var triState triStateValue + f := setUpFlagSet(&triState) + require.NoError(t, f.Parse([]string{})) + require.Equalf(t, triStateFalse, triState, + "expected", triStateFalse, "(triStateFalse) but got", triState, "instead", + ) + }) + + t.Run("with invalid value", func(t *testing.T) { + var triState triStateValue + f := setUpFlagSet(&triState) + var buf bytes.Buffer + f.SetOutput(&buf) + require.Errorf(t, f.Parse([]string{"--tristate=invalid"}), + "expected an error but did not get any, tristate has value", triState, + ) + }) + + t.Run("with BoolP", func(t *testing.T) { + b := BoolP("bool", "b", false, "bool value in CommandLine") + c := BoolP("c", "c", false, "other bool value") + args := []string{"--bool"} + require.NoError(t, CommandLine.Parse(args)) + require.Truef(t, *b, + "expected b=true got b=%v", *b, + ) + require.Falsef(t, *c, + "expect c=false got c=%v", *c, + ) + }) } diff --git a/bytes_test.go b/bytes_test.go index 5251f347..7956e789 100644 --- a/bytes_test.go +++ b/bytes_test.go @@ -5,16 +5,18 @@ import ( "fmt" "os" "testing" -) -func setUpBytesHex(bytesHex *[]byte) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.BytesHexVar(bytesHex, "bytes", []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, "Some bytes in HEX") - f.BytesHexVarP(bytesHex, "bytes2", "B", []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, "Some bytes in HEX") - return f -} + "github.com/stretchr/testify/require" +) func TestBytesHex(t *testing.T) { + newFlag := func(bytesHex *[]byte) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.BytesHexVar(bytesHex, "bytes", []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, "Some bytes in HEX") + f.BytesHexVarP(bytesHex, "bytes2", "B", []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, "Some bytes in HEX") + return f + } + testCases := []struct { input string success bool @@ -38,7 +40,7 @@ func TestBytesHex(t *testing.T) { for i := range testCases { var bytesHex []byte - f := setUpBytesHex(&bytesHex) + f := newFlag(&bytesHex) tc := &testCases[i] @@ -52,34 +54,37 @@ func TestBytesHex(t *testing.T) { for _, arg := range args { err := f.Parse([]string{arg}) - if err != nil && tc.success == true { - t.Errorf("expected success, got %q", err) - continue - } else if err == nil && tc.success == false { - // bytesHex, err := f.GetBytesHex("bytes") - t.Errorf("expected failure while processing %q", tc.input) + if !tc.success { + require.Errorf(t, err, + "expected failure while processing %q", tc.input, + ) + continue - } else if tc.success { - bytesHex, err := f.GetBytesHex("bytes") - if err != nil { - t.Errorf("Got error trying to fetch the 'bytes' flag: %v", err) - } - if fmt.Sprintf("%X", bytesHex) != tc.expected { - t.Errorf("expected %q, got '%X'", tc.expected, bytesHex) - } } + + require.NoErrorf(t, err, "expected success, got %q", err) + + bytesHex, err := f.GetBytesHex("bytes") + require.NoErrorf(t, err, + "got error trying to fetch the 'bytes' flag: %v", err, + ) + + require.Equalf(t, tc.expected, fmt.Sprintf("%X", bytesHex), + "expected %q, got '%X'", tc.expected, bytesHex, + ) + } } } -func setUpBytesBase64(bytesBase64 *[]byte) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.BytesBase64Var(bytesBase64, "bytes", []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, "Some bytes in Base64") - f.BytesBase64VarP(bytesBase64, "bytes2", "B", []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, "Some bytes in Base64") - return f -} - func TestBytesBase64(t *testing.T) { + newFlag := func(bytesBase64 *[]byte) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.BytesBase64Var(bytesBase64, "bytes", []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, "Some bytes in Base64") + f.BytesBase64VarP(bytesBase64, "bytes2", "B", []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, "Some bytes in Base64") + return f + } + testCases := []struct { input string success bool @@ -99,11 +104,9 @@ func TestBytesBase64(t *testing.T) { for i := range testCases { var bytesBase64 []byte - f := setUpBytesBase64(&bytesBase64) - + f := newFlag(&bytesBase64) tc := &testCases[i] - // --bytes args := []string{ fmt.Sprintf("--bytes=%s", tc.input), fmt.Sprintf("-B %s", tc.input), @@ -112,23 +115,23 @@ func TestBytesBase64(t *testing.T) { for _, arg := range args { err := f.Parse([]string{arg}) + if !tc.success { + require.Errorf(t, err, + "expected failure while processing %q", tc.input, + ) - if err != nil && tc.success == true { - t.Errorf("expected success, got %q", err) - continue - } else if err == nil && tc.success == false { - // bytesBase64, err := f.GetBytesBase64("bytes") - t.Errorf("expected failure while processing %q", tc.input) continue - } else if tc.success { - bytesBase64, err := f.GetBytesBase64("bytes") - if err != nil { - t.Errorf("Got error trying to fetch the 'bytes' flag: %v", err) - } - if base64.StdEncoding.EncodeToString(bytesBase64) != tc.expected { - t.Errorf("expected %q, got '%X'", tc.expected, bytesBase64) - } } + + require.NoErrorf(t, err, "expected success, got %q", err) + + bytesBase64, err := f.GetBytesBase64("bytes") + require.NoErrorf(t, err, + "got error trying to fetch the 'bytes' flag: %v", err, + ) + require.Equalf(t, tc.expected, base64.StdEncoding.EncodeToString(bytesBase64), + "expected %q, got '%X'", tc.expected, bytesBase64, + ) } } } diff --git a/count.go b/count.go index a0b2679f..5d659e52 100644 --- a/count.go +++ b/count.go @@ -13,7 +13,7 @@ func newCountValue(val int, p *int) *countValue { func (i *countValue) Set(s string) error { // "+1" means that no specific value was passed, so increment if s == "+1" { - *i = countValue(*i + 1) + *i++ return nil } v, err := strconv.ParseInt(s, 0, 0) diff --git a/count_test.go b/count_test.go index 3785d375..ec10b287 100644 --- a/count_test.go +++ b/count_test.go @@ -3,15 +3,17 @@ package pflag import ( "os" "testing" -) -func setUpCount(c *int) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.CountVarP(c, "verbose", "v", "a counter") - return f -} + "github.com/stretchr/testify/require" +) func TestCount(t *testing.T) { + newFlag := func(c *int) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.CountVarP(c, "verbose", "v", "a counter") + return f + } + testCases := []struct { input []string success bool @@ -30,27 +32,26 @@ func TestCount(t *testing.T) { devnull, _ := os.Open(os.DevNull) os.Stderr = devnull + for i := range testCases { var count int - f := setUpCount(&count) + f := newFlag(&count) tc := &testCases[i] err := f.Parse(tc.input) - if err != nil && tc.success == true { - t.Errorf("expected success, got %q", err) - continue - } else if err == nil && tc.success == false { - t.Errorf("expected failure, got success") + if !tc.success { + require.Errorf(t, err, + "expected failure with %q, got success", tc.input, + ) + continue - } else if tc.success { - c, err := f.GetCount("verbose") - if err != nil { - t.Errorf("Got error trying to fetch the counter flag") - } - if c != tc.expected { - t.Errorf("expected %d, got %d", tc.expected, c) - } } + + require.NoError(t, err) + + c, err := f.GetCount("verbose") + require.NoError(t, err) + require.Equal(t, tc.expected, c) } } diff --git a/duration_slice.go b/duration_slice.go index badadda5..ee987274 100644 --- a/duration_slice.go +++ b/duration_slice.go @@ -1,7 +1,6 @@ package pflag import ( - "fmt" "strings" "time" ) @@ -46,7 +45,7 @@ func (s *durationSliceValue) Type() string { func (s *durationSliceValue) String() string { out := make([]string, len(*s.value)) for i, d := range *s.value { - out[i] = fmt.Sprintf("%s", d) + out[i] = d.String() } return "[" + strings.Join(out, ",") + "]" } @@ -56,7 +55,7 @@ func (s *durationSliceValue) fromString(val string) (time.Duration, error) { } func (s *durationSliceValue) toString(val time.Duration) string { - return fmt.Sprintf("%s", val) + return val.String() } func (s *durationSliceValue) Append(val string) error { diff --git a/duration_slice_test.go b/duration_slice_test.go index 651fbd8b..3f44d1fa 100644 --- a/duration_slice_test.go +++ b/duration_slice_test.go @@ -9,180 +9,161 @@ import ( "strings" "testing" "time" -) -func setUpDSFlagSet(dsp *[]time.Duration) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.DurationSliceVar(dsp, "ds", []time.Duration{}, "Command separated list!") - return f -} + "github.com/stretchr/testify/require" +) -func setUpDSFlagSetWithDefault(dsp *[]time.Duration) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.DurationSliceVar(dsp, "ds", []time.Duration{0, 1}, "Command separated list!") - return f -} +func TestDurationSlice(t *testing.T) { + t.Parallel() -func TestEmptyDS(t *testing.T) { - var ds []time.Duration - f := setUpDSFlagSet(&ds) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) + newFlag := func(dsp *[]time.Duration) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.DurationSliceVar(dsp, "ds", []time.Duration{}, "Command separated list!") + return f } - getDS, err := f.GetDurationSlice("ds") - if err != nil { - t.Fatal("got an error from GetDurationSlice():", err) - } - if len(getDS) != 0 { - t.Fatalf("got ds %v with len=%d but expected length=0", getDS, len(getDS)) - } -} + t.Run("with empty slice", func(t *testing.T) { + ds := make([]time.Duration, 0) + f := newFlag(&ds) + require.NoError(t, f.Parse([]string{})) -func TestDS(t *testing.T) { - var ds []time.Duration - f := setUpDSFlagSet(&ds) + getDS, err := f.GetDurationSlice("ds") + require.NoErrorf(t, err, + "got an error from GetDurationSlice(): %v", err, + ) + require.Empty(t, getDS) + }) - vals := []string{"1ns", "2ms", "3m", "4h"} - arg := fmt.Sprintf("--ds=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range ds { - d, err := time.ParseDuration(vals[i]) - if err != nil { - t.Fatalf("got error: %v", err) + t.Run("with values", func(t *testing.T) { + vals := []string{"1ns", "2ms", "3m", "4h"} + ds := make([]time.Duration, 0, len(vals)) + f := newFlag(&ds) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--ds=%s", strings.Join(vals, ",")), + })) + + for i, v := range ds { + d, err := time.ParseDuration(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, d, + "expected ds[%d] to be %s but got: %d", i, vals[i], v, + ) } - if d != v { - t.Fatalf("expected ds[%d] to be %s but got: %d", i, vals[i], v) - } - } - getDS, err := f.GetDurationSlice("ds") - if err != nil { - t.Fatalf("got error: %v", err) - } - for i, v := range getDS { - d, err := time.ParseDuration(vals[i]) - if err != nil { - t.Fatalf("got error: %v", err) - } - if d != v { - t.Fatalf("expected ds[%d] to be %s but got: %d from GetDurationSlice", i, vals[i], v) - } - } -} - -func TestDSDefault(t *testing.T) { - var ds []time.Duration - f := setUpDSFlagSetWithDefault(&ds) - vals := []string{"0s", "1ns"} + getDS, erd := f.GetDurationSlice("ds") + require.NoError(t, erd) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range ds { - d, err := time.ParseDuration(vals[i]) - if err != nil { - t.Fatalf("got error: %v", err) + for i, v := range getDS { + d, err := time.ParseDuration(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, d, + "expected ds[%d] to be %s but got: %d from GetDurationSlice", i, vals[i], v, + ) } - if d != v { - t.Fatalf("expected ds[%d] to be %d but got: %d", i, d, v) - } - } + }) - getDS, err := f.GetDurationSlice("ds") - if err != nil { - t.Fatal("got an error from GetDurationSlice():", err) - } - for i, v := range getDS { - d, err := time.ParseDuration(vals[i]) - if err != nil { - t.Fatal("got an error from GetDurationSlice():", err) - } - if d != v { - t.Fatalf("expected ds[%d] to be %d from GetDurationSlice but got: %d", i, d, v) - } + newFlagWithDefault := func(dsp *[]time.Duration) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.DurationSliceVar(dsp, "ds", []time.Duration{0, 1}, "Command separated list!") + return f } -} -func TestDSWithDefault(t *testing.T) { - var ds []time.Duration - f := setUpDSFlagSetWithDefault(&ds) + t.Run("with default (1)", func(t *testing.T) { + vals := []string{"0s", "1ns"} + ds := make([]time.Duration, 0, len(vals)) + f := newFlagWithDefault(&ds) - vals := []string{"1ns", "2ns"} - arg := fmt.Sprintf("--ds=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range ds { - d, err := time.ParseDuration(vals[i]) - if err != nil { - t.Fatalf("got error: %v", err) - } - if d != v { - t.Fatalf("expected ds[%d] to be %d but got: %d", i, d, v) - } - } + require.NoError(t, f.Parse([]string{})) - getDS, err := f.GetDurationSlice("ds") - if err != nil { - t.Fatal("got an error from GetDurationSlice():", err) - } - for i, v := range getDS { - d, err := time.ParseDuration(vals[i]) - if err != nil { - t.Fatalf("got error: %v", err) + for i, v := range ds { + d, err := time.ParseDuration(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, d, + "expected ds[%d] to be %d but got: %d", i, d, v, + ) } - if d != v { - t.Fatalf("expected ds[%d] to be %d from GetDurationSlice but got: %d", i, d, v) + + getDS, erd := f.GetDurationSlice("ds") + require.NoErrorf(t, erd, + "got an error from GetDurationSlice(): %v", erd, + ) + + for i, v := range getDS { + d, err := time.ParseDuration(vals[i]) + require.NoErrorf(t, err, + "got an error from GetDurationSlice(): %v", err, + ) + require.Equalf(t, v, d, + "expected ds[%d] to be %d from GetDurationSlice but got: %d", i, d, v, + ) } - } -} + }) -func TestDSAsSliceValue(t *testing.T) { - var ds []time.Duration - f := setUpDSFlagSet(&ds) - - in := []string{"1ns", "2ns"} - argfmt := "--ds=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } + t.Run("with default (2)", func(t *testing.T) { + vals := []string{"1ns", "2ns"} + ds := make([]time.Duration, 0, len(vals)) + f := newFlagWithDefault(&ds) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--ds=%s", strings.Join(vals, ",")), + })) + + for i, v := range ds { + d, err := time.ParseDuration(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, d, + "expected ds[%d] to be %d but got: %d", i, d, v, + ) + } - f.VisitAll(func(f *Flag) { - if val, ok := f.Value.(SliceValue); ok { - _ = val.Replace([]string{"3ns"}) + getDS, erd := f.GetDurationSlice("ds") + require.NoErrorf(t, erd, + "got an error from GetDurationSlice(): %v", erd, + ) + + for i, v := range getDS { + d, err := time.ParseDuration(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, d, + "expected ds[%d] to be %d from GetDurationSlice but got: %d", i, d, v, + ) } }) - if len(ds) != 1 || ds[0] != time.Duration(3) { - t.Fatalf("Expected ss to be overwritten with '3ns', but got: %v", ds) - } -} -func TestDSCalledTwice(t *testing.T) { - var ds []time.Duration - f := setUpDSFlagSet(&ds) - - in := []string{"1ns,2ns", "3ns"} - expected := []time.Duration{1, 2, 3} - argfmt := "--ds=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range ds { - if expected[i] != v { - t.Fatalf("expected ds[%d] to be %d but got: %d", i, expected[i], v) - } - } + t.Run("as SliceValue", func(t *testing.T) { + in := []string{"1ns", "2ns"} + ds := make([]time.Duration, 0, len(in)) + f := newFlag(&ds) + + argfmt := "--ds=%s" + arg1 := fmt.Sprintf(argfmt, in[0]) + arg2 := fmt.Sprintf(argfmt, in[1]) + require.NoError(t, f.Parse([]string{arg1, arg2})) + + f.VisitAll(func(f *Flag) { + if val, ok := f.Value.(SliceValue); ok { + require.NoError(t, val.Replace([]string{"3ns"})) + } + }) + + require.Equalf(t, []time.Duration{time.Duration(3)}, ds, + "expected ss to be overwritten with '3ns', but got: %v", ds, + ) + }) + + t.Run("called twice", func(t *testing.T) { + const argfmt = "--ds=%s" + in := []string{"1ns,2ns", "3ns"} + ds := make([]time.Duration, 0, len(in)) + f := newFlag(&ds) + expected := []time.Duration{1, 2, 3} + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + require.Equal(t, expected, ds) + }) } diff --git a/example_test.go b/example_test.go index abd7806f..574ace0b 100644 --- a/example_test.go +++ b/example_test.go @@ -11,7 +11,7 @@ import ( ) func ExampleShorthandLookup() { - name := "verbose" + const name = "verbose" short := name[:1] pflag.BoolP(name, short, false, "verbose output") @@ -23,7 +23,7 @@ func ExampleShorthandLookup() { } func ExampleFlagSet_ShorthandLookup() { - name := "verbose" + const name = "verbose" short := name[:1] fs := pflag.NewFlagSet("Example", pflag.ContinueOnError) diff --git a/flag.go b/flag.go index 7c058de3..4121521e 100644 --- a/flag.go +++ b/flag.go @@ -27,23 +27,32 @@ unaffected. Define flags using flag.String(), Bool(), Int(), etc. This declares an integer flag, -flagname, stored in the pointer ip, with type *int. + var ip = flag.Int("flagname", 1234, "help message for flagname") + If you like, you can bind the flag to a variable using the Var() functions. + var flagvar int func init() { flag.IntVar(&flagvar, "flagname", 1234, "help message for flagname") } + Or you can create custom flags that satisfy the Value interface (with pointer receivers) and couple them to flag parsing by + flag.Var(&flagVal, "name", "help message for flagname") + For such flags, the default value is just the initial value of the variable. After all flags are defined, call + flag.Parse() + to parse the command line into the defined flags. Flags may then be used directly. If you're using the flags themselves, they are all pointers; if you bind to variables, they're values. + fmt.Println("ip has value ", *ip) fmt.Println("flagvar has value ", flagvar) @@ -54,22 +63,26 @@ The arguments are indexed from 0 through flag.NArg()-1. The pflag package also defines some new functions that are not in flag, that give one-letter shorthands for flags. You can use these by appending 'P' to the name of any function that defines a flag. + var ip = flag.IntP("flagname", "f", 1234, "help message") var flagvar bool func init() { flag.BoolVarP(&flagvar, "boolname", "b", true, "help message") } flag.VarP(&flagval, "varname", "v", "help message") + Shorthand letters can be used with single dashes on the command line. Boolean shorthand flags can be combined with other shorthand flags. Command line flag syntax: + --flag // boolean flags only --flag=x Unlike the flag package, a single dash before an option means something different than a double dash. Single dashes signify a series of shorthand letters for flags. All but the last shorthand letter must be boolean flags. + // boolean flags -f -abc @@ -365,7 +378,7 @@ func (f *FlagSet) ShorthandLookup(name string) *Flag { } if len(name) > 1 { msg := fmt.Sprintf("can not look up shorthand which is more than one ASCII character: %q", name) - fmt.Fprintf(f.Output(), msg) + fmt.Fprintln(f.Output(), msg) panic(msg) } c := name[0] @@ -606,7 +619,7 @@ func UnquoteUsage(flag *Flag) (name string, usage string) { name = "bools" } - return + return name, usage } // Splits the string `s` on whitespace into an initial substring up to @@ -634,7 +647,7 @@ func wrapN(i, slop int, s string) (string, string) { // caller). Pass `w` == 0 to do no wrapping func wrap(i, w int, s string) string { if w == 0 { - return strings.Replace(s, "\n", "\n"+strings.Repeat(" ", i), -1) + return strings.ReplaceAll(s, "\n", "\n"+strings.Repeat(" ", i)) } // space between indent i and end of line width w into which @@ -652,26 +665,26 @@ func wrap(i, w int, s string) string { } // If still not enough space then don't even try to wrap. if wrap < 24 { - return strings.Replace(s, "\n", r, -1) + return strings.ReplaceAll(s, "\n", r) } // Try to avoid short orphan words on the final line, by // allowing wrapN to go a bit over if that would fit in the // remainder of the line. slop := 5 - wrap = wrap - slop + wrap -= slop // Handle first line, which is indented by the caller (or the // special case above) l, s = wrapN(wrap, slop, s) - r = r + strings.Replace(l, "\n", "\n"+strings.Repeat(" ", i), -1) + r += strings.ReplaceAll(l, "\n", "\n"+strings.Repeat(" ", i)) // Now wrap the rest for s != "" { var t string t, s = wrapN(wrap, slop, s) - r = r + "\n" + strings.Repeat(" ", i) + strings.Replace(t, "\n", "\n"+strings.Repeat(" ", i), -1) + r = r + "\n" + strings.Repeat(" ", i) + strings.ReplaceAll(t, "\n", "\n"+strings.Repeat(" ", i)) } return r @@ -867,7 +880,7 @@ func (f *FlagSet) AddFlag(flag *Flag) { } if len(flag.Shorthand) > 1 { msg := fmt.Sprintf("%q shorthand is more than one ASCII character", flag.Shorthand) - fmt.Fprintf(f.Output(), msg) + fmt.Fprintln(f.Output(), msg) panic(msg) } if f.shorthands == nil { @@ -877,7 +890,7 @@ func (f *FlagSet) AddFlag(flag *Flag) { used, alreadyThere := f.shorthands[c] if alreadyThere { msg := fmt.Sprintf("unable to redefine %q shorthand in %q flagset: it's already used for %q flag", c, f.name, used.Name) - fmt.Fprintf(f.Output(), msg) + fmt.Fprintln(f.Output(), msg) panic(msg) } f.shorthands[c] = flag @@ -925,18 +938,19 @@ func (f *FlagSet) failf(format string, a ...interface{}) error { // usage calls the Usage method for the flag set, or the usage function if // the flag set is CommandLine. func (f *FlagSet) usage() { - if f == CommandLine { + switch { + case f == CommandLine: Usage() - } else if f.Usage == nil { + case f.Usage == nil: defaultUsage(f) - } else { + default: f.Usage() } } -//--unknown (args will be empty) -//--unknown --next-flag ... (args will be --next-flag ...) -//--unknown arg ... (args will be arg ...) +// --unknown (args will be empty) +// --unknown --next-flag ... (args will be --next-flag ...) +// --unknown arg ... (args will be arg ...) func stripUnknownFlagValue(args []string) []string { if len(args) == 0 { //--unknown @@ -961,7 +975,8 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin name := s[2:] if len(name) == 0 || name[0] == '-' || name[0] == '=' { err = f.failf("bad flag syntax: %s", s) - return + + return a, err } split := strings.SplitN(name, "=", 2) @@ -983,39 +998,43 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin return stripUnknownFlagValue(a), nil default: err = f.failf("unknown flag: --%s", name) - return + + return a, err } } var value string - if len(split) == 2 { + switch { + case len(split) == 2: // '--flag=arg' value = split[1] - } else if flag.NoOptDefVal != "" { + case flag.NoOptDefVal != "": // '--flag' (arg was optional) value = flag.NoOptDefVal - } else if len(a) > 0 { + case len(a) > 0: // '--flag arg' value = a[0] a = a[1:] - } else { + default: // '--flag' (arg was required) err = f.failf("flag needs an argument: %s", s) - return + + return a, err } err = fn(flag, value) if err != nil { - f.failf(err.Error()) + _ = f.failf(err.Error()) } - return + + return a, err } func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parseFunc) (outShorts string, outArgs []string, err error) { outArgs = args if strings.HasPrefix(shorthands, "test.") { - return + return "", outArgs, nil } outShorts = shorthands[1:] @@ -1026,44 +1045,44 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse switch { case c == 'h': f.usage() - err = ErrHelp - return + + return "", outArgs, ErrHelp case f.ParseErrorsWhitelist.UnknownFlags: // '-f=arg arg ...' // we do not want to lose arg in this case if len(shorthands) > 2 && shorthands[1] == '=' { - outShorts = "" - return + + return "", outArgs, nil } outArgs = stripUnknownFlagValue(outArgs) - return + + return outShorts, outArgs, err default: - err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands) - return + return outShorts, outArgs, f.failf("unknown shorthand flag: %q in -%s", c, shorthands) } } var value string - if len(shorthands) > 2 && shorthands[1] == '=' { + switch { + case len(shorthands) > 2 && shorthands[1] == '=': // '-f=arg' value = shorthands[2:] outShorts = "" - } else if flag.NoOptDefVal != "" { + case flag.NoOptDefVal != "": // '-f' (arg was optional) value = flag.NoOptDefVal - } else if len(shorthands) > 1 { + case len(shorthands) > 1: // '-farg' value = shorthands[1:] outShorts = "" - } else if len(args) > 0 { + case len(args) > 0: // '-f arg' value = args[0] outArgs = args[1:] - } else { + default: // '-f' (arg was required) - err = f.failf("flag needs an argument: %q in -%s", c, shorthands) - return + return outShorts, outArgs, f.failf("flag needs an argument: %q in -%s", c, shorthands) } if flag.ShorthandDeprecated != "" { @@ -1072,9 +1091,10 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse err = fn(flag, value) if err != nil { - f.failf(err.Error()) + _ = f.failf(err.Error()) } - return + + return outShorts, outArgs, err } func (f *FlagSet) parseShortArg(s string, args []string, fn parseFunc) (a []string, err error) { @@ -1085,24 +1105,27 @@ func (f *FlagSet) parseShortArg(s string, args []string, fn parseFunc) (a []stri for len(shorthands) > 0 { shorthands, a, err = f.parseSingleShortArg(shorthands, args, fn) if err != nil { - return + return a, err } } - return + return a, nil } func (f *FlagSet) parseArgs(args []string, fn parseFunc) (err error) { for len(args) > 0 { s := args[0] args = args[1:] + if len(s) == 0 || s[0] != '-' || len(s) == 1 { if !f.interspersed { f.args = append(f.args, s) f.args = append(f.args, args...) + return nil } f.args = append(f.args, s) + continue } @@ -1110,6 +1133,7 @@ func (f *FlagSet) parseArgs(args []string, fn parseFunc) (err error) { if len(s) == 2 { // "--" terminates the flags f.argsLenAtDash = len(f.args) f.args = append(f.args, args...) + break } args, err = f.parseLongArg(s, args, fn) @@ -1117,10 +1141,11 @@ func (f *FlagSet) parseArgs(args []string, fn parseFunc) (err error) { args, err = f.parseShortArg(s, args, fn) } if err != nil { - return + return err } } - return + + return nil } // Parse parses flag definitions from the argument list, which should not @@ -1130,15 +1155,11 @@ func (f *FlagSet) parseArgs(args []string, fn parseFunc) (err error) { func (f *FlagSet) Parse(arguments []string) error { if f.addedGoFlagSets != nil { for _, goFlagSet := range f.addedGoFlagSets { - goFlagSet.Parse(nil) + _ = goFlagSet.Parse(nil) } } f.parsed = true - if len(arguments) < 0 { - return nil - } - f.args = make([]string, 0, len(arguments)) set := func(flag *Flag, value string) error { @@ -1157,6 +1178,7 @@ func (f *FlagSet) Parse(arguments []string) error { panic(err) } } + return nil } @@ -1194,7 +1216,7 @@ func (f *FlagSet) Parsed() bool { // after all flags are defined and before flags are accessed by the program. func Parse() { // Ignore errors; CommandLine is set for ExitOnError. - CommandLine.Parse(os.Args[1:]) + _ = CommandLine.Parse(os.Args[1:]) } // ParseAll parses the command-line flags from os.Args[1:] and called fn for each. @@ -1202,7 +1224,7 @@ func Parse() { // defined and before flags are accessed by the program. func ParseAll(fn func(flag *Flag, value string) error) { // Ignore errors; CommandLine is set for ExitOnError. - CommandLine.ParseAll(os.Args[1:], fn) + _ = CommandLine.ParseAll(os.Args[1:], fn) } // SetInterspersed sets whether to support interspersed option/non-option arguments. diff --git a/flag_test.go b/flag_test.go index 58a5d25a..500f8c30 100644 --- a/flag_test.go +++ b/flag_test.go @@ -11,41 +11,59 @@ import ( "io/ioutil" "net" "os" - "reflect" "sort" "strconv" "strings" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( - testBool = Bool("test_bool", false, "bool value") - testInt = Int("test_int", 0, "int value") - testInt64 = Int64("test_int64", 0, "int64 value") - testUint = Uint("test_uint", 0, "uint value") - testUint64 = Uint64("test_uint64", 0, "uint64 value") - testString = String("test_string", "0", "string value") - testFloat = Float64("test_float64", 0, "float64 value") - testDuration = Duration("test_duration", 0, "time.Duration value") - testOptionalInt = Int("test_optional_int", 0, "optional int value") + testBool *bool + testInt *int + testInt64 *int64 + testUint *uint + testUint64 *uint64 + testString *string + testFloat *float64 + testDuration *time.Duration + testOptionalInt *int + normalizeFlagNameInvocations = 0 ) -func boolString(s string) string { - if s == "0" { - return "false" - } - return "true" +func init() { + testBool = Bool("test_bool", false, "bool value") + testInt = Int("test_int", 0, "int value") + testInt64 = Int64("test_int64", 0, "int64 value") + testUint = Uint("test_uint", 0, "uint value") + testUint64 = Uint64("test_uint64", 0, "uint64 value") + testString = String("test_string", "0", "string value") + testFloat = Float64("test_float64", 0, "float64 value") + testDuration = Duration("test_duration", 0, "time.Duration value") + testOptionalInt = Int("test_optional_int", 0, "optional int value") } -func TestEverything(t *testing.T) { - m := make(map[string]*Flag) - desired := "0" - visitor := func(f *Flag) { - if len(f.Name) > 5 && f.Name[0:5] == "test_" { +func TestVisit(t *testing.T) { + boolString := func(s string) string { + if s == "0" { + return "false" + } + return "true" + } + + visitor := func(desired string, m map[string]*Flag) func(*Flag) { + return func(f *Flag) { + if len(f.Name) <= 5 || f.Name[0:5] != "test_" { + return + } + m[f.Name] = f ok := false + switch { case f.Value.String() == desired: ok = true @@ -54,61 +72,77 @@ func TestEverything(t *testing.T) { case f.Name == "test_duration" && f.Value.String() == desired+"s": ok = true } - if !ok { - t.Error("Visit: bad value", f.Value.String(), "for", f.Name) - } + require.Truef(t, ok, + "visit: bad value", f.Value.String(), "for", f.Name, + ) } } - VisitAll(visitor) - if len(m) != 9 { - t.Error("VisitAll misses some flags") + + printMap := func(m map[string]*Flag) { for k, v := range m { t.Log(k, *v) } } - m = make(map[string]*Flag) - Visit(visitor) - if len(m) != 0 { - t.Errorf("Visit sees unset flags") - for k, v := range m { - t.Log(k, *v) + + t.Run("with VisitAll", func(t *testing.T) { + const desired = "0" + m := make(map[string]*Flag) + + VisitAll(visitor(desired, m)) + if !assert.Lenf(t, m, 9, "VisitAll misses some flags") { + printMap(m) } - } - // Now set all flags - Set("test_bool", "true") - Set("test_int", "1") - Set("test_int64", "1") - Set("test_uint", "1") - Set("test_uint64", "1") - Set("test_string", "1") - Set("test_float64", "1") - Set("test_duration", "1s") - Set("test_optional_int", "1") - desired = "1" - Visit(visitor) - if len(m) != 9 { - t.Error("Visit fails after set") - for k, v := range m { - t.Log(k, *v) + }) + + t.Run("with Visit", func(t *testing.T) { + const desired = "0" + m := make(map[string]*Flag) + + Visit(visitor(desired, m)) + if !assert.Lenf(t, m, 0, "Visit sees unset flags") { + printMap(m) } - } - // Now test they're visited in sort order. - var flagNames []string - Visit(func(f *Flag) { flagNames = append(flagNames, f.Name) }) - if !sort.StringsAreSorted(flagNames) { - t.Errorf("flag names not sorted: %v", flagNames) - } + }) + + t.Run("with all flags set", func(t *testing.T) { + const desired = "1" + m := make(map[string]*Flag) + + require.NoError(t, Set("test_bool", "true")) + require.NoError(t, Set("test_int", "1")) + require.NoError(t, Set("test_int64", "1")) + require.NoError(t, Set("test_uint", "1")) + require.NoError(t, Set("test_uint64", "1")) + require.NoError(t, Set("test_string", "1")) + require.NoError(t, Set("test_float64", "1")) + require.NoError(t, Set("test_duration", "1s")) + require.NoError(t, Set("test_optional_int", "1")) + + Visit(visitor(desired, m)) + if !assert.Lenf(t, m, 9, "Visit fails after set") { + printMap(m) + } + }) + + t.Run("visit in sorted order", func(t *testing.T) { + var flagNames []string + Visit(func(f *Flag) { flagNames = append(flagNames, f.Name) }) + require.Truef(t, sort.StringsAreSorted(flagNames), + "flag names not sorted: %v", flagNames, + ) + }) } func TestUsage(t *testing.T) { called := false ResetForTesting(func() { called = true }) - if GetCommandLine().Parse([]string{"--x"}) == nil { - t.Error("parse did not fail for unknown flag") - } - if called { - t.Error("did call Usage while using ContinueOnError") - } + + require.NotNilf(t, GetCommandLine().Parse([]string{"--x"}), + "parse did not fail for unknown flag", + ) + require.Falsef(t, called, + "did call Usage while using ContinueOnError", + ) } func TestAddFlagSet(t *testing.T) { @@ -123,56 +157,59 @@ func TestAddFlagSet(t *testing.T) { oldSet.AddFlagSet(newSet) - if len(oldSet.formal) != 3 { - t.Errorf("Unexpected result adding a FlagSet to a FlagSet %v", oldSet) - } + require.Lenf(t, oldSet.formal, 3, + "unexpected result adding a FlagSet to a FlagSet %v", oldSet, + ) } func TestAnnotation(t *testing.T) { f := NewFlagSet("shorthand", ContinueOnError) - if err := f.SetAnnotation("missing-flag", "key", nil); err == nil { - t.Errorf("Expected error setting annotation on non-existent flag") - } + require.Errorf(t, f.SetAnnotation("missing-flag", "key", nil), + "expected error setting annotation on non-existent flag", + ) f.StringP("stringa", "a", "", "string value") - if err := f.SetAnnotation("stringa", "key", nil); err != nil { - t.Errorf("Unexpected error setting new nil annotation: %v", err) - } - if annotation := f.Lookup("stringa").Annotations["key"]; annotation != nil { - t.Errorf("Unexpected annotation: %v", annotation) - } + require.NoErrorf(t, f.SetAnnotation("stringa", "key", nil), + "unexpected error setting new nil annotation", + ) + require.Nil(t, f.Lookup("stringa").Annotations["key"], + "unexpected annotation", + ) f.StringP("stringb", "b", "", "string2 value") - if err := f.SetAnnotation("stringb", "key", []string{"value1"}); err != nil { - t.Errorf("Unexpected error setting new annotation: %v", err) - } - if annotation := f.Lookup("stringb").Annotations["key"]; !reflect.DeepEqual(annotation, []string{"value1"}) { - t.Errorf("Unexpected annotation: %v", annotation) - } - - if err := f.SetAnnotation("stringb", "key", []string{"value2"}); err != nil { - t.Errorf("Unexpected error updating annotation: %v", err) - } - if annotation := f.Lookup("stringb").Annotations["key"]; !reflect.DeepEqual(annotation, []string{"value2"}) { - t.Errorf("Unexpected annotation: %v", annotation) - } + require.NoErrorf(t, f.SetAnnotation("stringb", "key", []string{"value1"}), + "unexpected error setting new annotation", + ) + + annotation := f.Lookup("stringb").Annotations["key"] + require.EqualValuesf(t, []string{"value1"}, annotation, + "unexpected annotation: %v", annotation, + ) + + require.NoErrorf(t, f.SetAnnotation("stringb", "key", []string{"value2"}), + "unexpected error updating annotation", + ) + annotation = f.Lookup("stringb").Annotations["key"] + require.EqualValuesf(t, []string{"value2"}, annotation, + "unexpected annotation: %v", annotation, + ) } func TestName(t *testing.T) { - flagSetName := "bob" + const flagSetName = "bob" f := NewFlagSet(flagSetName, ContinueOnError) givenName := f.Name() - if givenName != flagSetName { - t.Errorf("Unexpected result when retrieving a FlagSet's name: expected %s, but found %s", flagSetName, givenName) - } + require.Equalf(t, flagSetName, givenName, + "unexpected result when retrieving a FlagSet's name: expected %s, but found %s", + flagSetName, givenName, + ) } func testParse(f *FlagSet, t *testing.T) { - if f.Parsed() { - t.Error("f.Parse() = true before Parse") - } + require.Falsef(t, f.Parsed(), "f.Parse() = true before Parse") + boolFlag := f.Bool("bool", false, "bool value") bool2Flag := f.Bool("bool2", false, "bool2 value") bool3Flag := f.Bool("bool3", false, "bool3 value") @@ -193,168 +230,237 @@ func testParse(f *FlagSet, t *testing.T) { maskFlag := f.IPMask("mask", ParseIPv4Mask("0.0.0.0"), "mask value") durationFlag := f.Duration("duration", 5*time.Second, "time.Duration value") optionalIntNoValueFlag := f.Int("optional-int-no-value", 0, "int value") + f.Lookup("optional-int-no-value").NoOptDefVal = "9" optionalIntWithValueFlag := f.Int("optional-int-with-value", 0, "int value") f.Lookup("optional-int-no-value").NoOptDefVal = "9" - extra := "one-extra-argument" - args := []string{ - "--bool", - "--bool2=true", - "--bool3=false", - "--int=22", - "--int8=-8", - "--int16=-16", - "--int32=-32", - "--int64=0x23", - "--uint", "24", - "--uint8=8", - "--uint16=16", - "--uint32=32", - "--uint64=25", - "--string=hello", - "--float32=-172e12", - "--float64=2718e28", - "--ip=10.11.12.13", - "--mask=255.255.255.0", - "--duration=2m", - "--optional-int-no-value", - "--optional-int-with-value=42", - extra, - } - if err := f.Parse(args); err != nil { - t.Fatal(err) - } - if !f.Parsed() { - t.Error("f.Parse() = false after Parse") - } - if *boolFlag != true { - t.Error("bool flag should be true, is ", *boolFlag) - } - if v, err := f.GetBool("bool"); err != nil || v != *boolFlag { - t.Error("GetBool does not work.") - } - if *bool2Flag != true { - t.Error("bool2 flag should be true, is ", *bool2Flag) - } - if *bool3Flag != false { - t.Error("bool3 flag should be false, is ", *bool2Flag) - } - if *intFlag != 22 { - t.Error("int flag should be 22, is ", *intFlag) - } - if v, err := f.GetInt("int"); err != nil || v != *intFlag { - t.Error("GetInt does not work.") - } - if *int8Flag != -8 { - t.Error("int8 flag should be 0x23, is ", *int8Flag) - } - if *int16Flag != -16 { - t.Error("int16 flag should be -16, is ", *int16Flag) - } - if v, err := f.GetInt8("int8"); err != nil || v != *int8Flag { - t.Error("GetInt8 does not work.") - } - if v, err := f.GetInt16("int16"); err != nil || v != *int16Flag { - t.Error("GetInt16 does not work.") - } - if *int32Flag != -32 { - t.Error("int32 flag should be 0x23, is ", *int32Flag) - } - if v, err := f.GetInt32("int32"); err != nil || v != *int32Flag { - t.Error("GetInt32 does not work.") - } - if *int64Flag != 0x23 { - t.Error("int64 flag should be 0x23, is ", *int64Flag) - } - if v, err := f.GetInt64("int64"); err != nil || v != *int64Flag { - t.Error("GetInt64 does not work.") - } - if *uintFlag != 24 { - t.Error("uint flag should be 24, is ", *uintFlag) - } - if v, err := f.GetUint("uint"); err != nil || v != *uintFlag { - t.Error("GetUint does not work.") - } - if *uint8Flag != 8 { - t.Error("uint8 flag should be 8, is ", *uint8Flag) - } - if v, err := f.GetUint8("uint8"); err != nil || v != *uint8Flag { - t.Error("GetUint8 does not work.") - } - if *uint16Flag != 16 { - t.Error("uint16 flag should be 16, is ", *uint16Flag) - } - if v, err := f.GetUint16("uint16"); err != nil || v != *uint16Flag { - t.Error("GetUint16 does not work.") - } - if *uint32Flag != 32 { - t.Error("uint32 flag should be 32, is ", *uint32Flag) - } - if v, err := f.GetUint32("uint32"); err != nil || v != *uint32Flag { - t.Error("GetUint32 does not work.") - } - if *uint64Flag != 25 { - t.Error("uint64 flag should be 25, is ", *uint64Flag) - } - if v, err := f.GetUint64("uint64"); err != nil || v != *uint64Flag { - t.Error("GetUint64 does not work.") - } - if *stringFlag != "hello" { - t.Error("string flag should be `hello`, is ", *stringFlag) - } - if v, err := f.GetString("string"); err != nil || v != *stringFlag { - t.Error("GetString does not work.") - } - if *float32Flag != -172e12 { - t.Error("float32 flag should be -172e12, is ", *float32Flag) - } - if v, err := f.GetFloat32("float32"); err != nil || v != *float32Flag { - t.Errorf("GetFloat32 returned %v but float32Flag was %v", v, *float32Flag) - } - if *float64Flag != 2718e28 { - t.Error("float64 flag should be 2718e28, is ", *float64Flag) - } - if v, err := f.GetFloat64("float64"); err != nil || v != *float64Flag { - t.Errorf("GetFloat64 returned %v but float64Flag was %v", v, *float64Flag) - } - if !(*ipFlag).Equal(net.ParseIP("10.11.12.13")) { - t.Error("ip flag should be 10.11.12.13, is ", *ipFlag) - } - if v, err := f.GetIP("ip"); err != nil || !v.Equal(*ipFlag) { - t.Errorf("GetIP returned %v but ipFlag was %v", v, *ipFlag) - } - if (*maskFlag).String() != ParseIPv4Mask("255.255.255.0").String() { - t.Error("mask flag should be 255.255.255.0, is ", (*maskFlag).String()) - } - if v, err := f.GetIPv4Mask("mask"); err != nil || v.String() != (*maskFlag).String() { - t.Errorf("GetIP returned %v maskFlag was %v error was %v", v, *maskFlag, err) - } - if *durationFlag != 2*time.Minute { - t.Error("duration flag should be 2m, is ", *durationFlag) - } - if v, err := f.GetDuration("duration"); err != nil || v != *durationFlag { - t.Error("GetDuration does not work.") - } - if _, err := f.GetInt("duration"); err == nil { - t.Error("GetInt parsed a time.Duration?!?!") - } - if *optionalIntNoValueFlag != 9 { - t.Error("optional int flag should be the default value, is ", *optionalIntNoValueFlag) - } - if *optionalIntWithValueFlag != 42 { - t.Error("optional int flag should be 42, is ", *optionalIntWithValueFlag) - } - if len(f.Args()) != 1 { - t.Error("expected one argument, got", len(f.Args())) - } else if f.Args()[0] != extra { - t.Errorf("expected argument %q got %q", extra, f.Args()[0]) - } + + const extra = "one-extra-argument" + + t.Run("parse args", func(t *testing.T) { + args := []string{ + "--bool", + "--bool2=true", + "--bool3=false", + "--int=22", + "--int8=-8", + "--int16=-16", + "--int32=-32", + "--int64=0x23", + "--uint", "24", + "--uint8=8", + "--uint16=16", + "--uint32=32", + "--uint64=25", + "--string=hello", + "--float32=-172e12", + "--float64=2718e28", + "--ip=10.11.12.13", + "--mask=255.255.255.0", + "--duration=2m", + "--optional-int-no-value", + "--optional-int-with-value=42", + extra, + } + require.NoError(t, f.Parse(args)) + require.Truef(t, f.Parsed(), "f.Parse() = false after Parse") + }) + + t.Run("with bool flags", func(t *testing.T) { + require.Truef(t, *boolFlag, + "bool flag should be true, is ", *boolFlag, + ) + + v, err := f.GetBool("bool") + require.NoError(t, err) + require.Equalf(t, *boolFlag, v, "GetBool does not work") + require.Truef(t, *bool2Flag, + "bool2 flag should be true, is ", *bool2Flag, + ) + require.Falsef(t, *bool3Flag, + "bool3 flag should be false, is ", *bool2Flag, + ) + }) + + t.Run("with integer flags", func(t *testing.T) { + t.Run("int", func(t *testing.T) { + require.Equalf(t, 22, *intFlag, + "int flag should be 22, is ", *intFlag, + ) + v, err := f.GetInt("int") + require.NoError(t, err) + require.Equalf(t, *intFlag, v, "GetInt does not work") + }) + + t.Run("int8", func(t *testing.T) { + require.Equalf(t, int8(-8), *int8Flag, + "int8 flag should be 0x23, is ", *int8Flag, + ) + v, err := f.GetInt8("int8") + require.NoError(t, err) + require.Equalf(t, *int8Flag, v, "GetInt8 does not work") + }) + + t.Run("int16", func(t *testing.T) { + require.Equalf(t, int16(-16), *int16Flag, + "int16 flag should be -16, is ", *int16Flag, + ) + v, err := f.GetInt16("int16") + require.NoError(t, err) + require.Equalf(t, *int16Flag, v, "GetInt16 does not work") + }) + + t.Run("int32", func(t *testing.T) { + require.Equalf(t, int32(-32), *int32Flag, + "int32 flag should be 0x23, is ", *int32Flag, + ) + v, err := f.GetInt32("int32") + require.NoError(t, err) + require.Equalf(t, *int32Flag, v, "GetInt32 does not work") + }) + + t.Run("int64", func(t *testing.T) { + require.Equalf(t, int64(0x23), *int64Flag, + "int64 flag should be 0x23, is ", *int64Flag, + ) + v, err := f.GetInt64("int64") + require.NoError(t, err) + require.Equalf(t, *int64Flag, v, "GetInt64 does not work") + }) + + t.Run("uint", func(t *testing.T) { + require.Equalf(t, uint(24), *uintFlag, + "uint flag should be 24, is ", *uintFlag, + ) + v, err := f.GetUint("uint") + require.NoError(t, err) + require.Equalf(t, *uintFlag, v, "GetUint does not work") + }) + + t.Run("uint8", func(t *testing.T) { + require.Equalf(t, uint8(8), *uint8Flag, + "uint8 flag should be 8, is ", *uint8Flag, + ) + v, err := f.GetUint8("uint8") + require.NoError(t, err) + require.Equalf(t, *uint8Flag, v, "GetUint8 does not work") + }) + + t.Run("uint16", func(t *testing.T) { + require.Equalf(t, uint16(16), *uint16Flag, + "uint16 flag should be 16, is ", *uint16Flag, + ) + v, err := f.GetUint16("uint16") + require.NoError(t, err) + require.Equalf(t, *uint16Flag, v, "GetUint16 does not work") + }) + + t.Run("uint32", func(t *testing.T) { + require.Equalf(t, uint32(32), *uint32Flag, + "uint32 flag should be 32, is ", *uint32Flag, + ) + v, err := f.GetUint32("uint32") + require.NoError(t, err) + require.Equalf(t, *uint32Flag, v, "GetUint32 does not work") + }) + + t.Run("uint64", func(t *testing.T) { + require.Equalf(t, uint64(25), *uint64Flag, + "uint64 flag should be 25, is ", *uint64Flag, + ) + v, err := f.GetUint64("uint64") + require.NoError(t, err) + require.Equalf(t, *uint64Flag, v, "GetUint64 does not work") + }) + }) + + t.Run("with string flags", func(t *testing.T) { + require.Equalf(t, "hello", *stringFlag, + "string flag should be `hello`, is ", *stringFlag, + ) + v, err := f.GetString("string") + require.NoError(t, err) + require.Equalf(t, *stringFlag, v, "GetString does not work") + }) + + t.Run("with float flags", func(t *testing.T) { + t.Run("float32", func(t *testing.T) { + require.Equalf(t, float32(-172e12), *float32Flag, + "float32 flag should be -172e12, is ", *float32Flag, + ) + v, err := f.GetFloat32("float32") + require.NoError(t, err) + require.Equalf(t, *float32Flag, v, "GetFloat32 returned %v but float32Flag was %v", v, *float32Flag) + }) + + t.Run("float64", func(t *testing.T) { + require.Equalf(t, 2718e28, *float64Flag, + "float64 flag should be 2718e28, is ", *float64Flag, + ) + v, err := f.GetFloat64("float64") + require.NoError(t, err) + require.Equalf(t, *float64Flag, v, "GetFloat64 returned %v but float64Flag was %v", v, *float64Flag) + }) + }) + + t.Run("with IP address flags", func(t *testing.T) { + t.Run("IP", func(t *testing.T) { + require.True(t, ipFlag.Equal(net.ParseIP("10.11.12.13")), + "ip flag should be 10.11.12.13, is ", *ipFlag, + ) + v, err := f.GetIP("ip") + require.NoError(t, err) + require.True(t, v.Equal(*ipFlag), + "GetIP returned %v but ipFlag was %v", v, *ipFlag, + ) + }) + + t.Run("IPv4Mask", func(t *testing.T) { + require.Equal(t, ParseIPv4Mask("255.255.255.0").String(), maskFlag.String(), + "mask flag should be 255.255.255.0, is ", maskFlag.String(), + ) + v, err := f.GetIPv4Mask("mask") + require.NoError(t, err) + require.Equal(t, maskFlag.String(), v.String(), + "GetIP returned %v maskFlag was %v", v, *maskFlag, + ) + }) + }) + + t.Run("with duration flags", func(t *testing.T) { + require.Equalf(t, 2*time.Minute, *durationFlag, + "duration flag should be 2m, is ", *durationFlag, + ) + v, err := f.GetDuration("duration") + require.NoError(t, err) + require.Equalf(t, *durationFlag, v, "GetDuration does not work") + + _, err = f.GetInt("duration") + require.Errorf(t, err, "unexpectedly, GetInt parsed a time.Duration") + }) + + t.Run("flags with no-value defaults", func(t *testing.T) { + require.Equalf(t, 9, *optionalIntNoValueFlag, + "optional int flag should be the default value, is ", *optionalIntNoValueFlag, + ) + require.Equalf(t, 42, *optionalIntWithValueFlag, + "optional int flag should be 42, is ", *optionalIntWithValueFlag, + ) + }) + + t.Run("with non-flag argument", func(t *testing.T) { + require.Lenf(t, f.Args(), 1, + "expected one argument, got", len(f.Args()), + ) + require.Equalf(t, extra, f.Args()[0], + "expected argument %q got %q", extra, f.Args()[0], + ) + }) } func testParseAll(f *FlagSet, t *testing.T) { - if f.Parsed() { - t.Error("f.Parse() = true before Parse") - } + require.Falsef(t, f.Parsed(), "f.Parse() = true before Parse") + f.BoolP("boola", "a", false, "bool value") f.BoolP("boolb", "b", false, "bool2 value") f.BoolP("boolc", "c", false, "bool3 value") @@ -364,6 +470,7 @@ func testParseAll(f *FlagSet, t *testing.T) { f.StringP("stringx", "x", "0", "string value") f.StringP("stringy", "y", "0", "string value") f.Lookup("stringx").NoOptDefVal = "1" + args := []string{ "-ab", "-cs=xx", @@ -373,6 +480,7 @@ func testParseAll(f *FlagSet, t *testing.T) { "-y", "ee", } + want := []string{ "boola", "true", "boolb", "true", @@ -383,7 +491,8 @@ func testParseAll(f *FlagSet, t *testing.T) { "stringx", "1", "stringy", "ee", } - got := []string{} + got := make([]string, 0, len(want)) + store := func(flag *Flag, value string) error { got = append(got, flag.Name) if len(value) > 0 { @@ -391,23 +500,17 @@ func testParseAll(f *FlagSet, t *testing.T) { } return nil } - if err := f.ParseAll(args, store); err != nil { - t.Errorf("expected no error, got %s", err) - } - if !f.Parsed() { - t.Errorf("f.Parse() = false after Parse") - } - if !reflect.DeepEqual(got, want) { - t.Errorf("f.ParseAll() fail to restore the args") - t.Errorf("Got: %v", got) - t.Errorf("Want: %v", want) - } + + require.NoError(t, f.ParseAll(args, store)) + require.Truef(t, f.Parsed(), "f.Parse() = false after Parse") + require.Equalf(t, want, got, + "f.ParseAll() fail to restore the args. Got: %v, Want: %v", + got, want, + ) } func testParseWithUnknownFlags(f *FlagSet, t *testing.T) { - if f.Parsed() { - t.Error("f.Parse() = true before Parse") - } + require.Falsef(t, f.Parsed(), "f.Parse() = true before Parse") f.ParseErrorsWhitelist.UnknownFlags = true f.BoolP("boola", "a", false, "bool value") @@ -420,7 +523,9 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) { f.StringP("stringx", "x", "0", "string value") f.StringP("stringy", "y", "0", "string value") f.StringP("stringo", "o", "0", "string value") + f.Lookup("stringx").NoOptDefVal = "1" + args := []string{ "-ab", "-cs=xx", @@ -433,7 +538,7 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) { "-u=unknown3Value", "-p", "unknown4Value", - "-q", //another unknown with bool value + "-q", // another unknown with bool value "-y", "ee", "--unknown7=unknown7value", @@ -447,6 +552,7 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) { "--unknown10", "--unknown11", } + want := []string{ "boola", "true", "boolb", "true", @@ -459,7 +565,8 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) { "stringo", "ovalue", "boole", "true", } - got := []string{} + got := make([]string, 0, len(want)) + store := func(flag *Flag, value string) error { got = append(got, flag.Name) if len(value) > 0 { @@ -467,32 +574,29 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) { } return nil } - if err := f.ParseAll(args, store); err != nil { - t.Errorf("expected no error, got %s", err) - } - if !f.Parsed() { - t.Errorf("f.Parse() = false after Parse") - } - if !reflect.DeepEqual(got, want) { - t.Errorf("f.ParseAll() fail to restore the args") - t.Errorf("Got: %v", got) - t.Errorf("Want: %v", want) - } + + require.NoError(t, f.ParseAll(args, store)) + require.Truef(t, f.Parsed(), "f.Parse() = false after Parse") + require.Equalf(t, want, got, + "f.ParseAll() fail to restore the args. Got: %v, Want: %v", + got, want, + ) } func TestShorthand(t *testing.T) { f := NewFlagSet("shorthand", ContinueOnError) - if f.Parsed() { - t.Error("f.Parse() = true before Parse") - } + require.Falsef(t, f.Parsed(), "f.Parse() = true before Parse") + boolaFlag := f.BoolP("boola", "a", false, "bool value") boolbFlag := f.BoolP("boolb", "b", false, "bool2 value") boolcFlag := f.BoolP("boolc", "c", false, "bool3 value") booldFlag := f.BoolP("boold", "d", false, "bool4 value") stringaFlag := f.StringP("stringa", "s", "0", "string value") stringzFlag := f.StringP("stringz", "z", "0", "string value") - extra := "interspersed-argument" - notaflag := "--i-look-like-a-flag" + + const extra = "interspersed-argument" + const notaflag = "--i-look-like-a-flag" + args := []string{ "-ab", extra, @@ -503,77 +607,51 @@ func TestShorthand(t *testing.T) { "--", notaflag, } + f.SetOutput(ioutil.Discard) - if err := f.Parse(args); err != nil { - t.Error("expected no error, got ", err) - } - if !f.Parsed() { - t.Error("f.Parse() = false after Parse") - } - if *boolaFlag != true { - t.Error("boola flag should be true, is ", *boolaFlag) - } - if *boolbFlag != true { - t.Error("boolb flag should be true, is ", *boolbFlag) - } - if *boolcFlag != true { - t.Error("boolc flag should be true, is ", *boolcFlag) - } - if *booldFlag != true { - t.Error("boold flag should be true, is ", *booldFlag) - } - if *stringaFlag != "hello" { - t.Error("stringa flag should be `hello`, is ", *stringaFlag) - } - if *stringzFlag != "something" { - t.Error("stringz flag should be `something`, is ", *stringzFlag) - } - if len(f.Args()) != 2 { - t.Error("expected one argument, got", len(f.Args())) - } else if f.Args()[0] != extra { - t.Errorf("expected argument %q got %q", extra, f.Args()[0]) - } else if f.Args()[1] != notaflag { - t.Errorf("expected argument %q got %q", notaflag, f.Args()[1]) - } - if f.ArgsLenAtDash() != 1 { - t.Errorf("expected argsLenAtDash %d got %d", f.ArgsLenAtDash(), 1) - } + require.NoError(t, f.Parse(args)) + require.Truef(t, f.Parsed(), "f.Parse() = false after Parse") + + require.Truef(t, *boolaFlag, "boola flag should be true, is ", *boolaFlag) + require.Truef(t, *boolbFlag, "boolb flag should be true, is ", *boolbFlag) + require.Truef(t, *boolcFlag, "boolc flag should be true, is ", *boolcFlag) + require.Truef(t, *booldFlag, "boold flag should be true, is ", *booldFlag) + require.Equalf(t, "hello", *stringaFlag, "stringa flag should be `hello`, is ", *stringaFlag) + require.Equalf(t, "something", *stringzFlag, "stringz flag should be `something`, is ", *stringzFlag) + + require.Len(t, f.Args(), 2, "expected one argument, got", len(f.Args())) + require.Equalf(t, extra, f.Args()[0], "expected argument %q got %q", extra, f.Args()[0]) + require.Equalf(t, notaflag, f.Args()[1], "expected argument %q got %q", notaflag, f.Args()[1]) + require.Equal(t, 1, f.ArgsLenAtDash(), "expected argsLenAtDash %d got %d", f.ArgsLenAtDash(), 1) } func TestShorthandLookup(t *testing.T) { f := NewFlagSet("shorthand", ContinueOnError) - if f.Parsed() { - t.Error("f.Parse() = true before Parse") - } + require.Falsef(t, f.Parsed(), "f.Parse() = true before Parse") + f.BoolP("boola", "a", false, "bool value") f.BoolP("boolb", "b", false, "bool2 value") + args := []string{ "-ab", } + f.SetOutput(ioutil.Discard) - if err := f.Parse(args); err != nil { - t.Error("expected no error, got ", err) - } - if !f.Parsed() { - t.Error("f.Parse() = false after Parse") - } + require.NoError(t, f.Parse(args)) + require.Truef(t, f.Parsed(), "f.Parse() = false after Parse") + flag := f.ShorthandLookup("a") - if flag == nil { - t.Errorf("f.ShorthandLookup(\"a\") returned nil") - } - if flag.Name != "boola" { - t.Errorf("f.ShorthandLookup(\"a\") found %q instead of \"boola\"", flag.Name) - } - flag = f.ShorthandLookup("") - if flag != nil { - t.Errorf("f.ShorthandLookup(\"\") did not return nil") - } - defer func() { - recover() - }() - flag = f.ShorthandLookup("ab") - // should NEVER get here. lookup should panic. defer'd func should recover it. - t.Errorf("f.ShorthandLookup(\"ab\") did not panic") + require.NotNil(t, flag, "f.ShorthandLookup(\"a\") returned nil") + + require.Equalf(t, "boola", flag.Name, + "f.ShorthandLookup(\"a\") found %q instead of \"boola\"", flag.Name, + ) + require.Nil(t, f.ShorthandLookup(""), + "f.ShorthandLookup(\"\") did not return nil", + ) + require.Panicsf(t, func() { _ = f.ShorthandLookup("ab") }, + "f.ShorthandLookup(\"ab\") did not panic", + ) } func TestParse(t *testing.T) { @@ -581,6 +659,10 @@ func TestParse(t *testing.T) { testParse(GetCommandLine(), t) } +func TestFlagSetParse(t *testing.T) { + testParse(NewFlagSet("test", ContinueOnError), t) +} + func TestParseAll(t *testing.T) { ResetForTesting(func() { t.Error("bad parse") }) testParseAll(GetCommandLine(), t) @@ -591,51 +673,38 @@ func TestIgnoreUnknownFlags(t *testing.T) { testParseWithUnknownFlags(GetCommandLine(), t) } -func TestFlagSetParse(t *testing.T) { - testParse(NewFlagSet("test", ContinueOnError), t) -} - func TestChangedHelper(t *testing.T) { f := NewFlagSet("changedtest", ContinueOnError) + f.Bool("changed", false, "changed bool") f.Bool("settrue", true, "true to true") f.Bool("setfalse", false, "false to false") f.Bool("unchanged", false, "unchanged bool") args := []string{"--changed", "--settrue", "--setfalse=false"} - if err := f.Parse(args); err != nil { - t.Error("f.Parse() = false after Parse") - } - if !f.Changed("changed") { - t.Errorf("--changed wasn't changed!") - } - if !f.Changed("settrue") { - t.Errorf("--settrue wasn't changed!") - } - if !f.Changed("setfalse") { - t.Errorf("--setfalse wasn't changed!") - } - if f.Changed("unchanged") { - t.Errorf("--unchanged was changed!") - } - if f.Changed("invalid") { - t.Errorf("--invalid was changed!") - } - if f.ArgsLenAtDash() != -1 { - t.Errorf("Expected argsLenAtDash: %d but got %d", -1, f.ArgsLenAtDash()) - } + + require.NoError(t, f.Parse(args)) + require.Truef(t, f.Changed("changed"), "--changed wasn't changed!") + require.Truef(t, f.Changed("settrue"), "--settrue wasn't changed!") + require.Truef(t, f.Changed("setfalse"), "--setfalse wasn't changed!") + require.Falsef(t, f.Changed("unchanged"), "--unchanged was changed!") + require.Falsef(t, f.Changed("invalid"), "--invalid was changed!") + + require.Equalf(t, -1, f.ArgsLenAtDash(), + "expected argsLenAtDash: %d but got %d", -1, f.ArgsLenAtDash(), + ) } -func replaceSeparators(name string, from []string, to string) string { +func replaceSeparators(name string, from []string, to string) string { //nolint: unparam result := name for _, sep := range from { - result = strings.Replace(result, sep, to, -1) + result = strings.ReplaceAll(result, sep, to) } // Type convert to indicate normalization has been done. return result } -func wordSepNormalizeFunc(f *FlagSet, name string) NormalizedName { +func wordSepNormalizeFunc(_ *FlagSet, name string) NormalizedName { seps := []string{"-", "_"} name = replaceSeparators(name, seps, ".") normalizeFlagNameInvocations++ @@ -645,63 +714,59 @@ func wordSepNormalizeFunc(f *FlagSet, name string) NormalizedName { func testWordSepNormalizedNames(args []string, t *testing.T) { f := NewFlagSet("normalized", ContinueOnError) - if f.Parsed() { - t.Error("f.Parse() = true before Parse") - } + require.Falsef(t, f.Parsed(), "f.Parse() = true before Parse") + withDashFlag := f.Bool("with-dash-flag", false, "bool value") // Set this after some flags have been added and before others. f.SetNormalizeFunc(wordSepNormalizeFunc) withUnderFlag := f.Bool("with_under_flag", false, "bool value") withBothFlag := f.Bool("with-both_flag", false, "bool value") - if err := f.Parse(args); err != nil { - t.Fatal(err) - } - if !f.Parsed() { - t.Error("f.Parse() = false after Parse") - } - if *withDashFlag != true { - t.Error("withDashFlag flag should be true, is ", *withDashFlag) - } - if *withUnderFlag != true { - t.Error("withUnderFlag flag should be true, is ", *withUnderFlag) - } - if *withBothFlag != true { - t.Error("withBothFlag flag should be true, is ", *withBothFlag) - } + + require.NoError(t, f.Parse(args)) + require.Truef(t, f.Parsed(), "f.Parse() = false after Parse") + + require.Truef(t, *withDashFlag, "withDashFlag flag should be true, is ", *withDashFlag) + require.Truef(t, *withUnderFlag, "withUnderFlag flag should be true, is ", *withUnderFlag) + require.Truef(t, *withBothFlag, "withBothFlag flag should be true, is ", *withBothFlag) } func TestWordSepNormalizedNames(t *testing.T) { - args := []string{ - "--with-dash-flag", - "--with-under-flag", - "--with-both-flag", - } - testWordSepNormalizedNames(args, t) + t.Run("with dashes", func(t *testing.T) { + args := []string{ + "--with-dash-flag", + "--with-under-flag", + "--with-both-flag", + } + testWordSepNormalizedNames(args, t) + }) - args = []string{ - "--with_dash_flag", - "--with_under_flag", - "--with_both_flag", - } - testWordSepNormalizedNames(args, t) + t.Run("with underscores", func(t *testing.T) { + args := []string{ + "--with_dash_flag", + "--with_under_flag", + "--with_both_flag", + } + testWordSepNormalizedNames(args, t) + }) - args = []string{ - "--with-dash_flag", - "--with-under_flag", - "--with-both_flag", - } - testWordSepNormalizedNames(args, t) + t.Run("with dash and underscores", func(t *testing.T) { + args := []string{ + "--with-dash_flag", + "--with-under_flag", + "--with-both_flag", + } + testWordSepNormalizedNames(args, t) + }) } -func aliasAndWordSepFlagNames(f *FlagSet, name string) NormalizedName { +func aliasAndWordSepFlagNames(_ *FlagSet, name string) NormalizedName { seps := []string{"-", "_"} oldName := replaceSeparators("old-valid_flag", seps, ".") newName := replaceSeparators("valid-flag", seps, ".") name = replaceSeparators(name, seps, ".") - switch name { - case oldName: + if name == oldName { name = newName } @@ -710,62 +775,57 @@ func aliasAndWordSepFlagNames(f *FlagSet, name string) NormalizedName { func TestCustomNormalizedNames(t *testing.T) { f := NewFlagSet("normalized", ContinueOnError) - if f.Parsed() { - t.Error("f.Parse() = true before Parse") - } + require.Falsef(t, f.Parsed(), "f.Parse() = true before Parse") validFlag := f.Bool("valid-flag", false, "bool value") f.SetNormalizeFunc(aliasAndWordSepFlagNames) someOtherFlag := f.Bool("some-other-flag", false, "bool value") args := []string{"--old_valid_flag", "--some-other_flag"} - if err := f.Parse(args); err != nil { - t.Fatal(err) - } + require.NoError(t, f.Parse(args)) - if *validFlag != true { - t.Errorf("validFlag is %v even though we set the alias --old_valid_falg", *validFlag) - } - if *someOtherFlag != true { - t.Error("someOtherFlag should be true, is ", *someOtherFlag) - } + require.Truef(t, *validFlag, "validFlag is %v even though we set the alias --old_valid_flag", *validFlag) + require.Truef(t, *someOtherFlag, "someOtherFlag should be true, is ", *someOtherFlag) } -// Every flag we add, the name (displayed also in usage) should normalized +// Every flag we add, the name (displayed also in usage) should be normalized func TestNormalizationFuncShouldChangeFlagName(t *testing.T) { - // Test normalization after addition - f := NewFlagSet("normalized", ContinueOnError) - - f.Bool("valid_flag", false, "bool value") - if f.Lookup("valid_flag").Name != "valid_flag" { - t.Error("The new flag should have the name 'valid_flag' instead of ", f.Lookup("valid_flag").Name) - } + t.Run("with normalization after addition", func(t *testing.T) { + f := NewFlagSet("normalized", ContinueOnError) + + f.Bool("valid_flag", false, "bool value") + require.Equalf(t, "valid_flag", f.Lookup("valid_flag").Name, + "the new flag should have the name 'valid_flag' instead of ", f.Lookup("valid_flag").Name, + ) + + f.SetNormalizeFunc(wordSepNormalizeFunc) + require.Equalf(t, "valid.flag", f.Lookup("valid_flag").Name, + "the new flag should have the name 'valid.flag' instead of ", f.Lookup("valid_flag").Name, + ) + }) - f.SetNormalizeFunc(wordSepNormalizeFunc) - if f.Lookup("valid_flag").Name != "valid.flag" { - t.Error("The new flag should have the name 'valid.flag' instead of ", f.Lookup("valid_flag").Name) - } + t.Run("with normalization before addition", func(t *testing.T) { + f := NewFlagSet("normalized", ContinueOnError) + f.SetNormalizeFunc(wordSepNormalizeFunc) - // Test normalization before addition - f = NewFlagSet("normalized", ContinueOnError) - f.SetNormalizeFunc(wordSepNormalizeFunc) - - f.Bool("valid_flag", false, "bool value") - if f.Lookup("valid_flag").Name != "valid.flag" { - t.Error("The new flag should have the name 'valid.flag' instead of ", f.Lookup("valid_flag").Name) - } + f.Bool("valid_flag", false, "bool value") + require.Equalf(t, "valid.flag", f.Lookup("valid_flag").Name, + "the new flag should have the name 'valid.flag' instead of ", f.Lookup("valid_flag").Name, + ) + }) } // Related to https://github.com/spf13/cobra/issues/521. func TestNormalizationSharedFlags(t *testing.T) { f := NewFlagSet("set f", ContinueOnError) g := NewFlagSet("set g", ContinueOnError) + + const testName = "valid_flag" nfunc := wordSepNormalizeFunc - testName := "valid_flag" normName := nfunc(nil, testName) - if testName == string(normName) { - t.Error("TestNormalizationSharedFlags meaningless: the original and normalized flag names are identical:", testName) - } + require.NotEqualf(t, string(normName), testName, + "TestNormalizationSharedFlags meaningless: the original and normalized flag names are identical:", testName, + ) f.Bool(testName, false, "bool value") g.AddFlagSet(f) @@ -773,51 +833,60 @@ func TestNormalizationSharedFlags(t *testing.T) { f.SetNormalizeFunc(nfunc) g.SetNormalizeFunc(nfunc) - if len(f.formal) != 1 { - t.Error("Normalizing flags should not result in duplications in the flag set:", f.formal) - } - if f.orderedFormal[0].Name != string(normName) { - t.Error("Flag name not normalized") - } - for k := range f.formal { - if k != "valid.flag" { - t.Errorf("The key in the flag map should have been normalized: wanted \"%s\", got \"%s\" instead", normName, k) - } - } + require.Lenf(t, f.formal, 1, + "normalizing flags should not result in duplications in the flag set:", f.formal, + ) + require.Equalf(t, string(normName), f.orderedFormal[0].Name, + "flag name not normalized", + ) - if !reflect.DeepEqual(f.formal, g.formal) || !reflect.DeepEqual(f.orderedFormal, g.orderedFormal) { - t.Error("Two flag sets sharing the same flags should stay consistent after being normalized. Original set:", f.formal, "Duplicate set:", g.formal) - } + for k := range f.formal { + require.Equalf(t, "valid.flag", string(k), + "the key in the flag map should have been normalized: wanted \"%s\", got \"%s\" instead", normName, k, + ) + } + + require.Equalf(t, g.formal, f.formal, + "two flag sets sharing the same flags should stay consistent after being normalized. Original set:", + f.formal, "Duplicate set:", g.formal, + ) + require.Equalf(t, g.orderedFormal, f.orderedFormal, + "two ordered flag sets sharing the same flags should stay consistent after being normalized. Original set:", + f.formal, "Duplicate set:", g.formal, + ) } func TestNormalizationSetFlags(t *testing.T) { f := NewFlagSet("normalized", ContinueOnError) nfunc := wordSepNormalizeFunc - testName := "valid_flag" + const testName = "valid_flag" normName := nfunc(nil, testName) - if testName == string(normName) { - t.Error("TestNormalizationSetFlags meaningless: the original and normalized flag names are identical:", testName) - } + + require.NotEqualf(t, string(normName), testName, + "TestNormalizationSetFlags meaningless: the original and normalized flag names are identical:", testName, + ) f.Bool(testName, false, "bool value") - f.Set(testName, "true") + require.NoError(t, f.Set(testName, "true")) f.SetNormalizeFunc(nfunc) - if len(f.formal) != 1 { - t.Error("Normalizing flags should not result in duplications in the flag set:", f.formal) - } - if f.orderedFormal[0].Name != string(normName) { - t.Error("Flag name not normalized") - } + require.Lenf(t, f.formal, 1, + "normalizing flags should not result in duplications in the flag set:", f.formal, + ) + require.Equalf(t, string(normName), f.orderedFormal[0].Name, + "flag name not normalized", + ) + for k := range f.formal { - if k != "valid.flag" { - t.Errorf("The key in the flag map should have been normalized: wanted \"%s\", got \"%s\" instead", normName, k) - } + require.Equalf(t, "valid.flag", string(k), + "the key in the flag map should have been normalized: wanted \"%s\", got \"%s\" instead", normName, k, + ) } - if !reflect.DeepEqual(f.formal, f.actual) { - t.Error("The map of set flags should get normalized. Formal:", f.formal, "Actual:", f.actual) - } + require.Equalf(t, f.actual, f.formal, + "the map of set flags should get normalized. Formal:", + f.formal, "Actual:", f.actual, + ) } // Declare a user-defined flag type. @@ -837,107 +906,143 @@ func (f *flagVar) Type() string { } func TestUserDefined(t *testing.T) { - var flags FlagSet + var ( + flags FlagSet + v flagVar + ) + flags.Init("test", ContinueOnError) - var v flagVar flags.VarP(&v, "v", "v", "usage") - if err := flags.Parse([]string{"--v=1", "-v2", "-v", "3"}); err != nil { - t.Error(err) - } - if len(v) != 3 { - t.Fatal("expected 3 args; got ", len(v)) - } - expect := "[1 2 3]" - if v.String() != expect { - t.Errorf("expected value %q got %q", expect, v.String()) - } + + require.NoError(t, flags.Parse([]string{"--v=1", "-v2", "-v", "3"})) + require.Lenf(t, v, 3, "expected 3 args; got ", len(v)) + + const expect = "[1 2 3]" + require.Equalf(t, expect, v.String(), + "expected value %q got %q", expect, v.String(), + ) } func TestSetOutput(t *testing.T) { - var flags FlagSet - var buf bytes.Buffer - flags.SetOutput(&buf) - flags.Init("test", ContinueOnError) - flags.Parse([]string{"--unknown"}) - if out := buf.String(); !strings.Contains(out, "--unknown") { - t.Logf("expected output mentioning unknown; got %q", out) - } + t.Run("with ContinueOnError", func(t *testing.T) { + var ( + flags FlagSet + buf bytes.Buffer + ) + + flags.SetOutput(&buf) + flags.Init("test", ContinueOnError) + err := flags.Parse([]string{"--unknown"}) + require.Error(t, err) + + out := buf.String() + require.Emptyf(t, out, "expected no output, only error") + require.Containsf(t, err.Error(), "--unknown", + "expected output mentioning unknown; got %q", err, + ) + }) + + t.Run("with PanicOnError", func(t *testing.T) { + // notice the behavior inconsistent with the above test. It is what it is... + var ( + flags FlagSet + buf bytes.Buffer + ) + + flags.SetOutput(&buf) + flags.Init("test", PanicOnError) + require.PanicsWithError(t, "unknown flag: --unknown", func() { + _ = flags.Parse([]string{"--unknown"}) + }) + + out := buf.String() + require.Containsf(t, out, "--unknown", + "expected output mentioning unknown; got %q", out, + ) + }) } func TestOutput(t *testing.T) { - var flags FlagSet - var buf bytes.Buffer - expect := "an example string" + var ( + flags FlagSet + buf bytes.Buffer + ) + + const expect = "an example string" flags.SetOutput(&buf) fmt.Fprint(flags.Output(), expect) - if out := buf.String(); !strings.Contains(out, expect) { - t.Errorf("expected output %q; got %q", expect, out) - } + out := buf.String() + require.Containsf(t, out, expect, + "expected output %q; got %q", expect, out, + ) } // This tests that one can reset the flags. This still works but not well, and is // superseded by FlagSet. +// +// NOTE: this does not work well with parallel testing. func TestChangingArgs(t *testing.T) { ResetForTesting(func() { t.Fatal("bad parse") }) oldArgs := os.Args defer func() { os.Args = oldArgs }() + os.Args = []string{"cmd", "--before", "subcmd"} before := Bool("before", false, "") - if err := GetCommandLine().Parse(os.Args[1:]); err != nil { - t.Fatal(err) - } + require.NoError(t, GetCommandLine().Parse(os.Args[1:])) + cmd := Arg(0) os.Args = []string{"subcmd", "--after", "args"} after := Bool("after", false, "") Parse() args := Args() - if !*before || cmd != "subcmd" || !*after || len(args) != 1 || args[0] != "args" { - t.Fatalf("expected true subcmd true [args] got %v %v %v %v", *before, cmd, *after, args) - } + require.True(t, *before) + require.Equal(t, "subcmd", cmd) + require.True(t, *after) + require.Len(t, args, 1) + require.Equal(t, "args", args[0]) } // Test that -help invokes the usage message and returns ErrHelp. func TestHelp(t *testing.T) { - var helpCalled = false - fs := NewFlagSet("help test", ContinueOnError) - fs.Usage = func() { helpCalled = true } var flag bool - fs.BoolVar(&flag, "flag", false, "regular flag") - // Regular flag invocation should work - err := fs.Parse([]string{"--flag=true"}) - if err != nil { - t.Fatal("expected no error; got ", err) - } - if !flag { - t.Error("flag was not set by --flag") - } - if helpCalled { - t.Error("help called for regular flag") - helpCalled = false // reset for next test - } - // Help flag should work as expected. - err = fs.Parse([]string{"--help"}) - if err == nil { - t.Fatal("error expected") - } - if err != ErrHelp { - t.Fatal("expected ErrHelp; got ", err) - } - if !helpCalled { - t.Fatal("help was not called") - } - // If we define a help flag, that should override. - var help bool - fs.BoolVar(&help, "help", false, "help flag") - helpCalled = false - err = fs.Parse([]string{"--help"}) - if err != nil { - t.Fatal("expected no error for defined --help; got ", err) - } - if helpCalled { - t.Fatal("help was called; should not have been for defined help flag") + mockHelp := func(called *bool) func() { + return func() { + *called = true + } } + + t.Run("not called, regular flag invocation should work", func(t *testing.T) { + var helpCalled bool + fs := NewFlagSet("help test", ContinueOnError) + fs.Usage = mockHelp(&helpCalled) + + fs.BoolVar(&flag, "flag", false, "regular flag") + require.NoError(t, fs.Parse([]string{"--flag=true"})) + require.Truef(t, flag, "flag was not set by --flag") + require.Falsef(t, helpCalled, "help called for regular flag") + }) + + t.Run("called, help flag should work", func(t *testing.T) { + var helpCalled bool + fs := NewFlagSet("help test", ContinueOnError) + fs.Usage = mockHelp(&helpCalled) + err := fs.Parse([]string{"--help"}) + require.Error(t, err) + require.ErrorIsf(t, err, ErrHelp, "expected ErrHelp; got %v", err) + require.Truef(t, helpCalled, "help was not called") + }) + + t.Run("with help flag override", func(t *testing.T) { + var help, helpCalled bool + fs := NewFlagSet("help test", ContinueOnError) + fs.Usage = mockHelp(&helpCalled) + fs.BoolVar(&help, "help", false, "help flag") + require.NoErrorf(t, fs.Parse([]string{"--help"}), "expected no error for defined --help") + require.Falsef(t, helpCalled, + "help was called unexpectedly for a user-defined help flag", + ) + }) } func TestNoInterspersed(t *testing.T) { @@ -945,180 +1050,143 @@ func TestNoInterspersed(t *testing.T) { f.SetInterspersed(false) f.Bool("true", true, "always true") f.Bool("false", false, "always false") - err := f.Parse([]string{"--true", "break", "--false"}) - if err != nil { - t.Fatal("expected no error; got ", err) - } + require.NoError(t, f.Parse([]string{"--true", "break", "--false"})) + args := f.Args() - if len(args) != 2 || args[0] != "break" || args[1] != "--false" { - t.Fatal("expected interspersed options/non-options to fail") - } + require.Len(t, args, 2) + require.Equal(t, "break", args[0]) + require.Equal(t, "--false", args[1]) } func TestTermination(t *testing.T) { f := NewFlagSet("termination", ContinueOnError) boolFlag := f.BoolP("bool", "l", false, "bool value") - if f.Parsed() { - t.Error("f.Parse() = true before Parse") - } - arg1 := "ls" - arg2 := "-l" + require.Falsef(t, f.Parsed(), "f.Parse() = true before Parse") + + const ( + arg1 = "ls" + arg2 = "-l" + ) args := []string{ "--", arg1, arg2, } f.SetOutput(ioutil.Discard) - if err := f.Parse(args); err != nil { - t.Fatal("expected no error; got ", err) - } - if !f.Parsed() { - t.Error("f.Parse() = false after Parse") - } - if *boolFlag { - t.Error("expected boolFlag=false, got true") - } - if len(f.Args()) != 2 { - t.Errorf("expected 2 arguments, got %d: %v", len(f.Args()), f.Args()) - } - if f.Args()[0] != arg1 { - t.Errorf("expected argument %q got %q", arg1, f.Args()[0]) - } - if f.Args()[1] != arg2 { - t.Errorf("expected argument %q got %q", arg2, f.Args()[1]) - } - if f.ArgsLenAtDash() != 0 { - t.Errorf("expected argsLenAtDash %d got %d", 0, f.ArgsLenAtDash()) - } + require.NoError(t, f.Parse(args)) + require.Truef(t, f.Parsed(), "f.Parse() = false after Parse") + require.Falsef(t, *boolFlag, "expected boolFlag=false, got true") + require.Lenf(t, f.Args(), 2, + "expected 2 arguments, got %d: %v", len(f.Args()), f.Args(), + ) + require.Equalf(t, arg1, f.Args()[0], + "expected argument %q got %q", arg1, f.Args()[0], + ) + require.Equalf(t, arg2, f.Args()[1], + "expected argument %q got %q", arg2, f.Args()[0], + ) + require.Equalf(t, 0, f.ArgsLenAtDash(), + "expected argsLenAtDash %d got %d", 0, f.ArgsLenAtDash(), + ) } -func getDeprecatedFlagSet() *FlagSet { - f := NewFlagSet("bob", ContinueOnError) - f.Bool("badflag", true, "always true") - f.MarkDeprecated("badflag", "use --good-flag instead") - return f -} -func TestDeprecatedFlagInDocs(t *testing.T) { - f := getDeprecatedFlagSet() +func TestDeprecated(t *testing.T) { + const ( + badFlag = "badFlag" + usageMsg = "use --good-flag instead" + shortHandName = "noshorthandflag" + shortHandMsg = "use --noshorthandflag instead" + ) - out := new(bytes.Buffer) - f.SetOutput(out) - f.PrintDefaults() - - if strings.Contains(out.String(), "badflag") { - t.Errorf("found deprecated flag in usage!") - } -} + newFlag := func() *FlagSet { + f := NewFlagSet("bob", ContinueOnError) + f.Bool(badFlag, true, "always true") + _ = f.MarkDeprecated(badFlag, usageMsg) -func TestUnHiddenDeprecatedFlagInDocs(t *testing.T) { - f := getDeprecatedFlagSet() - flg := f.Lookup("badflag") - if flg == nil { - t.Fatalf("Unable to lookup 'bob' in TestUnHiddenDeprecatedFlagInDocs") + return f } - flg.Hidden = false - out := new(bytes.Buffer) - f.SetOutput(out) - f.PrintDefaults() - - defaults := out.String() - if !strings.Contains(defaults, "badflag") { - t.Errorf("Did not find deprecated flag in usage!") - } - if !strings.Contains(defaults, "use --good-flag instead") { - t.Errorf("Did not find 'use --good-flag instead' in defaults") - } -} - -func TestDeprecatedFlagShorthandInDocs(t *testing.T) { - f := NewFlagSet("bob", ContinueOnError) - name := "noshorthandflag" - f.BoolP(name, "n", true, "always true") - f.MarkShorthandDeprecated("noshorthandflag", fmt.Sprintf("use --%s instead", name)) - - out := new(bytes.Buffer) - f.SetOutput(out) - f.PrintDefaults() - - if strings.Contains(out.String(), "-n,") { - t.Errorf("found deprecated flag shorthand in usage!") - } -} - -func parseReturnStderr(t *testing.T, f *FlagSet, args []string) (string, error) { - oldStderr := os.Stderr - r, w, _ := os.Pipe() - os.Stderr = w - - err := f.Parse(args) - - outC := make(chan string) - // copy the output in a separate goroutine so printing can't block indefinitely - go func() { - var buf bytes.Buffer - io.Copy(&buf, r) - outC <- buf.String() - }() + t.Run("with flag in doc", func(t *testing.T) { + f := newFlag() + require.NotContainsf(t, printFlagDefaults(f), badFlag, + "found deprecated flag in usage!", + ) + }) - w.Close() - os.Stderr = oldStderr - out := <-outC + t.Run("with unhidden flag in doc", func(t *testing.T) { + f := newFlag() + flg := f.Lookup(badFlag) + require.NotNilf(t, flg, + "unable to lookup %q in flag doc", badFlag, + ) + flg.Hidden = false + defaults := printFlagDefaults(f) + + require.Containsf(t, defaults, badFlag, + "did not find deprecated flag in usage!", + ) + require.Containsf(t, defaults, usageMsg, + "did not find %q in defaults", usageMsg, + ) + }) - return out, err -} + t.Run("with shorthand in doc", func(t *testing.T) { + f := newFlag() + f.BoolP(shortHandName, "n", true, "always true") + require.NoError(t, + f.MarkShorthandDeprecated("noshorthandflag", shortHandMsg), + ) -func TestDeprecatedFlagUsage(t *testing.T) { - f := NewFlagSet("bob", ContinueOnError) - f.Bool("badflag", true, "always true") - usageMsg := "use --good-flag instead" - f.MarkDeprecated("badflag", usageMsg) + require.NotContainsf(t, printFlagDefaults(f), "-n,", + "found deprecated flag shorthand in usage!", + ) + }) - args := []string{"--badflag"} - out, err := parseReturnStderr(t, f, args) - if err != nil { - t.Fatal("expected no error; got ", err) - } + t.Run("with usage", func(t *testing.T) { + f := newFlag() + f.Bool("badflag", true, "always true") + usageMsg := "use --good-flag instead" + require.NoError(t, + f.MarkDeprecated("badflag", usageMsg), + ) + + args := []string{"--badflag"} + out, err := parseReturnStderr(t, f, args) + require.NoError(t, err) + + require.Containsf(t, out, usageMsg, + "%q not printed when using a deprecated flag!", usageMsg, + ) + }) - if !strings.Contains(out, usageMsg) { - t.Errorf("usageMsg not printed when using a deprecated flag!") - } -} + t.Run("with shorthand usage", func(t *testing.T) { + f := newFlag() + f.BoolP(shortHandName, "n", true, "always true") + _ = f.MarkShorthandDeprecated(shortHandName, shortHandMsg) -func TestDeprecatedFlagShorthandUsage(t *testing.T) { - f := NewFlagSet("bob", ContinueOnError) - name := "noshorthandflag" - f.BoolP(name, "n", true, "always true") - usageMsg := fmt.Sprintf("use --%s instead", name) - f.MarkShorthandDeprecated(name, usageMsg) - - args := []string{"-n"} - out, err := parseReturnStderr(t, f, args) - if err != nil { - t.Fatal("expected no error; got ", err) - } + args := []string{"-n"} + out, err := parseReturnStderr(t, f, args) + require.NoError(t, err) - if !strings.Contains(out, usageMsg) { - t.Errorf("usageMsg not printed when using a deprecated flag!") - } -} + require.Containsf(t, out, shortHandMsg, + "%q not printed when using a deprecated flag!", shortHandMsg, + ) + }) -func TestDeprecatedFlagUsageNormalized(t *testing.T) { - f := NewFlagSet("bob", ContinueOnError) - f.Bool("bad-double_flag", true, "always true") - f.SetNormalizeFunc(wordSepNormalizeFunc) - usageMsg := "use --good-flag instead" - f.MarkDeprecated("bad_double-flag", usageMsg) + t.Run("with usage normalized", func(t *testing.T) { + f := newFlag() + f.Bool("bad-double_flag", true, "always true") + f.SetNormalizeFunc(wordSepNormalizeFunc) + require.NoError(t, f.MarkDeprecated("bad_double-flag", usageMsg)) - args := []string{"--bad_double_flag"} - out, err := parseReturnStderr(t, f, args) - if err != nil { - t.Fatal("expected no error; got ", err) - } + args := []string{"--bad_double_flag"} + out, err := parseReturnStderr(t, f, args) + require.NoError(t, err) - if !strings.Contains(out, usageMsg) { - t.Errorf("usageMsg not printed when using a deprecated flag!") - } + require.Containsf(t, out, usageMsg, + "%q not printed when using a deprecated flag!", usageMsg, + ) + }) } // Name normalization function should be called only once on flag addition @@ -1129,41 +1197,39 @@ func TestMultipleNormalizeFlagNameInvocations(t *testing.T) { f.SetNormalizeFunc(wordSepNormalizeFunc) f.Bool("with_under_flag", false, "bool value") - if normalizeFlagNameInvocations != 1 { - t.Fatal("Expected normalizeFlagNameInvocations to be 1; got ", normalizeFlagNameInvocations) - } + require.Equalf(t, 1, normalizeFlagNameInvocations, + "expected normalizeFlagNameInvocations to be 1; got ", normalizeFlagNameInvocations, + ) } -// -func TestHiddenFlagInUsage(t *testing.T) { - f := NewFlagSet("bob", ContinueOnError) - f.Bool("secretFlag", true, "shhh") - f.MarkHidden("secretFlag") +func TestHidden(t *testing.T) { + t.Run("with doc", func(t *testing.T) { + f := NewFlagSet("bob", ContinueOnError) + f.Bool("secretFlag", true, "shhh") + require.NoError(t, + f.MarkHidden("secretFlag"), + ) + + require.NotContains(t, printFlagDefaults(f), "secretFlag", + "found hidden flag in usage!", + ) + }) - out := new(bytes.Buffer) - f.SetOutput(out) - f.PrintDefaults() + t.Run("with usage", func(t *testing.T) { + f := NewFlagSet("bob", ContinueOnError) + f.Bool("secretFlag", true, "shhh") + require.NoError(t, + f.MarkHidden("secretFlag"), + ) - if strings.Contains(out.String(), "secretFlag") { - t.Errorf("found hidden flag in usage!") - } -} + args := []string{"--secretFlag"} + out, err := parseReturnStderr(t, f, args) + require.NoError(t, err) -// -func TestHiddenFlagUsage(t *testing.T) { - f := NewFlagSet("bob", ContinueOnError) - f.Bool("secretFlag", true, "shhh") - f.MarkHidden("secretFlag") - - args := []string{"--secretFlag"} - out, err := parseReturnStderr(t, f, args) - if err != nil { - t.Fatal("expected no error; got ", err) - } - - if strings.Contains(out, "shhh") { - t.Errorf("usage message printed when using a hidden flag!") - } + require.NotContainsf(t, out, "shhh", + "usage message printed when using a hidden flag!", + ) + }) } const defaultOutput = ` --A for bootstrapping, allow 'any' type @@ -1205,6 +1271,7 @@ func (cv *customValue) Type() string { return "custom" } func TestPrintDefaults(t *testing.T) { fs := NewFlagSet("print defaults test", ContinueOnError) var buf bytes.Buffer + fs.SetOutput(&buf) fs.Bool("A", false, "for bootstrapping, allow 'any' type") fs.Bool("Alongflagname", false, "disable bounds checking") @@ -1237,11 +1304,9 @@ func TestPrintDefaults(t *testing.T) { fs.PrintDefaults() got := buf.String() - if got != defaultOutput { - fmt.Println("\n" + got) - fmt.Println("\n" + defaultOutput) - t.Errorf("got %q want %q\n", got, defaultOutput) - } + require.Equalf(t, defaultOutput, got, + "got:\n%q\nwant:\n%q", got, defaultOutput, + ) } func TestVisitAllFlagOrder(t *testing.T) { @@ -1259,9 +1324,9 @@ func TestVisitAllFlagOrder(t *testing.T) { i := 0 fs.VisitAll(func(f *Flag) { - if names[i] != f.Name { - t.Errorf("Incorrect order. Expected %v, got %v", names[i], f.Name) - } + require.Equalf(t, f.Name, names[i], + "incorrect order. Expected %v, got %v", names[i], f.Name, + ) i++ }) } @@ -1272,14 +1337,44 @@ func TestVisitFlagOrder(t *testing.T) { names := []string{"C", "B", "A", "D"} for _, name := range names { fs.Bool(name, false, "") - fs.Set(name, "true") + _ = fs.Set(name, "true") } i := 0 fs.Visit(func(f *Flag) { - if names[i] != f.Name { - t.Errorf("Incorrect order. Expected %v, got %v", names[i], f.Name) - } + require.Equalf(t, f.Name, names[i], + "incorrect order. Expected %v, got %v", names[i], f.Name, + ) i++ }) } + +func parseReturnStderr(_ *testing.T, f *FlagSet, args []string) (string, error) { + oldStderr := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + + err := f.Parse(args) + + outC := make(chan string) + // copy the output in a separate goroutine so printing can't block indefinitely + go func() { + var buf bytes.Buffer + _, _ = io.Copy(&buf, r) + outC <- buf.String() + }() + + w.Close() + os.Stderr = oldStderr + out := <-outC + + return out, err +} + +func printFlagDefaults(f *FlagSet) string { + out := new(bytes.Buffer) + f.SetOutput(out) + f.PrintDefaults() + + return out.String() +} diff --git a/float32_slice_test.go b/float32_slice_test.go index 997ce5c6..f612cb9f 100644 --- a/float32_slice_test.go +++ b/float32_slice_test.go @@ -9,192 +9,172 @@ import ( "strconv" "strings" "testing" -) -func setUpF32SFlagSet(f32sp *[]float32) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.Float32SliceVar(f32sp, "f32s", []float32{}, "Command separated list!") - return f -} + "github.com/stretchr/testify/require" +) -func setUpF32SFlagSetWithDefault(f32sp *[]float32) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.Float32SliceVar(f32sp, "f32s", []float32{0.0, 1.0}, "Command separated list!") - return f -} +func TestFloat32Slice(t *testing.T) { + t.Parallel() -func TestEmptyF32S(t *testing.T) { - var f32s []float32 - f := setUpF32SFlagSet(&f32s) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) + newFlag := func(f32sp *[]float32) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.Float32SliceVar(f32sp, "f32s", []float32{}, "Command separated list!") + return f } - getF32S, err := f.GetFloat32Slice("f32s") - if err != nil { - t.Fatal("got an error from GetFloat32Slice():", err) - } - if len(getF32S) != 0 { - t.Fatalf("got f32s %v with len=%d but expected length=0", getF32S, len(getF32S)) - } -} + t.Run("with empty slice", func(t *testing.T) { + f32s := make([]float32, 0) + f := newFlag(&f32s) -func TestF32S(t *testing.T) { - var f32s []float32 - f := setUpF32SFlagSet(&f32s) + require.NoError(t, f.Parse([]string{})) - vals := []string{"1.0", "2.0", "4.0", "3.0"} - arg := fmt.Sprintf("--f32s=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range f32s { - d64, err := strconv.ParseFloat(vals[i], 32) - if err != nil { - t.Fatalf("got error: %v", err) - } + getF32S, err := f.GetFloat32Slice("f32s") + require.NoErrorf(t, err, + "got an error from GetFloat32Slice(): %v", err, + ) + require.Empty(t, getF32S) + }) - d := float32(d64) - if d != v { - t.Fatalf("expected f32s[%d] to be %s but got: %f", i, vals[i], v) - } - } - getF32S, err := f.GetFloat32Slice("f32s") - if err != nil { - t.Fatalf("got error: %v", err) - } - for i, v := range getF32S { - d64, err := strconv.ParseFloat(vals[i], 32) - if err != nil { - t.Fatalf("got error: %v", err) - } + t.Run("with values", func(t *testing.T) { + vals := []string{"1.0", "2.0", "4.0", "3.0"} + f32s := make([]float32, 0, len(vals)) + f := newFlag(&f32s) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--f32s=%s", strings.Join(vals, ",")), + })) + + for i, v := range f32s { + d64, err := strconv.ParseFloat(vals[i], 32) + require.NoError(t, err) - d := float32(d64) - if d != v { - t.Fatalf("expected f32s[%d] to be %s but got: %f from GetFloat32Slice", i, vals[i], v) + d := float32(d64) + require.Equalf(t, v, d, + "expected f32s[%d] to be %s but got: %f", i, vals[i], v, + ) } - } -} -func TestF32SDefault(t *testing.T) { - var f32s []float32 - f := setUpF32SFlagSetWithDefault(&f32s) + getF32S, erf := f.GetFloat32Slice("f32s") + require.NoError(t, erf) - vals := []string{"0.0", "1.0"} + for i, v := range getF32S { + d64, err := strconv.ParseFloat(vals[i], 32) + require.NoError(t, err) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range f32s { - d64, err := strconv.ParseFloat(vals[i], 32) - if err != nil { - t.Fatalf("got error: %v", err) + d := float32(d64) + require.Equalf(t, v, d, + "expected f32s[%d] to be %s but got: %f from GetFloat32Slice", i, vals[i], v, + ) } + }) - d := float32(d64) - if d != v { - t.Fatalf("expected f32s[%d] to be %f but got: %f", i, d, v) - } + newFlagWithDefault := func(f32sp *[]float32) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.Float32SliceVar(f32sp, "f32s", []float32{0.0, 1.0}, "Command separated list!") + return f } - getF32S, err := f.GetFloat32Slice("f32s") - if err != nil { - t.Fatal("got an error from GetFloat32Slice():", err) - } - for i, v := range getF32S { - d64, err := strconv.ParseFloat(vals[i], 32) - if err != nil { - t.Fatal("got an error from GetFloat32Slice():", err) - } + t.Run("with defaults (1)", func(t *testing.T) { + vals := []string{"0.0", "1.0"} + f32s := make([]float32, 0, len(vals)) + f := newFlagWithDefault(&f32s) - d := float32(d64) - if d != v { - t.Fatalf("expected f32s[%d] to be %f from GetFloat32Slice but got: %f", i, d, v) - } - } -} + require.NoError(t, f.Parse([]string{})) -func TestF32SWithDefault(t *testing.T) { - var f32s []float32 - f := setUpF32SFlagSetWithDefault(&f32s) + for i, v := range f32s { + d64, err := strconv.ParseFloat(vals[i], 32) + require.NoError(t, err) - vals := []string{"1.0", "2.0"} - arg := fmt.Sprintf("--f32s=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range f32s { - d64, err := strconv.ParseFloat(vals[i], 32) - if err != nil { - t.Fatalf("got error: %v", err) + d := float32(d64) + require.Equalf(t, v, d, + "expected f32s[%d] to be %f but got: %f", i, d, v, + ) } - d := float32(d64) - if d != v { - t.Fatalf("expected f32s[%d] to be %f but got: %f", i, d, v) - } - } + getF32S, erf := f.GetFloat32Slice("f32s") + require.NoErrorf(t, erf, + "got an error from GetFloat32Slice(): %v", erf, + ) - getF32S, err := f.GetFloat32Slice("f32s") - if err != nil { - t.Fatal("got an error from GetFloat32Slice():", err) - } - for i, v := range getF32S { - d64, err := strconv.ParseFloat(vals[i], 32) - if err != nil { - t.Fatalf("got error: %v", err) + for i, v := range getF32S { + d64, err := strconv.ParseFloat(vals[i], 32) + require.NoErrorf(t, err, + "got an error from GetFloat32Slice(): %v", err, + ) + + require.Equalf(t, v, float32(d64), + "expected f32s[%d] to be %f from GetFloat32Slice but got: %f", i, float32(d64), v, + ) } + }) + + t.Run("with defaults (2)", func(t *testing.T) { + vals := []string{"1.0", "2.0"} + f32s := make([]float32, 0, len(vals)) + f := newFlagWithDefault(&f32s) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--f32s=%s", strings.Join(vals, ",")), + })) - d := float32(d64) - if d != v { - t.Fatalf("expected f32s[%d] to be %f from GetFloat32Slice but got: %f", i, d, v) + for i, v := range f32s { + d64, err := strconv.ParseFloat(vals[i], 32) + require.NoError(t, err) + + require.Equalf(t, v, float32(d64), + "expected f32s[%d] to be %f but got: %f", i, float32(d64), v, + ) } - } -} -func TestF32SAsSliceValue(t *testing.T) { - var f32s []float32 - f := setUpF32SFlagSet(&f32s) - - in := []string{"1.0", "2.0"} - argfmt := "--f32s=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } + getF32S, erf := f.GetFloat32Slice("f32s") + require.NoErrorf(t, erf, + "got an error from GetFloat32Slice(): %v", erf, + ) + + for i, v := range getF32S { + d64, err := strconv.ParseFloat(vals[i], 32) + require.NoError(t, err) - f.VisitAll(func(f *Flag) { - if val, ok := f.Value.(SliceValue); ok { - _ = val.Replace([]string{"3.1"}) + require.Equalf(t, v, float32(d64), + "expected f32s[%d] to be %f from GetFloat32Slice but got: %f", i, float32(d64), v, + ) } }) - if len(f32s) != 1 || f32s[0] != 3.1 { - t.Fatalf("Expected ss to be overwritten with '3.1', but got: %v", f32s) - } -} -func TestF32SCalledTwice(t *testing.T) { - var f32s []float32 - f := setUpF32SFlagSet(&f32s) - - in := []string{"1.0,2.0", "3.0"} - expected := []float32{1.0, 2.0, 3.0} - argfmt := "--f32s=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range f32s { - if expected[i] != v { - t.Fatalf("expected f32s[%d] to be %f but got: %f", i, expected[i], v) - } - } + t.Run("called twice", func(t *testing.T) { + const argfmt = "--f32s=%s" + in := []string{"1.0,2.0", "3.0"} + f32s := make([]float32, 0, len(in)) + f := newFlag(&f32s) + + expected := []float32{1.0, 2.0, 3.0} + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + require.Equal(t, expected, f32s) + }) + + t.Run("as slice value", func(t *testing.T) { + const argfmt = "--f32s=%s" + in := []string{"1.0", "2.0"} + f32s := make([]float32, 0, len(in)) + f := newFlag(&f32s) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + f.VisitAll(func(f *Flag) { + if val, ok := f.Value.(SliceValue); ok { + require.NoError(t, val.Replace([]string{"3.1"})) + } + }) + + require.Equalf(t, []float32{3.1}, f32s, + "expected ss to be overwritten with '3.1', but got: %v", f32s, + ) + }) } diff --git a/float64_slice_test.go b/float64_slice_test.go index 43778ef1..00be9d15 100644 --- a/float64_slice_test.go +++ b/float64_slice_test.go @@ -9,180 +9,164 @@ import ( "strconv" "strings" "testing" -) -func setUpF64SFlagSet(f64sp *[]float64) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.Float64SliceVar(f64sp, "f64s", []float64{}, "Command separated list!") - return f -} + "github.com/stretchr/testify/require" +) -func setUpF64SFlagSetWithDefault(f64sp *[]float64) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.Float64SliceVar(f64sp, "f64s", []float64{0.0, 1.0}, "Command separated list!") - return f -} +func TestFloat64Slice(t *testing.T) { + t.Parallel() -func TestEmptyF64S(t *testing.T) { - var f64s []float64 - f := setUpF64SFlagSet(&f64s) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) + newFlag := func(f64sp *[]float64) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.Float64SliceVar(f64sp, "f64s", []float64{}, "Command separated list!") + return f } - getF64S, err := f.GetFloat64Slice("f64s") - if err != nil { - t.Fatal("got an error from GetFloat64Slice():", err) - } - if len(getF64S) != 0 { - t.Fatalf("got f64s %v with len=%d but expected length=0", getF64S, len(getF64S)) - } -} + t.Run("with empty slice", func(t *testing.T) { + f64s := make([]float64, 0) + f := newFlag(&f64s) -func TestF64S(t *testing.T) { - var f64s []float64 - f := setUpF64SFlagSet(&f64s) + require.NoError(t, f.Parse([]string{})) - vals := []string{"1.0", "2.0", "4.0", "3.0"} - arg := fmt.Sprintf("--f64s=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range f64s { - d, err := strconv.ParseFloat(vals[i], 64) - if err != nil { - t.Fatalf("got error: %v", err) - } - if d != v { - t.Fatalf("expected f64s[%d] to be %s but got: %f", i, vals[i], v) - } - } - getF64S, err := f.GetFloat64Slice("f64s") - if err != nil { - t.Fatalf("got error: %v", err) - } - for i, v := range getF64S { - d, err := strconv.ParseFloat(vals[i], 64) - if err != nil { - t.Fatalf("got error: %v", err) - } - if d != v { - t.Fatalf("expected f64s[%d] to be %s but got: %f from GetFloat64Slice", i, vals[i], v) - } - } -} + getF64S, err := f.GetFloat64Slice("f64s") + require.NoErrorf(t, err, + "got an error from GetFloat64Slice(): %v", err, + ) + require.Empty(t, getF64S) + }) -func TestF64SDefault(t *testing.T) { - var f64s []float64 - f := setUpF64SFlagSetWithDefault(&f64s) + t.Run("with values", func(t *testing.T) { + vals := []string{"1.0", "2.0", "4.0", "3.0"} + f64s := make([]float64, 0, len(vals)) + f := newFlag(&f64s) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--f64s=%s", strings.Join(vals, ",")), + })) + + for i, v := range f64s { + d, err := strconv.ParseFloat(vals[i], 64) + require.NoError(t, err) + require.Equalf(t, v, d, + "expected f64s[%d] to be %s but got: %f", i, vals[i], v, + ) + } - vals := []string{"0.0", "1.0"} + getF64S, err := f.GetFloat64Slice("f64s") + require.NoError(t, err) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range f64s { - d, err := strconv.ParseFloat(vals[i], 64) - if err != nil { - t.Fatalf("got error: %v", err) - } - if d != v { - t.Fatalf("expected f64s[%d] to be %f but got: %f", i, d, v) + for i, v := range getF64S { + d, err := strconv.ParseFloat(vals[i], 64) + require.NoError(t, err) + require.Equalf(t, v, d, + "expected f64s[%d] to be %s but got: %f from GetFloat64Slice", i, vals[i], v, + ) } - } + }) - getF64S, err := f.GetFloat64Slice("f64s") - if err != nil { - t.Fatal("got an error from GetFloat64Slice():", err) - } - for i, v := range getF64S { - d, err := strconv.ParseFloat(vals[i], 64) - if err != nil { - t.Fatal("got an error from GetFloat64Slice():", err) - } - if d != v { - t.Fatalf("expected f64s[%d] to be %f from GetFloat64Slice but got: %f", i, d, v) - } + newFlagWithDefault := func(f64sp *[]float64) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.Float64SliceVar(f64sp, "f64s", []float64{0.0, 1.0}, "Command separated list!") + return f } -} -func TestF64SWithDefault(t *testing.T) { - var f64s []float64 - f := setUpF64SFlagSetWithDefault(&f64s) + t.Run("with defaults (1)", func(t *testing.T) { + vals := []string{"0.0", "1.0"} + f64s := make([]float64, 0, len(vals)) + f := newFlagWithDefault(&f64s) - vals := []string{"1.0", "2.0"} - arg := fmt.Sprintf("--f64s=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range f64s { - d, err := strconv.ParseFloat(vals[i], 64) - if err != nil { - t.Fatalf("got error: %v", err) + require.NoError(t, f.Parse([]string{})) + + for i, v := range f64s { + d, err := strconv.ParseFloat(vals[i], 64) + require.NoError(t, err) + + require.Equalf(t, v, d, + "expected f64s[%d] to be %f but got: %f", i, d, v, + ) } - if d != v { - t.Fatalf("expected f64s[%d] to be %f but got: %f", i, d, v) + + getF64S, erf := f.GetFloat64Slice("f64s") + require.NoErrorf(t, erf, + "got an error from GetFloat64Slice(): %v", erf, + ) + + for i, v := range getF64S { + d, err := strconv.ParseFloat(vals[i], 64) + require.NoErrorf(t, err, + "got an error from GetFloat64Slice(): %v", err, + ) + require.Equalf(t, v, d, + "expected f64s[%d] to be %f from GetFloat64Slice but got: %f", i, d, v, + ) } - } + }) - getF64S, err := f.GetFloat64Slice("f64s") - if err != nil { - t.Fatal("got an error from GetFloat64Slice():", err) - } - for i, v := range getF64S { - d, err := strconv.ParseFloat(vals[i], 64) - if err != nil { - t.Fatalf("got error: %v", err) + t.Run("with defaults (2)", func(t *testing.T) { + vals := []string{"1.0", "2.0"} + f64s := make([]float64, 0, len(vals)) + f := newFlagWithDefault(&f64s) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--f64s=%s", strings.Join(vals, ",")), + })) + + for i, v := range f64s { + d, err := strconv.ParseFloat(vals[i], 64) + require.NoError(t, err) + require.Equalf(t, v, d, + "expected f64s[%d] to be %f but got: %f", i, d, v, + ) } - if d != v { - t.Fatalf("expected f64s[%d] to be %f from GetFloat64Slice but got: %f", i, d, v) + + getF64S, erf := f.GetFloat64Slice("f64s") + require.NoErrorf(t, erf, + "got an error from GetFloat64Slice(): %v", erf, + ) + + for i, v := range getF64S { + d, err := strconv.ParseFloat(vals[i], 64) + require.NoError(t, err) + require.Equalf(t, v, d, + "expected f64s[%d] to be %f from GetFloat64Slice but got: %f", i, d, v, + ) } - } -} + }) -func TestF64SAsSliceValue(t *testing.T) { - var f64s []float64 - f := setUpF64SFlagSet(&f64s) - - in := []string{"1.0", "2.0"} - argfmt := "--f64s=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } + t.Run("called twice", func(t *testing.T) { + const argfmt = "--f64s=%s" + in := []string{"1.0,2.0", "3.0"} + f64s := make([]float64, 0, len(in)) + f := newFlag(&f64s) + expected := []float64{1.0, 2.0, 3.0} - f.VisitAll(func(f *Flag) { - if val, ok := f.Value.(SliceValue); ok { - _ = val.Replace([]string{"3.1"}) - } + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + require.Equal(t, expected, f64s) }) - if len(f64s) != 1 || f64s[0] != 3.1 { - t.Fatalf("Expected ss to be overwritten with '3.1', but got: %v", f64s) - } -} -func TestF64SCalledTwice(t *testing.T) { - var f64s []float64 - f := setUpF64SFlagSet(&f64s) - - in := []string{"1.0,2.0", "3.0"} - expected := []float64{1.0, 2.0, 3.0} - argfmt := "--f64s=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range f64s { - if expected[i] != v { - t.Fatalf("expected f64s[%d] to be %f but got: %f", i, expected[i], v) - } - } + t.Run("as slice value", func(t *testing.T) { + const argfmt = "--f64s=%s" + in := []string{"1.0", "2.0"} + f64s := make([]float64, 0, len(in)) + f := newFlag(&f64s) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + f.VisitAll(func(f *Flag) { + if val, ok := f.Value.(SliceValue); ok { + require.NoError(t, val.Replace([]string{"3.1"})) + } + }) + + require.Equalf(t, []float64{3.1}, f64s, + "expected ss to be overwritten with '3.1', but got: %v", f64s, + ) + }) } diff --git a/go.mod b/go.mod index b2287eec..ce6c9b4b 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/spf13/pflag go 1.12 + +require github.com/stretchr/testify v1.8.4 // indirect diff --git a/go.sum b/go.sum index e69de29b..5bddba9a 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,17 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/golangflag.go b/golangflag.go index d3dd72b7..e058a787 100644 --- a/golangflag.go +++ b/golangflag.go @@ -61,7 +61,7 @@ func (v *flagValueWrapper) Type() string { // If the *flag.Flag.Name was a single character (ex: `v`) it will be accessiblei // with both `-v` and `--v` in flags. If the golang flag was more than a single // character (ex: `verbose`) it will only be accessible via `--verbose` -func PFlagFromGoFlag(goflag *goflag.Flag) *Flag { +func PFlagFromGoFlag(goflag *goflag.Flag) *Flag { // nolint: revive // Remember the default value as a string; it won't change. flag := &Flag{ Name: goflag.Name, diff --git a/golangflag_test.go b/golangflag_test.go index 5bd831bf..e233c8be 100644 --- a/golangflag_test.go +++ b/golangflag_test.go @@ -7,41 +7,27 @@ package pflag import ( goflag "flag" "testing" + + "github.com/stretchr/testify/require" ) func TestGoflags(t *testing.T) { goflag.String("stringFlag", "stringFlag", "stringFlag") goflag.Bool("boolFlag", false, "boolFlag") - f := NewFlagSet("test", ContinueOnError) f.AddGoFlagSet(goflag.CommandLine) - err := f.Parse([]string{"--stringFlag=bob", "--boolFlag"}) - if err != nil { - t.Fatal("expected no error; get", err) - } + require.NoError(t, f.Parse([]string{"--stringFlag=bob", "--boolFlag"})) getString, err := f.GetString("stringFlag") - if err != nil { - t.Fatal("expected no error; get", err) - } - if getString != "bob" { - t.Fatalf("expected getString=bob but got getString=%s", getString) - } + require.NoError(t, err) + require.Equal(t, "bob", getString) getBool, err := f.GetBool("boolFlag") - if err != nil { - t.Fatal("expected no error; get", err) - } - if getBool != true { - t.Fatalf("expected getBool=true but got getBool=%v", getBool) - } - if !f.Parsed() { - t.Fatal("f.Parsed() return false after f.Parse() called") - } - - // in fact it is useless. because `go test` called flag.Parse() - if !goflag.CommandLine.Parsed() { - t.Fatal("goflag.CommandLine.Parsed() return false after f.Parse() called") - } + require.NoError(t, err) + + require.True(t, getBool) + require.Truef(t, f.Parsed(), + "f.Parsed() return false after f.Parse() called", + ) } diff --git a/int16.go b/int16.go index f1a01d05..eef29a8c 100644 --- a/int16.go +++ b/int16.go @@ -1,3 +1,4 @@ +// nolint: dupl package pflag import "strconv" diff --git a/int32.go b/int32.go index 9b95944f..35002a5d 100644 --- a/int32.go +++ b/int32.go @@ -1,3 +1,4 @@ +// nolint: dupl package pflag import "strconv" diff --git a/int32_slice_test.go b/int32_slice_test.go index 809c5633..aadee002 100644 --- a/int32_slice_test.go +++ b/int32_slice_test.go @@ -9,186 +9,168 @@ import ( "strconv" "strings" "testing" -) -func setUpI32SFlagSet(isp *[]int32) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.Int32SliceVar(isp, "is", []int32{}, "Command separated list!") - return f -} + "github.com/stretchr/testify/require" +) -func setUpI32SFlagSetWithDefault(isp *[]int32) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.Int32SliceVar(isp, "is", []int32{0, 1}, "Command separated list!") - return f -} +func TestInt32Slice(t *testing.T) { + t.Parallel() -func TestEmptyI32S(t *testing.T) { - var is []int32 - f := setUpI32SFlagSet(&is) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) + newFlag := func(isp *[]int32) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.Int32SliceVar(isp, "is", []int32{}, "Command separated list!") + return f } - getI32S, err := f.GetInt32Slice("is") - if err != nil { - t.Fatal("got an error from GetInt32Slice():", err) - } - if len(getI32S) != 0 { - t.Fatalf("got is %v with len=%d but expected length=0", getI32S, len(getI32S)) - } -} + t.Run("with empty slice", func(t *testing.T) { + is := make([]int32, 0) + f := newFlag(&is) + require.NoError(t, f.Parse([]string{})) -func TestI32S(t *testing.T) { - var is []int32 - f := setUpI32SFlagSet(&is) + getI32S, err := f.GetInt32Slice("is") + require.NoErrorf(t, err, + "got an error from GetInt32Slice(): %v", err, + ) + require.Empty(t, getI32S) + }) - vals := []string{"1", "2", "4", "3"} - arg := fmt.Sprintf("--is=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range is { - d64, err := strconv.ParseInt(vals[i], 0, 32) - if err != nil { - t.Fatalf("got error: %v", err) - } - d := int32(d64) - if d != v { - t.Fatalf("expected is[%d] to be %s but got: %d", i, vals[i], v) - } - } - getI32S, err := f.GetInt32Slice("is") - if err != nil { - t.Fatalf("got error: %v", err) - } - for i, v := range getI32S { - d64, err := strconv.ParseInt(vals[i], 0, 32) - if err != nil { - t.Fatalf("got error: %v", err) - } - d := int32(d64) - if d != v { - t.Fatalf("expected is[%d] to be %s but got: %d from GetInt32Slice", i, vals[i], v) - } - } -} + t.Run("with values", func(t *testing.T) { + vals := []string{"1", "2", "4", "3"} + is := make([]int32, 0, len(vals)) + f := newFlag(&is) -func TestI32SDefault(t *testing.T) { - var is []int32 - f := setUpI32SFlagSetWithDefault(&is) + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--is=%s", strings.Join(vals, ",")), + })) - vals := []string{"0", "1"} + for i, v := range is { + d64, err := strconv.ParseInt(vals[i], 0, 32) + require.NoError(t, err) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range is { - d64, err := strconv.ParseInt(vals[i], 0, 32) - if err != nil { - t.Fatalf("got error: %v", err) + require.Equalf(t, v, int32(d64), + "expected is[%d] to be %s but got: %d", i, vals[i], int32(d64), + ) } - d := int32(d64) - if d != v { - t.Fatalf("expected is[%d] to be %d but got: %d", i, d, v) - } - } - getI32S, err := f.GetInt32Slice("is") - if err != nil { - t.Fatal("got an error from GetInt32Slice():", err) - } - for i, v := range getI32S { - d64, err := strconv.ParseInt(vals[i], 0, 32) - if err != nil { - t.Fatal("got an error from GetInt32Slice():", err) - } - d := int32(d64) - if d != v { - t.Fatalf("expected is[%d] to be %d from GetInt32Slice but got: %d", i, d, v) + getI32S, eri := f.GetInt32Slice("is") + require.NoError(t, eri) + + for i, v := range getI32S { + d64, err := strconv.ParseInt(vals[i], 0, 32) + require.NoError(t, err) + + require.Equalf(t, v, int32(d64), + "expected is[%d] to be %s but got: %d from GetInt32Slice", i, vals[i], int32(d64), + ) } + }) + + newFlagWithDefault := func(isp *[]int32) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.Int32SliceVar(isp, "is", []int32{0, 1}, "Command separated list!") + return f } -} -func TestI32SWithDefault(t *testing.T) { - var is []int32 - f := setUpI32SFlagSetWithDefault(&is) + t.Run("with defaults (1)", func(t *testing.T) { + vals := []string{"0", "1"} + is := make([]int32, 0, len(vals)) + f := newFlagWithDefault(&is) - vals := []string{"1", "2"} - arg := fmt.Sprintf("--is=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range is { - d64, err := strconv.ParseInt(vals[i], 0, 32) - if err != nil { - t.Fatalf("got error: %v", err) - } - d := int32(d64) - if d != v { - t.Fatalf("expected is[%d] to be %d but got: %d", i, d, v) + require.NoError(t, f.Parse([]string{})) + + for i, v := range is { + d64, err := strconv.ParseInt(vals[i], 0, 32) + require.NoError(t, err) + + require.Equalf(t, v, int32(d64), + "expected is[%d] to be %d but got: %d", i, v, int32(d64), + ) } - } - getI32S, err := f.GetInt32Slice("is") - if err != nil { - t.Fatal("got an error from GetInt32Slice():", err) - } - for i, v := range getI32S { - d64, err := strconv.ParseInt(vals[i], 0, 32) - if err != nil { - t.Fatalf("got error: %v", err) + getI32S, eri := f.GetInt32Slice("is") + require.NoErrorf(t, eri, + "got an error from GetInt32Slice(): %v", eri, + ) + + for i, v := range getI32S { + d64, err := strconv.ParseInt(vals[i], 0, 32) + require.NoErrorf(t, err, + "got an error from GetInt32Slice(): %v", err, + ) + + require.Equalf(t, v, int32(d64), + "expected is[%d] to be %d from GetInt32Slice but got: %d", i, v, int32(d64), + ) } - d := int32(d64) - if d != v { - t.Fatalf("expected is[%d] to be %d from GetInt32Slice but got: %d", i, d, v) + }) + + t.Run("with defaults (2)", func(t *testing.T) { + vals := []string{"1", "2"} + is := make([]int32, 0, len(vals)) + f := newFlagWithDefault(&is) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--is=%s", strings.Join(vals, ",")), + })) + + for i, v := range is { + d64, err := strconv.ParseInt(vals[i], 0, 32) + require.NoError(t, err) + + require.Equalf(t, v, int32(d64), + "expected is[%d] to be %d but got: %d", i, v, int32(d64), + ) } - } -} -func TestI32SAsSliceValue(t *testing.T) { - var i32s []int32 - f := setUpI32SFlagSet(&i32s) - - in := []string{"1", "2"} - argfmt := "--is=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } + getI32S, eri := f.GetInt32Slice("is") + require.NoErrorf(t, eri, + "got an error from GetInt32Slice(): %v", eri, + ) - f.VisitAll(func(f *Flag) { - if val, ok := f.Value.(SliceValue); ok { - _ = val.Replace([]string{"3"}) + for i, v := range getI32S { + d64, err := strconv.ParseInt(vals[i], 0, 32) + require.NoError(t, err) + + require.Equalf(t, v, int32(d64), + "expected is[%d] to be %d from GetInt32Slice but got: %d", i, v, int32(d64), + ) } }) - if len(i32s) != 1 || i32s[0] != 3 { - t.Fatalf("Expected ss to be overwritten with '3.1', but got: %v", i32s) - } -} -func TestI32SCalledTwice(t *testing.T) { - var is []int32 - f := setUpI32SFlagSet(&is) - - in := []string{"1,2", "3"} - expected := []int32{1, 2, 3} - argfmt := "--is=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range is { - if expected[i] != v { - t.Fatalf("expected is[%d] to be %d but got: %d", i, expected[i], v) - } - } + t.Run("called twice", func(t *testing.T) { + const argfmt = "--is=%s" + in := []string{"1,2", "3"} + is := make([]int32, 0, len(in)) + f := newFlag(&is) + expected := []int32{1, 2, 3} + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + require.Equal(t, expected, is) + }) + + t.Run("as slice value", func(t *testing.T) { + const argfmt = "--is=%s" + in := []string{"1", "2"} + i32s := make([]int32, 0, len(in)) + f := newFlag(&i32s) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + f.VisitAll(func(f *Flag) { + if val, ok := f.Value.(SliceValue); ok { + require.NoError(t, val.Replace([]string{"3"})) + } + }) + + require.Equalf(t, []int32{3}, i32s, + "expected ss to be overwritten with '3.1', but got: %v", i32s, + ) + }) } diff --git a/int64_slice_test.go b/int64_slice_test.go index 09805c76..2eb6fcc7 100644 --- a/int64_slice_test.go +++ b/int64_slice_test.go @@ -9,180 +9,167 @@ import ( "strconv" "strings" "testing" -) -func setUpI64SFlagSet(isp *[]int64) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.Int64SliceVar(isp, "is", []int64{}, "Command separated list!") - return f -} + "github.com/stretchr/testify/require" +) -func setUpI64SFlagSetWithDefault(isp *[]int64) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.Int64SliceVar(isp, "is", []int64{0, 1}, "Command separated list!") - return f -} +func TestInt64Slice(t *testing.T) { + t.Parallel() -func TestEmptyI64S(t *testing.T) { - var is []int64 - f := setUpI64SFlagSet(&is) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) + newFlag := func(isp *[]int64) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.Int64SliceVar(isp, "is", []int64{}, "Command separated list!") + return f } - getI64S, err := f.GetInt64Slice("is") - if err != nil { - t.Fatal("got an error from GetInt64Slice():", err) - } - if len(getI64S) != 0 { - t.Fatalf("got is %v with len=%d but expected length=0", getI64S, len(getI64S)) - } -} + t.Run("with empty slice", func(t *testing.T) { + is := make([]int64, 0) + f := newFlag(&is) + require.NoError(t, f.Parse([]string{})) -func TestI64S(t *testing.T) { - var is []int64 - f := setUpI64SFlagSet(&is) + getI64S, err := f.GetInt64Slice("is") + require.NoErrorf(t, err, + "got an error from GetInt64Slice(): %v", err, + ) + require.Empty(t, getI64S) + }) - vals := []string{"1", "2", "4", "3"} - arg := fmt.Sprintf("--is=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range is { - d, err := strconv.ParseInt(vals[i], 0, 64) - if err != nil { - t.Fatalf("got error: %v", err) - } - if d != v { - t.Fatalf("expected is[%d] to be %s but got: %d", i, vals[i], v) - } - } - getI64S, err := f.GetInt64Slice("is") - if err != nil { - t.Fatalf("got error: %v", err) - } - for i, v := range getI64S { - d, err := strconv.ParseInt(vals[i], 0, 64) - if err != nil { - t.Fatalf("got error: %v", err) - } - if d != v { - t.Fatalf("expected is[%d] to be %s but got: %d from GetInt64Slice", i, vals[i], v) - } - } -} + t.Run("with values", func(t *testing.T) { + vals := []string{"1", "2", "4", "3"} + is := make([]int64, 0, len(vals)) + f := newFlag(&is) -func TestI64SDefault(t *testing.T) { - var is []int64 - f := setUpI64SFlagSetWithDefault(&is) + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--is=%s", strings.Join(vals, ",")), + })) - vals := []string{"0", "1"} + for i, v := range is { + d, err := strconv.ParseInt(vals[i], 0, 64) + require.NoError(t, err) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range is { - d, err := strconv.ParseInt(vals[i], 0, 64) - if err != nil { - t.Fatalf("got error: %v", err) - } - if d != v { - t.Fatalf("expected is[%d] to be %d but got: %d", i, d, v) + require.Equalf(t, v, d, + "expected is[%d] to be %s but got: %d", i, vals[i], v, + ) } - } - getI64S, err := f.GetInt64Slice("is") - if err != nil { - t.Fatal("got an error from GetInt64Slice():", err) - } - for i, v := range getI64S { - d, err := strconv.ParseInt(vals[i], 0, 64) - if err != nil { - t.Fatal("got an error from GetInt64Slice():", err) - } - if d != v { - t.Fatalf("expected is[%d] to be %d from GetInt64Slice but got: %d", i, d, v) + getI64S, eri := f.GetInt64Slice("is") + require.NoError(t, eri) + + for i, v := range getI64S { + d, err := strconv.ParseInt(vals[i], 0, 64) + require.NoError(t, err) + + require.Equalf(t, v, d, + "expected is[%d] to be %s but got: %d from GetInt64Slice", i, vals[i], v, + ) } + }) + + newFlagWithDefault := func(isp *[]int64) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.Int64SliceVar(isp, "is", []int64{0, 1}, "Command separated list!") + return f } -} -func TestI64SWithDefault(t *testing.T) { - var is []int64 - f := setUpI64SFlagSetWithDefault(&is) + t.Run("with defaults (1)", func(t *testing.T) { + vals := []string{"0", "1"} + is := make([]int64, 0, len(vals)) + f := newFlagWithDefault(&is) - vals := []string{"1", "2"} - arg := fmt.Sprintf("--is=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range is { - d, err := strconv.ParseInt(vals[i], 0, 64) - if err != nil { - t.Fatalf("got error: %v", err) - } - if d != v { - t.Fatalf("expected is[%d] to be %d but got: %d", i, d, v) + require.NoError(t, f.Parse([]string{})) + + for i, v := range is { + d, err := strconv.ParseInt(vals[i], 0, 64) + require.NoError(t, err) + + require.Equalf(t, v, d, + "expected is[%d] to be %d but got: %d", i, v, d, + ) } - } - getI64S, err := f.GetInt64Slice("is") - if err != nil { - t.Fatal("got an error from GetInt64Slice():", err) - } - for i, v := range getI64S { - d, err := strconv.ParseInt(vals[i], 0, 64) - if err != nil { - t.Fatalf("got error: %v", err) + getI64S, eri := f.GetInt64Slice("is") + require.NoErrorf(t, eri, + "got an error from GetInt64Slice(): %v", eri, + ) + + for i, v := range getI64S { + d, err := strconv.ParseInt(vals[i], 0, 64) + require.NoErrorf(t, err, + "got an error from GetInt64Slice(): %v", err, + ) + require.Equalf(t, v, d, + "expected is[%d] to be %d from GetInt64Slice but got: %d", i, v, d, + ) } - if d != v { - t.Fatalf("expected is[%d] to be %d from GetInt64Slice but got: %d", i, d, v) + }) + + t.Run("with defaults (2)", func(t *testing.T) { + vals := []string{"1", "2"} + is := make([]int64, 0, len(vals)) + f := newFlagWithDefault(&is) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--is=%s", strings.Join(vals, ",")), + })) + + for i, v := range is { + d, err := strconv.ParseInt(vals[i], 0, 64) + require.NoError(t, err) + + require.Equalf(t, v, d, + "expected is[%d] to be %d but got: %d", i, v, d, + ) } - } -} -func TestI64SAsSliceValue(t *testing.T) { - var i64s []int64 - f := setUpI64SFlagSet(&i64s) - - in := []string{"1", "2"} - argfmt := "--is=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } + getI64S, eri := f.GetInt64Slice("is") + require.NoErrorf(t, eri, + "got an error from GetInt64Slice(): %v", eri, + ) - f.VisitAll(func(f *Flag) { - if val, ok := f.Value.(SliceValue); ok { - _ = val.Replace([]string{"3"}) + for i, v := range getI64S { + d, err := strconv.ParseInt(vals[i], 0, 64) + require.NoError(t, err) + + require.Equalf(t, v, d, + "expected is[%d] to be %d from GetInt64Slice but got: %d", i, d, v, + ) } }) - if len(i64s) != 1 || i64s[0] != 3 { - t.Fatalf("Expected ss to be overwritten with '3.1', but got: %v", i64s) - } -} -func TestI64SCalledTwice(t *testing.T) { - var is []int64 - f := setUpI64SFlagSet(&is) - - in := []string{"1,2", "3"} - expected := []int64{1, 2, 3} - argfmt := "--is=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range is { - if expected[i] != v { - t.Fatalf("expected is[%d] to be %d but got: %d", i, expected[i], v) - } - } + t.Run("called twice", func(t *testing.T) { + const argfmt = "--is=%s" + in := []string{"1,2", "3"} + is := make([]int64, 0, len(in)) + f := newFlag(&is) + expected := []int64{1, 2, 3} + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + require.Equal(t, expected, is) + }) + + t.Run("as slice value", func(t *testing.T) { + const argfmt = "--is=%s" + in := []string{"1", "2"} + i64s := make([]int64, 0, len(in)) + f := newFlag(&i64s) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + f.VisitAll(func(f *Flag) { + if val, ok := f.Value.(SliceValue); ok { + require.NoError(t, val.Replace([]string{"3"})) + } + }) + + require.Equalf(t, []int64{3}, i64s, + "expected ss to be overwritten with '3.1', but got: %v", i64s, + ) + }) } diff --git a/int8.go b/int8.go index 4da92228..952c9eed 100644 --- a/int8.go +++ b/int8.go @@ -1,3 +1,4 @@ +// nolint: dupl package pflag import "strconv" diff --git a/int_slice_test.go b/int_slice_test.go index 745aecb9..2ae0ba4d 100644 --- a/int_slice_test.go +++ b/int_slice_test.go @@ -9,157 +9,138 @@ import ( "strconv" "strings" "testing" -) - -func setUpISFlagSet(isp *[]int) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.IntSliceVar(isp, "is", []int{}, "Command separated list!") - return f -} - -func setUpISFlagSetWithDefault(isp *[]int) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.IntSliceVar(isp, "is", []int{0, 1}, "Command separated list!") - return f -} -func TestEmptyIS(t *testing.T) { - var is []int - f := setUpISFlagSet(&is) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - - getIS, err := f.GetIntSlice("is") - if err != nil { - t.Fatal("got an error from GetIntSlice():", err) - } - if len(getIS) != 0 { - t.Fatalf("got is %v with len=%d but expected length=0", getIS, len(getIS)) - } -} - -func TestIS(t *testing.T) { - var is []int - f := setUpISFlagSet(&is) + "github.com/stretchr/testify/require" +) - vals := []string{"1", "2", "4", "3"} - arg := fmt.Sprintf("--is=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range is { - d, err := strconv.Atoi(vals[i]) - if err != nil { - t.Fatalf("got error: %v", err) +func TestIntSlice(t *testing.T) { + t.Parallel() + + newFlag := func(isp *[]int) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.IntSliceVar(isp, "is", []int{}, "Command separated list!") + return f + } + + t.Run("with empty slice", func(t *testing.T) { + is := make([]int, 0) + f := newFlag(&is) + require.NoError(t, f.Parse([]string{})) + + getIS, err := f.GetIntSlice("is") + require.NoErrorf(t, err, + "got an error from GetIntSlice(): %v", err, + ) + require.Empty(t, getIS) + }) + + t.Run("with values", func(t *testing.T) { + vals := []string{"1", "2", "4", "3"} + is := make([]int, 0, len(vals)) + f := newFlag(&is) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--is=%s", strings.Join(vals, ",")), + })) + + for i, v := range is { + d, err := strconv.Atoi(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, d, + "expected is[%d] to be %d but got: %d", i, v, d, + ) } - if d != v { - t.Fatalf("expected is[%d] to be %s but got: %d", i, vals[i], v) - } - } - getIS, err := f.GetIntSlice("is") - if err != nil { - t.Fatalf("got error: %v", err) - } - for i, v := range getIS { - d, err := strconv.Atoi(vals[i]) - if err != nil { - t.Fatalf("got error: %v", err) - } - if d != v { - t.Fatalf("expected is[%d] to be %s but got: %d from GetIntSlice", i, vals[i], v) - } - } -} - -func TestISDefault(t *testing.T) { - var is []int - f := setUpISFlagSetWithDefault(&is) - vals := []string{"0", "1"} + getIS, eri := f.GetIntSlice("is") + require.NoError(t, eri) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range is { - d, err := strconv.Atoi(vals[i]) - if err != nil { - t.Fatalf("got error: %v", err) + for i, v := range getIS { + d, err := strconv.Atoi(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, d, + "expected is[%d] to be %d but got: %d from GetIntSlice", i, v, d, + ) } - if d != v { - t.Fatalf("expected is[%d] to be %d but got: %d", i, d, v) - } - } + }) - getIS, err := f.GetIntSlice("is") - if err != nil { - t.Fatal("got an error from GetIntSlice():", err) + newFlagWithDefault := func(isp *[]int) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.IntSliceVar(isp, "is", []int{0, 1}, "Command separated list!") + return f } - for i, v := range getIS { - d, err := strconv.Atoi(vals[i]) - if err != nil { - t.Fatal("got an error from GetIntSlice():", err) - } - if d != v { - t.Fatalf("expected is[%d] to be %d from GetIntSlice but got: %d", i, d, v) - } - } -} -func TestISWithDefault(t *testing.T) { - var is []int - f := setUpISFlagSetWithDefault(&is) + t.Run("with defaults (1)", func(t *testing.T) { + vals := []string{"0", "1"} + is := make([]int, 0, len(vals)) + f := newFlagWithDefault(&is) - vals := []string{"1", "2"} - arg := fmt.Sprintf("--is=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range is { - d, err := strconv.Atoi(vals[i]) - if err != nil { - t.Fatalf("got error: %v", err) - } - if d != v { - t.Fatalf("expected is[%d] to be %d but got: %d", i, d, v) + require.NoError(t, f.Parse([]string{})) + + for i, v := range is { + d, err := strconv.Atoi(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, d, + "expected is[%d] to be %d but got: %d", i, v, d, + ) } - } - getIS, err := f.GetIntSlice("is") - if err != nil { - t.Fatal("got an error from GetIntSlice():", err) - } - for i, v := range getIS { - d, err := strconv.Atoi(vals[i]) - if err != nil { - t.Fatalf("got error: %v", err) + getIS, eri := f.GetIntSlice("is") + require.NoErrorf(t, eri, + "got an error from GetIntSlice(): %v", eri, + ) + + for i, v := range getIS { + d, err := strconv.Atoi(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, d, + "expected is[%d] to be %d from GetIntSlice but got: %d", i, v, d, + ) } - if d != v { - t.Fatalf("expected is[%d] to be %d from GetIntSlice but got: %d", i, d, v) + }) + + t.Run("with defaults (2)", func(t *testing.T) { + vals := []string{"1", "2"} + is := make([]int, 0, len(vals)) + f := newFlagWithDefault(&is) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--is=%s", strings.Join(vals, ",")), + })) + + for i, v := range is { + d, err := strconv.Atoi(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, d, + "expected is[%d] to be %d but got: %d", i, v, d, + ) } - } -} -func TestISCalledTwice(t *testing.T) { - var is []int - f := setUpISFlagSet(&is) - - in := []string{"1,2", "3"} - expected := []int{1, 2, 3} - argfmt := "--is=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range is { - if expected[i] != v { - t.Fatalf("expected is[%d] to be %d but got: %d", i, expected[i], v) + getIS, eri := f.GetIntSlice("is") + require.NoErrorf(t, eri, + "got an error from GetIntSlice(): %v", eri, + ) + + for i, v := range getIS { + d, err := strconv.Atoi(vals[i]) + require.NoError(t, err) + require.Equalf(t, v, d, + "expected is[%d] to be %d from GetIntSlice but got: %d", i, v, d, + ) } - } + }) + + t.Run("called twice", func(t *testing.T) { + const argfmt = "--is=%s" + in := []string{"1,2", "3"} + is := make([]int, 0, len(in)) + f := newFlag(&is) + expected := []int{1, 2, 3} + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + require.Equal(t, expected, is) + }) } diff --git a/ip_slice.go b/ip_slice.go index 775faae4..11223900 100644 --- a/ip_slice.go +++ b/ip_slice.go @@ -72,7 +72,7 @@ func (s *ipSliceValue) String() string { return "[" + out + "]" } -func (s *ipSliceValue) fromString(val string) (net.IP, error) { +func (s *ipSliceValue) fromString(val string) (net.IP, error) { // nolint: unparam return net.ParseIP(strings.TrimSpace(val)), nil } diff --git a/ip_slice_test.go b/ip_slice_test.go index d1892768..de1b548a 100644 --- a/ip_slice_test.go +++ b/ip_slice_test.go @@ -5,241 +5,242 @@ import ( "net" "strings" "testing" -) -func setUpIPSFlagSet(ipsp *[]net.IP) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.IPSliceVar(ipsp, "ips", []net.IP{}, "Command separated list!") - return f -} + "github.com/stretchr/testify/require" +) -func setUpIPSFlagSetWithDefault(ipsp *[]net.IP) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.IPSliceVar(ipsp, "ips", - []net.IP{ - net.ParseIP("192.168.1.1"), - net.ParseIP("0:0:0:0:0:0:0:1"), - }, - "Command separated list!") - return f -} +func TestIPSlice(t *testing.T) { + t.Parallel() -func TestEmptyIP(t *testing.T) { - var ips []net.IP - f := setUpIPSFlagSet(&ips) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) + newFlag := func(ipsp *[]net.IP) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.IPSliceVar(ipsp, "ips", []net.IP{}, "Command separated list!") + return f } - getIPS, err := f.GetIPSlice("ips") - if err != nil { - t.Fatal("got an error from GetIPSlice():", err) - } - if len(getIPS) != 0 { - t.Fatalf("got ips %v with len=%d but expected length=0", getIPS, len(getIPS)) - } -} + t.Run("with empty slice", func(t *testing.T) { + ips := make([]net.IP, 0) + f := newFlag(&ips) + require.NoError(t, f.Parse([]string{})) -func TestIPS(t *testing.T) { - var ips []net.IP - f := setUpIPSFlagSet(&ips) + getIPS, err := f.GetIPSlice("ips") + require.NoErrorf(t, err, + "got an error from GetIPSlice(): %v", err, + ) + require.Empty(t, getIPS) + }) - vals := []string{"192.168.1.1", "10.0.0.1", "0:0:0:0:0:0:0:2"} - arg := fmt.Sprintf("--ips=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range ips { - if ip := net.ParseIP(vals[i]); ip == nil { - t.Fatalf("invalid string being converted to IP address: %s", vals[i]) - } else if !ip.Equal(v) { - t.Fatalf("expected ips[%d] to be %s but got: %s from GetIPSlice", i, vals[i], v) + t.Run("with values", func(t *testing.T) { + vals := []string{"192.168.1.1", "10.0.0.1", "0:0:0:0:0:0:0:2"} + ips := make([]net.IP, 0, len(vals)) + f := newFlag(&ips) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--ips=%s", strings.Join(vals, ",")), + })) + + for i, v := range ips { + ip := net.ParseIP(vals[i]) + require.NotNilf(t, ip, + "invalid string being converted to IP address: %s", vals[i], + ) + require.Truef(t, ip.Equal(v), + "expected ips[%d] to be %s but got: %s from GetIPSlice", i, vals[i], v, + ) } + }) + + newFlagWithDefault := func(ipsp *[]net.IP) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.IPSliceVar(ipsp, "ips", + []net.IP{ + net.ParseIP("192.168.1.1"), + net.ParseIP("0:0:0:0:0:0:0:1"), + }, + "Command separated list!") + return f } -} -func TestIPSDefault(t *testing.T) { - var ips []net.IP - f := setUpIPSFlagSetWithDefault(&ips) + t.Run("with defaults (1)", func(t *testing.T) { + vals := []string{"192.168.1.1", "0:0:0:0:0:0:0:1"} + ips := make([]net.IP, 0, len(vals)) + f := newFlagWithDefault(&ips) - vals := []string{"192.168.1.1", "0:0:0:0:0:0:0:1"} - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range ips { - if ip := net.ParseIP(vals[i]); ip == nil { - t.Fatalf("invalid string being converted to IP address: %s", vals[i]) - } else if !ip.Equal(v) { - t.Fatalf("expected ips[%d] to be %s but got: %s", i, vals[i], v) - } - } + require.NoError(t, f.Parse([]string{})) - getIPS, err := f.GetIPSlice("ips") - if err != nil { - t.Fatal("got an error from GetIPSlice") - } - for i, v := range getIPS { - if ip := net.ParseIP(vals[i]); ip == nil { - t.Fatalf("invalid string being converted to IP address: %s", vals[i]) - } else if !ip.Equal(v) { - t.Fatalf("expected ips[%d] to be %s but got: %s", i, vals[i], v) + for i, v := range ips { + ip := net.ParseIP(vals[i]) + require.NotNilf(t, ip, + "invalid string being converted to IP address: %s", vals[i], + ) + require.Truef(t, ip.Equal(v), + "expected ips[%d] to be %s but got: %s", i, vals[i], v, + ) } - } -} - -func TestIPSWithDefault(t *testing.T) { - var ips []net.IP - f := setUpIPSFlagSetWithDefault(&ips) - vals := []string{"192.168.1.1", "0:0:0:0:0:0:0:1"} - arg := fmt.Sprintf("--ips=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range ips { - if ip := net.ParseIP(vals[i]); ip == nil { - t.Fatalf("invalid string being converted to IP address: %s", vals[i]) - } else if !ip.Equal(v) { - t.Fatalf("expected ips[%d] to be %s but got: %s", i, vals[i], v) + getIPS, eri := f.GetIPSlice("ips") + require.NoErrorf(t, eri, + "got an error from GetIPSlice: %v", eri, + ) + + for i, v := range getIPS { + ip := net.ParseIP(vals[i]) + require.NotNilf(t, ip, + "invalid string being converted to IP address: %s", vals[i], + ) + require.Truef(t, ip.Equal(v), + "expected ips[%d] to be %s but got: %s", i, vals[i], v, + ) } - } + }) - getIPS, err := f.GetIPSlice("ips") - if err != nil { - t.Fatal("got an error from GetIPSlice") - } - for i, v := range getIPS { - if ip := net.ParseIP(vals[i]); ip == nil { - t.Fatalf("invalid string being converted to IP address: %s", vals[i]) - } else if !ip.Equal(v) { - t.Fatalf("expected ips[%d] to be %s but got: %s", i, vals[i], v) + t.Run("with defaults (2)", func(t *testing.T) { + vals := []string{"192.168.1.1", "0:0:0:0:0:0:0:1"} + ips := make([]net.IP, 0, len(vals)) + f := newFlagWithDefault(&ips) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--ips=%s", strings.Join(vals, ",")), + })) + + for i, v := range ips { + ip := net.ParseIP(vals[i]) + require.NotNilf(t, ip, + "invalid string being converted to IP address: %s", vals[i], + ) + require.Truef(t, ip.Equal(v), + "expected ips[%d] to be %s but got: %s", i, vals[i], v, + ) } - } -} -func TestIPSCalledTwice(t *testing.T) { - var ips []net.IP - f := setUpIPSFlagSet(&ips) - - in := []string{"192.168.1.2,0:0:0:0:0:0:0:1", "10.0.0.1"} - expected := []net.IP{net.ParseIP("192.168.1.2"), net.ParseIP("0:0:0:0:0:0:0:1"), net.ParseIP("10.0.0.1")} - argfmt := "ips=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range ips { - if !expected[i].Equal(v) { - t.Fatalf("expected ips[%d] to be %s but got: %s", i, expected[i], v) + getIPS, err := f.GetIPSlice("ips") + require.NoErrorf(t, err, + "got an error from GetIPSlice: %v", err, + ) + + for i, v := range getIPS { + ip := net.ParseIP(vals[i]) + require.NotNilf(t, ip, + "invalid string being converted to IP address: %s", vals[i], + ) + require.Truef(t, ip.Equal(v), + "expected ips[%d] to be %s but got: %s", i, vals[i], v, + ) } - } -} + }) -func TestIPSAsSliceValue(t *testing.T) { - var ips []net.IP - f := setUpIPSFlagSet(&ips) - - in := []string{"192.168.1.1", "0:0:0:0:0:0:0:1"} - argfmt := "--ips=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } + t.Run("called twice", func(t *testing.T) { + const argfmt = "--ips=%s" + in := []string{"192.168.1.2,0:0:0:0:0:0:0:1", "10.0.0.1"} + ips := make([]net.IP, 0, len(in)) + f := newFlag(&ips) + expected := []net.IP{net.ParseIP("192.168.1.2"), net.ParseIP("0:0:0:0:0:0:0:1"), net.ParseIP("10.0.0.1")} - f.VisitAll(func(f *Flag) { - if val, ok := f.Value.(SliceValue); ok { - _ = val.Replace([]string{"192.168.1.2"}) - } + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + require.Equal(t, expected, ips) }) - if len(ips) != 1 || !ips[0].Equal(net.ParseIP("192.168.1.2")) { - t.Fatalf("Expected ss to be overwritten with '192.168.1.2', but got: %v", ips) - } -} -func TestIPSBadQuoting(t *testing.T) { - - tests := []struct { - Want []net.IP - FlagArg []string - }{ - { - Want: []net.IP{ - net.ParseIP("a4ab:61d:f03e:5d7d:fad7:d4c2:a1a5:568"), - net.ParseIP("203.107.49.208"), - net.ParseIP("14.57.204.90"), - }, - FlagArg: []string{ - "a4ab:61d:f03e:5d7d:fad7:d4c2:a1a5:568", - "203.107.49.208", - "14.57.204.90", - }, - }, - { - Want: []net.IP{ - net.ParseIP("204.228.73.195"), - net.ParseIP("86.141.15.94"), - }, - FlagArg: []string{ - "204.228.73.195", - "86.141.15.94", - }, - }, - { - Want: []net.IP{ - net.ParseIP("c70c:db36:3001:890f:c6ea:3f9b:7a39:cc3f"), - net.ParseIP("4d17:1d6e:e699:bd7a:88c5:5e7e:ac6a:4472"), + t.Run("as slice value", func(t *testing.T) { + const argfmt = "--ips=%s" + in := []string{"192.168.1.1", "0:0:0:0:0:0:0:1"} + ips := make([]net.IP, 0, len(in)) + f := newFlag(&ips) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + f.VisitAll(func(f *Flag) { + if val, ok := f.Value.(SliceValue); ok { + require.NoError(t, val.Replace([]string{"192.168.1.2"})) + } + }) + + require.Equalf(t, []net.IP{net.ParseIP("192.168.1.2")}, ips, + "expected ss to be overwritten with '192.168.1.2', but got: %v", ips, + ) + }) + + t.Run("bad quoting", func(t *testing.T) { + tests := []struct { + Want []net.IP + FlagArg []string + }{ + { + Want: []net.IP{ + net.ParseIP("a4ab:61d:f03e:5d7d:fad7:d4c2:a1a5:568"), + net.ParseIP("203.107.49.208"), + net.ParseIP("14.57.204.90"), + }, + FlagArg: []string{ + "a4ab:61d:f03e:5d7d:fad7:d4c2:a1a5:568", + "203.107.49.208", + "14.57.204.90", + }, }, - FlagArg: []string{ - "c70c:db36:3001:890f:c6ea:3f9b:7a39:cc3f", - "4d17:1d6e:e699:bd7a:88c5:5e7e:ac6a:4472", + { + Want: []net.IP{ + net.ParseIP("204.228.73.195"), + net.ParseIP("86.141.15.94"), + }, + FlagArg: []string{ + "204.228.73.195", + "86.141.15.94", + }, }, - }, - { - Want: []net.IP{ - net.ParseIP("5170:f971:cfac:7be3:512a:af37:952c:bc33"), - net.ParseIP("93.21.145.140"), - net.ParseIP("2cac:61d3:c5ff:6caf:73e0:1b1a:c336:c1ca"), + { + Want: []net.IP{ + net.ParseIP("c70c:db36:3001:890f:c6ea:3f9b:7a39:cc3f"), + net.ParseIP("4d17:1d6e:e699:bd7a:88c5:5e7e:ac6a:4472"), + }, + FlagArg: []string{ + "c70c:db36:3001:890f:c6ea:3f9b:7a39:cc3f", + "4d17:1d6e:e699:bd7a:88c5:5e7e:ac6a:4472", + }, }, - FlagArg: []string{ - " 5170:f971:cfac:7be3:512a:af37:952c:bc33 , 93.21.145.140 ", - "2cac:61d3:c5ff:6caf:73e0:1b1a:c336:c1ca", + { + Want: []net.IP{ + net.ParseIP("5170:f971:cfac:7be3:512a:af37:952c:bc33"), + net.ParseIP("93.21.145.140"), + net.ParseIP("2cac:61d3:c5ff:6caf:73e0:1b1a:c336:c1ca"), + }, + FlagArg: []string{ + " 5170:f971:cfac:7be3:512a:af37:952c:bc33 , 93.21.145.140 ", + "2cac:61d3:c5ff:6caf:73e0:1b1a:c336:c1ca", + }, }, - }, - { - Want: []net.IP{ - net.ParseIP("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b"), - net.ParseIP("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b"), - net.ParseIP("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b"), - net.ParseIP("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b"), + { + Want: []net.IP{ + net.ParseIP("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b"), + net.ParseIP("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b"), + net.ParseIP("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b"), + net.ParseIP("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b"), + }, + FlagArg: []string{ + `"2e5e:66b2:6441:848:5b74:76ea:574c:3a7b, 2e5e:66b2:6441:848:5b74:76ea:574c:3a7b,2e5e:66b2:6441:848:5b74:76ea:574c:3a7b "`, + " 2e5e:66b2:6441:848:5b74:76ea:574c:3a7b"}, }, - FlagArg: []string{ - `"2e5e:66b2:6441:848:5b74:76ea:574c:3a7b, 2e5e:66b2:6441:848:5b74:76ea:574c:3a7b,2e5e:66b2:6441:848:5b74:76ea:574c:3a7b "`, - " 2e5e:66b2:6441:848:5b74:76ea:574c:3a7b"}, - }, - } - - for i, test := range tests { + } - var ips []net.IP - f := setUpIPSFlagSet(&ips) + for i, test := range tests { + var ips []net.IP + f := newFlag(&ips) - if err := f.Parse([]string{fmt.Sprintf("--ips=%s", strings.Join(test.FlagArg, ","))}); err != nil { - t.Fatalf("flag parsing failed with error: %s\nparsing:\t%#v\nwant:\t\t%s", - err, test.FlagArg, test.Want[i]) - } + if err := f.Parse([]string{fmt.Sprintf("--ips=%s", strings.Join(test.FlagArg, ","))}); err != nil { + t.Fatalf("flag parsing failed with error: %s\nparsing:\t%#v\nwant:\t\t%s", + err, test.FlagArg, test.Want[i]) + } - for j, b := range ips { - if !b.Equal(test.Want[j]) { - t.Fatalf("bad value parsed for test %d on net.IP %d:\nwant:\t%s\ngot:\t%s", i, j, test.Want[j], b) + for j, b := range ips { + if !b.Equal(test.Want[j]) { + t.Fatalf("bad value parsed for test %d on net.IP %d:\nwant:\t%s\ngot:\t%s", i, j, test.Want[j], b) + } } } - } + }) } diff --git a/ip_test.go b/ip_test.go index 1fec50e4..113e157d 100644 --- a/ip_test.go +++ b/ip_test.go @@ -5,15 +5,17 @@ import ( "net" "os" "testing" -) -func setUpIP(ip *net.IP) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.IPVar(ip, "address", net.ParseIP("0.0.0.0"), "IP Address") - return f -} + "github.com/stretchr/testify/require" +) func TestIP(t *testing.T) { + newFlag := func(ip *net.IP) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.IPVar(ip, "address", net.ParseIP("0.0.0.0"), "IP Address") + return f + } + testCases := []struct { input string success bool @@ -36,28 +38,27 @@ func TestIP(t *testing.T) { devnull, _ := os.Open(os.DevNull) os.Stderr = devnull + for i := range testCases { var addr net.IP - f := setUpIP(&addr) - + f := newFlag(&addr) tc := &testCases[i] - arg := fmt.Sprintf("--address=%s", tc.input) - err := f.Parse([]string{arg}) - if err != nil && tc.success == true { - t.Errorf("expected success, got %q", err) - continue - } else if err == nil && tc.success == false { - t.Errorf("expected failure") + err := f.Parse([]string{ + fmt.Sprintf("--address=%s", tc.input), + }) + if !tc.success { + require.Errorf(t, err, "expected failure") + continue - } else if tc.success { - ip, err := f.GetIP("address") - if err != nil { - t.Errorf("Got error trying to fetch the IP flag: %v", err) - } - if ip.String() != tc.expected { - t.Errorf("expected %q, got %q", tc.expected, ip.String()) - } } + + require.NoErrorf(t, err, "expected success, got %q", err) + + ip, err := f.GetIP("address") + require.NoErrorf(t, err, + "got error trying to fetch the IP flag: %v", err, + ) + require.Equal(t, tc.expected, ip.String()) } } diff --git a/ipnet_slice_test.go b/ipnet_slice_test.go index 11644c58..749f8409 100644 --- a/ipnet_slice_test.go +++ b/ipnet_slice_test.go @@ -5,235 +5,232 @@ import ( "net" "strings" "testing" -) - -// Helper function to set static slices -func getCIDR(ip net.IP, cidr *net.IPNet, err error) net.IPNet { - return *cidr -} - -func equalCIDR(c1 net.IPNet, c2 net.IPNet) bool { - if c1.String() == c2.String() { - return true - } - return false -} - -func setUpIPNetFlagSet(ipsp *[]net.IPNet) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.IPNetSliceVar(ipsp, "cidrs", []net.IPNet{}, "Command separated list!") - return f -} - -func setUpIPNetFlagSetWithDefault(ipsp *[]net.IPNet) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.IPNetSliceVar(ipsp, "cidrs", - []net.IPNet{ - getCIDR(net.ParseCIDR("192.168.1.1/16")), - getCIDR(net.ParseCIDR("fd00::/64")), - }, - "Command separated list!") - return f -} - -func TestEmptyIPNet(t *testing.T) { - var cidrs []net.IPNet - f := setUpIPNetFlagSet(&cidrs) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - getIPNet, err := f.GetIPNetSlice("cidrs") - if err != nil { - t.Fatal("got an error from GetIPNetSlice():", err) - } - if len(getIPNet) != 0 { - t.Fatalf("got ips %v with len=%d but expected length=0", getIPNet, len(getIPNet)) - } -} - -func TestIPNets(t *testing.T) { - var ips []net.IPNet - f := setUpIPNetFlagSet(&ips) + "github.com/stretchr/testify/require" +) - vals := []string{"192.168.1.1/24", "10.0.0.1/16", "fd00:0:0:0:0:0:0:2/64"} - arg := fmt.Sprintf("--cidrs=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range ips { - if _, cidr, _ := net.ParseCIDR(vals[i]); cidr == nil { - t.Fatalf("invalid string being converted to CIDR: %s", vals[i]) - } else if !equalCIDR(*cidr, v) { - t.Fatalf("expected ips[%d] to be %s but got: %s from GetIPSlice", i, vals[i], v) +func TestIPNetSlice(t *testing.T) { + t.Parallel() + + newFlag := func(ipsp *[]net.IPNet) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.IPNetSliceVar(ipsp, "cidrs", []net.IPNet{}, "Command separated list!") + return f + } + + t.Run("with empty slice", func(t *testing.T) { + cidrs := make([]net.IPNet, 0) + f := newFlag(&cidrs) + require.NoError(t, f.Parse([]string{})) + + getIPNet, err := f.GetIPNetSlice("cidrs") + require.NoErrorf(t, err, + "got an error from GetIPNetSlice(): %v", err, + ) + require.Empty(t, getIPNet) + }) + + t.Run("with values", func(t *testing.T) { + vals := []string{"192.168.1.1/24", "10.0.0.1/16", "fd00:0:0:0:0:0:0:2/64"} + ips := make([]net.IPNet, 0, len(vals)) + f := newFlag(&ips) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--cidrs=%s", strings.Join(vals, ",")), + })) + + for i, v := range ips { + _, cidr, _ := net.ParseCIDR(vals[i]) + require.NotNilf(t, cidr, + "invalid string being converted to CIDR: %s", vals[i], + ) + require.Truef(t, equalCIDR(*cidr, v), + "expected ips[%d] to be %s but got: %s from GetIPSlice", i, vals[i], v, + ) } + }) + + newFlagWithDefault := func(ipsp *[]net.IPNet) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.IPNetSliceVar(ipsp, "cidrs", + []net.IPNet{ + getCIDR(net.ParseCIDR("192.168.1.1/16")), + getCIDR(net.ParseCIDR("fd00::/64")), + }, + "Command separated list!") + return f } -} -func TestIPNetDefault(t *testing.T) { - var cidrs []net.IPNet - f := setUpIPNetFlagSetWithDefault(&cidrs) + t.Run("with defaults (1)", func(t *testing.T) { + vals := []string{"192.168.1.1/16", "fd00::/64"} + cidrs := make([]net.IPNet, 0, len(vals)) + f := newFlagWithDefault(&cidrs) - vals := []string{"192.168.1.1/16", "fd00::/64"} - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range cidrs { - if _, cidr, _ := net.ParseCIDR(vals[i]); cidr == nil { - t.Fatalf("invalid string being converted to CIDR: %s", vals[i]) - } else if !equalCIDR(*cidr, v) { - t.Fatalf("expected cidrs[%d] to be %s but got: %s", i, vals[i], v) - } - } + require.NoError(t, f.Parse([]string{})) - getIPNet, err := f.GetIPNetSlice("cidrs") - if err != nil { - t.Fatal("got an error from GetIPNetSlice") - } - for i, v := range getIPNet { - if _, cidr, _ := net.ParseCIDR(vals[i]); cidr == nil { - t.Fatalf("invalid string being converted to CIDR: %s", vals[i]) - } else if !equalCIDR(*cidr, v) { - t.Fatalf("expected cidrs[%d] to be %s but got: %s", i, vals[i], v) + for i, v := range cidrs { + _, cidr, _ := net.ParseCIDR(vals[i]) + require.NotNilf(t, cidr, + "invalid string being converted to CIDR: %s", vals[i], + ) + require.Truef(t, equalCIDR(*cidr, v), + "expected cidrs[%d] to be %s but got: %s", i, vals[i], v, + ) } - } -} - -func TestIPNetWithDefault(t *testing.T) { - var cidrs []net.IPNet - f := setUpIPNetFlagSetWithDefault(&cidrs) - vals := []string{"192.168.1.1/16", "fd00::/64"} - arg := fmt.Sprintf("--cidrs=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range cidrs { - if _, cidr, _ := net.ParseCIDR(vals[i]); cidr == nil { - t.Fatalf("invalid string being converted to CIDR: %s", vals[i]) - } else if !equalCIDR(*cidr, v) { - t.Fatalf("expected cidrs[%d] to be %s but got: %s", i, vals[i], v) + getIPNet, err := f.GetIPNetSlice("cidrs") + require.NoErrorf(t, err, + "got an error from GetIPNetSlice: %v", err, + ) + + for i, v := range getIPNet { + _, cidr, _ := net.ParseCIDR(vals[i]) + require.NotNilf(t, cidr, + "invalid string being converted to CIDR: %s", vals[i], + ) + require.Truef(t, equalCIDR(*cidr, v), + "expected cidrs[%d] to be %s but got: %s", i, vals[i], v, + ) } - } - - getIPNet, err := f.GetIPNetSlice("cidrs") - if err != nil { - t.Fatal("got an error from GetIPNetSlice") - } - for i, v := range getIPNet { - if _, cidr, _ := net.ParseCIDR(vals[i]); cidr == nil { - t.Fatalf("invalid string being converted to CIDR: %s", vals[i]) - } else if !equalCIDR(*cidr, v) { - t.Fatalf("expected cidrs[%d] to be %s but got: %s", i, vals[i], v) + }) + + t.Run("with defaults (2)", func(t *testing.T) { + vals := []string{"192.168.1.1/16", "fd00::/64"} + cidrs := make([]net.IPNet, 0, len(vals)) + f := newFlagWithDefault(&cidrs) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--cidrs=%s", strings.Join(vals, ",")), + })) + + for i, v := range cidrs { + _, cidr, _ := net.ParseCIDR(vals[i]) + require.NotNilf(t, cidr, + "invalid string being converted to CIDR: %s", vals[i], + ) + require.Truef(t, equalCIDR(*cidr, v), + "expected cidrs[%d] to be %s but got: %s", i, vals[i], v, + ) } - } -} - -func TestIPNetCalledTwice(t *testing.T) { - var cidrs []net.IPNet - f := setUpIPNetFlagSet(&cidrs) - - in := []string{"192.168.1.2/16,fd00::/64", "10.0.0.1/24"} - expected := []net.IPNet{ - getCIDR(net.ParseCIDR("192.168.1.2/16")), - getCIDR(net.ParseCIDR("fd00::/64")), - getCIDR(net.ParseCIDR("10.0.0.1/24")), - } - argfmt := "--cidrs=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range cidrs { - if !equalCIDR(expected[i], v) { - t.Fatalf("expected cidrs[%d] to be %s but got: %s", i, expected[i], v) + getIPNet, err := f.GetIPNetSlice("cidrs") + require.NoErrorf(t, err, + "got an error from GetIPNetSlice: %v", err, + ) + + for i, v := range getIPNet { + _, cidr, _ := net.ParseCIDR(vals[i]) + require.NotNilf(t, cidr, + "invalid string being converted to CIDR: %s", vals[i], + ) + require.Truef(t, equalCIDR(*cidr, v), + "expected cidrs[%d] to be %s but got: %s", i, vals[i], v, + ) + } + }) + + t.Run("called twice", func(t *testing.T) { + const argfmt = "--cidrs=%s" + in := []string{"192.168.1.2/16,fd00::/64", "10.0.0.1/24"} + cidrs := make([]net.IPNet, 0, len(in)) + f := newFlag(&cidrs) + expected := []net.IPNet{ + getCIDR(net.ParseCIDR("192.168.1.2/16")), + getCIDR(net.ParseCIDR("fd00::/64")), + getCIDR(net.ParseCIDR("10.0.0.1/24")), } - } -} -func TestIPNetBadQuoting(t *testing.T) { - - tests := []struct { - Want []net.IPNet - FlagArg []string - }{ - { - Want: []net.IPNet{ - getCIDR(net.ParseCIDR("a4ab:61d:f03e:5d7d:fad7:d4c2:a1a5:568/128")), - getCIDR(net.ParseCIDR("203.107.49.208/32")), - getCIDR(net.ParseCIDR("14.57.204.90/32")), - }, - FlagArg: []string{ - "a4ab:61d:f03e:5d7d:fad7:d4c2:a1a5:568/128", - "203.107.49.208/32", - "14.57.204.90/32", + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + require.Equal(t, expected, cidrs) + }) + + t.Run("bad quoting", func(t *testing.T) { + tests := []struct { + Want []net.IPNet + FlagArg []string + }{ + { + Want: []net.IPNet{ + getCIDR(net.ParseCIDR("a4ab:61d:f03e:5d7d:fad7:d4c2:a1a5:568/128")), + getCIDR(net.ParseCIDR("203.107.49.208/32")), + getCIDR(net.ParseCIDR("14.57.204.90/32")), + }, + FlagArg: []string{ + "a4ab:61d:f03e:5d7d:fad7:d4c2:a1a5:568/128", + "203.107.49.208/32", + "14.57.204.90/32", + }, }, - }, - { - Want: []net.IPNet{ - getCIDR(net.ParseCIDR("204.228.73.195/32")), - getCIDR(net.ParseCIDR("86.141.15.94/32")), + { + Want: []net.IPNet{ + getCIDR(net.ParseCIDR("204.228.73.195/32")), + getCIDR(net.ParseCIDR("86.141.15.94/32")), + }, + FlagArg: []string{ + "204.228.73.195/32", + "86.141.15.94/32", + }, }, - FlagArg: []string{ - "204.228.73.195/32", - "86.141.15.94/32", + { + Want: []net.IPNet{ + getCIDR(net.ParseCIDR("c70c:db36:3001:890f:c6ea:3f9b:7a39:cc3f/128")), + getCIDR(net.ParseCIDR("4d17:1d6e:e699:bd7a:88c5:5e7e:ac6a:4472/128")), + }, + FlagArg: []string{ + "c70c:db36:3001:890f:c6ea:3f9b:7a39:cc3f/128", + "4d17:1d6e:e699:bd7a:88c5:5e7e:ac6a:4472/128", + }, }, - }, - { - Want: []net.IPNet{ - getCIDR(net.ParseCIDR("c70c:db36:3001:890f:c6ea:3f9b:7a39:cc3f/128")), - getCIDR(net.ParseCIDR("4d17:1d6e:e699:bd7a:88c5:5e7e:ac6a:4472/128")), + { + Want: []net.IPNet{ + getCIDR(net.ParseCIDR("5170:f971:cfac:7be3:512a:af37:952c:bc33/128")), + getCIDR(net.ParseCIDR("93.21.145.140/32")), + getCIDR(net.ParseCIDR("2cac:61d3:c5ff:6caf:73e0:1b1a:c336:c1ca/128")), + }, + FlagArg: []string{ + " 5170:f971:cfac:7be3:512a:af37:952c:bc33/128 , 93.21.145.140/32 ", + "2cac:61d3:c5ff:6caf:73e0:1b1a:c336:c1ca/128", + }, }, - FlagArg: []string{ - "c70c:db36:3001:890f:c6ea:3f9b:7a39:cc3f/128", - "4d17:1d6e:e699:bd7a:88c5:5e7e:ac6a:4472/128", + { + Want: []net.IPNet{ + getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")), + getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")), + getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")), + getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")), + }, + FlagArg: []string{ + `"2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128, 2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128,2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128 "`, + " 2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128"}, }, - }, - { - Want: []net.IPNet{ - getCIDR(net.ParseCIDR("5170:f971:cfac:7be3:512a:af37:952c:bc33/128")), - getCIDR(net.ParseCIDR("93.21.145.140/32")), - getCIDR(net.ParseCIDR("2cac:61d3:c5ff:6caf:73e0:1b1a:c336:c1ca/128")), - }, - FlagArg: []string{ - " 5170:f971:cfac:7be3:512a:af37:952c:bc33/128 , 93.21.145.140/32 ", - "2cac:61d3:c5ff:6caf:73e0:1b1a:c336:c1ca/128", - }, - }, - { - Want: []net.IPNet{ - getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")), - getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")), - getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")), - getCIDR(net.ParseCIDR("2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128")), - }, - FlagArg: []string{ - `"2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128, 2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128,2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128 "`, - " 2e5e:66b2:6441:848:5b74:76ea:574c:3a7b/128"}, - }, - } - - for i, test := range tests { + } - var cidrs []net.IPNet - f := setUpIPNetFlagSet(&cidrs) + for i, test := range tests { + cidrs := make([]net.IPNet, 0, len(test.Want)) + f := newFlag(&cidrs) - if err := f.Parse([]string{fmt.Sprintf("--cidrs=%s", strings.Join(test.FlagArg, ","))}); err != nil { - t.Fatalf("flag parsing failed with error: %s\nparsing:\t%#v\nwant:\t\t%s", - err, test.FlagArg, test.Want[i]) - } + require.NoErrorf(t, f.Parse([]string{fmt.Sprintf("--cidrs=%s", strings.Join(test.FlagArg, ","))}), + "flag parsing failed with error:\nparsing:\t%#v\nwant:\t\t%s", + test.FlagArg, test.Want, + ) - for j, b := range cidrs { - if !equalCIDR(b, test.Want[j]) { - t.Fatalf("bad value parsed for test %d on net.IP %d:\nwant:\t%s\ngot:\t%s", i, j, test.Want[j], b) + for j, b := range cidrs { + require.Truef(t, equalCIDR(b, test.Want[j]), + "bad value parsed for test %d on net.IP %d:\nwant:\t%s\ngot:\t%s", i, j, test.Want[j], b, + ) } } - } + }) +} + +func getCIDR(_ net.IP, cidr *net.IPNet, _ error) net.IPNet { + return *cidr +} + +func equalCIDR(c1 net.IPNet, c2 net.IPNet) bool { + return c1.String() == c2.String() } diff --git a/ipnet_test.go b/ipnet_test.go index 335b6fa1..f4c7c12a 100644 --- a/ipnet_test.go +++ b/ipnet_test.go @@ -5,16 +5,18 @@ import ( "net" "os" "testing" -) -func setUpIPNet(ip *net.IPNet) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - _, def, _ := net.ParseCIDR("0.0.0.0/0") - f.IPNetVar(ip, "address", *def, "IP Address") - return f -} + "github.com/stretchr/testify/require" +) func TestIPNet(t *testing.T) { + newFlag := func(ip *net.IPNet) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + _, def, _ := net.ParseCIDR("0.0.0.0/0") + f.IPNetVar(ip, "address", *def, "IP Address") + return f + } + testCases := []struct { input string success bool @@ -43,28 +45,27 @@ func TestIPNet(t *testing.T) { devnull, _ := os.Open(os.DevNull) os.Stderr = devnull + for i := range testCases { var addr net.IPNet - f := setUpIPNet(&addr) - + f := newFlag(&addr) tc := &testCases[i] - arg := fmt.Sprintf("--address=%s", tc.input) - err := f.Parse([]string{arg}) - if err != nil && tc.success == true { - t.Errorf("expected success, got %q", err) - continue - } else if err == nil && tc.success == false { - t.Errorf("expected failure") + err := f.Parse([]string{ + fmt.Sprintf("--address=%s", tc.input), + }) + if !tc.success { + require.Errorf(t, err, "expected failure") + continue - } else if tc.success { - ip, err := f.GetIPNet("address") - if err != nil { - t.Errorf("Got error trying to fetch the IP flag: %v", err) - } - if ip.String() != tc.expected { - t.Errorf("expected %q, got %q", tc.expected, ip.String()) - } } + + require.NoErrorf(t, err, "expected success, got %q", err) + + ip, err := f.GetIPNet("address") + require.NoErrorf(t, err, + "got error trying to fetch the IPnet flag: %v", err, + ) + require.Equal(t, tc.expected, ip.String()) } } diff --git a/printusage_test.go b/printusage_test.go index df982aab..d12a9a5c 100644 --- a/printusage_test.go +++ b/printusage_test.go @@ -1,12 +1,21 @@ package pflag import ( - "bytes" - "io" "testing" + + "github.com/stretchr/testify/require" ) -const expectedOutput = ` --long-form Some description +func TestPrintUsage(t *testing.T) { + t.Run("with print", func(t *testing.T) { + f := NewFlagSet("test", ExitOnError) + + f.Bool("long-form", false, "Some description") + f.Bool("long-form2", false, "Some description\n with multiline") + f.BoolP("long-name", "s", false, "Some description") + f.BoolP("long-name2", "t", false, "Some description with\n multiline") + + const expectedOutput = ` --long-form Some description --long-form2 Some description with multiline -s, --long-name Some description @@ -14,40 +23,23 @@ const expectedOutput = ` --long-form Some description multiline ` -func setUpPFlagSet(buf io.Writer) *FlagSet { - f := NewFlagSet("test", ExitOnError) - f.Bool("long-form", false, "Some description") - f.Bool("long-form2", false, "Some description\n with multiline") - f.BoolP("long-name", "s", false, "Some description") - f.BoolP("long-name2", "t", false, "Some description with\n multiline") - f.SetOutput(buf) - return f -} + require.Equal(t, expectedOutput, printFlagDefaults(f)) + }) -func TestPrintUsage(t *testing.T) { - buf := bytes.Buffer{} - f := setUpPFlagSet(&buf) - f.PrintDefaults() - res := buf.String() - if res != expectedOutput { - t.Errorf("Expected \n%s \nActual \n%s", expectedOutput, res) - } -} + t.Run("with wrapped columns", func(t *testing.T) { + const cols = 80 -func setUpPFlagSet2(buf io.Writer) *FlagSet { - f := NewFlagSet("test", ExitOnError) - f.Bool("long-form", false, "Some description") - f.Bool("long-form2", false, "Some description\n with multiline") - f.BoolP("long-name", "s", false, "Some description") - f.BoolP("long-name2", "t", false, "Some description with\n multiline") - f.StringP("some-very-long-arg", "l", "test", "Some very long description having break the limit") - f.StringP("other-very-long-arg", "o", "long-default-value", "Some very long description having break the limit") - f.String("some-very-long-arg2", "very long default value", "Some very long description\nwith line break\nmultiple") - f.SetOutput(buf) - return f -} + f := NewFlagSet("test", ExitOnError) + + f.Bool("long-form", false, "Some description") + f.Bool("long-form2", false, "Some description\n with multiline") + f.BoolP("long-name", "s", false, "Some description") + f.BoolP("long-name2", "t", false, "Some description with\n multiline") + f.StringP("some-very-long-arg", "l", "test", "Some very long description having break the limit") + f.StringP("other-very-long-arg", "o", "long-default-value", "Some very long description having break the limit") + f.String("some-very-long-arg2", "very long default value", "Some very long description\nwith line break\nmultiple") -const expectedOutput2 = ` --long-form Some description + const expectedOutput = ` --long-form Some description --long-form2 Some description with multiline -s, --long-name Some description @@ -64,11 +56,6 @@ const expectedOutput2 = ` --long-form Some description value") ` -func TestPrintUsage_2(t *testing.T) { - buf := bytes.Buffer{} - f := setUpPFlagSet2(&buf) - res := f.FlagUsagesWrapped(80) - if res != expectedOutput2 { - t.Errorf("Expected \n%q \nActual \n%q", expectedOutput2, res) - } + require.Equal(t, expectedOutput, f.FlagUsagesWrapped(cols)) + }) } diff --git a/string_array.go b/string_array.go index d1ff0a96..5b67af38 100644 --- a/string_array.go +++ b/string_array.go @@ -30,18 +30,14 @@ func (s *stringArrayValue) Append(val string) error { func (s *stringArrayValue) Replace(val []string) error { out := make([]string, len(val)) - for i, d := range val { - out[i] = d - } + copy(out, val) *s.value = out return nil } func (s *stringArrayValue) GetSlice() []string { out := make([]string, len(*s.value)) - for i, d := range *s.value { - out[i] = d - } + copy(out, *s.value) return out } @@ -55,6 +51,9 @@ func (s *stringArrayValue) String() string { } func stringArrayConv(sval string) (interface{}, error) { + if len(sval) == 0 { + return []string{}, nil + } sval = sval[1 : len(sval)-1] // An empty string would cause a array with one (empty) string if len(sval) == 0 { diff --git a/string_array_test.go b/string_array_test.go index 3c6d5958..89abc09d 100644 --- a/string_array_test.go +++ b/string_array_test.go @@ -7,250 +7,161 @@ package pflag import ( "fmt" "testing" -) - -func setUpSAFlagSet(sap *[]string) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.StringArrayVar(sap, "sa", []string{}, "Command separated list!") - return f -} - -func setUpSAFlagSetWithDefault(sap *[]string) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.StringArrayVar(sap, "sa", []string{"default", "values"}, "Command separated list!") - return f -} - -func TestEmptySA(t *testing.T) { - var sa []string - f := setUpSAFlagSet(&sa) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - - getSA, err := f.GetStringArray("sa") - if err != nil { - t.Fatal("got an error from GetStringArray():", err) - } - if len(getSA) != 0 { - t.Fatalf("got sa %v with len=%d but expected length=0", getSA, len(getSA)) - } -} - -func TestEmptySAValue(t *testing.T) { - var sa []string - f := setUpSAFlagSet(&sa) - err := f.Parse([]string{"--sa="}) - if err != nil { - t.Fatal("expected no error; got", err) - } - - getSA, err := f.GetStringArray("sa") - if err != nil { - t.Fatal("got an error from GetStringArray():", err) - } - if len(getSA) != 0 { - t.Fatalf("got sa %v with len=%d but expected length=0", getSA, len(getSA)) - } -} - -func TestSADefault(t *testing.T) { - var sa []string - f := setUpSAFlagSetWithDefault(&sa) - vals := []string{"default", "values"} - - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range sa { - if vals[i] != v { - t.Fatalf("expected sa[%d] to be %s but got: %s", i, vals[i], v) - } - } - - getSA, err := f.GetStringArray("sa") - if err != nil { - t.Fatal("got an error from GetStringArray():", err) - } - for i, v := range getSA { - if vals[i] != v { - t.Fatalf("expected sa[%d] to be %s from GetStringArray but got: %s", i, vals[i], v) - } - } -} - -func TestSAWithDefault(t *testing.T) { - var sa []string - f := setUpSAFlagSetWithDefault(&sa) - - val := "one" - arg := fmt.Sprintf("--sa=%s", val) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - - if len(sa) != 1 { - t.Fatalf("expected number of values to be %d but %d", 1, len(sa)) - } - - if sa[0] != val { - t.Fatalf("expected value to be %s but got: %s", sa[0], val) - } - - getSA, err := f.GetStringArray("sa") - if err != nil { - t.Fatal("got an error from GetStringArray():", err) - } + "github.com/stretchr/testify/require" +) - if len(getSA) != 1 { - t.Fatalf("expected number of values to be %d but %d", 1, len(getSA)) - } +func TestStringArray(t *testing.T) { + t.Parallel() - if getSA[0] != val { - t.Fatalf("expected value to be %s but got: %s", getSA[0], val) + newFlag := func(sap *[]string) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.StringArrayVar(sap, "sa", []string{}, "Command separated list!") + return f } -} -func TestSACalledTwice(t *testing.T) { - var sa []string - f := setUpSAFlagSet(&sa) + t.Run("with empty slice", func(t *testing.T) { + sa := make([]string, 0) + f := newFlag(&sa) + require.NoError(t, f.Parse([]string{})) - in := []string{"one", "two"} - expected := []string{"one", "two"} - argfmt := "--sa=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } + getSA, err := f.GetStringArray("sa") + require.NoErrorf(t, err, + "got an error from GetStringArray(): %v", err, + ) + require.Empty(t, getSA) + }) - if len(expected) != len(sa) { - t.Fatalf("expected number of sa to be %d but got: %d", len(expected), len(sa)) - } - for i, v := range sa { - if expected[i] != v { - t.Fatalf("expected sa[%d] to be %s but got: %s", i, expected[i], v) - } - } + t.Run("with empty value", func(t *testing.T) { + sa := make([]string, 0) + f := newFlag(&sa) + require.NoError(t, f.Parse([]string{"--sa="})) - values, err := f.GetStringArray("sa") - if err != nil { - t.Fatal("expected no error; got", err) - } + getSA, err := f.GetStringArray("sa") + require.NoErrorf(t, err, + "got an error from GetStringArray(): %v", err, + ) + require.Empty(t, getSA) + }) - if len(expected) != len(values) { - t.Fatalf("expected number of values to be %d but got: %d", len(expected), len(sa)) - } - for i, v := range values { - if expected[i] != v { - t.Fatalf("expected got sa[%d] to be %s but got: %s", i, expected[i], v) - } + newFlagWithDefault := func(sap *[]string) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.StringArrayVar(sap, "sa", []string{"default", "values"}, "Command separated list!") + return f } -} -func TestSAWithSpecialChar(t *testing.T) { - var sa []string - f := setUpSAFlagSet(&sa) + t.Run("with default (1)", func(t *testing.T) { + vals := []string{"default", "values"} + sa := make([]string, 0, len(vals)) + f := newFlagWithDefault(&sa) - in := []string{"one,two", `"three"`, `"four,five",six`, "seven eight"} - expected := []string{"one,two", `"three"`, `"four,five",six`, "seven eight"} - argfmt := "--sa=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - arg3 := fmt.Sprintf(argfmt, in[2]) - arg4 := fmt.Sprintf(argfmt, in[3]) - err := f.Parse([]string{arg1, arg2, arg3, arg4}) - if err != nil { - t.Fatal("expected no error; got", err) - } + require.NoError(t, f.Parse([]string{})) + require.Equal(t, vals, sa) - if len(expected) != len(sa) { - t.Fatalf("expected number of sa to be %d but got: %d", len(expected), len(sa)) - } - for i, v := range sa { - if expected[i] != v { - t.Fatalf("expected sa[%d] to be %s but got: %s", i, expected[i], v) - } - } + getSA, err := f.GetStringArray("sa") + require.NoError(t, err) + require.Equal(t, vals, getSA) + }) - values, err := f.GetStringArray("sa") - if err != nil { - t.Fatal("expected no error; got", err) - } + t.Run("with default (2)", func(t *testing.T) { + val := "one" + sa := make([]string, 0, len(val)) + f := newFlagWithDefault(&sa) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--sa=%s", val), + })) + require.Equal(t, []string{val}, sa) + + getSA, err := f.GetStringArray("sa") + require.NoErrorf(t, err, + "got an error from GetStringArray(): %v", err, + ) + require.Equal(t, []string{val}, getSA) + }) - if len(expected) != len(values) { - t.Fatalf("expected number of values to be %d but got: %d", len(expected), len(values)) - } - for i, v := range values { - if expected[i] != v { - t.Fatalf("expected got sa[%d] to be %s but got: %s", i, expected[i], v) - } - } -} + t.Run("called twice", func(t *testing.T) { + const argfmt = "--sa=%s" + in := []string{"one", "two"} + sa := make([]string, 0, len(in)) + f := newFlag(&sa) + expected := []string{"one", "two"} + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + require.Equal(t, expected, sa) + + values, err := f.GetStringArray("sa") + require.NoError(t, err) + require.Equal(t, expected, values) + }) -func TestSAAsSliceValue(t *testing.T) { - var sa []string - f := setUpSAFlagSet(&sa) + t.Run("with special char", func(t *testing.T) { + const argfmt = "--sa=%s" + in := []string{"one,two", `"three"`, `"four,five",six`, "seven eight"} + sa := make([]string, 0, len(in)) + f := newFlag(&sa) + expected := []string{"one,two", `"three"`, `"four,five",six`, "seven eight"} + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + fmt.Sprintf(argfmt, in[2]), + fmt.Sprintf(argfmt, in[3]), + })) + require.Equal(t, expected, sa) + + values, err := f.GetStringArray("sa") + require.NoError(t, err) + require.Equal(t, expected, values) + }) - in := []string{"1ns", "2ns"} - argfmt := "--sa=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } + t.Run("with square bracket", func(t *testing.T) { + const argfmt = "--sa=%s" + in := []string{"][]-[", "[a-z]", "[a-z]+"} + sa := make([]string, 0, len(in)) + f := newFlag(&sa) + expected := []string{"][]-[", "[a-z]", "[a-z]+"} + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + fmt.Sprintf(argfmt, in[2]), + })) + require.Equal(t, expected, sa) + + values, err := f.GetStringArray("sa") + require.NoError(t, err) + require.Equal(t, expected, values) + }) - f.VisitAll(func(f *Flag) { - if val, ok := f.Value.(SliceValue); ok { - _ = val.Replace([]string{"3ns"}) - } + t.Run("with slice as value", func(t *testing.T) { + const argfmt = "--sa=%s" + in := []string{"1ns", "2ns"} + sa := make([]string, 0, len(in)) + f := newFlag(&sa) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + f.VisitAll(func(f *Flag) { + if val, ok := f.Value.(SliceValue); ok { + require.NoError(t, val.Replace([]string{"3ns"})) + } + }) + require.Equalf(t, []string{"3ns"}, sa, + "expected ss to be overwritten with '3ns', but got: %v", sa, + ) }) - if len(sa) != 1 || sa[0] != "3ns" { - t.Fatalf("Expected ss to be overwritten with '3ns', but got: %v", sa) - } } -func TestSAWithSquareBrackets(t *testing.T) { - var sa []string - f := setUpSAFlagSet(&sa) - - in := []string{"][]-[", "[a-z]", "[a-z]+"} - expected := []string{"][]-[", "[a-z]", "[a-z]+"} - argfmt := "--sa=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - arg3 := fmt.Sprintf(argfmt, in[2]) - err := f.Parse([]string{arg1, arg2, arg3}) - if err != nil { - t.Fatal("expected no error; got", err) - } - - if len(expected) != len(sa) { - t.Fatalf("expected number of sa to be %d but got: %d", len(expected), len(sa)) - } - for i, v := range sa { - if expected[i] != v { - t.Fatalf("expected sa[%d] to be %s but got: %s", i, expected[i], v) - } - } - - values, err := f.GetStringArray("sa") - if err != nil { - t.Fatal("expected no error; got", err) - } - - if len(expected) != len(values) { - t.Fatalf("expected number of values to be %d but got: %d", len(expected), len(values)) - } - for i, v := range values { - if expected[i] != v { - t.Fatalf("expected got sa[%d] to be %s but got: %s", i, expected[i], v) - } - } +func TestStringArrayConv(t *testing.T) { + t.Run("with empty string", func(t *testing.T) { + _, err := stringArrayConv("") + require.NoError(t, err) + }) } diff --git a/string_slice.go b/string_slice.go index 3cb2e69d..d421887e 100644 --- a/string_slice.go +++ b/string_slice.go @@ -98,9 +98,12 @@ func (f *FlagSet) GetStringSlice(name string) ([]string, error) { // The argument p points to a []string variable in which to store the value of the flag. // Compared to StringArray flags, StringSlice flags take comma-separated value as arguments and split them accordingly. // For example: -// --ss="v1,v2" --ss="v3" +// +// --ss="v1,v2" --ss="v3" +// // will result in -// []string{"v1", "v2", "v3"} +// +// []string{"v1", "v2", "v3"} func (f *FlagSet) StringSliceVar(p *[]string, name string, value []string, usage string) { f.VarP(newStringSliceValue(value, p), name, "", usage) } @@ -114,9 +117,12 @@ func (f *FlagSet) StringSliceVarP(p *[]string, name, shorthand string, value []s // The argument p points to a []string variable in which to store the value of the flag. // Compared to StringArray flags, StringSlice flags take comma-separated value as arguments and split them accordingly. // For example: -// --ss="v1,v2" --ss="v3" +// +// --ss="v1,v2" --ss="v3" +// // will result in -// []string{"v1", "v2", "v3"} +// +// []string{"v1", "v2", "v3"} func StringSliceVar(p *[]string, name string, value []string, usage string) { CommandLine.VarP(newStringSliceValue(value, p), name, "", usage) } @@ -130,9 +136,12 @@ func StringSliceVarP(p *[]string, name, shorthand string, value []string, usage // The return value is the address of a []string variable that stores the value of the flag. // Compared to StringArray flags, StringSlice flags take comma-separated value as arguments and split them accordingly. // For example: -// --ss="v1,v2" --ss="v3" +// +// --ss="v1,v2" --ss="v3" +// // will result in -// []string{"v1", "v2", "v3"} +// +// []string{"v1", "v2", "v3"} func (f *FlagSet) StringSlice(name string, value []string, usage string) *[]string { p := []string{} f.StringSliceVarP(&p, name, "", value, usage) @@ -150,9 +159,12 @@ func (f *FlagSet) StringSliceP(name, shorthand string, value []string, usage str // The return value is the address of a []string variable that stores the value of the flag. // Compared to StringArray flags, StringSlice flags take comma-separated value as arguments and split them accordingly. // For example: -// --ss="v1,v2" --ss="v3" +// +// --ss="v1,v2" --ss="v3" +// // will result in -// []string{"v1", "v2", "v3"} +// +// []string{"v1", "v2", "v3"} func StringSlice(name string, value []string, usage string) *[]string { return CommandLine.StringSliceP(name, "", value, usage) } diff --git a/string_slice_test.go b/string_slice_test.go index 96924617..a980c9bc 100644 --- a/string_slice_test.go +++ b/string_slice_test.go @@ -8,269 +8,175 @@ import ( "fmt" "strings" "testing" -) -func setUpSSFlagSet(ssp *[]string) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.StringSliceVar(ssp, "ss", []string{}, "Command separated list!") - return f -} + "github.com/stretchr/testify/require" +) -func setUpSSFlagSetWithDefault(ssp *[]string) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.StringSliceVar(ssp, "ss", []string{"default", "values"}, "Command separated list!") - return f -} +func TestStringSlice(t *testing.T) { + t.Parallel() -func TestEmptySS(t *testing.T) { - var ss []string - f := setUpSSFlagSet(&ss) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) + newFlag := func(ssp *[]string) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.StringSliceVar(ssp, "ss", []string{}, "Command separated list!") + return f } - getSS, err := f.GetStringSlice("ss") - if err != nil { - t.Fatal("got an error from GetStringSlice():", err) - } - if len(getSS) != 0 { - t.Fatalf("got ss %v with len=%d but expected length=0", getSS, len(getSS)) - } -} + t.Run("with empty slice", func(t *testing.T) { + ss := make([]string, 0) + f := newFlag(&ss) -func TestEmptySSValue(t *testing.T) { - var ss []string - f := setUpSSFlagSet(&ss) - err := f.Parse([]string{"--ss="}) - if err != nil { - t.Fatal("expected no error; got", err) - } + require.NoError(t, f.Parse([]string{})) - getSS, err := f.GetStringSlice("ss") - if err != nil { - t.Fatal("got an error from GetStringSlice():", err) - } - if len(getSS) != 0 { - t.Fatalf("got ss %v with len=%d but expected length=0", getSS, len(getSS)) - } -} + getSS, err := f.GetStringSlice("ss") + require.NoErrorf(t, err, + "got an error from GetStringSlice(): %v", err, + ) + require.Empty(t, getSS) + }) -func TestSS(t *testing.T) { - var ss []string - f := setUpSSFlagSet(&ss) + t.Run("with empty values", func(t *testing.T) { + ss := make([]string, 0) + f := newFlag(&ss) + require.NoError(t, f.Parse([]string{"--ss="})) - vals := []string{"one", "two", "4", "3"} - arg := fmt.Sprintf("--ss=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range ss { - if vals[i] != v { - t.Fatalf("expected ss[%d] to be %s but got: %s", i, vals[i], v) - } - } + getSS, err := f.GetStringSlice("ss") + require.NoErrorf(t, err, + "got an error from GetStringSlice(): %v", err, + ) + require.Empty(t, getSS) + }) - getSS, err := f.GetStringSlice("ss") - if err != nil { - t.Fatal("got an error from GetStringSlice():", err) - } - for i, v := range getSS { - if vals[i] != v { - t.Fatalf("expected ss[%d] to be %s from GetStringSlice but got: %s", i, vals[i], v) - } - } -} + t.Run("with values", func(t *testing.T) { + vals := []string{"one", "two", "4", "3"} + ss := make([]string, 0, len(vals)) + f := newFlag(&ss) -func TestSSDefault(t *testing.T) { - var ss []string - f := setUpSSFlagSetWithDefault(&ss) + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--ss=%s", strings.Join(vals, ",")), + })) - vals := []string{"default", "values"} + require.Equal(t, vals, ss) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range ss { - if vals[i] != v { - t.Fatalf("expected ss[%d] to be %s but got: %s", i, vals[i], v) - } - } + getSS, err := f.GetStringSlice("ss") + require.NoErrorf(t, err, + "got an error from GetStringSlice(): %v", err, + ) + require.Equal(t, vals, getSS) + }) - getSS, err := f.GetStringSlice("ss") - if err != nil { - t.Fatal("got an error from GetStringSlice():", err) + newFlagWithDefault := func(ssp *[]string) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.StringSliceVar(ssp, "ss", []string{"default", "values"}, "Command separated list!") + return f } - for i, v := range getSS { - if vals[i] != v { - t.Fatalf("expected ss[%d] to be %s from GetStringSlice but got: %s", i, vals[i], v) - } - } -} -func TestSSWithDefault(t *testing.T) { - var ss []string - f := setUpSSFlagSetWithDefault(&ss) + t.Run("with defaults (1)", func(t *testing.T) { + vals := []string{"default", "values"} + ss := make([]string, 0, len(vals)) + f := newFlagWithDefault(&ss) - vals := []string{"one", "two", "4", "3"} - arg := fmt.Sprintf("--ss=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range ss { - if vals[i] != v { - t.Fatalf("expected ss[%d] to be %s but got: %s", i, vals[i], v) - } - } + require.NoError(t, f.Parse([]string{})) + require.Equal(t, vals, ss) - getSS, err := f.GetStringSlice("ss") - if err != nil { - t.Fatal("got an error from GetStringSlice():", err) - } - for i, v := range getSS { - if vals[i] != v { - t.Fatalf("expected ss[%d] to be %s from GetStringSlice but got: %s", i, vals[i], v) - } - } -} - -func TestSSCalledTwice(t *testing.T) { - var ss []string - f := setUpSSFlagSet(&ss) - - in := []string{"one,two", "three"} - expected := []string{"one", "two", "three"} - argfmt := "--ss=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } - - if len(expected) != len(ss) { - t.Fatalf("expected number of ss to be %d but got: %d", len(expected), len(ss)) - } - for i, v := range ss { - if expected[i] != v { - t.Fatalf("expected ss[%d] to be %s but got: %s", i, expected[i], v) - } - } + getSS, err := f.GetStringSlice("ss") + require.NoErrorf(t, err, + "got an error from GetStringSlice(): %v", err, + ) + require.Equal(t, vals, getSS) + }) - values, err := f.GetStringSlice("ss") - if err != nil { - t.Fatal("expected no error; got", err) - } + t.Run("with defaults (2)", func(t *testing.T) { + vals := []string{"one", "two", "4", "3"} + ss := make([]string, 0, len(vals)) + f := newFlagWithDefault(&ss) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--ss=%s", strings.Join(vals, ",")), + })) + require.Equal(t, vals, ss) + + getSS, err := f.GetStringSlice("ss") + require.NoErrorf(t, err, + "got an error from GetStringSlice(): %v", err, + ) + require.Equal(t, vals, getSS) + }) - if len(expected) != len(values) { - t.Fatalf("expected number of values to be %d but got: %d", len(expected), len(ss)) - } - for i, v := range values { - if expected[i] != v { - t.Fatalf("expected got ss[%d] to be %s but got: %s", i, expected[i], v) - } - } -} + t.Run("called twice", func(t *testing.T) { + const argfmt = "--ss=%s" + in := []string{"one,two", "three"} + ss := make([]string, 0, len(in)) + f := newFlag(&ss) + expected := []string{"one", "two", "three"} + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + require.Equal(t, expected, ss) + + values, err := f.GetStringSlice("ss") + require.NoError(t, err) + require.Equal(t, expected, values) + }) -func TestSSWithComma(t *testing.T) { - var ss []string - f := setUpSSFlagSet(&ss) - - in := []string{`"one,two"`, `"three"`, `"four,five",six`} - expected := []string{"one,two", "three", "four,five", "six"} - argfmt := "--ss=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - arg3 := fmt.Sprintf(argfmt, in[2]) - err := f.Parse([]string{arg1, arg2, arg3}) - if err != nil { - t.Fatal("expected no error; got", err) - } + t.Run("with comma", func(t *testing.T) { + const argfmt = "--ss=%s" + in := []string{`"one,two"`, `"three"`, `"four,five",six`} + ss := make([]string, 0, len(in)) + f := newFlag(&ss) + expected := []string{"one,two", "three", "four,five", "six"} - if len(expected) != len(ss) { - t.Fatalf("expected number of ss to be %d but got: %d", len(expected), len(ss)) - } - for i, v := range ss { - if expected[i] != v { - t.Fatalf("expected ss[%d] to be %s but got: %s", i, expected[i], v) - } - } + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + fmt.Sprintf(argfmt, in[2]), + })) - values, err := f.GetStringSlice("ss") - if err != nil { - t.Fatal("expected no error; got", err) - } + require.Equal(t, expected, ss) - if len(expected) != len(values) { - t.Fatalf("expected number of values to be %d but got: %d", len(expected), len(values)) - } - for i, v := range values { - if expected[i] != v { - t.Fatalf("expected got ss[%d] to be %s but got: %s", i, expected[i], v) - } - } -} - -func TestSSWithSquareBrackets(t *testing.T) { - var ss []string - f := setUpSSFlagSet(&ss) - - in := []string{`"[a-z]"`, `"[a-z]+"`} - expected := []string{"[a-z]", "[a-z]+"} - argfmt := "--ss=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } + values, err := f.GetStringSlice("ss") + require.NoError(t, err) + require.Equal(t, expected, values) + }) - if len(expected) != len(ss) { - t.Fatalf("expected number of ss to be %d but got: %d", len(expected), len(ss)) - } - for i, v := range ss { - if expected[i] != v { - t.Fatalf("expected ss[%d] to be %s but got: %s", i, expected[i], v) - } - } + t.Run("with square bracket", func(t *testing.T) { + const argfmt = "--ss=%s" + in := []string{`"[a-z]"`, `"[a-z]+"`} + ss := make([]string, 0, len(in)) + f := newFlag(&ss) + expected := []string{"[a-z]", "[a-z]+"} - values, err := f.GetStringSlice("ss") - if err != nil { - t.Fatal("expected no error; got", err) - } + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) - if len(expected) != len(values) { - t.Fatalf("expected number of values to be %d but got: %d", len(expected), len(values)) - } - for i, v := range values { - if expected[i] != v { - t.Fatalf("expected got ss[%d] to be %s but got: %s", i, expected[i], v) - } - } -} + require.Equal(t, expected, ss) -func TestSSAsSliceValue(t *testing.T) { - var ss []string - f := setUpSSFlagSet(&ss) - - in := []string{"one", "two"} - argfmt := "--ss=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } + values, err := f.GetStringSlice("ss") + require.NoError(t, err) + require.Equal(t, expected, values) + }) - f.VisitAll(func(f *Flag) { - if val, ok := f.Value.(SliceValue); ok { - _ = val.Replace([]string{"three"}) - } + t.Run("as slice value", func(t *testing.T) { + const argfmt = "--ss=%s" + in := []string{"one", "two"} + ss := make([]string, 0, len(in)) + f := newFlag(&ss) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + f.VisitAll(func(f *Flag) { + if val, ok := f.Value.(SliceValue); ok { + require.NoError(t, val.Replace([]string{"three"})) + } + }) + require.Equalf(t, []string{"three"}, ss, + "expected ss to be overwritten with 'three', but got: %s", ss, + ) }) - if len(ss) != 1 || ss[0] != "three" { - t.Fatalf("Expected ss to be overwritten with 'three', but got: %s", ss) - } } diff --git a/string_to_int64_test.go b/string_to_int64_test.go index 2b3f2989..27e4c4d0 100644 --- a/string_to_int64_test.go +++ b/string_to_int64_test.go @@ -9,148 +9,104 @@ import ( "fmt" "strconv" "testing" -) - -func setUpS2I64FlagSet(s2ip *map[string]int64) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.StringToInt64Var(s2ip, "s2i", map[string]int64{}, "Command separated ls2it!") - return f -} - -func setUpS2I64FlagSetWithDefault(s2ip *map[string]int64) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.StringToInt64Var(s2ip, "s2i", map[string]int64{"a": 1, "b": 2}, "Command separated ls2it!") - return f -} - -func createS2I64Flag(vals map[string]int64) string { - var buf bytes.Buffer - i := 0 - for k, v := range vals { - if i > 0 { - buf.WriteRune(',') - } - buf.WriteString(k) - buf.WriteRune('=') - buf.WriteString(strconv.FormatInt(v, 10)) - i++ - } - return buf.String() -} - -func TestEmptyS2I64(t *testing.T) { - var s2i map[string]int64 - f := setUpS2I64FlagSet(&s2i) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - - getS2I, err := f.GetStringToInt64("s2i") - if err != nil { - t.Fatal("got an error from GetStringToInt64():", err) - } - if len(getS2I) != 0 { - t.Fatalf("got s2i %v with len=%d but expected length=0", getS2I, len(getS2I)) - } -} - -func TestS2I64(t *testing.T) { - var s2i map[string]int64 - f := setUpS2I64FlagSet(&s2i) - - vals := map[string]int64{"a": 1, "b": 2, "d": 4, "c": 3} - arg := fmt.Sprintf("--s2i=%s", createS2I64Flag(vals)) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for k, v := range s2i { - if vals[k] != v { - t.Fatalf("expected s2i[%s] to be %d but got: %d", k, vals[k], v) - } - } - getS2I, err := f.GetStringToInt64("s2i") - if err != nil { - t.Fatalf("got error: %v", err) - } - for k, v := range getS2I { - if vals[k] != v { - t.Fatalf("expected s2i[%s] to be %d but got: %d from GetStringToInt64", k, vals[k], v) - } - } -} -func TestS2I64Default(t *testing.T) { - var s2i map[string]int64 - f := setUpS2I64FlagSetWithDefault(&s2i) - - vals := map[string]int64{"a": 1, "b": 2} - - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for k, v := range s2i { - if vals[k] != v { - t.Fatalf("expected s2i[%s] to be %d but got: %d", k, vals[k], v) - } - } - - getS2I, err := f.GetStringToInt64("s2i") - if err != nil { - t.Fatal("got an error from GetStringToInt64():", err) - } - for k, v := range getS2I { - if vals[k] != v { - t.Fatalf("expected s2i[%s] to be %d from GetStringToInt64 but got: %d", k, vals[k], v) - } - } -} - -func TestS2I64WithDefault(t *testing.T) { - var s2i map[string]int64 - f := setUpS2I64FlagSetWithDefault(&s2i) - - vals := map[string]int64{"a": 1, "b": 2} - arg := fmt.Sprintf("--s2i=%s", createS2I64Flag(vals)) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for k, v := range s2i { - if vals[k] != v { - t.Fatalf("expected s2i[%s] to be %d but got: %d", k, vals[k], v) - } - } - - getS2I, err := f.GetStringToInt64("s2i") - if err != nil { - t.Fatal("got an error from GetStringToInt64():", err) - } - for k, v := range getS2I { - if vals[k] != v { - t.Fatalf("expected s2i[%s] to be %d from GetStringToInt64 but got: %d", k, vals[k], v) - } - } -} + "github.com/stretchr/testify/require" +) -func TestS2I64CalledTwice(t *testing.T) { - var s2i map[string]int64 - f := setUpS2I64FlagSet(&s2i) - - in := []string{"a=1,b=2", "b=3"} - expected := map[string]int64{"a": 1, "b": 3} - argfmt := "--s2i=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range s2i { - if expected[i] != v { - t.Fatalf("expected s2i[%s] to be %d but got: %d", i, expected[i], v) +func TestMapInt64(t *testing.T) { + newFlag := func(s2ip *map[string]int64) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.StringToInt64Var(s2ip, "s2i", map[string]int64{}, "Command separated ls2it!") + return f + } + + createFlag := func(vals map[string]int64) string { + var buf bytes.Buffer + i := 0 + for k, v := range vals { + if i > 0 { + buf.WriteRune(',') + } + buf.WriteString(k) + buf.WriteRune('=') + buf.WriteString(strconv.FormatInt(v, 10)) + i++ } - } + return buf.String() + } + + t.Run("with empty map", func(t *testing.T) { + s2i := make(map[string]int64, 0) + f := newFlag(&s2i) + require.NoError(t, f.Parse([]string{})) + + getS2I, err := f.GetStringToInt64("s2i") + require.NoErrorf(t, err, + "got an error from GetStringToInt64(): %v", err, + ) + require.Empty(t, getS2I) + }) + + t.Run("with values", func(t *testing.T) { + vals := map[string]int64{"a": 1, "b": 2, "d": 4, "c": 3} + s2i := make(map[string]int64, len(vals)) + f := newFlag(&s2i) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--s2i=%s", createFlag(vals)), + })) + require.Equal(t, vals, s2i) + + getS2I, err := f.GetStringToInt64("s2i") + require.NoError(t, err) + require.Equal(t, vals, getS2I) + }) + + newFlagWithDefault := func(s2ip *map[string]int64) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.StringToInt64Var(s2ip, "s2i", map[string]int64{"a": 1, "b": 2}, "Command separated ls2it!") + return f + } + + t.Run("with defaults (1)", func(t *testing.T) { + vals := map[string]int64{"a": 1, "b": 2} + s2i := make(map[string]int64, len(vals)) + f := newFlagWithDefault(&s2i) + + require.NoError(t, f.Parse([]string{})) + require.Equal(t, vals, s2i) + + getS2I, err := f.GetStringToInt64("s2i") + require.NoError(t, err) + require.Equal(t, vals, getS2I) + }) + + t.Run("with defaults (2)", func(t *testing.T) { + vals := map[string]int64{"a": 1, "b": 2} + s2i := make(map[string]int64, len(vals)) + f := newFlagWithDefault(&s2i) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--s2i=%s", createFlag(vals)), + })) + require.Equal(t, vals, s2i) + + getS2I, err := f.GetStringToInt64("s2i") + require.NoError(t, err) + require.Equal(t, vals, getS2I) + }) + + t.Run("called twice", func(t *testing.T) { + const argfmt = "--s2i=%s" + in := []string{"a=1,b=2", "b=3"} + s2i := make(map[string]int64, len(in)) + f := newFlag(&s2i) + expected := map[string]int64{"a": 1, "b": 3} + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + require.Equal(t, expected, s2i) + }) } diff --git a/string_to_int_test.go b/string_to_int_test.go index b60bbafb..3ef00d0b 100644 --- a/string_to_int_test.go +++ b/string_to_int_test.go @@ -9,148 +9,103 @@ import ( "fmt" "strconv" "testing" -) - -func setUpS2IFlagSet(s2ip *map[string]int) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.StringToIntVar(s2ip, "s2i", map[string]int{}, "Command separated ls2it!") - return f -} - -func setUpS2IFlagSetWithDefault(s2ip *map[string]int) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.StringToIntVar(s2ip, "s2i", map[string]int{"a": 1, "b": 2}, "Command separated ls2it!") - return f -} - -func createS2IFlag(vals map[string]int) string { - var buf bytes.Buffer - i := 0 - for k, v := range vals { - if i > 0 { - buf.WriteRune(',') - } - buf.WriteString(k) - buf.WriteRune('=') - buf.WriteString(strconv.Itoa(v)) - i++ - } - return buf.String() -} - -func TestEmptyS2I(t *testing.T) { - var s2i map[string]int - f := setUpS2IFlagSet(&s2i) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - - getS2I, err := f.GetStringToInt("s2i") - if err != nil { - t.Fatal("got an error from GetStringToInt():", err) - } - if len(getS2I) != 0 { - t.Fatalf("got s2i %v with len=%d but expected length=0", getS2I, len(getS2I)) - } -} - -func TestS2I(t *testing.T) { - var s2i map[string]int - f := setUpS2IFlagSet(&s2i) - - vals := map[string]int{"a": 1, "b": 2, "d": 4, "c": 3} - arg := fmt.Sprintf("--s2i=%s", createS2IFlag(vals)) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for k, v := range s2i { - if vals[k] != v { - t.Fatalf("expected s2i[%s] to be %d but got: %d", k, vals[k], v) - } - } - getS2I, err := f.GetStringToInt("s2i") - if err != nil { - t.Fatalf("got error: %v", err) - } - for k, v := range getS2I { - if vals[k] != v { - t.Fatalf("expected s2i[%s] to be %d but got: %d from GetStringToInt", k, vals[k], v) - } - } -} -func TestS2IDefault(t *testing.T) { - var s2i map[string]int - f := setUpS2IFlagSetWithDefault(&s2i) - - vals := map[string]int{"a": 1, "b": 2} - - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for k, v := range s2i { - if vals[k] != v { - t.Fatalf("expected s2i[%s] to be %d but got: %d", k, vals[k], v) - } - } - - getS2I, err := f.GetStringToInt("s2i") - if err != nil { - t.Fatal("got an error from GetStringToInt():", err) - } - for k, v := range getS2I { - if vals[k] != v { - t.Fatalf("expected s2i[%s] to be %d from GetStringToInt but got: %d", k, vals[k], v) - } - } -} - -func TestS2IWithDefault(t *testing.T) { - var s2i map[string]int - f := setUpS2IFlagSetWithDefault(&s2i) - - vals := map[string]int{"a": 1, "b": 2} - arg := fmt.Sprintf("--s2i=%s", createS2IFlag(vals)) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for k, v := range s2i { - if vals[k] != v { - t.Fatalf("expected s2i[%s] to be %d but got: %d", k, vals[k], v) - } - } - - getS2I, err := f.GetStringToInt("s2i") - if err != nil { - t.Fatal("got an error from GetStringToInt():", err) - } - for k, v := range getS2I { - if vals[k] != v { - t.Fatalf("expected s2i[%s] to be %d from GetStringToInt but got: %d", k, vals[k], v) - } - } -} + "github.com/stretchr/testify/require" +) -func TestS2ICalledTwice(t *testing.T) { - var s2i map[string]int - f := setUpS2IFlagSet(&s2i) - - in := []string{"a=1,b=2", "b=3"} - expected := map[string]int{"a": 1, "b": 3} - argfmt := "--s2i=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range s2i { - if expected[i] != v { - t.Fatalf("expected s2i[%s] to be %d but got: %d", i, expected[i], v) +func TestMapInt(t *testing.T) { + t.Parallel() + + newFlag := func(s2ip *map[string]int) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.StringToIntVar(s2ip, "s2i", map[string]int{}, "Command separated ls2it!") + return f + } + + createFlag := func(vals map[string]int) string { + var buf bytes.Buffer + i := 0 + for k, v := range vals { + if i > 0 { + buf.WriteRune(',') + } + buf.WriteString(k) + buf.WriteRune('=') + buf.WriteString(strconv.Itoa(v)) + i++ } - } + return buf.String() + } + + t.Run("with empty map", func(t *testing.T) { + s2i := make(map[string]int, 0) + f := newFlag(&s2i) + require.NoError(t, f.Parse([]string{})) + + getS2I, err := f.GetStringToInt("s2i") + require.NoErrorf(t, err, + "got an error from GetStringToInt(): %v", err, + ) + require.Empty(t, getS2I) + }) + + t.Run("with value", func(t *testing.T) { + vals := map[string]int{"a": 1, "b": 2, "d": 4, "c": 3} + s2i := make(map[string]int, len(vals)) + f := newFlag(&s2i) + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--s2i=%s", createFlag(vals)), + })) + require.Equal(t, vals, s2i) + + getS2I, err := f.GetStringToInt("s2i") + require.NoError(t, err) + require.Equal(t, vals, getS2I) + }) + + newFlagWithDefault := func(s2ip *map[string]int) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.StringToIntVar(s2ip, "s2i", map[string]int{"a": 1, "b": 2}, "Command separated ls2it!") + return f + } + + t.Run("with defaults (1)", func(t *testing.T) { + vals := map[string]int{"a": 1, "b": 2} + s2i := make(map[string]int, len(vals)) + f := newFlagWithDefault(&s2i) + require.NoError(t, f.Parse([]string{})) + require.Equal(t, vals, s2i) + + getS2I, err := f.GetStringToInt("s2i") + require.NoError(t, err) + require.Equal(t, vals, getS2I) + }) + + t.Run("with defaults (2)", func(t *testing.T) { + vals := map[string]int{"a": 1, "b": 2} + s2i := make(map[string]int, len(vals)) + f := newFlagWithDefault(&s2i) + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--s2i=%s", createFlag(vals)), + })) + require.Equal(t, vals, s2i) + + getS2I, err := f.GetStringToInt("s2i") + require.NoError(t, err) + require.Equal(t, vals, getS2I) + }) + + t.Run("called twice", func(t *testing.T) { + const argfmt = "--s2i=%s" + in := []string{"a=1,b=2", "b=3"} + s2i := make(map[string]int, len(in)) + f := newFlag(&s2i) + expected := map[string]int{"a": 1, "b": 3} + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + require.Equal(t, expected, s2i) + }) } diff --git a/string_to_string_test.go b/string_to_string_test.go index 0777f03f..09612e84 100644 --- a/string_to_string_test.go +++ b/string_to_string_test.go @@ -10,153 +10,105 @@ import ( "fmt" "strings" "testing" -) - -func setUpS2SFlagSet(s2sp *map[string]string) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.StringToStringVar(s2sp, "s2s", map[string]string{}, "Command separated ls2st!") - return f -} - -func setUpS2SFlagSetWithDefault(s2sp *map[string]string) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.StringToStringVar(s2sp, "s2s", map[string]string{"da": "1", "db": "2", "de": "5,6"}, "Command separated ls2st!") - return f -} - -func createS2SFlag(vals map[string]string) string { - records := make([]string, 0, len(vals)>>1) - for k, v := range vals { - records = append(records, k+"="+v) - } - - var buf bytes.Buffer - w := csv.NewWriter(&buf) - if err := w.Write(records); err != nil { - panic(err) - } - w.Flush() - return strings.TrimSpace(buf.String()) -} - -func TestEmptyS2S(t *testing.T) { - var s2s map[string]string - f := setUpS2SFlagSet(&s2s) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - - getS2S, err := f.GetStringToString("s2s") - if err != nil { - t.Fatal("got an error from GetStringToString():", err) - } - if len(getS2S) != 0 { - t.Fatalf("got s2s %v with len=%d but expected length=0", getS2S, len(getS2S)) - } -} - -func TestS2S(t *testing.T) { - var s2s map[string]string - f := setUpS2SFlagSet(&s2s) - - vals := map[string]string{"a": "1", "b": "2", "d": "4", "c": "3", "e": "5,6"} - arg := fmt.Sprintf("--s2s=%s", createS2SFlag(vals)) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for k, v := range s2s { - if vals[k] != v { - t.Fatalf("expected s2s[%s] to be %s but got: %s", k, vals[k], v) - } - } - getS2S, err := f.GetStringToString("s2s") - if err != nil { - t.Fatalf("got error: %v", err) - } - for k, v := range getS2S { - if vals[k] != v { - t.Fatalf("expected s2s[%s] to be %s but got: %s from GetStringToString", k, vals[k], v) - } - } -} -func TestS2SDefault(t *testing.T) { - var s2s map[string]string - f := setUpS2SFlagSetWithDefault(&s2s) - - vals := map[string]string{"da": "1", "db": "2", "de": "5,6"} - - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for k, v := range s2s { - if vals[k] != v { - t.Fatalf("expected s2s[%s] to be %s but got: %s", k, vals[k], v) - } - } - - getS2S, err := f.GetStringToString("s2s") - if err != nil { - t.Fatal("got an error from GetStringToString():", err) - } - for k, v := range getS2S { - if vals[k] != v { - t.Fatalf("expected s2s[%s] to be %s from GetStringToString but got: %s", k, vals[k], v) - } - } -} + "github.com/stretchr/testify/require" +) -func TestS2SWithDefault(t *testing.T) { - var s2s map[string]string - f := setUpS2SFlagSetWithDefault(&s2s) +func TestMapString(t *testing.T) { + t.Parallel() - vals := map[string]string{"a": "1", "b": "2", "e": "5,6"} - arg := fmt.Sprintf("--s2s=%s", createS2SFlag(vals)) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for k, v := range s2s { - if vals[k] != v { - t.Fatalf("expected s2s[%s] to be %s but got: %s", k, vals[k], v) - } + newFlag := func(s2sp *map[string]string) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.StringToStringVar(s2sp, "s2s", map[string]string{}, "Command separated ls2st!") + return f } - getS2S, err := f.GetStringToString("s2s") - if err != nil { - t.Fatal("got an error from GetStringToString():", err) - } - for k, v := range getS2S { - if vals[k] != v { - t.Fatalf("expected s2s[%s] to be %s from GetStringToString but got: %s", k, vals[k], v) + createFlag := func(vals map[string]string) string { + records := make([]string, 0, len(vals)>>1) + for k, v := range vals { + records = append(records, k+"="+v) } - } -} -func TestS2SCalledTwice(t *testing.T) { - var s2s map[string]string - f := setUpS2SFlagSet(&s2s) - - in := []string{"a=1,b=2", "b=3", `"e=5,6"`, `f=7,8`} - expected := map[string]string{"a": "1", "b": "3", "e": "5,6", "f": "7,8"} - argfmt := "--s2s=%s" - arg0 := fmt.Sprintf(argfmt, in[0]) - arg1 := fmt.Sprintf(argfmt, in[1]) - arg2 := fmt.Sprintf(argfmt, in[2]) - arg3 := fmt.Sprintf(argfmt, in[3]) - err := f.Parse([]string{arg0, arg1, arg2, arg3}) - if err != nil { - t.Fatal("expected no error; got", err) - } - if len(s2s) != len(expected) { - t.Fatalf("expected %d flags; got %d flags", len(expected), len(s2s)) - } - for i, v := range s2s { - if expected[i] != v { - t.Fatalf("expected s2s[%s] to be %s but got: %s", i, expected[i], v) + var buf bytes.Buffer + w := csv.NewWriter(&buf) + if err := w.Write(records); err != nil { + panic(err) } - } + w.Flush() + return strings.TrimSpace(buf.String()) + } + + t.Run("with empty map", func(t *testing.T) { + s2s := make(map[string]string, 0) + f := newFlag(&s2s) + require.NoError(t, f.Parse([]string{})) + + getS2S, err := f.GetStringToString("s2s") + require.NoError(t, err) + require.Empty(t, getS2S) + }) + + t.Run("with value", func(t *testing.T) { + vals := map[string]string{"a": "1", "b": "2", "d": "4", "c": "3", "e": "5,6"} + s2s := make(map[string]string, len(vals)) + f := newFlag(&s2s) + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--s2s=%s", createFlag(vals)), + })) + require.Equal(t, vals, s2s) + + getS2S, err := f.GetStringToString("s2s") + require.NoError(t, err) + require.Equal(t, vals, getS2S) + }) + + newFlagWithDefault := func(s2sp *map[string]string) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.StringToStringVar(s2sp, "s2s", map[string]string{"da": "1", "db": "2", "de": "5,6"}, "Command separated ls2st!") + return f + } + + t.Run("with defaults (1)", func(t *testing.T) { + vals := map[string]string{"da": "1", "db": "2", "de": "5,6"} + s2s := make(map[string]string, len(vals)) + f := newFlagWithDefault(&s2s) + + require.NoError(t, f.Parse([]string{})) + require.Equal(t, vals, s2s) + + getS2S, err := f.GetStringToString("s2s") + require.NoError(t, err) + require.Equal(t, vals, getS2S) + }) + + t.Run("with defaults (2)", func(t *testing.T) { + vals := map[string]string{"a": "1", "b": "2", "e": "5,6"} + s2s := make(map[string]string, len(vals)) + f := newFlagWithDefault(&s2s) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--s2s=%s", createFlag(vals)), + })) + require.Equal(t, vals, s2s) + + getS2S, err := f.GetStringToString("s2s") + require.NoError(t, err) + require.Equal(t, vals, getS2S) + }) + + t.Run("called twice", func(t *testing.T) { + const argfmt = "--s2s=%s" + in := []string{"a=1,b=2", "b=3", `"e=5,6"`, `f=7,8`} + s2s := make(map[string]string, len(in)) + f := newFlag(&s2s) + expected := map[string]string{"a": "1", "b": "3", "e": "5,6", "f": "7,8"} + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + fmt.Sprintf(argfmt, in[2]), + fmt.Sprintf(argfmt, in[3]), + })) + require.Equal(t, expected, s2s) + }) } diff --git a/uint.go b/uint.go index dcbc2b75..b8671fb8 100644 --- a/uint.go +++ b/uint.go @@ -1,3 +1,4 @@ +// nolint: dupl package pflag import "strconv" diff --git a/uint16.go b/uint16.go index 7e9914ed..8774c5a2 100644 --- a/uint16.go +++ b/uint16.go @@ -1,3 +1,4 @@ +// nolint: dupl package pflag import "strconv" diff --git a/uint32.go b/uint32.go index d8024539..8ed0fdd6 100644 --- a/uint32.go +++ b/uint32.go @@ -1,3 +1,4 @@ +// nolint: dupl package pflag import "strconv" diff --git a/uint64.go b/uint64.go index f62240f2..86d8c7e6 100644 --- a/uint64.go +++ b/uint64.go @@ -27,7 +27,7 @@ func uint64Conv(sval string) (interface{}, error) { if err != nil { return 0, err } - return uint64(v), nil + return v, nil } // GetUint64 return the uint64 value of a flag with the given name diff --git a/uint8.go b/uint8.go index bb0e83c1..4a48d494 100644 --- a/uint8.go +++ b/uint8.go @@ -1,3 +1,4 @@ +// nolint: dupl package pflag import "strconv" diff --git a/uint_slice_test.go b/uint_slice_test.go index d0da4d07..36c3f69c 100644 --- a/uint_slice_test.go +++ b/uint_slice_test.go @@ -5,180 +5,147 @@ import ( "strconv" "strings" "testing" -) -func setUpUISFlagSet(uisp *[]uint) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.UintSliceVar(uisp, "uis", []uint{}, "Command separated list!") - return f -} + "github.com/stretchr/testify/require" +) -func setUpUISFlagSetWithDefault(uisp *[]uint) *FlagSet { - f := NewFlagSet("test", ContinueOnError) - f.UintSliceVar(uisp, "uis", []uint{0, 1}, "Command separated list!") - return f -} +func TestUintSlice(t *testing.T) { + t.Parallel() -func TestEmptyUIS(t *testing.T) { - var uis []uint - f := setUpUISFlagSet(&uis) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) + newFlag := func(uisp *[]uint) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.UintSliceVar(uisp, "uis", []uint{}, "Command separated list!") + return f } - getUIS, err := f.GetUintSlice("uis") - if err != nil { - t.Fatal("got an error from GetUintSlice():", err) - } - if len(getUIS) != 0 { - t.Fatalf("got is %v with len=%d but expected length=0", getUIS, len(getUIS)) - } -} + t.Run("with empty slice", func(t *testing.T) { + uis := make([]uint, 0) + f := newFlag(&uis) + require.NoError(t, f.Parse([]string{})) -func TestUIS(t *testing.T) { - var uis []uint - f := setUpUISFlagSet(&uis) + getUIS, err := f.GetUintSlice("uis") + require.NoErrorf(t, err, + "got an error from GetUintSlice(): %v", err, + ) + require.Empty(t, getUIS) + }) - vals := []string{"1", "2", "4", "3"} - arg := fmt.Sprintf("--uis=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range uis { - u, err := strconv.ParseUint(vals[i], 10, 0) - if err != nil { - t.Fatalf("got error: %v", err) - } - if uint(u) != v { - t.Fatalf("expected uis[%d] to be %s but got %d", i, vals[i], v) - } - } - getUIS, err := f.GetUintSlice("uis") - if err != nil { - t.Fatalf("got error: %v", err) - } - for i, v := range getUIS { - u, err := strconv.ParseUint(vals[i], 10, 0) - if err != nil { - t.Fatalf("got error: %v", err) + t.Run("with values", func(t *testing.T) { + vals := []string{"1", "2", "4", "3"} + uis := make([]uint, 0, len(vals)) + f := newFlag(&uis) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--uis=%s", strings.Join(vals, ",")), + })) + + for i, v := range uis { + u, err := strconv.ParseUint(vals[i], 10, 0) + require.NoError(t, err) + require.Equal(t, v, uint(u)) } - if uint(u) != v { - t.Fatalf("expected uis[%d] to be %s but got: %d from GetUintSlice", i, vals[i], v) + + getUIS, eru := f.GetUintSlice("uis") + require.NoError(t, eru) + + for i, v := range getUIS { + u, err := strconv.ParseUint(vals[i], 10, 0) + require.NoError(t, err) + require.Equal(t, v, uint(u)) } + }) + + newFlagWithDefault := func(uisp *[]uint) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.UintSliceVar(uisp, "uis", []uint{0, 1}, "Command separated list!") + return f } -} -func TestUISDefault(t *testing.T) { - var uis []uint - f := setUpUISFlagSetWithDefault(&uis) + t.Run("with defaults (1)", func(t *testing.T) { + vals := []string{"0", "1"} + uis := make([]uint, 0, len(vals)) + f := newFlagWithDefault(&uis) - vals := []string{"0", "1"} + require.NoError(t, f.Parse([]string{})) - err := f.Parse([]string{}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range uis { - u, err := strconv.ParseUint(vals[i], 10, 0) - if err != nil { - t.Fatalf("got error: %v", err) + for i, v := range uis { + u, err := strconv.ParseUint(vals[i], 10, 0) + require.NoError(t, err) + require.Equal(t, v, uint(u)) } - if uint(u) != v { - t.Fatalf("expect uis[%d] to be %d but got: %d", i, u, v) - } - } - getUIS, err := f.GetUintSlice("uis") - if err != nil { - t.Fatal("got an error from GetUintSlice():", err) - } - for i, v := range getUIS { - u, err := strconv.ParseUint(vals[i], 10, 0) - if err != nil { - t.Fatal("got an error from GetIntSlice():", err) - } - if uint(u) != v { - t.Fatalf("expected uis[%d] to be %d from GetUintSlice but got: %d", i, u, v) + getUIS, eru := f.GetUintSlice("uis") + require.NoErrorf(t, eru, + "got an error from GetUintSlice(): %v", eru, + ) + + for i, v := range getUIS { + u, err := strconv.ParseUint(vals[i], 10, 0) + require.NoError(t, err) + require.Equal(t, v, uint(u)) } - } -} + }) -func TestUISWithDefault(t *testing.T) { - var uis []uint - f := setUpUISFlagSetWithDefault(&uis) + t.Run("with defaults (2)", func(t *testing.T) { + vals := []string{"1", "2"} + uis := make([]uint, 0, len(vals)) + f := newFlagWithDefault(&uis) - vals := []string{"1", "2"} - arg := fmt.Sprintf("--uis=%s", strings.Join(vals, ",")) - err := f.Parse([]string{arg}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range uis { - u, err := strconv.ParseUint(vals[i], 10, 0) - if err != nil { - t.Fatalf("got error: %v", err) - } - if uint(u) != v { - t.Fatalf("expected uis[%d] to be %d from GetUintSlice but got: %d", i, u, v) - } - } + require.NoError(t, f.Parse([]string{ + fmt.Sprintf("--uis=%s", strings.Join(vals, ",")), + })) - getUIS, err := f.GetUintSlice("uis") - if err != nil { - t.Fatal("got an error from GetUintSlice():", err) - } - for i, v := range getUIS { - u, err := strconv.ParseUint(vals[i], 10, 0) - if err != nil { - t.Fatalf("got error: %v", err) - } - if uint(u) != v { - t.Fatalf("expected uis[%d] to be %d from GetUintSlice but got: %d", i, u, v) + for i, v := range uis { + u, err := strconv.ParseUint(vals[i], 10, 0) + require.NoError(t, err) + require.Equal(t, v, uint(u)) } - } -} -func TestUISAsSliceValue(t *testing.T) { - var uis []uint - f := setUpUISFlagSet(&uis) - - in := []string{"1", "2"} - argfmt := "--uis=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } + getUIS, eru := f.GetUintSlice("uis") + require.NoErrorf(t, eru, + "got an error from GetUintSlice(): %v", eru, + ) - f.VisitAll(func(f *Flag) { - if val, ok := f.Value.(SliceValue); ok { - _ = val.Replace([]string{"3"}) + for i, v := range getUIS { + u, err := strconv.ParseUint(vals[i], 10, 0) + require.NoError(t, err) + require.Equal(t, v, uint(u)) } }) - if len(uis) != 1 || uis[0] != 3 { - t.Fatalf("Expected ss to be overwritten with '3.1', but got: %v", uis) - } -} -func TestUISCalledTwice(t *testing.T) { - var uis []uint - f := setUpUISFlagSet(&uis) - - in := []string{"1,2", "3"} - expected := []int{1, 2, 3} - argfmt := "--uis=%s" - arg1 := fmt.Sprintf(argfmt, in[0]) - arg2 := fmt.Sprintf(argfmt, in[1]) - err := f.Parse([]string{arg1, arg2}) - if err != nil { - t.Fatal("expected no error; got", err) - } - for i, v := range uis { - if uint(expected[i]) != v { - t.Fatalf("expected uis[%d] to be %d but got: %d", i, expected[i], v) - } - } + t.Run("called twice", func(t *testing.T) { + const argfmt = "--uis=%s" + in := []string{"1,2", "3"} + uis := make([]uint, 0, len(in)) + f := newFlag(&uis) + expected := []uint{1, 2, 3} + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + require.Equal(t, expected, uis) + }) + + t.Run("as slice value", func(t *testing.T) { + const argfmt = "--uis=%s" + in := []string{"1", "2"} + uis := make([]uint, 0, len(in)) + f := newFlag(&uis) + + require.NoError(t, f.Parse([]string{ + fmt.Sprintf(argfmt, in[0]), + fmt.Sprintf(argfmt, in[1]), + })) + + f.VisitAll(func(f *Flag) { + if val, ok := f.Value.(SliceValue); ok { + require.NoError(t, val.Replace([]string{"3"})) + } + }) + require.Equalf(t, []uint{3}, uis, + "expected ss to be overwritten with '3.1', but got: %v", uis, + ) + }) }