Skip to content

Commit

Permalink
Faster reconnect on handshake required response (#4772)
Browse files Browse the repository at this point in the history
When orchestrator restarts, compute nodes wait for 5 failed heartbeats
(~75s) before attempting to reconnect, even though orchestrator
immediately returns "Handshake required" errors.

Modify compute nodes to detect this specific error and trigger immediate
reconnection, rather than waiting for the heartbeat failure threshold.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced error handling for heartbeat operations, specifically
addressing handshake requirements.
- New boolean field `HandshakeRequired` added to track handshake
necessity.

- **Bug Fixes**
- Improved robustness of connection health monitoring by incorporating
handshake checks.

- **Tests**
- Added tests for new handshake handling scenarios in both
`ControlPlaneTestSuite` and `ConnectionManagerTestSuite`.
- Enhanced coverage for `HealthTracker` functionality regarding
handshake states.

- **Documentation**
- Updated comments in connection health checks for clarity on new
criteria.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
wdbaruni authored Dec 16, 2024
1 parent 38572c1 commit 36c44b5
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pkg/transport/nclprotocol/compute/controlplane.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package compute
import (
"context"
"fmt"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -104,6 +105,10 @@ func (cp *ControlPlane) run(ctx context.Context) {

case <-heartbeat.C:
if err := cp.heartbeat(ctx); err != nil {
if strings.Contains(err.Error(), "handshake required") {
cp.healthTracker.HandshakeRequired()
return
}
log.Error().Err(err).Msg("Failed to send heartbeat")
}
case <-nodeInfo.C:
Expand Down
67 changes: 67 additions & 0 deletions pkg/transport/nclprotocol/compute/controlplane_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/bacalhau-project/bacalhau/pkg/lib/envelope"
"github.com/bacalhau-project/bacalhau/pkg/lib/ncl"
"github.com/bacalhau-project/bacalhau/pkg/models/messages"
"github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes"
"github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol"
nclprotocolcompute "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/compute"
ncltest "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/test"
Expand Down Expand Up @@ -179,6 +180,72 @@ func (s *ControlPlaneTestSuite) TestHeartbeat() {
}, 100*time.Millisecond, 10*time.Millisecond, "Heartbeat did not succeed")
}

func (s *ControlPlaneTestSuite) TestHeartbeatFailFastOnHandshakeRequired() {
// Create control plane with only heartbeat enabled and short intervals
controlPlane := s.createControlPlane(
50*time.Millisecond, // heartbeat
1*time.Hour, // node info - disabled
1*time.Hour, // checkpoint - disabled
)
defer s.Require().NoError(controlPlane.Stop(s.ctx))

// Setup handshake required error response
s.requester.EXPECT().
Request(gomock.Any(), gomock.Any()).
Return(nil, nodes.NewErrHandshakeRequired("test-node")).
Times(1) // Should only try once

// Start control plane
s.Require().NoError(controlPlane.Start(s.ctx))

// Wait a bit to allow for heartbeat attempt
time.Sleep(50 * time.Millisecond)

// wait health tracker state
s.Require().Eventually(func() bool {
return s.healthTracker.GetHealth().HandshakeRequired
}, 100*time.Millisecond, 10*time.Millisecond, "Heartbeat did not succeed")

// Verify health tracker state
health := s.healthTracker.GetHealth()
s.True(health.HandshakeRequired, "handshake should be marked as required")
s.Zero(health.LastSuccessfulHeartbeat, "no successful heartbeat should be recorded")

// Wait for another heartbeat interval to verify the loop has stopped
time.Sleep(70 * time.Millisecond)

s.Require().Eventually(func() bool {
return s.healthTracker.GetHealth().CurrentState == nclprotocol.Disconnected
}, 100*time.Millisecond, 10*time.Millisecond, "connection not marked as disconnected")
}

func (s *ControlPlaneTestSuite) TestHeartbeatContinuesOnOtherErrors() {
// Create control plane with only heartbeat enabled
controlPlane := s.createControlPlane(
50*time.Millisecond, // heartbeat
1*time.Hour, // node info - disabled
1*time.Hour, // checkpoint - disabled
)
defer s.Require().NoError(controlPlane.Stop(s.ctx))

// Setup regular error response that should not cause fail-fast
s.requester.EXPECT().
Request(gomock.Any(), gomock.Any()).
Return(nil, fmt.Errorf("network error")).
Times(2) // Should keep trying

// Start control plane
s.Require().NoError(controlPlane.Start(s.ctx))

// Wait for two heartbeat attempts
time.Sleep(120 * time.Millisecond)

// Verify health tracker state
health := s.healthTracker.GetHealth()
s.False(health.HandshakeRequired, "handshake should not be marked as required")
s.Zero(health.LastSuccessfulHeartbeat, "no successful heartbeat should be recorded")
}

func (s *ControlPlaneTestSuite) TestNodeInfoUpdate() {
// Create control plane with only checkpointing enabled
controlPlane := s.createControlPlane(
Expand Down
17 changes: 17 additions & 0 deletions pkg/transport/nclprotocol/compute/health_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func (ht *HealthTracker) MarkConnected() {
ht.health.LastSuccessfulHeartbeat = ht.clock.Now()
ht.health.ConsecutiveFailures = 0
ht.health.LastError = nil
ht.health.HandshakeRequired = false
}

// MarkDisconnected updates status when connection is lost
Expand All @@ -46,6 +47,7 @@ func (ht *HealthTracker) MarkDisconnected(err error) {
ht.health.CurrentState = nclprotocol.Disconnected
ht.health.LastError = err
ht.health.ConsecutiveFailures++
ht.health.HandshakeRequired = false
}

// MarkConnecting update status when connection is in progress
Expand All @@ -54,6 +56,7 @@ func (ht *HealthTracker) MarkConnecting() {
defer ht.mu.Unlock()

ht.health.CurrentState = nclprotocol.Connecting
ht.health.HandshakeRequired = false
}

// HeartbeatSuccess records successful heartbeat
Expand All @@ -70,6 +73,13 @@ func (ht *HealthTracker) UpdateSuccess() {
ht.health.LastSuccessfulUpdate = ht.clock.Now()
}

// HandshakeRequired marks that a handshake is required
func (ht *HealthTracker) HandshakeRequired() {
ht.mu.Lock()
defer ht.mu.Unlock()
ht.health.HandshakeRequired = true
}

// GetState returns current connection state
func (ht *HealthTracker) GetState() nclprotocol.ConnectionState {
ht.mu.RLock()
Expand All @@ -83,3 +93,10 @@ func (ht *HealthTracker) GetHealth() nclprotocol.ConnectionHealth {
defer ht.mu.RUnlock()
return ht.health
}

// IsHandshakeRequired returns true if a handshake is required
func (ht *HealthTracker) IsHandshakeRequired() bool {
ht.mu.RLock()
defer ht.mu.RUnlock()
return ht.health.HandshakeRequired
}
31 changes: 31 additions & 0 deletions pkg/transport/nclprotocol/compute/health_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,13 @@ func (s *HealthTrackerTestSuite) TestInitialState() {
s.Require().Equal(0, health.ConsecutiveFailures)
s.Require().Nil(health.LastError)
s.Require().True(health.ConnectedSince.IsZero())
s.Require().False(health.HandshakeRequired) // Verify initial state of HandshakeRequired
}

func (s *HealthTrackerTestSuite) TestMarkConnected() {
// First mark handshake required
s.tracker.HandshakeRequired()

// Advance clock to have distinct timestamps
s.clock.Add(time.Second)
connectedTime := s.clock.Now()
Expand All @@ -54,6 +58,7 @@ func (s *HealthTrackerTestSuite) TestMarkConnected() {
s.Require().Equal(connectedTime, health.LastSuccessfulHeartbeat)
s.Require().Equal(0, health.ConsecutiveFailures)
s.Require().Nil(health.LastError)
s.Require().False(health.HandshakeRequired) // Should be reset when connected
}

func (s *HealthTrackerTestSuite) TestMarkDisconnected() {
Expand All @@ -68,13 +73,39 @@ func (s *HealthTrackerTestSuite) TestMarkDisconnected() {
s.Require().Equal(nclprotocol.Disconnected, health.CurrentState)
s.Require().Equal(expectedErr, health.LastError)
s.Require().Equal(1, health.ConsecutiveFailures)
s.Require().False(health.HandshakeRequired) // Should still be false after disconnect

// Multiple disconnections should increment failure count
s.tracker.MarkDisconnected(expectedErr)
health = s.tracker.GetHealth()
s.Require().Equal(2, health.ConsecutiveFailures)
}

func (s *HealthTrackerTestSuite) TestHandshakeRequired() {
// Initially handshake should not be required
s.Require().False(s.tracker.IsHandshakeRequired())

// Mark handshake as required
s.tracker.HandshakeRequired()
s.Require().True(s.tracker.IsHandshakeRequired())

// Verify it's cleared when connected
s.tracker.MarkConnected()
s.Require().False(s.tracker.IsHandshakeRequired())

// Verify it's cleared when disconnected
s.tracker.HandshakeRequired()
s.Require().True(s.tracker.IsHandshakeRequired())
s.tracker.MarkDisconnected(fmt.Errorf("error"))
s.Require().False(s.tracker.IsHandshakeRequired())

// Verify it's cleared when connecting
s.tracker.HandshakeRequired()
s.Require().True(s.tracker.IsHandshakeRequired())
s.tracker.MarkConnecting()
s.Require().False(s.tracker.IsHandshakeRequired())
}

func (s *HealthTrackerTestSuite) TestSuccessfulOperations() {
// Initial timestamps
s.clock.Add(time.Second)
Expand Down
4 changes: 4 additions & 0 deletions pkg/transport/nclprotocol/compute/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ func (cm *ConnectionManager) checkConnectionHealth() {
// Consider connection unhealthy if:
// 1. No heartbeat succeeded within HeartbeatMissFactor intervals
// 2. NATS connection is closed/draining
// 3. Health tracker reports a handshake required
now := cm.config.Clock.Now()
heartbeatDeadline := now.Add(-time.Duration(cm.config.HeartbeatMissFactor) * cm.config.HeartbeatInterval)

Expand All @@ -443,6 +444,9 @@ func (cm *ConnectionManager) checkConnectionHealth() {
} else if cm.natsConn.IsClosed() {
reason = "NATS connection closed"
unhealthy = true
} else if cm.healthTracker.IsHandshakeRequired() {
reason = "handshake required"
unhealthy = true
}

if unhealthy {
Expand Down
73 changes: 73 additions & 0 deletions pkg/transport/nclprotocol/compute/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"fmt"
"reflect"
"strings"
"testing"
"time"

Expand All @@ -18,6 +19,7 @@ import (
"github.com/bacalhau-project/bacalhau/pkg/models"
"github.com/bacalhau-project/bacalhau/pkg/models/messages"
natsutil "github.com/bacalhau-project/bacalhau/pkg/nats"
"github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes"
testutils "github.com/bacalhau-project/bacalhau/pkg/test/utils"
"github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol"
nclprotocolcompute "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/compute"
Expand Down Expand Up @@ -239,6 +241,77 @@ func (s *ConnectionManagerTestSuite) TestHeartbeatFailure() {
}, time.Second, 10*time.Millisecond)
}

func (s *ConnectionManagerTestSuite) TestHeartbeatHandshakeRequired() {
err := s.manager.Start(s.ctx)
s.Require().NoError(err)

// Wait for initial connection
s.Require().Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Connected
}, time.Second, 10*time.Millisecond, "manager should connect initially")

// Configure heartbeat to require handshake
s.mockResponder.Behaviour().HeartbeatResponse.Error = nodes.NewErrHandshakeRequired("test-node")

// Should disconnect quickly after handshake required error
s.Require().Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Disconnected &&
health.LastError != nil &&
strings.Contains(health.LastError.Error(), "handshake required")
}, 1*time.Second, 5*time.Millisecond, "should disconnect due to handshake required: %+v", s.manager.GetHealth())

// Reset heartbeat response to succeed
s.mockResponder.Behaviour().HeartbeatResponse.Error = nil

// Should automatically attempt reconnection
s.Require().Eventually(func() bool {
// Get new handshakes after disconnect
handshakes := s.mockResponder.GetHandshakes()
return len(handshakes) > 1 // More than initial handshake
}, time.Second, 10*time.Millisecond, "should attempt reconnection")

// Should successfully reconnect
s.Require().Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Connected &&
!health.HandshakeRequired // Should be cleared after successful connection
}, time.Second, 10*time.Millisecond, "should reconnect successfully")

// Verify heartbeats resume
time.Sleep(s.config.HeartbeatInterval * 2)
heartbeats := s.mockResponder.GetHeartbeats()
s.Require().NotEmpty(heartbeats, "should resume heartbeats after reconnection")
}

func (s *ConnectionManagerTestSuite) TestHeartbeatHandshakeRequiredDifferentError() {
err := s.manager.Start(s.ctx)
s.Require().NoError(err)

// Wait for initial connection
s.Require().Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Connected
}, time.Second, 10*time.Millisecond)

// Configure heartbeat with error that mentions handshake but isn't the specific error
s.mockResponder.Behaviour().HeartbeatResponse.Error = fmt.Errorf("failed to process handshake data")

// Wait some heartbeat intervals - should not immediately disconnect
time.Sleep(s.config.HeartbeatInterval * 2)
health := s.manager.GetHealth()
s.False(health.HandshakeRequired, "should not set handshake required for different errors")

// Should eventually disconnect due to missed heartbeats
time.Sleep(s.config.HeartbeatInterval * time.Duration(s.config.HeartbeatMissFactor+1))
s.Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Disconnected &&
!health.HandshakeRequired // Should not be set
}, time.Second, 10*time.Millisecond)
}

func (s *ConnectionManagerTestSuite) TestNodeInfoUpdates() {
// Configure heartbeat callback to trigger node info updates
s.mockResponder.Behaviour().OnHeartbeat = func(req messages.HeartbeatRequest) {
Expand Down
1 change: 1 addition & 0 deletions pkg/transport/nclprotocol/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type ConnectionHealth struct {
ConsecutiveFailures int
LastError error
ConnectedSince time.Time
HandshakeRequired bool
}

const (
Expand Down

0 comments on commit 36c44b5

Please sign in to comment.