diff --git a/cli/cmd/start_cmd.go b/cli/cmd/start_cmd.go index f83e7ccb3e..b10c7bb921 100644 --- a/cli/cmd/start_cmd.go +++ b/cli/cmd/start_cmd.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "os" agentConfig "github.com/kubeshop/tracetest/agent/config" "github.com/kubeshop/tracetest/cli/config" @@ -10,8 +11,9 @@ import ( ) var ( - start = starter.NewStarter(configurator, resources) - saveParams = &saveParameters{} + start = starter.NewStarter(configurator, resources) + defaultToken = os.Getenv("TRACETEST_TOKEN") + saveParams = &saveParameters{} ) var startCmd = &cobra.Command{ @@ -28,6 +30,7 @@ var startCmd = &cobra.Command{ EnvironmentID: saveParams.environmentID, Endpoint: saveParams.endpoint, AgentApiKey: saveParams.agentApiKey, + Token: saveParams.token, } cfg, err := agentConfig.LoadConfig() @@ -49,6 +52,7 @@ func init() { startCmd.Flags().StringVarP(&saveParams.organizationID, "organization", "", "", "organization id") startCmd.Flags().StringVarP(&saveParams.environmentID, "environment", "", "", "environment id") startCmd.Flags().StringVarP(&saveParams.agentApiKey, "api-key", "", "", "agent api key") + startCmd.Flags().StringVarP(&saveParams.token, "token", "", defaultToken, "token api key") startCmd.Flags().StringVarP(&saveParams.endpoint, "endpoint", "e", config.DefaultCloudEndpoint, "set the value for the endpoint, so the CLI won't ask for this value") rootCmd.AddCommand(startCmd) } @@ -58,4 +62,5 @@ type saveParameters struct { environmentID string endpoint string agentApiKey string + token string } diff --git a/cli/config/config.go b/cli/config/config.go index 58de74b0be..a72e1b133b 100644 --- a/cli/config/config.go +++ b/cli/config/config.go @@ -27,6 +27,7 @@ type ConfigFlags struct { EnvironmentID string CI bool AgentApiKey string + Token string } type Config struct { diff --git a/cli/config/configurator.go b/cli/config/configurator.go index ea9b8ff8b1..f8523fc637 100644 --- a/cli/config/configurator.go +++ b/cli/config/configurator.go @@ -6,6 +6,7 @@ import ( "net/http" "strings" + "github.com/golang-jwt/jwt" "github.com/kubeshop/tracetest/cli/analytics" "github.com/kubeshop/tracetest/cli/pkg/oauth" "github.com/kubeshop/tracetest/cli/pkg/resourcemanager" @@ -99,32 +100,46 @@ func (c Configurator) Start(ctx context.Context, prev Config, flags ConfigFlags) return Save(cfg) } - if flags.AgentApiKey != "" { - cfg.AgentApiKey = flags.AgentApiKey - c.ShowOrganizationSelector(ctx, cfg, flags) - return nil - } + oauthEndpoint := fmt.Sprintf("%s%s", cfg.URL(), cfg.Path()) if prev.Jwt != "" { cfg.Jwt = prev.Jwt cfg.Token = prev.Token + } + + if flags.Token != "" { + jwt, err := oauth.ExchangeToken(oauthEndpoint, flags.Token) + if err != nil { + return err + } + + cfg.Jwt = jwt + cfg.Token = flags.Token + + claims, err := GetTokenClaims(jwt) + if err != nil { + return err + } + flags.OrganizationID = claims["organization_id"].(string) + flags.EnvironmentID = claims["environment_id"].(string) + } + + if flags.AgentApiKey != "" { + cfg.AgentApiKey = flags.AgentApiKey c.ShowOrganizationSelector(ctx, cfg, flags) return nil } - confirmed := c.ui.Enter("Lets get to it! Press enter to launch a browser and authenticate:") - if !confirmed { - c.ui.Finish() + if cfg.Jwt != "" { + c.ShowOrganizationSelector(ctx, cfg, flags) return nil } - oauthServer := oauth.NewOAuthServer(fmt.Sprintf("%s%s", cfg.URL(), cfg.Path()), cfg.UIEndpoint) - err = oauthServer.WithOnSuccess(c.onOAuthSuccess(ctx, cfg)). + oauthServer := oauth.NewOAuthServer(oauthEndpoint, cfg.UIEndpoint) + return oauthServer.WithOnSuccess(c.onOAuthSuccess(ctx, cfg)). WithOnFailure(c.onOAuthFailure). GetAuthJWT() - - return err } func (c Configurator) onOAuthSuccess(ctx context.Context, cfg Config) func(token, jwt string) { @@ -182,3 +197,17 @@ func SetupHttpClient(cfg Config) *resourcemanager.HTTPClient { return resourcemanager.NewHTTPClient(fmt.Sprintf("%s%s", cfg.URL(), cfg.Path()), extraHeaders) } + +func GetTokenClaims(tokenString string) (jwt.MapClaims, error) { + token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{}) + if err != nil { + return jwt.MapClaims{}, err + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return jwt.MapClaims{}, fmt.Errorf("invalid token claims") + } + + return claims, nil +} diff --git a/cli/pkg/oauth/oauth.go b/cli/pkg/oauth/oauth.go index d89285953b..0db69d8373 100644 --- a/cli/pkg/oauth/oauth.go +++ b/cli/pkg/oauth/oauth.go @@ -22,6 +22,7 @@ type OAuthServer struct { port int server *http.Server mutex sync.Mutex + ui ui.UI } type Option func(*OAuthServer) @@ -30,6 +31,7 @@ func NewOAuthServer(endpoint, frontendEndpoint string) *OAuthServer { return &OAuthServer{ endpoint: endpoint, frontendEndpoint: frontendEndpoint, + ui: ui.DefaultUI, } } @@ -44,6 +46,12 @@ func (s *OAuthServer) WithOnFailure(onFailure OnAuthFailure) *OAuthServer { } func (s *OAuthServer) GetAuthJWT() error { + confirmed := s.ui.Enter("Lets get to it! Press enter to launch a browser and authenticate:") + if !confirmed { + s.ui.Finish() + return nil + } + url, err := s.getUrl() if err != nil { return fmt.Errorf("failed to start oauth server: %w", err) @@ -64,8 +72,8 @@ type JWTResponse struct { Jwt string `json:"jwt"` } -func (s *OAuthServer) ExchangeToken(token string) (string, error) { - req, err := http.NewRequest("GET", fmt.Sprintf("%s/tokens/%s/exchange", s.endpoint, token), nil) +func ExchangeToken(endpoint string, token string) (string, error) { + req, err := http.NewRequest("GET", fmt.Sprintf("%s/tokens/%s/exchange", endpoint, token), nil) if err != nil { return "", fmt.Errorf("failed to create request: %w", err) } @@ -133,7 +141,7 @@ func (s *OAuthServer) handleResult(r *http.Request) (string, string, error) { return "", "", fmt.Errorf("tokenId not found") } - jwt, err := s.ExchangeToken(tokenId) + jwt, err := ExchangeToken(s.endpoint, tokenId) if err != nil { return "", "", err } diff --git a/cli/pkg/starter/starter.go b/cli/pkg/starter/starter.go index dcc89060bf..5c98ff4eca 100644 --- a/cli/pkg/starter/starter.go +++ b/cli/pkg/starter/starter.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" - "github.com/golang-jwt/jwt/v4" agentConfig "github.com/kubeshop/tracetest/agent/config" "github.com/kubeshop/tracetest/agent/initialization" @@ -31,7 +30,11 @@ func (s *Starter) Run(ctx context.Context, cfg config.Config, flags config.Confi s.ui.Println(`Tracetest start launches a lightweight agent. It enables you to run tests and collect traces with Tracetest. Once started, Tracetest Agent exposes OTLP ports 4317 and 4318 to ingest traces via gRCP and HTTP.`) - return s.configurator.WithOnFinish(s.onStartAgent).Start(ctx, cfg, flags) + if flags.Token == "" || flags.AgentApiKey != "" { + s.configurator = s.configurator.WithOnFinish(s.onStartAgent) + } + + return s.configurator.Start(ctx, cfg, flags) } func (s *Starter) onStartAgent(ctx context.Context, cfg config.Config) { @@ -132,7 +135,7 @@ func (s *Starter) StartAgent(ctx context.Context, endpoint, agentApiKey, uiEndpo isStarted = true } - claims, err := s.getTokenClaims(session.Token) + claims, err := config.GetTokenClaims(session.Token) if err != nil { return err } @@ -160,17 +163,3 @@ You can` } return nil } - -func (s *Starter) getTokenClaims(tokenString string) (jwt.MapClaims, error) { - token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{}) - if err != nil { - return jwt.MapClaims{}, err - } - - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - return jwt.MapClaims{}, fmt.Errorf("invalid token claims") - } - - return claims, nil -}