Skip to content

Commit

Permalink
Fix race conditions resulting in multiple data planes (#4771)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Enhanced error handling and cleanup processes for connection
management.
- State reset functionality added to the dispatcher during start and
stop operations.
- New recovery mechanisms for graceful stopping of the recovery process.

- **Bug Fixes**
- Improved error reporting during connection cleanup and state
management.

- **Tests**
- Added new test methods to validate dispatcher state management and
recovery processes.
- Enhanced existing tests for better error handling and recovery logic
validation.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
wdbaruni authored Dec 16, 2024
1 parent f9751b5 commit 38572c1
Show file tree
Hide file tree
Showing 6 changed files with 359 additions and 38 deletions.
24 changes: 17 additions & 7 deletions pkg/transport/nclprotocol/compute/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package compute

import (
"context"
"errors"
"fmt"
"sync"
"time"
Expand Down Expand Up @@ -144,8 +145,7 @@ func (cm *ConnectionManager) Close(ctx context.Context) error {

select {
case <-done:
cm.cleanup(ctx)
return nil
return cm.cleanup(ctx)
case <-ctx.Done():
return ctx.Err()
}
Expand All @@ -155,27 +155,28 @@ func (cm *ConnectionManager) Close(ctx context.Context) error {
// 1. Stops the data plane
// 2. Cleans up control plane
// 3. Closes NATS connection
func (cm *ConnectionManager) cleanup(ctx context.Context) {
func (cm *ConnectionManager) cleanup(ctx context.Context) error {
var errs error
// Clean up data plane subscriber
if cm.subscriber != nil {
if err := cm.subscriber.Close(ctx); err != nil {
log.Error().Err(err).Msg("Failed to close subscriber")
errs = errors.Join(errs, fmt.Errorf("failed to close subscriber: %w", err))
}
cm.subscriber = nil
}

// Clean up data plane
if cm.dataPlane != nil {
if err := cm.dataPlane.Stop(ctx); err != nil {
log.Error().Err(err).Msg("Failed to stop data plane")
errs = errors.Join(errs, fmt.Errorf("failed to stop data plane: %w", err))
}
cm.dataPlane = nil
}

// Clean up control plane
if cm.controlPlane != nil {
if err := cm.controlPlane.Stop(ctx); err != nil {
log.Error().Err(err).Msg("Failed to stop control plane")
errs = errors.Join(errs, fmt.Errorf("failed to stop control plane: %w", err))
}
cm.controlPlane = nil
}
Expand All @@ -185,6 +186,8 @@ func (cm *ConnectionManager) cleanup(ctx context.Context) {
cm.natsConn.Close()
cm.natsConn = nil
}

return errs
}

// connect attempts to establish a connection to the orchestrator. It follows these steps:
Expand All @@ -202,10 +205,17 @@ func (cm *ConnectionManager) connect(ctx context.Context) error {
log.Info().Str("node_id", cm.config.NodeID).Msg("Attempting to establish connection")
cm.transitionState(nclprotocol.Connecting, nil)

// cleanup existing components before reconnecting
if err := cm.cleanup(ctx); err != nil {
return fmt.Errorf("failed to cleanup existing components: %w", err)
}

var err error
defer func() {
if err != nil {
cm.cleanup(ctx)
if cleanupErr := cm.cleanup(ctx); cleanupErr != nil {
log.Warn().Err(cleanupErr).Msg("failed to cleanup after connection error")
}
cm.transitionState(nclprotocol.Disconnected, err)
}
}()
Expand Down
13 changes: 12 additions & 1 deletion pkg/transport/nclprotocol/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ func (d *Dispatcher) Start(ctx context.Context) error {
d.mu.Unlock()
return fmt.Errorf("dispatcher already running")
}

// Reset state before starting
d.state.reset()
d.recovery.reset()

d.running = true
d.mu.Unlock()

Expand Down Expand Up @@ -121,7 +126,13 @@ func (d *Dispatcher) Stop(ctx context.Context) error {
d.running = false
d.mu.Unlock()

// Signal recovery to stop
d.recovery.stop()

// Stop background goroutines
close(d.stopCh)

// Stop watcher after recovery to avoid new messages
d.watcher.Stop(ctx)

// Wait with timeout for all goroutines
Expand Down Expand Up @@ -212,7 +223,7 @@ func (d *Dispatcher) checkStalledMessages(ctx context.Context) {
Uint64("eventSeq", msg.eventSeqNum).
Time("publishTime", msg.publishTime).
Msg("Message publish stalled")
// Could implement recovery logic here
// TODO: Could implement recovery logic here
}
}
}
Expand Down
38 changes: 38 additions & 0 deletions pkg/transport/nclprotocol/dispatcher/dispatcher_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,44 @@ func (s *DispatcherE2ETestSuite) TestCheckpointOnStop() {
s.Equal(uint64(5), checkpoint)
}

func (s *DispatcherE2ETestSuite) TestRestartResetsState() {
config := dispatcher.Config{
CheckpointInterval: 100 * time.Millisecond,
}
d := s.startDispatcher(config)

// Store and process first batch of events (1-5)
s.storeEvents(5)
s.Eventually(func() bool {
return d.State().LastAckedSeqNum == 5
}, time.Second, 10*time.Millisecond)

// Stop dispatcher
err := d.Stop(s.ctx)
s.Require().NoError(err)

// Clear received messages
s.received = s.received[:0]

// Start new dispatcher - should start fresh but use checkpoint
d = s.startDispatcher(config)

// Store next batch of events (6-8)
s.storeEvent(6)
s.storeEvent(7)
s.storeEvent(8)

// Should only get new events since state was reset
s.Eventually(func() bool {
return d.State().LastAckedSeqNum == 8
}, time.Second, 10*time.Millisecond)

// Verify messages
for i, msg := range s.received {
s.verifyMsg(msg, i+6)
}
}

func (s *DispatcherE2ETestSuite) storeEvent(index int) {
err := s.store.StoreEvent(s.ctx, watcher.StoreEventRequest{
Operation: watcher.OperationCreate,
Expand Down
49 changes: 49 additions & 0 deletions pkg/transport/nclprotocol/dispatcher/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,55 @@ func (suite *DispatcherTestSuite) TestStopNonStarted() {
suite.NoError(err)
}

func (suite *DispatcherTestSuite) TestStartResetsState() {
// Setup
suite.watcher.EXPECT().SetHandler(gomock.Any()).Return(nil)
suite.watcher.EXPECT().Start(gomock.Any()).Return(nil)

d, err := New(suite.publisher, suite.watcher, suite.creator, suite.config)
suite.Require().NoError(err)

// Start first time
err = d.Start(suite.ctx)
suite.NoError(err)

// Stop
suite.watcher.EXPECT().Stop(gomock.Any())
err = d.Stop(suite.ctx)
suite.NoError(err)

// Start again - should reset state
suite.watcher.EXPECT().Start(gomock.Any()).Return(nil)
err = d.Start(suite.ctx)
suite.NoError(err)
}

func (suite *DispatcherTestSuite) TestStopSequence() {
// Setup
suite.watcher.EXPECT().SetHandler(gomock.Any()).Return(nil)
suite.watcher.EXPECT().Start(gomock.Any()).Return(nil)

d, err := New(suite.publisher, suite.watcher, suite.creator, suite.config)
suite.Require().NoError(err)

// Start
err = d.Start(suite.ctx)
suite.NoError(err)

// Stop should:
// 1. Stop recovery
// 2. Close stopCh
// 3. Stop watcher
// 4. Wait for goroutines
gomock.InOrder(
// Watcher.Stop will be called after recovery.stop and stopCh close
suite.watcher.EXPECT().Stop(gomock.Any()),
)

err = d.Stop(suite.ctx)
suite.NoError(err)
}

func TestDispatcherTestSuite(t *testing.T) {
suite.Run(t, new(DispatcherTestSuite))
}
47 changes: 42 additions & 5 deletions pkg/transport/nclprotocol/dispatcher/recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ type recovery struct {
isRecovering bool // true while in recovery process
lastFailure time.Time // time of most recent failure
failures int // failure count for backoff
stopCh chan struct{}
wg sync.WaitGroup
}

func newRecovery(
Expand All @@ -35,6 +37,7 @@ func newRecovery(
watcher: watcher,
state: state,
backoff: backoff.NewExponential(config.BaseRetryInterval, config.MaxRetryInterval),
stopCh: make(chan struct{}),
}
}

Expand Down Expand Up @@ -80,11 +83,13 @@ func (r *recovery) handleError(ctx context.Context, msg *pendingMessage, err err
log.Ctx(ctx).Debug().Msg("Reset dispatcher state state after publish failure")

// Launch recovery goroutine
r.wg.Add(1)
go r.recoveryLoop(ctx, r.failures)
}

// recoveryLoop handles the recovery process with backoff
func (r *recovery) recoveryLoop(ctx context.Context, failures int) {
defer r.wg.Done()
defer func() {
r.mu.Lock()
r.isRecovering = false
Expand All @@ -95,33 +100,51 @@ func (r *recovery) recoveryLoop(ctx context.Context, failures int) {
// Perform backoff
backoffDuration := r.backoff.BackoffDuration(failures)
log.Debug().Int("failures", failures).Dur("backoff", backoffDuration).Msg("Performing backoff")
r.backoff.Backoff(ctx, failures)

// Perform backoff with interruptibility
timer := time.NewTimer(backoffDuration)
select {
case <-timer.C:
case <-r.stopCh:
timer.Stop()
return
case <-ctx.Done():
timer.Stop()
return
}

// Just restart the watcher - it will resume from last checkpoint
if err := r.watcher.Start(ctx); err != nil {
if r.watcher.Stats().State == watcher.StateRunning {
log.Debug().Msg("Watcher already after recovery. Exiting recovery loop.")
return
}
if ctx.Err() != nil {
select {
case <-r.stopCh:
return
case <-ctx.Done():
return
default:
log.Error().Err(err).Msg("Failed to restart watcher after backoff. Retrying...")
failures++
}
log.Error().Err(err).Msg("Failed to restart watcher after backoff. Retrying...")
failures++
} else {
log.Info().Msg("Successfully restarted watcher after backoff")
log.Debug().Msg("Successfully restarted watcher after backoff")
return
}
}
}

// reset resets the recovery state
func (r *recovery) reset() {
r.stop() // Stop any existing recovery first

r.mu.Lock()
defer r.mu.Unlock()
r.isRecovering = false
r.lastFailure = time.Time{}
r.failures = 0
r.stopCh = make(chan struct{})
}

// getState returns current recovery state
Expand All @@ -132,3 +155,17 @@ func (r *recovery) getState() (bool, time.Time, int) {
defer r.mu.RUnlock()
return r.isRecovering, r.lastFailure, r.failures
}

func (r *recovery) stop() {
r.mu.Lock()
// Try to close the channel only if it's not already closed
select {
case <-r.stopCh:
default:
close(r.stopCh)
}
r.mu.Unlock()

// Wait for recovery loop to exit, if running
r.wg.Wait()
}
Loading

0 comments on commit 38572c1

Please sign in to comment.