|
6 | 6 | "context"
|
7 | 7 | "fmt"
|
8 | 8 | "sync"
|
| 9 | + "sync/atomic" |
9 | 10 | "testing"
|
10 | 11 | "time"
|
11 | 12 |
|
@@ -1072,6 +1073,184 @@ func (s *NodeManagerTestSuite) TestPersistenceWithContextCancellation() {
|
1072 | 1073 | assert.Equal(s.T(), lastComputeSeqNum, state.ConnectionState.LastComputeSeqNum)
|
1073 | 1074 | }
|
1074 | 1075 |
|
| 1076 | +func (s *NodeManagerTestSuite) TestShutdownNotice() { |
| 1077 | + lastOrchestratorSeqNum := uint64(42) |
| 1078 | + lastComputeSeqNum := uint64(24) |
| 1079 | + |
| 1080 | + // Test cases covering different shutdown scenarios |
| 1081 | + tests := []struct { |
| 1082 | + name string |
| 1083 | + setupNode bool // whether to setup node with handshake first |
| 1084 | + disconnect bool // whether to disconnect node before shutdown |
| 1085 | + reason string |
| 1086 | + expectedError string |
| 1087 | + validateState func(*testing.T, models.NodeState) |
| 1088 | + validateEvents func(*testing.T, []nodes.NodeConnectionEvent) |
| 1089 | + }{ |
| 1090 | + { |
| 1091 | + name: "successful shutdown", |
| 1092 | + setupNode: true, |
| 1093 | + reason: "maintenance", |
| 1094 | + validateState: func(t *testing.T, state models.NodeState) { |
| 1095 | + assert.Equal(t, models.NodeStates.DISCONNECTED, state.ConnectionState.Status) |
| 1096 | + assert.Equal(t, "graceful shutdown", state.ConnectionState.LastError) |
| 1097 | + assert.False(t, state.ConnectionState.DisconnectedSince.IsZero()) |
| 1098 | + }, |
| 1099 | + validateEvents: func(t *testing.T, events []nodes.NodeConnectionEvent) { |
| 1100 | + require.Len(t, events, 2) // connect + disconnect |
| 1101 | + assert.Equal(t, models.NodeStates.CONNECTED, events[1].Previous) |
| 1102 | + assert.Equal(t, models.NodeStates.DISCONNECTED, events[1].Current) |
| 1103 | + }, |
| 1104 | + }, |
| 1105 | + { |
| 1106 | + name: "shutdown without handshake", |
| 1107 | + setupNode: false, |
| 1108 | + reason: "testing", |
| 1109 | + expectedError: "handshake required", |
| 1110 | + }, |
| 1111 | + { |
| 1112 | + name: "shutdown already disconnected node", |
| 1113 | + setupNode: true, |
| 1114 | + disconnect: true, |
| 1115 | + reason: "testing", |
| 1116 | + expectedError: "handshake required", |
| 1117 | + validateState: func(t *testing.T, state models.NodeState) { |
| 1118 | + assert.Equal(t, models.NodeStates.DISCONNECTED, state.ConnectionState.Status) |
| 1119 | + }, |
| 1120 | + }, |
| 1121 | + { |
| 1122 | + name: "shutdown preserves sequence numbers", |
| 1123 | + setupNode: true, |
| 1124 | + reason: "testing", |
| 1125 | + validateState: func(t *testing.T, state models.NodeState) { |
| 1126 | + assert.Equal(t, lastOrchestratorSeqNum, state.ConnectionState.LastOrchestratorSeqNum) |
| 1127 | + assert.Equal(t, lastComputeSeqNum, state.ConnectionState.LastComputeSeqNum) |
| 1128 | + }, |
| 1129 | + }, |
| 1130 | + } |
| 1131 | + |
| 1132 | + for _, tt := range tests { |
| 1133 | + s.Run(tt.name, func() { |
| 1134 | + // Track connection events |
| 1135 | + var events []nodes.NodeConnectionEvent |
| 1136 | + eventsMu := sync.Mutex{} |
| 1137 | + s.manager.OnConnectionStateChange(func(event nodes.NodeConnectionEvent) { |
| 1138 | + eventsMu.Lock() |
| 1139 | + events = append(events, event) |
| 1140 | + eventsMu.Unlock() |
| 1141 | + }) |
| 1142 | + |
| 1143 | + nodeInfo := s.createNodeInfo("shutdown-test") |
| 1144 | + |
| 1145 | + // Setup node if required |
| 1146 | + if tt.setupNode { |
| 1147 | + _, err := s.manager.Handshake(s.ctx, messages.HandshakeRequest{NodeInfo: nodeInfo}) |
| 1148 | + s.Require().NoError(err) |
| 1149 | + |
| 1150 | + // Update sequence numbers |
| 1151 | + _, err = s.manager.Heartbeat(s.ctx, nodes.ExtendedHeartbeatRequest{ |
| 1152 | + HeartbeatRequest: messages.HeartbeatRequest{ |
| 1153 | + NodeID: nodeInfo.ID(), |
| 1154 | + LastOrchestratorSeqNum: 42, |
| 1155 | + }, |
| 1156 | + LastComputeSeqNum: 24, |
| 1157 | + }) |
| 1158 | + s.Require().NoError(err) |
| 1159 | + |
| 1160 | + if tt.disconnect { |
| 1161 | + s.clock.Add(s.disconnected + time.Second) |
| 1162 | + s.Eventually(func() bool { |
| 1163 | + state, err := s.manager.Get(s.ctx, nodeInfo.ID()) |
| 1164 | + s.Require().NoError(err) |
| 1165 | + return state.ConnectionState.Status == models.NodeStates.DISCONNECTED |
| 1166 | + }, 500*time.Millisecond, 20*time.Millisecond) |
| 1167 | + } |
| 1168 | + } |
| 1169 | + |
| 1170 | + // Send shutdown notice |
| 1171 | + req := nodes.ExtendedShutdownNoticeRequest{ |
| 1172 | + ShutdownNoticeRequest: messages.ShutdownNoticeRequest{ |
| 1173 | + NodeID: nodeInfo.ID(), |
| 1174 | + Reason: tt.reason, |
| 1175 | + LastOrchestratorSeqNum: lastOrchestratorSeqNum, |
| 1176 | + }, |
| 1177 | + LastComputeSeqNum: lastComputeSeqNum, |
| 1178 | + } |
| 1179 | + |
| 1180 | + _, err := s.manager.ShutdownNotice(s.ctx, req) |
| 1181 | + if tt.expectedError != "" { |
| 1182 | + s.Assert().Error(err) |
| 1183 | + s.Assert().Contains(err.Error(), tt.expectedError) |
| 1184 | + return |
| 1185 | + } |
| 1186 | + s.Assert().NoError(err) |
| 1187 | + |
| 1188 | + // Validate final state |
| 1189 | + state, err := s.manager.Get(s.ctx, nodeInfo.ID()) |
| 1190 | + s.Require().NoError(err) |
| 1191 | + |
| 1192 | + if tt.validateState != nil { |
| 1193 | + tt.validateState(s.T(), state) |
| 1194 | + } |
| 1195 | + |
| 1196 | + if tt.validateEvents != nil { |
| 1197 | + eventsMu.Lock() |
| 1198 | + tt.validateEvents(s.T(), events) |
| 1199 | + eventsMu.Unlock() |
| 1200 | + } |
| 1201 | + }) |
| 1202 | + } |
| 1203 | +} |
| 1204 | + |
| 1205 | +func (s *NodeManagerTestSuite) TestConcurrentShutdown() { |
| 1206 | + nodeInfo := s.createNodeInfo("concurrent-shutdown") |
| 1207 | + lastOrchestratorSeqNum := uint64(42) |
| 1208 | + lastComputeSeqNum := uint64(24) |
| 1209 | + |
| 1210 | + // Connect node |
| 1211 | + _, err := s.manager.Handshake(s.ctx, messages.HandshakeRequest{NodeInfo: nodeInfo}) |
| 1212 | + s.Require().NoError(err) |
| 1213 | + |
| 1214 | + // Track successful shutdowns |
| 1215 | + var wg sync.WaitGroup |
| 1216 | + successCount := int32(0) |
| 1217 | + const numConcurrent = 10 |
| 1218 | + |
| 1219 | + for i := 0; i < numConcurrent; i++ { |
| 1220 | + wg.Add(1) |
| 1221 | + go func(attempt int) { |
| 1222 | + defer wg.Done() |
| 1223 | + |
| 1224 | + req := nodes.ExtendedShutdownNoticeRequest{ |
| 1225 | + ShutdownNoticeRequest: messages.ShutdownNoticeRequest{ |
| 1226 | + NodeID: nodeInfo.ID(), |
| 1227 | + Reason: fmt.Sprintf("concurrent shutdown %d", attempt), |
| 1228 | + LastOrchestratorSeqNum: lastOrchestratorSeqNum, |
| 1229 | + }, |
| 1230 | + LastComputeSeqNum: lastComputeSeqNum, |
| 1231 | + } |
| 1232 | + |
| 1233 | + _, err := s.manager.ShutdownNotice(s.ctx, req) |
| 1234 | + if err == nil { |
| 1235 | + atomic.AddInt32(&successCount, 1) |
| 1236 | + } |
| 1237 | + }(i) |
| 1238 | + } |
| 1239 | + |
| 1240 | + wg.Wait() |
| 1241 | + |
| 1242 | + // Exactly one shutdown should succeed |
| 1243 | + s.Assert().Equal(int32(1), successCount) |
| 1244 | + |
| 1245 | + // Verify final state |
| 1246 | + state, err := s.manager.Get(s.ctx, nodeInfo.ID()) |
| 1247 | + s.Require().NoError(err) |
| 1248 | + s.Assert().Equal(models.NodeStates.DISCONNECTED, state.ConnectionState.Status) |
| 1249 | + s.Assert().Equal("graceful shutdown", state.ConnectionState.LastError) |
| 1250 | + s.Assert().Equal(lastOrchestratorSeqNum, state.ConnectionState.LastOrchestratorSeqNum) |
| 1251 | + s.Assert().Equal(lastComputeSeqNum, state.ConnectionState.LastComputeSeqNum) |
| 1252 | +} |
| 1253 | + |
1075 | 1254 | type mockNodeInfoProvider struct {
|
1076 | 1255 | info models.NodeInfo
|
1077 | 1256 | }
|
|
0 commit comments