Skip to content

Commit efd12b5

Browse files
committed
handle context in broadcaster
1 parent 8019ab3 commit efd12b5

File tree

7 files changed

+28
-18
lines changed

7 files changed

+28
-18
lines changed

internal/mode/static/handler.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ func (h *eventHandlerImpl) HandleEventBatch(ctx context.Context, logger logr.Log
183183
// and Deployment.
184184
// If fully deleted, then delete the deployment from the Store and close the stopCh.
185185
stopCh := make(chan struct{})
186-
deployment := h.cfg.nginxDeployments.GetOrStore(deploymentName, stopCh)
186+
deployment := h.cfg.nginxDeployments.GetOrStore(ctx, deploymentName, stopCh)
187187
if deployment == nil {
188188
panic("expected deployment, got nil")
189189
}

internal/mode/static/nginx/agent/broadcast/broadcast.go

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package broadcast
22

33
import (
4+
"context"
45
"sync"
56

67
pb "github.com/nginx/agent/v3/api/grpc/mpi/v1"
@@ -48,15 +49,15 @@ type DeploymentBroadcaster struct {
4849
}
4950

5051
// NewDeploymentBroadcaster returns a new instance of a DeploymentBroadcaster.
51-
func NewDeploymentBroadcaster(stopCh chan struct{}) *DeploymentBroadcaster {
52+
func NewDeploymentBroadcaster(ctx context.Context, stopCh chan struct{}) *DeploymentBroadcaster {
5253
broadcaster := &DeploymentBroadcaster{
5354
listeners: make(map[string]storedChannels),
5455
publishCh: make(chan NginxAgentMessage),
5556
subCh: make(chan storedChannels),
5657
unsubCh: make(chan string),
5758
doneCh: make(chan struct{}),
5859
}
59-
go broadcaster.run(stopCh)
60+
go broadcaster.run(ctx, stopCh)
6061

6162
return broadcaster
6263
}
@@ -102,11 +103,13 @@ func (b *DeploymentBroadcaster) CancelSubscription(id string) {
102103
// - if receiving a new subscriber, add it to the subscriber list.
103104
// - if receiving a canceled subscription, remove it from the subscriber list.
104105
// - if receiving a message to publish, send it to all subscribers.
105-
func (b *DeploymentBroadcaster) run(stopCh chan struct{}) {
106+
func (b *DeploymentBroadcaster) run(ctx context.Context, stopCh chan struct{}) {
106107
for {
107108
select {
108109
case <-stopCh:
109110
return
111+
case <-ctx.Done():
112+
return
110113
case channels := <-b.subCh:
111114
b.listeners[channels.id] = channels
112115
case id := <-b.unsubCh:

internal/mode/static/nginx/agent/broadcast/broadcast_test.go

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package broadcast_test
22

33
import (
4+
"context"
45
"testing"
56

67
. "github.com/onsi/gomega"
@@ -15,7 +16,7 @@ func TestSubscribe(t *testing.T) {
1516
stopCh := make(chan struct{})
1617
defer close(stopCh)
1718

18-
broadcaster := broadcast.NewDeploymentBroadcaster(stopCh)
19+
broadcaster := broadcast.NewDeploymentBroadcaster(context.Background(), stopCh)
1920

2021
subscriber := broadcaster.Subscribe()
2122
g.Expect(subscriber.ID).NotTo(BeEmpty())
@@ -40,7 +41,7 @@ func TestSubscribe_MultipleListeners(t *testing.T) {
4041
stopCh := make(chan struct{})
4142
defer close(stopCh)
4243

43-
broadcaster := broadcast.NewDeploymentBroadcaster(stopCh)
44+
broadcaster := broadcast.NewDeploymentBroadcaster(context.Background(), stopCh)
4445

4546
subscriber1 := broadcaster.Subscribe()
4647
subscriber2 := broadcaster.Subscribe()
@@ -69,7 +70,7 @@ func TestSubscribe_NoListeners(t *testing.T) {
6970
stopCh := make(chan struct{})
7071
defer close(stopCh)
7172

72-
broadcaster := broadcast.NewDeploymentBroadcaster(stopCh)
73+
broadcaster := broadcast.NewDeploymentBroadcaster(context.Background(), stopCh)
7374

7475
message := broadcast.NginxAgentMessage{
7576
ConfigVersion: "v1",
@@ -87,7 +88,7 @@ func TestCancelSubscription(t *testing.T) {
8788
stopCh := make(chan struct{})
8889
defer close(stopCh)
8990

90-
broadcaster := broadcast.NewDeploymentBroadcaster(stopCh)
91+
broadcaster := broadcast.NewDeploymentBroadcaster(context.Background(), stopCh)
9192

9293
subscriber := broadcaster.Subscribe()
9394

internal/mode/static/nginx/agent/deployment.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package agent
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"sync"
@@ -227,12 +228,16 @@ func (d *DeploymentStore) Get(nsName types.NamespacedName) *Deployment {
227228

228229
// GetOrStore returns the existing value for the key if present.
229230
// Otherwise, it stores and returns the given value.
230-
func (d *DeploymentStore) GetOrStore(nsName types.NamespacedName, stopCh chan struct{}) *Deployment {
231+
func (d *DeploymentStore) GetOrStore(
232+
ctx context.Context,
233+
nsName types.NamespacedName,
234+
stopCh chan struct{},
235+
) *Deployment {
231236
if deployment := d.Get(nsName); deployment != nil {
232237
return deployment
233238
}
234239

235-
deployment := newDeployment(broadcast.NewDeploymentBroadcaster(stopCh))
240+
deployment := newDeployment(broadcast.NewDeploymentBroadcaster(ctx, stopCh))
236241
d.deployments.Store(nsName, deployment)
237242

238243
return deployment

internal/mode/static/nginx/agent/deployment_test.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package agent
22

33
import (
4+
"context"
45
"errors"
56
"testing"
67

@@ -122,13 +123,13 @@ func TestDeploymentStore(t *testing.T) {
122123

123124
nsName := types.NamespacedName{Namespace: "default", Name: "test-deployment"}
124125

125-
deployment := store.GetOrStore(nsName, nil)
126+
deployment := store.GetOrStore(context.Background(), nsName, nil)
126127
g.Expect(deployment).ToNot(BeNil())
127128

128129
fetchedDeployment := store.Get(nsName)
129130
g.Expect(fetchedDeployment).To(Equal(deployment))
130131

131-
deployment = store.GetOrStore(nsName, nil)
132+
deployment = store.GetOrStore(context.Background(), nsName, nil)
132133
g.Expect(fetchedDeployment).To(Equal(deployment))
133134

134135
store.Remove(nsName)

internal/mode/static/nginx/agent/file_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func TestGetFile(t *testing.T) {
3131
connTracker.GetConnectionReturns(conn)
3232

3333
depStore := NewDeploymentStore(connTracker)
34-
dep := depStore.GetOrStore(deploymentName, nil)
34+
dep := depStore.GetOrStore(context.Background(), deploymentName, nil)
3535

3636
fileMeta := &pb.FileMeta{
3737
Name: "test.conf",
@@ -154,7 +154,7 @@ func TestGetFile_FileNotFound(t *testing.T) {
154154
connTracker.GetConnectionReturns(conn)
155155

156156
depStore := NewDeploymentStore(connTracker)
157-
depStore.GetOrStore(deploymentName, nil)
157+
depStore.GetOrStore(context.Background(), deploymentName, nil)
158158

159159
fs := newFileService(logr.Discard(), depStore, connTracker)
160160

internal/mode/static/nginx/agent/grpc/connections_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ func TestUntrackConnectionsForParent(t *testing.T) {
8181

8282
tracker := agentgrpc.NewConnectionsTracker()
8383

84-
parent := types.NamespacedName{Namespace: "default", Name: "parent1"}
85-
conn1 := agentgrpc.Connection{PodName: "pod1", InstanceID: "instance1", Parent: parent}
86-
conn2 := agentgrpc.Connection{PodName: "pod2", InstanceID: "instance2", Parent: parent}
84+
parent1 := types.NamespacedName{Namespace: "default", Name: "parent1"}
85+
conn1 := agentgrpc.Connection{PodName: "pod1", InstanceID: "instance1", Parent: parent1}
86+
conn2 := agentgrpc.Connection{PodName: "pod2", InstanceID: "instance2", Parent: parent1}
8787

8888
parent2 := types.NamespacedName{Namespace: "default", Name: "parent2"}
8989
conn3 := agentgrpc.Connection{PodName: "pod3", InstanceID: "instance3", Parent: parent2}
@@ -92,7 +92,7 @@ func TestUntrackConnectionsForParent(t *testing.T) {
9292
tracker.Track("key2", conn2)
9393
tracker.Track("key3", conn3)
9494

95-
tracker.UntrackConnectionsForParent(parent)
95+
tracker.UntrackConnectionsForParent(parent1)
9696
g.Expect(tracker.GetConnection("key1")).To(Equal(agentgrpc.Connection{}))
9797
g.Expect(tracker.GetConnection("key2")).To(Equal(agentgrpc.Connection{}))
9898
g.Expect(tracker.GetConnection("key3")).To(Equal(conn3))

0 commit comments

Comments
 (0)