From c51be777674e3e311d5790eea8428b3b6d1ca8aa Mon Sep 17 00:00:00 2001 From: Alano Terblanche <18033717+Benehiko@users.noreply.github.com> Date: Thu, 5 Dec 2024 13:11:17 +0100 Subject: [PATCH] cmd/docker: add cause to user-terminated `context.Context` This patch adds a "cause" to the `context.Context` error when the user terminates the process through SIGINT/SIGTERM. This allows us to distinguish the cause of the `context.Context` cancellation. In future we would also be able to improve the UX of printed errors based on the underlying cause. Signed-off-by: Alano Terblanche <18033717+Benehiko@users.noreply.github.com> cmd/docker: fix possible race between ctx channel and signal channel Signed-off-by: Alano Terblanche <18033717+Benehiko@users.noreply.github.com> test: notifyContext Signed-off-by: Alano Terblanche <18033717+Benehiko@users.noreply.github.com> cmd/docker: print status on SIGTERM and not SIGINT Signed-off-by: Alano Terblanche <18033717+Benehiko@users.noreply.github.com> --- cmd/docker/docker.go | 55 +++++++++++++++++++++++++++++++++++++-- cmd/docker/docker_test.go | 33 +++++++++++++++++++++++ 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/cmd/docker/docker.go b/cmd/docker/docker.go index 29a3ed5957..a757fdb693 100644 --- a/cmd/docker/docker.go +++ b/cmd/docker/docker.go @@ -28,16 +28,57 @@ import ( "go.opentelemetry.io/otel" ) +type errCtxSignalTerminated struct { + signal os.Signal +} + +func (e errCtxSignalTerminated) Error() string { + return "" +} + func main() { - err := dockerMain(context.Background()) + ctx := context.Background() + err := dockerMain(ctx) + + if errors.As(err, &errCtxSignalTerminated{}) { + os.Exit(getExitCode(err)) + return + } + if err != nil && !errdefs.IsCancelled(err) { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(getExitCode(err)) } } +func notifyContext(ctx context.Context, signals ...os.Signal) (context.Context, context.CancelFunc) { + ch := make(chan os.Signal, 1) + signal.Notify(ch, signals...) + + ctxCause, cancel := context.WithCancelCause(ctx) + + go func() { + select { + case <-ctx.Done(): + signal.Stop(ch) + return + case sig := <-ch: + cancel(errCtxSignalTerminated{ + signal: sig, + }) + signal.Stop(ch) + return + } + }() + + return ctxCause, func() { + signal.Stop(ch) + cancel(nil) + } +} + func dockerMain(ctx context.Context) error { - ctx, cancelNotify := signal.NotifyContext(ctx, platformsignals.TerminationSignals...) + ctx, cancelNotify := notifyContext(ctx, platformsignals.TerminationSignals...) defer cancelNotify() dockerCli, err := command.NewDockerCli(command.WithBaseContext(ctx)) @@ -57,6 +98,16 @@ func getExitCode(err error) int { if err == nil { return 0 } + + var userTerminatedErr errCtxSignalTerminated + if errors.As(err, &userTerminatedErr) { + s, ok := userTerminatedErr.signal.(syscall.Signal) + if !ok { + return 1 + } + return 128 + int(s) + } + var stErr cli.StatusError if errors.As(err, &stErr) && stErr.StatusCode != 0 { // FIXME(thaJeztah): StatusCode should never be used with a zero status-code. Check if we do this anywhere. return stErr.StatusCode diff --git a/cmd/docker/docker_test.go b/cmd/docker/docker_test.go index 1998b77c99..cd3625b670 100644 --- a/cmd/docker/docker_test.go +++ b/cmd/docker/docker_test.go @@ -5,10 +5,14 @@ import ( "context" "io" "os" + "syscall" "testing" + "time" "github.com/docker/cli/cli/command" "github.com/docker/cli/cli/debug" + platformsignals "github.com/docker/cli/cmd/docker/internal/signals" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "gotest.tools/v3/assert" is "gotest.tools/v3/assert/cmp" @@ -75,3 +79,32 @@ func TestVersion(t *testing.T) { assert.NilError(t, err) assert.Check(t, is.Contains(b.String(), "Docker version")) } + +func TestUserTerminatedError(t *testing.T) { + ctx, cancel := context.WithTimeoutCause(context.Background(), time.Second*1, errors.New("test timeout")) + t.Cleanup(cancel) + + notifyCtx, cancelNotify := notifyContext(ctx, platformsignals.TerminationSignals...) + t.Cleanup(cancelNotify) + + syscall.Kill(syscall.Getpid(), syscall.SIGINT) + + <-notifyCtx.Done() + assert.ErrorIs(t, context.Cause(notifyCtx), errCtxSignalTerminated{ + signal: syscall.SIGINT, + }) + + assert.Equal(t, getExitCode(context.Cause(notifyCtx)), 130) + + notifyCtx, cancelNotify = notifyContext(ctx, platformsignals.TerminationSignals...) + t.Cleanup(cancelNotify) + + syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + + <-notifyCtx.Done() + assert.ErrorIs(t, context.Cause(notifyCtx), errCtxSignalTerminated{ + signal: syscall.SIGTERM, + }) + + assert.Equal(t, getExitCode(context.Cause(notifyCtx)), 143) +}