From efd12b546d1c8cfce7a476c97ffc72c8fdcc7fa9 Mon Sep 17 00:00:00 2001 From: Saylor Berman Date: Tue, 28 Jan 2025 11:49:10 -0700 Subject: [PATCH] handle context in broadcaster --- internal/mode/static/handler.go | 2 +- internal/mode/static/nginx/agent/broadcast/broadcast.go | 9 ++++++--- .../mode/static/nginx/agent/broadcast/broadcast_test.go | 9 +++++---- internal/mode/static/nginx/agent/deployment.go | 9 +++++++-- internal/mode/static/nginx/agent/deployment_test.go | 5 +++-- internal/mode/static/nginx/agent/file_test.go | 4 ++-- .../mode/static/nginx/agent/grpc/connections_test.go | 8 ++++---- 7 files changed, 28 insertions(+), 18 deletions(-) diff --git a/internal/mode/static/handler.go b/internal/mode/static/handler.go index 92fd6d62d0..66523b95ff 100644 --- a/internal/mode/static/handler.go +++ b/internal/mode/static/handler.go @@ -183,7 +183,7 @@ func (h *eventHandlerImpl) HandleEventBatch(ctx context.Context, logger logr.Log // and Deployment. // If fully deleted, then delete the deployment from the Store and close the stopCh. stopCh := make(chan struct{}) - deployment := h.cfg.nginxDeployments.GetOrStore(deploymentName, stopCh) + deployment := h.cfg.nginxDeployments.GetOrStore(ctx, deploymentName, stopCh) if deployment == nil { panic("expected deployment, got nil") } diff --git a/internal/mode/static/nginx/agent/broadcast/broadcast.go b/internal/mode/static/nginx/agent/broadcast/broadcast.go index 972e56963f..2b21ae1117 100644 --- a/internal/mode/static/nginx/agent/broadcast/broadcast.go +++ b/internal/mode/static/nginx/agent/broadcast/broadcast.go @@ -1,6 +1,7 @@ package broadcast import ( + "context" "sync" pb "github.com/nginx/agent/v3/api/grpc/mpi/v1" @@ -48,7 +49,7 @@ type DeploymentBroadcaster struct { } // NewDeploymentBroadcaster returns a new instance of a DeploymentBroadcaster. -func NewDeploymentBroadcaster(stopCh chan struct{}) *DeploymentBroadcaster { +func NewDeploymentBroadcaster(ctx context.Context, stopCh chan struct{}) *DeploymentBroadcaster { broadcaster := &DeploymentBroadcaster{ listeners: make(map[string]storedChannels), publishCh: make(chan NginxAgentMessage), @@ -56,7 +57,7 @@ func NewDeploymentBroadcaster(stopCh chan struct{}) *DeploymentBroadcaster { unsubCh: make(chan string), doneCh: make(chan struct{}), } - go broadcaster.run(stopCh) + go broadcaster.run(ctx, stopCh) return broadcaster } @@ -102,11 +103,13 @@ func (b *DeploymentBroadcaster) CancelSubscription(id string) { // - if receiving a new subscriber, add it to the subscriber list. // - if receiving a canceled subscription, remove it from the subscriber list. // - if receiving a message to publish, send it to all subscribers. -func (b *DeploymentBroadcaster) run(stopCh chan struct{}) { +func (b *DeploymentBroadcaster) run(ctx context.Context, stopCh chan struct{}) { for { select { case <-stopCh: return + case <-ctx.Done(): + return case channels := <-b.subCh: b.listeners[channels.id] = channels case id := <-b.unsubCh: diff --git a/internal/mode/static/nginx/agent/broadcast/broadcast_test.go b/internal/mode/static/nginx/agent/broadcast/broadcast_test.go index 5f69360bb4..f42e250a01 100644 --- a/internal/mode/static/nginx/agent/broadcast/broadcast_test.go +++ b/internal/mode/static/nginx/agent/broadcast/broadcast_test.go @@ -1,6 +1,7 @@ package broadcast_test import ( + "context" "testing" . "github.com/onsi/gomega" @@ -15,7 +16,7 @@ func TestSubscribe(t *testing.T) { stopCh := make(chan struct{}) defer close(stopCh) - broadcaster := broadcast.NewDeploymentBroadcaster(stopCh) + broadcaster := broadcast.NewDeploymentBroadcaster(context.Background(), stopCh) subscriber := broadcaster.Subscribe() g.Expect(subscriber.ID).NotTo(BeEmpty()) @@ -40,7 +41,7 @@ func TestSubscribe_MultipleListeners(t *testing.T) { stopCh := make(chan struct{}) defer close(stopCh) - broadcaster := broadcast.NewDeploymentBroadcaster(stopCh) + broadcaster := broadcast.NewDeploymentBroadcaster(context.Background(), stopCh) subscriber1 := broadcaster.Subscribe() subscriber2 := broadcaster.Subscribe() @@ -69,7 +70,7 @@ func TestSubscribe_NoListeners(t *testing.T) { stopCh := make(chan struct{}) defer close(stopCh) - broadcaster := broadcast.NewDeploymentBroadcaster(stopCh) + broadcaster := broadcast.NewDeploymentBroadcaster(context.Background(), stopCh) message := broadcast.NginxAgentMessage{ ConfigVersion: "v1", @@ -87,7 +88,7 @@ func TestCancelSubscription(t *testing.T) { stopCh := make(chan struct{}) defer close(stopCh) - broadcaster := broadcast.NewDeploymentBroadcaster(stopCh) + broadcaster := broadcast.NewDeploymentBroadcaster(context.Background(), stopCh) subscriber := broadcaster.Subscribe() diff --git a/internal/mode/static/nginx/agent/deployment.go b/internal/mode/static/nginx/agent/deployment.go index a205a477e1..6257626ba0 100644 --- a/internal/mode/static/nginx/agent/deployment.go +++ b/internal/mode/static/nginx/agent/deployment.go @@ -1,6 +1,7 @@ package agent import ( + "context" "errors" "fmt" "sync" @@ -227,12 +228,16 @@ func (d *DeploymentStore) Get(nsName types.NamespacedName) *Deployment { // GetOrStore returns the existing value for the key if present. // Otherwise, it stores and returns the given value. -func (d *DeploymentStore) GetOrStore(nsName types.NamespacedName, stopCh chan struct{}) *Deployment { +func (d *DeploymentStore) GetOrStore( + ctx context.Context, + nsName types.NamespacedName, + stopCh chan struct{}, +) *Deployment { if deployment := d.Get(nsName); deployment != nil { return deployment } - deployment := newDeployment(broadcast.NewDeploymentBroadcaster(stopCh)) + deployment := newDeployment(broadcast.NewDeploymentBroadcaster(ctx, stopCh)) d.deployments.Store(nsName, deployment) return deployment diff --git a/internal/mode/static/nginx/agent/deployment_test.go b/internal/mode/static/nginx/agent/deployment_test.go index dd6d9d328f..456316a005 100644 --- a/internal/mode/static/nginx/agent/deployment_test.go +++ b/internal/mode/static/nginx/agent/deployment_test.go @@ -1,6 +1,7 @@ package agent import ( + "context" "errors" "testing" @@ -122,13 +123,13 @@ func TestDeploymentStore(t *testing.T) { nsName := types.NamespacedName{Namespace: "default", Name: "test-deployment"} - deployment := store.GetOrStore(nsName, nil) + deployment := store.GetOrStore(context.Background(), nsName, nil) g.Expect(deployment).ToNot(BeNil()) fetchedDeployment := store.Get(nsName) g.Expect(fetchedDeployment).To(Equal(deployment)) - deployment = store.GetOrStore(nsName, nil) + deployment = store.GetOrStore(context.Background(), nsName, nil) g.Expect(fetchedDeployment).To(Equal(deployment)) store.Remove(nsName) diff --git a/internal/mode/static/nginx/agent/file_test.go b/internal/mode/static/nginx/agent/file_test.go index fc8029cc0b..8ffe80de18 100644 --- a/internal/mode/static/nginx/agent/file_test.go +++ b/internal/mode/static/nginx/agent/file_test.go @@ -31,7 +31,7 @@ func TestGetFile(t *testing.T) { connTracker.GetConnectionReturns(conn) depStore := NewDeploymentStore(connTracker) - dep := depStore.GetOrStore(deploymentName, nil) + dep := depStore.GetOrStore(context.Background(), deploymentName, nil) fileMeta := &pb.FileMeta{ Name: "test.conf", @@ -154,7 +154,7 @@ func TestGetFile_FileNotFound(t *testing.T) { connTracker.GetConnectionReturns(conn) depStore := NewDeploymentStore(connTracker) - depStore.GetOrStore(deploymentName, nil) + depStore.GetOrStore(context.Background(), deploymentName, nil) fs := newFileService(logr.Discard(), depStore, connTracker) diff --git a/internal/mode/static/nginx/agent/grpc/connections_test.go b/internal/mode/static/nginx/agent/grpc/connections_test.go index 5434048442..9d72e2a290 100644 --- a/internal/mode/static/nginx/agent/grpc/connections_test.go +++ b/internal/mode/static/nginx/agent/grpc/connections_test.go @@ -81,9 +81,9 @@ func TestUntrackConnectionsForParent(t *testing.T) { tracker := agentgrpc.NewConnectionsTracker() - parent := types.NamespacedName{Namespace: "default", Name: "parent1"} - conn1 := agentgrpc.Connection{PodName: "pod1", InstanceID: "instance1", Parent: parent} - conn2 := agentgrpc.Connection{PodName: "pod2", InstanceID: "instance2", Parent: parent} + parent1 := types.NamespacedName{Namespace: "default", Name: "parent1"} + conn1 := agentgrpc.Connection{PodName: "pod1", InstanceID: "instance1", Parent: parent1} + conn2 := agentgrpc.Connection{PodName: "pod2", InstanceID: "instance2", Parent: parent1} parent2 := types.NamespacedName{Namespace: "default", Name: "parent2"} conn3 := agentgrpc.Connection{PodName: "pod3", InstanceID: "instance3", Parent: parent2} @@ -92,7 +92,7 @@ func TestUntrackConnectionsForParent(t *testing.T) { tracker.Track("key2", conn2) tracker.Track("key3", conn3) - tracker.UntrackConnectionsForParent(parent) + tracker.UntrackConnectionsForParent(parent1) g.Expect(tracker.GetConnection("key1")).To(Equal(agentgrpc.Connection{})) g.Expect(tracker.GetConnection("key2")).To(Equal(agentgrpc.Connection{})) g.Expect(tracker.GetConnection("key3")).To(Equal(conn3))