Skip to content

Commit

Permalink
feat: Enhance the auth hook func to support external JWT
Browse files Browse the repository at this point in the history
Resolves edgexfoundry#810. Enhance the auth middleware func to support external JWT verifcation.

Signed-off-by: Lindsey Cheng <[email protected]>
  • Loading branch information
lindseysimple committed Dec 26, 2024
1 parent 48a6707 commit 1da7d09
Show file tree
Hide file tree
Showing 11 changed files with 551 additions and 56 deletions.
12 changes: 12 additions & 0 deletions bootstrap/container/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,15 @@ func ScheduleActionRecordClientFrom(get di.Get) interfaces.ScheduleActionRecordC

return get(ScheduleActionRecordClientName).(interfaces.ScheduleActionRecordClient)
}

// SecurityProxyAuthClientName contains the name of the AuthClient's implementation in the DIC.
var SecurityProxyAuthClientName = di.TypeInstanceToName((*interfaces.AuthClient)(nil))

// SecurityProxyAuthClientFrom helper function queries the DIC and returns the AuthClient's implementation.
func SecurityProxyAuthClientFrom(get di.Get) interfaces.AuthClient {
if get(SecurityProxyAuthClientName) == nil {
return nil
}

return get(SecurityProxyAuthClientName).(interfaces.AuthClient)
}
3 changes: 1 addition & 2 deletions bootstrap/controller/commonapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ type config struct {

func NewCommonController(dic *di.Container, r *echo.Echo, serviceName string, serviceVersion string) *CommonController {
lc := container.LoggingClientFrom(dic.Get)
secretProvider := container.SecretProviderExtFrom(dic.Get)
authenticationHook := handlers.AutoConfigAuthenticationFunc(secretProvider, lc)
authenticationHook := handlers.AutoConfigAuthenticationFunc(dic)
configuration := container.ConfigurationFrom(dic.Get)
c := CommonController{
dic: dic,
Expand Down
8 changes: 3 additions & 5 deletions bootstrap/handlers/auth_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@ import (
"os"
"strconv"

"github.com/edgexfoundry/go-mod-core-contracts/v4/clients/logger"

"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/interfaces"
"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/secret"
"github.com/edgexfoundry/go-mod-bootstrap/v4/di"

"github.com/labstack/echo/v4"
)
Expand All @@ -44,12 +42,12 @@ func NilAuthenticationHandlerFunc() echo.MiddlewareFunc {
// to disable JWT validation. This might be wanted for an EdgeX
// adopter that wanted to only validate JWT's at the proxy layer,
// or as an escape hatch for a caller that cannot authenticate.
func AutoConfigAuthenticationFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient) echo.MiddlewareFunc {
func AutoConfigAuthenticationFunc(dic *di.Container) echo.MiddlewareFunc {
// Golang standard library treats an error as false
disableJWTValidation, _ := strconv.ParseBool(os.Getenv("EDGEX_DISABLE_JWT_VALIDATION"))
authenticationHook := NilAuthenticationHandlerFunc()
if secret.IsSecurityEnabled() && !disableJWTValidation {
authenticationHook = SecretStoreAuthenticationHandlerFunc(secretProvider, lc)
authenticationHook = AuthenticationHandlerFunc(dic)
}
return authenticationHook
}
85 changes: 62 additions & 23 deletions bootstrap/handlers/auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,43 @@ import (
"net/http"
"strings"

"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/container"
"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/handlers/headers"
"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/interfaces"
"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/zerotrust"
"github.com/edgexfoundry/go-mod-bootstrap/v4/di"
"github.com/edgexfoundry/go-mod-core-contracts/v4/clients/logger"
dtoCommon "github.com/edgexfoundry/go-mod-core-contracts/v4/dtos/common"

"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"github.com/openziti/sdk-golang/ziti/edge"
)

// SecretStoreAuthenticationHandlerFunc prefixes an existing HandlerFunc
// with a OpenBao-based JWT authentication check. Usage:
// openBaoIssuer defines the issuer if JWT was issued from OpenBao
const openBaoIssuer = "/v1/identity/oidc"

// AuthenticationHandlerFunc prefixes an existing HandlerFunc,
// performing authentication checks based on OpenBao-issued JWTs or external JWTs by checking the Authorization header. Usage:
//
// authenticationHook := handlers.NilAuthenticationHandlerFunc()
//
// authenticationHook := handlers.NilAuthenticationHandlerFunc()
// if secret.IsSecurityEnabled() {
// lc := container.LoggingClientFrom(dic.Get)
// secretProvider := container.SecretProviderFrom(dic.Get)
// authenticationHook = handlers.SecretStoreAuthenticationHandlerFunc(secretProvider, lc)
// }
// For optionally-authenticated requests
// r.HandleFunc("path", authenticationHook(handlerFunc)).Methods(http.MethodGet)
// if secret.IsSecurityEnabled() {
// authenticationHook = handlers.AuthenticationHandlerFunc(dic)
// }
// For optionally-authenticated requests
// r.HandleFunc("path", authenticationHook(handlerFunc)).Methods(http.MethodGet)
//
// For unauthenticated requests
// r.HandleFunc("path", handlerFunc).Methods(http.MethodGet)
// For unauthenticated requests
// r.HandleFunc("path", handlerFunc).Methods(http.MethodGet)
//
// For typical usage, it is preferred to use AutoConfigAuthenticationFunc which
// will automatically select between a real and a fake JWT validation handler.
func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient) echo.MiddlewareFunc {
func AuthenticationHandlerFunc(dic *di.Container) echo.MiddlewareFunc {
return func(inner echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
lc := container.LoggingClientFrom(dic.Get)
secretProvider := container.SecretProviderExtFrom(dic.Get)
r := c.Request()
w := c.Response()
authHeader := r.Header.Get("Authorization")
Expand All @@ -70,20 +79,29 @@ func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProvid
authParts := strings.Split(authHeader, " ")
if len(authParts) >= 2 && strings.EqualFold(authParts[0], "Bearer") {
token := authParts[1]
validToken, err := secretProvider.IsJWTValid(token)
if err != nil {
lc.Errorf("Error checking JWT validity: %v", err)
// set Response.Committed to true in order to rewrite the status code

parser := jwt.NewParser()
parsedToken, _, jwtErr := parser.ParseUnverified(token, &jwt.MapClaims{})
if jwtErr != nil {
w.Committed = false
return echo.NewHTTPError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
} else if !validToken {
lc.Warnf("Request to '%s' UNAUTHORIZED", r.URL.Path)
// set Response.Committed to true in order to rewrite the status code
return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
}
issuer, jwtErr := parsedToken.Claims.GetIssuer()
if jwtErr != nil {
w.Committed = false
return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
}
lc.Debugf("Request to '%s' authorized", r.URL.Path)
return inner(c)

if issuer == openBaoIssuer {
return SecretStoreAuthenticationHandlerFunc(secretProvider, lc, token, c)
} else {
// Verify the JWT by invoking security-proxy-auth http client
err := headers.VerifyJWT(token, issuer, parsedToken.Method.Alg(), dic, r.Context())
if err != nil {
errResp := dtoCommon.NewBaseResponse("", err.Error(), err.Code())
return c.JSON(err.Code(), errResp)
}
}
}
err := fmt.Errorf("unable to parse JWT for call to '%s'; unauthorized", r.URL.Path)
lc.Errorf("%v", err)
Expand All @@ -93,3 +111,24 @@ func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProvid
}
}
}

// SecretStoreAuthenticationHandlerFunc verifies the JWT with a OpenBao-based JWT authentication check
func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient, token string, c echo.Context) error {
r := c.Request()
w := c.Response()

validToken, err := secretProvider.IsJWTValid(token)
if err != nil {
lc.Errorf("Error checking JWT validity by the secret provider: %v ", err)
// set Response.Committed to true in order to rewrite the status code
w.Committed = false
return echo.NewHTTPError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
} else if !validToken {
lc.Warnf("Request to '%s' UNAUTHORIZED", r.URL.Path)
// set Response.Committed to true in order to rewrite the status code
w.Committed = false
return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
}
lc.Debugf("Request to '%s' authorized", r.URL.Path)
return nil
}
77 changes: 54 additions & 23 deletions bootstrap/handlers/auth_middleware_no_ziti.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,27 @@ import (
"github.com/labstack/echo/v4"
)

// SecretStoreAuthenticationHandlerFunc prefixes an existing HandlerFunc
// with a OpenBao-based JWT authentication check. Usage:
// AuthenticationHandlerFunc prefixes an existing HandlerFunc,
// performing authentication checks based on OpenBao-issued JWTs or external JWTs by checking the Authorization header. Usage:
//
// authenticationHook := handlers.NilAuthenticationHandlerFunc()
// if secret.IsSecurityEnabled() {
// lc := container.LoggingClientFrom(dic.Get)
// secretProvider := container.SecretProviderFrom(dic.Get)
// authenticationHook = handlers.SecretStoreAuthenticationHandlerFunc(secretProvider, lc)
// }
// For optionally-authenticated requests
// r.HandleFunc("path", authenticationHook(handlerFunc)).Methods(http.MethodGet)
// authenticationHook := handlers.NilAuthenticationHandlerFunc()
//
// For unauthenticated requests
// r.HandleFunc("path", handlerFunc).Methods(http.MethodGet)
// if secret.IsSecurityEnabled() {
// authenticationHook = handlers.AuthenticationHandlerFunc(dic)
// }
// For optionally-authenticated requests
// r.HandleFunc("path", authenticationHook(handlerFunc)).Methods(http.MethodGet)
//
// For unauthenticated requests
// r.HandleFunc("path", handlerFunc).Methods(http.MethodGet)
//
// For typical usage, it is preferred to use AutoConfigAuthenticationFunc which
// will automatically select between a real and a fake JWT validation handler.
func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient) echo.MiddlewareFunc {
func AuthenticationHandlerFunc(dic *di.Container) echo.MiddlewareFunc {
return func(inner echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
lc := container.LoggingClientFrom(dic.Get)
secretProvider := container.SecretProviderExtFrom(dic.Get)
r := c.Request()
w := c.Response()
authHeader := r.Header.Get("Authorization")
Expand All @@ -61,20 +62,29 @@ func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProvid
authParts := strings.Split(authHeader, " ")
if len(authParts) >= 2 && strings.EqualFold(authParts[0], "Bearer") {
token := authParts[1]
validToken, err := secretProvider.IsJWTValid(token)
if err != nil {
lc.Errorf("Error checking JWT validity: %v", err)
// set Response.Committed to true in order to rewrite the status code

parser := jwt.NewParser()
parsedToken, _, jwtErr := parser.ParseUnverified(token, &jwt.MapClaims{})
if jwtErr != nil {
w.Committed = false
return echo.NewHTTPError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
} else if !validToken {
lc.Warnf("Request to '%s' UNAUTHORIZED", r.URL.Path)
// set Response.Committed to true in order to rewrite the status code
return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
}
issuer, jwtErr := parsedToken.Claims.GetIssuer()
if jwtErr != nil {
w.Committed = false
return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
}
lc.Debugf("Request to '%s' authorized", r.URL.Path)
return inner(c)

if issuer == openBaoIssuer {
return SecretStoreAuthenticationHandlerFunc(secretProvider, lc, token, c)
} else {
// Verify the JWT by invoking security-proxy-auth http client
err := headers.VerifyJWT(token, issuer, parsedToken.Method.Alg(), dic, r.Context())
if err != nil {
errResp := dtoCommon.NewBaseResponse("", err.Error(), err.Code())
return c.JSON(err.Code(), errResp)
}
}
}
err := fmt.Errorf("unable to parse JWT for call to '%s'; unauthorized", r.URL.Path)
lc.Errorf("%v", err)
Expand All @@ -84,3 +94,24 @@ func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProvid
}
}
}

// SecretStoreAuthenticationHandlerFunc verifies the JWT with a OpenBao-based JWT authentication check
func SecretStoreAuthenticationHandlerFunc(secretProvider interfaces.SecretProviderExt, lc logger.LoggingClient, token string, c echo.Context) error {
r := c.Request()
w := c.Response()

validToken, err := secretProvider.IsJWTValid(token)
if err != nil {
lc.Errorf("Error checking JWT validity by the secret provider: %v ", err)
// set Response.Committed to true in order to rewrite the status code
w.Committed = false
return echo.NewHTTPError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
} else if !validToken {
lc.Warnf("Request to '%s' UNAUTHORIZED", r.URL.Path)
// set Response.Committed to true in order to rewrite the status code
w.Committed = false
return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
}
lc.Debugf("Request to '%s' authorized", r.URL.Path)
return nil
}
70 changes: 70 additions & 0 deletions bootstrap/handlers/headers/jwt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
//
// Copyright (C) 2024 IOTech Ltd
//
// SPDX-License-Identifier: Apache-2.0

package headers

import (
"context"
stdErrs "errors"

"github.com/edgexfoundry/go-mod-bootstrap/v4/bootstrap/container"
"github.com/edgexfoundry/go-mod-bootstrap/v4/di"
"github.com/edgexfoundry/go-mod-core-contracts/v4/errors"

"github.com/golang-jwt/jwt/v5"
)

// VerifyJWT validates the JWT issued by security-proxy-auth by using the verification key provided by the security-proxy-auth service
func VerifyJWT(token string,
issuer string,
alg string,
dic *di.Container,
ctx context.Context) errors.EdgeX {
lc := container.LoggingClientFrom(dic.Get)

verifyKey, edgexErr := GetVerificationKey(dic, issuer, alg, ctx)
if edgexErr != nil {
return errors.NewCommonEdgeXWrapper(edgexErr)
}

err := ParseJWT(token, verifyKey, &jwt.MapClaims{}, jwt.WithExpirationRequired())
if err != nil {
if stdErrs.Is(err, jwt.ErrTokenExpired) {
// Skip the JWT expired error
lc.Debug("JWT is valid but expired")
return nil
} else {
if stdErrs.Is(err, jwt.ErrTokenMalformed) ||
stdErrs.Is(err, jwt.ErrTokenUnverifiable) ||
stdErrs.Is(err, jwt.ErrTokenSignatureInvalid) ||
stdErrs.Is(err, jwt.ErrTokenRequiredClaimMissing) {
lc.Errorf("Invalid jwt : %v\n", err)
return errors.NewCommonEdgeX(errors.KindUnauthorized, "invalid jwt", err)
}
lc.Errorf("Error occurred while validating JWT: %v", err)
return errors.NewCommonEdgeX(errors.Kind(err), "failed to parse jwt", err)
}
}
return nil
}

// ParseJWT parses and validates the JWT with the passed ParserOptions and returns the token which implements the Claim interface
func ParseJWT(token string, verifyKey any, claims jwt.Claims, parserOption ...jwt.ParserOption) error {
_, err := jwt.ParseWithClaims(token, claims, func(_ *jwt.Token) (any, error) {
return verifyKey, nil
}, parserOption...)
if err != nil {
return err
}

issuer, err := claims.GetIssuer()
if err != nil {
return errors.NewCommonEdgeX(errors.KindServerError, "failed to retrieve the issuer", err)
}
if len(issuer) == 0 {
return errors.NewCommonEdgeX(errors.KindUnauthorized, "issuer is empty", err)
}
return nil
}
Loading

0 comments on commit 1da7d09

Please sign in to comment.