Skip to content

Commit

Permalink
Merge pull request #12 from aidantwoods/sentinel-errors
Browse files Browse the repository at this point in the history
Add sentinel error types
  • Loading branch information
aidantwoods authored Jan 1, 2023
2 parents 9baac02 + 315beca commit 0a59ee4
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 16 deletions.
8 changes: 8 additions & 0 deletions claims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ func TestFutureIat(t *testing.T) {

_, err := parser.ParseV4Local(key, encrypted, nil)
require.Error(t, err)
require.ErrorIs(t, err, &paseto.RuleError{})
require.NotErrorIs(t, err, &paseto.TokenError{})
}

func TestFutureNbf(t *testing.T) {
Expand All @@ -146,6 +148,8 @@ func TestFutureNbf(t *testing.T) {

_, err := parser.ParseV4Local(key, encrypted, nil)
require.Error(t, err)
require.ErrorIs(t, err, &paseto.RuleError{})
require.NotErrorIs(t, err, &paseto.TokenError{})
}

func TestFutureNbfNotBeforeNbfRule(t *testing.T) {
Expand Down Expand Up @@ -184,6 +188,8 @@ func TestFutureNbfNotBeforeNbfRuleError(t *testing.T) {

_, err := parser.ParseV4Local(key, encrypted, nil)
require.Error(t, err)
require.ErrorIs(t, err, &paseto.RuleError{})
require.NotErrorIs(t, err, &paseto.TokenError{})
}

func TestPastExp(t *testing.T) {
Expand All @@ -203,6 +209,8 @@ func TestPastExp(t *testing.T) {

_, err := parser.ParseV4Local(key, encrypted, nil)
require.Error(t, err)
require.ErrorIs(t, err, &paseto.RuleError{})
require.NotErrorIs(t, err, &paseto.TokenError{})
}

func TestReadMeExample(t *testing.T) {
Expand Down
71 changes: 58 additions & 13 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,51 @@ package paseto

import "fmt"

// Any cryptography issue (with the token) or formatting error.
// This does not include cryptography errors with input key material, these will
// return regular errors.
type TokenError struct {
e error
}

func (e *TokenError) Error() string {
return e.e.Error()
}

func (_ *TokenError) Is(e error) bool {
_, ok := e.(*TokenError)
return ok
}

func (e *TokenError) Unwrap() error {
return e.e
}

func (e *TokenError) wrapWith(msg string) *TokenError {
return &TokenError{fmt.Errorf("%s: %w", msg, e)}
}

// Any error which is the result of a rule failure (distinct from a TokenError)
// Can be used to detect cryptographically valid tokens which have failed only
// due to a rule failure: which may warrant a slightly different processing
// follow up.
type RuleError struct {
e error
}

func (e *RuleError) Error() string {
return e.e.Error()
}

func (_ *RuleError) Is(e error) bool {
_, ok := e.(*RuleError)
return ok
}

func (e *RuleError) Unwrap() error {
return e.e
}

func errorKeyLength(expected, given int) error {
return fmt.Errorf("key length incorrect (%d), expected %d", given, expected)
}
Expand All @@ -10,32 +55,32 @@ func errorSeedLength(expected, given int) error {
return fmt.Errorf("seed length incorrect (%d), expected %d", given, expected)
}

func errorMessageParts(given int) error {
return fmt.Errorf("invalid number of message parts in token (%d)", given)
func errorMessageParts(given int) *TokenError {
return &TokenError{fmt.Errorf("invalid number of message parts in token (%d)", given)}
}

func errorMessageHeader(expected Protocol, givenHeader string) error {
return fmt.Errorf("message header `%s' is not valid, expected `%s'", givenHeader, expected.Header())
func errorMessageHeader(expected Protocol, givenHeader string) *TokenError {
return &TokenError{fmt.Errorf("message header `%s' is not valid, expected `%s'", givenHeader, expected.Header())}
}

func errorMessageHeaderDecrypt(expected Protocol, givenHeader string) error {
return fmt.Errorf("cannot decrypt message: %w", errorMessageHeader(expected, givenHeader))
func errorMessageHeaderDecrypt(expected Protocol, givenHeader string) *TokenError {
return errorMessageHeader(expected, givenHeader).wrapWith("cannot decrypt message")
}

func errorMessageHeaderVerify(expected Protocol, givenHeader string) error {
return fmt.Errorf("cannot verify message: %w", errorMessageHeader(expected, givenHeader))
func errorMessageHeaderVerify(expected Protocol, givenHeader string) *TokenError {
return errorMessageHeader(expected, givenHeader).wrapWith("cannot verify message")
}

var unsupportedPasetoVersion = fmt.Errorf("unsupported PASETO version")
var unsupportedPasetoPurpose = fmt.Errorf("unsupported PASETO purpose")
var unsupportedPayload = fmt.Errorf("unsupported payload")

var errorPayloadShort = fmt.Errorf("payload is not long enough to be a valid PASETO message")
var errorBadSignature = fmt.Errorf("bad signature")
var errorBadMAC = fmt.Errorf("bad message authentication code")
var errorPayloadShort = &TokenError{fmt.Errorf("payload is not long enough to be a valid PASETO message")}
var errorBadSignature = &TokenError{fmt.Errorf("bad signature")}
var errorBadMAC = &TokenError{fmt.Errorf("bad message authentication code")}

var errorKeyInvalid = fmt.Errorf("key was not valid")

func errorDecrypt(err error) error {
return fmt.Errorf("the message could not be decrypted: %w", err)
func errorDecrypt(err error) *TokenError {
return (&TokenError{err}).wrapWith("the message could not be decrypted")
}
4 changes: 4 additions & 0 deletions keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ func TestV2AsymmetricSecretKeyImport(t *testing.T) {

_, err := paseto.NewV2AsymmetricSecretKeyFromHex(badKey)
require.Error(t, err)
require.NotErrorIs(t, err, &paseto.RuleError{})
require.NotErrorIs(t, err, &paseto.TokenError{})

goodKey := "b4cbfb43df4ce210727d953e4a713307fa19bb7d9f85041438d9e11b942a37741eb9dbbbbc047c03fd70604e0071f0987e16b28b757225c11f00415d0e20b1a2"

Expand All @@ -24,6 +26,8 @@ func TestV4AsymmetricSecretKeyImport(t *testing.T) {

_, err := paseto.NewV4AsymmetricSecretKeyFromHex(badKey)
require.Error(t, err)
require.NotErrorIs(t, err, &paseto.RuleError{})
require.NotErrorIs(t, err, &paseto.TokenError{})

goodKey := "b4cbfb43df4ce210727d953e4a713307fa19bb7d9f85041438d9e11b942a37741eb9dbbbbc047c03fd70604e0071f0987e16b28b757225c11f00415d0e20b1a2"

Expand Down
4 changes: 2 additions & 2 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ func newMessage(protocol Protocol, token string) (message, error) {

payloadBytes, err := encoding.Decode(encodedPayload)
if err != nil {
return message{}, err
return message{}, &TokenError{err}
}

footer, err := encoding.Decode(encodedFooter)
if err != nil {
return message{}, err
return message{}, &TokenError{err}
}

payload, err := protocol.newPayload(payloadBytes)
Expand Down
2 changes: 1 addition & 1 deletion parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (p *Parser) AddRule(rule ...Rule) {
func (p Parser) validate(token Token) (*Token, error) {
for _, rule := range p.rules {
if err := rule(token); err != nil {
return nil, err
return nil, &RuleError{err}
}
}

Expand Down
8 changes: 8 additions & 0 deletions token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ func TestSomeInt(t *testing.T) {
var output string
err = token.Get("foo", &output)
require.Error(t, err)
require.NotErrorIs(t, err, &paseto.RuleError{})
require.NotErrorIs(t, err, &paseto.TokenError{})

var intOutput int
err = token.Get("foo", &intOutput)
Expand All @@ -46,6 +48,8 @@ func TestSomeBool(t *testing.T) {
var intOutput int
err = token.Get("foo", &intOutput)
require.Error(t, err)
require.NotErrorIs(t, err, &paseto.RuleError{})
require.NotErrorIs(t, err, &paseto.TokenError{})

var output bool
err = token.Get("foo", &output)
Expand Down Expand Up @@ -92,6 +96,8 @@ func TestSomeWrongType(t *testing.T) {
var output bool
err = token.Get("baz", &output)
require.Error(t, err)
require.NotErrorIs(t, err, &paseto.RuleError{})
require.NotErrorIs(t, err, &paseto.TokenError{})
}

func TestSomeWrongKey(t *testing.T) {
Expand All @@ -103,6 +109,8 @@ func TestSomeWrongKey(t *testing.T) {
var output string
err = token.Get("bar", &output)
require.Error(t, err)
require.NotErrorIs(t, err, &paseto.RuleError{})
require.NotErrorIs(t, err, &paseto.TokenError{})
}

func TestFromMap(t *testing.T) {
Expand Down
24 changes: 24 additions & 0 deletions vectors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,17 @@ func TestV2(t *testing.T) {
message, err := paseto.NewMessage(paseto.V2Local, test.Token)
if test.ExpectFail {
require.Error(t, err)
require.ErrorIs(t, err, &paseto.TokenError{})
require.NotErrorIs(t, err, &paseto.RuleError{})
return
}
require.NoError(t, err)

decoded, err = paseto.V2LocalDecrypt(message, sk)
if test.ExpectFail {
require.Error(t, err)
require.ErrorIs(t, err, &paseto.TokenError{})
require.NotErrorIs(t, err, &paseto.RuleError{})
return
}
require.NoError(t, err)
Expand All @@ -69,13 +73,17 @@ func TestV2(t *testing.T) {
message, err := paseto.NewMessage(paseto.V2Public, test.Token)
if test.ExpectFail {
require.Error(t, err)
require.ErrorIs(t, err, &paseto.TokenError{})
require.NotErrorIs(t, err, &paseto.RuleError{})
return
}
require.NoError(t, err)

decoded, err = paseto.V2PublicVerify(message, pk)
if test.ExpectFail {
require.Error(t, err)
require.ErrorIs(t, err, &paseto.TokenError{})
require.NotErrorIs(t, err, &paseto.RuleError{})
return
}
require.NoError(t, err)
Expand Down Expand Up @@ -135,13 +143,17 @@ func TestV3(t *testing.T) {
message, err := paseto.NewMessage(paseto.V3Local, test.Token)
if test.ExpectFail {
require.Error(t, err)
require.ErrorIs(t, err, &paseto.TokenError{})
require.NotErrorIs(t, err, &paseto.RuleError{})
return
}
require.NoError(t, err)

decoded, err = paseto.V3LocalDecrypt(message, sk, []byte(test.ImplicitAssertation))
if test.ExpectFail {
require.Error(t, err)
require.ErrorIs(t, err, &paseto.TokenError{})
require.NotErrorIs(t, err, &paseto.RuleError{})
return
}
require.NoError(t, err)
Expand All @@ -154,13 +166,17 @@ func TestV3(t *testing.T) {
message, err := paseto.NewMessage(paseto.V3Public, test.Token)
if test.ExpectFail {
require.Error(t, err)
require.ErrorIs(t, err, &paseto.TokenError{})
require.NotErrorIs(t, err, &paseto.RuleError{})
return
}
require.NoError(t, err)

decoded, err = paseto.V3PublicVerify(message, pk, []byte(test.ImplicitAssertation))
if test.ExpectFail {
require.Error(t, err)
require.ErrorIs(t, err, &paseto.TokenError{})
require.NotErrorIs(t, err, &paseto.RuleError{})
return
}
require.NoError(t, err)
Expand Down Expand Up @@ -235,13 +251,17 @@ func TestV4(t *testing.T) {
message, err := paseto.NewMessage(paseto.V4Local, test.Token)
if test.ExpectFail {
require.Error(t, err)
require.ErrorIs(t, err, &paseto.TokenError{})
require.NotErrorIs(t, err, &paseto.RuleError{})
return
}
require.NoError(t, err)

decoded, err = paseto.V4LocalDecrypt(message, sk, []byte(test.ImplicitAssertation))
if test.ExpectFail {
require.Error(t, err)
require.ErrorIs(t, err, &paseto.TokenError{})
require.NotErrorIs(t, err, &paseto.RuleError{})
return
}
require.NoError(t, err)
Expand All @@ -254,13 +274,17 @@ func TestV4(t *testing.T) {
message, err := paseto.NewMessage(paseto.V4Public, test.Token)
if test.ExpectFail {
require.Error(t, err)
require.ErrorIs(t, err, &paseto.TokenError{})
require.NotErrorIs(t, err, &paseto.RuleError{})
return
}
require.NoError(t, err)

decoded, err = paseto.V4PublicVerify(message, pk, []byte(test.ImplicitAssertation))
if test.ExpectFail {
require.Error(t, err)
require.ErrorIs(t, err, &paseto.TokenError{})
require.NotErrorIs(t, err, &paseto.RuleError{})
return
}
require.NoError(t, err)
Expand Down

0 comments on commit 0a59ee4

Please sign in to comment.