diff --git a/auth/login.go b/auth/login.go index 65e156b..eb7ae2d 100644 --- a/auth/login.go +++ b/auth/login.go @@ -2,6 +2,7 @@ package auth import ( "context" + "errors" "fmt" "net/url" "os" @@ -24,6 +25,8 @@ type LoginCmd struct { ClientID string `help:"Client ID is the OIDC client ID of the API." default:"nineapis.ch-f178254"` } +const ErrNonInteractiveEnvironmentEptyToken = "a static API token is required in non-interactive envirtonments" + func (l *LoginCmd) Run(ctx context.Context, command string, tk api.TokenGetter) error { apiURL, err := url.Parse(l.APIURL) if err != nil { @@ -58,6 +61,10 @@ func (l *LoginCmd) Run(ctx context.Context, command string, tk api.TokenGetter) return login(ctx, cfg, loadingRules.GetDefaultFilename(), userInfo.User, "", project(l.Organization)) } + if !format.IsInteractiveEnvironment(ctx, os.Stdout) { + return errors.New(ErrNonInteractiveEnvironmentEptyToken) + } + usePKCE := true token, err := tk.GetTokenString(ctx, l.IssuerURL, l.ClientID, usePKCE) diff --git a/auth/login_test.go b/auth/login_test.go index 273381d..eef8f27 100644 --- a/auth/login_test.go +++ b/auth/login_test.go @@ -8,7 +8,10 @@ import ( "path" "testing" + "github.com/ninech/nctl/api" + "github.com/ninech/nctl/internal/format" "github.com/ninech/nctl/internal/test" + "github.com/stretchr/testify/require" "k8s.io/client-go/tools/clientcmd" clientcmdapi "k8s.io/client-go/tools/clientcmd/api" ) @@ -19,7 +22,19 @@ func (f *fakeTokenGetter) GetTokenString(ctx context.Context, issuerURL, clientI return test.FakeJWTToken, nil } +func checkErrorRequire(t *testing.T, err error, expectError bool, expectedErrMsg string) { + t.Helper() + if expectError { + require.Error(t, err, "expected an error but got none") + require.EqualError(t, err, expectedErrMsg, "unexpected error message") + } else { + require.NoError(t, err, "expected no error but got one") + } +} + func TestLoginCmd(t *testing.T) { + ctx := format.WithForceInteractiveEnvironment(context.Background(), true) + // write our "existing" kubeconfig to a temp kubeconfig kubeconfig, err := os.CreateTemp("", "*-kubeconfig.yaml") if err != nil { @@ -39,7 +54,7 @@ func TestLoginCmd(t *testing.T) { APIURL: "https://" + apiHost, IssuerURL: "https://auth.example.org", } - if err := cmd.Run(context.Background(), "", tk); err != nil { + if err := cmd.Run(ctx, "", tk); err != nil { t.Fatal(err) } @@ -58,46 +73,83 @@ func TestLoginCmd(t *testing.T) { } func TestLoginStaticToken(t *testing.T) { - kubeconfig, err := os.CreateTemp("", "*-kubeconfig.yaml") - if err != nil { - log.Fatal(err) - } - defer os.Remove(kubeconfig.Name()) - - os.Setenv(clientcmd.RecommendedConfigPathEnvVar, kubeconfig.Name()) - apiHost := "api.example.org" - token := test.FakeJWTToken - cmd := &LoginCmd{APIURL: "https://" + apiHost, APIToken: token, Organization: "test"} - tk := &fakeTokenGetter{} - if err := cmd.Run(context.Background(), "", tk); err != nil { - t.Fatal(err) - } - - // read out the kubeconfig again to test the contents - b, err := io.ReadAll(kubeconfig) - if err != nil { - t.Fatal(err) - } - - kc, err := clientcmd.Load(b) - if err != nil { - t.Fatal(err) - } - - checkConfig(t, kc, 1, "") - - if token != kc.AuthInfos[apiHost].Token { - t.Fatalf("expected token to be %s, got %s", token, kc.AuthInfos[apiHost].Token) - } - - if kc.AuthInfos[apiHost].Exec != nil { - t.Fatalf("expected execConfig to be empty, got %v", kc.AuthInfos[apiHost].Exec) + tests := []struct { + name string + ctx context.Context + cmd *LoginCmd + tk api.TokenGetter + wantToken string + wantErr bool + wantErrMessage string + }{ + { + name: "interactive environment with token", + ctx: format.WithForceInteractiveEnvironment(context.Background(), true), + cmd: &LoginCmd{APIURL: "https://" + apiHost, APIToken: test.FakeJWTToken, Organization: "test"}, + tk: &fakeTokenGetter{}, + wantToken: test.FakeJWTToken, + }, + { + name: "non-interactive environment with token", + ctx: context.Background(), + cmd: &LoginCmd{APIURL: "https://" + apiHost, APIToken: test.FakeJWTToken, Organization: "test"}, + tk: &fakeTokenGetter{}, + wantToken: test.FakeJWTToken, + }, + { + name: "non-interactive environment with empty token", + ctx: context.Background(), + cmd: &LoginCmd{APIURL: "https://" + apiHost, APIToken: "", Organization: "test"}, + wantErr: true, + wantErrMessage: ErrNonInteractiveEnvironmentEptyToken, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + kubeconfig, err := os.CreateTemp("", "*-kubeconfig.yaml") + if err != nil { + log.Fatal(err) + } + defer os.Remove(kubeconfig.Name()) + os.Setenv(clientcmd.RecommendedConfigPathEnvVar, kubeconfig.Name()) + + err = tt.cmd.Run(tt.ctx, "", tt.tk) + checkErrorRequire(t, err, tt.wantErr, tt.wantErrMessage) + + if tt.wantErr { + return + } + + // read out the kubeconfig again to test the contents + b, err := io.ReadAll(kubeconfig) + if err != nil { + t.Fatal(err) + } + + kc, err := clientcmd.Load(b) + if err != nil { + t.Fatal(err) + } + + checkConfig(t, kc, 1, "") + + if tt.wantToken != kc.AuthInfos[apiHost].Token { + t.Fatalf("expected token to be %s, got %s", tt.wantToken, kc.AuthInfos[apiHost].Token) + } + + if kc.AuthInfos[apiHost].Exec != nil { + t.Fatalf("expected execConfig to be empty, got %v", kc.AuthInfos[apiHost].Exec) + } + }) } } func TestLoginCmdWithoutExistingKubeconfig(t *testing.T) { + ctx := format.WithForceInteractiveEnvironment(context.Background(), true) + dir, err := os.MkdirTemp("", "nctl-test-*") if err != nil { t.Fatal(err) @@ -113,7 +165,7 @@ func TestLoginCmdWithoutExistingKubeconfig(t *testing.T) { IssuerURL: "https://auth.example.org", } tk := &fakeTokenGetter{} - if err := cmd.Run(context.Background(), "", tk); err != nil { + if err := cmd.Run(ctx, "", tk); err != nil { t.Fatal(err) } diff --git a/get/all.go b/get/all.go index 4a686e8..5085a08 100644 --- a/get/all.go +++ b/get/all.go @@ -58,7 +58,7 @@ func (cmd *allCmd) Run(ctx context.Context, client *api.Client, get *Cmd) error case noHeader: return printItems(items, *get, defaultOut(cmd.out), false) case yamlOut: - return format.PrettyPrintObjects(items, format.PrintOpts{Out: cmd.out}) + return format.PrettyPrintObjects(ctx, items, format.PrintOpts{Out: cmd.out}) } return nil diff --git a/get/apiserviceaccount.go b/get/apiserviceaccount.go index 80b5866..b930811 100644 --- a/get/apiserviceaccount.go +++ b/get/apiserviceaccount.go @@ -54,7 +54,7 @@ func (asa *apiServiceAccountsCmd) Run(ctx context.Context, client *api.Client, g case noHeader: return asa.print(asaList.Items, get, false) case yamlOut: - return format.PrettyPrintObjects(asaList.GetItems(), format.PrintOpts{}) + return format.PrettyPrintObjects(ctx, asaList.GetItems(), format.PrintOpts{}) } return nil diff --git a/get/application.go b/get/application.go index 4607f8b..b76357f 100644 --- a/get/application.go +++ b/get/application.go @@ -43,14 +43,14 @@ func (cmd *applicationsCmd) Run(ctx context.Context, client *api.Client, get *Cm fmt.Fprintf(defaultOut(cmd.out), "no application with basic auth enabled found\n") return err } - if printErr := printCredentials(creds, get, defaultOut(cmd.out)); printErr != nil { + if printErr := printCredentials(ctx, creds, get, defaultOut(cmd.out)); printErr != nil { err = multierror.Append(err, printErr) } return err } if cmd.DNS { - return printDNSDetails(util.GatherDNSDetails(appList.Items), get, defaultOut(cmd.out)) + return printDNSDetails(ctx, util.GatherDNSDetails(appList.Items), get, defaultOut(cmd.out)) } switch get.Output { @@ -59,7 +59,7 @@ func (cmd *applicationsCmd) Run(ctx context.Context, client *api.Client, get *Cm case noHeader: return printApplication(appList.Items, get, defaultOut(cmd.out), false) case yamlOut: - return format.PrettyPrintObjects(appList.GetItems(), format.PrintOpts{Out: defaultOut(cmd.out)}) + return format.PrettyPrintObjects(ctx, appList.GetItems(), format.PrintOpts{Out: defaultOut(cmd.out)}) case stats: return cmd.printStats(ctx, client, appList.Items, get, defaultOut(cmd.out)) } @@ -101,9 +101,9 @@ func printApplication(apps []apps.Application, get *Cmd, out io.Writer, header b return w.Flush() } -func printCredentials(creds []appCredentials, get *Cmd, out io.Writer) error { +func printCredentials(ctx context.Context, creds []appCredentials, get *Cmd, out io.Writer) error { if get.Output == yamlOut { - return format.PrettyPrintObjects(creds, format.PrintOpts{Out: out}) + return format.PrettyPrintObjects(ctx, creds, format.PrintOpts{Out: out}) } return printCredentialsTabRow(creds, get, out) } @@ -162,9 +162,9 @@ func join(list []string) string { return strings.Join(list, ",") } -func printDNSDetails(items []util.DNSDetail, get *Cmd, out io.Writer) error { +func printDNSDetails(ctx context.Context, items []util.DNSDetail, get *Cmd, out io.Writer) error { if get.Output == yamlOut { - return format.PrettyPrintObjects(items, format.PrintOpts{Out: out}) + return format.PrettyPrintObjects(ctx, items, format.PrintOpts{Out: out}) } return printDNSDetailsTabRow(items, get, out) } diff --git a/get/build.go b/get/build.go index 989a139..c4435b3 100644 --- a/get/build.go +++ b/get/build.go @@ -64,7 +64,7 @@ func (cmd *buildCmd) Run(ctx context.Context, client *api.Client, get *Cmd) erro case noHeader: return printBuild(buildList.Items, get, defaultOut(cmd.out), false) case yamlOut: - return format.PrettyPrintObjects(buildList.GetItems(), format.PrintOpts{Out: defaultOut(cmd.out)}) + return format.PrettyPrintObjects(ctx, buildList.GetItems(), format.PrintOpts{Out: defaultOut(cmd.out)}) } return nil diff --git a/get/cloudvm.go b/get/cloudvm.go index 218b9fe..5cb868a 100644 --- a/get/cloudvm.go +++ b/get/cloudvm.go @@ -35,7 +35,7 @@ func (cmd *cloudVMCmd) Run(ctx context.Context, client *api.Client, get *Cmd) er case noHeader: return cmd.printCloudVirtualMachineInstances(cloudVMList.Items, get, false) case yamlOut: - return format.PrettyPrintObjects(cloudVMList.GetItems(), format.PrintOpts{}) + return format.PrettyPrintObjects(ctx, cloudVMList.GetItems(), format.PrintOpts{}) } return nil diff --git a/get/clusters.go b/get/clusters.go index cd85f7d..54afd8d 100644 --- a/get/clusters.go +++ b/get/clusters.go @@ -35,7 +35,7 @@ func (l *clustersCmd) Run(ctx context.Context, client *api.Client, get *Cmd) err case noHeader: return printClusters(clusterList.Items, get, false) case yamlOut: - return format.PrettyPrintObjects(clusterList.GetItems(), format.PrintOpts{}) + return format.PrettyPrintObjects(ctx, clusterList.GetItems(), format.PrintOpts{}) case contexts: for _, cluster := range clusterList.Items { fmt.Printf("%s\n", config.ContextName(&cluster)) diff --git a/get/keyvaluestore.go b/get/keyvaluestore.go index 4d356ac..d84f516 100644 --- a/get/keyvaluestore.go +++ b/get/keyvaluestore.go @@ -42,7 +42,7 @@ func (cmd *keyValueStoreCmd) Run(ctx context.Context, client *api.Client, get *C case noHeader: return cmd.printKeyValueStoreInstances(keyValueStoreList.Items, get, false) case yamlOut: - return format.PrettyPrintObjects(keyValueStoreList.GetItems(), format.PrintOpts{}) + return format.PrettyPrintObjects(ctx, keyValueStoreList.GetItems(), format.PrintOpts{}) } return nil diff --git a/get/mysql.go b/get/mysql.go index 63dccc0..17dd310 100644 --- a/get/mysql.go +++ b/get/mysql.go @@ -51,7 +51,7 @@ func (cmd *mySQLCmd) Run(ctx context.Context, client *api.Client, get *Cmd) erro case noHeader: return cmd.printMySQLInstances(mysqlList.Items, get, false) case yamlOut: - return format.PrettyPrintObjects(mysqlList.GetItems(), format.PrintOpts{}) + return format.PrettyPrintObjects(ctx, mysqlList.GetItems(), format.PrintOpts{}) } return nil diff --git a/get/postgres.go b/get/postgres.go index bf67829..ab48e55 100644 --- a/get/postgres.go +++ b/get/postgres.go @@ -51,7 +51,7 @@ func (cmd *postgresCmd) Run(ctx context.Context, client *api.Client, get *Cmd) e case noHeader: return cmd.printPostgresInstances(postgresList.Items, get, false) case yamlOut: - return format.PrettyPrintObjects(postgresList.GetItems(), format.PrintOpts{}) + return format.PrettyPrintObjects(ctx, postgresList.GetItems(), format.PrintOpts{}) } return nil diff --git a/get/project.go b/get/project.go index ec113f8..f6f3720 100644 --- a/get/project.go +++ b/get/project.go @@ -43,6 +43,7 @@ func (proj *projectCmd) Run(ctx context.Context, client *api.Client, get *Cmd) e return printProject(projectList, *get, defaultOut(proj.out), false) case yamlOut: return format.PrettyPrintObjects( + ctx, (&management.ProjectList{Items: projectList}).GetItems(), format.PrintOpts{ Out: proj.out, diff --git a/get/project_config.go b/get/project_config.go index d85184e..7e7a16a 100644 --- a/get/project_config.go +++ b/get/project_config.go @@ -35,7 +35,7 @@ func (cmd *configsCmd) Run(ctx context.Context, client *api.Client, get *Cmd) er case noHeader: return printProjectConfigs(projectConfigList.Items, get, defaultOut(cmd.out), false) case yamlOut: - return format.PrettyPrintObjects(projectConfigList.GetItems(), format.PrintOpts{Out: defaultOut(cmd.out)}) + return format.PrettyPrintObjects(ctx, projectConfigList.GetItems(), format.PrintOpts{Out: defaultOut(cmd.out)}) } return nil diff --git a/get/releases.go b/get/releases.go index 5a819fb..56cfa84 100644 --- a/get/releases.go +++ b/get/releases.go @@ -46,7 +46,7 @@ func (cmd *releasesCmd) Run(ctx context.Context, client *api.Client, get *Cmd) e case noHeader: return cmd.printReleases(releaseList.Items, get, false) case yamlOut: - return format.PrettyPrintObjects(releaseList.GetItems(), format.PrintOpts{Out: defaultOut(cmd.out)}) + return format.PrettyPrintObjects(ctx, releaseList.GetItems(), format.PrintOpts{Out: defaultOut(cmd.out)}) } return nil diff --git a/internal/format/print.go b/internal/format/print.go index 93b8721..438bbb4 100644 --- a/internal/format/print.go +++ b/internal/format/print.go @@ -1,6 +1,7 @@ package format import ( + "context" "fmt" "io" "os" @@ -124,9 +125,9 @@ func (p PrintOpts) defaultOut() io.Writer { // PrettyPrintObjects prints the supplied objects in "pretty" colored yaml // with some metadata, status and other default fields stripped out. If // multiple objects are supplied, they will be divided with a yaml divider. -func PrettyPrintObjects[T any](objs []T, opts PrintOpts) error { +func PrettyPrintObjects[T any](ctx context.Context, objs []T, opts PrintOpts) error { for i, obj := range objs { - if err := PrettyPrintObject(obj, opts); err != nil { + if err := PrettyPrintObject(ctx, obj, opts); err != nil { return err } // if there's another object we print a yaml divider @@ -140,13 +141,13 @@ func PrettyPrintObjects[T any](objs []T, opts PrintOpts) error { // PrettyPrintObject prints the supplied object in "pretty" colored yaml // with some metadata, status and other default fields stripped out. -func PrettyPrintObject(obj any, opts PrintOpts) error { +func PrettyPrintObject(ctx context.Context, obj any, opts PrintOpts) error { // we check if we can make a copy of the object as we might alter it. // If we can't make a copy of the object we print it directly as // altering it would change the source object runtimeObject, is := obj.(runtime.Object) if !is { - return printResource(obj, opts) + return printResource(ctx, obj, opts) } objCopy := runtimeObject.DeepCopyObject() @@ -163,12 +164,12 @@ func PrettyPrintObject(obj any, opts PrintOpts) error { if u, is := toPrint.(*unstructured.Unstructured); is { toPrint = u.Object } - return printResource(toPrint, opts) + return printResource(ctx, toPrint, opts) } // printResource prints the resource similar to how // https://github.com/goccy/go-yaml#ycat does it. -func printResource(obj any, opts PrintOpts) error { +func printResource(ctx context.Context, obj any, opts PrintOpts) error { b, err := yaml.Marshal(obj) if err != nil { return err @@ -177,7 +178,7 @@ func printResource(obj any, opts PrintOpts) error { if opts.Out == nil { opts.Out = os.Stdout } - p, err := getPrinter(opts.Out) + p, err := getPrinter(ctx, opts.Out) if err != nil { return err } @@ -189,15 +190,11 @@ func printResource(obj any, opts PrintOpts) error { // getPrinter returns a printer for printing tokens. It will have color output // if the given io.Writer is a terminal. -func getPrinter(out io.Writer) (printer.Printer, error) { +func getPrinter(ctx context.Context, out io.Writer) (printer.Printer, error) { p := printer.Printer{ LineNumber: false, } - f, isFile := out.(*os.File) - if !isFile { - return p, nil - } - if isatty.IsTerminal(f.Fd()) || isatty.IsCygwinTerminal(f.Fd()) { + if IsInteractiveEnvironment(ctx, out) { p.Bool = printerProperty(&printer.Property{ Prefix: format(color.FgHiMagenta), Suffix: format(color.Reset), @@ -218,6 +215,30 @@ func getPrinter(out io.Writer) (printer.Printer, error) { return p, nil } +type key int + +const forceInteractiveEnvironmentKey key = iota + +// WithForceInteractiveEnvironment allows to controll interactivity detection. +// We use context.Context parameter to propagate the forceInteractiveEnvironment +// flag to avoid race conditions. Can be used e.g. in sequential/parallel tests. +func WithForceInteractiveEnvironment(ctx context.Context, value bool) context.Context { + return context.WithValue(ctx, forceInteractiveEnvironmentKey, value) +} + +func IsInteractiveEnvironment(ctx context.Context, out io.Writer) bool { + if ctx != nil { + if v, ok := ctx.Value(forceInteractiveEnvironmentKey).(bool); ok && v { + return true + } + } + f, isFile := out.(*os.File) + if !isFile { + return false + } + return isatty.IsTerminal(f.Fd()) || isatty.IsCygwinTerminal(f.Fd()) +} + // stripObj removes some fields which simply add clutter to the yaml output. // The object should still be applyable afterwards as just defaults and // computed fields get removed.