Skip to content

Commit

Permalink
always use orchestrator seqNum during handshake (#4740)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
	- Enhanced node management with improved event storage capabilities.
	- New methods added to validate handshake sequence number logic.
	- Updated dispatcher setup process to refine event handling.

- **Bug Fixes**
- Improved error handling during node manager initialization and
dispatcher setup.

- **Tests**
- Expanded test suite to cover edge cases in handshake sequence number
logic and event storage.
- Added tests for handshake sequence number logic and concurrent
operations.

- **Documentation**
- Updated method signatures to reflect new parameters and
functionalities.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
wdbaruni authored Dec 11, 2024
1 parent 04ec16c commit 559e78a
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 11 deletions.
11 changes: 6 additions & 5 deletions pkg/node/requester.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,18 @@ func NewRequesterNode(
transportLayer *nats_transport.NATSTransport,
metadataStore MetadataStore,
) (*Requester, error) {
natsConn, err := transportLayer.CreateClient(ctx)
jobStore, err := createJobStore(ctx, cfg)
if err != nil {
return nil, err
}

nodeID := cfg.NodeID
nodesManager, nodeStore, err := createNodeManager(ctx, cfg, natsConn)
natsConn, err := transportLayer.CreateClient(ctx)
if err != nil {
return nil, err
}

jobStore, err := createJobStore(ctx, cfg)
nodeID := cfg.NodeID
nodesManager, nodeStore, err := createNodeManager(ctx, cfg, jobStore.GetEventStore(), natsConn)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -371,7 +371,7 @@ func createJobStore(ct context.Context, cfg NodeConfig) (jobstore.Store, error)
return jobStore, nil
}

func createNodeManager(ctx context.Context, cfg NodeConfig, natsConn *nats.Conn) (
func createNodeManager(ctx context.Context, cfg NodeConfig, eventStore watcher.EventStore, natsConn *nats.Conn) (
nodes.Manager, nodes.Store, error) {
nodeInfoStore, err := kvstore.NewNodeStore(ctx, kvstore.NodeStoreParams{
BucketName: kvstore.BucketNameCurrent,
Expand All @@ -385,6 +385,7 @@ func createNodeManager(ctx context.Context, cfg NodeConfig, natsConn *nats.Conn)
Store: nodeInfoStore,
NodeDisconnectedAfter: cfg.BacalhauConfig.Orchestrator.NodeManager.DisconnectTimeout.AsTimeDuration(),
ManualApproval: cfg.BacalhauConfig.Orchestrator.NodeManager.ManualApproval,
EventStore: eventStore,
})

if err != nil {
Expand Down
42 changes: 36 additions & 6 deletions pkg/orchestrator/nodes/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ package nodes

import (
"context"
"errors"
"fmt"
"sync"
"time"

"github.com/benbjohnson/clock"
"github.com/rs/zerolog/log"

"github.com/bacalhau-project/bacalhau/pkg/bacerrors"
"github.com/bacalhau-project/bacalhau/pkg/lib/validate"
"github.com/bacalhau-project/bacalhau/pkg/lib/watcher"
"github.com/bacalhau-project/bacalhau/pkg/models"
"github.com/bacalhau-project/bacalhau/pkg/models/messages"
)
Expand Down Expand Up @@ -42,8 +46,9 @@ const (
// state persistence with configurable intervals.
type nodesManager struct {
// Core dependencies
store Store // Persistent storage for node states
clock clock.Clock // Time source (can be mocked for testing)
store Store // Persistent storage for node states
eventstore watcher.EventStore // Store for events
clock clock.Clock // Time source (can be mocked for testing)

// Configuration
defaultApprovalState models.NodeMembershipState // Initial membership state for new nodes
Expand Down Expand Up @@ -94,6 +99,10 @@ type ManagerParams struct {

// ShutdownTimeout is the timeout for graceful shutdown (optional)
ShutdownTimeout time.Duration

// EventStore provides storage for events so that node manager can assign
// new nodes with latest sequence number in the store
EventStore watcher.EventStore
}

// trackedLiveState holds the runtime state for an active node.
Expand Down Expand Up @@ -138,8 +147,16 @@ func NewManager(params ManagerParams) (Manager, error) {
params.ShutdownTimeout = defaultShutdownTimeout
}

if err := errors.Join(
validate.NotNil(params.Store, "store required"),
validate.NotNil(params.EventStore, "event store required"),
); err != nil {
return nil, fmt.Errorf("node manager invalid params: %w", err)
}

return &nodesManager{
store: params.Store,
eventstore: params.EventStore,
clock: params.Clock,
liveState: &sync.Map{},
defaultApprovalState: defaultApprovalState,
Expand Down Expand Up @@ -425,16 +442,29 @@ func (n *nodesManager) Handshake(
Info: request.NodeInfo,
Membership: n.defaultApprovalState,
ConnectionState: models.ConnectionState{
Status: models.NodeStates.CONNECTED,
ConnectedSince: n.clock.Now(),
LastHeartbeat: n.clock.Now(),
LastOrchestratorSeqNum: request.LastOrchestratorSeqNum,
Status: models.NodeStates.CONNECTED,
ConnectedSince: n.clock.Now(),
LastHeartbeat: n.clock.Now(),
},
}

// If a node is reconnecting, we trust and preserve the sequence numbers from its previous state,
// rather than using the sequence numbers from the handshake request. For new nodes,
// we assign them the latest event sequence number from the event store.
// This prevents several edge cases:
// - Compute node losing its state. The handshake will ask to start from 0.
// - Orchestrator losing their state and compute nodes asking to start from a later seqNum that no longer exist.
// - New compute nodes joining. The handshake will also ask to start from 0.
if isReconnect {
state.Membership = existing.Membership
state.ConnectionState.LastComputeSeqNum = existing.ConnectionState.LastComputeSeqNum
state.ConnectionState.LastOrchestratorSeqNum = existing.ConnectionState.LastOrchestratorSeqNum
} else {
// Assign the latest sequence number from the event store
state.ConnectionState.LastOrchestratorSeqNum, err = n.eventstore.GetLatestEventNum(ctx)
if err != nil {
return messages.HandshakeResponse{}, fmt.Errorf("failed to initialize node with latest event number: %w", err)
}
}

if err = n.store.Put(ctx, state); err != nil {
Expand Down
146 changes: 146 additions & 0 deletions pkg/orchestrator/nodes/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,20 @@ import (
"github.com/stretchr/testify/suite"

"github.com/bacalhau-project/bacalhau/pkg/bacerrors"
"github.com/bacalhau-project/bacalhau/pkg/lib/watcher"
"github.com/bacalhau-project/bacalhau/pkg/models"
"github.com/bacalhau-project/bacalhau/pkg/models/messages"
"github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes"
"github.com/bacalhau-project/bacalhau/pkg/orchestrator/nodes/inmemory"
testutils "github.com/bacalhau-project/bacalhau/pkg/test/utils"
)

type NodeManagerTestSuite struct {
suite.Suite
ctx context.Context
clock *clock.Mock
store nodes.Store
eventStore watcher.EventStore
manager nodes.Manager
disconnected time.Duration
}
Expand All @@ -44,8 +47,11 @@ func (s *NodeManagerTestSuite) SetupTest() {
TTL: time.Hour,
})

s.eventStore, _ = testutils.CreateStringEventStore(s.T())

manager, err := nodes.NewManager(nodes.ManagerParams{
Store: s.store,
EventStore: s.eventStore,
Clock: s.clock,
NodeDisconnectedAfter: s.disconnected,
HealthCheckFrequency: 1 * time.Second,
Expand All @@ -62,6 +68,10 @@ func (s *NodeManagerTestSuite) SetupTest() {
func (s *NodeManagerTestSuite) TearDownTest() {
err := s.manager.Stop(s.ctx)
s.Require().NoError(err)

// Cleanup event store
err = s.eventStore.Close(s.ctx)
s.Require().NoError(err)
}

func (s *NodeManagerTestSuite) createNodeInfo(id string) models.NodeInfo {
Expand Down Expand Up @@ -139,6 +149,134 @@ func (s *NodeManagerTestSuite) TestHeartbeatMaintainsConnection() {

// Edge Cases and Error Scenarios

func (s *NodeManagerTestSuite) TestHandshakeSequenceNumberLogic() {
// Test initial handshake with new node
nodeInfo := s.createNodeInfo("new-node")

// First add some events to the event store to have a non-zero latest sequence
ctx := context.Background()
for i := 0; i < 5; i++ {
err := s.eventStore.StoreEvent(ctx, watcher.StoreEventRequest{
Operation: watcher.OperationCreate,
ObjectType: testutils.TypeString,
Object: fmt.Sprintf("test-event-%d", i),
})
s.Require().NoError(err)
}

// Get the latest sequence number for verification
latestSeqNum, err := s.eventStore.GetLatestEventNum(ctx)
s.Require().NoError(err)

// Perform initial handshake
resp1, err := s.manager.Handshake(ctx, messages.HandshakeRequest{
NodeInfo: nodeInfo,
LastOrchestratorSeqNum: 100, // Should be ignored for new nodes
})
s.Require().NoError(err)
s.Require().True(resp1.Accepted)

// Verify the node was assigned the latest sequence number
state, err := s.manager.Get(ctx, nodeInfo.ID())
s.Require().NoError(err)
s.Assert().Equal(latestSeqNum, state.ConnectionState.LastOrchestratorSeqNum,
"New node should be assigned latest sequence number")

// Update sequence numbers via heartbeat
updatedOrchSeqNum := uint64(200)
updatedComputeSeqNum := uint64(150)
_, err = s.manager.Heartbeat(ctx, nodes.ExtendedHeartbeatRequest{
HeartbeatRequest: messages.HeartbeatRequest{
NodeID: nodeInfo.ID(),
LastOrchestratorSeqNum: updatedOrchSeqNum,
},
LastComputeSeqNum: updatedComputeSeqNum,
})
s.Require().NoError(err)

// Simulate disconnect
s.clock.Add(s.disconnected + time.Second)
s.Eventually(func() bool {
state, err := s.manager.Get(ctx, nodeInfo.ID())
s.Require().NoError(err)
return state.ConnectionState.Status == models.NodeStates.DISCONNECTED
}, 500*time.Millisecond, 20*time.Millisecond)

// Reconnect with different sequence number - should keep existing
resp2, err := s.manager.Handshake(ctx, messages.HandshakeRequest{
NodeInfo: nodeInfo,
LastOrchestratorSeqNum: 300, // Should be ignored for reconnecting nodes
})
s.Require().NoError(err)
s.Require().True(resp2.Accepted)
s.Assert().Contains(resp2.Reason, "reconnected")

// Verify sequence numbers were preserved from previous state
state, err = s.manager.Get(ctx, nodeInfo.ID())
s.Require().NoError(err)
s.Assert().Equal(updatedOrchSeqNum, state.ConnectionState.LastOrchestratorSeqNum,
"Reconnected node should preserve previous orchestrator sequence number")
s.Assert().Equal(updatedComputeSeqNum, state.ConnectionState.LastComputeSeqNum,
"Reconnected node should preserve previous compute sequence number")
}

func (s *NodeManagerTestSuite) TestHandshakeSequenceNumberEdgeCases() {
ctx := context.Background()

// Test zero sequence numbers in event store
nodeInfo1 := s.createNodeInfo("zero-seq-node")
resp1, err := s.manager.Handshake(ctx, messages.HandshakeRequest{
NodeInfo: nodeInfo1,
})
s.Require().NoError(err)
s.Require().True(resp1.Accepted)

state1, err := s.manager.Get(ctx, nodeInfo1.ID())
s.Require().NoError(err)
s.Assert().Equal(uint64(0), state1.ConnectionState.LastOrchestratorSeqNum,
"New node should get zero sequence when event store is empty")

// Test concurrent handshakes with sequence numbers
var wg sync.WaitGroup
const numConcurrent = 10

// Add some events first
for i := 0; i < 5; i++ {
err = s.eventStore.StoreEvent(ctx, watcher.StoreEventRequest{
Operation: watcher.OperationCreate,
ObjectType: testutils.TypeString,
Object: fmt.Sprintf("test-event-%d", i),
})
s.Require().NoError(err)
}

latestSeqNum, err := s.eventStore.GetLatestEventNum(ctx)
s.Require().NoError(err)

for i := 0; i < numConcurrent; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()

nodeInfo := s.createNodeInfo(fmt.Sprintf("concurrent-node-%d", id))
resp, err := s.manager.Handshake(ctx, messages.HandshakeRequest{
NodeInfo: nodeInfo,
LastOrchestratorSeqNum: 999, // Should be ignored
})
s.Require().NoError(err)
s.Require().True(resp.Accepted)

// Verify assigned sequence number
state, err := s.manager.Get(ctx, nodeInfo.ID())
s.Require().NoError(err)
s.Assert().Equal(latestSeqNum, state.ConnectionState.LastOrchestratorSeqNum,
"Concurrent new nodes should all get latest sequence number")
}(i)
}

wg.Wait()
}

func (s *NodeManagerTestSuite) TestHeartbeatWithoutHandshake() {
_, err := s.manager.Heartbeat(s.ctx, nodes.ExtendedHeartbeatRequest{
HeartbeatRequest: messages.HeartbeatRequest{
Expand Down Expand Up @@ -394,6 +532,7 @@ func (s *NodeManagerTestSuite) TestConcurrentOperations() {

manager, err := nodes.NewManager(nodes.ManagerParams{
Store: s.store,
EventStore: s.eventStore,
Clock: clock.New(), // Use real clock for this test
NodeDisconnectedAfter: s.disconnected,
HealthCheckFrequency: 1 * time.Second,
Expand Down Expand Up @@ -552,6 +691,7 @@ func (s *NodeManagerTestSuite) TestStartStop() {
// Create a new manager without starting it
manager, err := nodes.NewManager(nodes.ManagerParams{
Store: s.store,
EventStore: s.eventStore,
Clock: s.clock,
NodeDisconnectedAfter: s.disconnected,
HealthCheckFrequency: 1 * time.Second,
Expand All @@ -578,6 +718,7 @@ func (s *NodeManagerTestSuite) TestStartAlreadyStarted() {
// Create and start a manager
manager, err := nodes.NewManager(nodes.ManagerParams{
Store: s.store,
EventStore: s.eventStore,
Clock: s.clock,
NodeDisconnectedAfter: s.disconnected,
})
Expand All @@ -603,6 +744,7 @@ func (s *NodeManagerTestSuite) TestStartAlreadyStarted() {
func (s *NodeManagerTestSuite) TestStartContextCancellation() {
manager, err := nodes.NewManager(nodes.ManagerParams{
Store: s.store,
EventStore: s.eventStore,
Clock: s.clock,
NodeDisconnectedAfter: s.disconnected,
HealthCheckFrequency: 1 * time.Second,
Expand Down Expand Up @@ -630,6 +772,7 @@ func (s *NodeManagerTestSuite) TestStopAlreadyStopped() {
// Create and start a manager
manager, err := nodes.NewManager(nodes.ManagerParams{
Store: s.store,
EventStore: s.eventStore,
Clock: s.clock,
NodeDisconnectedAfter: s.disconnected,
})
Expand Down Expand Up @@ -657,6 +800,7 @@ func (s *NodeManagerTestSuite) TestPeriodicStatePersistence() {
persistInterval := 100 * time.Millisecond
manager, err := nodes.NewManager(nodes.ManagerParams{
Store: s.store,
EventStore: s.eventStore,
Clock: s.clock,
NodeDisconnectedAfter: s.disconnected,
PersistInterval: persistInterval,
Expand Down Expand Up @@ -720,6 +864,7 @@ func (s *NodeManagerTestSuite) TestStatePersistenceOnStop() {
// Create manager
manager, err := nodes.NewManager(nodes.ManagerParams{
Store: s.store,
EventStore: s.eventStore,
Clock: s.clock,
NodeDisconnectedAfter: s.disconnected,
PersistInterval: time.Hour, // Long interval to ensure persistence happens on stop
Expand Down Expand Up @@ -764,6 +909,7 @@ func (s *NodeManagerTestSuite) TestPersistenceWithContextCancellation() {
// Create manager with short persist interval
manager, err := nodes.NewManager(nodes.ManagerParams{
Store: s.store,
EventStore: s.eventStore,
Clock: s.clock,
NodeDisconnectedAfter: s.disconnected,
PersistInterval: 100 * time.Millisecond,
Expand Down
Loading

0 comments on commit 559e78a

Please sign in to comment.