Skip to content

Commit

Permalink
support atomic writes
Browse files Browse the repository at this point in the history
  • Loading branch information
sandeepvinayak committed Dec 7, 2024
1 parent 61f69d3 commit c95f790
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 48 deletions.
29 changes: 16 additions & 13 deletions docstore/awsdynamodb/dynamo.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,12 @@ func (c *collection) RevisionField() string { return c.opts.RevisionField }

func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, opts *driver.RunActionsOptions) driver.ActionListError {
errs := make([]error, len(actions))
beforeGets, gets, writes, afterGets := driver.GroupActions(actions)
beforeGets, gets, writes, writesTx, afterGets := driver.GroupActions(actions)
c.runGets(ctx, beforeGets, errs, opts)
ch := make(chan struct{})
ch2 := make(chan struct{})
go func() { defer close(ch); c.runWrites(ctx, writes, errs, opts) }()
go func() { defer close(ch2); c.transactWrite(ctx, writesTx, errs, opts) }()
c.runGets(ctx, gets, errs, opts)
<-ch
c.runGets(ctx, afterGets, errs, opts)
Expand Down Expand Up @@ -613,25 +615,26 @@ func revisionPrecondition(doc driver.Document, revField string) (*expression.Con
return &cb, nil
}

// TODO(jba): use this if/when we support atomic writes.
func (c *collection) transactWrite(ctx context.Context, actions []*driver.Action, errs []error, opts *driver.RunActionsOptions, start, end int) {
func (c *collection) transactWrite(ctx context.Context, actions []*driver.Action, errs []error, opts *driver.RunActionsOptions) {
if len(actions) == 0 {
return
}
setErr := func(err error) {
for i := start; i <= end; i++ {
errs[actions[i].Index] = err
for _, a := range actions {
errs[a.Index] = err
}
}

tws := make([]*dyn.TransactWriteItem, 0, len(actions))
var ops []*writeOp
tws := make([]*dyn.TransactWriteItem, 0, end-start+1)
for i := start; i <= end; i++ {
a := actions[i]
op, err := c.newWriteOp(a, opts)
for _, w := range actions {
op, err := c.newWriteOp(w, opts)
if err != nil {
setErr(err)
return
errs[w.Index] = err
} else {
ops = append(ops, op)
tws = append(tws, op.writeItem)
}
ops = append(ops, op)
tws = append(tws, op.writeItem)
}

in := &dyn.TransactWriteItemsInput{
Expand Down
39 changes: 24 additions & 15 deletions docstore/docstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,20 @@ func (c *Collection) Actions() *ActionList {
// document; a Get after the write will see the new value if the service is strongly
// consistent, but may see the old value if the service is eventually consistent.
type ActionList struct {
coll *Collection
actions []*Action
beforeDo func(asFunc func(interface{}) bool) error
coll *Collection
actions []*Action
enableAtomicWrites bool
beforeDo func(asFunc func(interface{}) bool) error
}

// An Action is a read or write on a single document.
// Use the methods of ActionList to create and execute Actions.
type Action struct {
kind driver.ActionKind
doc Document
fieldpaths []FieldPath // paths to retrieve, for Get
mods Mods // modifications to make, for Update
kind driver.ActionKind
doc Document
fieldpaths []FieldPath // paths to retrieve, for Get
mods Mods // modifications to make, for Update
inAtomicWrite bool // if this action is a part of atomic writes
}

func (l *ActionList) add(a *Action) *ActionList {
Expand All @@ -170,7 +172,7 @@ func (l *ActionList) add(a *Action) *ActionList {
// Except for setting the revision field and possibly setting the key fields, the doc
// argument is not modified.
func (l *ActionList) Create(doc Document) *ActionList {
return l.add(&Action{kind: driver.Create, doc: doc})
return l.add(&Action{kind: driver.Create, doc: doc, inAtomicWrite: l.enableAtomicWrites})
}

// Replace adds an action that replaces a document to the given ActionList, and
Expand All @@ -182,7 +184,7 @@ func (l *ActionList) Create(doc Document) *ActionList {
// See the Revisions section of the package documentation for how revisions are
// handled.
func (l *ActionList) Replace(doc Document) *ActionList {
return l.add(&Action{kind: driver.Replace, doc: doc})
return l.add(&Action{kind: driver.Replace, doc: doc, inAtomicWrite: l.enableAtomicWrites})
}

// Put adds an action that adds or replaces a document to the given ActionList, and returns the ActionList.
Expand All @@ -195,7 +197,7 @@ func (l *ActionList) Replace(doc Document) *ActionList {
// See the Revisions section of the package documentation for how revisions are
// handled.
func (l *ActionList) Put(doc Document) *ActionList {
return l.add(&Action{kind: driver.Put, doc: doc})
return l.add(&Action{kind: driver.Put, doc: doc, inAtomicWrite: l.enableAtomicWrites})
}

// Delete adds an action that deletes a document to the given ActionList, and returns
Expand All @@ -210,7 +212,7 @@ func (l *ActionList) Delete(doc Document) *ActionList {
// semantics of an action list are to stop at first error, then we might abort a
// list of Deletes just because one of the docs was not present, and that seems
// wrong, or at least something you'd want to turn off.
return l.add(&Action{kind: driver.Delete, doc: doc})
return l.add(&Action{kind: driver.Delete, doc: doc, inAtomicWrite: l.enableAtomicWrites})
}

// Get adds an action that retrieves a document to the given ActionList, and
Expand Down Expand Up @@ -252,9 +254,10 @@ func (l *ActionList) Get(doc Document, fps ...FieldPath) *ActionList {
// the updated document, call Get after calling Update.
func (l *ActionList) Update(doc Document, mods Mods) *ActionList {
return l.add(&Action{
kind: driver.Update,
doc: doc,
mods: mods,
kind: driver.Update,
doc: doc,
mods: mods,
inAtomicWrite: l.enableAtomicWrites,
})
}

Expand Down Expand Up @@ -430,7 +433,7 @@ func (c *Collection) toDriverAction(a *Action) (*driver.Action, error) {
// A Put with a revision field is equivalent to a Replace.
kind = driver.Replace
}
d := &driver.Action{Kind: kind, Doc: ddoc, Key: key}
d := &driver.Action{Kind: kind, Doc: ddoc, Key: key, InAtomicWrite: a.inAtomicWrite}
if a.fieldpaths != nil {
d.FieldPaths, err = parseFieldPaths(a.fieldpaths)
if err != nil {
Expand Down Expand Up @@ -534,6 +537,12 @@ func (l *ActionList) String() string {
return "[" + strings.Join(as, ", ") + "]"
}

// AtomicWrites causes all following writes in the list to execute atomically.
func (l *ActionList) AtomicWrites() *ActionList {
l.enableAtomicWrites = true
return l
}

func (a *Action) String() string {
buf := &strings.Builder{}
fmt.Fprintf(buf, "%s(%v", a.kind, a.doc)
Expand Down
14 changes: 7 additions & 7 deletions docstore/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ const (

//go:generate stringer -type=ActionKind

// An Action describes a single operation on a single document.
type Action struct {
Kind ActionKind // the kind of action
Doc Document // the document on which to perform the action
Key interface{} // the document key returned by Collection.Key, to avoid recomputing it
FieldPaths [][]string // field paths to retrieve, for Get only
Mods []Mod // modifications to make, for Update only
Index int // the index of the action in the original action list
Kind ActionKind // the kind of action
Doc Document // the document on which to perform the action
Key interface{} // the document key returned by Collection.Key, to avoid recomputing it
FieldPaths [][]string // field paths to retrieve, for Get only
Mods []Mod // modifications to make, for Update only
Index int // the index of the action in the original action list
InAtomicWrite bool // if this action is a part of transaction
}

// A Mod is a modification to a field path in a document.
Expand Down
22 changes: 18 additions & 4 deletions docstore/driver/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ func SplitActions(actions []*Action, split func(a, b *Action) bool) [][]*Action

// GroupActions separates actions into four sets: writes, gets that must happen before the writes,
// gets that must happen after the writes, and gets that can happen concurrently with the writes.
func GroupActions(actions []*Action) (beforeGets, getList, writeList, afterGets []*Action) {
func GroupActions(actions []*Action) (beforeGets, getList, writeList, writesTxList, afterGets []*Action) {
// maps from key to action
bgets := map[interface{}]*Action{}
agets := map[interface{}]*Action{}
cgets := map[interface{}]*Action{}
writes := map[interface{}]*Action{}
writesTx := map[interface{}]*Action{}
var nilkeys []*Action
for _, a := range actions {
if a.Key == nil {
Expand All @@ -69,7 +70,7 @@ func GroupActions(actions []*Action) (beforeGets, getList, writeList, afterGets
} else if a.Kind == Get {
// If there was a prior write with this key, make sure this get
// happens after the writes.
if _, ok := writes[a.Key]; ok {
if valueExistsInMaps(a.Key, writes, writesTx) {
agets[a.Key] = a
} else {
cgets[a.Key] = a
Expand All @@ -81,7 +82,11 @@ func GroupActions(actions []*Action) (beforeGets, getList, writeList, afterGets
delete(cgets, a.Key)
bgets[a.Key] = g
}
writes[a.Key] = a
if a.InAtomicWrite {
writesTx[a.Key] = a
} else {
writes[a.Key] = a
}
}
}

Expand All @@ -95,7 +100,16 @@ func GroupActions(actions []*Action) (beforeGets, getList, writeList, afterGets
return as
}

return vals(bgets), vals(cgets), append(vals(writes), nilkeys...), vals(agets)
return vals(bgets), vals(cgets), append(vals(writes), nilkeys...), vals(writesTx), vals(agets)
}

func valueExistsInMaps(key interface{}, maps ...map[interface{}]*Action) bool {
for _, m := range maps {
if _, ok := m[key]; ok {
return true
}
}
return false
}

// AsFunc creates and returns an "as function" that behaves as follows:
Expand Down
12 changes: 6 additions & 6 deletions docstore/driver/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestGroupActions(t *testing.T) {
}{
{
in: []*Action{{Kind: Get, Key: 1}},
want: [][]int{nil, {0}, nil, nil},
want: [][]int{nil, {0}, nil, nil, nil},
},
{
in: []*Action{
Expand All @@ -89,16 +89,16 @@ func TestGroupActions(t *testing.T) {
{Kind: Replace, Key: 2},
{Kind: Get, Key: 2},
},
want: [][]int{{0}, {1}, {2, 3}, {4}},
want: [][]int{{0}, {1}, {2, 3}, nil, {4}},
},
{
in: []*Action{{Kind: Create}, {Kind: Create}, {Kind: Create}},
want: [][]int{nil, nil, {0, 1, 2}, nil},
want: [][]int{nil, nil, {0, 1, 2}, nil, nil},
},
} {
got := make([][]*Action, 4)
got[0], got[1], got[2], got[3] = GroupActions(test.in)
want := make([][]*Action, 4)
got := make([][]*Action, 5)
got[0], got[1], got[2], got[3], got[4] = GroupActions(test.in)
want := make([][]*Action, 5)
for i, s := range test.want {
for _, x := range s {
want[i] = append(want[i], test.in[x])
Expand Down
76 changes: 76 additions & 0 deletions docstore/drivertest/drivertest.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"gocloud.dev/docstore"
"gocloud.dev/docstore/driver"
"gocloud.dev/gcerrors"
Expand Down Expand Up @@ -1900,6 +1901,81 @@ func testMultipleActions(t *testing.T, coll *docstore.Collection, revField strin
}
}

func testAtomicWrites(t *testing.T, coll *docstore.Collection, revField string) {
t.Helper()

ctx := context.Background()

must := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}

var docs []docmap
for i := 0; i < 9; i++ {
docs = append(docs, docmap{
KeyField: fmt.Sprintf("testAtomicWrites%d", i),
"s": fmt.Sprint(i),
revField: nil,
})
}

compare := func(gots, wants []docmap) {
t.Helper()
for i := 0; i < len(gots); i++ {
got := gots[i]
want := clone(wants[i])
want[revField] = got[revField]
if !cmp.Equal(got, want, cmpopts.IgnoreUnexported(tspb.Timestamp{})) {
t.Errorf("index #%d:\ngot %v\nwant %v", i, got, want)
}
}
}

// Put the first six docs.
actions := coll.Actions()
for i := 0; i < 6; i++ {
actions.Create(docs[i])
}
must(actions.Do(ctx))

// Delete the first three, get the second three, and update last three in transaction.
gdocs := []docmap{
{KeyField: docs[3][KeyField]},
{KeyField: docs[4][KeyField]},
{KeyField: docs[5][KeyField]},
}
actions = coll.Actions()
actions.Get(gdocs[0])
actions.Delete(docs[0])
actions.Delete(docs[1])
actions.Get(gdocs[1])
actions.Delete(docs[2])
actions.Get(gdocs[2])
actions.AtomicWrites()
actions.Update(docs[6], docstore.Mods{"s": "66'"})
actions.Update(docs[7], docstore.Mods{"s": "77'"})
actions.Update(docs[8], docstore.Mods{"s": "88"})

must(actions.Do(ctx))
compare(gdocs, docs[3:6])

// Get the docs updated as part of atomic writes and verify that got written.
actions = coll.Actions()

doc := docmap{KeyField: docs[6][KeyField]}
_ = coll.Get(ctx, doc)
assert.Equal(t, "66", doc["s"])
doc = docmap{KeyField: docs[7][KeyField]}
_ = coll.Get(ctx, doc)
assert.Equal(t, "77", doc["s"])
doc = docmap{KeyField: docs[8][KeyField]}
_ = coll.Get(ctx, doc)
assert.Equal(t, "88", doc["s"])
}

func testActionsOnStructNoRev(t *testing.T, _ Harness, coll *docstore.Collection) {
t.Helper()

Expand Down
2 changes: 1 addition & 1 deletion docstore/gcpfirestore/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func (c *collection) RevisionField() string {
// RunActions implements driver.RunActions.
func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, opts *driver.RunActionsOptions) driver.ActionListError {
errs := make([]error, len(actions))
beforeGets, gets, writes, afterGets := driver.GroupActions(actions)
beforeGets, gets, writes, _, afterGets := driver.GroupActions(actions)
calls := c.buildCommitCalls(writes, errs)
// runGets does not issue concurrent RPCs, so it doesn't need a throttle.
c.runGets(ctx, beforeGets, errs, opts)
Expand Down
2 changes: 1 addition & 1 deletion docstore/memdocstore/mem.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, o
}
}

beforeGets, gets, writes, afterGets := driver.GroupActions(actions)
beforeGets, gets, writes, _, afterGets := driver.GroupActions(actions)
run(beforeGets)
run(gets)
run(writes)
Expand Down
2 changes: 1 addition & 1 deletion docstore/mongodocstore/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ const mongoIDField = "_id"

func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, opts *driver.RunActionsOptions) driver.ActionListError {
errs := make([]error, len(actions))
beforeGets, gets, writes, afterGets := driver.GroupActions(actions)
beforeGets, gets, writes, _, afterGets := driver.GroupActions(actions)
c.runGets(ctx, beforeGets, errs, opts)
ch := make(chan []error)
go func() { ch <- c.bulkWrite(ctx, writes, errs, opts) }()
Expand Down
Loading

0 comments on commit c95f790

Please sign in to comment.