From 12f15cc814eb312f7a73021e63d0cebbd9f51baf Mon Sep 17 00:00:00 2001 From: Tetsuya Morimoto Date: Tue, 21 Nov 2023 09:14:08 +0900 Subject: [PATCH] refactor: match the result code with IsErrorAnyOf even if an error is wrapped (#471) --- error.go | 5 +-- error_test.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++++ v3/error.go | 5 +-- v3/error_test.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 164 insertions(+), 4 deletions(-) diff --git a/error.go b/error.go index 3c2559e..53c6d62 100644 --- a/error.go +++ b/error.go @@ -1,6 +1,7 @@ package ldap import ( + "errors" "fmt" ber "github.com/go-asn1-ber/asn1-ber" @@ -241,8 +242,8 @@ func IsErrorAnyOf(err error, codes ...uint16) bool { return false } - serverError, ok := err.(*Error) - if !ok { + var serverError *Error + if !errors.As(err, &serverError) { return false } diff --git a/error_test.go b/error_test.go index c115e00..323f766 100644 --- a/error_test.go +++ b/error_test.go @@ -2,6 +2,7 @@ package ldap import ( "errors" + "fmt" "io" "net" "strings" @@ -11,6 +12,84 @@ import ( ber "github.com/go-asn1-ber/asn1-ber" ) +// TestWrappedError tests that match the result code when an error is wrapped. +func TestWrappedError(t *testing.T) { + resultCodes := []uint16{ + LDAPResultProtocolError, + LDAPResultBusy, + ErrorNetwork, + } + + tests := []struct { + name string + err error + codes []uint16 + expected bool + }{ + // success + { + name: "a normal error", + err: &Error{ + ResultCode: ErrorNetwork, + }, + codes: resultCodes, + expected: true, + }, + + { + name: "a wrapped error", + err: fmt.Errorf("wrap: %w", &Error{ + ResultCode: LDAPResultBusy, + }), + codes: resultCodes, + expected: true, + }, + + { + name: "multiple wrapped error", + err: fmt.Errorf("second: %w", + fmt.Errorf("first: %w", + &Error{ + ResultCode: LDAPResultProtocolError, + }, + ), + ), + codes: resultCodes, + expected: true, + }, + + // failure + { + name: "not match a normal error", + err: &Error{ + ResultCode: LDAPResultSuccess, + }, + codes: resultCodes, + expected: false, + }, + + { + name: "not match a wrapped error", + err: fmt.Errorf("wrap: %w", &Error{ + ResultCode: LDAPResultNoSuchObject, + }), + codes: resultCodes, + expected: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + actual := IsErrorAnyOf(tt.err, tt.codes...) + if tt.expected != actual { + t.Errorf("expected %t, but got %t", tt.expected, actual) + } + }) + } +} + // TestNilPacket tests that nil packets don't cause a panic. func TestNilPacket(t *testing.T) { // Test for nil packet diff --git a/v3/error.go b/v3/error.go index 3c2559e..53c6d62 100644 --- a/v3/error.go +++ b/v3/error.go @@ -1,6 +1,7 @@ package ldap import ( + "errors" "fmt" ber "github.com/go-asn1-ber/asn1-ber" @@ -241,8 +242,8 @@ func IsErrorAnyOf(err error, codes ...uint16) bool { return false } - serverError, ok := err.(*Error) - if !ok { + var serverError *Error + if !errors.As(err, &serverError) { return false } diff --git a/v3/error_test.go b/v3/error_test.go index c115e00..323f766 100644 --- a/v3/error_test.go +++ b/v3/error_test.go @@ -2,6 +2,7 @@ package ldap import ( "errors" + "fmt" "io" "net" "strings" @@ -11,6 +12,84 @@ import ( ber "github.com/go-asn1-ber/asn1-ber" ) +// TestWrappedError tests that match the result code when an error is wrapped. +func TestWrappedError(t *testing.T) { + resultCodes := []uint16{ + LDAPResultProtocolError, + LDAPResultBusy, + ErrorNetwork, + } + + tests := []struct { + name string + err error + codes []uint16 + expected bool + }{ + // success + { + name: "a normal error", + err: &Error{ + ResultCode: ErrorNetwork, + }, + codes: resultCodes, + expected: true, + }, + + { + name: "a wrapped error", + err: fmt.Errorf("wrap: %w", &Error{ + ResultCode: LDAPResultBusy, + }), + codes: resultCodes, + expected: true, + }, + + { + name: "multiple wrapped error", + err: fmt.Errorf("second: %w", + fmt.Errorf("first: %w", + &Error{ + ResultCode: LDAPResultProtocolError, + }, + ), + ), + codes: resultCodes, + expected: true, + }, + + // failure + { + name: "not match a normal error", + err: &Error{ + ResultCode: LDAPResultSuccess, + }, + codes: resultCodes, + expected: false, + }, + + { + name: "not match a wrapped error", + err: fmt.Errorf("wrap: %w", &Error{ + ResultCode: LDAPResultNoSuchObject, + }), + codes: resultCodes, + expected: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + actual := IsErrorAnyOf(tt.err, tt.codes...) + if tt.expected != actual { + t.Errorf("expected %t, but got %t", tt.expected, actual) + } + }) + } +} + // TestNilPacket tests that nil packets don't cause a panic. func TestNilPacket(t *testing.T) { // Test for nil packet