From c803c14e908a1c8ed8dd1355d94f40bf83d4987b Mon Sep 17 00:00:00 2001 From: Giannis Katsanos Date: Fri, 16 Feb 2024 14:18:36 +0200 Subject: [PATCH] feat: JWT parsing with custom claims Replaced the CustomClaims parameter with a CustomClaimsConstructor function when verifying a session JWT. The option is also available in the HTTP middleware. The constructor function will be called when the JWT is parsed, producing a new struct instance instead of writing on a single instance. The custom claims will be made available in the SessionClaims.Custom field. --- clerktest/clerktest.go | 29 ++++++++ http/middleware.go | 27 ++++++-- jwt.go | 2 + jwt/jwt.go | 39 +++++++---- jwt/jwt_test.go | 146 +++++++++++++++++++++++++++++++++++++++-- 5 files changed, 222 insertions(+), 21 deletions(-) diff --git a/clerktest/clerktest.go b/clerktest/clerktest.go index cd048ba0..ff7f5d3a 100644 --- a/clerktest/clerktest.go +++ b/clerktest/clerktest.go @@ -3,12 +3,17 @@ package clerktest import ( "bytes" + "crypto" + "crypto/rand" + "crypto/rsa" "encoding/json" "io" "net/http" "net/url" "testing" + "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3/jwt" "github.com/stretchr/testify/require" ) @@ -59,3 +64,27 @@ func (rt *RoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { Body: io.NopCloser(bytes.NewReader(rt.Out)), }, nil } + +// GenerateJWT creates a JSON web token with the provided claims +// and key ID. +func GenerateJWT(t *testing.T, claims any, kid string) (string, crypto.PublicKey) { + t.Helper() + + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + signerOpts := &jose.SignerOptions{} + signerOpts.WithType("JWT") + if kid != "" { + signerOpts.WithHeader("kid", kid) + } + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: privKey}, signerOpts) + require.NoError(t, err) + + builder := jwt.Signed(signer) + builder = builder.Claims(claims) + token, err := builder.CompactSerialize() + require.NoError(t, err) + + return token, privKey.Public() +} diff --git a/http/middleware.go b/http/middleware.go index 8936415f..2b63c9e3 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -188,11 +188,30 @@ func AuthorizedPartyMatches(parties ...string) func(string) bool { } } -// CustomClaims allows to pass a type (e.g. struct), which will be populated with the token claims based on json tags. -// You must pass a pointer for this option to work. -func CustomClaims(claims any) AuthorizationOption { +// CustomClaimsConstructor allows to pass a constructor function +// which returns a pointer to a type (struct) to hold custom token +// claims. +// The instance of the custom claims type will be then made available +// through the clerk.SessionClaims struct. +// +// // Define a type to describe the custom claims. +// type MyCustomClaims struct { +// ACustomClaim string `json:"a_custom_claim"` +// } +// +// // In your HTTP server mux, configure the middleware with +// // the custom claims constructor. +// WithHeaderAuthorization(CustomClaimsConstructor(func(_ context.Context) any { +// return &MyCustomClaims{} +// }) +// +// // In the HTTP handler, access the active session claims. The +// // custom claims are available in the SessionClaims.Custom field. +// sessionClaims, ok := clerk.SessionClaimsFromContext(r.Context()) +// customClaims, ok := sessionClaims.Custom.(*MyCustomClaims) +func CustomClaimsConstructor(constructor func(context.Context) any) AuthorizationOption { return func(params *AuthorizationParams) error { - params.CustomClaims = claims + params.CustomClaimsConstructor = constructor return nil } } diff --git a/jwt.go b/jwt.go index 6e2a92a1..e3796587 100644 --- a/jwt.go +++ b/jwt.go @@ -34,6 +34,8 @@ type SessionClaims struct { ActiveOrganizationRole string `json:"org_role"` ActiveOrganizationPermissions []string `json:"org_permissions"` Actor json.RawMessage `json:"act,omitempty"` + // Custom can hold any custom claims that might be found in a JWT. + Custom any `json:"-"` } // HasPermission checks if the session claims contain the provided diff --git a/jwt/jwt.go b/jwt/jwt.go index b42aff83..42d257bd 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -12,15 +12,32 @@ import ( "github.com/go-jose/go-jose/v3/jwt" ) +// AuthorizedPartyHandler is a type that can be used to perform checks +// on the 'azp' claim. type AuthorizedPartyHandler func(string) bool +// CustomClaimsConstructor can initialize structs for holding custom +// JWT claims. +type CustomClaimsConstructor func(context.Context) any + type VerifyParams struct { // Token is the JWT that will be verified. Required. Token string // JWK the custom JSON Web Key that will be used to verify the // Token with. Required. - JWK *clerk.JSONWebKey - CustomClaims any + JWK *clerk.JSONWebKey + // CustomClaimsConstructor will be called when parsing the Token's + // claims. It's useful for parsing custom claims into user-defined + // types. + // Make sure it returns a pointer to a type (struct) that describes + // any custom claims schema with the correct JSON tags. + // type MyCustomClaims struct {} + // VerifyParams{ + // CustomClaimsConstructor: func(_ context.Context) any { + // return &MyCustomClaims{} + // }, + // } + CustomClaimsConstructor CustomClaimsConstructor // Leeway is the duration which the JWT is considered valid after // it's expired. Useful for defending against server clock skews. Leeway time.Duration @@ -51,8 +68,9 @@ func Verify(ctx context.Context, params *VerifyParams) (*clerk.SessionClaims, er claims := &clerk.SessionClaims{} allClaims := []any{claims} - if params.CustomClaims != nil { - allClaims = append(allClaims, params.CustomClaims) + if params.CustomClaimsConstructor != nil { + claims.Custom = params.CustomClaimsConstructor(ctx) + allClaims = append(allClaims, claims.Custom) } err = parsedToken.Claims(jwk.Key, allClaims...) if err != nil { @@ -64,13 +82,9 @@ func Verify(ctx context.Context, params *VerifyParams) (*clerk.SessionClaims, er return nil, err } - iss := claims.Issuer - if params.ProxyURL != nil && *params.ProxyURL != "" { - iss = *params.ProxyURL - } // Non-satellite domains must validate the issuer. - if !params.IsSatellite && !isValidIssuer(iss) { - return nil, fmt.Errorf("invalid issuer %s", iss) + if !params.IsSatellite && !isValidIssuer(claims.Issuer, params.ProxyURL) { + return nil, fmt.Errorf("invalid issuer %s", claims.Issuer) } if params.AuthorizedPartyHandler != nil && !params.AuthorizedPartyHandler(claims.AuthorizedParty) { @@ -80,7 +94,10 @@ func Verify(ctx context.Context, params *VerifyParams) (*clerk.SessionClaims, er return claims, nil } -func isValidIssuer(iss string) bool { +func isValidIssuer(iss string, proxyURL *string) bool { + if proxyURL != nil { + return iss == *proxyURL + } return strings.HasPrefix(iss, "https://clerk.") || strings.Contains(iss, ".clerk.accounts") } diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 6f83223e..494a53a6 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -5,12 +5,16 @@ import ( "testing" "github.com/clerk/clerk-sdk-go/v2" + "github.com/clerk/clerk-sdk-go/v2/clerktest" + "github.com/go-jose/go-jose/v3" "github.com/stretchr/testify/require" ) func TestVerify_InvalidParams(t *testing.T) { + t.Parallel() ctx := context.Background() - token := "eyJhbGciOiJSUzI1NiIsImNhdCI6ImNsX0I3ZDRQRDExMUFBQSIsImtpZCI6Imluc18yOWR6bUdmQ3JydzdSMDRaVFFZRDNKSTB5dkYiLCJ0eXAiOiJKV1QifQ.eyJhenAiOiJodHRwczovL2Rhc2hib2FyZC5wcm9kLmxjbGNsZXJrLmNvbSIsImV4cCI6MTcwNzMwMDMyMiwiaWF0IjoxNzA3MzAwMjYyLCJpc3MiOiJodHRwczovL2NsZXJrLnByb2QubGNsY2xlcmsuY29tIiwibmJmIjoxNzA3MzAwMjUyLCJvcmdzIjp7Im9yZ18ySUlwcVIxenFNeHJQQkhSazNzTDJOSnJUQkQiOiJvcmc6YWRtaW4iLCJvcmdfMllHMlNwd0IzWEJoNUo0ZXF5elFVb0dXMjVhIjoib3JnOmFkbWluIiwib3JnXzJhZzJ6bmgxWGFjTXI0dGRXYjZRbEZSQ2RuaiI6Im9yZzphZG1pbiIsIm9yZ18yYWlldHlXa3VFSEhaRmRSUTFvVjYzMnZWaFciOiJvcmc6YWRtaW4ifSwic2lkIjoic2Vzc18yYm84b2gyRnIyeTNueVoyRVZQYktBd2ZvaU0iLCJzdWIiOiJ1c2VyXzI5ZTBXTnp6M245V1Q5S001WlpJYTBVVjNDNyJ9.6GtQafMBYY3Ij3pKHOyBYKt76LoLeBC71QUY_ho3k5nb0FBSvV0upKFLPBvIXNuF7hH0FK2QqDcAmrhbzAI-2qF_Ynve8Xl4VZCRpbTuZI7uL-tVjCvMffEIH-BHtrZ-QcXhEmNFQNIPyZTu21242he7U6o4S8st_aLmukWQzj_4qir7o5_fmVhm7YkLa0gYG5SLjkr2czwem1VGFHEVEOrHjun-g6eMnDNMMMysIOkZFxeqiCnqpc4u1V7Z7jfoK0r_-Unp8mGGln5KWYMCQyp1l1SkGwugtxeWfSbE4eklKRmItGOdVftvTyG16kDGpzsb22AQGtg65Iygni4PHg" + kid := "kid" + token, pubKey := clerktest.GenerateJWT(t, map[string]any{"iss": "https://clerk.com"}, kid) // Verifying without providing a key returns an error. _, err := Verify(ctx, &VerifyParams{ @@ -19,19 +23,149 @@ func TestVerify_InvalidParams(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "missing json web key") - // Verify needs a key. + // Verifying with wrong public key for the key. _, err = Verify(ctx, &VerifyParams{ Token: token, - JWK: &clerk.JSONWebKey{}, + JWK: &clerk.JSONWebKey{ + Key: nil, + KeyID: kid, + Algorithm: string(jose.EdDSA), + Use: "sig", + }, }) - if err != nil { - require.NotContains(t, err.Error(), "missing json web key") + require.Error(t, err) + + // Verifying with wrong algorithm for the key. + _, err = Verify(ctx, &VerifyParams{ + Token: token, + JWK: &clerk.JSONWebKey{ + Key: pubKey, + KeyID: kid, + Algorithm: string(jose.EdDSA), + Use: "sig", + }, + }) + require.Error(t, err) + + // Verify with correct JSON web key. + validKey := &clerk.JSONWebKey{ + Key: pubKey, + KeyID: kid, + Algorithm: string(jose.RS256), + Use: "sig", } + _, err = Verify(ctx, &VerifyParams{ + Token: token, + JWK: validKey, + }) + require.NoError(t, err) // Try an invalid token. _, err = Verify(ctx, &VerifyParams{ Token: "this-is-not-a-token", - JWK: &clerk.JSONWebKey{}, + JWK: validKey, }) require.Error(t, err) + + // Generate a token with an invalid issuer + token, pubKey = clerktest.GenerateJWT(t, map[string]any{"iss": "https://whatever.com"}, kid) + // Cannot verify if token has invalid issuer + validKey = &clerk.JSONWebKey{ + Key: pubKey, + KeyID: kid, + Algorithm: string(jose.RS256), + Use: "sig", + } + _, err = Verify(ctx, &VerifyParams{ + Token: token, + JWK: validKey, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "issuer") + // Satellite domains don't validate the issuer + _, err = Verify(ctx, &VerifyParams{ + Token: token, + JWK: validKey, + IsSatellite: true, + }) + require.NoError(t, err) + // Issuer must match the proxy + _, err = Verify(ctx, &VerifyParams{ + Token: token, + JWK: validKey, + ProxyURL: clerk.String("https://whatever.com"), + }) + require.NoError(t, err) + _, err = Verify(ctx, &VerifyParams{ + Token: token, + JWK: validKey, + ProxyURL: clerk.String("https://another.com/proxy"), + }) + require.Error(t, err) + require.Contains(t, err.Error(), "issuer") + + // Generate a token with the 'azp' claim. + token, pubKey = clerktest.GenerateJWT( + t, + map[string]any{ + "iss": "https://clerk.com", + "azp": "whatever.com", + }, + kid, + ) + // Cannot verify if 'azp' does not match + validKey = &clerk.JSONWebKey{ + Key: pubKey, + KeyID: kid, + Algorithm: string(jose.RS256), + Use: "sig", + } + _, err = Verify(ctx, &VerifyParams{ + Token: token, + JWK: validKey, + AuthorizedPartyHandler: func(azp string) bool { + return azp == "clerk.com" + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "authorized party") +} + +type testCustomClaims struct { + Domain string `json:"domain"` + Environment string `json:"environment"` +} + +func TestVerify_CustomClaims(t *testing.T) { + t.Parallel() + ctx := context.Background() + kid := "kid" + // Generate a JWT for the following custom claims. + tokenClaims := map[string]any{ + "domain": "clerk.com", + "environment": "production", + "sub": "user_123", + "iss": "https://clerk.com", + } + token, pubKey := clerktest.GenerateJWT(t, tokenClaims, kid) + + customClaimsConstructor := func(_ context.Context) any { + return &testCustomClaims{} + } + claims, err := Verify(ctx, &VerifyParams{ + Token: token, + JWK: &clerk.JSONWebKey{ + Key: pubKey, + KeyID: kid, + Algorithm: string(jose.RS256), + Use: "sig", + }, + CustomClaimsConstructor: customClaimsConstructor, + }) + require.NoError(t, err) + customClaims, ok := claims.Custom.(*testCustomClaims) + require.True(t, ok) + require.Equal(t, "user_123", claims.Subject) + require.Equal(t, "clerk.com", customClaims.Domain) + require.Equal(t, "production", customClaims.Environment) }