Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ctx cancellation on login prompt #5168

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 14 additions & 33 deletions cli/command/registry.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package command

import (
"bufio"
"context"
"fmt"
"io"
"os"
"runtime"
"strings"
Expand All @@ -18,7 +16,6 @@ import (
"github.com/docker/docker/api/types"
registrytypes "github.com/docker/docker/api/types/registry"
"github.com/docker/docker/registry"
"github.com/moby/term"
"github.com/pkg/errors"
)

Expand All @@ -44,7 +41,7 @@ func RegistryAuthenticationPrivilegedFunc(cli Cli, index *registrytypes.IndexInf
default:
}

err = ConfigureAuth(cli, "", "", &authConfig, isDefaultRegistry)
err = ConfigureAuth(ctx, cli, "", "", &authConfig, isDefaultRegistry)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -90,7 +87,7 @@ func GetDefaultAuthConfig(cfg *configfile.ConfigFile, checkCredStore bool, serve
}

// ConfigureAuth handles prompting of user's username and password if needed
func ConfigureAuth(cli Cli, flUser, flPassword string, authconfig *registrytypes.AuthConfig, isDefaultRegistry bool) error {
func ConfigureAuth(ctx context.Context, cli Cli, flUser, flPassword string, authconfig *registrytypes.AuthConfig, isDefaultRegistry bool) error {
// On Windows, force the use of the regular OS stdin stream.
//
// See:
Expand Down Expand Up @@ -125,9 +122,15 @@ func ConfigureAuth(cli Cli, flUser, flPassword string, authconfig *registrytypes
fmt.Fprintln(cli.Out())
}
}
promptWithDefault(cli.Out(), "Username", authconfig.Username)

var prompt string
if authconfig.Username == "" {
prompt = "Username: "
} else {
prompt = fmt.Sprintf("Username (%s): ", authconfig.Username)
}
var err error
flUser, err = readInput(cli.In())
flUser, err = PromptForInput(ctx, cli.In(), cli.Out(), prompt)
if err != nil {
return err
}
Expand All @@ -139,16 +142,13 @@ func ConfigureAuth(cli Cli, flUser, flPassword string, authconfig *registrytypes
return errors.Errorf("Error: Non-null Username Required")
}
if flPassword == "" {
oldState, err := term.SaveState(cli.In().FD())
restoreInput, err := DisableInputEcho(cli.In())
if err != nil {
return err
}
fmt.Fprintf(cli.Out(), "Password: ")
_ = term.DisableEcho(cli.In().FD(), oldState)
defer func() {
_ = term.RestoreTerminal(cli.In().FD(), oldState)
}()
flPassword, err = readInput(cli.In())
defer restoreInput()

flPassword, err = PromptForInput(ctx, cli.In(), cli.Out(), "Password: ")
if err != nil {
return err
}
Expand All @@ -164,25 +164,6 @@ func ConfigureAuth(cli Cli, flUser, flPassword string, authconfig *registrytypes
return nil
}

// readInput reads, and returns user input from in. It tries to return a
// single line, not including the end-of-line bytes, and trims leading
// and trailing whitespace.
func readInput(in io.Reader) (string, error) {
line, _, err := bufio.NewReader(in).ReadLine()
if err != nil {
return "", errors.Wrap(err, "error while reading input")
}
return strings.TrimSpace(string(line)), nil
}

func promptWithDefault(out io.Writer, prompt string, configDefault string) {
if configDefault == "" {
fmt.Fprintf(out, "%s: ", prompt)
} else {
fmt.Fprintf(out, "%s (%s): ", prompt, configDefault)
}
}

// RetrieveAuthTokenFromImage retrieves an encoded auth token given a complete
// image. The auth configuration is serialized as a base64url encoded RFC4648,
// section 5) JSON string for sending through the X-Registry-Auth header.
Expand Down
2 changes: 1 addition & 1 deletion cli/command/registry/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func runLogin(ctx context.Context, dockerCli command.Cli, opts loginOptions) err
response, err = loginWithCredStoreCreds(ctx, dockerCli, &authConfig)
}
if err != nil || authConfig.Username == "" || authConfig.Password == "" {
err = command.ConfigureAuth(dockerCli, opts.user, opts.password, &authConfig, isDefaultRegistry)
err = command.ConfigureAuth(ctx, dockerCli, opts.user, opts.password, &authConfig, isDefaultRegistry)
if err != nil {
return err
}
Expand Down
41 changes: 41 additions & 0 deletions cli/command/registry/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import (
"errors"
"fmt"
"testing"
"time"

"github.com/creack/pty"
"github.com/docker/cli/cli/command"
configtypes "github.com/docker/cli/cli/config/types"
"github.com/docker/cli/cli/streams"
"github.com/docker/cli/internal/test"
Expand Down Expand Up @@ -185,3 +188,41 @@ func TestRunLogin(t *testing.T) {
})
}
}

func TestLoginTermination(t *testing.T) {
p, tty, err := pty.Open()
assert.NilError(t, err)

t.Cleanup(func() {
_ = tty.Close()
_ = p.Close()
})

cli := test.NewFakeCli(&fakeClient{}, func(fc *test.FakeCli) {
fc.SetOut(streams.NewOut(tty))
fc.SetIn(streams.NewIn(tty))
})
tmpFile := fs.NewFile(t, "test-login-termination")
defer tmpFile.Remove()

configFile := cli.ConfigFile()
configFile.Filename = tmpFile.Path()

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

runErr := make(chan error)
go func() {
runErr <- runLogin(ctx, cli, loginOptions{})
}()

// Let the prompt get canceled by the context
cancel()

select {
case <-time.After(1 * time.Second):
t.Fatal("timed out after 1 second. `runLogin` did not return")
case err := <-runErr:
assert.ErrorIs(t, err, command.ErrPromptTerminated)
}
}
43 changes: 43 additions & 0 deletions cli/command/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/docker/docker/api/types/versions"
"github.com/docker/docker/errdefs"
"github.com/moby/sys/sequential"
"github.com/moby/term"
"github.com/pkg/errors"
"github.com/spf13/pflag"
)
Expand Down Expand Up @@ -76,6 +77,48 @@ func PrettyPrint(i any) string {

var ErrPromptTerminated = errdefs.Cancelled(errors.New("prompt terminated"))

// DisableInputEcho disables input echo on the provided streams.In.
// This is useful when the user provides sensitive information like passwords.
// The function returns a restore function that should be called to restore the
// terminal state.
func DisableInputEcho(ins *streams.In) (restore func() error, err error) {
oldState, err := term.SaveState(ins.FD())
if err != nil {
return nil, err
}
restore = func() error {
return term.RestoreTerminal(ins.FD(), oldState)
}
return restore, term.DisableEcho(ins.FD(), oldState)
}

// PromptForInput requests input from the user.
//
// If the user terminates the CLI with SIGINT or SIGTERM while the prompt is
// active, the prompt will return an empty string ("") with an ErrPromptTerminated error.
// When the prompt returns an error, the caller should propagate the error up
// the stack and close the io.Reader used for the prompt which will prevent the
// background goroutine from blocking indefinitely.
func PromptForInput(ctx context.Context, in io.Reader, out io.Writer, message string) (string, error) {
_, _ = fmt.Fprint(out, message)

result := make(chan string)
go func() {
scanner := bufio.NewScanner(in)
if scanner.Scan() {
result <- strings.TrimSpace(scanner.Text())
}
}()

select {
case <-ctx.Done():
_, _ = fmt.Fprintln(out, "")
return "", ErrPromptTerminated
case r := <-result:
return r, nil
}
}

// PromptForConfirmation requests and checks confirmation from the user.
// This will display the provided message followed by ' [y/N] '. If the user
// input 'y' or 'Y' it returns true otherwise false. If no message is provided,
Expand Down
61 changes: 61 additions & 0 deletions cli/command/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"github.com/docker/cli/cli/command"
"github.com/docker/cli/cli/streams"
"github.com/docker/cli/internal/test"
"github.com/pkg/errors"
"gotest.tools/v3/assert"
Expand Down Expand Up @@ -80,6 +81,66 @@ func TestValidateOutputPath(t *testing.T) {
}
}

func TestPromptForInput(t *testing.T) {
t.Run("case=cancelling the context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
reader, _ := io.Pipe()

buf := new(bytes.Buffer)
bufioWriter := bufio.NewWriter(buf)

wroteHook := make(chan struct{}, 1)
promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) {
wroteHook <- struct{}{}
})

promptErr := make(chan error, 1)
go func() {
_, err := command.PromptForInput(ctx, streams.NewIn(reader), streams.NewOut(promptOut), "Enter something")
promptErr <- err
}()

select {
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for prompt to write to buffer")
case <-wroteHook:
cancel()
}

select {
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for prompt to be canceled")
case err := <-promptErr:
assert.ErrorIs(t, err, command.ErrPromptTerminated)
}
})

t.Run("case=user input should be properly trimmed", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)

reader, writer := io.Pipe()

buf := new(bytes.Buffer)
bufioWriter := bufio.NewWriter(buf)

wroteHook := make(chan struct{}, 1)
promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) {
wroteHook <- struct{}{}
})

go func() {
<-wroteHook
writer.Write([]byte(" foo \n"))
}()

answer, err := command.PromptForInput(ctx, streams.NewIn(reader), streams.NewOut(promptOut), "Enter something")
assert.NilError(t, err)
assert.Equal(t, answer, "foo")
})
}

func TestPromptForConfirmation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
Expand Down
Loading