Skip to content

Commit

Permalink
refactor: match the result code with IsErrorAnyOf even if an error is…
Browse files Browse the repository at this point in the history
… wrapped
  • Loading branch information
t2y committed Nov 21, 2023
1 parent ef0e538 commit 00afd10
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 4 deletions.
5 changes: 3 additions & 2 deletions error.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldap

import (
"errors"
"fmt"

ber "github.com/go-asn1-ber/asn1-ber"
Expand Down Expand Up @@ -241,8 +242,8 @@ func IsErrorAnyOf(err error, codes ...uint16) bool {
return false
}

serverError, ok := err.(*Error)
if !ok {
var serverError *Error
if !errors.As(err, &serverError) {
return false
}

Expand Down
79 changes: 79 additions & 0 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ldap

import (
"errors"
"fmt"
"io"
"net"
"strings"
Expand All @@ -11,6 +12,84 @@ import (
ber "github.com/go-asn1-ber/asn1-ber"
)

// TestWrappedError tests that match the result code when an error is wrapped.
func TestWrappedError(t *testing.T) {
resultCodes := []uint16{
LDAPResultProtocolError,
LDAPResultBusy,
ErrorNetwork,
}

tests := []struct {
name string
err error
codes []uint16
expected bool
}{
// success
{
name: "a normal error",
err: &Error{
ResultCode: ErrorNetwork,
},
codes: resultCodes,
expected: true,
},

{
name: "a wrapped error",
err: fmt.Errorf("wrap: %w", &Error{
ResultCode: LDAPResultBusy,
}),
codes: resultCodes,
expected: true,
},

{
name: "multiple wrapped error",
err: fmt.Errorf("second: %w",
fmt.Errorf("first: %w",
&Error{
ResultCode: LDAPResultProtocolError,
},
),
),
codes: resultCodes,
expected: true,
},

// failure
{
name: "not match a normal error",
err: &Error{
ResultCode: LDAPResultSuccess,
},
codes: resultCodes,
expected: false,
},

{
name: "not match a wrapped error",
err: fmt.Errorf("wrap: %w", &Error{
ResultCode: LDAPResultNoSuchObject,
}),
codes: resultCodes,
expected: false,
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
actual := IsErrorAnyOf(tt.err, tt.codes...)
if tt.expected != actual {
t.Errorf("expected %t, but got %t", tt.expected, actual)
}
})
}
}

// TestNilPacket tests that nil packets don't cause a panic.
func TestNilPacket(t *testing.T) {
// Test for nil packet
Expand Down
5 changes: 3 additions & 2 deletions v3/error.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldap

import (
"errors"
"fmt"

ber "github.com/go-asn1-ber/asn1-ber"
Expand Down Expand Up @@ -241,8 +242,8 @@ func IsErrorAnyOf(err error, codes ...uint16) bool {
return false
}

serverError, ok := err.(*Error)
if !ok {
var serverError *Error
if !errors.As(err, &serverError) {
return false
}

Expand Down
79 changes: 79 additions & 0 deletions v3/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ldap

import (
"errors"
"fmt"
"io"
"net"
"strings"
Expand All @@ -11,6 +12,84 @@ import (
ber "github.com/go-asn1-ber/asn1-ber"
)

// TestWrappedError tests that match the result code when an error is wrapped.
func TestWrappedError(t *testing.T) {
resultCodes := []uint16{
LDAPResultProtocolError,
LDAPResultBusy,
ErrorNetwork,
}

tests := []struct {
name string
err error
codes []uint16
expected bool
}{
// success
{
name: "a normal error",
err: &Error{
ResultCode: ErrorNetwork,
},
codes: resultCodes,
expected: true,
},

{
name: "a wrapped error",
err: fmt.Errorf("wrap: %w", &Error{
ResultCode: LDAPResultBusy,
}),
codes: resultCodes,
expected: true,
},

{
name: "multiple wrapped error",
err: fmt.Errorf("second: %w",
fmt.Errorf("first: %w",
&Error{
ResultCode: LDAPResultProtocolError,
},
),
),
codes: resultCodes,
expected: true,
},

// failure
{
name: "not match a normal error",
err: &Error{
ResultCode: LDAPResultSuccess,
},
codes: resultCodes,
expected: false,
},

{
name: "not match a wrapped error",
err: fmt.Errorf("wrap: %w", &Error{
ResultCode: LDAPResultNoSuchObject,
}),
codes: resultCodes,
expected: false,
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
actual := IsErrorAnyOf(tt.err, tt.codes...)
if tt.expected != actual {
t.Errorf("expected %t, but got %t", tt.expected, actual)
}
})
}
}

// TestNilPacket tests that nil packets don't cause a panic.
func TestNilPacket(t *testing.T) {
// Test for nil packet
Expand Down

0 comments on commit 00afd10

Please sign in to comment.