From 993c45a92bf7386e7f34b6e387817ef44616b748 Mon Sep 17 00:00:00 2001 From: Giannis Katsanos Date: Tue, 16 Apr 2024 16:08:13 +0300 Subject: [PATCH] feat: Custom authorization failure handler (#283) Added an option to the http.WithHeaderAuthorization middleware to modify the default response in case of authentication failure. --- http/middleware.go | 27 +++++++++++- http/middleware_test.go | 98 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 117 insertions(+), 8 deletions(-) diff --git a/http/middleware.go b/http/middleware.go index 13b58e8f..9fb5770b 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -49,6 +49,9 @@ func WithHeaderAuthorization(opts ...AuthorizationOption) func(http.Handler) htt if params.Clock == nil { params.Clock = clerk.NewClock() } + if params.AuthorizationFailureHandler == nil { + params.AuthorizationFailureHandler = http.HandlerFunc(defaultAuthorizationFailureHandler) + } authorization := strings.TrimSpace(r.Header.Get("Authorization")) if authorization == "" { @@ -65,14 +68,14 @@ func WithHeaderAuthorization(opts ...AuthorizationOption) func(http.Handler) htt if params.JWK == nil { params.JWK, err = getJWK(r.Context(), params.JWKSClient, decoded.KeyID, params.Clock) if err != nil { - w.WriteHeader(http.StatusUnauthorized) + params.AuthorizationFailureHandler.ServeHTTP(w, r) return } } params.Token = token claims, err := jwt.Verify(r.Context(), ¶ms.VerifyParams) if err != nil { - w.WriteHeader(http.StatusUnauthorized) + params.AuthorizationFailureHandler.ServeHTTP(w, r) return } @@ -83,6 +86,10 @@ func WithHeaderAuthorization(opts ...AuthorizationOption) func(http.Handler) htt } } +func defaultAuthorizationFailureHandler(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) +} + // Retrieve the JSON web key for the provided token from the JWKS set. // Tries a cached value first, but if there's no value or the entry // has expired, it will fetch the JWK set from the API and cache the @@ -109,6 +116,11 @@ func getJWK(ctx context.Context, jwksClient *jwks.Client, kid string, clock cler type AuthorizationParams struct { jwt.VerifyParams + // AuthorizationFailureHandler gets executed when request authorization + // fails. Pass a custom http.Handler to control the http.Response for + // invalid authorization. The default is a Response with an empty body + // and 401 Unauthorized status. + AuthorizationFailureHandler http.Handler // JWKSClient is the jwks.Client that will be used to fetch the // JSON Web Key Set. A default client will be used if none is // provided. @@ -119,6 +131,17 @@ type AuthorizationParams struct { // authorization options. type AuthorizationOption func(*AuthorizationParams) error +// AuthorizationFailureHandler allows to provide a handler that +// writes the response in case of authorization failures. +// The default behavior is a response with an empty body and 401 +// Unauthorized status. +func AuthorizationFailureHandler(h http.Handler) AuthorizationOption { + return func(params *AuthorizationParams) error { + params.AuthorizationFailureHandler = h + return nil + } +} + // AuthorizedParty allows to provide a handler that accepts the // 'azp' claim. // The handler can be used to perform validations on the azp claim diff --git a/http/middleware_test.go b/http/middleware_test.go index 6d8155c6..9377e81c 100644 --- a/http/middleware_test.go +++ b/http/middleware_test.go @@ -13,6 +13,29 @@ import ( ) func TestWithHeaderAuthorization_InvalidAuthorization(t *testing.T) { + kid := "kid-" + t.Name() + // Mock the Clerk API server. We expect requests to GET /jwks. + clerkAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/jwks" && r.Method == http.MethodGet { + _, err := w.Write([]byte( + fmt.Sprintf( + `{"keys":[{"use":"sig","kty":"RSA","kid":"%s","alg":"RS256","n":"ypsS9Iq26F71B3lPjT_IMtglDXo8Dko9h5UBmrvkWo6pdH_4zmMjeghozaHY1aQf1dHUBLsov_XvG_t-1yf7tFfO_ImC1JqSQwdSjrXZp3oMNFHwdwAknvtlBg3sBxJ8nM1WaCWaTlb2JhEmczIji15UG6V0M2cAp2VK_brcylQROaJLC2zVa4usGi4AHzAHaRUTv6XB9bGYMvkM-ZniuXgp9dPurisIIWg25DGrTaH-kg8LPaqGwa54eLEnvfAe0ZH_MvA4_bn_u_iDkQ9ZI_CD1vwf0EDnzLgd9ZG1khGsqmXY_4WiLRGsPqZe90HzaBJma9sAxXB4qj_aNnwD5w","e":"AQAB"}]}`, + kid, + ), + )) + require.NoError(t, err) + return + } + })) + defer clerkAPI.Close() + + // Mock the clerk backend + clerk.SetBackend(clerk.NewBackend(&clerk.BackendConfig{ + HTTPClient: clerkAPI.Client(), + URL: &clerkAPI.URL, + })) + + // This is the user's server, guarded by Clerk's middleware. ts := httptest.NewServer(WithHeaderAuthorization()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, ok := clerk.SessionClaimsFromContext(r.Context()) require.False(t, ok) @@ -21,11 +44,6 @@ func TestWithHeaderAuthorization_InvalidAuthorization(t *testing.T) { }))) defer ts.Close() - clerk.SetBackend(clerk.NewBackend(&clerk.BackendConfig{ - HTTPClient: ts.Client(), - URL: &ts.URL, - })) - // Request without Authorization header req, err := http.NewRequest(http.MethodGet, ts.URL, nil) require.NoError(t, err) @@ -38,6 +56,19 @@ func TestWithHeaderAuthorization_InvalidAuthorization(t *testing.T) { res, err = ts.Client().Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, res.StatusCode) + + // Request with unverifiable Bearer token + tokenClaims := map[string]any{ + "sid": "sess_123", + } + token, _ := clerktest.GenerateJWT(t, tokenClaims, kid) + req, err = http.NewRequest(http.MethodGet, ts.URL, nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+token) + require.NoError(t, err) + res, err = ts.Client().Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusUnauthorized, res.StatusCode) } func TestRequireHeaderAuthorization_InvalidAuthorization(t *testing.T) { @@ -67,7 +98,7 @@ func TestRequireHeaderAuthorization_InvalidAuthorization(t *testing.T) { } func TestWithHeaderAuthorization_Caching(t *testing.T) { - kid := "kid" + kid := "kid-" + t.Name() clock := clerktest.NewClockAt(time.Now().UTC()) // Mock the Clerk API server. We expect requests to GET /jwks. @@ -134,6 +165,61 @@ func TestWithHeaderAuthorization_Caching(t *testing.T) { require.Equal(t, 2, totalJWKSRequests) } +func TestWithHeaderAuthorization_CustomFailureHandler(t *testing.T) { + kid := "kid-" + t.Name() + // Mock the Clerk API server. We expect requests to GET /jwks. + clerkAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/jwks" && r.Method == http.MethodGet { + _, err := w.Write([]byte( + fmt.Sprintf( + `{"keys":[{"use":"sig","kty":"RSA","kid":"%s","alg":"RS256","n":"ypsS9Iq26F71B3lPjT_IMtglDXo8Dko9h5UBmrvkWo6pdH_4zmMjeghozaHY1aQf1dHUBLsov_XvG_t-1yf7tFfO_ImC1JqSQwdSjrXZp3oMNFHwdwAknvtlBg3sBxJ8nM1WaCWaTlb2JhEmczIji15UG6V0M2cAp2VK_brcylQROaJLC2zVa4usGi4AHzAHaRUTv6XB9bGYMvkM-ZniuXgp9dPurisIIWg25DGrTaH-kg8LPaqGwa54eLEnvfAe0ZH_MvA4_bn_u_iDkQ9ZI_CD1vwf0EDnzLgd9ZG1khGsqmXY_4WiLRGsPqZe90HzaBJma9sAxXB4qj_aNnwD5w","e":"AQAB"}]}`, + kid, + ), + )) + require.NoError(t, err) + return + } + })) + defer clerkAPI.Close() + + // Define a custom failure handler which returns a custom HTTP + // status code. + customFailureHandler := func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusTeapot) + } + + // Apply the custom failure handler to the WithHeaderAuthorization + // middleware. + middleware := WithHeaderAuthorization( + AuthorizationFailureHandler(http.HandlerFunc(customFailureHandler)), + ) + // This is the user's server, guarded by Clerk's http middleware. + ts := httptest.NewServer(middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, ok := clerk.SessionClaimsFromContext(r.Context()) + require.False(t, ok) + _, err := w.Write([]byte("{}")) + require.NoError(t, err) + }))) + defer ts.Close() + + clerk.SetBackend(clerk.NewBackend(&clerk.BackendConfig{ + HTTPClient: clerkAPI.Client(), + URL: &clerkAPI.URL, + })) + + tokenClaims := map[string]any{ + "sid": "sess_123", + } + token, _ := clerktest.GenerateJWT(t, tokenClaims, kid) + // Request with invalid Authorization header + req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+token) + res, err := ts.Client().Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusTeapot, res.StatusCode) +} + func TestAuthorizedPartyFunc(t *testing.T) { t.Parallel() for _, tc := range []struct {