diff --git a/pkg/transport/nclprotocol/compute/controlplane.go b/pkg/transport/nclprotocol/compute/controlplane.go index 845af7ba5c..5a55fac937 100644 --- a/pkg/transport/nclprotocol/compute/controlplane.go +++ b/pkg/transport/nclprotocol/compute/controlplane.go @@ -3,6 +3,7 @@ package compute import ( "context" "fmt" + "strings" "sync" "time" @@ -104,6 +105,10 @@ func (cp *ControlPlane) run(ctx context.Context) { case <-heartbeat.C: if err := cp.heartbeat(ctx); err != nil { + if strings.Contains(err.Error(), "handshake required") { + cp.healthTracker.HandshakeRequired() + return + } log.Error().Err(err).Msg("Failed to send heartbeat") } case <-nodeInfo.C: diff --git a/pkg/transport/nclprotocol/compute/controlplane_test.go b/pkg/transport/nclprotocol/compute/controlplane_test.go index a76944c278..bd3cd4cc70 100644 --- a/pkg/transport/nclprotocol/compute/controlplane_test.go +++ b/pkg/transport/nclprotocol/compute/controlplane_test.go @@ -15,6 +15,7 @@ import ( "github.com/bacalhau-project/bacalhau/pkg/lib/envelope" "github.com/bacalhau-project/bacalhau/pkg/lib/ncl" "github.com/bacalhau-project/bacalhau/pkg/models/messages" + "github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes" "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" nclprotocolcompute "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/compute" ncltest "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/test" @@ -179,6 +180,72 @@ func (s *ControlPlaneTestSuite) TestHeartbeat() { }, 100*time.Millisecond, 10*time.Millisecond, "Heartbeat did not succeed") } +func (s *ControlPlaneTestSuite) TestHeartbeatFailFastOnHandshakeRequired() { + // Create control plane with only heartbeat enabled and short intervals + controlPlane := s.createControlPlane( + 50*time.Millisecond, // heartbeat + 1*time.Hour, // node info - disabled + 1*time.Hour, // checkpoint - disabled + ) + defer s.Require().NoError(controlPlane.Stop(s.ctx)) + + // Setup handshake required error response + s.requester.EXPECT(). + Request(gomock.Any(), gomock.Any()). + Return(nil, nodes.NewErrHandshakeRequired("test-node")). + Times(1) // Should only try once + + // Start control plane + s.Require().NoError(controlPlane.Start(s.ctx)) + + // Wait a bit to allow for heartbeat attempt + time.Sleep(50 * time.Millisecond) + + // wait health tracker state + s.Require().Eventually(func() bool { + return s.healthTracker.GetHealth().HandshakeRequired + }, 100*time.Millisecond, 10*time.Millisecond, "Heartbeat did not succeed") + + // Verify health tracker state + health := s.healthTracker.GetHealth() + s.True(health.HandshakeRequired, "handshake should be marked as required") + s.Zero(health.LastSuccessfulHeartbeat, "no successful heartbeat should be recorded") + + // Wait for another heartbeat interval to verify the loop has stopped + time.Sleep(70 * time.Millisecond) + + s.Require().Eventually(func() bool { + return s.healthTracker.GetHealth().CurrentState == nclprotocol.Disconnected + }, 100*time.Millisecond, 10*time.Millisecond, "connection not marked as disconnected") +} + +func (s *ControlPlaneTestSuite) TestHeartbeatContinuesOnOtherErrors() { + // Create control plane with only heartbeat enabled + controlPlane := s.createControlPlane( + 50*time.Millisecond, // heartbeat + 1*time.Hour, // node info - disabled + 1*time.Hour, // checkpoint - disabled + ) + defer s.Require().NoError(controlPlane.Stop(s.ctx)) + + // Setup regular error response that should not cause fail-fast + s.requester.EXPECT(). + Request(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("network error")). + Times(2) // Should keep trying + + // Start control plane + s.Require().NoError(controlPlane.Start(s.ctx)) + + // Wait for two heartbeat attempts + time.Sleep(120 * time.Millisecond) + + // Verify health tracker state + health := s.healthTracker.GetHealth() + s.False(health.HandshakeRequired, "handshake should not be marked as required") + s.Zero(health.LastSuccessfulHeartbeat, "no successful heartbeat should be recorded") +} + func (s *ControlPlaneTestSuite) TestNodeInfoUpdate() { // Create control plane with only checkpointing enabled controlPlane := s.createControlPlane( diff --git a/pkg/transport/nclprotocol/compute/health_tracker.go b/pkg/transport/nclprotocol/compute/health_tracker.go index 873bb8bedf..2234dadfd1 100644 --- a/pkg/transport/nclprotocol/compute/health_tracker.go +++ b/pkg/transport/nclprotocol/compute/health_tracker.go @@ -36,6 +36,7 @@ func (ht *HealthTracker) MarkConnected() { ht.health.LastSuccessfulHeartbeat = ht.clock.Now() ht.health.ConsecutiveFailures = 0 ht.health.LastError = nil + ht.health.HandshakeRequired = false } // MarkDisconnected updates status when connection is lost @@ -46,6 +47,7 @@ func (ht *HealthTracker) MarkDisconnected(err error) { ht.health.CurrentState = nclprotocol.Disconnected ht.health.LastError = err ht.health.ConsecutiveFailures++ + ht.health.HandshakeRequired = false } // MarkConnecting update status when connection is in progress @@ -54,6 +56,7 @@ func (ht *HealthTracker) MarkConnecting() { defer ht.mu.Unlock() ht.health.CurrentState = nclprotocol.Connecting + ht.health.HandshakeRequired = false } // HeartbeatSuccess records successful heartbeat @@ -70,6 +73,13 @@ func (ht *HealthTracker) UpdateSuccess() { ht.health.LastSuccessfulUpdate = ht.clock.Now() } +// HandshakeRequired marks that a handshake is required +func (ht *HealthTracker) HandshakeRequired() { + ht.mu.Lock() + defer ht.mu.Unlock() + ht.health.HandshakeRequired = true +} + // GetState returns current connection state func (ht *HealthTracker) GetState() nclprotocol.ConnectionState { ht.mu.RLock() @@ -83,3 +93,10 @@ func (ht *HealthTracker) GetHealth() nclprotocol.ConnectionHealth { defer ht.mu.RUnlock() return ht.health } + +// IsHandshakeRequired returns true if a handshake is required +func (ht *HealthTracker) IsHandshakeRequired() bool { + ht.mu.RLock() + defer ht.mu.RUnlock() + return ht.health.HandshakeRequired +} diff --git a/pkg/transport/nclprotocol/compute/health_tracker_test.go b/pkg/transport/nclprotocol/compute/health_tracker_test.go index 9fe680d0ad..e118d3494e 100644 --- a/pkg/transport/nclprotocol/compute/health_tracker_test.go +++ b/pkg/transport/nclprotocol/compute/health_tracker_test.go @@ -39,9 +39,13 @@ func (s *HealthTrackerTestSuite) TestInitialState() { s.Require().Equal(0, health.ConsecutiveFailures) s.Require().Nil(health.LastError) s.Require().True(health.ConnectedSince.IsZero()) + s.Require().False(health.HandshakeRequired) // Verify initial state of HandshakeRequired } func (s *HealthTrackerTestSuite) TestMarkConnected() { + // First mark handshake required + s.tracker.HandshakeRequired() + // Advance clock to have distinct timestamps s.clock.Add(time.Second) connectedTime := s.clock.Now() @@ -54,6 +58,7 @@ func (s *HealthTrackerTestSuite) TestMarkConnected() { s.Require().Equal(connectedTime, health.LastSuccessfulHeartbeat) s.Require().Equal(0, health.ConsecutiveFailures) s.Require().Nil(health.LastError) + s.Require().False(health.HandshakeRequired) // Should be reset when connected } func (s *HealthTrackerTestSuite) TestMarkDisconnected() { @@ -68,6 +73,7 @@ func (s *HealthTrackerTestSuite) TestMarkDisconnected() { s.Require().Equal(nclprotocol.Disconnected, health.CurrentState) s.Require().Equal(expectedErr, health.LastError) s.Require().Equal(1, health.ConsecutiveFailures) + s.Require().False(health.HandshakeRequired) // Should still be false after disconnect // Multiple disconnections should increment failure count s.tracker.MarkDisconnected(expectedErr) @@ -75,6 +81,31 @@ func (s *HealthTrackerTestSuite) TestMarkDisconnected() { s.Require().Equal(2, health.ConsecutiveFailures) } +func (s *HealthTrackerTestSuite) TestHandshakeRequired() { + // Initially handshake should not be required + s.Require().False(s.tracker.IsHandshakeRequired()) + + // Mark handshake as required + s.tracker.HandshakeRequired() + s.Require().True(s.tracker.IsHandshakeRequired()) + + // Verify it's cleared when connected + s.tracker.MarkConnected() + s.Require().False(s.tracker.IsHandshakeRequired()) + + // Verify it's cleared when disconnected + s.tracker.HandshakeRequired() + s.Require().True(s.tracker.IsHandshakeRequired()) + s.tracker.MarkDisconnected(fmt.Errorf("error")) + s.Require().False(s.tracker.IsHandshakeRequired()) + + // Verify it's cleared when connecting + s.tracker.HandshakeRequired() + s.Require().True(s.tracker.IsHandshakeRequired()) + s.tracker.MarkConnecting() + s.Require().False(s.tracker.IsHandshakeRequired()) +} + func (s *HealthTrackerTestSuite) TestSuccessfulOperations() { // Initial timestamps s.clock.Add(time.Second) diff --git a/pkg/transport/nclprotocol/compute/manager.go b/pkg/transport/nclprotocol/compute/manager.go index addee100d8..d93b85a382 100644 --- a/pkg/transport/nclprotocol/compute/manager.go +++ b/pkg/transport/nclprotocol/compute/manager.go @@ -431,6 +431,7 @@ func (cm *ConnectionManager) checkConnectionHealth() { // Consider connection unhealthy if: // 1. No heartbeat succeeded within HeartbeatMissFactor intervals // 2. NATS connection is closed/draining + // 3. Health tracker reports a handshake required now := cm.config.Clock.Now() heartbeatDeadline := now.Add(-time.Duration(cm.config.HeartbeatMissFactor) * cm.config.HeartbeatInterval) @@ -443,6 +444,9 @@ func (cm *ConnectionManager) checkConnectionHealth() { } else if cm.natsConn.IsClosed() { reason = "NATS connection closed" unhealthy = true + } else if cm.healthTracker.IsHandshakeRequired() { + reason = "handshake required" + unhealthy = true } if unhealthy { diff --git a/pkg/transport/nclprotocol/compute/manager_test.go b/pkg/transport/nclprotocol/compute/manager_test.go index 7a7926ff5c..769983969c 100644 --- a/pkg/transport/nclprotocol/compute/manager_test.go +++ b/pkg/transport/nclprotocol/compute/manager_test.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "reflect" + "strings" "testing" "time" @@ -18,6 +19,7 @@ import ( "github.com/bacalhau-project/bacalhau/pkg/models" "github.com/bacalhau-project/bacalhau/pkg/models/messages" natsutil "github.com/bacalhau-project/bacalhau/pkg/nats" + "github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes" testutils "github.com/bacalhau-project/bacalhau/pkg/test/utils" "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol" nclprotocolcompute "github.com/bacalhau-project/bacalhau/pkg/transport/nclprotocol/compute" @@ -239,6 +241,77 @@ func (s *ConnectionManagerTestSuite) TestHeartbeatFailure() { }, time.Second, 10*time.Millisecond) } +func (s *ConnectionManagerTestSuite) TestHeartbeatHandshakeRequired() { + err := s.manager.Start(s.ctx) + s.Require().NoError(err) + + // Wait for initial connection + s.Require().Eventually(func() bool { + health := s.manager.GetHealth() + return health.CurrentState == nclprotocol.Connected + }, time.Second, 10*time.Millisecond, "manager should connect initially") + + // Configure heartbeat to require handshake + s.mockResponder.Behaviour().HeartbeatResponse.Error = nodes.NewErrHandshakeRequired("test-node") + + // Should disconnect quickly after handshake required error + s.Require().Eventually(func() bool { + health := s.manager.GetHealth() + return health.CurrentState == nclprotocol.Disconnected && + health.LastError != nil && + strings.Contains(health.LastError.Error(), "handshake required") + }, 1*time.Second, 5*time.Millisecond, "should disconnect due to handshake required: %+v", s.manager.GetHealth()) + + // Reset heartbeat response to succeed + s.mockResponder.Behaviour().HeartbeatResponse.Error = nil + + // Should automatically attempt reconnection + s.Require().Eventually(func() bool { + // Get new handshakes after disconnect + handshakes := s.mockResponder.GetHandshakes() + return len(handshakes) > 1 // More than initial handshake + }, time.Second, 10*time.Millisecond, "should attempt reconnection") + + // Should successfully reconnect + s.Require().Eventually(func() bool { + health := s.manager.GetHealth() + return health.CurrentState == nclprotocol.Connected && + !health.HandshakeRequired // Should be cleared after successful connection + }, time.Second, 10*time.Millisecond, "should reconnect successfully") + + // Verify heartbeats resume + time.Sleep(s.config.HeartbeatInterval * 2) + heartbeats := s.mockResponder.GetHeartbeats() + s.Require().NotEmpty(heartbeats, "should resume heartbeats after reconnection") +} + +func (s *ConnectionManagerTestSuite) TestHeartbeatHandshakeRequiredDifferentError() { + err := s.manager.Start(s.ctx) + s.Require().NoError(err) + + // Wait for initial connection + s.Require().Eventually(func() bool { + health := s.manager.GetHealth() + return health.CurrentState == nclprotocol.Connected + }, time.Second, 10*time.Millisecond) + + // Configure heartbeat with error that mentions handshake but isn't the specific error + s.mockResponder.Behaviour().HeartbeatResponse.Error = fmt.Errorf("failed to process handshake data") + + // Wait some heartbeat intervals - should not immediately disconnect + time.Sleep(s.config.HeartbeatInterval * 2) + health := s.manager.GetHealth() + s.False(health.HandshakeRequired, "should not set handshake required for different errors") + + // Should eventually disconnect due to missed heartbeats + time.Sleep(s.config.HeartbeatInterval * time.Duration(s.config.HeartbeatMissFactor+1)) + s.Eventually(func() bool { + health := s.manager.GetHealth() + return health.CurrentState == nclprotocol.Disconnected && + !health.HandshakeRequired // Should not be set + }, time.Second, 10*time.Millisecond) +} + func (s *ConnectionManagerTestSuite) TestNodeInfoUpdates() { // Configure heartbeat callback to trigger node info updates s.mockResponder.Behaviour().OnHeartbeat = func(req messages.HeartbeatRequest) { diff --git a/pkg/transport/nclprotocol/types.go b/pkg/transport/nclprotocol/types.go index 358dc5093e..8ce6b637d6 100644 --- a/pkg/transport/nclprotocol/types.go +++ b/pkg/transport/nclprotocol/types.go @@ -50,6 +50,7 @@ type ConnectionHealth struct { ConsecutiveFailures int LastError error ConnectedSince time.Time + HandshakeRequired bool } const (