diff --git a/conn.go b/conn.go index 6ed7b5e6..bab2498e 100644 --- a/conn.go +++ b/conn.go @@ -7,6 +7,9 @@ import ( "fmt" "net" "net/url" + "sort" + "strconv" + "strings" "sync" "sync/atomic" "time" @@ -144,7 +147,7 @@ type DialContext struct { tc *tls.Config } -func (dc *DialContext) dial(u *url.URL) (net.Conn, error) { +func (dc *DialContext) dial(u *url.URL) (conn net.Conn, err error) { if u.Scheme == "ldapi" { if u.Path == "" || u.Path == "/" { u.Path = "/var/run/slapd/ldapi" @@ -152,27 +155,79 @@ func (dc *DialContext) dial(u *url.URL) (net.Conn, error) { return dc.d.Dial("unix", u.Path) } - host, port, err := net.SplitHostPort(u.Host) - if err != nil { - // we assume that error is due to missing port - host = u.Host - port = "" + if u.Scheme != "ldap" && u.Scheme != "ldaps" { + return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme) } - switch u.Scheme { - case "ldap": - if port == "" { - port = DefaultLdapPort + hostports := make([]string, 0, 1) + + if u.Host == "" { + fragments := strings.Split(u.Path[1:], ",") + pieces := make([]string, 0, len(fragments)) + for _, fragment := range fragments { + if strings.HasPrefix(fragment, "dc=") { + pieces = append(pieces, fragment[3:]) + } } - return dc.d.Dial("tcp", net.JoinHostPort(host, port)) - case "ldaps": + + domain := strings.Join(pieces, ".") + _, records, err := net.LookupSRV("ldap", "tcp", domain) + if err != nil { + return nil, err + } + + sort.Slice(records, func(i, j int) bool { + return records[i].Priority > records[j].Priority + }) + + if u.Scheme == "ldaps" { + dc.tc = &tls.Config{ + ServerName: domain, + } + } + + for _, record := range records { + port := strconv.Itoa(int(record.Port)) + hostports = append(hostports, net.JoinHostPort(record.Target, port)) + } + } else { + host, port, err := net.SplitHostPort(u.Host) + if err != nil { + // we assume that error is due to missing port + host = u.Host + port = "" + } + if port == "" { - port = DefaultLdapsPort + if u.Scheme == "ldap" { + port = DefaultLdapPort + } else if u.Scheme == "ldaps" { + port = DefaultLdapsPort + } } - return tls.DialWithDialer(dc.d, "tcp", net.JoinHostPort(host, port), dc.tc) + + hostports = []string{net.JoinHostPort(host, port)} + } + + for _, pair := range hostports { + conn, err = dc.dialConn(u.Scheme, pair) + if conn != nil { + return conn, err + } + } + + return +} + +func (dc *DialContext) dialConn(scheme, target string) (net.Conn, error) { + switch scheme { + case "ldap": + return dc.d.Dial("tcp", target) + case "ldaps": + return tls.DialWithDialer(dc.d, "tcp", target, dc.tc) } - return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme) + return nil, fmt.Errorf("Unknown scheme '%s'", scheme) } // Dial connects to the given address on the given network using net.Dial diff --git a/v3/conn.go b/v3/conn.go index 6ed7b5e6..bab2498e 100644 --- a/v3/conn.go +++ b/v3/conn.go @@ -7,6 +7,9 @@ import ( "fmt" "net" "net/url" + "sort" + "strconv" + "strings" "sync" "sync/atomic" "time" @@ -144,7 +147,7 @@ type DialContext struct { tc *tls.Config } -func (dc *DialContext) dial(u *url.URL) (net.Conn, error) { +func (dc *DialContext) dial(u *url.URL) (conn net.Conn, err error) { if u.Scheme == "ldapi" { if u.Path == "" || u.Path == "/" { u.Path = "/var/run/slapd/ldapi" @@ -152,27 +155,79 @@ func (dc *DialContext) dial(u *url.URL) (net.Conn, error) { return dc.d.Dial("unix", u.Path) } - host, port, err := net.SplitHostPort(u.Host) - if err != nil { - // we assume that error is due to missing port - host = u.Host - port = "" + if u.Scheme != "ldap" && u.Scheme != "ldaps" { + return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme) } - switch u.Scheme { - case "ldap": - if port == "" { - port = DefaultLdapPort + hostports := make([]string, 0, 1) + + if u.Host == "" { + fragments := strings.Split(u.Path[1:], ",") + pieces := make([]string, 0, len(fragments)) + for _, fragment := range fragments { + if strings.HasPrefix(fragment, "dc=") { + pieces = append(pieces, fragment[3:]) + } } - return dc.d.Dial("tcp", net.JoinHostPort(host, port)) - case "ldaps": + + domain := strings.Join(pieces, ".") + _, records, err := net.LookupSRV("ldap", "tcp", domain) + if err != nil { + return nil, err + } + + sort.Slice(records, func(i, j int) bool { + return records[i].Priority > records[j].Priority + }) + + if u.Scheme == "ldaps" { + dc.tc = &tls.Config{ + ServerName: domain, + } + } + + for _, record := range records { + port := strconv.Itoa(int(record.Port)) + hostports = append(hostports, net.JoinHostPort(record.Target, port)) + } + } else { + host, port, err := net.SplitHostPort(u.Host) + if err != nil { + // we assume that error is due to missing port + host = u.Host + port = "" + } + if port == "" { - port = DefaultLdapsPort + if u.Scheme == "ldap" { + port = DefaultLdapPort + } else if u.Scheme == "ldaps" { + port = DefaultLdapsPort + } } - return tls.DialWithDialer(dc.d, "tcp", net.JoinHostPort(host, port), dc.tc) + + hostports = []string{net.JoinHostPort(host, port)} + } + + for _, pair := range hostports { + conn, err = dc.dialConn(u.Scheme, pair) + if conn != nil { + return conn, err + } + } + + return +} + +func (dc *DialContext) dialConn(scheme, target string) (net.Conn, error) { + switch scheme { + case "ldap": + return dc.d.Dial("tcp", target) + case "ldaps": + return tls.DialWithDialer(dc.d, "tcp", target, dc.tc) } - return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme) + return nil, fmt.Errorf("Unknown scheme '%s'", scheme) } // Dial connects to the given address on the given network using net.Dial