diff --git a/error.go b/error.go index 3c2559e..53c6d62 100644 --- a/error.go +++ b/error.go @@ -1,6 +1,7 @@ package ldap import ( + "errors" "fmt" ber "github.com/go-asn1-ber/asn1-ber" @@ -241,8 +242,8 @@ func IsErrorAnyOf(err error, codes ...uint16) bool { return false } - serverError, ok := err.(*Error) - if !ok { + var serverError *Error + if !errors.As(err, &serverError) { return false } diff --git a/error_test.go b/error_test.go index c115e00..323f766 100644 --- a/error_test.go +++ b/error_test.go @@ -2,6 +2,7 @@ package ldap import ( "errors" + "fmt" "io" "net" "strings" @@ -11,6 +12,84 @@ import ( ber "github.com/go-asn1-ber/asn1-ber" ) +// TestWrappedError tests that match the result code when an error is wrapped. +func TestWrappedError(t *testing.T) { + resultCodes := []uint16{ + LDAPResultProtocolError, + LDAPResultBusy, + ErrorNetwork, + } + + tests := []struct { + name string + err error + codes []uint16 + expected bool + }{ + // success + { + name: "a normal error", + err: &Error{ + ResultCode: ErrorNetwork, + }, + codes: resultCodes, + expected: true, + }, + + { + name: "a wrapped error", + err: fmt.Errorf("wrap: %w", &Error{ + ResultCode: LDAPResultBusy, + }), + codes: resultCodes, + expected: true, + }, + + { + name: "multiple wrapped error", + err: fmt.Errorf("second: %w", + fmt.Errorf("first: %w", + &Error{ + ResultCode: LDAPResultProtocolError, + }, + ), + ), + codes: resultCodes, + expected: true, + }, + + // failure + { + name: "not match a normal error", + err: &Error{ + ResultCode: LDAPResultSuccess, + }, + codes: resultCodes, + expected: false, + }, + + { + name: "not match a wrapped error", + err: fmt.Errorf("wrap: %w", &Error{ + ResultCode: LDAPResultNoSuchObject, + }), + codes: resultCodes, + expected: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + actual := IsErrorAnyOf(tt.err, tt.codes...) + if tt.expected != actual { + t.Errorf("expected %t, but got %t", tt.expected, actual) + } + }) + } +} + // TestNilPacket tests that nil packets don't cause a panic. func TestNilPacket(t *testing.T) { // Test for nil packet diff --git a/v3/error.go b/v3/error.go index 3c2559e..53c6d62 100644 --- a/v3/error.go +++ b/v3/error.go @@ -1,6 +1,7 @@ package ldap import ( + "errors" "fmt" ber "github.com/go-asn1-ber/asn1-ber" @@ -241,8 +242,8 @@ func IsErrorAnyOf(err error, codes ...uint16) bool { return false } - serverError, ok := err.(*Error) - if !ok { + var serverError *Error + if !errors.As(err, &serverError) { return false } diff --git a/v3/error_test.go b/v3/error_test.go index c115e00..323f766 100644 --- a/v3/error_test.go +++ b/v3/error_test.go @@ -2,6 +2,7 @@ package ldap import ( "errors" + "fmt" "io" "net" "strings" @@ -11,6 +12,84 @@ import ( ber "github.com/go-asn1-ber/asn1-ber" ) +// TestWrappedError tests that match the result code when an error is wrapped. +func TestWrappedError(t *testing.T) { + resultCodes := []uint16{ + LDAPResultProtocolError, + LDAPResultBusy, + ErrorNetwork, + } + + tests := []struct { + name string + err error + codes []uint16 + expected bool + }{ + // success + { + name: "a normal error", + err: &Error{ + ResultCode: ErrorNetwork, + }, + codes: resultCodes, + expected: true, + }, + + { + name: "a wrapped error", + err: fmt.Errorf("wrap: %w", &Error{ + ResultCode: LDAPResultBusy, + }), + codes: resultCodes, + expected: true, + }, + + { + name: "multiple wrapped error", + err: fmt.Errorf("second: %w", + fmt.Errorf("first: %w", + &Error{ + ResultCode: LDAPResultProtocolError, + }, + ), + ), + codes: resultCodes, + expected: true, + }, + + // failure + { + name: "not match a normal error", + err: &Error{ + ResultCode: LDAPResultSuccess, + }, + codes: resultCodes, + expected: false, + }, + + { + name: "not match a wrapped error", + err: fmt.Errorf("wrap: %w", &Error{ + ResultCode: LDAPResultNoSuchObject, + }), + codes: resultCodes, + expected: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + actual := IsErrorAnyOf(tt.err, tt.codes...) + if tt.expected != actual { + t.Errorf("expected %t, but got %t", tt.expected, actual) + } + }) + } +} + // TestNilPacket tests that nil packets don't cause a panic. func TestNilPacket(t *testing.T) { // Test for nil packet