diff --git a/bootstrap/container/clients.go b/bootstrap/container/clients.go index fb8d6f87..18f94f52 100644 --- a/bootstrap/container/clients.go +++ b/bootstrap/container/clients.go @@ -180,3 +180,15 @@ func ScheduleActionRecordClientFrom(get di.Get) interfaces.ScheduleActionRecordC return get(ScheduleActionRecordClientName).(interfaces.ScheduleActionRecordClient) } + +// SecurityProxyAuthClientName contains the name of the AuthClient's implementation in the DIC. +var SecurityProxyAuthClientName = di.TypeInstanceToName((*interfaces.AuthClient)(nil)) + +// SecurityProxyAuthClientFrom helper function queries the DIC and returns the AuthClient's implementation. +func SecurityProxyAuthClientFrom(get di.Get) interfaces.AuthClient { + if get(SecurityProxyAuthClientName) == nil { + return nil + } + + return get(SecurityProxyAuthClientName).(interfaces.AuthClient) +} diff --git a/bootstrap/controller/commonapi.go b/bootstrap/controller/commonapi.go index 8701f0fc..26600b76 100644 --- a/bootstrap/controller/commonapi.go +++ b/bootstrap/controller/commonapi.go @@ -46,8 +46,7 @@ type config struct { func NewCommonController(dic *di.Container, r *echo.Echo, serviceName string, serviceVersion string) *CommonController { lc := container.LoggingClientFrom(dic.Get) - secretProvider := container.SecretProviderExtFrom(dic.Get) - authenticationHook := handlers.AutoConfigAuthenticationFunc(secretProvider, lc) + authenticationHook := handlers.AutoConfigAuthenticationFunc(dic) configuration := container.ConfigurationFrom(dic.Get) c := CommonController{ dic: dic, diff --git a/bootstrap/handlers/auth_func.go b/bootstrap/handlers/auth_func.go index 61d72d2a..dd4ffabd 100644 --- a/bootstrap/handlers/auth_func.go +++ b/bootstrap/handlers/auth_func.go @@ -18,10 +18,8 @@ import ( "os" "strconv" - "github.com/edgexfoundry/go-mod-core-contracts/v4/clients/logger" - - "github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/interfaces" "github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/secret" + "github.com/edgexfoundry/go-mod-bootstrap/v4/di" "github.com/labstack/echo/v4" ) @@ -44,12 +42,12 @@ func NilAuthenticationHandlerFunc() echo.MiddlewareFunc { // to disable JWT validation. This might be wanted for an EdgeX // adopter that wanted to only validate JWT's at the proxy layer, // or as an escape hatch for a caller that cannot authenticate. -func AutoConfigAuthenticationFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient) echo.MiddlewareFunc { +func AutoConfigAuthenticationFunc(dic *di.Container) echo.MiddlewareFunc { // Golang standard library treats an error as false disableJWTValidation, _ := strconv.ParseBool(os.Getenv("EDGEX_DISABLE_JWT_VALIDATION")) authenticationHook := NilAuthenticationHandlerFunc() if secret.IsSecurityEnabled() && !disableJWTValidation { - authenticationHook = SecretStoreAuthenticationHandlerFunc(secretProvider, lc) + authenticationHook = AuthenticationHandlerFunc(dic) } return authenticationHook } diff --git a/bootstrap/handlers/auth_middleware.go b/bootstrap/handlers/auth_middleware.go index f277d4c2..8e99b963 100644 --- a/bootstrap/handlers/auth_middleware.go +++ b/bootstrap/handlers/auth_middleware.go @@ -22,34 +22,43 @@ import ( "net/http" "strings" + "github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/container" + "github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/handlers/headers" "github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/interfaces" "github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/zerotrust" + "github.com/edgexfoundry/go-mod-bootstrap/v4/di" "github.com/edgexfoundry/go-mod-core-contracts/v4/clients/logger" + dtoCommon "github.com/edgexfoundry/go-mod-core-contracts/v4/dtos/common" + "github.com/golang-jwt/jwt/v5" "github.com/labstack/echo/v4" "github.com/openziti/sdk-golang/ziti/edge" ) -// SecretStoreAuthenticationHandlerFunc prefixes an existing HandlerFunc -// with a OpenBao-based JWT authentication check. Usage: +// openBaoIssuer defines the issuer if JWT was issued from OpenBao +const openBaoIssuer = "/v1/identity/oidc" + +// AuthenticationHandlerFunc prefixes an existing HandlerFunc, +// performing authentication checks based on OpenBao-issued JWTs or external JWTs by checking the Authorization header. Usage: +// +// authenticationHook := handlers.NilAuthenticationHandlerFunc() // -// authenticationHook := handlers.NilAuthenticationHandlerFunc() -// if secret.IsSecurityEnabled() { -// lc := container.LoggingClientFrom(dic.Get) -// secretProvider := container.SecretProviderFrom(dic.Get) -// authenticationHook = handlers.SecretStoreAuthenticationHandlerFunc(secretProvider, lc) -// } -// For optionally-authenticated requests -// r.HandleFunc("path", authenticationHook(handlerFunc)).Methods(http.MethodGet) +// if secret.IsSecurityEnabled() { +// authenticationHook = handlers.AuthenticationHandlerFunc(dic) +// } +// For optionally-authenticated requests +// r.HandleFunc("path", authenticationHook(handlerFunc)).Methods(http.MethodGet) // -// For unauthenticated requests -// r.HandleFunc("path", handlerFunc).Methods(http.MethodGet) +// For unauthenticated requests +// r.HandleFunc("path", handlerFunc).Methods(http.MethodGet) // // For typical usage, it is preferred to use AutoConfigAuthenticationFunc which // will automatically select between a real and a fake JWT validation handler. -func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient) echo.MiddlewareFunc { +func AuthenticationHandlerFunc(dic *di.Container) echo.MiddlewareFunc { return func(inner echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + lc := container.LoggingClientFrom(dic.Get) + secretProvider := container.SecretProviderExtFrom(dic.Get) r := c.Request() w := c.Response() authHeader := r.Header.Get("Authorization") @@ -70,20 +79,29 @@ func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProvid authParts := strings.Split(authHeader, " ") if len(authParts) >= 2 && strings.EqualFold(authParts[0], "Bearer") { token := authParts[1] - validToken, err := secretProvider.IsJWTValid(token) - if err != nil { - lc.Errorf("Error checking JWT validity: %v", err) - // set Response.Committed to true in order to rewrite the status code + + parser := jwt.NewParser() + parsedToken, _, jwtErr := parser.ParseUnverified(token, &jwt.MapClaims{}) + if jwtErr != nil { w.Committed = false - return echo.NewHTTPError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) - } else if !validToken { - lc.Warnf("Request to '%s' UNAUTHORIZED", r.URL.Path) - // set Response.Committed to true in order to rewrite the status code + return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + } + issuer, jwtErr := parsedToken.Claims.GetIssuer() + if jwtErr != nil { w.Committed = false return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) } - lc.Debugf("Request to '%s' authorized", r.URL.Path) - return inner(c) + + if issuer == openBaoIssuer { + return SecretStoreAuthenticationHandlerFunc(secretProvider, lc, token, c) + } else { + // Verify the JWT by invoking security-proxy-auth http client + err := headers.VerifyJWT(token, issuer, parsedToken.Method.Alg(), dic, r.Context()) + if err != nil { + errResp := dtoCommon.NewBaseResponse("", err.Error(), err.Code()) + return c.JSON(err.Code(), errResp) + } + } } err := fmt.Errorf("unable to parse JWT for call to '%s'; unauthorized", r.URL.Path) lc.Errorf("%v", err) @@ -93,3 +111,24 @@ func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProvid } } } + +// SecretStoreAuthenticationHandlerFunc verifies the JWT with a OpenBao-based JWT authentication check +func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient, token string, c echo.Context) error { + r := c.Request() + w := c.Response() + + validToken, err := secretProvider.IsJWTValid(token) + if err != nil { + lc.Errorf("Error checking JWT validity by the secret provider: %v ", err) + // set Response.Committed to true in order to rewrite the status code + w.Committed = false + return echo.NewHTTPError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } else if !validToken { + lc.Warnf("Request to '%s' UNAUTHORIZED", r.URL.Path) + // set Response.Committed to true in order to rewrite the status code + w.Committed = false + return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + } + lc.Debugf("Request to '%s' authorized", r.URL.Path) + return nil +} diff --git a/bootstrap/handlers/auth_middleware_no_ziti.go b/bootstrap/handlers/auth_middleware_no_ziti.go index c5deb293..b21d2d76 100644 --- a/bootstrap/handlers/auth_middleware_no_ziti.go +++ b/bootstrap/handlers/auth_middleware_no_ziti.go @@ -27,26 +27,27 @@ import ( "github.com/labstack/echo/v4" ) -// SecretStoreAuthenticationHandlerFunc prefixes an existing HandlerFunc -// with a OpenBao-based JWT authentication check. Usage: +// AuthenticationHandlerFunc prefixes an existing HandlerFunc, +// performing authentication checks based on OpenBao-issued JWTs or external JWTs by checking the Authorization header. Usage: // -// authenticationHook := handlers.NilAuthenticationHandlerFunc() -// if secret.IsSecurityEnabled() { -// lc := container.LoggingClientFrom(dic.Get) -// secretProvider := container.SecretProviderFrom(dic.Get) -// authenticationHook = handlers.SecretStoreAuthenticationHandlerFunc(secretProvider, lc) -// } -// For optionally-authenticated requests -// r.HandleFunc("path", authenticationHook(handlerFunc)).Methods(http.MethodGet) +// authenticationHook := handlers.NilAuthenticationHandlerFunc() // -// For unauthenticated requests -// r.HandleFunc("path", handlerFunc).Methods(http.MethodGet) +// if secret.IsSecurityEnabled() { +// authenticationHook = handlers.AuthenticationHandlerFunc(dic) +// } +// For optionally-authenticated requests +// r.HandleFunc("path", authenticationHook(handlerFunc)).Methods(http.MethodGet) +// +// For unauthenticated requests +// r.HandleFunc("path", handlerFunc).Methods(http.MethodGet) // // For typical usage, it is preferred to use AutoConfigAuthenticationFunc which // will automatically select between a real and a fake JWT validation handler. -func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient) echo.MiddlewareFunc { +func AuthenticationHandlerFunc(dic *di.Container) echo.MiddlewareFunc { return func(inner echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + lc := container.LoggingClientFrom(dic.Get) + secretProvider := container.SecretProviderExtFrom(dic.Get) r := c.Request() w := c.Response() authHeader := r.Header.Get("Authorization") @@ -61,20 +62,29 @@ func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProvid authParts := strings.Split(authHeader, " ") if len(authParts) >= 2 && strings.EqualFold(authParts[0], "Bearer") { token := authParts[1] - validToken, err := secretProvider.IsJWTValid(token) - if err != nil { - lc.Errorf("Error checking JWT validity: %v", err) - // set Response.Committed to true in order to rewrite the status code + + parser := jwt.NewParser() + parsedToken, _, jwtErr := parser.ParseUnverified(token, &jwt.MapClaims{}) + if jwtErr != nil { w.Committed = false - return echo.NewHTTPError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) - } else if !validToken { - lc.Warnf("Request to '%s' UNAUTHORIZED", r.URL.Path) - // set Response.Committed to true in order to rewrite the status code + return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + } + issuer, jwtErr := parsedToken.Claims.GetIssuer() + if jwtErr != nil { w.Committed = false return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) } - lc.Debugf("Request to '%s' authorized", r.URL.Path) - return inner(c) + + if issuer == openBaoIssuer { + return SecretStoreAuthenticationHandlerFunc(secretProvider, lc, token, c) + } else { + // Verify the JWT by invoking security-proxy-auth http client + err := headers.VerifyJWT(token, issuer, parsedToken.Method.Alg(), dic, r.Context()) + if err != nil { + errResp := dtoCommon.NewBaseResponse("", err.Error(), err.Code()) + return c.JSON(err.Code(), errResp) + } + } } err := fmt.Errorf("unable to parse JWT for call to '%s'; unauthorized", r.URL.Path) lc.Errorf("%v", err) @@ -84,3 +94,24 @@ func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProvid } } } + +// SecretStoreAuthenticationHandlerFunc verifies the JWT with a OpenBao-based JWT authentication check +func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient, token string, c echo.Context) error { + r := c.Request() + w := c.Response() + + validToken, err := secretProvider.IsJWTValid(token) + if err != nil { + lc.Errorf("Error checking JWT validity by the secret provider: %v ", err) + // set Response.Committed to true in order to rewrite the status code + w.Committed = false + return echo.NewHTTPError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } else if !validToken { + lc.Warnf("Request to '%s' UNAUTHORIZED", r.URL.Path) + // set Response.Committed to true in order to rewrite the status code + w.Committed = false + return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) + } + lc.Debugf("Request to '%s' authorized", r.URL.Path) + return nil +} diff --git a/bootstrap/handlers/headers/jwt.go b/bootstrap/handlers/headers/jwt.go new file mode 100644 index 00000000..4284f3d6 --- /dev/null +++ b/bootstrap/handlers/headers/jwt.go @@ -0,0 +1,70 @@ +// +// Copyright (C) 2024 IOTech Ltd +// +// SPDX-License-Identifier: Apache-2.0 + +package headers + +import ( + "context" + stdErrs "errors" + + "github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/container" + "github.com/edgexfoundry/go-mod-bootstrap/v4/di" + "github.com/edgexfoundry/go-mod-core-contracts/v4/errors" + + "github.com/golang-jwt/jwt/v5" +) + +// VerifyJWT validates the JWT issued by security-proxy-auth by using the verification key provided by the security-proxy-auth service +func VerifyJWT(token string, + issuer string, + alg string, + dic *di.Container, + ctx context.Context) errors.EdgeX { + lc := container.LoggingClientFrom(dic.Get) + + verifyKey, edgexErr := GetVerificationKey(dic, issuer, alg, ctx) + if edgexErr != nil { + return errors.NewCommonEdgeXWrapper(edgexErr) + } + + err := ParseJWT(token, verifyKey, &jwt.MapClaims{}, jwt.WithExpirationRequired()) + if err != nil { + if stdErrs.Is(err, jwt.ErrTokenExpired) { + // Skip the JWT expired error + lc.Debug("JWT is valid but expired") + return nil + } else { + if stdErrs.Is(err, jwt.ErrTokenMalformed) || + stdErrs.Is(err, jwt.ErrTokenUnverifiable) || + stdErrs.Is(err, jwt.ErrTokenSignatureInvalid) || + stdErrs.Is(err, jwt.ErrTokenRequiredClaimMissing) { + lc.Errorf("Invalid jwt : %v\n", err) + return errors.NewCommonEdgeX(errors.KindUnauthorized, "invalid jwt", err) + } + lc.Errorf("Error occurred while validating JWT: %v", err) + return errors.NewCommonEdgeX(errors.Kind(err), "failed to parse jwt", err) + } + } + return nil +} + +// ParseJWT parses and validates the JWT with the passed ParserOptions and returns the token which implements the Claim interface +func ParseJWT(token string, verifyKey any, claims jwt.Claims, parserOption ...jwt.ParserOption) error { + _, err := jwt.ParseWithClaims(token, claims, func(_ *jwt.Token) (any, error) { + return verifyKey, nil + }, parserOption...) + if err != nil { + return err + } + + issuer, err := claims.GetIssuer() + if err != nil { + return errors.NewCommonEdgeX(errors.KindServerError, "failed to retrieve the issuer", err) + } + if len(issuer) == 0 { + return errors.NewCommonEdgeX(errors.KindUnauthorized, "issuer is empty", err) + } + return nil +} diff --git a/bootstrap/handlers/headers/jwt_test.go b/bootstrap/handlers/headers/jwt_test.go new file mode 100644 index 00000000..ee648dd8 --- /dev/null +++ b/bootstrap/handlers/headers/jwt_test.go @@ -0,0 +1,141 @@ +// +// Copyright (C) 2024 IOTech Ltd +// +// SPDX-License-Identifier: Apache-2.0 + +package headers + +import ( + "context" + "encoding/base64" + "net/http" + "testing" + + "github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/container" + "github.com/edgexfoundry/go-mod-bootstrap/v4/di" + mockClients "github.com/edgexfoundry/go-mod-core-contracts/v4/clients/interfaces/mocks" + "github.com/edgexfoundry/go-mod-core-contracts/v4/clients/logger" + "github.com/edgexfoundry/go-mod-core-contracts/v4/dtos" + "github.com/edgexfoundry/go-mod-core-contracts/v4/dtos/responses" + edgexErr "github.com/edgexfoundry/go-mod-core-contracts/v4/errors" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" +) + +var ( + issuer = "testIssuer" + mockVerifyKey = "mysecret" + mockIncorrectKey = "notmysecret" + incorrectKeyIssuer = "incorrectKey" + failedIssuer = "failedIssuer" + notFoundIssuer = "notFoundIssuer" +) + +func mockDic() *di.Container { + acMock := &mockClients.AuthClient{} + + acMock.On("VerificationKeyByIssuer", context.Background(), issuer). + Return(responses.NewKeyDataResponse("", "", http.StatusOK, dtos.KeyData{ + Issuer: issuer, + Type: "verification", + Key: mockVerifyKey, + }), nil) + acMock.On("VerificationKeyByIssuer", context.Background(), incorrectKeyIssuer). + Return(responses.NewKeyDataResponse("", "", http.StatusOK, dtos.KeyData{ + Issuer: issuer, + Type: "verification", + Key: mockIncorrectKey, + }), nil) + acMock.On("VerificationKeyByIssuer", context.Background(), failedIssuer). + Return(responses.KeyDataResponse{}, edgexErr.NewCommonEdgeX(edgexErr.KindServerError, "internal error", nil)) + acMock.On("VerificationKeyByIssuer", context.Background(), notFoundIssuer). + Return(responses.KeyDataResponse{}, edgexErr.NewCommonEdgeX(edgexErr.KindEntityDoesNotExist, "verification key not found", nil)) + + return di.NewContainer(di.ServiceConstructorMap{ + container.SecurityProxyAuthClientName: func(get di.Get) interface{} { + return acMock + }, + container.LoggingClientInterfaceName: func(get di.Get) interface{} { + return logger.NewMockClient() + }, + }) +} + +func TestVerifyJWT(t *testing.T) { + dic := mockDic() + + alg := "HS256" + + validJWT := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjE5MjQ0MDU3OTYsImlzcyI6IklPVGVjaFN5c3RlbSJ9.iM2f5eXTBdV3HEdfp5xVIsuo2mlsdOrC-EY0kvBTgg4" + noIssuer := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJleHAiOjE5MjQ0MDU3OTZ9.OvQ2Ot2q8XpIaK9-hoStMVGdY8zW7fk62-FruNKQLhI" + noExp := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJJT1RlY2hTeXN0ZW0ifQ.Ead-LdhSPISMhVADR6Dq5qv88QAC0RG-Fc7CGVbuo7k" + expiredJWT := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJJT1RlY2hTeXN0ZW0iLCJleHAiOjE3MDM0ODA5OTZ9.X14GAFL5-6z8qh3mo49h8OgANkE9JBSiltxxc5j_n40" + invalidJWT := "invalid" + invalidSignature := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJpbmNvcnJlY3RLZXkiLCJleHAiOjE5MjQ0MDU3OTZ9.cczSNpaHtEgCP1_BTcs0A99UQReQCJgzA0Lld5FJt5w" + + tests := []struct { + name string + token string + issuer string + errorExpected bool + errType edgexErr.ErrKind + }{ + {"Valid JWT", validJWT, issuer, false, ""}, + {"Valid JWT - expired", expiredJWT, issuer, false, ""}, + {"Invalid JWT - no issuer", noIssuer, issuer, true, edgexErr.KindUnauthorized}, + {"Invalid JWT - no exp", noExp, issuer, true, edgexErr.KindUnauthorized}, + {"Invalid JWT - malformed", invalidJWT, issuer, true, edgexErr.KindUnauthorized}, + {"Invalid JWT - invalid signature", invalidSignature, incorrectKeyIssuer, true, edgexErr.KindUnauthorized}, + {"Invalid JWT - invalid signature", "", failedIssuer, true, edgexErr.KindServerError}, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + ctx := context.Background() + err := VerifyJWT(testCase.token, testCase.issuer, alg, dic, ctx) + if testCase.errorExpected { + require.Error(t, err) + require.Equal(t, testCase.errType, edgexErr.Kind(err)) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestParseJWT(t *testing.T) { + keyBytes, err := base64.StdEncoding.DecodeString(mockVerifyKey) + require.NoError(t, err) + + jwtWithNoExp := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJJT1RlY2hTeXN0ZW0ifQ.Ead-LdhSPISMhVADR6Dq5qv88QAC0RG-Fc7CGVbuo7k" + jwtWithExp := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE5MjQ0MDU3OTYsImlzcyI6IklPVGVjaFN5c3RlbSJ9.lbVl9cRRcXx7tLhbJU_wGyHB-Qj_h4VOjs-t3MjRIQ4" + jwtWithNoIssuer := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE5MjQ0MDU3OTZ9.JexgnJ50U_DT6gZwYQ-RHZu864wH0ilkwaABC0y_GIo" + + tests := []struct { + name string + token string + verifyKey any + parserOpts []jwt.ParserOption + errorExpected bool + }{ + {"Valid JWT", jwtWithNoExp, keyBytes, nil, false}, + {"Valid JWT - with exp", jwtWithExp, keyBytes, []jwt.ParserOption{jwt.WithExpirationRequired()}, false}, + {"Invalid JWT - no exp", jwtWithNoExp, keyBytes, []jwt.ParserOption{jwt.WithExpirationRequired()}, true}, + {"Invalid JWT - no issuer", jwtWithNoIssuer, keyBytes, []jwt.ParserOption{jwt.WithExpirationRequired()}, true}, + {"Invalid JWT - invalid signature", jwtWithNoExp, []byte(mockIncorrectKey), nil, true}, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + mockClaims := &jwt.MapClaims{} + parseErr := ParseJWT(testCase.token, testCase.verifyKey, mockClaims, testCase.parserOpts...) + if testCase.errorExpected { + require.Error(t, parseErr) + } else { + require.NoError(t, parseErr) + } + }) + + } +} diff --git a/bootstrap/handlers/headers/key.go b/bootstrap/handlers/headers/key.go new file mode 100644 index 00000000..f6807311 --- /dev/null +++ b/bootstrap/handlers/headers/key.go @@ -0,0 +1,106 @@ +// +// Copyright (C) 2024 IOTech Ltd +// +// SPDX-License-Identifier: Apache-2.0 + +package headers + +import ( + "context" + "crypto/ed25519" + "encoding/base64" + "encoding/pem" + "fmt" + "sync" + + "github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/container" + "github.com/edgexfoundry/go-mod-bootstrap/v4/di" + "github.com/edgexfoundry/go-mod-core-contracts/v4/clients/logger" + "github.com/edgexfoundry/go-mod-core-contracts/v4/errors" + + "github.com/golang-jwt/jwt/v5" +) + +// A key cache to store the verification keys by issuer +var ( + keysCache = make(map[string]any) + mutex sync.RWMutex +) + +// GetVerificationKey returns the verification key obtained from local cache or security-proxy-auth http client +func GetVerificationKey(dic *di.Container, issuer, alg string, ctx context.Context) (any, errors.EdgeX) { + lc := container.LoggingClientFrom(dic.Get) + var verifyKey any + + // Check if the verification of the issuer already exists + mutex.RLock() + key, ok := keysCache[issuer] + mutex.RUnlock() + + if ok { + lc.Debugf("obtaining verification key from cache for JWT issuer '%s'", issuer) + + verifyKey = key + } else { + lc.Debugf("obtaining verification key from proxy-auth service client for JWT issuer '%s'", issuer) + + authClient := container.SecurityProxyAuthClientFrom(dic.Get) + keyResponse, edgexErr := authClient.VerificationKeyByIssuer(ctx, issuer) + if edgexErr != nil { + if errors.Kind(edgexErr) == errors.KindEntityDoesNotExist { + return nil, errors.NewCommonEdgeX(errors.KindServerError, fmt.Sprintf("verification key not found from proxy-auth service for JWT issuer '%s'", issuer), nil) + } + return nil, errors.NewCommonEdgeX(errors.KindServerError, fmt.Sprintf("failed to obtain the verification key from proxy-auth service for JWT issuer '%s'", issuer), edgexErr) + } + verifyKey, edgexErr = ProcessVerificationKey(keyResponse.KeyData.Key, alg, lc) + if edgexErr != nil { + return nil, errors.NewCommonEdgeX(errors.KindServerError, fmt.Sprintf("failed to process the verification key from proxy-auth service for JWT issuer '%s'", issuer), edgexErr) + } + + mutex.Lock() + keysCache[issuer] = verifyKey + mutex.Unlock() + } + return verifyKey, nil +} + +// ProcessVerificationKey handles the verification key retrieved from security-proxy-auth and returns the public key in the appropriate format according to the JWT signing algorithm +func ProcessVerificationKey(keyString string, alg string, lc logger.LoggingClient) (any, errors.EdgeX) { + keyBytes := []byte(keyString) + + switch alg { + case jwt.SigningMethodHS256.Alg(), jwt.SigningMethodHS384.Alg(), jwt.SigningMethodHS512.Alg(): + binaryKey, err := base64.StdEncoding.DecodeString(keyString) + if err != nil { + lc.Debugf("the key is not a valid base64, err: '%v', using the key '%s' without base64 encoding.", err, keyString) + return keyBytes, nil + } + + return binaryKey, nil + case jwt.SigningMethodEdDSA.Alg(): + block, _ := pem.Decode(keyBytes) + if block == nil || block.Type != "PUBLIC KEY" { + return nil, errors.NewCommonEdgeX(errors.KindServerError, "failed to decode the verification key PEM block", nil) + } + + edPublicKey := ed25519.PublicKey(block.Bytes) + return edPublicKey, nil + case jwt.SigningMethodRS256.Alg(), jwt.SigningMethodRS384.Alg(), jwt.SigningMethodRS512.Alg(), + jwt.SigningMethodPS256.Alg(), jwt.SigningMethodPS384.Alg(), jwt.SigningMethodPS512.Alg(): + rsaPublicKey, err := jwt.ParseRSAPublicKeyFromPEM(keyBytes) + if err != nil { + return nil, errors.NewCommonEdgeX(errors.KindServerError, fmt.Sprintf("failed to parse '%s' rsa verification key", alg), err) + } + + return rsaPublicKey, nil + case jwt.SigningMethodES256.Alg(), jwt.SigningMethodES384.Alg(), jwt.SigningMethodES512.Alg(): + ecdsaPublicKey, err := jwt.ParseECPublicKeyFromPEM(keyBytes) + if err != nil { + return nil, errors.NewCommonEdgeX(errors.KindServerError, fmt.Sprintf("failed to parse '%s' es verification key", alg), err) + } + + return ecdsaPublicKey, nil + default: + return nil, errors.NewCommonEdgeX(errors.KindContractInvalid, fmt.Sprintf("unsupported signing algorithm '%s'", alg), nil) + } +} diff --git a/bootstrap/handlers/headers/key_test.go b/bootstrap/handlers/headers/key_test.go new file mode 100644 index 00000000..975158ac --- /dev/null +++ b/bootstrap/handlers/headers/key_test.go @@ -0,0 +1,97 @@ +// +// Copyright (C) 2024 IOTech Ltd +// +// SPDX-License-Identifier: Apache-2.0 + +package headers + +import ( + "context" + "crypto/ed25519" + "encoding/base64" + "encoding/pem" + "fmt" + "testing" + + "github.com/edgexfoundry/go-mod-core-contracts/v4/clients/logger/mocks" + "github.com/edgexfoundry/go-mod-core-contracts/v4/errors" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" +) + +func TestGetVerificationKey(t *testing.T) { + dic := mockDic() + + expectedKeyBytes, err := base64.StdEncoding.DecodeString(mockVerifyKey) + require.NoError(t, err) + + tests := []struct { + name string + issuer string + keyInCache bool + expectedKey any + expectedError bool + expectedErrMsg string + }{ + {"Key in Cache", "cachedIssuer", true, []byte(mockVerifyKey), false, ""}, + {"Key not in Cache", issuer, false, expectedKeyBytes, false, ""}, + {"Key not found", notFoundIssuer, false, expectedKeyBytes, true, fmt.Sprintf("verification key not found from proxy-auth service for JWT issuer '%s'", notFoundIssuer)}, + {"Key processed error", failedIssuer, false, expectedKeyBytes, true, fmt.Sprintf("failed to obtain the verification key from proxy-auth service for JWT issuer '%s'", failedIssuer)}, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + keysCache = make(map[string]any) + + if testCase.keyInCache { + keysCache[testCase.issuer] = []byte(mockVerifyKey) + } + + key, err := GetVerificationKey(dic, testCase.issuer, "HS256", context.Background()) + if testCase.expectedError { + require.Error(t, err) + require.Equal(t, testCase.expectedErrMsg, err.Message()) + } else { + require.NoError(t, err) + require.Equal(t, testCase.expectedKey, key) + } + }) + } +} + +func TestProcessVerificationKey(t *testing.T) { + mockLogger := mocks.NewLoggingClient(t) + mockKey := "testKey" + + edDSAKey := "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAeDQLRoLzKZkHvXgU5nKiT2fp0zHt5nmY8YZykC1g+zE=\n-----END PUBLIC KEY-----" + block, _ := pem.Decode([]byte(edDSAKey)) + edDSAKeyBytes := block.Bytes + + invalidEdDSAKey := "-----BEGIN PUBLIC KEY-----\nINVALIDDATA\n-----END PUBLIC KEY-----" + + tests := []struct { + name string + keyString string + alg string + expectedKey any + errorExpected bool + expectedErrKind errors.ErrKind + }{ + {"Valid - HS256 alg", base64.StdEncoding.EncodeToString([]byte(mockKey)), jwt.SigningMethodHS256.Alg(), []byte(mockKey), false, ""}, + {"Valid - EdDSA alg", edDSAKey, jwt.SigningMethodEdDSA.Alg(), ed25519.PublicKey(edDSAKeyBytes), false, ""}, + {"Invalid - invalid EdDSA PEM Block", invalidEdDSAKey, jwt.SigningMethodEdDSA.Alg(), nil, true, errors.KindServerError}, + {"Invalid - unsupported signing algorithm", "anyKey", "UNSUPPORTED", nil, true, errors.KindContractInvalid}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + key, err := ProcessVerificationKey(test.keyString, test.alg, mockLogger) + if test.errorExpected { + require.Equal(t, test.expectedErrKind, errors.Kind(err)) + } else { + require.Equal(t, test.expectedKey, key) + } + }) + } +} diff --git a/go.mod b/go.mod index 4ab0409e..7dfd72c9 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/edgexfoundry/go-mod-messaging/v4 v4.0.0-dev.10 github.com/edgexfoundry/go-mod-registry/v4 v4.0.0-dev.2 github.com/edgexfoundry/go-mod-secrets/v4 v4.0.0-dev.4 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/hashicorp/go-multierror v1.1.1 github.com/labstack/echo/v4 v4.13.3 @@ -50,7 +51,6 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.23.0 // indirect github.com/go-resty/resty/v2 v2.15.3 // indirect - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/gorilla/schema v1.4.1 // indirect @@ -123,3 +123,5 @@ require ( gopkg.in/go-jose/go-jose.v2 v2.6.3 // indirect nhooyr.io/websocket v1.8.17 // indirect ) + +replace github.com/edgexfoundry/go-mod-core-contracts/v4 => github.com/lindseysimple/go-mod-core-contracts/v4 v4.0.0-20241224070246-4567c6a29d20 diff --git a/go.sum b/go.sum index b8abee31..ecd64bfb 100644 --- a/go.sum +++ b/go.sum @@ -70,8 +70,6 @@ github.com/eclipse/paho.mqtt.golang v1.5.0 h1:EH+bUVJNgttidWFkLLVKaQPGmkTUfQQqjO github.com/eclipse/paho.mqtt.golang v1.5.0/go.mod h1:du/2qNQVqJf/Sqs4MEL77kR8QTqANF7XU7Fk0aOTAgk= github.com/edgexfoundry/go-mod-configuration/v4 v4.0.0-dev.10 h1:DMv5LZDxcqUeb1dREMd/vK+reXmZYlpafgtm8XhYdHQ= github.com/edgexfoundry/go-mod-configuration/v4 v4.0.0-dev.10/go.mod h1:ltUpMcOpJSzmabBtZox5qg1AK2wEikvZJyIBXtJ7mUQ= -github.com/edgexfoundry/go-mod-core-contracts/v4 v4.0.0-dev.15 h1:4FbSL5rsNXVonrYz4K5v1oCNmi64LvcEx8xCgr6mXOo= -github.com/edgexfoundry/go-mod-core-contracts/v4 v4.0.0-dev.15/go.mod h1:M5JXcRrmnIVNAmqeDNVXd0PSOGdq96fgrEmzivx02c8= github.com/edgexfoundry/go-mod-messaging/v4 v4.0.0-dev.10 h1:xvDQDIJtmj/ZCmKzbAzg3h1F2ZdWz1MPoJSNfYZANGc= github.com/edgexfoundry/go-mod-messaging/v4 v4.0.0-dev.10/go.mod h1:ibaiw7r3RgLYDuuFfWT1kh//bjP+onDOOQsnSsdD4E8= github.com/edgexfoundry/go-mod-registry/v4 v4.0.0-dev.2 h1:iHu8JPpmrEOrIZdv0iYW69FlMmkyal/FpbXtC3pHt2c= @@ -287,6 +285,8 @@ github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0 github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/lindseysimple/go-mod-core-contracts/v4 v4.0.0-20241224070246-4567c6a29d20 h1:9AM7b578tXzt7SmvAfNUgyNPZj4PhbbnmsJRGTYmluU= +github.com/lindseysimple/go-mod-core-contracts/v4 v4.0.0-20241224070246-4567c6a29d20/go.mod h1:M5JXcRrmnIVNAmqeDNVXd0PSOGdq96fgrEmzivx02c8= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60=