Skip to content

Commit

Permalink
feat(discovery): implement LDAP service discovery (RFC 2782)
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyhhyip committed Jun 13, 2022
1 parent a3dcdda commit 6d25a80
Show file tree
Hide file tree
Showing 20 changed files with 355 additions and 32 deletions.
7 changes: 7 additions & 0 deletions add.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ type AddRequest struct {
Controls []Control
}

func (req *AddRequest) appendBaseDN(dn string) appendDnRequest {
r2 := new(AddRequest)
*r2 = *req
r2.DN = appendDN(req.DN, dn)
return r2
}

func (req *AddRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
Expand Down
93 changes: 78 additions & 15 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import (
"fmt"
"net"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -101,6 +104,7 @@ type Conn struct {
wgClose sync.WaitGroup
outstandingRequests uint
messageMutex sync.Mutex
rootDN string
}

var _ Client = &Conn{}
Expand Down Expand Up @@ -144,35 +148,89 @@ type DialContext struct {
tc *tls.Config
}

func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
func (dc *DialContext) dial(u *url.URL) (conn net.Conn, err error) {
if u.Scheme == "ldapi" {
if u.Path == "" || u.Path == "/" {
u.Path = "/var/run/slapd/ldapi"
}
return dc.d.Dial("unix", u.Path)
}

host, port, err := net.SplitHostPort(u.Host)
if err != nil {
// we assume that error is due to missing port
host = u.Host
port = ""
if u.Scheme != "ldap" && u.Scheme != "ldaps" {
return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
}

switch u.Scheme {
case "ldap":
if port == "" {
port = DefaultLdapPort
hostports := make([]string, 0, 1)

if u.Host == "" {
// Attempt to use DNS SRV discovery for uri like ldap:///dc=example,dc=com
// For ldap:///dc=example,dc=com, it would query for _ldap._tcp.example.com with SRV type record
fragments := strings.Split(u.Path[1:], ",")
pieces := make([]string, 0, len(fragments))
for _, fragment := range fragments {
if strings.HasPrefix(fragment, "dc=") {
pieces = append(pieces, fragment[3:])
}
}
return dc.d.Dial("tcp", net.JoinHostPort(host, port))
case "ldaps":

domain := strings.Join(pieces, ".")
_, records, err := net.LookupSRV("ldap", "tcp", domain)
if err != nil {
return nil, err
}

sort.Slice(records, func(i, j int) bool {
return records[i].Priority > records[j].Priority
})

if u.Scheme == "ldaps" {
dc.tc = &tls.Config{
ServerName: domain,
}
}

for _, record := range records {
port := strconv.Itoa(int(record.Port))
hostports = append(hostports, net.JoinHostPort(record.Target, port))
}
} else {
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
// we assume that error is due to missing port
host = u.Host
port = ""
}

if port == "" {
port = DefaultLdapsPort
if u.Scheme == "ldap" {
port = DefaultLdapPort
} else if u.Scheme == "ldaps" {
port = DefaultLdapsPort
}
}

hostports = []string{net.JoinHostPort(host, port)}
}

for _, pair := range hostports {
conn, err = dc.dialConn(u.Scheme, pair)
if conn != nil {
return conn, err
}
return tls.DialWithDialer(dc.d, "tcp", net.JoinHostPort(host, port), dc.tc)
}

return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
return
}

func (dc *DialContext) dialConn(scheme, target string) (net.Conn, error) {
switch scheme {
case "ldap":
return dc.d.Dial("tcp", target)
case "ldaps":
return tls.DialWithDialer(dc.d, "tcp", target, dc.tc)
}

return nil, fmt.Errorf("Unknown scheme '%s'", scheme)
}

// Dial connects to the given address on the given network using net.Dial
Expand Down Expand Up @@ -223,7 +281,12 @@ func DialURL(addr string, opts ...DialOpt) (*Conn, error) {
return nil, NewError(ErrorNetwork, err)
}

rootDN := ""
if u.Host == "" {
rootDN = u.Path[1:]
}
conn := NewConn(c, u.Scheme == "ldaps")
conn.rootDN = rootDN
conn.Start()
return conn, nil
}
Expand Down
7 changes: 7 additions & 0 deletions del.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ type DelRequest struct {
Controls []Control
}

func (req *DelRequest) appendBaseDN(dn string) appendDnRequest {
r2 := new(DelRequest)
*r2 = *req
r2.DN = appendDN(req.DN, dn)
return r2
}

func (req *DelRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationDelRequest, req.DN, "Del Request")
pkt.Data.Write([]byte(req.DN))
Expand Down
16 changes: 16 additions & 0 deletions dn.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,19 @@ func (r *RelativeDN) hasAllAttributesFold(attrs []*AttributeTypeAndValue) bool {
func (a *AttributeTypeAndValue) EqualFold(other *AttributeTypeAndValue) bool {
return strings.EqualFold(a.Type, other.Type) && strings.EqualFold(a.Value, other.Value)
}

// appendDN is for concat the baseDN and rootDN
// dn stand for user input dn in request
// rootDN stand for dn used during discovery
func appendDN(dn, rootDN string) string {
if rootDN != "" {
var baseDnBuilder strings.Builder
if dn != "" {
baseDnBuilder.WriteString(dn)
baseDnBuilder.WriteByte(',')
}
baseDnBuilder.WriteString(rootDN)
return baseDnBuilder.String()
}
return dn
}
31 changes: 31 additions & 0 deletions dn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,34 @@ func TestDNAncestor(t *testing.T) {
}
}
}

func TestAppendDN(t *testing.T) {
testcases := []struct {
baseDN string
rootDN string
expected string
}{
{
baseDN: "ou=A",
rootDN: "dc=ldap,dc=internal",
expected: "ou=A,dc=ldap,dc=internal",
},
{
baseDN: "ou=A,dc=ldap,dc=internal",
rootDN: "",
expected: "ou=A,dc=ldap,dc=internal",
},
{
baseDN: "",
rootDN: "dc=ldap,dc=internal",
expected: "dc=ldap,dc=internal",
},
}

for i, tc := range testcases {
result := appendDN(tc.baseDN, tc.rootDN)
if result != tc.expected {
t.Errorf("#%d, expected %s, getting: %s", i, tc.expected, result)
}
}
}
8 changes: 8 additions & 0 deletions ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ var attributes = []string{
"description",
}

func TestDialURLViaDiscovery(t *testing.T) {
l, err := DialURL("ldap:///dc=umich,dc=edu")
if err != nil {
t.Fatal(err)
}
defer l.Close()
}

func TestUnsecureDialURL(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
Expand Down
9 changes: 8 additions & 1 deletion moddn.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewModifyDNRequest(dn string, rdn string, delOld bool, newSup string) *Modi
//
// Refer NewModifyDNRequest for other parameters
func NewModifyDNWithControlsRequest(dn string, rdn string, delOld bool,
newSup string, controls []Control) *ModifyDNRequest {
newSup string, controls []Control) *ModifyDNRequest {
return &ModifyDNRequest{
DN: dn,
NewRDN: rdn,
Expand All @@ -50,6 +50,13 @@ func NewModifyDNWithControlsRequest(dn string, rdn string, delOld bool,
}
}

func (req *ModifyDNRequest) appendBaseDN(dn string) appendDnRequest {
r2 := new(ModifyDNRequest)
*r2 = *req
r2.DN = appendDN(req.DN, dn)
return r2
}

func (req *ModifyDNRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyDNRequest, nil, "Modify DN Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
Expand Down
7 changes: 7 additions & 0 deletions modify.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ func (req *ModifyRequest) appendChange(operation uint, attrType string, attrVals
req.Changes = append(req.Changes, Change{operation, PartialAttribute{Type: attrType, Vals: attrVals}})
}

func (req *ModifyRequest) appendBaseDN(dn string) appendDnRequest {
r2 := new(ModifyRequest)
*r2 = *req
r2.DN = appendDN(req.DN, dn)
return r2
}

func (req *ModifyRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
Expand Down
10 changes: 10 additions & 0 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ type request interface {
appendTo(*ber.Packet) error
}

type appendDnRequest interface {
request
appendBaseDN(dn string) appendDnRequest
}

type requestFunc func(*ber.Packet) error

func (f requestFunc) appendTo(p *ber.Packet) error {
Expand All @@ -30,6 +35,11 @@ func (l *Conn) doRequest(req request) (*messageContext, error) {

packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))

if areq, ok := req.(appendDnRequest); ok {
req = areq.appendBaseDN(l.rootDN)
}

if err := req.appendTo(packet); err != nil {
return nil, err
}
Expand Down
7 changes: 7 additions & 0 deletions search.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@ type SearchRequest struct {
Controls []Control
}

func (req *SearchRequest) appendBaseDN(dn string) appendDnRequest {
r2 := new(SearchRequest)
*r2 = *req
r2.BaseDN = appendDN(req.BaseDN, dn)
return r2
}

func (req *SearchRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.BaseDN, "Base DN"))
Expand Down
7 changes: 7 additions & 0 deletions v3/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ type AddRequest struct {
Controls []Control
}

func (req *AddRequest) appendBaseDN(dn string) appendDnRequest {
r2 := new(AddRequest)
*r2 = *req
r2.DN = appendDN(req.DN, dn)
return r2
}

func (req *AddRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
Expand Down
Loading

0 comments on commit 6d25a80

Please sign in to comment.