diff --git a/cmd/oras/internal/errors/errors.go b/cmd/oras/internal/errors/errors.go index addd92117..bb7784660 100644 --- a/cmd/oras/internal/errors/errors.go +++ b/cmd/oras/internal/errors/errors.go @@ -161,15 +161,34 @@ func NewErrEmptyTagOrDigest(ref string, cmd *cobra.Command, needsTag bool) error // CheckMutuallyExclusiveFlags checks if any mutually exclusive flags are used // at the same time, returns an error when detecting used exclusive flags. func CheckMutuallyExclusiveFlags(fs *pflag.FlagSet, exclusiveFlagSet ...string) error { - var changedFlags []string - for _, flagName := range exclusiveFlagSet { - if fs.Changed(flagName) { - changedFlags = append(changedFlags, fmt.Sprintf("--%s", flagName)) - } - } + changedFlags, _ := checkChangedFlags(fs, exclusiveFlagSet...) if len(changedFlags) >= 2 { flags := strings.Join(changedFlags, ", ") return fmt.Errorf("%s cannot be used at the same time", flags) } return nil } + +// CheckRequiredTogetherFlags checks if any flags required together are all used, +// returns an error when detecting any flags not used while other flags have been used. +func CheckRequiredTogetherFlags(fs *pflag.FlagSet, requiredTogetherFlags ...string) error { + changed, unchanged := checkChangedFlags(fs, requiredTogetherFlags...) + unchangedCount := len(unchanged) + if unchangedCount != 0 && unchangedCount != len(requiredTogetherFlags) { + changed := strings.Join(changed, ", ") + unchanged := strings.Join(unchanged, ", ") + return fmt.Errorf("%s must be used in conjunction with %s", changed, unchanged) + } + return nil +} + +func checkChangedFlags(fs *pflag.FlagSet, flagSet ...string) (changedFlags []string, unchangedFlags []string) { + for _, flagName := range flagSet { + if fs.Changed(flagName) { + changedFlags = append(changedFlags, fmt.Sprintf("--%s", flagName)) + } else { + unchangedFlags = append(unchangedFlags, fmt.Sprintf("--%s", flagName)) + } + } + return +} diff --git a/cmd/oras/internal/errors/errors_test.go b/cmd/oras/internal/errors/errors_test.go index bf3c2314a..d5b3898dc 100644 --- a/cmd/oras/internal/errors/errors_test.go +++ b/cmd/oras/internal/errors/errors_test.go @@ -53,3 +53,42 @@ func TestCheckMutuallyExclusiveFlags(t *testing.T) { }) } } + +func TestCheckRequiredTogetherFlags(t *testing.T) { + fs := &pflag.FlagSet{} + var foo, bar, hello, world bool + fs.BoolVar(&foo, "foo", false, "foo test") + fs.BoolVar(&bar, "bar", false, "bar test") + fs.BoolVar(&hello, "hello", false, "hello test") + fs.BoolVar(&world, "world", false, "world test") + fs.Lookup("foo").Changed = true + fs.Lookup("bar").Changed = true + tests := []struct { + name string + requiredTogetherFlags []string + wantErr bool + }{ + { + "--foo and --bar are both used, no error is returned", + []string{"foo", "bar"}, + false, + }, + { + "--foo and --hello are not both used, an error is returned", + []string{"foo", "hello"}, + true, + }, + { + "none of --hello and --world is used, no error is returned", + []string{"hello", "world"}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := CheckRequiredTogetherFlags(fs, tt.requiredTogetherFlags...); (err != nil) != tt.wantErr { + t.Errorf("CheckRequiredTogetherFlags() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/cmd/oras/internal/option/remote.go b/cmd/oras/internal/option/remote.go index 9b47d04bc..b9f761d8d 100644 --- a/cmd/oras/internal/option/remote.go +++ b/cmd/oras/internal/option/remote.go @@ -166,9 +166,9 @@ func (opts *Remote) Parse(cmd *cobra.Command) error { if err := opts.parseCustomHeaders(); err != nil { return err } - - cmd.MarkFlagsRequiredTogether(certFileAndKeyFileFlags...) - + if err := oerrors.CheckRequiredTogetherFlags(cmd.Flags(), certFileAndKeyFileFlags...); err != nil { + return err + } return opts.readSecret(cmd) } diff --git a/test/e2e/internal/utils/exec.go b/test/e2e/internal/utils/exec.go index 8e26d75dc..255306f7d 100644 --- a/test/e2e/internal/utils/exec.go +++ b/test/e2e/internal/utils/exec.go @@ -34,9 +34,9 @@ const ( orasBinary = "oras" // customize your own basic auth file via `htpasswd -cBb ` - Username = "hello" - Password = "oras-test" - DefaultTimeout = 10 * time.Second + Username = "hello" + Password = "oras-test" + DefaultTimeout = 10 * time.Second // If the command hasn't exited yet, ginkgo session ExitCode is -1 notResponding = -1 ) diff --git a/test/e2e/suite/auth/auth.go b/test/e2e/suite/auth/auth.go index 38c5df2b5..8e82b7937 100644 --- a/test/e2e/suite/auth/auth.go +++ b/test/e2e/suite/auth/auth.go @@ -166,6 +166,11 @@ var _ = Describe("Common registry user", func() { ORAS("login", ZOTHost, "--identity-token", Password). MatchErrKeyWords("WARNING", "Using --identity-token via the CLI is insecure", "Use --identity-token-stdin").ExpectFailure().Exec() }) + + It("should fail if --cert-file is not used with --key-file with correct error message", func() { + ORAS("login", ZOTHost, "--cert-file", "test"). + MatchErrKeyWords("--cert-file", "in conjunction with", "--key-file").ExpectFailure().Exec() + }) }) When("using legacy config", func() {