Skip to content

Commit

Permalink
Merge pull request #1 from covidtrace/feat/identifier-role-jwt
Browse files Browse the repository at this point in the history
Add identifier and role to JWT
  • Loading branch information
joshgummersall authored Apr 28, 2020
2 parents 695c77c + 06032d6 commit 4dfdcd6
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 31 deletions.
72 changes: 43 additions & 29 deletions jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import (

// Claims represents the JWT claims covidtrace cares about
type Claims struct {
Hash string `json:"covidtrace:hash"`
Refreshed int `json:"covidtrace:refreshed"`
Hash string `json:"covidtrace:hash,omitempty"`
Identifier string `json:"covidtrace:identifier,omitempty"`
Refreshed int `json:"covidtrace:refreshed"`
Role string `json:"covidtrace:role,omitempty"`
jwt.StandardClaims
}

Expand Down Expand Up @@ -56,11 +58,13 @@ func (i *Issuer) WithDur(dur time.Duration) *Issuer {
}

// Claims constructs a new Claims object, filling details in from i
func (i *Issuer) Claims(hash string, refresh int) *Claims {
func (i *Issuer) Claims(hash string, refresh int, identifier, role string) *Claims {
return &Claims{
hash,
refresh,
jwt.StandardClaims{
Hash: hash,
Identifier: identifier,
Refreshed: refresh,
Role: role,
StandardClaims: jwt.StandardClaims{
Audience: i.aud,
Issuer: i.iss,
ExpiresAt: time.Now().Add(i.dur).Unix(),
Expand All @@ -70,11 +74,36 @@ func (i *Issuer) Claims(hash string, refresh int) *Claims {

// Token handles generating a signed JWT token with the given `hash` and
// `refresh` count
func (i *Issuer) Token(hash string, refresh int) (string, error) {
t := jwt.NewWithClaims(i.sm, i.Claims(hash, refresh))
func (i *Issuer) Token(hash string, refresh int, identifier, role string) (string, error) {
t := jwt.NewWithClaims(i.sm, i.Claims(hash, refresh, identifier, role))
return t.SignedString(i.key)
}

func getClaimString(claims jwt.MapClaims, key, def string) string {
result := def
if iface, ok := claims[key]; ok {
if str, ok := iface.(string); ok {
result = str
}
}
return result
}

func getClaimFloat64(claims jwt.MapClaims, key string, def float64) float64 {
result := def

iface, ok := claims[key]
if !ok {
iface = def
}

if flt, ok := iface.(float64); ok {
result = flt
}

return result
}

// Validate handles ensuring `signedString` is a valid JWT issued by this
// issuer. It returns the `hash` and `refreshed` claims, or an `error` if the
// token is invalid
Expand Down Expand Up @@ -116,25 +145,10 @@ func (i *Issuer) Validate(signedString string) (*Claims, error) {
return nil, fmt.Errorf("Invalid aud: %v", aud)
}

hashi, ok := claims["covidtrace:hash"]
if !ok {
return nil, errors.New("Missing hash")
}

hash, ok := hashi.(string)
if !ok {
return nil, fmt.Errorf("Invalid hash: %v", hashi)
}

refreshi, ok := claims["covidtrace:refreshed"]
if !ok {
refreshi = 0
}

refresh, ok := refreshi.(float64)
if !ok {
return nil, fmt.Errorf("Invalid refresh: %v", refreshi)
}

return i.Claims(hash, int(refresh)), nil
return i.Claims(
getClaimString(claims, "covidtrace:hash", ""),
int(getClaimFloat64(claims, "covidtrace:refresh", 0)),
getClaimString(claims, "covidtrace:identifier", ""),
getClaimString(claims, "covidtrace:role", ""),
), nil
}
4 changes: 2 additions & 2 deletions jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func init() {
}

func TestIssuer(t *testing.T) {
token, err := issuer.Token("hash", 0)
token, err := issuer.Token("hash", 0, "identifier", "role")
if err != nil {
t.Error(err)
}
Expand All @@ -51,7 +51,7 @@ func TestExpired(t *testing.T) {
t.Error(err)
}

token, err := issuer.Token("hash", 0)
token, err := issuer.Token("hash", 0, "identifier", "role")
if err != nil {
t.Error(err)
}
Expand Down

0 comments on commit 4dfdcd6

Please sign in to comment.