From dc9b239e7509d12f0c6fa622df882759ffaf4280 Mon Sep 17 00:00:00 2001 From: Drewry Pope Date: Fri, 23 Feb 2024 04:08:53 -0600 Subject: [PATCH] refactor config --- sources/identity/auth/connection.go | 30 +- sources/identity/cmd/cmd.go | 378 +++++++++--------- .../identity/configuration/configuration.go | 8 + sources/identity/main.go | 4 +- sources/identity/start-server-stream.ps1 | 95 +++++ sources/identity/web/server.go | 16 + 6 files changed, 351 insertions(+), 180 deletions(-) create mode 100644 sources/identity/start-server-stream.ps1 diff --git a/sources/identity/auth/connection.go b/sources/identity/auth/connection.go index 3f2bb399..57b38e81 100644 --- a/sources/identity/auth/connection.go +++ b/sources/identity/auth/connection.go @@ -31,7 +31,7 @@ func GetConnectionMap(ctx context.Context) (*SafeConnectionMap, bool) { } type SafeConnectionMap struct { - mu sync.RWMutex + mu sync.RWMutex // todo ristretto data map[string]*Connection } @@ -123,6 +123,20 @@ func (sm *SafeConnectionMap) Values() []Connection { return values } +func (sm *SafeConnectionMap) ValuesUnique() []*Connection { + sm.mu.RLock() + defer sm.mu.RUnlock() + values := make([]*Connection, 0, len(sm.data)) + unique := make(map[string]bool) + for _, v := range sm.data { + if _, ok := unique[*v.ConnectionID]; !ok { + unique[*v.ConnectionID] = true + values = append(values, v) + } + } + return values +} + // ValuesRef safely retrieves all values from the map // Be careful with this, as it allows you to modify the map func (sm *SafeConnectionMap) ValuesRef() []*Connection { @@ -135,6 +149,20 @@ func (sm *SafeConnectionMap) ValuesRef() []*Connection { return values } +func (sm *SafeConnectionMap) ValuesRefUnique() []*Connection { + sm.mu.RLock() + defer sm.mu.RUnlock() + values := make([]*Connection, 0, len(sm.data)) + unique := make(map[string]bool) + for _, v := range sm.data { + if _, ok := unique[*v.ConnectionID]; !ok { + unique[*v.ConnectionID] = true + values = append(values, v) + } + } + return values +} + // Connection represents a connection type Connection struct { ConnectionID *string diff --git a/sources/identity/cmd/cmd.go b/sources/identity/cmd/cmd.go index 2d5cc9b4..776aeaa6 100644 --- a/sources/identity/cmd/cmd.go +++ b/sources/identity/cmd/cmd.go @@ -59,6 +59,198 @@ remote config ( // todo: dependency injection / remove globals and inits ??? */ + +var Separator = "." +var ConfigurationFilePath = "config.kdl" +var EmbeddedConfigurationFilePath = "embed/config.kdl" +var GeneratedKeyDirPath = ".ssh/generated" +var HostKeyPath = ".ssh/term_info_ed25519" +var ScpFileSystemDirPath = "scp" + +func NewConfiguration() *configuration.IdentityServerConfiguration { + return &configuration.IdentityServerConfiguration{ + Configuration: koanf.New(Separator), + ConfigurationLocations: &configuration.ConfigurationLocations{ + ConfigurationFilePaths: []string{ + ConfigurationFilePath, + // identity.kdl identity.config.kdl config.identity.kdl identity.config + // run these against ? binary dir ? pwd of execution ? appdata ? .config ? .local ??? + // then check for further locations/env-prefixes/etc from first pass, rerun on top with second pass + // (maybe config.kdl next to binary sets a new set of configurationPaths, finish out loading from defaults, then load from new paths) + // this pattern continues, after hard-code default env/file search, then custom file/env search, then eventually maybe nats/s3 or other remote or db config + }, + EmbeddedConfigurationFilePaths: []string{ + EmbeddedConfigurationFilePath, + }, + }, + EmbedFS: &configuration.EmbedFS, + } +} + +func StartCharmCmd(config *configuration.IdentityServerConfiguration) *cobra.Command { + result := charmcmd.ServeCmd + result.Use = "charm" + result.Aliases = []string{"ch", "c"} + return result +} + +func StartAllAltCmd(command cobra.Command) *cobra.Command { + result := command + result.Use = "all" + result.Aliases = []string{"al", "a"} + return &result +} + +func LoadDefaultConfiguration() *configuration.IdentityServerConfiguration { + config := NewConfiguration() + config.LoadConfiguration() + log.Info("Loaded config", "config", config.Configuration.Sprint()) + return config +} + +func RootCmd(config *configuration.IdentityServerConfiguration) *cobra.Command { + result := &cobra.Command{ + Use: "identity", + Short: "publish your identity", + Long: `publish your identity and allow others to connect to you.`, + } + result.AddCommand(charmcmd.RootCmd, StartAllCmd(config)) + return result +} + +func StartAllCmd(config *configuration.IdentityServerConfiguration) *cobra.Command { + result := &cobra.Command{ + Use: "start", + Short: "Starts the identity and charm servers", + Run: StartAll(config), + Aliases: []string{"s", "run", "serve", "publish", "pub", "p", "i", "y", "u", "o", "p", "q", "w", "e", "r", "t", "a", "s", "d", "f", "g", "h", "j", "k", "l", "z", "x", "c", "v", "b"}, + } + result.AddCommand(StartCharmCmd(config), StartIdentityCmd(config), StartStreamCmd(config), StartAllAltCmd(*result)) + return result +} + +func StartIdentityCmd(config *configuration.IdentityServerConfiguration) *cobra.Command { + return &cobra.Command{ + Use: "identity", + Short: "Starts only the identity server", + Run: StartIdentity(config), + Aliases: []string{"id", "i"}, + } +} + +func StartStreamCmd(config *configuration.IdentityServerConfiguration) *cobra.Command { + return &cobra.Command{ + Use: "stream", + Short: "Starts only the stream server", + Run: StartStream(config), + Aliases: []string{"tr", "t"}, + } +} + +func StartAll(config *configuration.IdentityServerConfiguration) func(*cobra.Command, []string) { + return func(cmd *cobra.Command, args []string) { + tasks := []func(*cobra.Command, []string){ + StartCharm(config), + StartIdentity(config), + StartStream(config), + } + + var wg sync.WaitGroup + wg.Add(len(tasks)) + + for _, task := range tasks { + go RunTask(&wg, task)(cmd, args) + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + <-done + fmt.Println("All tasks completed. Proceeding to cleanup and shutdown.") + } +} + +func RunTask(wg *sync.WaitGroup, taskFunc func(*cobra.Command, []string)) func(*cobra.Command, []string) { + return func(cmd *cobra.Command, args []string) { + defer wg.Done() + taskFunc(cmd, args) + } +} + +func StartStream(config *configuration.IdentityServerConfiguration) func(*cobra.Command, []string) { + return func(cmd *cobra.Command, args []string) { + log.Info("Starting stream server") + } +} + +func StartCharm(config *configuration.IdentityServerConfiguration) func(*cobra.Command, []string) { + return func(cmd *cobra.Command, args []string) { + log.Info("Starting charm server") + charmcmd.ServeCmdRunE(cmd, args) + } +} + +func StartIdentity(config *configuration.IdentityServerConfiguration) func(*cobra.Command, []string) { + return func(cmd *cobra.Command, args []string) { + log.Info("Starting identity server") + // todo split web and ssh into separate functions + connections := auth.NewSafeConnectionMap() + web.GoRunWebServer(connections, config) + handler := scp.NewFileSystemHandler(ScpFileSystemDirPath) + s, err := wish.NewServer( + wish.WithMiddleware( + scp.Middleware(handler, handler), + bubbletea.Middleware(TeaHandler), // todo: before bubbletea, use non-fullscreen teahandler to accept TOS if not this verion accepted in DB. check connection for previous tos, but this might need to be a charm user column? // separate todo: add tos table in database and pulldown latest tos on boot? + comment.Middleware("Thanks, have a nice day!"), + elapsed.Middleware(), + promwish.Middleware("0.0.0.0:9222", "identity"), + logging.Middleware(), + observability.Middleware(connections), + ), + wish.WithPasswordAuth(func(ctx ssh.Context, password string) bool { + log.Info("Accepting password", "password", password, "len", len(password)) + return Connect(ctx, nil, &password, nil, connections) + }), + wish.WithKeyboardInteractiveAuth(func(ctx ssh.Context, challenge gossh.KeyboardInteractiveChallenge) bool { + log.Info("Accepting keyboard interactive") + return Connect(ctx, nil, nil, challenge, connections) + }), + wish.WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { + log.Info("Accepting public key", "publicKeyType", key.Type(), "publicKeyString", base64.StdEncoding.EncodeToString(key.Marshal())) + return Connect(ctx, key, nil, nil, connections) + }), + wish.WithBannerHandler(Banner(config)), + wish.WithAddress(fmt.Sprintf("%s:%d", config.Configuration.String("identity.server.host"), config.Configuration.Int("identity.server.ssh.port"))), + wish.WithHostKeyPath(HostKeyPath), + ) + if err != nil { + log.Error("could not start server", "error", err) + return + } + + done := make(chan os.Signal, 1) + signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + log.Info("Starting ssh server", "identity.server.host", config.Configuration.String("identity.server.host"), "identity.server.ssh.port", config.Configuration.Int("identity.server.ssh.port"), "address", fmt.Sprintf("%s:%d", config.Configuration.String("identity.server.host"), config.Configuration.Int("identity.server.ssh.port"))) + go func() { + if err := s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { + log.Error("could not start server", "error", err) + done <- os.Interrupt + } + }() + + <-done + log.Info("Stopping ssh server", "identity.server.host", config.Configuration.String("identity.server.host"), "identity.server.ssh.port", config.Configuration.Int("identity.server.ssh.port"), "address", fmt.Sprintf("%s:%d", config.Configuration.String("identity.server.host"), config.Configuration.Int("identity.server.ssh.port"))) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { + log.Error("could not stop server", "error", err) + } + } +} + type errMsg error type model struct { @@ -88,19 +280,19 @@ func (m model) Init() tea.Cmd { return m.spinner.Tick } -const useHighPerformanceRenderer = false +const UseHighPerformanceRenderer = false var ( - titleStyle = func() lipgloss.Style { + TitleStyle = func() lipgloss.Style { b := lipgloss.RoundedBorder() b.Right = "├" return lipgloss.NewStyle().BorderStyle(b).Padding(0, 1) }() - infoStyle = func() lipgloss.Style { + InfoStyle = func() lipgloss.Style { b := lipgloss.RoundedBorder() b.Left = "┤" - return titleStyle.Copy().BorderStyle(b) + return TitleStyle.Copy().BorderStyle(b) }() ) @@ -233,85 +425,7 @@ func (m model) View() string { return m.viewport.View() } -var separator = "." -var configurationFilePath = "config.kdl" -var embeddedConfigurationFilePath = "embed/config.kdl" -var generatedKeyDirPath = ".ssh/generated" -var hostKeyPath = ".ssh/term_info_ed25519" -var scpFileSystemDirPath = "scp" -var config = NewConfiguration() - -func initializeConfig() { - config = NewConfiguration() -} - -func NewConfiguration() configuration.IdentityServerConfiguration { - return configuration.IdentityServerConfiguration{ - Configuration: koanf.New(separator), - ConfigurationLocations: &configuration.ConfigurationLocations{ - ConfigurationFilePaths: []string{ - configurationFilePath, - // identity.kdl identity.config.kdl config.identity.kdl identity.config - // run these against ? binary dir ? pwd of execution ? appdata ? .config ? .local ??? - // then check for further locations/env-prefixes/etc from first pass, rerun on top with second pass - // (maybe config.kdl next to binary sets a new set of configurationPaths, finish out loading from defaults, then load from new paths) - // this pattern continues, after hard-code default env/file search, then custom file/env search, then eventually maybe nats/s3 or other remote or db config - }, - EmbeddedConfigurationFilePaths: []string{ - embeddedConfigurationFilePath, - }, - }, - EmbedFS: &configuration.EmbedFS, - } -} - -func initializeAndLoadConfiguration() { - initializeConfig() - log.Debug("Initialized config", "config", config.Configuration.Sprint()) - config.LoadConfiguration() - log.Info("Loaded config", "config", config.Configuration.Sprint()) -} - -func init() { - cobra.OnInitialize(initializeAndLoadConfiguration) - StartCharmCmd := charmcmd.ServeCmd - StartCharmCmd.Use = "charm" - StartCharmCmd.Aliases = []string{"ch", "c"} - StartAllCmd.AddCommand(StartCharmCmd) - StartAllCmd.AddCommand(StartIdentityCmd) - startAllAltCmd := *StartAllCmd - StartAllAltCmd = &startAllAltCmd - StartAllAltCmd.Use = "all" - StartAllAltCmd.Aliases = []string{"al", "a"} - StartAllCmd.AddCommand(StartAllAltCmd) - RootCmd.AddCommand(StartAllCmd) - RootCmd.AddCommand(charmcmd.RootCmd) -} - -var StartCharmCmd = &cobra.Command{} // constructed in init -var StartAllAltCmd = &cobra.Command{} // constructed in init - -var RootCmd = &cobra.Command{ - Use: "identity", - Short: "publish your identity", - Long: `publish your identity and allow others to connect to you.`, -} - -var StartAllCmd = &cobra.Command{ - Use: "start", - Short: "Starts the identity and charm servers", - Run: startAll, - Aliases: []string{"s", "run", "serve", "publish", "pub", "p", "i", "y", "u", "o", "p", "q", "w", "e", "r", "t", "a", "s", "d", "f", "g", "h", "j", "k", "l", "z", "x", "c", "v", "b"}, -} - -var StartIdentityCmd = &cobra.Command{ - Use: "identity", - Short: "Starts only the identity server", - Run: startIdentity, - Aliases: []string{"id", "i"}, -} - -func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) { +func TeaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) { pty, _, active := s.Pty() if !active { wish.Fatalln(s, "no active terminal, skipping") @@ -337,8 +451,9 @@ func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) { return m, []tea.ProgramOption{tea.WithAltScreen()} } -func Banner(ctx ssh.Context) string { - return ` +func Banner(config *configuration.IdentityServerConfiguration) func(ctx ssh.Context) string { + return func(ctx ssh.Context) string { + return ` Welcome to the identity server! ("The Service") By using The Service, you agree to all of the following terms and conditions. @@ -372,99 +487,6 @@ If you do not agree to all of the above terms and conditions, then you may not u ` + fmt.Sprintf("Your client version is %s\n", ctx.ClientVersion()) + ` ` + fmt.Sprintf("Your session id is %s\n", ctx.SessionID()) + ` ` + fmt.Sprintf("You are connecting with user %s\n", ctx.User()) -} - -func startAll(cmd *cobra.Command, args []string) { - var wg sync.WaitGroup - done := make(chan struct{}) // Channel to signal all tasks are done - - // Helper function for running tasks - runTask := func(taskFunc func(*cobra.Command, []string)) { - defer wg.Done() - - taskFunc(cmd, args) // Execute the task - // After task completion, optionally signal done for cleanup - } - - wg.Add(2) // Prepare for two goroutines - - // Start startCharm in its own goroutine - go runTask(func(cmd *cobra.Command, args []string) { - startCharm(cmd, args) - }) - - // Start startIdentity in its own goroutine - go runTask(func(cmd *cobra.Command, args []string) { - startIdentity(cmd, args) - }) - - go func() { - wg.Wait() // Wait for both tasks to complete - close(done) // Signal that all tasks are done - }() - - // Wait for the done signal before proceeding to cleanup or exit - <-done - fmt.Println("All tasks completed. Proceeding to cleanup and shutdown.") -} - -func startCharm(cmd *cobra.Command, args []string) { - charmcmd.ServeCmdRunE(cmd, args) -} - -func startIdentity(cmd *cobra.Command, args []string) { - // todo split web and ssh into separate functions - connections := auth.NewSafeConnectionMap() - web.GoRunWebServer(connections, &config) - handler := scp.NewFileSystemHandler(scpFileSystemDirPath) - s, err := wish.NewServer( - wish.WithMiddleware( - scp.Middleware(handler, handler), - bubbletea.Middleware(teaHandler), // todo: before bubbletea, use non-fullscreen teahandler to accept TOS if not this verion accepted in DB. check connection for previous tos, but this might need to be a charm user column? // separate todo: add tos table in database and pulldown latest tos on boot? - comment.Middleware("Thanks, have a nice day!"), - elapsed.Middleware(), - promwish.Middleware("0.0.0.0:9222", "identity"), - logging.Middleware(), - observability.Middleware(connections), - ), - wish.WithPasswordAuth(func(ctx ssh.Context, password string) bool { - log.Info("Accepting password", "password", password, "len", len(password)) - return Connect(ctx, nil, &password, nil, connections) - }), - wish.WithKeyboardInteractiveAuth(func(ctx ssh.Context, challenge gossh.KeyboardInteractiveChallenge) bool { - log.Info("Accepting keyboard interactive") - return Connect(ctx, nil, nil, challenge, connections) - }), - wish.WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { - log.Info("Accepting public key", "publicKeyType", key.Type(), "publicKeyString", base64.StdEncoding.EncodeToString(key.Marshal())) - return Connect(ctx, key, nil, nil, connections) - }), - wish.WithBannerHandler(Banner), - wish.WithAddress(fmt.Sprintf("%s:%d", config.Configuration.String("identity.server.host"), config.Configuration.Int("identity.server.ssh.port"))), - wish.WithHostKeyPath(hostKeyPath), - ) - if err != nil { - log.Error("could not start server", "error", err) - return - } - - done := make(chan os.Signal, 1) - signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) - log.Info("Starting ssh server", "identity.server.host", config.Configuration.String("identity.server.host"), "identity.server.ssh.port", config.Configuration.Int("identity.server.ssh.port"), "address", fmt.Sprintf("%s:%d", config.Configuration.String("identity.server.host"), config.Configuration.Int("identity.server.ssh.port"))) - go func() { - if err := s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { - log.Error("could not start server", "error", err) - done <- os.Interrupt - } - }() - - <-done - log.Info("Stopping ssh server", "identity.server.host", config.Configuration.String("identity.server.host"), "identity.server.ssh.port", config.Configuration.Int("identity.server.ssh.port"), "address", fmt.Sprintf("%s:%d", config.Configuration.String("identity.server.host"), config.Configuration.Int("identity.server.ssh.port"))) - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { - log.Error("could not stop server", "error", err) } } diff --git a/sources/identity/configuration/configuration.go b/sources/identity/configuration/configuration.go index d8a3283c..4520af53 100644 --- a/sources/identity/configuration/configuration.go +++ b/sources/identity/configuration/configuration.go @@ -152,3 +152,11 @@ func (c *IdentityServerConfiguration) LoadConfiguration() { log.Info("Loaded file configuration", "config", c.Configuration.Sprint()) } + +func (c *IdentityServerConfiguration) SetConfiguration(config *IdentityServerConfiguration) { + c.Configuration = config.Configuration + c.ConfigurationLocations = config.ConfigurationLocations + c.EmbedFS = config.EmbedFS + c.JWTAudience = config.JWTAudience + c.JWKSProvider = config.JWKSProvider +} diff --git a/sources/identity/main.go b/sources/identity/main.go index 29929088..711475c9 100644 --- a/sources/identity/main.go +++ b/sources/identity/main.go @@ -8,7 +8,9 @@ import ( ) func main() { - if err := cmd.RootCmd.Execute(); err != nil { + configuration := cmd.LoadDefaultConfiguration() + + if err := cmd.RootCmd(configuration).Execute(); err != nil { fmt.Println(err) os.Exit(1) } diff --git a/sources/identity/start-server-stream.ps1 b/sources/identity/start-server-stream.ps1 new file mode 100644 index 00000000..f73520fb --- /dev/null +++ b/sources/identity/start-server-stream.ps1 @@ -0,0 +1,95 @@ +#!/usr/bin/env pwsh +param( + [switch]$FastBuild, + [switch]$Tidy, + [switch]$SkipBuild, + [switch]$SkipBuildWebJs, + [switch]$SkipBuildTempl, + [switch]$SkipBuildGoGenerate, + [switch]$SkipBuildGoModTidy, + [switch]$SkipBuildGoGet, + [switch]$SkipBuildGoBuild, + [switch]$SkipBuildGoExperiment, + [switch]$Update, + [switch]$ForceInstallTempl +) + +Set-StrictMode -Version Latest + +$PSNativeCommandUseErrorActionPreference = $true + +if ($PSNativeCommandUseErrorActionPreference) { + # always true, this is a linter workaround + $ErrorActionPreference = "Stop" + $PSDefaultParameterValues['*:ErrorAction'] = 'Stop' +} + +$originalVerbosePreference = $VerbosePreference +$VerbosePreference = 'Continue' + +Write-Verbose "script: $($MyInvocation.MyCommand.Name)" +Write-Verbose "psscriptroot: $PSScriptRoot" +Write-Verbose "full script path: $PSScriptRoot$([IO.Path]::DirectorySeparatorChar)$($MyInvocation.MyCommand.Name)" +Write-Verbose "originalVerbosePreference: $originalVerbosePreference" +Write-Verbose "VerbosePreference: $VerbosePreference" + +if ($FastBuild) { + $SkipBuildWebJs = $true + $SkipBuildTempl = $true + $SkipBuildGoGenerate = $true + $SkipBuildGoModTidy = $true + $SkipBuildGoGet = $true + $SkipBuildGoExperiment = $true +} + +if ($Tidy) { + $SkipBuildGoModTidy = $false +} + +try { + + $cwd = Get-Location + + Write-Verbose "Current directory: $cwd" + + try { + + Write-Verbose "Set-Location $PSScriptRoot" + + Set-Location $PSScriptRoot + + if (-not $SkipBuild) { + Write-Verbose "Building libsql" + ."$PSScriptRoot/build-libsql.ps1" -ForceInstallTempl:$ForceInstallTempl -Update:$Update -SkipBuildWebJs:$SkipBuildWebJs -SkipBuildTempl:$SkipBuildTempl -SkipBuildGoGenerate:$SkipBuildGoGenerate -SkipBuildGoModTidy:$SkipBuildGoModTidy -SkipBuildGoGet:$SkipBuildGoGet -SkipBuildGoBuild:$SkipBuildGoBuild -SkipBuildGoExperiment:$SkipBuildGoExperiment + } + else { + Write-Verbose "Skipping libsql build" + } + $env:CHARM_SERVER_DB_DRIVER = "libsql" + + if ([string]::IsNullOrEmpty($env:TURSO_HOST)) { + throw "TURSO_HOST environment variable must be set" + } + if ([string]::IsNullOrEmpty($env:TURSO_AUTH_TOKEN)) { + throw "TURSO_AUTH_TOKEN environment variable must be set" + } + $env:CHARM_SERVER_DB_DATA_SOURCE = "libsql://${env:TURSO_HOST}?authToken=${env:TURSO_AUTH_TOKEN}" + + $serverType = [System.IO.Path]::GetFileNameWithoutExtension($MyInvocation.MyCommand.Name) -replace '(?i)^start-server-', '' -replace '-', ' ' -replace ',', ' ' + + Write-Verbose "serverType: $serverType" + + Write-Verbose "./identity serve $serverType" + + Invoke-Expression "./identity serve $serverType" + } + finally { + Write-Verbose "Set-Location $cwd" + + Set-Location $cwd + } +} +finally { + Write-Verbose "Resetting VerbosePreference to $originalVerbosePreference" + $VerbosePreference = $originalVerbosePreference +} diff --git a/sources/identity/web/server.go b/sources/identity/web/server.go index a8949c6d..e4c4e35a 100644 --- a/sources/identity/web/server.go +++ b/sources/identity/web/server.go @@ -153,6 +153,17 @@ func NewJWKSValidator(jwksURL *url.URL, issuer *url.URL, audience []string, cach return func(token *jwt.Token) (interface{}, error) { log.Info("Validating token", "token", token) + alg, ok := token.Header["alg"].(string) + if !ok { + log.Error("Expecting JWT header to have string 'alg'") + return nil, fmt.Errorf("expecting JWT header to have string 'alg'") + } + log.Info("Algorithm", "alg", alg) + if alg == "" || strings.ToLower(alg) == "none" { + log.Error("Invalid algorithm", "alg", alg) + return nil, fmt.Errorf("invalid algorithm") + } + kid, ok := token.Header["kid"].(string) if !ok { log.Error("Expecting JWT header to have string 'kid'") @@ -190,6 +201,10 @@ func NewJWKSValidator(jwksURL *url.URL, issuer *url.URL, audience []string, cach log.Error("Key ID does not match", "kid", kid, "keyID", cacheKey.KeyID, "cacheKey", cacheKey) return nil, fmt.Errorf("key ID does not match") } + if cacheKey.Algorithm != alg { + log.Error("Algorithm does not match", "alg", alg, "algorithm", cacheKey.Algorithm, "cacheKey", cacheKey) + return nil, fmt.Errorf("algorithm does not match") + } if !cacheKey.Valid() { log.Error("Key is not valid", "cacheKey", cacheKey) return nil, fmt.Errorf("key is not valid") @@ -206,6 +221,7 @@ func NewJWKSValidator(jwksURL *url.URL, issuer *url.URL, audience []string, cach switch key := cacheKey.Key.(type) { case ed25519.PublicKey: log.Info("Key", "key", key, "type", "ed25519.PublicKey") + return key, nil default: log.Error("Key is not a valid type", "key", key, "type", fmt.Sprintf("%T", key), "expectedType", []string{"ed25519.PublicKey"})