Skip to content

Commit

Permalink
feat(discovery): implement LDAP service discovery (RFC 2782)
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyhhyip committed May 31, 2022
1 parent 7d3b8d4 commit eb13da1
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 30 deletions.
85 changes: 70 additions & 15 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import (
"fmt"
"net"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -144,35 +147,87 @@ type DialContext struct {
tc *tls.Config
}

func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
func (dc *DialContext) dial(u *url.URL) (conn net.Conn, err error) {
if u.Scheme == "ldapi" {
if u.Path == "" || u.Path == "/" {
u.Path = "/var/run/slapd/ldapi"
}
return dc.d.Dial("unix", u.Path)
}

host, port, err := net.SplitHostPort(u.Host)
if err != nil {
// we assume that error is due to missing port
host = u.Host
port = ""
if u.Scheme != "ldap" && u.Scheme != "ldaps" {
return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
}

switch u.Scheme {
case "ldap":
if port == "" {
port = DefaultLdapPort
hostports := make([]string, 0, 1)

if u.Host == "" {
fragments := strings.Split(u.Path[1:], ",")
pieces := make([]string, 0, len(fragments))
for _, fragment := range fragments {
if strings.HasPrefix(fragment, "dc=") {
pieces = append(pieces, fragment[3:])
}
}
return dc.d.Dial("tcp", net.JoinHostPort(host, port))
case "ldaps":

domain := strings.Join(pieces, ".")
_, records, err := net.LookupSRV("ldap", "tcp", domain)
if err != nil {
return nil, err
}

sort.Slice(records, func(i, j int) bool {
return records[i].Priority > records[j].Priority
})

if u.Scheme == "ldaps" {
dc.tc = &tls.Config{
ServerName: domain,
}
}

for _, record := range records {
port := strconv.Itoa(int(record.Port))
hostports = append(hostports, net.JoinHostPort(record.Target, port))
}
} else {
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
// we assume that error is due to missing port
host = u.Host
port = ""
}

if port == "" {
port = DefaultLdapsPort
if u.Scheme == "ldap" {
port = DefaultLdapPort
} else if u.Scheme == "ldaps" {
port = DefaultLdapsPort
}
}
return tls.DialWithDialer(dc.d, "tcp", net.JoinHostPort(host, port), dc.tc)

hostports = []string{net.JoinHostPort(host, port)}
}

for _, pair := range hostports {
conn, err = dc.dialConn(u.Scheme, pair)
if conn != nil {
return conn, err
}
}

return
}

func (dc *DialContext) dialConn(scheme, target string) (net.Conn, error) {
switch scheme {
case "ldap":
return dc.d.Dial("tcp", target)
case "ldaps":
return tls.DialWithDialer(dc.d, "tcp", target, dc.tc)
}

return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
return nil, fmt.Errorf("Unknown scheme '%s'", scheme)
}

// Dial connects to the given address on the given network using net.Dial
Expand Down
85 changes: 70 additions & 15 deletions v3/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import (
"fmt"
"net"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -144,35 +147,87 @@ type DialContext struct {
tc *tls.Config
}

func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
func (dc *DialContext) dial(u *url.URL) (conn net.Conn, err error) {
if u.Scheme == "ldapi" {
if u.Path == "" || u.Path == "/" {
u.Path = "/var/run/slapd/ldapi"
}
return dc.d.Dial("unix", u.Path)
}

host, port, err := net.SplitHostPort(u.Host)
if err != nil {
// we assume that error is due to missing port
host = u.Host
port = ""
if u.Scheme != "ldap" && u.Scheme != "ldaps" {
return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
}

switch u.Scheme {
case "ldap":
if port == "" {
port = DefaultLdapPort
hostports := make([]string, 0, 1)

if u.Host == "" {
fragments := strings.Split(u.Path[1:], ",")
pieces := make([]string, 0, len(fragments))
for _, fragment := range fragments {
if strings.HasPrefix(fragment, "dc=") {
pieces = append(pieces, fragment[3:])
}
}
return dc.d.Dial("tcp", net.JoinHostPort(host, port))
case "ldaps":

domain := strings.Join(pieces, ".")
_, records, err := net.LookupSRV("ldap", "tcp", domain)
if err != nil {
return nil, err
}

sort.Slice(records, func(i, j int) bool {
return records[i].Priority > records[j].Priority
})

if u.Scheme == "ldaps" {
dc.tc = &tls.Config{
ServerName: domain,
}
}

for _, record := range records {
port := strconv.Itoa(int(record.Port))
hostports = append(hostports, net.JoinHostPort(record.Target, port))
}
} else {
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
// we assume that error is due to missing port
host = u.Host
port = ""
}

if port == "" {
port = DefaultLdapsPort
if u.Scheme == "ldap" {
port = DefaultLdapPort
} else if u.Scheme == "ldaps" {
port = DefaultLdapsPort
}
}
return tls.DialWithDialer(dc.d, "tcp", net.JoinHostPort(host, port), dc.tc)

hostports = []string{net.JoinHostPort(host, port)}
}

for _, pair := range hostports {
conn, err = dc.dialConn(u.Scheme, pair)
if conn != nil {
return conn, err
}
}

return
}

func (dc *DialContext) dialConn(scheme, target string) (net.Conn, error) {
switch scheme {
case "ldap":
return dc.d.Dial("tcp", target)
case "ldaps":
return tls.DialWithDialer(dc.d, "tcp", target, dc.tc)
}

return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
return nil, fmt.Errorf("Unknown scheme '%s'", scheme)
}

// Dial connects to the given address on the given network using net.Dial
Expand Down

0 comments on commit eb13da1

Please sign in to comment.