Skip to content

Commit

Permalink
restructure the code
Browse files Browse the repository at this point in the history
This makes the context just an accessor to underlying state machine
code
  • Loading branch information
muhamadazmy committed Mar 15, 2024
1 parent 8f8bf82 commit 13d533d
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 138 deletions.
24 changes: 17 additions & 7 deletions internal/state/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,17 @@ type serviceCall struct {
method string
}

func (c *serviceCall) makeRequest(key string, body any) ([]byte, error) {
// Do makes a call and wait for the response
func (c *serviceCall) Do(key string, body any) ([]byte, error) {
return c.machine.doCall(c.service, c.method, key, body)
}

// Send runs a call in the background after delay duration
func (c *serviceCall) Send(key string, body any, delay time.Duration) error {
return c.machine.sendCall(c.service, c.method, key, body, delay)
}

func (c *Machine) makeRequest(key string, body any) ([]byte, error) {

input, err := json.Marshal(body)
if err != nil {
Expand All @@ -58,7 +68,7 @@ func (c *serviceCall) makeRequest(key string, body any) ([]byte, error) {
return proto.Marshal(params)
}

func (c *serviceCall) Do(key string, body any) ([]byte, error) {
func (c *Machine) doCall(service, method, key string, body any) ([]byte, error) {
c.mutex.Lock()
defer c.mutex.Unlock()

Expand All @@ -68,8 +78,8 @@ func (c *serviceCall) Do(key string, body any) ([]byte, error) {
}

err = c.protocol.Write(&protocol.InvokeEntryMessage{
ServiceName: c.service,
MethodName: c.method,
ServiceName: service,
MethodName: method,
Parameter: input,
})

Expand Down Expand Up @@ -110,7 +120,7 @@ func (c *serviceCall) Do(key string, body any) ([]byte, error) {
return rpcResponse.Response.MarshalJSON()
}

func (c *serviceCall) Send(key string, body any, delay time.Duration) error {
func (c *Machine) sendCall(service, method, key string, body any, delay time.Duration) error {
c.mutex.Lock()
defer c.mutex.Unlock()

Expand All @@ -125,8 +135,8 @@ func (c *serviceCall) Send(key string, body any, delay time.Duration) error {
}

err = c.protocol.Write(&protocol.BackgroundInvokeEntryMessage{
ServiceName: c.service,
MethodName: c.method,
ServiceName: service,
MethodName: method,
Parameter: input,
InvokeTime: invokeTime,
})
Expand Down
168 changes: 40 additions & 128 deletions internal/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,138 +34,35 @@ var (
)

type Context struct {
ctx context.Context

partial bool
current map[string][]byte

protocol *wire.Protocol
mutex sync.Mutex
ctx context.Context
machine *Machine
}

func (c *Context) Ctx() context.Context {
return c.ctx
}

func (c *Context) Set(key string, value []byte) error {
c.mutex.Lock()
defer c.mutex.Unlock()

c.current[key] = value

return c.protocol.Write(
&protocol.SetStateEntryMessage{
Key: []byte(key),
Value: value,
})
return c.machine.set(key, value)
}

func (c *Context) Clear(key string) error {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.machine.clear(key)

return c.protocol.Write(
&protocol.ClearStateEntryMessage{
Key: []byte(key),
},
)
}

// ClearAll drops all associated keys
func (c *Context) ClearAll() error {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.machine.clearAll()

return c.protocol.Write(
&protocol.ClearAllStateEntryMessage{},
)
}

func (c *Context) Get(key string) ([]byte, error) {
msg := &protocol.GetStateEntryMessage{
Key: []byte(key),
}

c.mutex.Lock()
defer c.mutex.Unlock()

value, ok := c.current[key]

if ok {
// value in map, we still send the current
// value to the runtime
msg.Result = &protocol.GetStateEntryMessage_Value{
Value: value,
}

if err := c.protocol.Write(msg); err != nil {
return nil, err
}

return value, nil
}

// key is not in map! there are 2 cases.
if !c.partial {
// current is complete. we need to return nil to the user
// but also send an empty get state entry message
msg.Result = &protocol.GetStateEntryMessage_Empty{}

if err := c.protocol.Write(msg); err != nil {
return nil, err
}

return nil, nil
}

if err := c.protocol.Write(msg); err != nil {
return nil, err
}

// wait for completion
response, err := c.protocol.Read()
if err != nil {
return nil, err
}

if response.Type() != wire.CompletionMessageType {
return nil, ErrUnexpectedMessage
}

completion := response.(*wire.CompletionMessage)

switch value := completion.Payload.Result.(type) {
case *protocol.CompletionMessage_Empty:
return nil, nil
case *protocol.CompletionMessage_Failure:
// the get state entry message is not failable so this should
// never happen
return nil, fmt.Errorf("[%d] %s", value.Failure.Code, value.Failure.Message)
case *protocol.CompletionMessage_Value:
c.current[key] = value.Value
return value.Value, nil
}

return nil, fmt.Errorf("unreachable")
return c.machine.get(key)
}

func (c *Context) Sleep(until time.Time) error {
if err := c.protocol.Write(&protocol.SleepEntryMessage{
WakeUpTime: uint64(until.UnixMilli()),
}); err != nil {
return err
}

response, err := c.protocol.Read()
if err != nil {
return err
}

if response.Type() != wire.CompletionMessageType {
return ErrUnexpectedMessage
}

return nil
return c.machine.sleep(until)
}

func (c *Context) Service(service string) restate.Service {
Expand All @@ -175,22 +72,16 @@ func (c *Context) Service(service string) restate.Service {
}
}

func newContext(inner context.Context, protocol *wire.Protocol, start *wire.StartMessage) *Context {
log.Debug().
Bool("partial-state", start.Payload.PartialState).
Int("state-len", len(start.Payload.StateMap)).
Msg("start message")
func newContext(inner context.Context, machine *Machine) *Context {

state := make(map[string][]byte)
for _, entry := range start.Payload.StateMap {
state[string(entry.Key)] = entry.Value
}
// state := make(map[string][]byte)
// for _, entry := range start.Payload.StateMap {
// state[string(entry.Key)] = entry.Value
// }

ctx := &Context{
ctx: inner,
partial: start.Payload.PartialState,
current: state,
protocol: protocol,
ctx: inner,
machine: machine,
}

return ctx
Expand All @@ -199,15 +90,22 @@ func newContext(inner context.Context, protocol *wire.Protocol, start *wire.Star
type Machine struct {
handler restate.Handler
protocol *wire.Protocol
mutex sync.Mutex

// state
id []byte

partial bool
current map[string][]byte

entries []wire.Message
}

func NewMachine(handler restate.Handler, conn io.ReadWriter) *Machine {
return &Machine{
handler: handler,
protocol: wire.NewProtocol(conn),
current: make(map[string][]byte),
}
}

Expand Down Expand Up @@ -280,7 +178,12 @@ func (m *Machine) invoke(ctx *Context, input *dynrpc.RpcRequest) error {
return m.protocol.Write(output)
}

func (m *Machine) process(ctx *Context) error {
func (m *Machine) process(ctx *Context, start *wire.StartMessage) error {
m.id = start.Payload.Id

for _, entry := range start.Payload.StateMap {
m.current[string(entry.Key)] = entry.Value
}

// expect input message
msg, err := m.protocol.Read()
Expand All @@ -292,6 +195,16 @@ func (m *Machine) process(ctx *Context) error {
return ErrUnexpectedMessage
}

log.Debug().Uint32("known entries", start.Payload.KnownEntries).Msg("known entires")
for i := uint32(1); i < start.Payload.KnownEntries; i++ {
msg, err := m.protocol.Read()
if err != nil {
return fmt.Errorf("failed to read entry: %w", err)
}

log.Debug().Uint16("type", uint16(msg.Type())).Msg("received entry")
}

inputMsg := msg.(*wire.PollInputEntry)
value := inputMsg.Payload.GetValue()
var input dynrpc.RpcRequest
Expand Down Expand Up @@ -321,9 +234,8 @@ func (m *Machine) Start(inner context.Context) error {
return ErrInvalidVersion
}

m.id = start.Payload.Id
ctx := newContext(inner, m)

ctx := newContext(inner, m.protocol, start)
log.Debug().Str("id", base64.URLEncoding.EncodeToString(m.id)).Msg("start invocation")
return m.process(ctx)
log.Trace().Str("id", base64.URLEncoding.EncodeToString(m.id)).Msg("start invocation")
return m.process(ctx, start)
}
Loading

0 comments on commit 13d533d

Please sign in to comment.