From 6d25a80b9798547eadb133a8e2d97335a86f8258 Mon Sep 17 00:00:00 2001 From: Tony Yip Date: Mon, 21 Mar 2022 22:05:20 +0800 Subject: [PATCH] feat(discovery): implement LDAP service discovery (RFC 2782) --- add.go | 7 ++++ conn.go | 93 +++++++++++++++++++++++++++++++++++++++++-------- del.go | 7 ++++ dn.go | 16 +++++++++ dn_test.go | 31 +++++++++++++++++ ldap_test.go | 8 +++++ moddn.go | 9 ++++- modify.go | 7 ++++ request.go | 10 ++++++ search.go | 7 ++++ v3/add.go | 7 ++++ v3/conn.go | 91 +++++++++++++++++++++++++++++++++++++++-------- v3/del.go | 7 ++++ v3/dn.go | 16 +++++++++ v3/dn_test.go | 31 +++++++++++++++++ v3/ldap_test.go | 8 +++++ v3/moddn.go | 9 ++++- v3/modify.go | 7 ++++ v3/request.go | 9 +++++ v3/search.go | 7 ++++ 20 files changed, 355 insertions(+), 32 deletions(-) diff --git a/add.go b/add.go index c3101b76..630e4742 100644 --- a/add.go +++ b/add.go @@ -33,6 +33,13 @@ type AddRequest struct { Controls []Control } +func (req *AddRequest) appendBaseDN(dn string) appendDnRequest { + r2 := new(AddRequest) + *r2 = *req + r2.DN = appendDN(req.DN, dn) + return r2 +} + func (req *AddRequest) appendTo(envelope *ber.Packet) error { pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request") pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) diff --git a/conn.go b/conn.go index 6ed7b5e6..065d5830 100644 --- a/conn.go +++ b/conn.go @@ -7,6 +7,9 @@ import ( "fmt" "net" "net/url" + "sort" + "strconv" + "strings" "sync" "sync/atomic" "time" @@ -101,6 +104,7 @@ type Conn struct { wgClose sync.WaitGroup outstandingRequests uint messageMutex sync.Mutex + rootDN string } var _ Client = &Conn{} @@ -144,7 +148,7 @@ type DialContext struct { tc *tls.Config } -func (dc *DialContext) dial(u *url.URL) (net.Conn, error) { +func (dc *DialContext) dial(u *url.URL) (conn net.Conn, err error) { if u.Scheme == "ldapi" { if u.Path == "" || u.Path == "/" { u.Path = "/var/run/slapd/ldapi" @@ -152,27 +156,81 @@ func (dc *DialContext) dial(u *url.URL) (net.Conn, error) { return dc.d.Dial("unix", u.Path) } - host, port, err := net.SplitHostPort(u.Host) - if err != nil { - // we assume that error is due to missing port - host = u.Host - port = "" + if u.Scheme != "ldap" && u.Scheme != "ldaps" { + return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme) } - switch u.Scheme { - case "ldap": - if port == "" { - port = DefaultLdapPort + hostports := make([]string, 0, 1) + + if u.Host == "" { + // Attempt to use DNS SRV discovery for uri like ldap:///dc=example,dc=com + // For ldap:///dc=example,dc=com, it would query for _ldap._tcp.example.com with SRV type record + fragments := strings.Split(u.Path[1:], ",") + pieces := make([]string, 0, len(fragments)) + for _, fragment := range fragments { + if strings.HasPrefix(fragment, "dc=") { + pieces = append(pieces, fragment[3:]) + } } - return dc.d.Dial("tcp", net.JoinHostPort(host, port)) - case "ldaps": + + domain := strings.Join(pieces, ".") + _, records, err := net.LookupSRV("ldap", "tcp", domain) + if err != nil { + return nil, err + } + + sort.Slice(records, func(i, j int) bool { + return records[i].Priority > records[j].Priority + }) + + if u.Scheme == "ldaps" { + dc.tc = &tls.Config{ + ServerName: domain, + } + } + + for _, record := range records { + port := strconv.Itoa(int(record.Port)) + hostports = append(hostports, net.JoinHostPort(record.Target, port)) + } + } else { + host, port, err := net.SplitHostPort(u.Host) + if err != nil { + // we assume that error is due to missing port + host = u.Host + port = "" + } + if port == "" { - port = DefaultLdapsPort + if u.Scheme == "ldap" { + port = DefaultLdapPort + } else if u.Scheme == "ldaps" { + port = DefaultLdapsPort + } + } + + hostports = []string{net.JoinHostPort(host, port)} + } + + for _, pair := range hostports { + conn, err = dc.dialConn(u.Scheme, pair) + if conn != nil { + return conn, err } - return tls.DialWithDialer(dc.d, "tcp", net.JoinHostPort(host, port), dc.tc) } - return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme) + return +} + +func (dc *DialContext) dialConn(scheme, target string) (net.Conn, error) { + switch scheme { + case "ldap": + return dc.d.Dial("tcp", target) + case "ldaps": + return tls.DialWithDialer(dc.d, "tcp", target, dc.tc) + } + + return nil, fmt.Errorf("Unknown scheme '%s'", scheme) } // Dial connects to the given address on the given network using net.Dial @@ -223,7 +281,12 @@ func DialURL(addr string, opts ...DialOpt) (*Conn, error) { return nil, NewError(ErrorNetwork, err) } + rootDN := "" + if u.Host == "" { + rootDN = u.Path[1:] + } conn := NewConn(c, u.Scheme == "ldaps") + conn.rootDN = rootDN conn.Start() return conn, nil } diff --git a/del.go b/del.go index bac0dfb7..5f0ca903 100644 --- a/del.go +++ b/del.go @@ -12,6 +12,13 @@ type DelRequest struct { Controls []Control } +func (req *DelRequest) appendBaseDN(dn string) appendDnRequest { + r2 := new(DelRequest) + *r2 = *req + r2.DN = appendDN(req.DN, dn) + return r2 +} + func (req *DelRequest) appendTo(envelope *ber.Packet) error { pkt := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationDelRequest, req.DN, "Del Request") pkt.Data.Write([]byte(req.DN)) diff --git a/dn.go b/dn.go index 916984b9..64e01db1 100644 --- a/dn.go +++ b/dn.go @@ -268,3 +268,19 @@ func (r *RelativeDN) hasAllAttributesFold(attrs []*AttributeTypeAndValue) bool { func (a *AttributeTypeAndValue) EqualFold(other *AttributeTypeAndValue) bool { return strings.EqualFold(a.Type, other.Type) && strings.EqualFold(a.Value, other.Value) } + +// appendDN is for concat the baseDN and rootDN +// dn stand for user input dn in request +// rootDN stand for dn used during discovery +func appendDN(dn, rootDN string) string { + if rootDN != "" { + var baseDnBuilder strings.Builder + if dn != "" { + baseDnBuilder.WriteString(dn) + baseDnBuilder.WriteByte(',') + } + baseDnBuilder.WriteString(rootDN) + return baseDnBuilder.String() + } + return dn +} diff --git a/dn_test.go b/dn_test.go index 6a82c72d..f47f39d6 100644 --- a/dn_test.go +++ b/dn_test.go @@ -264,3 +264,34 @@ func TestDNAncestor(t *testing.T) { } } } + +func TestAppendDN(t *testing.T) { + testcases := []struct { + baseDN string + rootDN string + expected string + }{ + { + baseDN: "ou=A", + rootDN: "dc=ldap,dc=internal", + expected: "ou=A,dc=ldap,dc=internal", + }, + { + baseDN: "ou=A,dc=ldap,dc=internal", + rootDN: "", + expected: "ou=A,dc=ldap,dc=internal", + }, + { + baseDN: "", + rootDN: "dc=ldap,dc=internal", + expected: "dc=ldap,dc=internal", + }, + } + + for i, tc := range testcases { + result := appendDN(tc.baseDN, tc.rootDN) + if result != tc.expected { + t.Errorf("#%d, expected %s, getting: %s", i, tc.expected, result) + } + } +} diff --git a/ldap_test.go b/ldap_test.go index b488bf0e..9aba0331 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -25,6 +25,14 @@ var attributes = []string{ "description", } +func TestDialURLViaDiscovery(t *testing.T) { + l, err := DialURL("ldap:///dc=umich,dc=edu") + if err != nil { + t.Fatal(err) + } + defer l.Close() +} + func TestUnsecureDialURL(t *testing.T) { l, err := DialURL(ldapServer) if err != nil { diff --git a/moddn.go b/moddn.go index ec246d1f..910fe4c4 100644 --- a/moddn.go +++ b/moddn.go @@ -40,7 +40,7 @@ func NewModifyDNRequest(dn string, rdn string, delOld bool, newSup string) *Modi // // Refer NewModifyDNRequest for other parameters func NewModifyDNWithControlsRequest(dn string, rdn string, delOld bool, - newSup string, controls []Control) *ModifyDNRequest { + newSup string, controls []Control) *ModifyDNRequest { return &ModifyDNRequest{ DN: dn, NewRDN: rdn, @@ -50,6 +50,13 @@ func NewModifyDNWithControlsRequest(dn string, rdn string, delOld bool, } } +func (req *ModifyDNRequest) appendBaseDN(dn string) appendDnRequest { + r2 := new(ModifyDNRequest) + *r2 = *req + r2.DN = appendDN(req.DN, dn) + return r2 +} + func (req *ModifyDNRequest) appendTo(envelope *ber.Packet) error { pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyDNRequest, nil, "Modify DN Request") pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) diff --git a/modify.go b/modify.go index 8b379558..841073f9 100644 --- a/modify.go +++ b/modify.go @@ -82,6 +82,13 @@ func (req *ModifyRequest) appendChange(operation uint, attrType string, attrVals req.Changes = append(req.Changes, Change{operation, PartialAttribute{Type: attrType, Vals: attrVals}}) } +func (req *ModifyRequest) appendBaseDN(dn string) appendDnRequest { + r2 := new(ModifyRequest) + *r2 = *req + r2.DN = appendDN(req.DN, dn) + return r2 +} + func (req *ModifyRequest) appendTo(envelope *ber.Packet) error { pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request") pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) diff --git a/request.go b/request.go index adc3b1c2..cab296b7 100644 --- a/request.go +++ b/request.go @@ -17,6 +17,11 @@ type request interface { appendTo(*ber.Packet) error } +type appendDnRequest interface { + request + appendBaseDN(dn string) appendDnRequest +} + type requestFunc func(*ber.Packet) error func (f requestFunc) appendTo(p *ber.Packet) error { @@ -30,6 +35,11 @@ func (l *Conn) doRequest(req request) (*messageContext, error) { packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) + + if areq, ok := req.(appendDnRequest); ok { + req = areq.appendBaseDN(l.rootDN) + } + if err := req.appendTo(packet); err != nil { return nil, err } diff --git a/search.go b/search.go index 35fc2497..d8d161b4 100644 --- a/search.go +++ b/search.go @@ -231,6 +231,13 @@ type SearchRequest struct { Controls []Control } +func (req *SearchRequest) appendBaseDN(dn string) appendDnRequest { + r2 := new(SearchRequest) + *r2 = *req + r2.BaseDN = appendDN(req.BaseDN, dn) + return r2 +} + func (req *SearchRequest) appendTo(envelope *ber.Packet) error { pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request") pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.BaseDN, "Base DN")) diff --git a/v3/add.go b/v3/add.go index c3101b76..630e4742 100644 --- a/v3/add.go +++ b/v3/add.go @@ -33,6 +33,13 @@ type AddRequest struct { Controls []Control } +func (req *AddRequest) appendBaseDN(dn string) appendDnRequest { + r2 := new(AddRequest) + *r2 = *req + r2.DN = appendDN(req.DN, dn) + return r2 +} + func (req *AddRequest) appendTo(envelope *ber.Packet) error { pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request") pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) diff --git a/v3/conn.go b/v3/conn.go index 6ed7b5e6..0981c553 100644 --- a/v3/conn.go +++ b/v3/conn.go @@ -7,6 +7,9 @@ import ( "fmt" "net" "net/url" + "sort" + "strconv" + "strings" "sync" "sync/atomic" "time" @@ -101,6 +104,7 @@ type Conn struct { wgClose sync.WaitGroup outstandingRequests uint messageMutex sync.Mutex + rootDN string } var _ Client = &Conn{} @@ -144,7 +148,7 @@ type DialContext struct { tc *tls.Config } -func (dc *DialContext) dial(u *url.URL) (net.Conn, error) { +func (dc *DialContext) dial(u *url.URL) (conn net.Conn, err error) { if u.Scheme == "ldapi" { if u.Path == "" || u.Path == "/" { u.Path = "/var/run/slapd/ldapi" @@ -152,27 +156,79 @@ func (dc *DialContext) dial(u *url.URL) (net.Conn, error) { return dc.d.Dial("unix", u.Path) } - host, port, err := net.SplitHostPort(u.Host) - if err != nil { - // we assume that error is due to missing port - host = u.Host - port = "" + if u.Scheme != "ldap" && u.Scheme != "ldaps" { + return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme) } - switch u.Scheme { - case "ldap": - if port == "" { - port = DefaultLdapPort + hostports := make([]string, 0, 1) + + if u.Host == "" { + fragments := strings.Split(u.Path[1:], ",") + pieces := make([]string, 0, len(fragments)) + for _, fragment := range fragments { + if strings.HasPrefix(fragment, "dc=") { + pieces = append(pieces, fragment[3:]) + } } - return dc.d.Dial("tcp", net.JoinHostPort(host, port)) - case "ldaps": + + domain := strings.Join(pieces, ".") + _, records, err := net.LookupSRV("ldap", "tcp", domain) + if err != nil { + return nil, err + } + + sort.Slice(records, func(i, j int) bool { + return records[i].Priority > records[j].Priority + }) + + if u.Scheme == "ldaps" { + dc.tc = &tls.Config{ + ServerName: domain, + } + } + + for _, record := range records { + port := strconv.Itoa(int(record.Port)) + hostports = append(hostports, net.JoinHostPort(record.Target, port)) + } + } else { + host, port, err := net.SplitHostPort(u.Host) + if err != nil { + // we assume that error is due to missing port + host = u.Host + port = "" + } + if port == "" { - port = DefaultLdapsPort + if u.Scheme == "ldap" { + port = DefaultLdapPort + } else if u.Scheme == "ldaps" { + port = DefaultLdapsPort + } + } + + hostports = []string{net.JoinHostPort(host, port)} + } + + for _, pair := range hostports { + conn, err = dc.dialConn(u.Scheme, pair) + if conn != nil { + return conn, err } - return tls.DialWithDialer(dc.d, "tcp", net.JoinHostPort(host, port), dc.tc) } - return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme) + return +} + +func (dc *DialContext) dialConn(scheme, target string) (net.Conn, error) { + switch scheme { + case "ldap": + return dc.d.Dial("tcp", target) + case "ldaps": + return tls.DialWithDialer(dc.d, "tcp", target, dc.tc) + } + + return nil, fmt.Errorf("Unknown scheme '%s'", scheme) } // Dial connects to the given address on the given network using net.Dial @@ -223,7 +279,12 @@ func DialURL(addr string, opts ...DialOpt) (*Conn, error) { return nil, NewError(ErrorNetwork, err) } + rootDN := "" + if u.Host == "" { + rootDN = u.Path[1:] + } conn := NewConn(c, u.Scheme == "ldaps") + conn.rootDN = rootDN conn.Start() return conn, nil } diff --git a/v3/del.go b/v3/del.go index bac0dfb7..5f0ca903 100644 --- a/v3/del.go +++ b/v3/del.go @@ -12,6 +12,13 @@ type DelRequest struct { Controls []Control } +func (req *DelRequest) appendBaseDN(dn string) appendDnRequest { + r2 := new(DelRequest) + *r2 = *req + r2.DN = appendDN(req.DN, dn) + return r2 +} + func (req *DelRequest) appendTo(envelope *ber.Packet) error { pkt := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationDelRequest, req.DN, "Del Request") pkt.Data.Write([]byte(req.DN)) diff --git a/v3/dn.go b/v3/dn.go index d802580e..066b19c9 100644 --- a/v3/dn.go +++ b/v3/dn.go @@ -268,3 +268,19 @@ func (r *RelativeDN) hasAllAttributesFold(attrs []*AttributeTypeAndValue) bool { func (a *AttributeTypeAndValue) EqualFold(other *AttributeTypeAndValue) bool { return strings.EqualFold(a.Type, other.Type) && strings.EqualFold(a.Value, other.Value) } + +// appendDN is for concat the baseDN and rootDN +// baseDN stand for user input dn in request +// rootDN stand for dn used during discovery +func appendDN(baseDN, rootDN string) string { + if rootDN != "" { + var baseDnBuilder strings.Builder + if baseDN != "" { + baseDnBuilder.WriteString(baseDN) + baseDnBuilder.WriteByte(',') + } + baseDnBuilder.WriteString(rootDN) + return baseDnBuilder.String() + } + return baseDN +} diff --git a/v3/dn_test.go b/v3/dn_test.go index 6a82c72d..f47f39d6 100644 --- a/v3/dn_test.go +++ b/v3/dn_test.go @@ -264,3 +264,34 @@ func TestDNAncestor(t *testing.T) { } } } + +func TestAppendDN(t *testing.T) { + testcases := []struct { + baseDN string + rootDN string + expected string + }{ + { + baseDN: "ou=A", + rootDN: "dc=ldap,dc=internal", + expected: "ou=A,dc=ldap,dc=internal", + }, + { + baseDN: "ou=A,dc=ldap,dc=internal", + rootDN: "", + expected: "ou=A,dc=ldap,dc=internal", + }, + { + baseDN: "", + rootDN: "dc=ldap,dc=internal", + expected: "dc=ldap,dc=internal", + }, + } + + for i, tc := range testcases { + result := appendDN(tc.baseDN, tc.rootDN) + if result != tc.expected { + t.Errorf("#%d, expected %s, getting: %s", i, tc.expected, result) + } + } +} diff --git a/v3/ldap_test.go b/v3/ldap_test.go index b488bf0e..9aba0331 100644 --- a/v3/ldap_test.go +++ b/v3/ldap_test.go @@ -25,6 +25,14 @@ var attributes = []string{ "description", } +func TestDialURLViaDiscovery(t *testing.T) { + l, err := DialURL("ldap:///dc=umich,dc=edu") + if err != nil { + t.Fatal(err) + } + defer l.Close() +} + func TestUnsecureDialURL(t *testing.T) { l, err := DialURL(ldapServer) if err != nil { diff --git a/v3/moddn.go b/v3/moddn.go index ec246d1f..910fe4c4 100644 --- a/v3/moddn.go +++ b/v3/moddn.go @@ -40,7 +40,7 @@ func NewModifyDNRequest(dn string, rdn string, delOld bool, newSup string) *Modi // // Refer NewModifyDNRequest for other parameters func NewModifyDNWithControlsRequest(dn string, rdn string, delOld bool, - newSup string, controls []Control) *ModifyDNRequest { + newSup string, controls []Control) *ModifyDNRequest { return &ModifyDNRequest{ DN: dn, NewRDN: rdn, @@ -50,6 +50,13 @@ func NewModifyDNWithControlsRequest(dn string, rdn string, delOld bool, } } +func (req *ModifyDNRequest) appendBaseDN(dn string) appendDnRequest { + r2 := new(ModifyDNRequest) + *r2 = *req + r2.DN = appendDN(req.DN, dn) + return r2 +} + func (req *ModifyDNRequest) appendTo(envelope *ber.Packet) error { pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyDNRequest, nil, "Modify DN Request") pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) diff --git a/v3/modify.go b/v3/modify.go index 8b379558..841073f9 100644 --- a/v3/modify.go +++ b/v3/modify.go @@ -82,6 +82,13 @@ func (req *ModifyRequest) appendChange(operation uint, attrType string, attrVals req.Changes = append(req.Changes, Change{operation, PartialAttribute{Type: attrType, Vals: attrVals}}) } +func (req *ModifyRequest) appendBaseDN(dn string) appendDnRequest { + r2 := new(ModifyRequest) + *r2 = *req + r2.DN = appendDN(req.DN, dn) + return r2 +} + func (req *ModifyRequest) appendTo(envelope *ber.Packet) error { pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request") pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) diff --git a/v3/request.go b/v3/request.go index adc3b1c2..c58bfeec 100644 --- a/v3/request.go +++ b/v3/request.go @@ -17,6 +17,11 @@ type request interface { appendTo(*ber.Packet) error } +type appendDnRequest interface { + request + appendBaseDN(dn string) appendDnRequest +} + type requestFunc func(*ber.Packet) error func (f requestFunc) appendTo(p *ber.Packet) error { @@ -30,6 +35,10 @@ func (l *Conn) doRequest(req request) (*messageContext, error) { packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) + + if areq, ok := req.(appendDnRequest); ok { + req = areq.appendBaseDN(l.rootDN) + } if err := req.appendTo(packet); err != nil { return nil, err } diff --git a/v3/search.go b/v3/search.go index 35fc2497..d8d161b4 100644 --- a/v3/search.go +++ b/v3/search.go @@ -231,6 +231,13 @@ type SearchRequest struct { Controls []Control } +func (req *SearchRequest) appendBaseDN(dn string) appendDnRequest { + r2 := new(SearchRequest) + *r2 = *req + r2.BaseDN = appendDN(req.BaseDN, dn) + return r2 +} + func (req *SearchRequest) appendTo(envelope *ber.Packet) error { pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request") pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.BaseDN, "Base DN"))