Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cli): Token Support #3255

Merged
merged 4 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions cli/cmd/start_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cmd

import (
"context"
"os"

agentConfig "github.com/kubeshop/tracetest/agent/config"
"github.com/kubeshop/tracetest/cli/config"
Expand All @@ -10,8 +11,9 @@ import (
)

var (
start = starter.NewStarter(configurator, resources)
saveParams = &saveParameters{}
start = starter.NewStarter(configurator, resources)
defaultToken = os.Getenv("TRACETEST_CLI_API_KEY")
saveParams = &saveParameters{}
)

var startCmd = &cobra.Command{
Expand All @@ -28,6 +30,7 @@ var startCmd = &cobra.Command{
EnvironmentID: saveParams.environmentID,
Endpoint: saveParams.endpoint,
AgentApiKey: saveParams.agentApiKey,
CLIApiKey: saveParams.cliApiKey,
}

cfg, err := agentConfig.LoadConfig()
Expand All @@ -49,6 +52,7 @@ func init() {
startCmd.Flags().StringVarP(&saveParams.organizationID, "organization", "", "", "organization id")
startCmd.Flags().StringVarP(&saveParams.environmentID, "environment", "", "", "environment id")
startCmd.Flags().StringVarP(&saveParams.agentApiKey, "api-key", "", "", "agent api key")
startCmd.Flags().StringVarP(&saveParams.cliApiKey, "cli-api-key", "", defaultToken, "CLI api key")
startCmd.Flags().StringVarP(&saveParams.endpoint, "endpoint", "e", config.DefaultCloudEndpoint, "set the value for the endpoint, so the CLI won't ask for this value")
rootCmd.AddCommand(startCmd)
}
Expand All @@ -58,4 +62,5 @@ type saveParameters struct {
environmentID string
endpoint string
agentApiKey string
cliApiKey string
}
1 change: 1 addition & 0 deletions cli/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type ConfigFlags struct {
EnvironmentID string
CI bool
AgentApiKey string
CLIApiKey string
}

type Config struct {
Expand Down
53 changes: 41 additions & 12 deletions cli/config/configurator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"strings"

"github.com/golang-jwt/jwt"
"github.com/kubeshop/tracetest/cli/analytics"
"github.com/kubeshop/tracetest/cli/pkg/oauth"
"github.com/kubeshop/tracetest/cli/pkg/resourcemanager"
Expand Down Expand Up @@ -99,32 +100,46 @@ func (c Configurator) Start(ctx context.Context, prev Config, flags ConfigFlags)
return Save(cfg)
}

if flags.AgentApiKey != "" {
cfg.AgentApiKey = flags.AgentApiKey
c.ShowOrganizationSelector(ctx, cfg, flags)
return nil
}
oauthEndpoint := fmt.Sprintf("%s%s", cfg.URL(), cfg.Path())

if prev.Jwt != "" {
cfg.Jwt = prev.Jwt
cfg.Token = prev.Token
}

if flags.CLIApiKey != "" {
jwt, err := oauth.ExchangeToken(oauthEndpoint, flags.CLIApiKey)
if err != nil {
return err
}

cfg.Jwt = jwt
cfg.Token = flags.CLIApiKey

claims, err := GetTokenClaims(jwt)
if err != nil {
return err
}

flags.OrganizationID = claims["organization_id"].(string)
flags.EnvironmentID = claims["environment_id"].(string)
}

if flags.AgentApiKey != "" {
cfg.AgentApiKey = flags.AgentApiKey
c.ShowOrganizationSelector(ctx, cfg, flags)
return nil
}

confirmed := c.ui.Enter("Lets get to it! Press enter to launch a browser and authenticate:")
if !confirmed {
c.ui.Finish()
if cfg.Jwt != "" {
c.ShowOrganizationSelector(ctx, cfg, flags)
return nil
}

oauthServer := oauth.NewOAuthServer(fmt.Sprintf("%s%s", cfg.URL(), cfg.Path()), cfg.UIEndpoint)
err = oauthServer.WithOnSuccess(c.onOAuthSuccess(ctx, cfg)).
oauthServer := oauth.NewOAuthServer(oauthEndpoint, cfg.UIEndpoint)
return oauthServer.WithOnSuccess(c.onOAuthSuccess(ctx, cfg)).
WithOnFailure(c.onOAuthFailure).
GetAuthJWT()

return err
}

func (c Configurator) onOAuthSuccess(ctx context.Context, cfg Config) func(token, jwt string) {
Expand Down Expand Up @@ -182,3 +197,17 @@ func SetupHttpClient(cfg Config) *resourcemanager.HTTPClient {

return resourcemanager.NewHTTPClient(fmt.Sprintf("%s%s", cfg.URL(), cfg.Path()), extraHeaders)
}

func GetTokenClaims(tokenString string) (jwt.MapClaims, error) {
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{})
if err != nil {
return jwt.MapClaims{}, err
}

claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return jwt.MapClaims{}, fmt.Errorf("invalid token claims")
}

return claims, nil
}
14 changes: 11 additions & 3 deletions cli/pkg/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type OAuthServer struct {
port int
server *http.Server
mutex sync.Mutex
ui ui.UI
}

type Option func(*OAuthServer)
Expand All @@ -30,6 +31,7 @@ func NewOAuthServer(endpoint, frontendEndpoint string) *OAuthServer {
return &OAuthServer{
endpoint: endpoint,
frontendEndpoint: frontendEndpoint,
ui: ui.DefaultUI,
}
}

Expand All @@ -44,6 +46,12 @@ func (s *OAuthServer) WithOnFailure(onFailure OnAuthFailure) *OAuthServer {
}

func (s *OAuthServer) GetAuthJWT() error {
confirmed := s.ui.Enter("Lets get to it! Press enter to launch a browser and authenticate:")
if !confirmed {
s.ui.Finish()
return nil
}

url, err := s.getUrl()
if err != nil {
return fmt.Errorf("failed to start oauth server: %w", err)
Expand All @@ -64,8 +72,8 @@ type JWTResponse struct {
Jwt string `json:"jwt"`
}

func (s *OAuthServer) ExchangeToken(token string) (string, error) {
req, err := http.NewRequest("GET", fmt.Sprintf("%s/tokens/%s/exchange", s.endpoint, token), nil)
func ExchangeToken(endpoint string, token string) (string, error) {
req, err := http.NewRequest("GET", fmt.Sprintf("%s/tokens/%s/exchange", endpoint, token), nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
Expand Down Expand Up @@ -133,7 +141,7 @@ func (s *OAuthServer) handleResult(r *http.Request) (string, string, error) {
return "", "", fmt.Errorf("tokenId not found")
}

jwt, err := s.ExchangeToken(tokenId)
jwt, err := ExchangeToken(s.endpoint, tokenId)
if err != nil {
return "", "", err
}
Expand Down
23 changes: 6 additions & 17 deletions cli/pkg/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"fmt"

"github.com/golang-jwt/jwt/v4"
agentConfig "github.com/kubeshop/tracetest/agent/config"
"github.com/kubeshop/tracetest/agent/initialization"

Expand All @@ -31,7 +30,11 @@ func (s *Starter) Run(ctx context.Context, cfg config.Config, flags config.Confi
s.ui.Println(`Tracetest start launches a lightweight agent. It enables you to run tests and collect traces with Tracetest.
Once started, Tracetest Agent exposes OTLP ports 4317 and 4318 to ingest traces via gRCP and HTTP.`)

return s.configurator.WithOnFinish(s.onStartAgent).Start(ctx, cfg, flags)
if flags.CLIApiKey == "" || flags.AgentApiKey != "" {
s.configurator = s.configurator.WithOnFinish(s.onStartAgent)
}

return s.configurator.Start(ctx, cfg, flags)
}

func (s *Starter) onStartAgent(ctx context.Context, cfg config.Config) {
Expand Down Expand Up @@ -132,7 +135,7 @@ func (s *Starter) StartAgent(ctx context.Context, endpoint, agentApiKey, uiEndpo
isStarted = true
}

claims, err := s.getTokenClaims(session.Token)
claims, err := config.GetTokenClaims(session.Token)
if err != nil {
return err
}
Expand Down Expand Up @@ -160,17 +163,3 @@ You can`
}
return nil
}

func (s *Starter) getTokenClaims(tokenString string) (jwt.MapClaims, error) {
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{})
if err != nil {
return jwt.MapClaims{}, err
}

claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return jwt.MapClaims{}, fmt.Errorf("invalid token claims")
}

return claims, nil
}