Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race conditions resulting in multiple data planes #4771

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions pkg/transport/nclprotocol/compute/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package compute

import (
"context"
"errors"
"fmt"
"sync"
"time"
Expand Down Expand Up @@ -144,8 +145,7 @@ func (cm *ConnectionManager) Close(ctx context.Context) error {

select {
case <-done:
cm.cleanup(ctx)
return nil
return cm.cleanup(ctx)
case <-ctx.Done():
return ctx.Err()
}
Expand All @@ -155,27 +155,28 @@ func (cm *ConnectionManager) Close(ctx context.Context) error {
// 1. Stops the data plane
// 2. Cleans up control plane
// 3. Closes NATS connection
func (cm *ConnectionManager) cleanup(ctx context.Context) {
func (cm *ConnectionManager) cleanup(ctx context.Context) error {
var errs error
// Clean up data plane subscriber
if cm.subscriber != nil {
if err := cm.subscriber.Close(ctx); err != nil {
log.Error().Err(err).Msg("Failed to close subscriber")
errs = errors.Join(errs, fmt.Errorf("failed to close subscriber: %w", err))
}
cm.subscriber = nil
}

// Clean up data plane
if cm.dataPlane != nil {
if err := cm.dataPlane.Stop(ctx); err != nil {
log.Error().Err(err).Msg("Failed to stop data plane")
errs = errors.Join(errs, fmt.Errorf("failed to stop data plane: %w", err))
}
cm.dataPlane = nil
}

// Clean up control plane
if cm.controlPlane != nil {
if err := cm.controlPlane.Stop(ctx); err != nil {
log.Error().Err(err).Msg("Failed to stop control plane")
errs = errors.Join(errs, fmt.Errorf("failed to stop control plane: %w", err))
}
cm.controlPlane = nil
}
Expand All @@ -185,6 +186,8 @@ func (cm *ConnectionManager) cleanup(ctx context.Context) {
cm.natsConn.Close()
cm.natsConn = nil
}

return errs
}

// connect attempts to establish a connection to the orchestrator. It follows these steps:
Expand All @@ -202,10 +205,17 @@ func (cm *ConnectionManager) connect(ctx context.Context) error {
log.Info().Str("node_id", cm.config.NodeID).Msg("Attempting to establish connection")
cm.transitionState(nclprotocol.Connecting, nil)

// cleanup existing components before reconnecting
if err := cm.cleanup(ctx); err != nil {
return fmt.Errorf("failed to cleanup existing components: %w", err)
}

var err error
defer func() {
if err != nil {
cm.cleanup(ctx)
if cleanupErr := cm.cleanup(ctx); cleanupErr != nil {
log.Warn().Err(cleanupErr).Msg("failed to cleanup after connection error")
}
cm.transitionState(nclprotocol.Disconnected, err)
}
}()
Expand Down
13 changes: 12 additions & 1 deletion pkg/transport/nclprotocol/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ func (d *Dispatcher) Start(ctx context.Context) error {
d.mu.Unlock()
return fmt.Errorf("dispatcher already running")
}

// Reset state before starting
d.state.reset()
d.recovery.reset()

d.running = true
d.mu.Unlock()

Expand Down Expand Up @@ -121,7 +126,13 @@ func (d *Dispatcher) Stop(ctx context.Context) error {
d.running = false
d.mu.Unlock()

// Signal recovery to stop
d.recovery.stop()

// Stop background goroutines
close(d.stopCh)

// Stop watcher after recovery to avoid new messages
d.watcher.Stop(ctx)

// Wait with timeout for all goroutines
Expand Down Expand Up @@ -212,7 +223,7 @@ func (d *Dispatcher) checkStalledMessages(ctx context.Context) {
Uint64("eventSeq", msg.eventSeqNum).
Time("publishTime", msg.publishTime).
Msg("Message publish stalled")
// Could implement recovery logic here
// TODO: Could implement recovery logic here
}
}
}
Expand Down
38 changes: 38 additions & 0 deletions pkg/transport/nclprotocol/dispatcher/dispatcher_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,44 @@ func (s *DispatcherE2ETestSuite) TestCheckpointOnStop() {
s.Equal(uint64(5), checkpoint)
}

func (s *DispatcherE2ETestSuite) TestRestartResetsState() {
config := dispatcher.Config{
CheckpointInterval: 100 * time.Millisecond,
}
d := s.startDispatcher(config)

// Store and process first batch of events (1-5)
s.storeEvents(5)
s.Eventually(func() bool {
return d.State().LastAckedSeqNum == 5
}, time.Second, 10*time.Millisecond)

// Stop dispatcher
err := d.Stop(s.ctx)
s.Require().NoError(err)

// Clear received messages
s.received = s.received[:0]

// Start new dispatcher - should start fresh but use checkpoint
d = s.startDispatcher(config)

// Store next batch of events (6-8)
s.storeEvent(6)
s.storeEvent(7)
s.storeEvent(8)

// Should only get new events since state was reset
s.Eventually(func() bool {
return d.State().LastAckedSeqNum == 8
}, time.Second, 10*time.Millisecond)

// Verify messages
for i, msg := range s.received {
s.verifyMsg(msg, i+6)
}
}

func (s *DispatcherE2ETestSuite) storeEvent(index int) {
err := s.store.StoreEvent(s.ctx, watcher.StoreEventRequest{
Operation: watcher.OperationCreate,
Expand Down
49 changes: 49 additions & 0 deletions pkg/transport/nclprotocol/dispatcher/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,55 @@ func (suite *DispatcherTestSuite) TestStopNonStarted() {
suite.NoError(err)
}

func (suite *DispatcherTestSuite) TestStartResetsState() {
// Setup
suite.watcher.EXPECT().SetHandler(gomock.Any()).Return(nil)
suite.watcher.EXPECT().Start(gomock.Any()).Return(nil)

d, err := New(suite.publisher, suite.watcher, suite.creator, suite.config)
suite.Require().NoError(err)

// Start first time
err = d.Start(suite.ctx)
suite.NoError(err)

// Stop
suite.watcher.EXPECT().Stop(gomock.Any())
err = d.Stop(suite.ctx)
suite.NoError(err)

// Start again - should reset state
suite.watcher.EXPECT().Start(gomock.Any()).Return(nil)
err = d.Start(suite.ctx)
suite.NoError(err)
}

func (suite *DispatcherTestSuite) TestStopSequence() {
// Setup
suite.watcher.EXPECT().SetHandler(gomock.Any()).Return(nil)
suite.watcher.EXPECT().Start(gomock.Any()).Return(nil)

d, err := New(suite.publisher, suite.watcher, suite.creator, suite.config)
suite.Require().NoError(err)

// Start
err = d.Start(suite.ctx)
suite.NoError(err)

// Stop should:
// 1. Stop recovery
// 2. Close stopCh
// 3. Stop watcher
// 4. Wait for goroutines
gomock.InOrder(
// Watcher.Stop will be called after recovery.stop and stopCh close
suite.watcher.EXPECT().Stop(gomock.Any()),
)

err = d.Stop(suite.ctx)
suite.NoError(err)
}
wdbaruni marked this conversation as resolved.
Show resolved Hide resolved

func TestDispatcherTestSuite(t *testing.T) {
suite.Run(t, new(DispatcherTestSuite))
}
45 changes: 40 additions & 5 deletions pkg/transport/nclprotocol/dispatcher/recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ type recovery struct {
isRecovering bool // true while in recovery process
lastFailure time.Time // time of most recent failure
failures int // failure count for backoff
stopCh chan struct{}
wg sync.WaitGroup
}

func newRecovery(
Expand All @@ -35,6 +37,7 @@ func newRecovery(
watcher: watcher,
state: state,
backoff: backoff.NewExponential(config.BaseRetryInterval, config.MaxRetryInterval),
stopCh: make(chan struct{}),
}
}

Expand Down Expand Up @@ -80,11 +83,13 @@ func (r *recovery) handleError(ctx context.Context, msg *pendingMessage, err err
log.Ctx(ctx).Debug().Msg("Reset dispatcher state state after publish failure")

// Launch recovery goroutine
r.wg.Add(1)
go r.recoveryLoop(ctx, r.failures)
}

// recoveryLoop handles the recovery process with backoff
func (r *recovery) recoveryLoop(ctx context.Context, failures int) {
defer r.wg.Done()
defer func() {
r.mu.Lock()
r.isRecovering = false
Expand All @@ -95,21 +100,36 @@ func (r *recovery) recoveryLoop(ctx context.Context, failures int) {
// Perform backoff
backoffDuration := r.backoff.BackoffDuration(failures)
log.Debug().Int("failures", failures).Dur("backoff", backoffDuration).Msg("Performing backoff")
r.backoff.Backoff(ctx, failures)

// Perform backoff with interruptibility
timer := time.NewTimer(backoffDuration)
select {
case <-timer.C:
case <-r.stopCh:
timer.Stop()
return
case <-ctx.Done():
timer.Stop()
return
}

// Just restart the watcher - it will resume from last checkpoint
if err := r.watcher.Start(ctx); err != nil {
if r.watcher.Stats().State == watcher.StateRunning {
log.Debug().Msg("Watcher already after recovery. Exiting recovery loop.")
return
}
if ctx.Err() != nil {
select {
case <-r.stopCh:
return
case <-ctx.Done():
return
default:
log.Error().Err(err).Msg("Failed to restart watcher after backoff. Retrying...")
failures++
}
log.Error().Err(err).Msg("Failed to restart watcher after backoff. Retrying...")
failures++
} else {
log.Info().Msg("Successfully restarted watcher after backoff")
log.Debug().Msg("Successfully restarted watcher after backoff")
return
}
}
Expand All @@ -122,6 +142,7 @@ func (r *recovery) reset() {
r.isRecovering = false
r.lastFailure = time.Time{}
r.failures = 0
r.stopCh = make(chan struct{})
wdbaruni marked this conversation as resolved.
Show resolved Hide resolved
}

// getState returns current recovery state
Expand All @@ -132,3 +153,17 @@ func (r *recovery) getState() (bool, time.Time, int) {
defer r.mu.RUnlock()
return r.isRecovering, r.lastFailure, r.failures
}

func (r *recovery) stop() {
r.mu.Lock()
// Try to close the channel only if it's not already closed
select {
case <-r.stopCh:
default:
close(r.stopCh)
}
r.mu.Unlock()

// Wait for recovery loop to exit, if running
r.wg.Wait()
}
Loading
Loading