From 43aaded95b72ffd772c7e95dc54978a59809f9df Mon Sep 17 00:00:00 2001 From: Igor Lazarev Date: Sun, 19 Jun 2022 15:44:26 +0300 Subject: [PATCH] generic errors.As() method --- README.md | 3 +- errors.go | 38 +++--- errors_test.go | 206 ++++++++++++++++++++++++++++--- example_test.go | 46 +++++++ go.mod | 4 +- logging/logrusadapter/adapter.go | 18 ++- stack_test.go | 12 +- 7 files changed, 283 insertions(+), 44 deletions(-) create mode 100644 example_test.go diff --git a/README.md b/README.md index 50fbe6d..d04529d 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Key differences and features: and should be used to create sentinel package-level errors; * minimalistic API: few methods to wrap an error: `errors.Errorf()`, `errors.Wrap()`; * adds stack trace idempotently (only once in a chain); +* `errors.As()` method is based on typed parameters (aka generics); * options to skip caller in a stack trace and to add error fields for structured logging; * error fields are made for the statically typed logger interface; * package errors can be easily marshaled into JSON with all fields in a chain. @@ -210,7 +211,7 @@ logrusadapter.Log(err, logger) Output ``` -ERRO[0000] find product: sql error: sql: no rows in result set productID=123 requestID=24874020-cab7-4ef3-bac5-76858832f8b0 sql="SELECT id, name FROM product WHERE id = ?" stackTrace="[scratch.go:12 proc.go:250 asm_amd64.s:1571]" +ERRO[0000] find product: sql error: sql: no rows in result set productID=123 requestID=24874020-cab7-4ef3-bac5-76858832f8b0 sql="SELECT id, name FROM product WHERE id = ?" stackTrace="[{main.main /home/strider/projects/errors/var/scratch.go 12} {runtime.main /usr/local/go/src/runtime/proc.go 250} {runtime.goexit /usr/local/go/src/runtime/asm_amd64.s 1571}]" ``` ## Contributing diff --git a/errors.go b/errors.go index ace38ae..e8bb173 100644 --- a/errors.go +++ b/errors.go @@ -46,26 +46,34 @@ func Is(err, target error) bool { return errors.Is(err, target) } -// As finds the first error in err's chain that matches target, and if one is found, sets -// target to that error value and returns true. Otherwise, it returns false. +// As finds the first error in err's chain that matches type T, and if one is found, returns +// its value and true. Otherwise, it returns zero value and false. // // The chain consists of err itself followed by the sequence of errors obtained by // repeatedly calling Unwrap. // -// An error matches target if the error's concrete value is assignable to the value -// pointed to by target, or if the error has a method As(interface{}) bool such that -// As(target) returns true. In the latter case, the As method is responsible for -// setting target. +// An error matches target if the error's concrete value is of type T, or if the error +// has a method As(any) bool such that As(target) returns true. In the latter case, +// the As method is responsible for setting returned value. // // An error type might provide an As method so it can be treated as if it were a // different error type. -// -// As panics if target is not a non-nil pointer to either a type that implements -// error, or to any interface type. -// -// This function is an alias to standard errors.As. -func As(err error, target interface{}) bool { - return errors.As(err, target) +func As[T any](err error) (T, bool) { + for err != nil { + if t, ok := err.(T); ok { + return t, true + } + if x, ok := err.(interface{ As(any) bool }); ok { + var t T + if x.As(&t) { + return t, true + } + } + err = Unwrap(err) + } + + var z T + return z, false } // Unwrap returns the result of calling the Unwrap method on err, if err's @@ -130,9 +138,9 @@ func isWrapper(err error) bool { return false } - var w wrapper + _, ok := As[wrapper](err) - return errors.As(err, &w) + return ok } type wrapped struct { diff --git a/errors_test.go b/errors_test.go index 669c608..dc0997e 100644 --- a/errors_test.go +++ b/errors_test.go @@ -3,6 +3,8 @@ package errors_test import ( "encoding/json" "fmt" + "io/fs" + "os" "testing" "time" @@ -21,7 +23,7 @@ func TestStackTrace(t *testing.T) { err: errors.Errorf("ooh"), want: []string{ "github.com/muonsoft/errors_test.TestStackTrace\n" + - "\t.+/errors/errors_test.go:21", + "\t.+/errors/errors_test.go:23", }, }, { @@ -29,7 +31,7 @@ func TestStackTrace(t *testing.T) { err: errors.Wrap(errors.Errorf("ooh")), want: []string{ "github.com/muonsoft/errors_test.TestStackTrace\n" + - "\t.+/errors/errors_test.go:29", + "\t.+/errors/errors_test.go:31", }, }, { @@ -37,7 +39,7 @@ func TestStackTrace(t *testing.T) { err: errors.Wrap(errors.New("ooh")), want: []string{ "github.com/muonsoft/errors_test.TestStackTrace\n" + - "\t.+/errors/errors_test.go:37", + "\t.+/errors/errors_test.go:39", }, }, { @@ -45,7 +47,7 @@ func TestStackTrace(t *testing.T) { err: errors.Wrap(errors.Wrap(errors.New("ooh"))), want: []string{ "github.com/muonsoft/errors_test.TestStackTrace\n" + - "\t.+/errors/errors_test.go:45", + "\t.+/errors/errors_test.go:47", }, }, { @@ -53,7 +55,7 @@ func TestStackTrace(t *testing.T) { err: errors.Errorf("ooh"), want: []string{ "github.com/muonsoft/errors_test.TestStackTrace\n" + - "\t.+/errors/errors_test.go:53", + "\t.+/errors/errors_test.go:55", }, }, { @@ -61,7 +63,7 @@ func TestStackTrace(t *testing.T) { err: errors.Errorf("%v", errors.New("ooh")), want: []string{ "github.com/muonsoft/errors_test.TestStackTrace\n" + - "\t.+/errors/errors_test.go:61", + "\t.+/errors/errors_test.go:63", }, }, { @@ -69,7 +71,7 @@ func TestStackTrace(t *testing.T) { err: errors.Errorf("%w", errors.Wrap(errors.New("ooh"))), want: []string{ "github.com/muonsoft/errors_test.TestStackTrace\n" + - "\t.+/errors/errors_test.go:69", + "\t.+/errors/errors_test.go:71", }, }, { @@ -77,7 +79,7 @@ func TestStackTrace(t *testing.T) { err: errors.Errorf("%%w %v", errors.New("ooh")), want: []string{ "github.com/muonsoft/errors_test.TestStackTrace\n" + - "\t.+/errors/errors_test.go:77", + "\t.+/errors/errors_test.go:79", }, }, { @@ -85,7 +87,7 @@ func TestStackTrace(t *testing.T) { err: errors.Errorf("%s: %w", "prefix", errors.New("ooh")), want: []string{ "github.com/muonsoft/errors_test.TestStackTrace\n" + - "\t.+/errors/errors_test.go:85", + "\t.+/errors/errors_test.go:87", }, }, { @@ -93,7 +95,7 @@ func TestStackTrace(t *testing.T) { err: errors.Errorf("%w", errors.Errorf("%w", errors.New("ooh"))), want: []string{ "github.com/muonsoft/errors_test.TestStackTrace\n" + - "\t.+/errors/errors_test.go:93", + "\t.+/errors/errors_test.go:95", }, }, { @@ -101,7 +103,7 @@ func TestStackTrace(t *testing.T) { err: errors.Errorf("%w", fmt.Errorf("%w", errors.New("ooh"))), want: []string{ "github.com/muonsoft/errors_test.TestStackTrace\n" + - "\t.+/errors/errors_test.go:101", + "\t.+/errors/errors_test.go:103", }, }, { @@ -109,9 +111,9 @@ func TestStackTrace(t *testing.T) { err: wrap(errors.New("ooh")), want: []string{ "github.com/muonsoft/errors_test.wrap\n" + - "\t.+/errors/errors_test.go:150", + "\t.+/errors/errors_test.go:152", "github.com/muonsoft/errors_test.TestStackTrace\n" + - "\t.+/errors/errors_test.go:109", + "\t.+/errors/errors_test.go:111", }, }, { @@ -119,7 +121,7 @@ func TestStackTrace(t *testing.T) { err: wrapSkipCaller(errors.New("ooh")), want: []string{ "github.com/muonsoft/errors_test.TestStackTrace\n" + - "\t.+/errors/errors_test.go:119", + "\t.+/errors/errors_test.go:121", }, }, { @@ -127,15 +129,15 @@ func TestStackTrace(t *testing.T) { err: errorfSkipCaller("ooh"), want: []string{ "github.com/muonsoft/errors_test.TestStackTrace\n" + - "\t.+/errors/errors_test.go:127", + "\t.+/errors/errors_test.go:129", }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { assertSingleStack(t, test.err) - var stacked StackTracer - if !errors.As(test.err, &stacked) { + stacked, ok := errors.As[StackTracer](test.err) + if !ok { t.Fatalf("expected %#v to implement errors.StackTracer", test.err) } st := stacked.StackTrace() @@ -240,12 +242,12 @@ func TestFields(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var err errors.LoggableError - if !errors.As(test.err, &err) { + loggable, ok := errors.As[errors.LoggableError](test.err) + if !ok { t.Fatalf("expected %#v to implement errors.LoggableError", test.err) } logger := errorstest.NewLogger() - err.LogFields(logger) + loggable.LogFields(logger) logger.AssertField(t, "key", test.expected) }) } @@ -267,3 +269,167 @@ func TestIs(t *testing.T) { t.Error("want errors is true") } } + +func TestAs(t *testing.T) { + type timeout interface{ Timeout() bool } + _, errFileNotFound := os.Open("non-existing") + poserErr := &poser{"oh no", nil} + + tests := []struct { + name string + err error + as func(err error) (any, bool) + match bool + want any // value of target on match + }{ + { + "nil", + nil, + func(err error) (any, bool) { + return errors.As[*fs.PathError](err) + }, + false, + nil, + }, + { + "wrapped error", + wrapped{"pitied the fool", errorT{"T"}}, + func(err error) (any, bool) { + return errors.As[errorT](err) + }, + true, + errorT{"T"}, + }, + { + "match path error", + errFileNotFound, + func(err error) (any, bool) { + return errors.As[*fs.PathError](err) + }, + true, + errFileNotFound, + }, + { + "not match path error", + errorT{}, + func(err error) (any, bool) { + return errors.As[*fs.PathError](err) + }, + false, + nil, + }, + { + "wrapped nil", + wrapped{"wrapped", nil}, + func(err error) (any, bool) { + return errors.As[errorT](err) + }, + false, + nil, + }, + { + "error with matching as method", + &poser{"error", nil}, + func(err error) (any, bool) { + return errors.As[errorT](err) + }, + true, + errorT{"poser"}, + }, + { + "error with matching as method", + &poser{"path", nil}, + func(err error) (any, bool) { + return errors.As[*fs.PathError](err) + }, + true, + poserPathErr, + }, + { + "error with matching as method", + poserErr, + func(err error) (any, bool) { + return errors.As[*poser](err) + }, + true, + poserErr, + }, + { + "timeout error", + errors.New("err"), + func(err error) (any, bool) { + return errors.As[timeout](err) + }, + false, + nil, + }, + { + "file not found as timeout", + errFileNotFound, + func(err error) (any, bool) { + return errors.As[timeout](err) + }, + true, + errFileNotFound, + }, + { + "wrapped file not found as timeout", + wrapped{"path error", errFileNotFound}, + func(err error) (any, bool) { + return errors.As[timeout](err) + }, + true, + errFileNotFound, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, match := test.as(test.err) + if match != test.match { + t.Fatalf("match: got %v; want %v", match, test.match) + } + if !match { + return + } + if got != test.want { + t.Fatalf("got %#v, want %#v", got, test.want) + } + }) + } +} + +type poser struct { + msg string + f func(error) bool +} + +var poserPathErr = &fs.PathError{Op: "poser"} + +func (p *poser) Error() string { return p.msg } +func (p *poser) Is(err error) bool { return p.f(err) } +func (p *poser) As(err any) bool { + switch x := err.(type) { + case **poser: + *x = p + case *errorT: + *x = errorT{"poser"} + case **fs.PathError: + *x = poserPathErr + default: + return false + } + return true +} + +type errorT struct{ s string } + +func (e errorT) Error() string { return fmt.Sprintf("errorT(%s)", e.s) } + +type wrapped struct { + msg string + err error +} + +func (e wrapped) Error() string { return e.msg } + +func (e wrapped) Unwrap() error { return e.err } diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..9cf58d8 --- /dev/null +++ b/example_test.go @@ -0,0 +1,46 @@ +package errors_test + +import ( + "fmt" + "io/fs" + "os" + + "github.com/muonsoft/errors" +) + +func ExampleIs() { + if _, err := os.Open("non-existing"); err != nil { + if errors.Is(err, fs.ErrNotExist) { + fmt.Println("file does not exist") + } else { + fmt.Println(err) + } + } + + // Output: + // file does not exist +} + +func ExampleAs() { + if _, err := os.Open("non-existing"); err != nil { + if pathError, ok := errors.As[*fs.PathError](err); ok { + fmt.Println("Failed at path:", pathError.Path) + } else { + fmt.Println(err) + } + } + + // Output: + // Failed at path: non-existing +} + +func ExampleUnwrap() { + err1 := errors.New("error1") + err2 := fmt.Errorf("error2: [%w]", err1) + fmt.Println(err2) + fmt.Println(errors.Unwrap(err2)) + + // Output: + // error2: [error1] + // error1 +} diff --git a/go.mod b/go.mod index e2ff665..cedc942 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,7 @@ module github.com/muonsoft/errors -go 1.13 +go 1.18 require github.com/sirupsen/logrus v1.8.1 + +require golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 // indirect diff --git a/logging/logrusadapter/adapter.go b/logging/logrusadapter/adapter.go index 93ae444..98282a8 100644 --- a/logging/logrusadapter/adapter.go +++ b/logging/logrusadapter/adapter.go @@ -27,4 +27,20 @@ func (a *adapter) SetValue(key string, value interface{}) { a.l = a.l.WithF func (a *adapter) SetTime(key string, value time.Time) { a.l = a.l.WithField(key, value) } func (a *adapter) SetDuration(key string, value time.Duration) { a.l = a.l.WithField(key, value) } func (a *adapter) SetJSON(key string, value json.RawMessage) { a.l = a.l.WithField(key, value) } -func (a *adapter) SetStackTrace(trace errors.StackTrace) { a.l = a.l.WithField("stackTrace", trace) } + +func (a *adapter) SetStackTrace(trace errors.StackTrace) { + type Frame struct { + Function string `json:"function"` + File string `json:"file,omitempty"` + Line int `json:"line,omitempty"` + } + + frames := make([]Frame, len(trace)) + for i, frame := range trace { + frames[i].File = frame.File() + frames[i].Function = frame.Name() + frames[i].Line = frame.Line() + } + + a.l = a.l.WithField("stackTrace", frames) +} diff --git a/stack_test.go b/stack_test.go index f186c95..400e603 100644 --- a/stack_test.go +++ b/stack_test.go @@ -217,8 +217,8 @@ func caller() errors.Frame { func TestStackTrace_String(t *testing.T) { err := errors.Errorf("ooh") - var stacked StackTracer - if !errors.As(err, &stacked) { + stacked, ok := errors.As[StackTracer](err) + if !ok { t.Fatalf("expected %#v to implement errors.StackTracer", err) } s := stacked.StackTrace().String() @@ -230,8 +230,8 @@ func TestStackTrace_String(t *testing.T) { func TestStackTrace_Strings(t *testing.T) { err := errors.Errorf("ooh") - var stacked StackTracer - if !errors.As(err, &stacked) { + stacked, ok := errors.As[StackTracer](err) + if !ok { t.Fatalf("expected %#v to implement errors.StackTracer", err) } s := stacked.StackTrace().Strings() @@ -243,8 +243,8 @@ func TestStackTrace_Strings(t *testing.T) { func TestStackTrace_MarshalJSON(t *testing.T) { err := errors.Errorf("ooh") - var stacked StackTracer - if !errors.As(err, &stacked) { + stacked, ok := errors.As[StackTracer](err) + if !ok { t.Fatalf("expected %#v to implement errors.StackTracer", err) } st := stacked.StackTrace()