Skip to content

Commit

Permalink
handle context in broadcaster
Browse files Browse the repository at this point in the history
  • Loading branch information
sjberman committed Jan 28, 2025
1 parent 8019ab3 commit efd12b5
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 18 deletions.
2 changes: 1 addition & 1 deletion internal/mode/static/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (h *eventHandlerImpl) HandleEventBatch(ctx context.Context, logger logr.Log
// and Deployment.
// If fully deleted, then delete the deployment from the Store and close the stopCh.
stopCh := make(chan struct{})
deployment := h.cfg.nginxDeployments.GetOrStore(deploymentName, stopCh)
deployment := h.cfg.nginxDeployments.GetOrStore(ctx, deploymentName, stopCh)
if deployment == nil {
panic("expected deployment, got nil")
}
Expand Down
9 changes: 6 additions & 3 deletions internal/mode/static/nginx/agent/broadcast/broadcast.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package broadcast

import (
"context"
"sync"

pb "github.com/nginx/agent/v3/api/grpc/mpi/v1"
Expand Down Expand Up @@ -48,15 +49,15 @@ type DeploymentBroadcaster struct {
}

// NewDeploymentBroadcaster returns a new instance of a DeploymentBroadcaster.
func NewDeploymentBroadcaster(stopCh chan struct{}) *DeploymentBroadcaster {
func NewDeploymentBroadcaster(ctx context.Context, stopCh chan struct{}) *DeploymentBroadcaster {
broadcaster := &DeploymentBroadcaster{
listeners: make(map[string]storedChannels),
publishCh: make(chan NginxAgentMessage),
subCh: make(chan storedChannels),
unsubCh: make(chan string),
doneCh: make(chan struct{}),
}
go broadcaster.run(stopCh)
go broadcaster.run(ctx, stopCh)

return broadcaster
}
Expand Down Expand Up @@ -102,11 +103,13 @@ func (b *DeploymentBroadcaster) CancelSubscription(id string) {
// - if receiving a new subscriber, add it to the subscriber list.
// - if receiving a canceled subscription, remove it from the subscriber list.
// - if receiving a message to publish, send it to all subscribers.
func (b *DeploymentBroadcaster) run(stopCh chan struct{}) {
func (b *DeploymentBroadcaster) run(ctx context.Context, stopCh chan struct{}) {
for {
select {
case <-stopCh:
return
case <-ctx.Done():
return
case channels := <-b.subCh:
b.listeners[channels.id] = channels
case id := <-b.unsubCh:
Expand Down
9 changes: 5 additions & 4 deletions internal/mode/static/nginx/agent/broadcast/broadcast_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package broadcast_test

import (
"context"
"testing"

. "github.com/onsi/gomega"
Expand All @@ -15,7 +16,7 @@ func TestSubscribe(t *testing.T) {
stopCh := make(chan struct{})
defer close(stopCh)

broadcaster := broadcast.NewDeploymentBroadcaster(stopCh)
broadcaster := broadcast.NewDeploymentBroadcaster(context.Background(), stopCh)

subscriber := broadcaster.Subscribe()
g.Expect(subscriber.ID).NotTo(BeEmpty())
Expand All @@ -40,7 +41,7 @@ func TestSubscribe_MultipleListeners(t *testing.T) {
stopCh := make(chan struct{})
defer close(stopCh)

broadcaster := broadcast.NewDeploymentBroadcaster(stopCh)
broadcaster := broadcast.NewDeploymentBroadcaster(context.Background(), stopCh)

subscriber1 := broadcaster.Subscribe()
subscriber2 := broadcaster.Subscribe()
Expand Down Expand Up @@ -69,7 +70,7 @@ func TestSubscribe_NoListeners(t *testing.T) {
stopCh := make(chan struct{})
defer close(stopCh)

broadcaster := broadcast.NewDeploymentBroadcaster(stopCh)
broadcaster := broadcast.NewDeploymentBroadcaster(context.Background(), stopCh)

message := broadcast.NginxAgentMessage{
ConfigVersion: "v1",
Expand All @@ -87,7 +88,7 @@ func TestCancelSubscription(t *testing.T) {
stopCh := make(chan struct{})
defer close(stopCh)

broadcaster := broadcast.NewDeploymentBroadcaster(stopCh)
broadcaster := broadcast.NewDeploymentBroadcaster(context.Background(), stopCh)

subscriber := broadcaster.Subscribe()

Expand Down
9 changes: 7 additions & 2 deletions internal/mode/static/nginx/agent/deployment.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package agent

import (
"context"
"errors"
"fmt"
"sync"
Expand Down Expand Up @@ -227,12 +228,16 @@ func (d *DeploymentStore) Get(nsName types.NamespacedName) *Deployment {

// GetOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
func (d *DeploymentStore) GetOrStore(nsName types.NamespacedName, stopCh chan struct{}) *Deployment {
func (d *DeploymentStore) GetOrStore(
ctx context.Context,
nsName types.NamespacedName,
stopCh chan struct{},
) *Deployment {
if deployment := d.Get(nsName); deployment != nil {
return deployment
}

deployment := newDeployment(broadcast.NewDeploymentBroadcaster(stopCh))
deployment := newDeployment(broadcast.NewDeploymentBroadcaster(ctx, stopCh))
d.deployments.Store(nsName, deployment)

return deployment
Expand Down
5 changes: 3 additions & 2 deletions internal/mode/static/nginx/agent/deployment_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package agent

import (
"context"
"errors"
"testing"

Expand Down Expand Up @@ -122,13 +123,13 @@ func TestDeploymentStore(t *testing.T) {

nsName := types.NamespacedName{Namespace: "default", Name: "test-deployment"}

deployment := store.GetOrStore(nsName, nil)
deployment := store.GetOrStore(context.Background(), nsName, nil)
g.Expect(deployment).ToNot(BeNil())

fetchedDeployment := store.Get(nsName)
g.Expect(fetchedDeployment).To(Equal(deployment))

deployment = store.GetOrStore(nsName, nil)
deployment = store.GetOrStore(context.Background(), nsName, nil)
g.Expect(fetchedDeployment).To(Equal(deployment))

store.Remove(nsName)
Expand Down
4 changes: 2 additions & 2 deletions internal/mode/static/nginx/agent/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestGetFile(t *testing.T) {
connTracker.GetConnectionReturns(conn)

depStore := NewDeploymentStore(connTracker)
dep := depStore.GetOrStore(deploymentName, nil)
dep := depStore.GetOrStore(context.Background(), deploymentName, nil)

fileMeta := &pb.FileMeta{
Name: "test.conf",
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestGetFile_FileNotFound(t *testing.T) {
connTracker.GetConnectionReturns(conn)

depStore := NewDeploymentStore(connTracker)
depStore.GetOrStore(deploymentName, nil)
depStore.GetOrStore(context.Background(), deploymentName, nil)

fs := newFileService(logr.Discard(), depStore, connTracker)

Expand Down
8 changes: 4 additions & 4 deletions internal/mode/static/nginx/agent/grpc/connections_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ func TestUntrackConnectionsForParent(t *testing.T) {

tracker := agentgrpc.NewConnectionsTracker()

parent := types.NamespacedName{Namespace: "default", Name: "parent1"}
conn1 := agentgrpc.Connection{PodName: "pod1", InstanceID: "instance1", Parent: parent}
conn2 := agentgrpc.Connection{PodName: "pod2", InstanceID: "instance2", Parent: parent}
parent1 := types.NamespacedName{Namespace: "default", Name: "parent1"}
conn1 := agentgrpc.Connection{PodName: "pod1", InstanceID: "instance1", Parent: parent1}
conn2 := agentgrpc.Connection{PodName: "pod2", InstanceID: "instance2", Parent: parent1}

parent2 := types.NamespacedName{Namespace: "default", Name: "parent2"}
conn3 := agentgrpc.Connection{PodName: "pod3", InstanceID: "instance3", Parent: parent2}
Expand All @@ -92,7 +92,7 @@ func TestUntrackConnectionsForParent(t *testing.T) {
tracker.Track("key2", conn2)
tracker.Track("key3", conn3)

tracker.UntrackConnectionsForParent(parent)
tracker.UntrackConnectionsForParent(parent1)
g.Expect(tracker.GetConnection("key1")).To(Equal(agentgrpc.Connection{}))
g.Expect(tracker.GetConnection("key2")).To(Equal(agentgrpc.Connection{}))
g.Expect(tracker.GetConnection("key3")).To(Equal(conn3))
Expand Down

0 comments on commit efd12b5

Please sign in to comment.