Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster reconnect on handshake required response #4772

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pkg/transport/nclprotocol/compute/controlplane.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package compute
import (
"context"
"fmt"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -104,6 +105,10 @@ func (cp *ControlPlane) run(ctx context.Context) {

case <-heartbeat.C:
if err := cp.heartbeat(ctx); err != nil {
if strings.Contains(err.Error(), "handshake required") {
cp.healthTracker.HandshakeRequired()
return
}
Comment on lines +108 to +111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider using error type assertion instead of string matching.

The current implementation uses string matching to detect handshake required errors, which can be fragile if error messages change. Consider using type assertion or error wrapping to make the error handling more robust.

-if strings.Contains(err.Error(), "handshake required") {
+var handshakeErr *nodes.ErrHandshakeRequired
+if errors.As(err, &handshakeErr) {
   cp.healthTracker.HandshakeRequired()
   return
}

Committable suggestion skipped: line range outside the PR's diff.

log.Error().Err(err).Msg("Failed to send heartbeat")
}
case <-nodeInfo.C:
Expand Down
67 changes: 67 additions & 0 deletions pkg/transport/nclprotocol/compute/controlplane_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/bacalhau-project/bacalhau/pkg/lib/envelope"
"github.com/bacalhau-project/bacalhau/pkg/lib/ncl"
"github.com/bacalhau-project/bacalhau/pkg/models/messages"
"github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes"
"github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol"
nclprotocolcompute "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/compute"
ncltest "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/test"
Expand Down Expand Up @@ -179,6 +180,72 @@ func (s *ControlPlaneTestSuite) TestHeartbeat() {
}, 100*time.Millisecond, 10*time.Millisecond, "Heartbeat did not succeed")
}

func (s *ControlPlaneTestSuite) TestHeartbeatFailFastOnHandshakeRequired() {
// Create control plane with only heartbeat enabled and short intervals
controlPlane := s.createControlPlane(
50*time.Millisecond, // heartbeat
1*time.Hour, // node info - disabled
1*time.Hour, // checkpoint - disabled
)
defer s.Require().NoError(controlPlane.Stop(s.ctx))

// Setup handshake required error response
s.requester.EXPECT().
Request(gomock.Any(), gomock.Any()).
Return(nil, nodes.NewErrHandshakeRequired("test-node")).
Times(1) // Should only try once

// Start control plane
s.Require().NoError(controlPlane.Start(s.ctx))

// Wait a bit to allow for heartbeat attempt
time.Sleep(50 * time.Millisecond)

// wait health tracker state
s.Require().Eventually(func() bool {
return s.healthTracker.GetHealth().HandshakeRequired
}, 100*time.Millisecond, 10*time.Millisecond, "Heartbeat did not succeed")

// Verify health tracker state
health := s.healthTracker.GetHealth()
s.True(health.HandshakeRequired, "handshake should be marked as required")
s.Zero(health.LastSuccessfulHeartbeat, "no successful heartbeat should be recorded")

// Wait for another heartbeat interval to verify the loop has stopped
time.Sleep(70 * time.Millisecond)

s.Require().Eventually(func() bool {
return s.healthTracker.GetHealth().CurrentState == nclprotocol.Disconnected
}, 100*time.Millisecond, 10*time.Millisecond, "connection not marked as disconnected")
}

func (s *ControlPlaneTestSuite) TestHeartbeatContinuesOnOtherErrors() {
// Create control plane with only heartbeat enabled
controlPlane := s.createControlPlane(
50*time.Millisecond, // heartbeat
1*time.Hour, // node info - disabled
1*time.Hour, // checkpoint - disabled
)
defer s.Require().NoError(controlPlane.Stop(s.ctx))

// Setup regular error response that should not cause fail-fast
s.requester.EXPECT().
Request(gomock.Any(), gomock.Any()).
Return(nil, fmt.Errorf("network error")).
Times(2) // Should keep trying

// Start control plane
s.Require().NoError(controlPlane.Start(s.ctx))

// Wait for two heartbeat attempts
time.Sleep(120 * time.Millisecond)

// Verify health tracker state
health := s.healthTracker.GetHealth()
s.False(health.HandshakeRequired, "handshake should not be marked as required")
s.Zero(health.LastSuccessfulHeartbeat, "no successful heartbeat should be recorded")
}

func (s *ControlPlaneTestSuite) TestNodeInfoUpdate() {
// Create control plane with only checkpointing enabled
controlPlane := s.createControlPlane(
Expand Down
17 changes: 17 additions & 0 deletions pkg/transport/nclprotocol/compute/health_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func (ht *HealthTracker) MarkConnected() {
ht.health.LastSuccessfulHeartbeat = ht.clock.Now()
ht.health.ConsecutiveFailures = 0
ht.health.LastError = nil
ht.health.HandshakeRequired = false
}

// MarkDisconnected updates status when connection is lost
Expand All @@ -46,6 +47,7 @@ func (ht *HealthTracker) MarkDisconnected(err error) {
ht.health.CurrentState = nclprotocol.Disconnected
ht.health.LastError = err
ht.health.ConsecutiveFailures++
ht.health.HandshakeRequired = false
}

// MarkConnecting update status when connection is in progress
Expand All @@ -54,6 +56,7 @@ func (ht *HealthTracker) MarkConnecting() {
defer ht.mu.Unlock()

ht.health.CurrentState = nclprotocol.Connecting
ht.health.HandshakeRequired = false
}

// HeartbeatSuccess records successful heartbeat
Expand All @@ -70,6 +73,13 @@ func (ht *HealthTracker) UpdateSuccess() {
ht.health.LastSuccessfulUpdate = ht.clock.Now()
}

// HandshakeRequired marks that a handshake is required
func (ht *HealthTracker) HandshakeRequired() {
ht.mu.Lock()
defer ht.mu.Unlock()
ht.health.HandshakeRequired = true
}

// GetState returns current connection state
func (ht *HealthTracker) GetState() nclprotocol.ConnectionState {
ht.mu.RLock()
Expand All @@ -83,3 +93,10 @@ func (ht *HealthTracker) GetHealth() nclprotocol.ConnectionHealth {
defer ht.mu.RUnlock()
return ht.health
}

// IsHandshakeRequired returns true if a handshake is required
func (ht *HealthTracker) IsHandshakeRequired() bool {
ht.mu.RLock()
defer ht.mu.RUnlock()
return ht.health.HandshakeRequired
}
31 changes: 31 additions & 0 deletions pkg/transport/nclprotocol/compute/health_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,13 @@ func (s *HealthTrackerTestSuite) TestInitialState() {
s.Require().Equal(0, health.ConsecutiveFailures)
s.Require().Nil(health.LastError)
s.Require().True(health.ConnectedSince.IsZero())
s.Require().False(health.HandshakeRequired) // Verify initial state of HandshakeRequired
}

func (s *HealthTrackerTestSuite) TestMarkConnected() {
// First mark handshake required
s.tracker.HandshakeRequired()

// Advance clock to have distinct timestamps
s.clock.Add(time.Second)
connectedTime := s.clock.Now()
Expand All @@ -54,6 +58,7 @@ func (s *HealthTrackerTestSuite) TestMarkConnected() {
s.Require().Equal(connectedTime, health.LastSuccessfulHeartbeat)
s.Require().Equal(0, health.ConsecutiveFailures)
s.Require().Nil(health.LastError)
s.Require().False(health.HandshakeRequired) // Should be reset when connected
}

func (s *HealthTrackerTestSuite) TestMarkDisconnected() {
Expand All @@ -68,13 +73,39 @@ func (s *HealthTrackerTestSuite) TestMarkDisconnected() {
s.Require().Equal(nclprotocol.Disconnected, health.CurrentState)
s.Require().Equal(expectedErr, health.LastError)
s.Require().Equal(1, health.ConsecutiveFailures)
s.Require().False(health.HandshakeRequired) // Should still be false after disconnect

// Multiple disconnections should increment failure count
s.tracker.MarkDisconnected(expectedErr)
health = s.tracker.GetHealth()
s.Require().Equal(2, health.ConsecutiveFailures)
}

func (s *HealthTrackerTestSuite) TestHandshakeRequired() {
// Initially handshake should not be required
s.Require().False(s.tracker.IsHandshakeRequired())

// Mark handshake as required
s.tracker.HandshakeRequired()
s.Require().True(s.tracker.IsHandshakeRequired())

// Verify it's cleared when connected
s.tracker.MarkConnected()
s.Require().False(s.tracker.IsHandshakeRequired())

// Verify it's cleared when disconnected
s.tracker.HandshakeRequired()
s.Require().True(s.tracker.IsHandshakeRequired())
s.tracker.MarkDisconnected(fmt.Errorf("error"))
s.Require().False(s.tracker.IsHandshakeRequired())

// Verify it's cleared when connecting
s.tracker.HandshakeRequired()
s.Require().True(s.tracker.IsHandshakeRequired())
s.tracker.MarkConnecting()
s.Require().False(s.tracker.IsHandshakeRequired())
}

func (s *HealthTrackerTestSuite) TestSuccessfulOperations() {
// Initial timestamps
s.clock.Add(time.Second)
Expand Down
4 changes: 4 additions & 0 deletions pkg/transport/nclprotocol/compute/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ func (cm *ConnectionManager) checkConnectionHealth() {
// Consider connection unhealthy if:
// 1. No heartbeat succeeded within HeartbeatMissFactor intervals
// 2. NATS connection is closed/draining
// 3. Health tracker reports a handshake required
now := cm.config.Clock.Now()
heartbeatDeadline := now.Add(-time.Duration(cm.config.HeartbeatMissFactor) * cm.config.HeartbeatInterval)

Expand All @@ -443,6 +444,9 @@ func (cm *ConnectionManager) checkConnectionHealth() {
} else if cm.natsConn.IsClosed() {
reason = "NATS connection closed"
unhealthy = true
} else if cm.healthTracker.IsHandshakeRequired() {
reason = "handshake required"
unhealthy = true
}

if unhealthy {
Expand Down
73 changes: 73 additions & 0 deletions pkg/transport/nclprotocol/compute/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"fmt"
"reflect"
"strings"
"testing"
"time"

Expand All @@ -18,6 +19,7 @@ import (
"github.com/bacalhau-project/bacalhau/pkg/models"
"github.com/bacalhau-project/bacalhau/pkg/models/messages"
natsutil "github.com/bacalhau-project/bacalhau/pkg/nats"
"github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes"
testutils "github.com/bacalhau-project/bacalhau/pkg/test/utils"
"github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol"
nclprotocolcompute "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/compute"
Expand Down Expand Up @@ -239,6 +241,77 @@ func (s *ConnectionManagerTestSuite) TestHeartbeatFailure() {
}, time.Second, 10*time.Millisecond)
}

func (s *ConnectionManagerTestSuite) TestHeartbeatHandshakeRequired() {
err := s.manager.Start(s.ctx)
s.Require().NoError(err)

// Wait for initial connection
s.Require().Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Connected
}, time.Second, 10*time.Millisecond, "manager should connect initially")

// Configure heartbeat to require handshake
s.mockResponder.Behaviour().HeartbeatResponse.Error = nodes.NewErrHandshakeRequired("test-node")

// Should disconnect quickly after handshake required error
s.Require().Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Disconnected &&
health.LastError != nil &&
strings.Contains(health.LastError.Error(), "handshake required")
}, 1*time.Second, 5*time.Millisecond, "should disconnect due to handshake required: %+v", s.manager.GetHealth())

// Reset heartbeat response to succeed
s.mockResponder.Behaviour().HeartbeatResponse.Error = nil

// Should automatically attempt reconnection
s.Require().Eventually(func() bool {
// Get new handshakes after disconnect
handshakes := s.mockResponder.GetHandshakes()
return len(handshakes) > 1 // More than initial handshake
}, time.Second, 10*time.Millisecond, "should attempt reconnection")

// Should successfully reconnect
s.Require().Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Connected &&
!health.HandshakeRequired // Should be cleared after successful connection
}, time.Second, 10*time.Millisecond, "should reconnect successfully")

// Verify heartbeats resume
time.Sleep(s.config.HeartbeatInterval * 2)
heartbeats := s.mockResponder.GetHeartbeats()
s.Require().NotEmpty(heartbeats, "should resume heartbeats after reconnection")
}

func (s *ConnectionManagerTestSuite) TestHeartbeatHandshakeRequiredDifferentError() {
err := s.manager.Start(s.ctx)
s.Require().NoError(err)

// Wait for initial connection
s.Require().Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Connected
}, time.Second, 10*time.Millisecond)

// Configure heartbeat with error that mentions handshake but isn't the specific error
s.mockResponder.Behaviour().HeartbeatResponse.Error = fmt.Errorf("failed to process handshake data")

// Wait some heartbeat intervals - should not immediately disconnect
time.Sleep(s.config.HeartbeatInterval * 2)
health := s.manager.GetHealth()
s.False(health.HandshakeRequired, "should not set handshake required for different errors")

// Should eventually disconnect due to missed heartbeats
time.Sleep(s.config.HeartbeatInterval * time.Duration(s.config.HeartbeatMissFactor+1))
s.Eventually(func() bool {
health := s.manager.GetHealth()
return health.CurrentState == nclprotocol.Disconnected &&
!health.HandshakeRequired // Should not be set
}, time.Second, 10*time.Millisecond)
}

func (s *ConnectionManagerTestSuite) TestNodeInfoUpdates() {
// Configure heartbeat callback to trigger node info updates
s.mockResponder.Behaviour().OnHeartbeat = func(req messages.HeartbeatRequest) {
Expand Down
1 change: 1 addition & 0 deletions pkg/transport/nclprotocol/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type ConnectionHealth struct {
ConsecutiveFailures int
LastError error
ConnectedSince time.Time
HandshakeRequired bool
}

const (
Expand Down
Loading