Skip to content

Commit

Permalink
feat: support excluding types from comparison
Browse files Browse the repository at this point in the history
This is supported by extending the `msgAndArgs...any` variadic parameter
to `msgArgsAndCompareOptions...any`. Any argument that is an
`assert.CompareOption` will be passed to `assert.Compare`.

eg.

```go
assert.Equal(t, a, b, "Unequal! %#v != %#v", a, b, assert.Exclude[time.Time]())
```
  • Loading branch information
alecthomas committed Oct 3, 2023
1 parent fed5290 commit 0a3ee63
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 38 deletions.
115 changes: 78 additions & 37 deletions assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,52 +15,62 @@ import (
"github.com/hexops/gotextdiff/myers"
)

func objectsAreEqual(expected, actual interface{}) bool {
if expected == nil || actual == nil {
return expected == actual
}
// A CompareOption modifies how object comparisons behave.
type CompareOption func() []repr.Option

exp, eok := expected.([]byte)
act, aok := actual.([]byte)

if eok && aok {
return bytes.Equal(exp, act)
// Exclude fields of the given type from comparison.
func Exclude[T any]() CompareOption {
return func() []repr.Option {
return []repr.Option{repr.Hide[T]()}
}

return reflect.DeepEqual(expected, actual)
}

// Compare two values for equality and return true or false.
func Compare[T any](t testing.TB, x, y T) bool {
return objectsAreEqual(x, y)
func Compare[T any](t testing.TB, x, y T, options ...CompareOption) bool {
return objectsAreEqual(x, y, options...)
}

func extractCompareOptions(msgAndArgs ...any) ([]any, []CompareOption) {
compareOptions := []CompareOption{}
out := []any{}
for _, arg := range msgAndArgs {
if opt, ok := arg.(CompareOption); ok {
compareOptions = append(compareOptions, opt)
} else {
out = append(out, arg)
}
}
return out, compareOptions
}

// Equal asserts that "expected" and "actual" are equal.
//
// If they are not, a diff of the Go representation of the values will be displayed.
func Equal[T any](t testing.TB, expected, actual T, msgAndArgs ...interface{}) {
if objectsAreEqual(expected, actual) {
func Equal[T any](t testing.TB, expected, actual T, msgArgsAndCompareOptions ...any) {
msgArgsAndCompareOptions, compareOptions := extractCompareOptions(msgArgsAndCompareOptions...)
if objectsAreEqual(expected, actual, compareOptions...) {
return
}
t.Helper()
msg := formatMsgAndArgs("Expected values to be equal:", msgAndArgs...)
t.Fatalf("%s\n%s", msg, diff(expected, actual))
msg := formatMsgAndArgs("Expected values to be equal:", msgArgsAndCompareOptions...)
t.Fatalf("%s\n%s", msg, diff(expected, actual, compareOptions...))
}

// NotEqual asserts that "expected" is not equal to "actual".
//
// If they are equal the expected value will be displayed.
func NotEqual[T any](t testing.TB, expected, actual T, msgAndArgs ...interface{}) {
if !objectsAreEqual(expected, actual) {
func NotEqual[T any](t testing.TB, expected, actual T, msgArgsAndCompareOptions ...any) {
msgArgsAndCompareOptions, compareOptions := extractCompareOptions(msgArgsAndCompareOptions...)
if !objectsAreEqual(expected, actual, compareOptions...) {
return
}
t.Helper()
msg := formatMsgAndArgs("Expected values to not be equal but both were:", msgAndArgs...)
msg := formatMsgAndArgs("Expected values to not be equal but both were:", msgArgsAndCompareOptions...)
t.Fatalf("%s\n%s", msg, repr.String(expected, repr.Indent(" ")))
}

// Contains asserts that "haystack" contains "needle".
func Contains(t testing.TB, haystack string, needle string, msgAndArgs ...interface{}) {
func Contains(t testing.TB, haystack string, needle string, msgAndArgs ...any) {
if strings.Contains(haystack, needle) {
return
}
Expand All @@ -70,7 +80,7 @@ func Contains(t testing.TB, haystack string, needle string, msgAndArgs ...interf
}

// NotContains asserts that "haystack" does not contain "needle".
func NotContains(t testing.TB, haystack string, needle string, msgAndArgs ...interface{}) {
func NotContains(t testing.TB, haystack string, needle string, msgAndArgs ...any) {
if !strings.Contains(haystack, needle) {
return
}
Expand All @@ -81,7 +91,7 @@ func NotContains(t testing.TB, haystack string, needle string, msgAndArgs ...int
}

// Zero asserts that a value is its zero value.
func Zero[T any](t testing.TB, value T, msgAndArgs ...interface{}) {
func Zero[T any](t testing.TB, value T, msgAndArgs ...any) {
var zero T
if objectsAreEqual(value, zero) {
return
Expand All @@ -96,7 +106,7 @@ func Zero[T any](t testing.TB, value T, msgAndArgs ...interface{}) {
}

// NotZero asserts that a value is not its zero value.
func NotZero[T any](t testing.TB, value T, msgAndArgs ...interface{}) {
func NotZero[T any](t testing.TB, value T, msgAndArgs ...any) {
var zero T
if !objectsAreEqual(value, zero) {
val := reflect.ValueOf(value)
Expand All @@ -111,7 +121,7 @@ func NotZero[T any](t testing.TB, value T, msgAndArgs ...interface{}) {

// EqualError asserts that either an error is non-nil and that its message is what is expected,
// or that error is nil if the expected message is empty.
func EqualError(t testing.TB, err error, errString string, msgAndArgs ...interface{}) {
func EqualError(t testing.TB, err error, errString string, msgAndArgs ...any) {
if err == nil && errString == "" {
return
}
Expand All @@ -126,7 +136,7 @@ func EqualError(t testing.TB, err error, errString string, msgAndArgs ...interfa
}

// IsError asserts than any error in "err"'s tree matches "target".
func IsError(t testing.TB, err, target error, msgAndArgs ...interface{}) {
func IsError(t testing.TB, err, target error, msgAndArgs ...any) {
if errors.Is(err, target) {
return
}
Expand All @@ -135,7 +145,7 @@ func IsError(t testing.TB, err, target error, msgAndArgs ...interface{}) {
}

// NotIsError asserts than no error in "err"'s tree matches "target".
func NotIsError(t testing.TB, err, target error, msgAndArgs ...interface{}) {
func NotIsError(t testing.TB, err, target error, msgAndArgs ...any) {
if !errors.Is(err, target) {
return
}
Expand All @@ -144,7 +154,7 @@ func NotIsError(t testing.TB, err, target error, msgAndArgs ...interface{}) {
}

// Error asserts that an error is not nil.
func Error(t testing.TB, err error, msgAndArgs ...interface{}) {
func Error(t testing.TB, err error, msgAndArgs ...any) {
if err != nil {
return
}
Expand All @@ -153,7 +163,7 @@ func Error(t testing.TB, err error, msgAndArgs ...interface{}) {
}

// NoError asserts that an error is nil.
func NoError(t testing.TB, err error, msgAndArgs ...interface{}) {
func NoError(t testing.TB, err error, msgAndArgs ...any) {
if err == nil {
return
}
Expand All @@ -163,7 +173,7 @@ func NoError(t testing.TB, err error, msgAndArgs ...interface{}) {
}

// True asserts that an expression is true.
func True(t testing.TB, ok bool, msgAndArgs ...interface{}) {
func True(t testing.TB, ok bool, msgAndArgs ...any) {
if ok {
return
}
Expand All @@ -172,7 +182,7 @@ func True(t testing.TB, ok bool, msgAndArgs ...interface{}) {
}

// False asserts that an expression is false.
func False(t testing.TB, ok bool, msgAndArgs ...interface{}) {
func False(t testing.TB, ok bool, msgAndArgs ...any) {
if !ok {
return
}
Expand All @@ -181,7 +191,7 @@ func False(t testing.TB, ok bool, msgAndArgs ...interface{}) {
}

// Panics asserts that the given function panics.
func Panics(t testing.TB, fn func(), msgAndArgs ...interface{}) {
func Panics(t testing.TB, fn func(), msgAndArgs ...any) {
t.Helper()
defer func() {
if recover() == nil {
Expand All @@ -193,7 +203,7 @@ func Panics(t testing.TB, fn func(), msgAndArgs ...interface{}) {
}

// NotPanics asserts that the given function does not panic.
func NotPanics(t testing.TB, fn func(), msgAndArgs ...interface{}) {
func NotPanics(t testing.TB, fn func(), msgAndArgs ...any) {
t.Helper()
defer func() {
if err := recover(); err != nil {
Expand All @@ -204,15 +214,16 @@ func NotPanics(t testing.TB, fn func(), msgAndArgs ...interface{}) {
fn()
}

func diff[T any](before, after T) string {
func diff[T any](before, after T, compareOptions ...CompareOption) string {
var lhss, rhss string
// Special case strings so we get nice diffs.
if l, ok := any(before).(string); ok {
lhss = l
rhss = any(after).(string)
} else {
lhss = repr.String(before, repr.Indent(" ")) + "\n"
rhss = repr.String(after, repr.Indent(" ")) + "\n"
ropts := expandCompareOptions(compareOptions...)
lhss = repr.String(before, ropts...) + "\n"
rhss = repr.String(after, ropts...) + "\n"
}
edits := myers.ComputeEdits("a.txt", lhss, rhss)
lines := strings.Split(fmt.Sprint(gotextdiff.ToUnified("expected.txt", "actual.txt", lhss, edits)), "\n")
Expand All @@ -222,7 +233,7 @@ func diff[T any](before, after T) string {
return strings.Join(lines[3:], "\n")
}

func formatMsgAndArgs(dflt string, msgAndArgs ...interface{}) string {
func formatMsgAndArgs(dflt string, msgAndArgs ...any) string {
if len(msgAndArgs) == 0 {
return dflt
}
Expand All @@ -243,3 +254,33 @@ func needlePosition(haystack, needle string) (quotedHaystack, quotedNeedle, posi
}
return
}

func expandCompareOptions(options ...CompareOption) []repr.Option {
ropts := []repr.Option{repr.Indent(" ")}
for _, option := range options {
ropts = append(ropts, option()...)
}
return ropts
}

func objectsAreEqual(expected, actual any, options ...CompareOption) bool {
if expected == nil || actual == nil {
return expected == actual
}
if exp, eok := expected.([]byte); eok {
if act, aok := actual.([]byte); aok {
return bytes.Equal(exp, act)
}
}
if exp, eok := expected.(string); eok {
if act, aok := actual.(string); aok {
return exp == act
}
}

ropts := expandCompareOptions(options...)
expectedStr := repr.String(expected, ropts...)
actualStr := repr.String(actual, ropts...)

return expectedStr == actualStr
}
6 changes: 6 additions & 0 deletions assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ func TestEqual(t *testing.T) {
assertFail(t, "Different numbers", func(t testing.TB) {
Equal(t, 42, 43)
})
assertOk(t, "Exclude", func(t testing.TB) {
Equal(t, Data{Str: "expected", Num: 1234}, Data{Str: "expected"}, Exclude[int64]())
})
}

func TestEqualStrings(t *testing.T) {
Expand All @@ -48,6 +51,9 @@ func TestNotEqual(t *testing.T) {
assertFail(t, "SameValue", func(t testing.TB) {
NotEqual(t, Data{"expected", 1234}, Data{"expected", 1234})
})
assertFail(t, "Exclude", func(t testing.TB) {
NotEqual(t, Data{Str: "expected", Num: 1234}, Data{Str: "expected"}, Exclude[int64]())
})
}

func TestContains(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ module github.com/alecthomas/assert/v2
go 1.18

require (
github.com/alecthomas/repr v0.2.0
github.com/alecthomas/repr v0.3.0
github.com/hexops/gotextdiff v1.0.3
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk=
github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/alecthomas/repr v0.3.0 h1:NeYzUPfjjlqHY4KtzgKJiWd6sVq2eNUPTi34PiFGjY8=
github.com/alecthomas/repr v0.3.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=

0 comments on commit 0a3ee63

Please sign in to comment.