Skip to content

Commit

Permalink
docstore/memdocstore: #3508 allow nested slices query
Browse files Browse the repository at this point in the history
  • Loading branch information
eqinox76 committed Dec 10, 2024
1 parent 5695484 commit 02ce845
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 22 deletions.
2 changes: 1 addition & 1 deletion docstore/memdocstore/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func decodeDoc(m storedDoc, ddoc driver.Document, fps [][]string) error {
// (We don't need the key field because ddoc must already have it.)
m2 = map[string]interface{}{}
for _, fp := range fps {
val, err := getAtFieldPath(m, fp)
val, err := getAtFieldPath(m, fp, false)
if err != nil {
if gcerrors.Code(err) == gcerrors.NotFound {
continue
Expand Down
48 changes: 32 additions & 16 deletions docstore/memdocstore/mem.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ type Options struct {
// When the collection is closed, its contents are saved to the file.
Filename string

// AllowNestedSlicesQuery allows querying with nested slices.
// This makes the memdocstore more compatible with MongoDB,
// but other providers may not support this feature.
// See https://github.com/google/go-cloud/pull/3511 for more details.
AllowNestedSlicesQuery bool

// Call this function when the collection is closed.
// For internal use only.
onClose func()
Expand Down Expand Up @@ -399,16 +405,34 @@ func (c *collection) checkRevision(arg driver.Document, current storedDoc) error

// getAtFieldPath gets the value of m at fp. It returns an error if fp is invalid
// (see getParentMap).
func getAtFieldPath(m map[string]interface{}, fp []string) (interface{}, error) {
m2, err := getParentMap(m, fp, false)
if err != nil {
return nil, err
func getAtFieldPath(m map[string]interface{}, fp []string, nested bool) (result interface{}, err error) {

var get func(m interface{}, name string) interface{}
get = func(m interface{}, name string) interface{} {
switch concrete := m.(type) {
case map[string]interface{}:
return concrete[name]
case []interface{}:
if !nested {
return nil
}
result := []interface{}{}
for _, e := range concrete {
result = append(result, get(e, name))
}
return result
}
return nil
}
v, ok := m2[fp[len(fp)-1]]
if ok {
return v, nil
result = m
for _, k := range fp {
next := get(result, k)
if next == nil {
return nil, gcerr.Newf(gcerr.NotFound, nil, "field %s not found", strings.Join(fp, "."))
}
result = next
}
return nil, gcerr.Newf(gcerr.NotFound, nil, "field %s not found", fp)
return result, nil
}

// setAtFieldPath sets m's value at fp to val. It creates intermediate maps as
Expand All @@ -422,14 +446,6 @@ func setAtFieldPath(m map[string]interface{}, fp []string, val interface{}) erro
return nil
}

// Delete the value from m at the given field path, if it exists.
func deleteAtFieldPath(m map[string]interface{}, fp []string) {
m2, _ := getParentMap(m, fp, false) // ignore error
if m2 != nil {
delete(m2, fp[len(fp)-1])
}
}

// getParentMap returns the map that directly contains the given field path;
// that is, the value of m at the field path that excludes the last component
// of fp. If a non-map is encountered along the way, an InvalidArgument error is
Expand Down
44 changes: 44 additions & 0 deletions docstore/memdocstore/mem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package memdocstore

import (
"context"
"io"
"os"
"path/filepath"
"testing"
Expand Down Expand Up @@ -129,6 +130,49 @@ func TestUpdateAtomic(t *testing.T) {
}
}

func TestQueryNested(t *testing.T) {
ctx := context.Background()

count := func(iter *docstore.DocumentIterator) (c int) {
doc := docmap{}
for {
if err := iter.Next(ctx, doc); err != nil {
if err == io.EOF {
break
}
t.Fatal(err)
}
c++
}
return c
}

dc, err := newCollection(drivertest.KeyField, nil, &Options{AllowNestedSlicesQuery: true})
if err != nil {
t.Fatal(err)
}
coll := docstore.NewCollection(dc)
defer coll.Close()

doc := docmap{drivertest.KeyField: "TestQueryNested",
"list": []any{docmap{"a": "A"}},
"map": docmap{"b": "B"},
dc.RevisionField(): nil,
}
if err := coll.Put(ctx, doc); err != nil {
t.Fatal(err)
}

got := count(coll.Query().Where("list.a", "=", "A").Get(ctx))
if got != 1 {
t.Errorf("got %v docs when filtering by list.a, want 1", got)
}
got = count(coll.Query().Where("map.b", "=", "B").Get(ctx))
if got != 1 {
t.Errorf("got %v docs when filtering by map.b, want 1", got)
}
}

func TestSortDocs(t *testing.T) {
newDocs := func() []storedDoc {
return []storedDoc{
Expand Down
23 changes: 18 additions & 5 deletions docstore/memdocstore/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (c *collection) RunGetQuery(_ context.Context, q *driver.Query) (driver.Doc

var resultDocs []storedDoc
for _, doc := range c.docs {
if filtersMatch(q.Filters, doc) {
if filtersMatch(q.Filters, doc, c.opts.AllowNestedSlicesQuery) {
resultDocs = append(resultDocs, doc)
}
}
Expand Down Expand Up @@ -74,17 +74,17 @@ func (c *collection) RunGetQuery(_ context.Context, q *driver.Query) (driver.Doc
}, nil
}

func filtersMatch(fs []driver.Filter, doc storedDoc) bool {
func filtersMatch(fs []driver.Filter, doc storedDoc, nested bool) bool {
for _, f := range fs {
if !filterMatches(f, doc) {
if !filterMatches(f, doc, nested) {
return false
}
}
return true
}

func filterMatches(f driver.Filter, doc storedDoc) bool {
docval, err := getAtFieldPath(doc, f.FieldPath)
func filterMatches(f driver.Filter, doc storedDoc, nested bool) bool {
docval, err := getAtFieldPath(doc, f.FieldPath, nested)
// missing or bad field path => no match
if err != nil {
return false
Expand Down Expand Up @@ -138,6 +138,19 @@ func compare(x1, x2 interface{}) (int, bool) {
}
return -1, true
}
if v1.Kind() == reflect.Slice {
for i := 0; i < v1.Len(); i++ {
if c, ok := compare(x2, v1.Index(i).Interface()); ok {
if !ok {
return 0, false
}
if c == 0 {
return 0, true
}
}
}
return -1, true
}
if v1.Kind() == reflect.String && v2.Kind() == reflect.String {
return strings.Compare(v1.String(), v2.String()), true
}
Expand Down

0 comments on commit 02ce845

Please sign in to comment.