Skip to content

Commit

Permalink
feat: Refactor ParseDN function to fix resource usage and invalid p…
Browse files Browse the repository at this point in the history
…arsings (fixes #487) (#497)
  • Loading branch information
cpuschma authored Apr 1, 2024
1 parent f09ee91 commit ff74b2c
Show file tree
Hide file tree
Showing 6 changed files with 356 additions and 214 deletions.
260 changes: 166 additions & 94 deletions dn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@ package ldap

import (
"bytes"
enchex "encoding/hex"
"encoding/asn1"
"encoding/hex"
"errors"
"fmt"
"sort"
"strings"

ber "github.com/go-asn1-ber/asn1-ber"
)

// AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514
Expand All @@ -19,8 +18,43 @@ type AttributeTypeAndValue struct {
Value string
}

func (a *AttributeTypeAndValue) setType(str string) error {
result, err := decodeString(str)
if err != nil {
return err
}
a.Type = result

return nil
}

func (a *AttributeTypeAndValue) setValue(s string) error {
// https://www.ietf.org/rfc/rfc4514.html#section-2.4
// If the AttributeType is of the dotted-decimal form, the
// AttributeValue is represented by an number sign ('#' U+0023)
// character followed by the hexadecimal encoding of each of the octets
// of the BER encoding of the X.500 AttributeValue.
if len(s) > 0 && s[0] == '#' {
decodedString, err := decodeEncodedString(s[1:])
if err != nil {
return err
}

a.Value = decodedString
return nil
} else {
decodedString, err := decodeString(s)
if err != nil {
return err
}

a.Value = decodedString
return nil
}
}

// String returns a normalized string representation of this attribute type and
// value pair which is the a lowercased join of the Type and Value with a "=".
// value pair which is the lowercase join of the Type and Value with a "=".
func (a *AttributeTypeAndValue) String() string {
return strings.ToLower(a.Type) + "=" + a.encodeValue()
}
Expand All @@ -39,7 +73,7 @@ func (a *AttributeTypeAndValue) encodeValue() string {

escapeHex := func(c byte) {
encodedBuf.WriteByte('\\')
encodedBuf.WriteString(enchex.EncodeToString([]byte{c}))
encodedBuf.WriteString(hex.EncodeToString([]byte{c}))
}

for i := 0; i < len(value); i++ {
Expand Down Expand Up @@ -108,114 +142,152 @@ func (d *DN) String() string {
return strings.Join(rdns, ",")
}

// Remove leading and trailing spaces from the attribute type and value
// and unescape any escaped characters in these fields
//
// decodeString is based on https://github.com/inteon/cert-manager/blob/ed280d28cd02b262c5db46054d88e70ab518299c/pkg/util/pki/internal/dn.go#L170
func decodeString(str string) (string, error) {
s := []rune(strings.TrimSpace(str))
// Re-add the trailing space if the last character was an escaped space character
if len(s) > 0 && s[len(s)-1] == '\\' && str[len(str)-2] == ' ' {
s = append(s, ' ')
}

builder := strings.Builder{}
for i := 0; i < len(s); i++ {
char := s[i]

// If the character is not an escape character, just add it to the
// builder and continue
if char != '\\' {
builder.WriteRune(char)
continue
}

// If the escape character is the last character, it's a corrupted
// escaped character
if i+1 >= len(s) {
return "", fmt.Errorf("got corrupted escaped character: '%s'", string(s))
}

// If the escaped character is a special character, just add it to
// the builder and continue
switch s[i+1] {
case ' ', '"', '#', '+', ',', ';', '<', '=', '>', '\\':
builder.WriteRune(s[i+1])
i++
continue
}

// If the escaped character is not a special character, it should
// be a hex-encoded character of the form \XX if it's not at least
// two characters long, it's a corrupted escaped character
if i+2 >= len(s) {
return "", errors.New("failed to decode escaped character: encoding/hex: invalid byte: " + string(s[i+1]))
}

// Get the runes for the two characters after the escape character
// and convert them to a byte slice
xx := []byte(string(s[i+1 : i+3]))

// If the two runes are not hex characters and result in more than
// two bytes when converted to a byte slice, it's a corrupted
// escaped character
if len(xx) != 2 {
return "", errors.New("failed to decode escaped character: invalid byte: " + string(xx))
}

// Decode the hex-encoded character and add it to the builder
dst := []byte{0}
if n, err := hex.Decode(dst, xx); err != nil {
return "", errors.New("failed to decode escaped character: " + err.Error())
} else if n != 1 {
return "", fmt.Errorf("failed to decode escaped character: encoding/hex: expected 1 byte when un-escaping, got %d", n)
}

builder.WriteByte(dst[0])
i += 2
}

return builder.String(), nil
}

func decodeEncodedString(str string) (string, error) {
decoded, err := hex.DecodeString(str)
if err != nil {
return "", fmt.Errorf("failed to decode BER encoding: %s", err)
}

var rawValue asn1.RawValue
result, err := asn1.Unmarshal(decoded, &rawValue)
if err != nil {
return "", fmt.Errorf("failed to unmarshal hex-encoded string: %s", err)
}
if len(result) != 0 {
return "", errors.New("trailing data after unmarshalling hex-encoded string")
}

return string(rawValue.Bytes), nil
}

// ParseDN returns a distinguishedName or an error.
// The function respects https://tools.ietf.org/html/rfc4514
func ParseDN(str string) (*DN, error) {
dn := new(DN)
dn.RDNs = make([]*RelativeDN, 0)
rdn := new(RelativeDN)
rdn.Attributes = make([]*AttributeTypeAndValue, 0)
buffer := bytes.Buffer{}
attribute := new(AttributeTypeAndValue)
escaping := false

unescapedTrailingSpaces := 0
stringFromBuffer := func() string {
s := buffer.String()
s = s[0 : len(s)-unescapedTrailingSpaces]
buffer.Reset()
unescapedTrailingSpaces = 0
return s
var dn = &DN{RDNs: make([]*RelativeDN, 0)}
if str = strings.TrimSpace(str); len(str) == 0 {
return dn, nil
}

var (
rdn = &RelativeDN{}
attr = &AttributeTypeAndValue{}
escaping bool
startPos int
appendAttributesToRDN = func(end bool) {
rdn.Attributes = append(rdn.Attributes, attr)
attr = &AttributeTypeAndValue{}
if end {
dn.RDNs = append(dn.RDNs, rdn)
rdn = &RelativeDN{}
}
}
)

for i := 0; i < len(str); i++ {
char := str[i]
switch {
case escaping:
unescapedTrailingSpaces = 0
escaping = false
switch char {
case ' ', '"', '#', '+', ',', ';', '<', '=', '>', '\\':
buffer.WriteByte(char)
continue
}
// Not a special character, assume hex encoded octet
if len(str) == i+1 {
return nil, errors.New("got corrupted escaped character")
}

dst := []byte{0}
n, err := enchex.Decode([]byte(dst), []byte(str[i:i+2]))
if err != nil {
return nil, fmt.Errorf("failed to decode escaped character: %s", err)
} else if n != 1 {
return nil, fmt.Errorf("expected 1 byte when un-escaping, got %d", n)
}
buffer.WriteByte(dst[0])
i++
case char == '\\':
unescapedTrailingSpaces = 0
escaping = true
case char == '=' && attribute.Type == "":
attribute.Type = stringFromBuffer()
// Special case: If the first character in the value is # the
// following data is BER encoded so we can just fast forward
// and decode.
if len(str) > i+1 && str[i+1] == '#' {
i += 2
index := strings.IndexAny(str[i:], ",+")
var data string
if index > 0 {
data = str[i : i+index]
} else {
data = str[i:]
}
rawBER, err := enchex.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("failed to decode BER encoding: %s", err)
}
packet, err := ber.DecodePacketErr(rawBER)
if err != nil {
return nil, fmt.Errorf("failed to decode BER packet: %s", err)
}
buffer.WriteString(packet.Data.String())
i += len(data) - 1
case char == '=' && len(attr.Type) == 0:
if err := attr.setType(str[startPos:i]); err != nil {
return nil, err
}
startPos = i + 1
case char == ',' || char == '+' || char == ';':
// We're done with this RDN or value, push it
if len(attribute.Type) == 0 {
return nil, errors.New("incomplete type, value pair")
if len(attr.Type) == 0 {
return dn, errors.New("incomplete type, value pair")
}
attribute.Value = stringFromBuffer()
rdn.Attributes = append(rdn.Attributes, attribute)
attribute = new(AttributeTypeAndValue)
if char == ',' || char == ';' {
dn.RDNs = append(dn.RDNs, rdn)
rdn = new(RelativeDN)
rdn.Attributes = make([]*AttributeTypeAndValue, 0)
if err := attr.setValue(str[startPos:i]); err != nil {
return nil, err
}
case char == ' ' && buffer.Len() == 0:
// ignore unescaped leading spaces
continue
default:
if char == ' ' {
// Track unescaped spaces in case they are trailing and we need to remove them
unescapedTrailingSpaces++
} else {
// Reset if we see a non-space char
unescapedTrailingSpaces = 0
}
buffer.WriteByte(char)

startPos = i + 1
last := char == ',' || char == ';'
appendAttributesToRDN(last)
}
}
if buffer.Len() > 0 {
if len(attribute.Type) == 0 {
return nil, errors.New("DN ended with incomplete type, value pair")
}
attribute.Value = stringFromBuffer()
rdn.Attributes = append(rdn.Attributes, attribute)
dn.RDNs = append(dn.RDNs, rdn)

if len(attr.Type) == 0 {
return dn, errors.New("DN ended with incomplete type, value pair")
}

if err := attr.setValue(str[startPos:]); err != nil {
return dn, err
}
appendAttributesToRDN(true)

return dn, nil
}

Expand Down
18 changes: 9 additions & 9 deletions dn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,21 @@ func TestSuccessfulDNParsing(t *testing.T) {
for test, answer := range testcases {
dn, err := ParseDN(test)
if err != nil {
t.Errorf(err.Error())
t.Errorf("ParseDN failed for DN test '%s': %s", test, err)
continue
}
if !reflect.DeepEqual(dn, &answer) {
t.Errorf("Parsed DN %s is not equal to the expected structure", test)
t.Errorf("Parsed DN '%s' is not equal to the expected structure", test)
t.Logf("Expected:")
for _, rdn := range answer.RDNs {
for _, attribs := range rdn.Attributes {
t.Logf("#%v\n", attribs)
for _, attribute := range rdn.Attributes {
t.Logf(" #%v\n", attribute)
}
}
t.Logf("Actual:")
for _, rdn := range dn.RDNs {
for _, attribs := range rdn.Attributes {
t.Logf("#%v\n", attribs)
for _, attribute := range rdn.Attributes {
t.Logf(" #%v\n", attribute)
}
}
}
Expand All @@ -107,7 +107,7 @@ func TestErrorDNParsing(t *testing.T) {
testcases := map[string]string{
"*": "DN ended with incomplete type, value pair",
"cn=Jim\\0Test": "failed to decode escaped character: encoding/hex: invalid byte: U+0054 'T'",
"cn=Jim\\0": "got corrupted escaped character",
"cn=Jim\\0": "failed to decode escaped character: encoding/hex: invalid byte: 0",
"DC=example,=net": "DN ended with incomplete type, value pair",
"1=#0402486": "failed to decode BER encoding: encoding/hex: odd length hex string",
"test,DC=example,DC=com": "incomplete type, value pair",
Expand All @@ -117,9 +117,9 @@ func TestErrorDNParsing(t *testing.T) {
for test, answer := range testcases {
_, err := ParseDN(test)
if err == nil {
t.Errorf("Expected %s to fail parsing but succeeded\n", test)
t.Errorf("Expected '%s' to fail parsing but succeeded\n", test)
} else if err.Error() != answer {
t.Errorf("Unexpected error on %s:\n%s\nvs.\n%s\n", test, answer, err.Error())
t.Errorf("Unexpected error on: '%s':\nExpected: %s\nGot: %s\n", test, answer, err.Error())
}
}
}
Expand Down
7 changes: 3 additions & 4 deletions ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,6 @@ func DebugBinaryFile(fileName string) error {
return nil
}

var hex = "0123456789abcdef"

func mustEscape(c byte) bool {
return c > 0x7f || c == '(' || c == ')' || c == '\\' || c == '*' || c == 0
}
Expand All @@ -324,6 +322,7 @@ func mustEscape(c byte) bool {
// characters in the set `()*\` and those out of the range 0 < c < 0x80,
// as defined in RFC4515.
func EscapeFilter(filter string) string {
const hexValues = "0123456789abcdef"
escape := 0
for i := 0; i < len(filter); i++ {
if mustEscape(filter[i]) {
Expand All @@ -338,8 +337,8 @@ func EscapeFilter(filter string) string {
c := filter[i]
if mustEscape(c) {
buf[j+0] = '\\'
buf[j+1] = hex[c>>4]
buf[j+2] = hex[c&0xf]
buf[j+1] = hexValues[c>>4]
buf[j+2] = hexValues[c&0xf]
j += 3
} else {
buf[j] = c
Expand Down
Loading

0 comments on commit ff74b2c

Please sign in to comment.