Skip to content

Commit 3f0d90a

Browse files
committed
feat: global signal handling with context cancellation
Signed-off-by: Alano Terblanche <[email protected]>
1 parent 8b924a5 commit 3f0d90a

File tree

9 files changed

+160
-49
lines changed

9 files changed

+160
-49
lines changed

cli/command/container/attach.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,12 @@ func RunAttach(ctx context.Context, dockerCLI command.Cli, containerID string, o
105105

106106
if opts.Proxy && !c.Config.Tty {
107107
sigc := notifyAllSignals()
108-
go ForwardAllSignals(ctx, apiClient, containerID, sigc)
108+
// since we're explicitly setting up signal handling here, and the daemon will
109+
// get notified independently of the clients ctx cancellation, we use this context
110+
// but without cancellation to avoid ForwardAllSignals from returning
111+
// before all signals are forwarded.
112+
bgCtx := context.WithoutCancel(ctx)
113+
go ForwardAllSignals(bgCtx, apiClient, containerID, sigc)
109114
defer signal.StopCatch(sigc)
110115
}
111116

cli/command/container/client_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type fakeClient struct {
3737
containerRemoveFunc func(ctx context.Context, containerID string, options container.RemoveOptions) error
3838
containerKillFunc func(ctx context.Context, containerID, signal string) error
3939
containerPruneFunc func(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error)
40+
containerAttachFunc func(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error)
4041
Version string
4142
}
4243

@@ -173,3 +174,10 @@ func (f *fakeClient) ContainersPrune(ctx context.Context, pruneFilters filters.A
173174
}
174175
return types.ContainersPruneReport{}, nil
175176
}
177+
178+
func (f *fakeClient) ContainerAttach(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error) {
179+
if f.containerAttachFunc != nil {
180+
return f.containerAttachFunc(ctx, containerID, options)
181+
}
182+
return types.HijackedResponse{}, nil
183+
}

cli/command/container/run.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,12 @@ func runContainer(ctx context.Context, dockerCli command.Cli, runOpts *runOption
150150
}
151151
if runOpts.sigProxy {
152152
sigc := notifyAllSignals()
153-
go ForwardAllSignals(ctx, apiClient, containerID, sigc)
153+
// since we're explicitly setting up signal handling here, and the daemon will
154+
// get notified independently of the clients ctx cancellation, we use this context
155+
// but without cancellation to avoid ForwardAllSignals from returning
156+
// before all signals are forwarded.
157+
bgCtx := context.WithoutCancel(ctx)
158+
go ForwardAllSignals(bgCtx, apiClient, containerID, sigc)
154159
defer signal.StopCatch(sigc)
155160
}
156161

cli/command/container/run_test.go

+69
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,18 @@ import (
55
"errors"
66
"fmt"
77
"io"
8+
"net"
9+
"os/signal"
10+
"syscall"
811
"testing"
12+
"time"
913

14+
"github.com/creack/pty"
1015
"github.com/docker/cli/cli"
16+
"github.com/docker/cli/cli/streams"
1117
"github.com/docker/cli/internal/test"
1218
"github.com/docker/cli/internal/test/notary"
19+
"github.com/docker/docker/api/types"
1320
"github.com/docker/docker/api/types/container"
1421
"github.com/docker/docker/api/types/network"
1522
specs "github.com/opencontainers/image-spec/specs-go/v1"
@@ -32,6 +39,68 @@ func TestRunLabel(t *testing.T) {
3239
assert.NilError(t, cmd.Execute())
3340
}
3441

42+
func TestRunAttachTermination(t *testing.T) {
43+
p, tty, err := pty.Open()
44+
assert.NilError(t, err)
45+
46+
defer func() {
47+
_ = tty.Close()
48+
_ = p.Close()
49+
}()
50+
51+
killCh := make(chan struct{})
52+
attachCh := make(chan struct{})
53+
fakeCLI := test.NewFakeCli(&fakeClient{
54+
createContainerFunc: func(_ *container.Config, _ *container.HostConfig, _ *network.NetworkingConfig, _ *specs.Platform, _ string) (container.CreateResponse, error) {
55+
return container.CreateResponse{
56+
ID: "id",
57+
}, nil
58+
},
59+
containerKillFunc: func(ctx context.Context, containerID, signal string) error {
60+
killCh <- struct{}{}
61+
return nil
62+
},
63+
containerAttachFunc: func(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error) {
64+
server, client := net.Pipe()
65+
t.Cleanup(func() {
66+
_ = server.Close()
67+
})
68+
attachCh <- struct{}{}
69+
return types.NewHijackedResponse(client, types.MediaTypeRawStream), nil
70+
},
71+
Version: "1.36",
72+
}, func(fc *test.FakeCli) {
73+
fc.SetOut(streams.NewOut(tty))
74+
fc.SetIn(streams.NewIn(tty))
75+
})
76+
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM)
77+
defer cancel()
78+
79+
assert.Equal(t, fakeCLI.In().IsTerminal(), true)
80+
assert.Equal(t, fakeCLI.Out().IsTerminal(), true)
81+
82+
cmd := NewRunCommand(fakeCLI)
83+
cmd.SetArgs([]string{"-it", "busybox"})
84+
cmd.SilenceUsage = true
85+
go func() {
86+
assert.ErrorIs(t, cmd.ExecuteContext(ctx), context.Canceled)
87+
}()
88+
89+
select {
90+
case <-time.After(5 * time.Second):
91+
t.Fatal("containerAttachFunc was not called before the 5 second timeout")
92+
case <-attachCh:
93+
}
94+
95+
assert.NilError(t, syscall.Kill(syscall.Getpid(), syscall.SIGTERM))
96+
select {
97+
case <-time.After(5 * time.Second):
98+
cancel()
99+
t.Fatal("containerKillFunc was not called before the 5 second timeout")
100+
case <-killCh:
101+
}
102+
}
103+
35104
func TestRunCommandWithContentTrustErrors(t *testing.T) {
36105
testCases := []struct {
37106
name string

cli/command/container/start.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ func RunStart(ctx context.Context, dockerCli command.Cli, opts *StartOptions) er
8787
// We always use c.ID instead of container to maintain consistency during `docker start`
8888
if !c.Config.Tty {
8989
sigc := notifyAllSignals()
90-
go ForwardAllSignals(ctx, dockerCli.Client(), c.ID, sigc)
90+
bgCtx := context.WithoutCancel(ctx)
91+
go ForwardAllSignals(bgCtx, dockerCli.Client(), c.ID, sigc)
9192
defer signal.StopCatch(sigc)
9293
}
9394

cli/command/utils.go

+1-9
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@ import (
99
"fmt"
1010
"io"
1111
"os"
12-
"os/signal"
1312
"path/filepath"
1413
"runtime"
1514
"strings"
16-
"syscall"
1715

1816
"github.com/docker/cli/cli/streams"
1917
"github.com/docker/docker/api/types/filters"
@@ -103,11 +101,6 @@ func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, m
103101

104102
result := make(chan bool)
105103

106-
// Catch the termination signal and exit the prompt gracefully.
107-
// The caller is responsible for properly handling the termination.
108-
notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
109-
defer notifyCancel()
110-
111104
go func() {
112105
var res bool
113106
scanner := bufio.NewScanner(ins)
@@ -121,8 +114,7 @@ func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, m
121114
}()
122115

123116
select {
124-
case <-notifyCtx.Done():
125-
// print a newline on termination
117+
case <-ctx.Done():
126118
_, _ = fmt.Fprintln(outs, "")
127119
return false, ErrPromptTerminated
128120
case r := <-result:

cli/command/utils_test.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io"
99
"os"
10+
"os/signal"
1011
"path/filepath"
1112
"strings"
1213
"syscall"
@@ -135,6 +136,9 @@ func TestPromptForConfirmation(t *testing.T) {
135136
}, promptResult{false, nil}},
136137
} {
137138
t.Run("case="+tc.desc, func(t *testing.T) {
139+
notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
140+
t.Cleanup(notifyCancel)
141+
138142
buf.Reset()
139143
promptReader, promptWriter = io.Pipe()
140144

@@ -145,7 +149,7 @@ func TestPromptForConfirmation(t *testing.T) {
145149

146150
result := make(chan promptResult, 1)
147151
go func() {
148-
r, err := command.PromptForConfirmation(ctx, promptReader, promptOut, "")
152+
r, err := command.PromptForConfirmation(notifyCtx, promptReader, promptOut, "")
149153
result <- promptResult{r, err}
150154
}()
151155

cmd/docker/docker.go

+58-33
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,20 @@ import (
2828
)
2929

3030
func main() {
31-
ctx := context.Background()
31+
statusCode := dockerMain()
32+
if statusCode != 0 {
33+
os.Exit(statusCode)
34+
}
35+
}
36+
37+
func dockerMain() int {
38+
ctx, cancelNotify := signal.NotifyContext(context.Background(), platformsignals.TerminationSignals...)
39+
defer cancelNotify()
3240

3341
dockerCli, err := command.NewDockerCli(command.WithBaseContext(ctx))
3442
if err != nil {
3543
fmt.Fprintln(os.Stderr, err)
36-
os.Exit(1)
44+
return 1
3745
}
3846
logrus.SetOutput(dockerCli.Err())
3947
otel.SetErrorHandler(debug.OTELErrorHandler)
@@ -46,16 +54,17 @@ func main() {
4654
// StatusError should only be used for errors, and all errors should
4755
// have a non-zero exit status, so never exit with 0
4856
if sterr.StatusCode == 0 {
49-
os.Exit(1)
57+
return 1
5058
}
51-
os.Exit(sterr.StatusCode)
59+
return sterr.StatusCode
5260
}
5361
if errdefs.IsCancelled(err) {
54-
os.Exit(0)
62+
return 0
5563
}
5664
fmt.Fprintln(dockerCli.Err(), err)
57-
os.Exit(1)
65+
return 1
5866
}
67+
return 0
5968
}
6069

6170
func newDockerCommand(dockerCli *command.DockerCli) *cli.TopLevelCommand {
@@ -224,7 +233,7 @@ func setValidateArgs(dockerCli command.Cli, cmd *cobra.Command) {
224233
})
225234
}
226235

227-
func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, envs []string) error {
236+
func tryPluginRun(ctx context.Context, dockerCli command.Cli, cmd *cobra.Command, subcommand string, envs []string) error {
228237
plugincmd, err := pluginmanager.PluginRunCommand(dockerCli, subcommand, cmd)
229238
if err != nil {
230239
return err
@@ -242,40 +251,56 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,
242251

243252
// Background signal handling logic: block on the signals channel, and
244253
// notify the plugin via the PluginServer (or signal) as appropriate.
245-
const exitLimit = 3
246-
signals := make(chan os.Signal, exitLimit)
247-
signal.Notify(signals, platformsignals.TerminationSignals...)
254+
const exitLimit = 2
255+
256+
tryTerminatePlugin := func(force bool) {
257+
// If stdin is a TTY, the kernel will forward
258+
// signals to the subprocess because the shared
259+
// pgid makes the TTY a controlling terminal.
260+
//
261+
// The plugin should have it's own copy of this
262+
// termination logic, and exit after 3 retries
263+
// on it's own.
264+
if dockerCli.Out().IsTerminal() {
265+
return
266+
}
267+
268+
// Terminate the plugin server, which will
269+
// close all connections with plugin
270+
// subprocesses, and signal them to exit.
271+
//
272+
// Repeated invocations will result in EINVAL,
273+
// or EBADF; but that is fine for our purposes.
274+
_ = srv.Close()
275+
276+
// force the process to terminate if it hasn't already
277+
if force {
278+
_ = plugincmd.Process.Kill()
279+
_, _ = fmt.Fprint(dockerCli.Err(), "got 3 SIGTERM/SIGINTs, forcefully exiting\n")
280+
os.Exit(1)
281+
}
282+
}
283+
248284
go func() {
249285
retries := 0
250-
for range signals {
251-
// If stdin is a TTY, the kernel will forward
252-
// signals to the subprocess because the shared
253-
// pgid makes the TTY a controlling terminal.
254-
//
255-
// The plugin should have it's own copy of this
256-
// termination logic, and exit after 3 retries
257-
// on it's own.
258-
if dockerCli.Out().IsTerminal() {
259-
continue
260-
}
286+
force := false
287+
// catch the first signal through context cancellation
288+
<-ctx.Done()
289+
tryTerminatePlugin(force)
261290

262-
// Terminate the plugin server, which will
263-
// close all connections with plugin
264-
// subprocesses, and signal them to exit.
265-
//
266-
// Repeated invocations will result in EINVAL,
267-
// or EBADF; but that is fine for our purposes.
268-
_ = srv.Close()
291+
// register subsequent signals
292+
signals := make(chan os.Signal, exitLimit)
293+
signal.Notify(signals, platformsignals.TerminationSignals...)
269294

295+
for range signals {
296+
retries++
270297
// If we're still running after 3 interruptions
271298
// (SIGINT/SIGTERM), send a SIGKILL to the plugin as a
272299
// final attempt to terminate, and exit.
273-
retries++
274300
if retries >= exitLimit {
275-
_, _ = fmt.Fprintf(dockerCli.Err(), "got %d SIGTERM/SIGINTs, forcefully exiting\n", retries)
276-
_ = plugincmd.Process.Kill()
277-
os.Exit(1)
301+
force = true
278302
}
303+
tryTerminatePlugin(force)
279304
}
280305
}()
281306

@@ -338,7 +363,7 @@ func runDocker(ctx context.Context, dockerCli *command.DockerCli) error {
338363
ccmd, _, err := cmd.Find(args)
339364
subCommand = ccmd
340365
if err != nil || pluginmanager.IsPluginCommand(ccmd) {
341-
err := tryPluginRun(dockerCli, cmd, args[0], envs)
366+
err := tryPluginRun(ctx, dockerCli, cmd, args[0], envs)
342367
if err == nil {
343368
if dockerCli.HooksEnabled() && dockerCli.Out().IsTerminal() && ccmd != nil {
344369
pluginmanager.RunPluginHooks(ctx, dockerCli, cmd, ccmd, args)

internal/test/cmd.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package test
33
import (
44
"context"
55
"os"
6-
"syscall"
76
"testing"
87
"time"
98

@@ -32,8 +31,11 @@ func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli
3231
assert.NilError(t, err)
3332
cli.SetIn(streams.NewIn(r))
3433

34+
notifyCtx, notifyCancel := context.WithCancel(ctx)
35+
t.Cleanup(notifyCancel)
36+
3537
go func() {
36-
errChan <- cmd.ExecuteContext(ctx)
38+
errChan <- cmd.ExecuteContext(notifyCtx)
3739
}()
3840

3941
writeCtx, writeCancel := context.WithTimeout(ctx, 100*time.Millisecond)
@@ -66,7 +68,7 @@ func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli
6668

6769
// sigint and sigterm are caught by the prompt
6870
// this allows us to gracefully exit the prompt with a 0 exit code
69-
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
71+
notifyCancel()
7072

7173
select {
7274
case <-errCtx.Done():

0 commit comments

Comments
 (0)