diff --git a/cli/command/container/client_test.go b/cli/command/container/client_test.go index 621acbcdd097..193d2aeaffe6 100644 --- a/cli/command/container/client_test.go +++ b/cli/command/container/client_test.go @@ -25,7 +25,7 @@ type fakeClient struct { platform *specs.Platform, containerName string) (container.CreateResponse, error) containerStartFunc func(containerID string, options container.StartOptions) error - imageCreateFunc func(parentReference string, options image.CreateOptions) (io.ReadCloser, error) + imageCreateFunc func(ctx context.Context, parentReference string, options image.CreateOptions) (io.ReadCloser, error) infoFunc func() (system.Info, error) containerStatPathFunc func(containerID, path string) (container.PathStat, error) containerCopyFromFunc func(containerID, srcPath string) (io.ReadCloser, container.PathStat, error) @@ -94,9 +94,9 @@ func (f *fakeClient) ContainerRemove(ctx context.Context, containerID string, op return nil } -func (f *fakeClient) ImageCreate(_ context.Context, parentReference string, options image.CreateOptions) (io.ReadCloser, error) { +func (f *fakeClient) ImageCreate(ctx context.Context, parentReference string, options image.CreateOptions) (io.ReadCloser, error) { if f.imageCreateFunc != nil { - return f.imageCreateFunc(parentReference, options) + return f.imageCreateFunc(ctx, parentReference, options) } return nil, nil } diff --git a/cli/command/container/create_test.go b/cli/command/container/create_test.go index 43509a9b007e..f8c5a9bc14b4 100644 --- a/cli/command/container/create_test.go +++ b/cli/command/container/create_test.go @@ -133,7 +133,7 @@ func TestCreateContainerImagePullPolicy(t *testing.T) { return container.CreateResponse{ID: containerID}, nil } }, - imageCreateFunc: func(parentReference string, options image.CreateOptions) (io.ReadCloser, error) { + imageCreateFunc: func(ctx context.Context, parentReference string, options image.CreateOptions) (io.ReadCloser, error) { defer func() { pullCounter++ }() return io.NopCloser(strings.NewReader("")), nil }, diff --git a/cli/command/container/run.go b/cli/command/container/run.go index 9816c787656a..bd7afb9d3a7c 100644 --- a/cli/command/container/run.go +++ b/cli/command/container/run.go @@ -112,8 +112,6 @@ func runRun(ctx context.Context, dockerCli command.Cli, flags *pflag.FlagSet, ro //nolint:gocyclo func runContainer(ctx context.Context, dockerCli command.Cli, runOpts *runOptions, copts *containerOptions, containerCfg *containerConfig) error { - ctx = context.WithoutCancel(ctx) - config := containerCfg.Config stdout, stderr := dockerCli.Out(), dockerCli.Err() apiClient := dockerCli.Client() @@ -135,9 +133,6 @@ func runContainer(ctx context.Context, dockerCli command.Cli, runOpts *runOption config.StdinOnce = false } - ctx, cancelFun := context.WithCancel(ctx) - defer cancelFun() - containerID, err := createContainer(ctx, dockerCli, containerCfg, &runOpts.createOptions) if err != nil { reportError(stderr, "run", err.Error(), true) @@ -154,6 +149,9 @@ func runContainer(ctx context.Context, dockerCli command.Cli, runOpts *runOption defer signal.StopCatch(sigc) } + ctx, cancelFun := context.WithCancel(context.WithoutCancel(ctx)) + defer cancelFun() + var ( waitDisplayID chan struct{} errCh chan error diff --git a/cli/command/container/run_test.go b/cli/command/container/run_test.go index 406816664aa8..9a1569be9210 100644 --- a/cli/command/container/run_test.go +++ b/cli/command/container/run_test.go @@ -2,7 +2,9 @@ package container import ( "context" + "encoding/json" "errors" + "fmt" "io" "net" "syscall" @@ -16,7 +18,9 @@ import ( "github.com/docker/cli/internal/test/notary" "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/image" "github.com/docker/docker/api/types/network" + "github.com/docker/docker/pkg/jsonmessage" specs "github.com/opencontainers/image-spec/specs-go/v1" "github.com/spf13/pflag" "gotest.tools/v3/assert" @@ -189,6 +193,87 @@ func TestRunAttachTermination(t *testing.T) { } } +func TestRunPullTermination(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + attachCh := make(chan struct{}) + fakeCLI := test.NewFakeCli(&fakeClient{ + createContainerFunc: func(config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, + platform *specs.Platform, containerName string, + ) (container.CreateResponse, error) { + select { + case <-ctx.Done(): + return container.CreateResponse{}, ctx.Err() + default: + } + return container.CreateResponse{}, fakeNotFound{} + }, + containerAttachFunc: func(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error) { + return types.HijackedResponse{}, errors.New("shouldn't try to attach to a container") + }, + imageCreateFunc: func(ctx context.Context, parentReference string, options image.CreateOptions) (io.ReadCloser, error) { + server, client := net.Pipe() + t.Cleanup(func() { + _ = server.Close() + }) + go func() { + enc := json.NewEncoder(server) + for i := 0; i < 100; i++ { + select { + case <-ctx.Done(): + assert.NilError(t, server.Close(), "failed to close imageCreateFunc server") + return + default: + } + assert.NilError(t, enc.Encode(jsonmessage.JSONMessage{ + Status: "Downloading", + ID: fmt.Sprintf("id-%d", i), + TimeNano: time.Now().UnixNano(), + Time: time.Now().Unix(), + Progress: &jsonmessage.JSONProgress{ + Current: int64(i), + Total: 100, + Start: 0, + }, + })) + time.Sleep(100 * time.Millisecond) + } + }() + attachCh <- struct{}{} + return client, nil + }, + Version: "1.30", + }) + + cmd := NewRunCommand(fakeCLI) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + cmd.SetArgs([]string{"foobar:latest"}) + + cmdErrC := make(chan error, 1) + go func() { + cmdErrC <- cmd.ExecuteContext(ctx) + }() + + select { + case <-time.After(5 * time.Second): + t.Fatal("imageCreateFunc was not called before the timeout") + case <-attachCh: + } + + cancel() + + select { + case cmdErr := <-cmdErrC: + assert.Equal(t, cmdErr, cli.StatusError{ + StatusCode: 125, + }) + case <-time.After(10 * time.Second): + t.Fatal("cmd did not return before the timeout") + } +} + func TestRunCommandWithContentTrustErrors(t *testing.T) { testCases := []struct { name string