Skip to content

Commit

Permalink
add method to get paginated query results (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmenglund authored Nov 3, 2023
1 parent 5478c41 commit 1f17ba8
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 4 deletions.
33 changes: 33 additions & 0 deletions paginate/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,36 @@ func (p Paginator) Query(ctx context.Context, documentCh chan<- map[string]any,

return nil
}

// GetQueryResults gets all query results from the previously executed query with queryID,
// and sends the result to the document channel, which is closed once everything has been retrieved.
func (p Paginator) GetQueryResults(ctx context.Context, documentCh chan<- map[string]any, queryID string) error {
defer close(documentCh)

var cursor string
for {
var options []option.QueryResultOption
if cursor != "" {
options = append(options, option.WithQueryResultCursor(cursor))
}

res, err := p.rc.GetQueryResults(ctx, queryID, options...)
if err != nil {
return err
}

// TODO if the query hasn't finished running, this could optionally go into wait loop

// send documents from the current batch
for _, doc := range res.Results {
documentCh <- doc
}
cursor = res.Pagination.GetNextCursor()

if cursor == "" {
break
}
}

return nil
}
50 changes: 46 additions & 4 deletions paginate/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestQueryPaginated_integration(t *testing.T) {
wg.Add(1)
var count int
go func() {
for _ = range docs {
for range docs {
count++
}
wg.Done()
Expand Down Expand Up @@ -66,7 +66,7 @@ func TestQueryPaginated(t *testing.T) {
wg.Add(1)
var count int
go func() {
for _ = range docs {
for range docs {
count++
}
wg.Done()
Expand Down Expand Up @@ -95,7 +95,7 @@ func TestQueryPaginated_queryError(t *testing.T) {
wg.Add(1)
var count int
go func() {
for _ = range docs {
for range docs {
count++
}
wg.Done()
Expand Down Expand Up @@ -132,7 +132,7 @@ func TestQueryPaginated_paginationError(t *testing.T) {
wg.Add(1)
var count int
go func() {
for _ = range docs {
for range docs {
count++
}
wg.Done()
Expand All @@ -146,3 +146,45 @@ func TestQueryPaginated_paginationError(t *testing.T) {
assert.Equal(t, 1, rc.QueryCallCount())
assert.Equal(t, 1, rc.GetQueryResultsCallCount())
}

func TestPaginator_GetQueryResults(t *testing.T) {
ctx := test.Context()

rc := &fake.FakeRockClient{}
rc.GetQueryResultsReturnsOnCall(0, openapi.QueryPaginationResponse{
Pagination: &openapi.PaginationInfo{
NextCursor: openapi.PtrString("foo"),
},
Results: []map[string]interface{}{
{"0": "0", "1": "1"},
},
}, nil)
rc.GetQueryResultsReturnsOnCall(1, openapi.QueryPaginationResponse{
Pagination: &openapi.PaginationInfo{
NextCursor: openapi.PtrString(""),
},
Results: []map[string]interface{}{
{"2": "2", "3": "3"},
},
}, nil)
docs := make(chan map[string]any, 100)

wg := sync.WaitGroup{}
wg.Add(1)
var count int
go func() {
for range docs {
count++
}
wg.Done()
}()

p := paginate.New(rc)

err := p.GetQueryResults(ctx, docs, "id")
assert.NoError(t, err)

wg.Wait()
assert.Equal(t, 2, count)
assert.Equal(t, 2, rc.GetQueryResultsCallCount())
}

0 comments on commit 1f17ba8

Please sign in to comment.