diff --git a/jwt.go b/jwt.go index 73ffcbb..36af74f 100644 --- a/jwt.go +++ b/jwt.go @@ -10,8 +10,10 @@ import ( // Claims represents the JWT claims covidtrace cares about type Claims struct { - Hash string `json:"covidtrace:hash"` - Refreshed int `json:"covidtrace:refreshed"` + Hash string `json:"covidtrace:hash,omitempty"` + Identifier string `json:"covidtrace:identifier,omitempty"` + Refreshed int `json:"covidtrace:refreshed"` + Role string `json:"covidtrace:role,omitempty"` jwt.StandardClaims } @@ -56,11 +58,13 @@ func (i *Issuer) WithDur(dur time.Duration) *Issuer { } // Claims constructs a new Claims object, filling details in from i -func (i *Issuer) Claims(hash string, refresh int) *Claims { +func (i *Issuer) Claims(hash string, refresh int, identifier, role string) *Claims { return &Claims{ - hash, - refresh, - jwt.StandardClaims{ + Hash: hash, + Identifier: identifier, + Refreshed: refresh, + Role: role, + StandardClaims: jwt.StandardClaims{ Audience: i.aud, Issuer: i.iss, ExpiresAt: time.Now().Add(i.dur).Unix(), @@ -70,11 +74,36 @@ func (i *Issuer) Claims(hash string, refresh int) *Claims { // Token handles generating a signed JWT token with the given `hash` and // `refresh` count -func (i *Issuer) Token(hash string, refresh int) (string, error) { - t := jwt.NewWithClaims(i.sm, i.Claims(hash, refresh)) +func (i *Issuer) Token(hash string, refresh int, identifier, role string) (string, error) { + t := jwt.NewWithClaims(i.sm, i.Claims(hash, refresh, identifier, role)) return t.SignedString(i.key) } +func getClaimString(claims jwt.MapClaims, key, def string) string { + result := def + if iface, ok := claims[key]; ok { + if str, ok := iface.(string); ok { + result = str + } + } + return result +} + +func getClaimFloat64(claims jwt.MapClaims, key string, def float64) float64 { + result := def + + iface, ok := claims[key] + if !ok { + iface = def + } + + if flt, ok := iface.(float64); ok { + result = flt + } + + return result +} + // Validate handles ensuring `signedString` is a valid JWT issued by this // issuer. It returns the `hash` and `refreshed` claims, or an `error` if the // token is invalid @@ -116,25 +145,10 @@ func (i *Issuer) Validate(signedString string) (*Claims, error) { return nil, fmt.Errorf("Invalid aud: %v", aud) } - hashi, ok := claims["covidtrace:hash"] - if !ok { - return nil, errors.New("Missing hash") - } - - hash, ok := hashi.(string) - if !ok { - return nil, fmt.Errorf("Invalid hash: %v", hashi) - } - - refreshi, ok := claims["covidtrace:refreshed"] - if !ok { - refreshi = 0 - } - - refresh, ok := refreshi.(float64) - if !ok { - return nil, fmt.Errorf("Invalid refresh: %v", refreshi) - } - - return i.Claims(hash, int(refresh)), nil + return i.Claims( + getClaimString(claims, "covidtrace:hash", ""), + int(getClaimFloat64(claims, "covidtrace:refresh", 0)), + getClaimString(claims, "covidtrace:identifier", ""), + getClaimString(claims, "covidtrace:role", ""), + ), nil } diff --git a/jwt_test.go b/jwt_test.go index 2d83757..71b5e89 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -25,7 +25,7 @@ func init() { } func TestIssuer(t *testing.T) { - token, err := issuer.Token("hash", 0) + token, err := issuer.Token("hash", 0, "identifier", "role") if err != nil { t.Error(err) } @@ -51,7 +51,7 @@ func TestExpired(t *testing.T) { t.Error(err) } - token, err := issuer.Token("hash", 0) + token, err := issuer.Token("hash", 0, "identifier", "role") if err != nil { t.Error(err) }