Skip to content

Commit

Permalink
Faster reconnect on handshake required response
Browse files Browse the repository at this point in the history
  • Loading branch information
wdbaruni committed Dec 16, 2024
1 parent f9751b5 commit 69b1b42
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pkg/transport/nclprotocol/compute/controlplane.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package compute
import (
"context"
"fmt"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -104,6 +105,10 @@ func (cp *ControlPlane) run(ctx context.Context) {

case <-heartbeat.C:
if err := cp.heartbeat(ctx); err != nil {
if strings.Contains(err.Error(), "handshake required") {
cp.healthTracker.HandshakeRequired()
return
}
log.Error().Err(err).Msg("Failed to send heartbeat")
}
case <-nodeInfo.C:
Expand Down
67 changes: 67 additions & 0 deletions pkg/transport/nclprotocol/compute/controlplane_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/bacalhau-project/bacalhau/pkg/lib/envelope"
"github.com/bacalhau-project/bacalhau/pkg/lib/ncl"
"github.com/bacalhau-project/bacalhau/pkg/models/messages"
"github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes"
"github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol"
nclprotocolcompute "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/compute"
ncltest "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/test"
Expand Down Expand Up @@ -179,6 +180,72 @@ func (s *ControlPlaneTestSuite) TestHeartbeat() {
}, 100*time.Millisecond, 10*time.Millisecond, "Heartbeat did not succeed")
}

func (s *ControlPlaneTestSuite) TestHeartbeatFailFastOnHandshakeRequired() {
// Create control plane with only heartbeat enabled and short intervals
controlPlane := s.createControlPlane(
50*time.Millisecond, // heartbeat
1*time.Hour, // node info - disabled
1*time.Hour, // checkpoint - disabled
)
defer s.Require().NoError(controlPlane.Stop(s.ctx))

// Setup handshake required error response
s.requester.EXPECT().
Request(gomock.Any(), gomock.Any()).
Return(nil, nodes.NewErrHandshakeRequired("test-node")).
Times(1) // Should only try once

// Start control plane
s.Require().NoError(controlPlane.Start(s.ctx))

// Wait a bit to allow for heartbeat attempt
time.Sleep(50 * time.Millisecond)

// wait health tracker state
s.Require().Eventually(func() bool {
return s.healthTracker.GetHealth().HandshakeRequired
}, 100*time.Millisecond, 10*time.Millisecond, "Heartbeat did not succeed")

// Verify health tracker state
health := s.healthTracker.GetHealth()
s.True(health.HandshakeRequired, "handshake should be marked as required")
s.Zero(health.LastSuccessfulHeartbeat, "no successful heartbeat should be recorded")

// Wait for another heartbeat interval to verify the loop has stopped
time.Sleep(70 * time.Millisecond)

s.Require().Eventually(func() bool {
return s.healthTracker.GetHealth().CurrentState == nclprotocol.Disconnected
}, 100*time.Millisecond, 10*time.Millisecond, "connection not marked as disconnected")
}

func (s *ControlPlaneTestSuite) TestHeartbeatContinuesOnOtherErrors() {
// Create control plane with only heartbeat enabled
controlPlane := s.createControlPlane(
50*time.Millisecond, // heartbeat
1*time.Hour, // node info - disabled
1*time.Hour, // checkpoint - disabled
)
defer s.Require().NoError(controlPlane.Stop(s.ctx))

// Setup regular error response that should not cause fail-fast
s.requester.EXPECT().
Request(gomock.Any(), gomock.Any()).
Return(nil, fmt.Errorf("network error")).
Times(2) // Should keep trying

// Start control plane
s.Require().NoError(controlPlane.Start(s.ctx))

// Wait for two heartbeat attempts
time.Sleep(120 * time.Millisecond)

// Verify health tracker state
health := s.healthTracker.GetHealth()
s.False(health.HandshakeRequired, "handshake should not be marked as required")
s.Zero(health.LastSuccessfulHeartbeat, "no successful heartbeat should be recorded")
}

func (s *ControlPlaneTestSuite) TestNodeInfoUpdate() {
// Create control plane with only checkpointing enabled
controlPlane := s.createControlPlane(
Expand Down
17 changes: 17 additions & 0 deletions pkg/transport/nclprotocol/compute/health_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func (ht *HealthTracker) MarkConnected() {
ht.health.LastSuccessfulHeartbeat = ht.clock.Now()
ht.health.ConsecutiveFailures = 0
ht.health.LastError = nil
ht.health.HandshakeRequired = false
}

// MarkDisconnected updates status when connection is lost
Expand All @@ -46,6 +47,7 @@ func (ht *HealthTracker) MarkDisconnected(err error) {
ht.health.CurrentState = nclprotocol.Disconnected
ht.health.LastError = err
ht.health.ConsecutiveFailures++
ht.health.HandshakeRequired = false
}

// MarkConnecting update status when connection is in progress
Expand All @@ -54,6 +56,7 @@ func (ht *HealthTracker) MarkConnecting() {
defer ht.mu.Unlock()

ht.health.CurrentState = nclprotocol.Connecting
ht.health.HandshakeRequired = false
}

// HeartbeatSuccess records successful heartbeat
Expand All @@ -70,6 +73,13 @@ func (ht *HealthTracker) UpdateSuccess() {
ht.health.LastSuccessfulUpdate = ht.clock.Now()
}

// HandshakeRequired marks that a handshake is required
func (ht *HealthTracker) HandshakeRequired() {
ht.mu.Lock()
defer ht.mu.Unlock()
ht.health.HandshakeRequired = true
}

// GetState returns current connection state
func (ht *HealthTracker) GetState() nclprotocol.ConnectionState {
ht.mu.RLock()
Expand All @@ -83,3 +93,10 @@ func (ht *HealthTracker) GetHealth() nclprotocol.ConnectionHealth {
defer ht.mu.RUnlock()
return ht.health
}

// IsHandshakeRequired returns true if a handshake is required
func (ht *HealthTracker) IsHandshakeRequired() bool {
ht.mu.RLock()
defer ht.mu.RUnlock()
return ht.health.HandshakeRequired
}
31 changes: 31 additions & 0 deletions pkg/transport/nclprotocol/compute/health_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,13 @@ func (s *HealthTrackerTestSuite) TestInitialState() {
s.Require().Equal(0, health.ConsecutiveFailures)
s.Require().Nil(health.LastError)
s.Require().True(health.ConnectedSince.IsZero())
s.Require().False(health.HandshakeRequired) // Verify initial state of HandshakeRequired
}

func (s *HealthTrackerTestSuite) TestMarkConnected() {
// First mark handshake required
s.tracker.HandshakeRequired()

// Advance clock to have distinct timestamps
s.clock.Add(time.Second)
connectedTime := s.clock.Now()
Expand All @@ -54,6 +58,7 @@ func (s *HealthTrackerTestSuite) TestMarkConnected() {
s.Require().Equal(connectedTime, health.LastSuccessfulHeartbeat)
s.Require().Equal(0, health.ConsecutiveFailures)
s.Require().Nil(health.LastError)
s.Require().False(health.HandshakeRequired) // Should be reset when connected
}

func (s *HealthTrackerTestSuite) TestMarkDisconnected() {
Expand All @@ -68,13 +73,39 @@ func (s *HealthTrackerTestSuite) TestMarkDisconnected() {
s.Require().Equal(nclprotocol.Disconnected, health.CurrentState)
s.Require().Equal(expectedErr, health.LastError)
s.Require().Equal(1, health.ConsecutiveFailures)
s.Require().False(health.HandshakeRequired) // Should still be false after disconnect

// Multiple disconnections should increment failure count
s.tracker.MarkDisconnected(expectedErr)
health = s.tracker.GetHealth()
s.Require().Equal(2, health.ConsecutiveFailures)
}

func (s *HealthTrackerTestSuite) TestHandshakeRequired() {
// Initially handshake should not be required
s.Require().False(s.tracker.IsHandshakeRequired())

// Mark handshake as required
s.tracker.HandshakeRequired()
s.Require().True(s.tracker.IsHandshakeRequired())

// Verify it's cleared when connected
s.tracker.MarkConnected()
s.Require().False(s.tracker.IsHandshakeRequired())

// Verify it's cleared when disconnected
s.tracker.HandshakeRequired()
s.Require().True(s.tracker.IsHandshakeRequired())
s.tracker.MarkDisconnected(fmt.Errorf("error"))
s.Require().False(s.tracker.IsHandshakeRequired())

// Verify it's cleared when connecting
s.tracker.HandshakeRequired()
s.Require().True(s.tracker.IsHandshakeRequired())
s.tracker.MarkConnecting()
s.Require().False(s.tracker.IsHandshakeRequired())
}

func (s *HealthTrackerTestSuite) TestSuccessfulOperations() {
// Initial timestamps
s.clock.Add(time.Second)
Expand Down
4 changes: 4 additions & 0 deletions pkg/transport/nclprotocol/compute/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ func (cm *ConnectionManager) checkConnectionHealth() {
// Consider connection unhealthy if:
// 1. No heartbeat succeeded within HeartbeatMissFactor intervals
// 2. NATS connection is closed/draining
// 3. Health tracker reports a handshake required
now := cm.config.Clock.Now()
heartbeatDeadline := now.Add(-time.Duration(cm.config.HeartbeatMissFactor) * cm.config.HeartbeatInterval)

Expand All @@ -433,6 +434,9 @@ func (cm *ConnectionManager) checkConnectionHealth() {
} else if cm.natsConn.IsClosed() {
reason = "NATS connection closed"
unhealthy = true
} else if cm.healthTracker.IsHandshakeRequired() {
reason = "handshake required"
unhealthy = true
}

if unhealthy {
Expand Down
73 changes: 73 additions & 0 deletions pkg/transport/nclprotocol/compute/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"fmt"
"reflect"
"strings"
"testing"
"time"

Expand All @@ -18,6 +19,7 @@ import (
"github.com/bacalhau-project/bacalhau/pkg/models"
"github.com/bacalhau-project/bacalhau/pkg/models/messages"
natsutil "github.com/bacalhau-project/bacalhau/pkg/nats"
"github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes"
testutils "github.com/bacalhau-project/bacalhau/pkg/test/utils"
"github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol"
nclprotocolcompute "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/compute"
Expand Down Expand Up @@ -239,6 +241,77 @@ func (s *ConnectionManagerTestSuite) TestHeartbeatFailure() {
}, time.Second, 10*time.Millisecond)
}

func (s *ConnectionManagerTestSuite) TestHeartbeatHandshakeRequired() {
err := s.manager.Start(s.ctx)
s.Require().NoError(err)

// Wait for initial connection
s.Require().Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Connected
}, time.Second, 10*time.Millisecond, "manager should connect initially")

// Configure heartbeat to require handshake
s.mockResponder.Behaviour().HeartbeatResponse.Error = nodes.NewErrHandshakeRequired("test-node")

// Should disconnect quickly after handshake required error
s.Require().Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Disconnected &&
health.LastError != nil &&
strings.Contains(health.LastError.Error(), "handshake required")
}, 1*time.Second, 5*time.Millisecond, "should disconnect due to handshake required: %+v", s.manager.GetHealth())

// Reset heartbeat response to succeed
s.mockResponder.Behaviour().HeartbeatResponse.Error = nil

// Should automatically attempt reconnection
s.Require().Eventually(func() bool {
// Get new handshakes after disconnect
handshakes := s.mockResponder.GetHandshakes()
return len(handshakes) > 1 // More than initial handshake
}, time.Second, 10*time.Millisecond, "should attempt reconnection")

// Should successfully reconnect
s.Require().Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Connected &&
!health.HandshakeRequired // Should be cleared after successful connection
}, time.Second, 10*time.Millisecond, "should reconnect successfully")

// Verify heartbeats resume
time.Sleep(s.config.HeartbeatInterval * 2)
heartbeats := s.mockResponder.GetHeartbeats()
s.Require().NotEmpty(heartbeats, "should resume heartbeats after reconnection")
}

func (s *ConnectionManagerTestSuite) TestHeartbeatHandshakeRequiredDifferentError() {
err := s.manager.Start(s.ctx)
s.Require().NoError(err)

// Wait for initial connection
s.Require().Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Connected
}, time.Second, 10*time.Millisecond)

// Configure heartbeat with error that mentions handshake but isn't the specific error
s.mockResponder.Behaviour().HeartbeatResponse.Error = fmt.Errorf("failed to process handshake data")

// Wait some heartbeat intervals - should not immediately disconnect
time.Sleep(s.config.HeartbeatInterval * 2)
health := s.manager.GetHealth()
s.False(health.HandshakeRequired, "should not set handshake required for different errors")

// Should eventually disconnect due to missed heartbeats
time.Sleep(s.config.HeartbeatInterval * time.Duration(s.config.HeartbeatMissFactor+1))
s.Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Disconnected &&
!health.HandshakeRequired // Should not be set
}, time.Second, 10*time.Millisecond)
}

func (s *ConnectionManagerTestSuite) TestNodeInfoUpdates() {
// Configure heartbeat callback to trigger node info updates
s.mockResponder.Behaviour().OnHeartbeat = func(req messages.HeartbeatRequest) {
Expand Down
1 change: 1 addition & 0 deletions pkg/transport/nclprotocol/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type ConnectionHealth struct {
ConsecutiveFailures int
LastError error
ConnectedSince time.Time
HandshakeRequired bool
}

const (
Expand Down

0 comments on commit 69b1b42

Please sign in to comment.