Skip to content

Commit

Permalink
State machine (#1)
Browse files Browse the repository at this point in the history
* WIP: support replay

* support replay

* implement get all keys

* add ci

* add missing go.sum
  • Loading branch information
muhamadazmy authored Mar 16, 2024
1 parent 13d533d commit ec9ad33
Show file tree
Hide file tree
Showing 11 changed files with 472 additions and 64 deletions.
21 changes: 21 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Go
on: [push]

jobs:
build:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v4
with:
go-version: "1.22.x"
- name: Install dependencies
run: go get .
- name: Vet
run: go vet -v ./...
- name: Build
run: go build -v ./...
- name: Test with the Go CLI
run: go test -v ./...
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
go.sum
34 changes: 34 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0 h1:zZg03nifrj6ayWNaDO8tNj57tqrOIKDwiBaLkhtK7Kk=
github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0/go.mod h1:bblJa8QcHntareAJYfLJUzLj42sUFBKCBeTDK5LyUrw=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0=
github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I=
google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
5 changes: 3 additions & 2 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import (
"google.golang.org/protobuf/types/known/structpb"
)

// Empty struct used as a quick shortcut when you don't care about
// the input or output type
// Void is a placeholder used usually for functions that their signature require that
// you accept an input or return an output but the function implementation does not
// require them
type Void struct{}

func (v Void) MarshalJSON() ([]byte, error) {
Expand Down
64 changes: 52 additions & 12 deletions internal/state/call.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package state

import (
"bytes"
"encoding/json"
"fmt"
"time"
Expand Down Expand Up @@ -69,18 +70,39 @@ func (c *Machine) makeRequest(key string, body any) ([]byte, error) {
}

func (c *Machine) doCall(service, method, key string, body any) ([]byte, error) {
c.mutex.Lock()
defer c.mutex.Unlock()

input, err := c.makeRequest(key, body)
params, err := c.makeRequest(key, body)
if err != nil {
return nil, err
}

err = c.protocol.Write(&protocol.InvokeEntryMessage{
return replayOrNew(
c,
wire.InvokeEntryMessageType,
func(entry *wire.InvokeEntryMessage) ([]byte, error) {
if entry.Payload.ServiceName != service ||
entry.Payload.MethodName != method ||
!bytes.Equal(entry.Payload.Parameter, params) {
return nil, errEntryMismatch
}

switch result := entry.Payload.Result.(type) {
case *protocol.InvokeEntryMessage_Failure:
return nil, fmt.Errorf("[%d] %s", result.Failure.Code, result.Failure.Message)
case *protocol.InvokeEntryMessage_Value:
return result.Value, nil
}

return nil, errUnreachable
}, func() ([]byte, error) {
return c._doCall(service, method, params)
})
}

func (c *Machine) _doCall(service, method string, params []byte) ([]byte, error) {
err := c.protocol.Write(&protocol.InvokeEntryMessage{
ServiceName: service,
MethodName: method,
Parameter: input,
Parameter: params,
})

if err != nil {
Expand Down Expand Up @@ -121,23 +143,41 @@ func (c *Machine) doCall(service, method, key string, body any) ([]byte, error)
}

func (c *Machine) sendCall(service, method, key string, body any, delay time.Duration) error {
c.mutex.Lock()
defer c.mutex.Unlock()

input, err := c.makeRequest(key, body)
params, err := c.makeRequest(key, body)
if err != nil {
return err
}

_, err = replayOrNew(
c,
wire.BackgroundInvokeEntryMessageType,
func(entry *wire.BackgroundInvokeEntryMessage) (restate.Void, error) {
if entry.Payload.ServiceName != service ||
entry.Payload.MethodName != method ||
!bytes.Equal(entry.Payload.Parameter, params) {
return restate.Void{}, errEntryMismatch
}

return restate.Void{}, nil
},
func() (restate.Void, error) {
return restate.Void{}, c._sendCall(service, method, params, delay)
},
)

return err
}

func (c *Machine) _sendCall(service, method string, params []byte, delay time.Duration) error {
var invokeTime uint64
if delay != 0 {
invokeTime = uint64(time.Now().Add(delay).UnixMilli())
}

err = c.protocol.Write(&protocol.BackgroundInvokeEntryMessage{
err := c.protocol.Write(&protocol.BackgroundInvokeEntryMessage{
ServiceName: service,
MethodName: method,
Parameter: input,
Parameter: params,
InvokeTime: invokeTime,
})

Expand Down
110 changes: 85 additions & 25 deletions internal/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ const (
var (
ErrUnexpectedMessage = fmt.Errorf("unexpected message")
ErrInvalidVersion = fmt.Errorf("invalid version number")

errUnreachable = fmt.Errorf("unreachable")
)

var (
Expand Down Expand Up @@ -61,6 +59,10 @@ func (c *Context) Get(key string) ([]byte, error) {
return c.machine.get(key)
}

func (c *Context) Keys() ([]string, error) {
return c.machine.keys()
}

func (c *Context) Sleep(until time.Time) error {
return c.machine.sleep(until)
}
Expand Down Expand Up @@ -98,7 +100,8 @@ type Machine struct {
partial bool
current map[string][]byte

entries []wire.Message
entries []wire.Message
entryIndex int
}

func NewMachine(handler restate.Handler, conn io.ReadWriter) *Machine {
Expand All @@ -109,11 +112,33 @@ func NewMachine(handler restate.Handler, conn io.ReadWriter) *Machine {
}
}

func (m *Machine) output(r *dynrpc.RpcResponse, err error) proto.Message {
// TODO: if err is terminal return outputStreamEntryMessage but if error is
// not terminal, return ErrorMessage instead.
//var output protocol.OutputStreamEntryMessage
// Start starts the state machine
func (m *Machine) Start(inner context.Context) error {
// reader starts a rea
msg, err := m.protocol.Read()
if err != nil {
return err
}

if msg.Type() != wire.StartMessageType {
// invalid negotiation
return ErrUnexpectedMessage
}

start := msg.(*wire.StartMessage)

if start.Version != Version {
return ErrInvalidVersion
}

ctx := newContext(inner, m)

log.Trace().Str("id", base64.URLEncoding.EncodeToString(m.id)).Msg("start invocation")
return m.process(ctx, start)
}

// handle handler response and build proper response message
func (m *Machine) output(r *dynrpc.RpcResponse, err error) proto.Message {
if err != nil && restate.IsTerminalError(err) {
// terminal errors.
return &protocol.OutputStreamEntryMessage{
Expand Down Expand Up @@ -195,14 +220,18 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error {
return ErrUnexpectedMessage
}

log.Debug().Uint32("known entries", start.Payload.KnownEntries).Msg("known entires")
log.Trace().Uint32("known entries", start.Payload.KnownEntries).Msg("known entires")
m.entries = make([]wire.Message, 0, start.Payload.KnownEntries-1)

// we don't track the poll input entry
for i := uint32(1); i < start.Payload.KnownEntries; i++ {
msg, err := m.protocol.Read()
if err != nil {
return fmt.Errorf("failed to read entry: %w", err)
}

log.Debug().Uint16("type", uint16(msg.Type())).Msg("received entry")
log.Trace().Uint16("type", uint16(msg.Type())).Msg("replay log entry")
m.entries = append(m.entries, msg)
}

inputMsg := msg.(*wire.PollInputEntry)
Expand All @@ -216,26 +245,57 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error {

}

func (m *Machine) Start(inner context.Context) error {
// reader starts a rea
msg, err := m.protocol.Read()
if err != nil {
return err
func (c *Machine) currentEntry() (wire.Message, bool) {
if c.entryIndex < len(c.entries) {
return c.entries[c.entryIndex], true
}

if msg.Type() != wire.StartMessageType {
// invalid negotiation
return ErrUnexpectedMessage
}
return nil, false
}

start := msg.(*wire.StartMessage)
// replayOrNew is a utility function to easily either
// replay a log entry, or create a new one if one
// does not exist
//
// this should be an instance method on Machine but unfortunately
// go does not support generics on instance methods
//
// the idea is when called, it will check if there is a log
// entry at current index, then compare the entry message type
// if not matching, that's obviously an error with the code version
// (code has changed and now doesn't match the play log)
//
// if type is okay, the function will then call a `replay“ callback.
// the replay callback just need to extract the result from the entry
//
// otherwise this function will call a `new` callback to create a new entry in the log
// by sending the proper runtime messages
func replayOrNew[M wire.Message, O any](
m *Machine,
typ wire.Type,
replay func(msg M) (O, error),
new func() (O, error),
) (output O, err error) {

m.mutex.Lock()
defer m.mutex.Unlock()

if start.Version != Version {
return ErrInvalidVersion
}
defer func() {
m.entryIndex += 1
}()

ctx := newContext(inner, m)
// check if there is an entry as this index
entry, ok := m.currentEntry()

log.Trace().Str("id", base64.URLEncoding.EncodeToString(m.id)).Msg("start invocation")
return m.process(ctx, start)
// if entry exists, we need to replay it
// by calling the replay function
if ok {
if entry.Type() != typ {
return output, errEntryMismatch
}
return replay(entry.(M))
}

// other wise call the new function
return new()
}
Loading

0 comments on commit ec9ad33

Please sign in to comment.