From bc7f67d895384e5016cb9f09767ed0c339d5a608 Mon Sep 17 00:00:00 2001 From: Muhamad Azamy Date: Wed, 20 Mar 2024 10:11:55 +0100 Subject: [PATCH] suspend if sleep is more than 1 second --- internal/state/call.go | 2 +- internal/state/state.go | 53 +++++++++++++++++++++++++++++++---------- internal/state/sys.go | 22 +++++++++++++---- internal/wire/wire.go | 41 ++++++++++++++++++++++++++++--- 4 files changed, 97 insertions(+), 21 deletions(-) diff --git a/internal/state/call.go b/internal/state/call.go index d3db104..871efe8 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -145,7 +145,7 @@ func (m *Machine) _doCall(service, method string, params []byte) ([]byte, error) } if response.Type() != wire.CompletionMessageType { - return nil, ErrUnexpectedMessage + return nil, wire.ErrUnexpectedMessage } completion := response.(*wire.CompletionMessage) diff --git a/internal/state/state.go b/internal/state/state.go index b6c3afb..03c9417 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -24,14 +24,19 @@ const ( ) var ( - ErrUnexpectedMessage = fmt.Errorf("unexpected message") - ErrInvalidVersion = fmt.Errorf("invalid version number") + ErrInvalidVersion = fmt.Errorf("invalid version number") ) var ( _ restate.Context = (*Context)(nil) ) +// suspend is a struct we use to throw in a panic so we can rewind the stack +// then specially handle to suspend the invocation +type suspend struct { + resumeEntry uint32 +} + type Context struct { ctx context.Context machine *Machine @@ -138,7 +143,7 @@ func (m *Machine) Start(inner context.Context, trace string) error { if msg.Type() != wire.StartMessageType { // invalid negotiation - return ErrUnexpectedMessage + return wire.ErrUnexpectedMessage } start := msg.(*wire.StartMessage) @@ -203,24 +208,46 @@ func (m *Machine) invoke(ctx *Context, input *dynrpc.RpcRequest) error { // always terminate the invocation with // an end message. // this will always terminate the connection - defer m.protocol.Write(&protocol.EndMessage{}) defer func() { - if err := recover(); err != nil { - // handle service panic - // safely + // recover will return a non-nil object + // if there was a panic + // + recovered := recover() + + switch typ := recovered.(type) { + case nil: + // nothing to do, just send end message and exit + break + case *suspend: + // suspend object with thrown. we need to send a suspension + // message. then terminate the connection + m.log.Debug().Msg("suspending invocation") + err := m.protocol.Write(&protocol.SuspensionMessage{ + EntryIndexes: []uint32{typ.resumeEntry}, + }) - // this should become a retry error ErrorMessage - wErr := m.protocol.Write(&protocol.ErrorMessage{ + if err != nil { + m.log.Error().Err(err).Msg("error sending failure message") + } + return + default: + // unknown panic! + // send an error message (retryable) + err := m.protocol.Write(&protocol.ErrorMessage{ Code: uint32(restate.INTERNAL), - Message: fmt.Sprint(err), + Message: fmt.Sprint(typ), Description: string(debug.Stack()), }) - if wErr != nil { - m.log.Error().Err(wErr).Msg("error sending failure message") + if err != nil { + m.log.Error().Err(err).Msg("error sending failure message") } } + + if err := m.protocol.Write(&protocol.EndMessage{}); err != nil { + m.log.Error().Err(err).Msg("error sending end message") + } }() output := m.output(m.handler.Call(ctx, input)) @@ -240,7 +267,7 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error { } if msg.Type() != wire.PollInputEntryMessageType { - return ErrUnexpectedMessage + return wire.ErrUnexpectedMessage } m.log.Trace().Uint32("known entries", start.Payload.KnownEntries).Msg("known entires") diff --git a/internal/state/sys.go b/internal/state/sys.go index c6f9387..c5d3c1a 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -189,7 +189,7 @@ func (m *Machine) _get(key string) ([]byte, error) { } if response.Type() != wire.CompletionMessageType { - return nil, ErrUnexpectedMessage + return nil, wire.ErrUnexpectedMessage } completion := response.(*wire.CompletionMessage) @@ -243,7 +243,7 @@ func (m *Machine) _keys() ([]string, error) { if msg.Type() != wire.CompletionMessageType { m.log.Error().Stringer("type", msg.Type()).Msg("receiving message of type") - return nil, ErrUnexpectedMessage + return nil, wire.ErrUnexpectedMessage } response := msg.(*wire.CompletionMessage) @@ -288,20 +288,34 @@ func (m *Machine) sleep(until time.Time) error { return err } +// _sleep creating a new sleep entry. The implementation of this function +// will also suspend execution if sleep duration is greater than 1 second +// as a form of optimization func (m *Machine) _sleep(until time.Time) error { if err := m.protocol.Write(&protocol.SleepEntryMessage{ WakeUpTime: uint64(until.UnixMilli()), - }); err != nil { + }, wire.FlagRequiresAck); err != nil { return err } + entryIndex, err := m.protocol.ReadAck() + if err != nil { + return err + } + + // if duration is more than one second, just pause the execution + if time.Until(until) > time.Second { + panic(&suspend{entryIndex}) + } + + // we can suspend invocation now response, err := m.protocol.Read() if err != nil { return err } if response.Type() != wire.CompletionMessageType { - return ErrUnexpectedMessage + return wire.ErrUnexpectedMessage } return nil diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 2a087c7..19078d3 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -15,10 +15,14 @@ import ( "google.golang.org/protobuf/proto" ) +var ( + ErrUnexpectedMessage = fmt.Errorf("unexpected message") +) + const ( // masks - FlagCompleted Flag = 0x0001 - FlagAck Flag = 0x8000 + FlagCompleted Flag = 0x0001 + FlagRequiresAck Flag = 0x8000 VersionMask = 0x03FF ) @@ -26,7 +30,9 @@ const ( // control StartMessageType Type = 0x0000 CompletionMessageType Type = 0x0000 + 1 + SuspensionMessageType Type = 0x0000 + 2 ErrorMessageType Type = 0x0000 + 3 + EntryAckMessageType Type = 0x0000 + 4 EndMessageType Type = 0x0000 + 5 // Input/Output @@ -69,7 +75,7 @@ func (r Flag) Completed() bool { } func (r Flag) Ack() bool { - return r&FlagAck != 0 + return r&FlagRequiresAck != 0 } type Header struct { @@ -133,6 +139,21 @@ func (s *Protocol) header() (header Header, err error) { return } +func (s *Protocol) ReadAck() (uint32, error) { + msg, err := s.Read() + if err != nil { + return 0, err + } + + if msg.Type() != EntryAckMessageType { + return 0, ErrUnexpectedMessage + } + + ack := msg.(*EntryAckMessage) + + return ack.Payload.EntryIndex, nil +} + func (s *Protocol) Read() (Message, error) { header, err := s.header() if err != nil { @@ -170,6 +191,8 @@ func (s *Protocol) Write(message proto.Message, flags ...Flag) error { case *protocol.StartMessage: // TODO: sdk should never write this message typ = StartMessageType + case *protocol.SuspensionMessage: + typ = SuspensionMessageType case *protocol.PollInputStreamEntryMessage: typ = PollInputEntryMessageType case *protocol.OutputStreamEntryMessage: @@ -241,6 +264,13 @@ var ( return msg, proto.Unmarshal(bytes, &msg.Payload) }, + EntryAckMessageType: func(header Header, bytes []byte) (Message, error) { + msg := &EntryAckMessage{ + Header: header, + } + + return msg, proto.Unmarshal(bytes, &msg.Payload) + }, PollInputEntryMessageType: func(header Header, bytes []byte) (Message, error) { msg := &PollInputEntry{ Header: header, @@ -393,3 +423,8 @@ type SideEffectEntryMessage struct { Header Payload javascript.SideEffectEntryMessage } + +type EntryAckMessage struct { + Header + Payload protocol.EntryAckMessage +}