Skip to content

Commit 850a775

Browse files
authored
Merge branch 'main' into eng-448-faster-reconnect-on-handshake-required-response
2 parents 69b1b42 + 38572c1 commit 850a775

File tree

6 files changed

+359
-38
lines changed

6 files changed

+359
-38
lines changed

pkg/transport/nclprotocol/compute/manager.go

+17-7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package compute
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"sync"
78
"time"
@@ -144,8 +145,7 @@ func (cm *ConnectionManager) Close(ctx context.Context) error {
144145

145146
select {
146147
case <-done:
147-
cm.cleanup(ctx)
148-
return nil
148+
return cm.cleanup(ctx)
149149
case <-ctx.Done():
150150
return ctx.Err()
151151
}
@@ -155,27 +155,28 @@ func (cm *ConnectionManager) Close(ctx context.Context) error {
155155
// 1. Stops the data plane
156156
// 2. Cleans up control plane
157157
// 3. Closes NATS connection
158-
func (cm *ConnectionManager) cleanup(ctx context.Context) {
158+
func (cm *ConnectionManager) cleanup(ctx context.Context) error {
159+
var errs error
159160
// Clean up data plane subscriber
160161
if cm.subscriber != nil {
161162
if err := cm.subscriber.Close(ctx); err != nil {
162-
log.Error().Err(err).Msg("Failed to close subscriber")
163+
errs = errors.Join(errs, fmt.Errorf("failed to close subscriber: %w", err))
163164
}
164165
cm.subscriber = nil
165166
}
166167

167168
// Clean up data plane
168169
if cm.dataPlane != nil {
169170
if err := cm.dataPlane.Stop(ctx); err != nil {
170-
log.Error().Err(err).Msg("Failed to stop data plane")
171+
errs = errors.Join(errs, fmt.Errorf("failed to stop data plane: %w", err))
171172
}
172173
cm.dataPlane = nil
173174
}
174175

175176
// Clean up control plane
176177
if cm.controlPlane != nil {
177178
if err := cm.controlPlane.Stop(ctx); err != nil {
178-
log.Error().Err(err).Msg("Failed to stop control plane")
179+
errs = errors.Join(errs, fmt.Errorf("failed to stop control plane: %w", err))
179180
}
180181
cm.controlPlane = nil
181182
}
@@ -185,6 +186,8 @@ func (cm *ConnectionManager) cleanup(ctx context.Context) {
185186
cm.natsConn.Close()
186187
cm.natsConn = nil
187188
}
189+
190+
return errs
188191
}
189192

190193
// connect attempts to establish a connection to the orchestrator. It follows these steps:
@@ -202,10 +205,17 @@ func (cm *ConnectionManager) connect(ctx context.Context) error {
202205
log.Info().Str("node_id", cm.config.NodeID).Msg("Attempting to establish connection")
203206
cm.transitionState(nclprotocol.Connecting, nil)
204207

208+
// cleanup existing components before reconnecting
209+
if err := cm.cleanup(ctx); err != nil {
210+
return fmt.Errorf("failed to cleanup existing components: %w", err)
211+
}
212+
205213
var err error
206214
defer func() {
207215
if err != nil {
208-
cm.cleanup(ctx)
216+
if cleanupErr := cm.cleanup(ctx); cleanupErr != nil {
217+
log.Warn().Err(cleanupErr).Msg("failed to cleanup after connection error")
218+
}
209219
cm.transitionState(nclprotocol.Disconnected, err)
210220
}
211221
}()

pkg/transport/nclprotocol/dispatcher/dispatcher.go

+12-1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ func (d *Dispatcher) Start(ctx context.Context) error {
9292
d.mu.Unlock()
9393
return fmt.Errorf("dispatcher already running")
9494
}
95+
96+
// Reset state before starting
97+
d.state.reset()
98+
d.recovery.reset()
99+
95100
d.running = true
96101
d.mu.Unlock()
97102

@@ -121,7 +126,13 @@ func (d *Dispatcher) Stop(ctx context.Context) error {
121126
d.running = false
122127
d.mu.Unlock()
123128

129+
// Signal recovery to stop
130+
d.recovery.stop()
131+
132+
// Stop background goroutines
124133
close(d.stopCh)
134+
135+
// Stop watcher after recovery to avoid new messages
125136
d.watcher.Stop(ctx)
126137

127138
// Wait with timeout for all goroutines
@@ -212,7 +223,7 @@ func (d *Dispatcher) checkStalledMessages(ctx context.Context) {
212223
Uint64("eventSeq", msg.eventSeqNum).
213224
Time("publishTime", msg.publishTime).
214225
Msg("Message publish stalled")
215-
// Could implement recovery logic here
226+
// TODO: Could implement recovery logic here
216227
}
217228
}
218229
}

pkg/transport/nclprotocol/dispatcher/dispatcher_e2e_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,44 @@ func (s *DispatcherE2ETestSuite) TestCheckpointOnStop() {
254254
s.Equal(uint64(5), checkpoint)
255255
}
256256

257+
func (s *DispatcherE2ETestSuite) TestRestartResetsState() {
258+
config := dispatcher.Config{
259+
CheckpointInterval: 100 * time.Millisecond,
260+
}
261+
d := s.startDispatcher(config)
262+
263+
// Store and process first batch of events (1-5)
264+
s.storeEvents(5)
265+
s.Eventually(func() bool {
266+
return d.State().LastAckedSeqNum == 5
267+
}, time.Second, 10*time.Millisecond)
268+
269+
// Stop dispatcher
270+
err := d.Stop(s.ctx)
271+
s.Require().NoError(err)
272+
273+
// Clear received messages
274+
s.received = s.received[:0]
275+
276+
// Start new dispatcher - should start fresh but use checkpoint
277+
d = s.startDispatcher(config)
278+
279+
// Store next batch of events (6-8)
280+
s.storeEvent(6)
281+
s.storeEvent(7)
282+
s.storeEvent(8)
283+
284+
// Should only get new events since state was reset
285+
s.Eventually(func() bool {
286+
return d.State().LastAckedSeqNum == 8
287+
}, time.Second, 10*time.Millisecond)
288+
289+
// Verify messages
290+
for i, msg := range s.received {
291+
s.verifyMsg(msg, i+6)
292+
}
293+
}
294+
257295
func (s *DispatcherE2ETestSuite) storeEvent(index int) {
258296
err := s.store.StoreEvent(s.ctx, watcher.StoreEventRequest{
259297
Operation: watcher.OperationCreate,

pkg/transport/nclprotocol/dispatcher/dispatcher_test.go

+49
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,55 @@ func (suite *DispatcherTestSuite) TestStopNonStarted() {
172172
suite.NoError(err)
173173
}
174174

175+
func (suite *DispatcherTestSuite) TestStartResetsState() {
176+
// Setup
177+
suite.watcher.EXPECT().SetHandler(gomock.Any()).Return(nil)
178+
suite.watcher.EXPECT().Start(gomock.Any()).Return(nil)
179+
180+
d, err := New(suite.publisher, suite.watcher, suite.creator, suite.config)
181+
suite.Require().NoError(err)
182+
183+
// Start first time
184+
err = d.Start(suite.ctx)
185+
suite.NoError(err)
186+
187+
// Stop
188+
suite.watcher.EXPECT().Stop(gomock.Any())
189+
err = d.Stop(suite.ctx)
190+
suite.NoError(err)
191+
192+
// Start again - should reset state
193+
suite.watcher.EXPECT().Start(gomock.Any()).Return(nil)
194+
err = d.Start(suite.ctx)
195+
suite.NoError(err)
196+
}
197+
198+
func (suite *DispatcherTestSuite) TestStopSequence() {
199+
// Setup
200+
suite.watcher.EXPECT().SetHandler(gomock.Any()).Return(nil)
201+
suite.watcher.EXPECT().Start(gomock.Any()).Return(nil)
202+
203+
d, err := New(suite.publisher, suite.watcher, suite.creator, suite.config)
204+
suite.Require().NoError(err)
205+
206+
// Start
207+
err = d.Start(suite.ctx)
208+
suite.NoError(err)
209+
210+
// Stop should:
211+
// 1. Stop recovery
212+
// 2. Close stopCh
213+
// 3. Stop watcher
214+
// 4. Wait for goroutines
215+
gomock.InOrder(
216+
// Watcher.Stop will be called after recovery.stop and stopCh close
217+
suite.watcher.EXPECT().Stop(gomock.Any()),
218+
)
219+
220+
err = d.Stop(suite.ctx)
221+
suite.NoError(err)
222+
}
223+
175224
func TestDispatcherTestSuite(t *testing.T) {
176225
suite.Run(t, new(DispatcherTestSuite))
177226
}

pkg/transport/nclprotocol/dispatcher/recovery.go

+42-5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ type recovery struct {
2626
isRecovering bool // true while in recovery process
2727
lastFailure time.Time // time of most recent failure
2828
failures int // failure count for backoff
29+
stopCh chan struct{}
30+
wg sync.WaitGroup
2931
}
3032

3133
func newRecovery(
@@ -35,6 +37,7 @@ func newRecovery(
3537
watcher: watcher,
3638
state: state,
3739
backoff: backoff.NewExponential(config.BaseRetryInterval, config.MaxRetryInterval),
40+
stopCh: make(chan struct{}),
3841
}
3942
}
4043

@@ -80,11 +83,13 @@ func (r *recovery) handleError(ctx context.Context, msg *pendingMessage, err err
8083
log.Ctx(ctx).Debug().Msg("Reset dispatcher state state after publish failure")
8184

8285
// Launch recovery goroutine
86+
r.wg.Add(1)
8387
go r.recoveryLoop(ctx, r.failures)
8488
}
8589

8690
// recoveryLoop handles the recovery process with backoff
8791
func (r *recovery) recoveryLoop(ctx context.Context, failures int) {
92+
defer r.wg.Done()
8893
defer func() {
8994
r.mu.Lock()
9095
r.isRecovering = false
@@ -95,33 +100,51 @@ func (r *recovery) recoveryLoop(ctx context.Context, failures int) {
95100
// Perform backoff
96101
backoffDuration := r.backoff.BackoffDuration(failures)
97102
log.Debug().Int("failures", failures).Dur("backoff", backoffDuration).Msg("Performing backoff")
98-
r.backoff.Backoff(ctx, failures)
103+
104+
// Perform backoff with interruptibility
105+
timer := time.NewTimer(backoffDuration)
106+
select {
107+
case <-timer.C:
108+
case <-r.stopCh:
109+
timer.Stop()
110+
return
111+
case <-ctx.Done():
112+
timer.Stop()
113+
return
114+
}
99115

100116
// Just restart the watcher - it will resume from last checkpoint
101117
if err := r.watcher.Start(ctx); err != nil {
102118
if r.watcher.Stats().State == watcher.StateRunning {
103119
log.Debug().Msg("Watcher already after recovery. Exiting recovery loop.")
104120
return
105121
}
106-
if ctx.Err() != nil {
122+
select {
123+
case <-r.stopCh:
124+
return
125+
case <-ctx.Done():
107126
return
127+
default:
128+
log.Error().Err(err).Msg("Failed to restart watcher after backoff. Retrying...")
129+
failures++
108130
}
109-
log.Error().Err(err).Msg("Failed to restart watcher after backoff. Retrying...")
110-
failures++
111131
} else {
112-
log.Info().Msg("Successfully restarted watcher after backoff")
132+
log.Debug().Msg("Successfully restarted watcher after backoff")
113133
return
114134
}
115135
}
116136
}
117137

118138
// reset resets the recovery state
119139
func (r *recovery) reset() {
140+
r.stop() // Stop any existing recovery first
141+
120142
r.mu.Lock()
121143
defer r.mu.Unlock()
122144
r.isRecovering = false
123145
r.lastFailure = time.Time{}
124146
r.failures = 0
147+
r.stopCh = make(chan struct{})
125148
}
126149

127150
// getState returns current recovery state
@@ -132,3 +155,17 @@ func (r *recovery) getState() (bool, time.Time, int) {
132155
defer r.mu.RUnlock()
133156
return r.isRecovering, r.lastFailure, r.failures
134157
}
158+
159+
func (r *recovery) stop() {
160+
r.mu.Lock()
161+
// Try to close the channel only if it's not already closed
162+
select {
163+
case <-r.stopCh:
164+
default:
165+
close(r.stopCh)
166+
}
167+
r.mu.Unlock()
168+
169+
// Wait for recovery loop to exit, if running
170+
r.wg.Wait()
171+
}

0 commit comments

Comments
 (0)