Skip to content

Commit f9751b5

Browse files
authored
Add Shutdown Notice from Compute Nodes (#4769)
When compute nodes shut down, they should notify the orchestrator with their final sequence numbers. While this info is already shared in heartbeats, a dedicated shutdown message confirms intentional shutdown vs connection failure and guarantees the orchestrator receives latest sequence numbers. This lets the orchestrator clean up node state immediately rather than waiting for missing heartbeats. Changes: - Added `sendShutdownNotification` method to ControlPlane for notifying orchestrator before shutdown - Modified Stop() to send notification if node is connected and context isn't cancelled - Added test cases covering successful notification, skipped notifications (when disconnected/cancelled), and error handling <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced constants for shutdown notice request and response message types. - Added structures for handling shutdown notice requests and responses. - Implemented a method for processing shutdown notifications in the node manager. - Enhanced control plane to send shutdown notifications to the orchestrator. - Updated the compute manager to handle shutdown requests from nodes. - **Bug Fixes** - Improved error handling and state management during shutdown operations. - **Tests** - Added comprehensive tests for shutdown functionality in node manager and control plane. - Enhanced mock responder to simulate shutdown notifications for testing purposes. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 86e8a96 commit f9751b5

File tree

10 files changed

+516
-0
lines changed

10 files changed

+516
-0
lines changed

pkg/models/messages/constants.go

+2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ const (
1313
HandshakeRequestMessageType = "transport.HandshakeRequest"
1414
HeartbeatRequestMessageType = "transport.HeartbeatRequest"
1515
NodeInfoUpdateRequestMessageType = "transport.UpdateNodeInfoRequest"
16+
ShutdownNoticeRequestMessageType = "transport.ShutdownNoticeRequest"
1617

1718
HandshakeResponseType = "transport.HandshakeResponse"
1819
HeartbeatResponseType = "transport.HeartbeatResponse"
1920
NodeInfoUpdateResponseType = "transport.UpdateNodeInfoResponse"
21+
ShutdownNoticeResponseType = "transport.ShutdownNoticeResponse"
2022
)

pkg/models/messages/node.go pkg/models/messages/control_plane.go

+12
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,15 @@ type UpdateNodeInfoResponse struct {
4040
Accepted bool `json:"accepted"`
4141
Reason string `json:"reason,omitempty"`
4242
}
43+
44+
// ShutdownNoticeRequest tells the orchestrator that this node is shutting down
45+
type ShutdownNoticeRequest struct {
46+
NodeID string `json:"NodeID"`
47+
Reason string `json:"Reason,omitempty"`
48+
LastOrchestratorSeqNum uint64 `json:"LastOrchestratorSeqNum"` // Last seq received from orchestrator
49+
}
50+
51+
// ShutdownNoticeResponse sends any final instructions back to the shutting-down node
52+
type ShutdownNoticeResponse struct {
53+
LastComputeSeqNum uint64 `json:"LastComputeSeqNum"` // Last seq received from compute node
54+
}

pkg/orchestrator/nodes/manager.go

+56
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,62 @@ func (n *nodesManager) Heartbeat(
582582
return messages.HeartbeatResponse{}, NewErrConcurrentModification()
583583
}
584584

585+
// ShutdownNotice processes a shutdown notification from a node and updates its state.
586+
// It updates:
587+
// - Final sequence numbers
588+
// - Connection state to disconnected
589+
// - Preserves the sequence numbers in persistent storage
590+
//
591+
// Returns ShutdownNoticeResponse with the last sequence number processed from that node.
592+
func (n *nodesManager) ShutdownNotice(
593+
ctx context.Context, request ExtendedShutdownNoticeRequest) (messages.ShutdownNoticeResponse, error) {
594+
// Get existing live state
595+
existingEntry, ok := n.liveState.Load(request.NodeID)
596+
if !ok {
597+
return messages.ShutdownNoticeResponse{}, NewErrHandshakeRequired(request.NodeID)
598+
}
599+
600+
existing := existingEntry.(*trackedLiveState)
601+
if existing.connectionState.Status != models.NodeStates.CONNECTED {
602+
return messages.ShutdownNoticeResponse{}, NewErrHandshakeRequired(request.NodeID)
603+
}
604+
605+
// Update connection state with final sequence numbers
606+
updated := existing.connectionState
607+
updated.Status = models.NodeStates.DISCONNECTED
608+
updated.DisconnectedSince = n.clock.Now().UTC()
609+
n.updateSequenceNumbers(&updated, request.LastOrchestratorSeqNum, request.LastComputeSeqNum)
610+
updated.LastError = "graceful shutdown"
611+
612+
// Attempt atomic update
613+
if !n.liveState.CompareAndSwap(request.NodeID, existingEntry, &trackedLiveState{
614+
connectionState: updated,
615+
availableCapacity: models.Resources{}, // Clear capacity since node is shutting down
616+
queueUsedCapacity: models.Resources{},
617+
}) {
618+
return messages.ShutdownNoticeResponse{}, NewErrConcurrentModification()
619+
}
620+
621+
log.Info().
622+
Str("node", request.NodeID).
623+
Str("reason", request.Reason).
624+
Uint64("lastOrchestratorSeq", updated.LastOrchestratorSeqNum).
625+
Uint64("lastComputeSeq", updated.LastComputeSeqNum).
626+
Msg("Node shutdown notice received")
627+
628+
// Notify about state change
629+
n.notifyConnectionStateChange(NodeConnectionEvent{
630+
NodeID: request.NodeID,
631+
Previous: models.NodeStates.CONNECTED,
632+
Current: models.NodeStates.DISCONNECTED,
633+
Timestamp: updated.DisconnectedSince,
634+
})
635+
636+
return messages.ShutdownNoticeResponse{
637+
LastComputeSeqNum: updated.LastComputeSeqNum,
638+
}, nil
639+
}
640+
585641
// ApproveNode approves a node for cluster participation.
586642
// The node must be in PENDING state. The operation updates
587643
// both persistent and live state.

pkg/orchestrator/nodes/manager_test.go

+179
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"context"
77
"fmt"
88
"sync"
9+
"sync/atomic"
910
"testing"
1011
"time"
1112

@@ -1072,6 +1073,184 @@ func (s *NodeManagerTestSuite) TestPersistenceWithContextCancellation() {
10721073
assert.Equal(s.T(), lastComputeSeqNum, state.ConnectionState.LastComputeSeqNum)
10731074
}
10741075

1076+
func (s *NodeManagerTestSuite) TestShutdownNotice() {
1077+
lastOrchestratorSeqNum := uint64(42)
1078+
lastComputeSeqNum := uint64(24)
1079+
1080+
// Test cases covering different shutdown scenarios
1081+
tests := []struct {
1082+
name string
1083+
setupNode bool // whether to setup node with handshake first
1084+
disconnect bool // whether to disconnect node before shutdown
1085+
reason string
1086+
expectedError string
1087+
validateState func(*testing.T, models.NodeState)
1088+
validateEvents func(*testing.T, []nodes.NodeConnectionEvent)
1089+
}{
1090+
{
1091+
name: "successful shutdown",
1092+
setupNode: true,
1093+
reason: "maintenance",
1094+
validateState: func(t *testing.T, state models.NodeState) {
1095+
assert.Equal(t, models.NodeStates.DISCONNECTED, state.ConnectionState.Status)
1096+
assert.Equal(t, "graceful shutdown", state.ConnectionState.LastError)
1097+
assert.False(t, state.ConnectionState.DisconnectedSince.IsZero())
1098+
},
1099+
validateEvents: func(t *testing.T, events []nodes.NodeConnectionEvent) {
1100+
require.Len(t, events, 2) // connect + disconnect
1101+
assert.Equal(t, models.NodeStates.CONNECTED, events[1].Previous)
1102+
assert.Equal(t, models.NodeStates.DISCONNECTED, events[1].Current)
1103+
},
1104+
},
1105+
{
1106+
name: "shutdown without handshake",
1107+
setupNode: false,
1108+
reason: "testing",
1109+
expectedError: "handshake required",
1110+
},
1111+
{
1112+
name: "shutdown already disconnected node",
1113+
setupNode: true,
1114+
disconnect: true,
1115+
reason: "testing",
1116+
expectedError: "handshake required",
1117+
validateState: func(t *testing.T, state models.NodeState) {
1118+
assert.Equal(t, models.NodeStates.DISCONNECTED, state.ConnectionState.Status)
1119+
},
1120+
},
1121+
{
1122+
name: "shutdown preserves sequence numbers",
1123+
setupNode: true,
1124+
reason: "testing",
1125+
validateState: func(t *testing.T, state models.NodeState) {
1126+
assert.Equal(t, lastOrchestratorSeqNum, state.ConnectionState.LastOrchestratorSeqNum)
1127+
assert.Equal(t, lastComputeSeqNum, state.ConnectionState.LastComputeSeqNum)
1128+
},
1129+
},
1130+
}
1131+
1132+
for _, tt := range tests {
1133+
s.Run(tt.name, func() {
1134+
// Track connection events
1135+
var events []nodes.NodeConnectionEvent
1136+
eventsMu := sync.Mutex{}
1137+
s.manager.OnConnectionStateChange(func(event nodes.NodeConnectionEvent) {
1138+
eventsMu.Lock()
1139+
events = append(events, event)
1140+
eventsMu.Unlock()
1141+
})
1142+
1143+
nodeInfo := s.createNodeInfo("shutdown-test")
1144+
1145+
// Setup node if required
1146+
if tt.setupNode {
1147+
_, err := s.manager.Handshake(s.ctx, messages.HandshakeRequest{NodeInfo: nodeInfo})
1148+
s.Require().NoError(err)
1149+
1150+
// Update sequence numbers
1151+
_, err = s.manager.Heartbeat(s.ctx, nodes.ExtendedHeartbeatRequest{
1152+
HeartbeatRequest: messages.HeartbeatRequest{
1153+
NodeID: nodeInfo.ID(),
1154+
LastOrchestratorSeqNum: 42,
1155+
},
1156+
LastComputeSeqNum: 24,
1157+
})
1158+
s.Require().NoError(err)
1159+
1160+
if tt.disconnect {
1161+
s.clock.Add(s.disconnected + time.Second)
1162+
s.Eventually(func() bool {
1163+
state, err := s.manager.Get(s.ctx, nodeInfo.ID())
1164+
s.Require().NoError(err)
1165+
return state.ConnectionState.Status == models.NodeStates.DISCONNECTED
1166+
}, 500*time.Millisecond, 20*time.Millisecond)
1167+
}
1168+
}
1169+
1170+
// Send shutdown notice
1171+
req := nodes.ExtendedShutdownNoticeRequest{
1172+
ShutdownNoticeRequest: messages.ShutdownNoticeRequest{
1173+
NodeID: nodeInfo.ID(),
1174+
Reason: tt.reason,
1175+
LastOrchestratorSeqNum: lastOrchestratorSeqNum,
1176+
},
1177+
LastComputeSeqNum: lastComputeSeqNum,
1178+
}
1179+
1180+
_, err := s.manager.ShutdownNotice(s.ctx, req)
1181+
if tt.expectedError != "" {
1182+
s.Assert().Error(err)
1183+
s.Assert().Contains(err.Error(), tt.expectedError)
1184+
return
1185+
}
1186+
s.Assert().NoError(err)
1187+
1188+
// Validate final state
1189+
state, err := s.manager.Get(s.ctx, nodeInfo.ID())
1190+
s.Require().NoError(err)
1191+
1192+
if tt.validateState != nil {
1193+
tt.validateState(s.T(), state)
1194+
}
1195+
1196+
if tt.validateEvents != nil {
1197+
eventsMu.Lock()
1198+
tt.validateEvents(s.T(), events)
1199+
eventsMu.Unlock()
1200+
}
1201+
})
1202+
}
1203+
}
1204+
1205+
func (s *NodeManagerTestSuite) TestConcurrentShutdown() {
1206+
nodeInfo := s.createNodeInfo("concurrent-shutdown")
1207+
lastOrchestratorSeqNum := uint64(42)
1208+
lastComputeSeqNum := uint64(24)
1209+
1210+
// Connect node
1211+
_, err := s.manager.Handshake(s.ctx, messages.HandshakeRequest{NodeInfo: nodeInfo})
1212+
s.Require().NoError(err)
1213+
1214+
// Track successful shutdowns
1215+
var wg sync.WaitGroup
1216+
successCount := int32(0)
1217+
const numConcurrent = 10
1218+
1219+
for i := 0; i < numConcurrent; i++ {
1220+
wg.Add(1)
1221+
go func(attempt int) {
1222+
defer wg.Done()
1223+
1224+
req := nodes.ExtendedShutdownNoticeRequest{
1225+
ShutdownNoticeRequest: messages.ShutdownNoticeRequest{
1226+
NodeID: nodeInfo.ID(),
1227+
Reason: fmt.Sprintf("concurrent shutdown %d", attempt),
1228+
LastOrchestratorSeqNum: lastOrchestratorSeqNum,
1229+
},
1230+
LastComputeSeqNum: lastComputeSeqNum,
1231+
}
1232+
1233+
_, err := s.manager.ShutdownNotice(s.ctx, req)
1234+
if err == nil {
1235+
atomic.AddInt32(&successCount, 1)
1236+
}
1237+
}(i)
1238+
}
1239+
1240+
wg.Wait()
1241+
1242+
// Exactly one shutdown should succeed
1243+
s.Assert().Equal(int32(1), successCount)
1244+
1245+
// Verify final state
1246+
state, err := s.manager.Get(s.ctx, nodeInfo.ID())
1247+
s.Require().NoError(err)
1248+
s.Assert().Equal(models.NodeStates.DISCONNECTED, state.ConnectionState.Status)
1249+
s.Assert().Equal("graceful shutdown", state.ConnectionState.LastError)
1250+
s.Assert().Equal(lastOrchestratorSeqNum, state.ConnectionState.LastOrchestratorSeqNum)
1251+
s.Assert().Equal(lastComputeSeqNum, state.ConnectionState.LastComputeSeqNum)
1252+
}
1253+
10751254
type mockNodeInfoProvider struct {
10761255
info models.NodeInfo
10771256
}

pkg/orchestrator/nodes/types.go

+12
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ type Manager interface {
6464
// The node must be registered and not rejected.
6565
UpdateNodeInfo(ctx context.Context, request messages.UpdateNodeInfoRequest) (messages.UpdateNodeInfoResponse, error)
6666

67+
// ShutdownNotice handles a node's graceful shutdown notification.
68+
// It updates sequence numbers and marks the node as cleanly disconnected.
69+
ShutdownNotice(ctx context.Context, request ExtendedShutdownNoticeRequest) (messages.ShutdownNoticeResponse, error)
70+
6771
// Heartbeat processes a node's heartbeat message and updates its state.
6872
// It returns the last known sequence numbers for synchronization.
6973
Heartbeat(ctx context.Context, request ExtendedHeartbeatRequest) (messages.HeartbeatResponse, error)
@@ -136,3 +140,11 @@ type ExtendedHeartbeatRequest struct {
136140
// LastComputeSeqNum is the last processed compute message sequence
137141
LastComputeSeqNum uint64
138142
}
143+
144+
// ExtendedShutdownNoticeRequest represents a shutdown message with additional metadata.
145+
type ExtendedShutdownNoticeRequest struct {
146+
messages.ShutdownNoticeRequest
147+
148+
// LastComputeSeqNum is the last processed compute message sequence
149+
LastComputeSeqNum uint64
150+
}

pkg/transport/nclprotocol/compute/controlplane.go

+27
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,25 @@ func (cp *ControlPlane) updateNodeInfo(ctx context.Context) error {
186186
return nil
187187
}
188188

189+
// sendShutdownNotification informs the orchestrator that this node is gracefully shutting down.
190+
// It includes the last message sequence numbers to help prevent message duplication on reconnect.
191+
// The notification is best-effort - we don't wait for or retry the response since we're shutting down.
192+
func (cp *ControlPlane) sendShutdownNotification(ctx context.Context) error {
193+
ctx, cancel := context.WithTimeout(ctx, cp.cfg.RequestTimeout)
194+
defer cancel()
195+
196+
msg := envelope.NewMessage(messages.ShutdownNoticeRequest{
197+
NodeID: cp.cfg.NodeID,
198+
LastOrchestratorSeqNum: cp.incomingSeqTracker.GetLastSeqNum(),
199+
}).WithMetadataValue(envelope.KeyMessageType, messages.ShutdownNoticeRequestMessageType)
200+
201+
_, err := cp.requester.Request(ctx, ncl.NewPublishRequest(msg))
202+
if err != nil {
203+
return fmt.Errorf("shutdown notification failed: %w", err)
204+
}
205+
return nil
206+
}
207+
189208
// checkpointProgress saves the latest processed message sequence number if it has
190209
// changed since the last checkpoint. This allows for resuming message processing
191210
// from the last known point after node restarts.
@@ -211,10 +230,18 @@ func (cp *ControlPlane) Stop(ctx context.Context) error {
211230
return nil
212231
}
213232

233+
// Prevent new operations
214234
cp.running = false
215235
close(cp.stopCh)
216236
cp.mu.Unlock()
217237

238+
// Send shutdown notification before closing
239+
if cp.healthTracker.GetHealth().CurrentState == nclprotocol.Connected && ctx.Err() == nil {
240+
if err := cp.sendShutdownNotification(ctx); err != nil {
241+
log.Error().Err(err).Msg("Failed to send shutdown notification")
242+
}
243+
}
244+
218245
// Wait for graceful shutdown
219246
done := make(chan struct{})
220247
go func() {

0 commit comments

Comments
 (0)