Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cli): reauthenticate user in case of invalid token #3643

Merged
merged 6 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions cli/cmd/configure_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ var configureCmd = &cobra.Command{
Short: "Configure your tracetest CLI",
Long: "Configure your tracetest CLI",
PreRun: setupLogger,
Run: WithResultHandler(WithParamsHandler(configParams)(func(cmd *cobra.Command, _ []string) (string, error) {
ctx := context.Background()
Run: WithResultHandler(WithParamsHandler(configParams)(func(ctx context.Context, cmd *cobra.Command, _ []string) (string, error) {
flags := agentConfig.Flags{
CI: configParams.CI,
}
Expand Down
3 changes: 2 additions & 1 deletion cli/cmd/dashboard_cmd.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"context"
"fmt"

"github.com/kubeshop/tracetest/cli/ui"
Expand All @@ -13,7 +14,7 @@ var dashboardCmd = &cobra.Command{
Short: "Opens the Tracetest Dashboard URL",
Long: "Opens the Tracetest Dashboard URL",
PreRun: setupCommand(),
Run: WithResultHandler(func(_ *cobra.Command, _ []string) (string, error) {
Run: WithResultHandler(func(_ context.Context, _ *cobra.Command, _ []string) (string, error) {
if cliConfig.IsEmpty() {
return "", fmt.Errorf("missing Tracetest endpoint configuration")
}
Expand Down
62 changes: 57 additions & 5 deletions cli/cmd/middleware.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,55 @@
package cmd

import (
"context"
"errors"
"fmt"
"os"

"github.com/kubeshop/tracetest/cli/config"
"github.com/kubeshop/tracetest/cli/pkg/resourcemanager"
"github.com/kubeshop/tracetest/cli/ui"

"github.com/spf13/cobra"
)

type RunFn func(cmd *cobra.Command, args []string) (string, error)
type RunFn func(ctx context.Context, cmd *cobra.Command, args []string) (string, error)
type CobraRunFn func(cmd *cobra.Command, args []string)
type MiddlewareWrapper func(RunFn) RunFn

func rootCtx(cmd *cobra.Command) context.Context {
// cobra does not correctly progpagate rootcmd context to sub commands,
// so we need to manually traverse the command tree to find the root context
if cmd == nil {
return nil
}

var (
ctx = cmd.Context()
p = cmd.Parent()
)
if cmd.Parent() == nil {
return ctx
}
for {
ctx = p.Context()
p = p.Parent()
if p == nil {
break
}
}
return ctx
}

func WithResultHandler(runFn RunFn) CobraRunFn {
return func(cmd *cobra.Command, args []string) {
res, err := runFn(cmd, args)
// we need the root cmd context in case of an error caused rerun
ctx := rootCtx(cmd)

res, err := runFn(ctx, cmd, args)

if err != nil {
OnError(err)
handleError(ctx, err)
return
}

Expand All @@ -29,6 +59,28 @@ func WithResultHandler(runFn RunFn) CobraRunFn {
}
}

func handleError(ctx context.Context, err error) {
reqErr := resourcemanager.RequestError{}
if errors.As(err, &reqErr) && reqErr.IsAuthError {
handleAuthError(ctx)
} else {
OnError(err)
}
}

func handleAuthError(ctx context.Context) {
ui.DefaultUI.Warning("Your authentication token has expired, please log in again.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this handles environment tokens expiration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good question. I didn't test that case, but this mechanism is triggered whenever the server returns a 401. So if an expired token throws a 401, then this will be triggered.

configurator.
WithOnFinish(func(ctx context.Context, _ config.Config) {
retryCommand(ctx)
}).
ExecuteUserLogin(ctx, cliConfig)
}

func retryCommand(ctx context.Context) {
handleRootExecErr(rootCmd.ExecuteContext(ctx))
}

type errorMessageRenderer interface {
Render()
}
Expand Down Expand Up @@ -66,7 +118,7 @@ func handleErrorMessage(err error) string {

func WithParamsHandler(validators ...Validator) MiddlewareWrapper {
return func(runFn RunFn) RunFn {
return func(cmd *cobra.Command, args []string) (string, error) {
return func(ctx context.Context, cmd *cobra.Command, args []string) (string, error) {
errors := make([]error, 0)

for _, validator := range validators {
Expand All @@ -82,7 +134,7 @@ func WithParamsHandler(validators ...Validator) MiddlewareWrapper {
return "", fmt.Errorf(errorText)
}

return runFn(cmd, args)
return runFn(ctx, cmd, args)
}
}
}
Expand Down
3 changes: 1 addition & 2 deletions cli/cmd/resource_apply_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ func init() {
Short: "Apply resources",
Long: "Apply (create/update) resources to your Tracetest server",
PreRun: setupCommand(),
Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) {
Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) {
resourceType := resourceParams.ResourceName
ctx := context.Background()

resourceClient, err := resources.Get(resourceType)
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions cli/cmd/resource_delete_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ func init() {
Short: "Delete resources",
Long: "Delete resources from your Tracetest server",
PreRun: setupCommand(),
Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) {
Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) {
resourceType := resourceParams.ResourceName
ctx := context.Background()

resourceClient, err := resources.Get(resourceType)
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions cli/cmd/resource_export_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ func init() {
Long: "Export a resource from your Tracetest server",
Short: "Export resource",
PreRun: setupCommand(),
Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) {
Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) {
resourceType := resourceParams.ResourceName
ctx := context.Background()

resourceClient, err := resources.Get(resourceType)
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions cli/cmd/resource_get_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ func init() {
Short: "Get resource",
Long: "Get a resource from your Tracetest server",
PreRun: setupCommand(),
Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) {
Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) {
resourceType := resourceParams.ResourceName
ctx := context.Background()

resourceClient, err := resources.Get(resourceType)
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions cli/cmd/resource_list_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ func init() {
Short: "List resources",
Long: "List resources from your Tracetest server",
PreRun: setupCommand(),
Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) {
Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) {
resourceType := resourceParams.ResourceName
ctx := context.Background()

resourceClient, err := resources.Get(resourceType)
if err != nil {
Expand Down
3 changes: 1 addition & 2 deletions cli/cmd/resource_run_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ func init() {
Short: "run resources",
Long: "run resources",
PreRun: setupCommand(WithOptionalResourceName()),
Run: WithResourceMiddleware(func(_ *cobra.Command, args []string) (string, error) {
ctx := context.Background()
Run: WithResourceMiddleware(func(ctx context.Context, _ *cobra.Command, args []string) (string, error) {
resourceType, err := getResourceType(runParams, resourceParams)
if err != nil {
return "", err
Expand Down
12 changes: 9 additions & 3 deletions cli/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,16 @@ var rootCmd = &cobra.Command{
}

func Execute() {
if err := rootCmd.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
ExitCLI(1)
handleRootExecErr(rootCmd.Execute())
}

func handleRootExecErr(err error) {
if err == nil {
ExitCLI(0)
}

fmt.Fprintln(os.Stderr, err)
ExitCLI(1)
}

func ExitCLI(errorCode int) {
Expand Down
19 changes: 15 additions & 4 deletions cli/cmd/start_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ import (
"context"
"os"

"github.com/davecgh/go-spew/spew"
agentConfig "github.com/kubeshop/tracetest/agent/config"
"github.com/kubeshop/tracetest/agent/runner"
"github.com/kubeshop/tracetest/agent/ui"
"github.com/kubeshop/tracetest/cli/config"
"github.com/spf13/cobra"
)

var (
agentRunner = runner.NewRunner(configurator, resources, ui.DefaultUI)
agentRunner = runner.NewRunner(configurator.WithErrorHandler(handleError), resources, ui.DefaultUI)
defaultToken = os.Getenv("TRACETEST_TOKEN")
defaultEndpoint = os.Getenv("TRACETEST_SERVER_URL")
defaultAPIKey = os.Getenv("TRACETEST_API_KEY")
Expand All @@ -24,9 +26,7 @@ var startCmd = &cobra.Command{
Short: "Start Tracetest",
Long: "Start using Tracetest",
PreRun: setupCommand(SkipConfigValidation(), SkipVersionMismatchCheck()),
Run: WithResultHandler((func(_ *cobra.Command, _ []string) (string, error) {
ctx := context.Background()

Run: WithResultHandler((func(ctx context.Context, _ *cobra.Command, _ []string) (string, error) {
flags := agentConfig.Flags{
OrganizationID: saveParams.organizationID,
EnvironmentID: saveParams.environmentID,
Expand All @@ -37,6 +37,17 @@ var startCmd = &cobra.Command{
LogLevel: saveParams.logLevel,
}

// override organization and environment id from context.
// this happens when auto rerunning the cmd after relogin
if orgID := config.ContextGetOrganizationID(ctx); orgID != "" {
flags.OrganizationID = orgID
}
if envID := config.ContextGetEnvironmentID(ctx); envID != "" {
flags.EnvironmentID = envID
}

spew.Dump(flags)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spew 🔫


cfg, err := agentConfig.LoadConfig()
if err != nil {
return "", err
Expand Down
4 changes: 3 additions & 1 deletion cli/cmd/version_cmd.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package cmd

import (
"context"

"github.com/spf13/cobra"
)

Expand All @@ -10,7 +12,7 @@ var versionCmd = &cobra.Command{
Short: "Display this CLI tool version",
Long: "Display this CLI tool version",
PreRun: setupCommand(),
Run: WithResultHandler(func(_ *cobra.Command, _ []string) (string, error) {
Run: WithResultHandler(func(_ context.Context, _ *cobra.Command, _ []string) (string, error) {
return versionText, nil
}),
PostRun: teardownCommand,
Expand Down
44 changes: 39 additions & 5 deletions cli/config/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package config

import (
"context"
"encoding/json"
"fmt"
"os"
Expand Down Expand Up @@ -143,26 +144,59 @@ func ParseServerURL(serverURL string) (scheme, endpoint, serverPath string, err
return url.Scheme, url.Host, url.Path, nil
}

func Save(config Config) error {
type orgIDKeyType struct{}
type envIDKeyType struct{}

var orgIDKey = orgIDKeyType{}
var envIDKey = envIDKeyType{}

func ContextWithOrganizationID(ctx context.Context, orgID string) context.Context {
return context.WithValue(ctx, orgIDKey, orgID)
}

func ContextWithEnvironmentID(ctx context.Context, envID string) context.Context {
return context.WithValue(ctx, envIDKey, envID)
}

func ContextGetOrganizationID(ctx context.Context) string {
v := ctx.Value(orgIDKey)
if v == nil {
return ""
}
return v.(string)
}

func ContextGetEnvironmentID(ctx context.Context) string {
v := ctx.Value(envIDKey)
if v == nil {
return ""
}
return v.(string)
}

func Save(ctx context.Context, config Config) (context.Context, error) {
configPath, err := GetConfigurationPath()
if err != nil {
return fmt.Errorf("could not get configuration path: %w", err)
return ctx, fmt.Errorf("could not get configuration path: %w", err)
}

configYml, err := yaml.Marshal(config)
if err != nil {
return fmt.Errorf("could not marshal configuration into yml: %w", err)
return ctx, fmt.Errorf("could not marshal configuration into yml: %w", err)
}

if _, err := os.Stat(configPath); os.IsNotExist(err) {
os.MkdirAll(filepath.Dir(configPath), 0700) // Ensure folder exists
}
err = os.WriteFile(configPath, configYml, 0755)
if err != nil {
return fmt.Errorf("could not write file: %w", err)
return ctx, fmt.Errorf("could not write file: %w", err)
}

return nil
ctx = ContextWithOrganizationID(ctx, config.OrganizationID)
ctx = ContextWithEnvironmentID(ctx, config.EnvironmentID)

return ctx, nil
}

func GetConfigurationPath() (string, error) {
Expand Down
Loading
Loading