From 76836dd95d0a317c881427c44e6355f8b76336c9 Mon Sep 17 00:00:00 2001 From: Walid Baruni Date: Wed, 18 Dec 2024 12:07:56 +0200 Subject: [PATCH] Replace NodeID field access with ID() (#4784) ## Summary by CodeRabbit - **New Features** - Enhanced encapsulation by updating node ID retrieval methods across various components. - **Bug Fixes** - Improved robustness in heartbeat handling and node information updates in tests. - **Tests** - Updated test cases to reflect changes in node ID handling, ensuring accurate validation of control plane behavior. --- pkg/models/node_info.go | 1 - pkg/nats/proxy/management_proxy.go | 2 +- pkg/orchestrator/selection/ranking/available_capacity.go | 8 ++++---- pkg/transport/bprotocol/orchestrator/heartbeat_test.go | 2 +- pkg/transport/nclprotocol/compute/controlplane.go | 2 +- pkg/transport/nclprotocol/compute/controlplane_test.go | 2 +- pkg/transport/nclprotocol/compute/manager_test.go | 4 ++-- 7 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pkg/models/node_info.go b/pkg/models/node_info.go index c889618216..8fd73417e4 100644 --- a/pkg/models/node_info.go +++ b/pkg/models/node_info.go @@ -90,7 +90,6 @@ func (n NoopNodeInfoDecorator) DecorateNodeInfo(ctx context.Context, nodeInfo No // to further its view of the networks conditions. ComputeNodeInfo is non-nil iff the NodeType is NodeTypeCompute. // TODO(walid): add Validate() method to NodeInfo and make sure it is called in all the places where it is initialized type NodeInfo struct { - // TODO replace all access on this field with the `ID()` method NodeID string `json:"NodeID"` NodeType NodeType `json:"NodeType"` Labels map[string]string `json:"Labels"` diff --git a/pkg/nats/proxy/management_proxy.go b/pkg/nats/proxy/management_proxy.go index 92178e1bce..1087afc10f 100644 --- a/pkg/nats/proxy/management_proxy.go +++ b/pkg/nats/proxy/management_proxy.go @@ -52,7 +52,7 @@ func (p *ManagementProxy) Register(ctx context.Context, var asyncRes *concurrency.AsyncResult[legacy.RegisterResponse] asyncRes, err = send[legacy.RegisterRequest, legacy.RegisterResponse]( - ctx, p.conn, request.Info.NodeID, request, RegisterNode) + ctx, p.conn, request.Info.ID(), request, RegisterNode) if err != nil { return nil, errors.Wrap(err, "failed to send response to registration request") diff --git a/pkg/orchestrator/selection/ranking/available_capacity.go b/pkg/orchestrator/selection/ranking/available_capacity.go index 8c8a8de547..eadc0d80e0 100644 --- a/pkg/orchestrator/selection/ranking/available_capacity.go +++ b/pkg/orchestrator/selection/ranking/available_capacity.go @@ -103,8 +103,8 @@ func (s *AvailableCapacityNodeRanker) calculateWeightedCapacities(nodes []models weightedAvailableCapacity := weightedCapacity(node.ComputeNodeInfo.AvailableCapacity, weights) weightedQueueUsedCapacity := weightedCapacity(node.ComputeNodeInfo.QueueUsedCapacity, weights) - weightedAvailableCapacities[node.NodeID] = weightedAvailableCapacity - weightedQueueCapacities[node.NodeID] = weightedQueueUsedCapacity + weightedAvailableCapacities[node.ID()] = weightedAvailableCapacity + weightedQueueCapacities[node.ID()] = weightedQueueUsedCapacity if weightedAvailableCapacity > maxWeightedAvailableCapacity { maxWeightedAvailableCapacity = weightedAvailableCapacity @@ -124,8 +124,8 @@ func (s *AvailableCapacityNodeRanker) rankNodesBasedOnCapacities(ctx context.Con ranks := make([]orchestrator.NodeRank, len(nodes)) for i, node := range nodes { - weightedAvailableCapacity := wAvailableCapacities[node.NodeID] - weightedQueueUsedCapacity := wQueueCapacities[node.NodeID] + weightedAvailableCapacity := wAvailableCapacities[node.ID()] + weightedQueueUsedCapacity := wQueueCapacities[node.ID()] // Calculate the ratios of available and queue capacities availableRatio := 0.0 diff --git a/pkg/transport/bprotocol/orchestrator/heartbeat_test.go b/pkg/transport/bprotocol/orchestrator/heartbeat_test.go index aac607cb6a..3abd83894e 100644 --- a/pkg/transport/bprotocol/orchestrator/heartbeat_test.go +++ b/pkg/transport/bprotocol/orchestrator/heartbeat_test.go @@ -224,7 +224,7 @@ func (s *HeartbeatTestSuite) TestHeartbeatScenarios() { s.clock.Add(tc.waitUntil) - nodeState, err := s.nodeManager.Get(ctx, nodeInfo.NodeID) + nodeState, err := s.nodeManager.Get(ctx, nodeInfo.ID()) if tc.handshake { s.Require().NoError(err) s.Require().Equal(tc.expectedState, nodeState.ConnectionState.Status, fmt.Sprintf("incorrect state in %s", tc.name)) diff --git a/pkg/transport/nclprotocol/compute/controlplane.go b/pkg/transport/nclprotocol/compute/controlplane.go index f3ed2b969d..e04a5c5fc0 100644 --- a/pkg/transport/nclprotocol/compute/controlplane.go +++ b/pkg/transport/nclprotocol/compute/controlplane.go @@ -139,7 +139,7 @@ func (cp *ControlPlane) heartbeat(ctx context.Context) error { cp.latestNodeInfo = nodeInfo msg := envelope.NewMessage(messages.HeartbeatRequest{ - NodeID: cp.latestNodeInfo.NodeID, + NodeID: cp.latestNodeInfo.ID(), AvailableCapacity: nodeInfo.ComputeNodeInfo.AvailableCapacity, QueueUsedCapacity: nodeInfo.ComputeNodeInfo.QueueUsedCapacity, LastOrchestratorSeqNum: cp.incomingSeqTracker.GetLastSeqNum(), diff --git a/pkg/transport/nclprotocol/compute/controlplane_test.go b/pkg/transport/nclprotocol/compute/controlplane_test.go index bd3cd4cc70..753ac9fb0f 100644 --- a/pkg/transport/nclprotocol/compute/controlplane_test.go +++ b/pkg/transport/nclprotocol/compute/controlplane_test.go @@ -158,7 +158,7 @@ func (s *ControlPlaneTestSuite) TestHeartbeat() { nodeInfo := s.nodeInfoProvider.GetNodeInfo(s.ctx) heartbeatMsg := envelope.NewMessage(messages.HeartbeatRequest{ - NodeID: nodeInfo.NodeID, + NodeID: nodeInfo.ID(), AvailableCapacity: nodeInfo.ComputeNodeInfo.AvailableCapacity, QueueUsedCapacity: nodeInfo.ComputeNodeInfo.QueueUsedCapacity, LastOrchestratorSeqNum: s.seqTracker.GetLastSeqNum(), diff --git a/pkg/transport/nclprotocol/compute/manager_test.go b/pkg/transport/nclprotocol/compute/manager_test.go index 769983969c..e97db3ef3d 100644 --- a/pkg/transport/nclprotocol/compute/manager_test.go +++ b/pkg/transport/nclprotocol/compute/manager_test.go @@ -159,7 +159,7 @@ func (s *ConnectionManagerTestSuite) TestSuccessfulConnection() { heartbeats := s.mockResponder.GetHeartbeats() s.Require().Len(heartbeats, 1) s.Require().Equal(messages.HeartbeatRequest{ - NodeID: nodeInfo.NodeID, + NodeID: nodeInfo.ID(), AvailableCapacity: nodeInfo.ComputeNodeInfo.AvailableCapacity, QueueUsedCapacity: nodeInfo.ComputeNodeInfo.QueueUsedCapacity, LastOrchestratorSeqNum: handshakeSeqNum, // Should use sequence number from handshake response @@ -178,7 +178,7 @@ func (s *ConnectionManagerTestSuite) TestSuccessfulConnection() { s.Require().Eventually(func() bool { lastHeartbeat := s.mockResponder.GetHeartbeats()[len(s.mockResponder.GetHeartbeats())-1] return reflect.DeepEqual(lastHeartbeat, messages.HeartbeatRequest{ - NodeID: nodeInfo.NodeID, + NodeID: nodeInfo.ID(), AvailableCapacity: nodeInfo.ComputeNodeInfo.AvailableCapacity, QueueUsedCapacity: nodeInfo.ComputeNodeInfo.QueueUsedCapacity, LastOrchestratorSeqNum: handshakeSeqNum, // Should continue using sequence number from handshake