Skip to content

Commit

Permalink
add SearchAsync
Browse files Browse the repository at this point in the history
Signed-off-by: Adphi <[email protected]>
  • Loading branch information
Adphi committed Jan 28, 2023
1 parent f61ea45 commit 337a51e
Show file tree
Hide file tree
Showing 7 changed files with 401 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.idea
1 change: 1 addition & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ type Client interface {
PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error)

Search(*SearchRequest) (*SearchResult, error)
SearchAsync(searchRequest *SearchRequest, done chan struct{}) (<-chan *SearchAsyncResponse, error)
SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error)
}
84 changes: 84 additions & 0 deletions ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,90 @@ func TestSearch(t *testing.T) {
t.Logf("TestSearch: %s -> num of entries = %d", searchRequest.Filter, len(sr.Entries))
}

func TestSearchAsync(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
t.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[0],
attributes,
nil)

var entries []*Entry
responses, err := l.SearchAsync(searchRequest, nil)
if err != nil {
t.Fatal(err)
}
for res := range responses {
if err := res.Err(); err != nil {
t.Error(err)
break
}
if res.Closed() {
break
}
switch res.Type {
case SearchAsyncResponseTypeEntry:
entries = append(entries, res.Entry)
case SearchAsyncResponseTypeReferral:
t.Logf("Received Referral: %s", res.Referral)
case SearchAsyncResponseTypeControl:
t.Logf("Received Control: %s", res.Control)
}
}
t.Logf("TestSearch: %s -> num of entries = %d", searchRequest.Filter, len(entries))
}

func TestSearchAsyncStop(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
t.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[0],
attributes,
nil)

var entries []*Entry
done := make(chan struct{})
responses, err := l.SearchAsync(searchRequest, done)
if err != nil {
t.Fatal(err)
}
for res := range responses {
if err := res.Err(); err != nil {
t.Error(err)
break
}

if res.Closed() {
break
}
close(done)
switch res.Type {
case SearchAsyncResponseTypeEntry:
entries = append(entries, res.Entry)
case SearchAsyncResponseTypeReferral:
t.Logf("Received Referral: %s", res.Referral)
case SearchAsyncResponseTypeControl:
t.Logf("Received Control: %s", res.Control)
}
}
if len(entries) > 1 {
t.Errorf("Expected 1 entry, got %d", len(entries))
}
t.Logf("TestSearch: %s -> num of entries = %d", searchRequest.Filter, len(entries))
}

func TestSearchStartTLS(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
Expand Down
119 changes: 115 additions & 4 deletions search.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,42 @@ func (s *SearchResult) PrettyPrint(indent int) {
}
}

// SearchAsyncResponseType describes the SearchAsyncResponse content type
type SearchAsyncResponseType uint8

const (
SearchAsyncResponseTypeNone SearchAsyncResponseType = iota
SearchAsyncResponseTypeEntry
SearchAsyncResponseTypeReferral
SearchAsyncResponseTypeControl
)

// SearchAsyncResponse holds the server's response message to an async search request
type SearchAsyncResponse struct {
// Type indicates the SearchAsyncResponse type
Type SearchAsyncResponseType
// Entry is the received entry, only set if Type is SearchAsyncResponseTypeEntry
Entry *Entry
// Referral is the received referral, only set if Type is SearchAsyncResponseTypeReferral
Referral string
// Control is the received control, only set if Type is SearchAsyncResponseTypeControl
Control Control
// closed indicates that the request is finished
closed bool
// err holds the encountered error while processing server's response, if any
err error
}

// Closed returns true if the request is finished
func (r *SearchAsyncResponse) Closed() bool {
return r.closed
}

// Err returns the encountered error while processing server's response, if any
func (r *SearchAsyncResponse) Err() error {
return r.err
}

// SearchRequest represents a search request to send to the server
type SearchRequest struct {
BaseDN string
Expand Down Expand Up @@ -285,10 +321,11 @@ func NewSearchRequest(
// SearchWithPaging accepts a search request and desired page size in order to execute LDAP queries to fulfill the
// search request. All paged LDAP query responses will be buffered and the final result will be returned atomically.
// The following four cases are possible given the arguments:
// - given SearchRequest missing a control of type ControlTypePaging: we will add one with the desired paging size
// - given SearchRequest contains a control of type ControlTypePaging that isn't actually a ControlPaging: fail without issuing any queries
// - given SearchRequest contains a control of type ControlTypePaging with pagingSize equal to the size requested: no change to the search request
// - given SearchRequest contains a control of type ControlTypePaging with pagingSize not equal to the size requested: fail without issuing any queries
// - given SearchRequest missing a control of type ControlTypePaging: we will add one with the desired paging size
// - given SearchRequest contains a control of type ControlTypePaging that isn't actually a ControlPaging: fail without issuing any queries
// - given SearchRequest contains a control of type ControlTypePaging with pagingSize equal to the size requested: no change to the search request
// - given SearchRequest contains a control of type ControlTypePaging with pagingSize not equal to the size requested: fail without issuing any queries
//
// A requested pagingSize of 0 is interpreted as no limit by LDAP servers.
func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) {
var pagingControl *ControlPaging
Expand Down Expand Up @@ -402,6 +439,80 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
}
}

// SearchAsync performs the given search request asynchronously, it takes an optional done channel to stop the request. It returns a SearchAsyncResponse channel which will be
// closed when the request finished and an error, not nil if the request to the server failed
func (l *Conn) SearchAsync(searchRequest *SearchRequest, done chan struct{}) (<-chan *SearchAsyncResponse, error) {
if done == nil {
done = make(chan struct{})
}
msgCtx, err := l.doRequest(searchRequest)
if err != nil {
return nil, err
}
responses := make(chan *SearchAsyncResponse)
ch := make(chan *SearchAsyncResponse)
rcv := func() {
for {
packet, err := l.readPacket(msgCtx)
if err != nil {
ch <- &SearchAsyncResponse{closed: true, err: err}
return
}

switch packet.Children[1].Tag {
case 4:
entry := &Entry{
DN: packet.Children[1].Children[0].Value.(string),
Attributes: unpackAttributes(packet.Children[1].Children[1].Children),
}
ch <- &SearchAsyncResponse{Type: SearchAsyncResponseTypeEntry, Entry: entry}
case 5:
err := GetLDAPError(packet)
if err != nil {
ch <- &SearchAsyncResponse{closed: true, err: err}
return
}
var response SearchAsyncResponse
if len(packet.Children) == 3 {
for _, child := range packet.Children[2].Children {
decodedChild, err := DecodeControl(child)
if err != nil {
responses <- &SearchAsyncResponse{closed: true, err: fmt.Errorf("failed to decode child control: %s", err)}
return
}
response = SearchAsyncResponse{Type: SearchAsyncResponseTypeControl, Control: decodedChild}
}
}
response.closed = true
ch <- &response
return
case 19:
ch <- &SearchAsyncResponse{Type: SearchAsyncResponseTypeReferral, Referral: packet.Children[1].Children[0].Value.(string)}
}
}
}
go func() {
defer l.finishMessage(msgCtx)
defer close(responses)
go rcv()
for {
select {
case <-done:
responses <- &SearchAsyncResponse{
closed: true,
}
return
case res := <-ch:
responses <- res
if res.Closed() {
return
}
}
}
}()
return responses, nil
}

// unpackAttributes will extract all given LDAP attributes and it's values
// from the ber.Packet
func unpackAttributes(children []*ber.Packet) []*EntryAttribute {
Expand Down
1 change: 1 addition & 0 deletions v3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ type Client interface {
PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error)

Search(*SearchRequest) (*SearchResult, error)
SearchAsync(searchRequest *SearchRequest, done chan struct{}) (<-chan *SearchAsyncResponse, error)
SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error)
}
84 changes: 84 additions & 0 deletions v3/ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,90 @@ func TestSearch(t *testing.T) {
t.Logf("TestSearch: %s -> num of entries = %d", searchRequest.Filter, len(sr.Entries))
}

func TestSearchAsync(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
t.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[0],
attributes,
nil)

var entries []*Entry
responses, err := l.SearchAsync(searchRequest, nil)
if err != nil {
t.Fatal(err)
}
for res := range responses {
if err := res.Err(); err != nil {
t.Error(err)
break
}
if res.Closed() {
break
}
switch res.Type {
case SearchAsyncResponseTypeEntry:
entries = append(entries, res.Entry)
case SearchAsyncResponseTypeReferral:
t.Logf("Received Referral: %s", res.Referral)
case SearchAsyncResponseTypeControl:
t.Logf("Received Control: %s", res.Control)
}
}
t.Logf("TestSearch: %s -> num of entries = %d", searchRequest.Filter, len(entries))
}

func TestSearchAsyncStop(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
t.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[0],
attributes,
nil)

var entries []*Entry
done := make(chan struct{})
responses, err := l.SearchAsync(searchRequest, done)
if err != nil {
t.Fatal(err)
}
for res := range responses {
if err := res.Err(); err != nil {
t.Error(err)
break
}

if res.Closed() {
break
}
close(done)
switch res.Type {
case SearchAsyncResponseTypeEntry:
entries = append(entries, res.Entry)
case SearchAsyncResponseTypeReferral:
t.Logf("Received Referral: %s", res.Referral)
case SearchAsyncResponseTypeControl:
t.Logf("Received Control: %s", res.Control)
}
}
if len(entries) > 1 {
t.Errorf("Expected 1 entry, got %d", len(entries))
}
t.Logf("TestSearch: %s -> num of entries = %d", searchRequest.Filter, len(entries))
}

func TestSearchStartTLS(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
Expand Down
Loading

0 comments on commit 337a51e

Please sign in to comment.