Skip to content

Commit

Permalink
suspend if sleep is more than 1 second
Browse files Browse the repository at this point in the history
  • Loading branch information
muhamadazmy committed Mar 20, 2024
1 parent 527899a commit bc7f67d
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 21 deletions.
2 changes: 1 addition & 1 deletion internal/state/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (m *Machine) _doCall(service, method string, params []byte) ([]byte, error)
}

if response.Type() != wire.CompletionMessageType {
return nil, ErrUnexpectedMessage
return nil, wire.ErrUnexpectedMessage
}

completion := response.(*wire.CompletionMessage)
Expand Down
53 changes: 40 additions & 13 deletions internal/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,19 @@ const (
)

var (
ErrUnexpectedMessage = fmt.Errorf("unexpected message")
ErrInvalidVersion = fmt.Errorf("invalid version number")
ErrInvalidVersion = fmt.Errorf("invalid version number")
)

var (
_ restate.Context = (*Context)(nil)
)

// suspend is a struct we use to throw in a panic so we can rewind the stack
// then specially handle to suspend the invocation
type suspend struct {
resumeEntry uint32
}

type Context struct {
ctx context.Context
machine *Machine
Expand Down Expand Up @@ -138,7 +143,7 @@ func (m *Machine) Start(inner context.Context, trace string) error {

if msg.Type() != wire.StartMessageType {
// invalid negotiation
return ErrUnexpectedMessage
return wire.ErrUnexpectedMessage
}

start := msg.(*wire.StartMessage)
Expand Down Expand Up @@ -203,24 +208,46 @@ func (m *Machine) invoke(ctx *Context, input *dynrpc.RpcRequest) error {
// always terminate the invocation with
// an end message.
// this will always terminate the connection
defer m.protocol.Write(&protocol.EndMessage{})

defer func() {
if err := recover(); err != nil {
// handle service panic
// safely
// recover will return a non-nil object
// if there was a panic
//
recovered := recover()

switch typ := recovered.(type) {
case nil:
// nothing to do, just send end message and exit
break
case *suspend:
// suspend object with thrown. we need to send a suspension
// message. then terminate the connection
m.log.Debug().Msg("suspending invocation")
err := m.protocol.Write(&protocol.SuspensionMessage{
EntryIndexes: []uint32{typ.resumeEntry},
})

// this should become a retry error ErrorMessage
wErr := m.protocol.Write(&protocol.ErrorMessage{
if err != nil {
m.log.Error().Err(err).Msg("error sending failure message")
}
return
default:
// unknown panic!
// send an error message (retryable)
err := m.protocol.Write(&protocol.ErrorMessage{
Code: uint32(restate.INTERNAL),
Message: fmt.Sprint(err),
Message: fmt.Sprint(typ),
Description: string(debug.Stack()),
})

if wErr != nil {
m.log.Error().Err(wErr).Msg("error sending failure message")
if err != nil {
m.log.Error().Err(err).Msg("error sending failure message")
}
}

if err := m.protocol.Write(&protocol.EndMessage{}); err != nil {
m.log.Error().Err(err).Msg("error sending end message")
}
}()

output := m.output(m.handler.Call(ctx, input))
Expand All @@ -240,7 +267,7 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error {
}

if msg.Type() != wire.PollInputEntryMessageType {
return ErrUnexpectedMessage
return wire.ErrUnexpectedMessage
}

m.log.Trace().Uint32("known entries", start.Payload.KnownEntries).Msg("known entires")
Expand Down
22 changes: 18 additions & 4 deletions internal/state/sys.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func (m *Machine) _get(key string) ([]byte, error) {
}

if response.Type() != wire.CompletionMessageType {
return nil, ErrUnexpectedMessage
return nil, wire.ErrUnexpectedMessage
}

completion := response.(*wire.CompletionMessage)
Expand Down Expand Up @@ -243,7 +243,7 @@ func (m *Machine) _keys() ([]string, error) {

if msg.Type() != wire.CompletionMessageType {
m.log.Error().Stringer("type", msg.Type()).Msg("receiving message of type")
return nil, ErrUnexpectedMessage
return nil, wire.ErrUnexpectedMessage
}

response := msg.(*wire.CompletionMessage)
Expand Down Expand Up @@ -288,20 +288,34 @@ func (m *Machine) sleep(until time.Time) error {
return err
}

// _sleep creating a new sleep entry. The implementation of this function
// will also suspend execution if sleep duration is greater than 1 second
// as a form of optimization
func (m *Machine) _sleep(until time.Time) error {
if err := m.protocol.Write(&protocol.SleepEntryMessage{
WakeUpTime: uint64(until.UnixMilli()),
}); err != nil {
}, wire.FlagRequiresAck); err != nil {
return err
}

entryIndex, err := m.protocol.ReadAck()
if err != nil {
return err
}

// if duration is more than one second, just pause the execution
if time.Until(until) > time.Second {
panic(&suspend{entryIndex})
}

// we can suspend invocation now
response, err := m.protocol.Read()
if err != nil {
return err
}

if response.Type() != wire.CompletionMessageType {
return ErrUnexpectedMessage
return wire.ErrUnexpectedMessage
}

return nil
Expand Down
41 changes: 38 additions & 3 deletions internal/wire/wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,24 @@ import (
"google.golang.org/protobuf/proto"
)

var (
ErrUnexpectedMessage = fmt.Errorf("unexpected message")
)

const (
// masks
FlagCompleted Flag = 0x0001
FlagAck Flag = 0x8000
FlagCompleted Flag = 0x0001
FlagRequiresAck Flag = 0x8000

VersionMask = 0x03FF
)
const (
// control
StartMessageType Type = 0x0000
CompletionMessageType Type = 0x0000 + 1
SuspensionMessageType Type = 0x0000 + 2
ErrorMessageType Type = 0x0000 + 3
EntryAckMessageType Type = 0x0000 + 4
EndMessageType Type = 0x0000 + 5

// Input/Output
Expand Down Expand Up @@ -69,7 +75,7 @@ func (r Flag) Completed() bool {
}

func (r Flag) Ack() bool {
return r&FlagAck != 0
return r&FlagRequiresAck != 0
}

type Header struct {
Expand Down Expand Up @@ -133,6 +139,21 @@ func (s *Protocol) header() (header Header, err error) {
return
}

func (s *Protocol) ReadAck() (uint32, error) {
msg, err := s.Read()
if err != nil {
return 0, err
}

if msg.Type() != EntryAckMessageType {
return 0, ErrUnexpectedMessage
}

ack := msg.(*EntryAckMessage)

return ack.Payload.EntryIndex, nil
}

func (s *Protocol) Read() (Message, error) {
header, err := s.header()
if err != nil {
Expand Down Expand Up @@ -170,6 +191,8 @@ func (s *Protocol) Write(message proto.Message, flags ...Flag) error {
case *protocol.StartMessage:
// TODO: sdk should never write this message
typ = StartMessageType
case *protocol.SuspensionMessage:
typ = SuspensionMessageType
case *protocol.PollInputStreamEntryMessage:
typ = PollInputEntryMessageType
case *protocol.OutputStreamEntryMessage:
Expand Down Expand Up @@ -241,6 +264,13 @@ var (

return msg, proto.Unmarshal(bytes, &msg.Payload)
},
EntryAckMessageType: func(header Header, bytes []byte) (Message, error) {
msg := &EntryAckMessage{
Header: header,
}

return msg, proto.Unmarshal(bytes, &msg.Payload)
},
PollInputEntryMessageType: func(header Header, bytes []byte) (Message, error) {
msg := &PollInputEntry{
Header: header,
Expand Down Expand Up @@ -393,3 +423,8 @@ type SideEffectEntryMessage struct {
Header
Payload javascript.SideEffectEntryMessage
}

type EntryAckMessage struct {
Header
Payload protocol.EntryAckMessage
}

0 comments on commit bc7f67d

Please sign in to comment.