Skip to content

Commit

Permalink
incorporate the generic source type into validation
Browse files Browse the repository at this point in the history
  • Loading branch information
johnabass committed Jul 11, 2024
1 parent 5d8a265 commit 7d7da18
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 98 deletions.
12 changes: 6 additions & 6 deletions basculehttp/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ func WithTokenParser(scheme bascule.Scheme, tp bascule.TokenParser[*http.Request
// WithAuthentication adds validators used for authentication to this Middleware. Each
// invocation of this option is cumulative. Authentication validators are run in the order
// supplied by this option.
func WithAuthentication(v ...bascule.Validator) MiddlewareOption {
func WithAuthentication(v ...bascule.Validator[*http.Request]) MiddlewareOption {
return middlewareOptionFunc(func(m *Middleware) error {
m.authentication.Add(v...)
m.authentication = m.authentication.Append(v...)
return nil
})
}
Expand Down Expand Up @@ -114,7 +114,7 @@ func WithErrorMarshaler(em ErrorMarshaler) MiddlewareOption {
type Middleware struct {
credentialsParser bascule.CredentialsParser[*http.Request]
tokenParsers bascule.TokenParsers[*http.Request]
authentication bascule.Validators
authentication bascule.Validators[*http.Request]
authorization bascule.Authorizers[*http.Request]
challenges Challenges

Expand Down Expand Up @@ -194,8 +194,8 @@ func (m *Middleware) getCredentialsAndToken(ctx context.Context, request *http.R
return
}

func (m *Middleware) authenticate(ctx context.Context, token bascule.Token) error {
return m.authentication.Validate(ctx, token)
func (m *Middleware) authenticate(ctx context.Context, request *http.Request, token bascule.Token) (bascule.Token, error) {
return m.authentication.Validate(ctx, request, token)
}

func (m *Middleware) authorize(ctx context.Context, token bascule.Token, request *http.Request) error {
Expand All @@ -220,7 +220,7 @@ func (fd *frontDoor) ServeHTTP(response http.ResponseWriter, request *http.Reque
}

ctx = bascule.WithCredentials(ctx, creds)
err = fd.middleware.authenticate(ctx, token)
token, err = fd.middleware.authenticate(ctx, request, token)
if err != nil {
// at this point in the workflow, the request has valid credentials. we use
// StatusForbidden as the default because any failure to authenticate isn't a
Expand Down
224 changes: 198 additions & 26 deletions validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,50 +5,222 @@ package bascule

import (
"context"
"reflect"
)

// Validator represents a general strategy for validating tokens. Token validation
// typically happens during authentication, but it can also happen during parsing
// if a caller uses NewValidatingTokenParser.
type Validator interface {
// typically happens during authentication.
type Validator[S any] interface {
// Validate validates a token. If this validator needs to interact
// with external systems, the supplied context can be passed to honor
// cancelation semantics.
// cancelation semantics. Additionally, the source object from which the
// token was taken is made available.
//
// This method may be passed a token that it doesn't support, e.g. a Basic
// validator can be passed a JWT token. In that case, this method should
// simply return nil.
Validate(context.Context, Token) error
// simply return a nil error.
//
// If this method returns a nil token, then the supplied token should be used
// as is. If this method returns a non-nil token, that new new token should be
// used instead. This allows a validator to augment a token with additional
// data, possibly from an external system or database.
Validate(ctx context.Context, source S, t Token) (Token, error)
}

// ValidatorFunc is a closure type that implements Validator.
type ValidatorFunc func(context.Context, Token) error
// Validate applies several validators to the given token. Although each individual
// validator may return a nil Token to indicate that there is no change in the token,
// this function will always return a non-nil Token.
//
// This function returns the validated Token and a nil error to indicate success.
// If any validator fails, this function halts further validation and returns
// the error.
func Validate[S any](ctx context.Context, source S, original Token, v ...Validator[S]) (validated Token, err error) {
next := original
for i, prev := 0, next; err == nil && i < len(v); i, prev = i+1, next {
next, err = v[i].Validate(ctx, source, prev)
if next == nil {
// no change in the token
next = prev
}
}

if err == nil {
validated = next
}

func (vf ValidatorFunc) Validate(ctx context.Context, token Token) error {
return vf(ctx, token)
return
}

// Validators is an aggregate Validator.
type Validators []Validator
// Validators is an aggregate Validator that returns validity if and only if
// all of its contained validators return validity.
type Validators[S any] []Validator[S]

// Append tacks on more validators to this aggregate, returning the possibly new
// instance. The semantics of this method are the same as the built-in append.
func (vs Validators[S]) Append(more ...Validator[S]) Validators[S] {
return append(vs, more...)
}

// Validate executes each contained validator in order, returning validity only
// if all validators pass. Any validation failure prevents subsequent validators
// from running.
func (vs Validators[S]) Validate(ctx context.Context, source S, t Token) (Token, error) {
return Validate(ctx, source, t, vs...)
}

// ValidatorFunc defines the closure signatures that are allowed as Validator instances.
type ValidatorFunc[S any] interface {
~func(Token) error |
~func(S, Token) error |
~func(Token) (Token, error) |
~func(S, Token) (Token, error) |
~func(context.Context, Token) error |
~func(context.Context, S, Token) error |
~func(context.Context, Token) (Token, error) |
~func(context.Context, S, Token) (Token, error)
}

// validatorFunc is an internal type that implements Validator. Used to normalize
// and uncurry a closure.
type validatorFunc[S any] func(context.Context, S, Token) (Token, error)

func (vf validatorFunc[S]) Validate(ctx context.Context, source S, t Token) (Token, error) {
return vf(ctx, source, t)
}

var (
tokenReturnsError = reflect.TypeOf((func(Token) error)(nil))
tokenReturnsTokenError = reflect.TypeOf((func(Token) (Token, error))(nil))
contextTokenReturnsError = reflect.TypeOf((func(context.Context, Token) error)(nil))
contextTokenReturnsTokenError = reflect.TypeOf((func(context.Context, Token) (Token, error))(nil))
)

// asValidatorSimple tries simple conversions on f. This function will not catch
// user-defined types.
func asValidatorSimple[S any, F ValidatorFunc[S]](f F) (v Validator[S]) {
switch vf := any(f).(type) {
case func(Token) error:
v = validatorFunc[S](
func(ctx context.Context, source S, t Token) (Token, error) {
return nil, vf(t)
},
)

// Add appends validators to this aggregate Validators.
func (vs *Validators) Add(v ...Validator) {
if *vs == nil {
*vs = make(Validators, 0, len(v))
case func(S, Token) error:
v = validatorFunc[S](
func(ctx context.Context, source S, t Token) (Token, error) {
return nil, vf(source, t)
},
)

case func(Token) (Token, error):
v = validatorFunc[S](
func(ctx context.Context, source S, t Token) (next Token, err error) {
next, err = vf(t)
if next == nil {
next = t
}

return
},
)

case func(S, Token) (Token, error):
v = validatorFunc[S](
func(ctx context.Context, source S, t Token) (next Token, err error) {
next, err = vf(source, t)
if next == nil {
next = t
}

return
},
)

case func(context.Context, Token) error:
v = validatorFunc[S](
func(ctx context.Context, source S, t Token) (Token, error) {
return nil, vf(ctx, t)
},
)

case func(context.Context, S, Token) error:
v = validatorFunc[S](
func(ctx context.Context, source S, t Token) (Token, error) {
return nil, vf(ctx, source, t)
},
)

case func(context.Context, Token) (Token, error):
v = validatorFunc[S](
func(ctx context.Context, source S, t Token) (next Token, err error) {
next, err = vf(ctx, t)
if next == nil {
next = t
}

return
},
)

case func(context.Context, S, Token) (Token, error):
v = validatorFunc[S](vf)
}

*vs = append(*vs, v...)
return
}

// Validate applies each validator in sequence. Execution stops at the first validator
// that returns an error, and that error is returned. If all validators return nil,
// this method returns nil, indicating the Token is valid.
func (vs Validators) Validate(ctx context.Context, token Token) error {
for _, v := range vs {
if err := v.Validate(ctx, token); err != nil {
return err
}
// AsValidator takes a ValidatorFunc closure and returns a Validator instance that
// executes that closure.
func AsValidator[S any, F ValidatorFunc[S]](f F) Validator[S] {
// first, try the simple way:
if v := asValidatorSimple[S](f); v != nil {
return v
}

// next, support user-defined types that are closures that do not
// require the source type.
fVal := reflect.ValueOf(f)
switch {
case fVal.CanConvert(tokenReturnsError):
return asValidatorSimple[S](
fVal.Convert(tokenReturnsError).Interface().(func(Token) error),
)

case fVal.CanConvert(tokenReturnsTokenError):
return asValidatorSimple[S](
fVal.Convert(tokenReturnsError).Interface().(func(Token) (Token, error)),
)

case fVal.CanConvert(contextTokenReturnsError):
return asValidatorSimple[S](
fVal.Convert(contextTokenReturnsError).Interface().(func(context.Context, Token) error),
)

case fVal.CanConvert(contextTokenReturnsTokenError):
return asValidatorSimple[S](
fVal.Convert(contextTokenReturnsError).Interface().(func(context.Context, Token) (Token, error)),
)
}

return nil
// finally: user-defined types that are closures involving the source type S.
// we have to look these up here, due to the way generics in golang work.
if ft := reflect.TypeOf((func(S, Token) error)(nil)); fVal.CanConvert(ft) {
return asValidatorSimple[S](
fVal.Convert(ft).Interface().(func(S, Token) error),
)
} else if ft := reflect.TypeOf((func(S, Token) (Token, error))(nil)); fVal.CanConvert(ft) {
return asValidatorSimple[S](
fVal.Convert(ft).Interface().(func(S, Token) (Token, error)),
)
} else if ft := reflect.TypeOf((func(context.Context, S, Token) error)(nil)); fVal.CanConvert(ft) {
return asValidatorSimple[S](
fVal.Convert(ft).Interface().(func(context.Context, S, Token) error),
)
} else {
// we know this can be converted to this final type
return asValidatorSimple[S](
fVal.Convert(ft).Interface().(func(context.Context, S, Token) (Token, error)),
)
}
}
66 changes: 0 additions & 66 deletions validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
package bascule

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/suite"
Expand All @@ -15,70 +13,6 @@ type ValidatorsTestSuite struct {
TestSuite
}

func (suite *ValidatorsTestSuite) TestValidate() {
validateErr := errors.New("expected Validate error")

testCases := []struct {
name string
results []error
expectedErr error
}{
{
name: "EmptyValidators",
results: nil,
},
{
name: "OneSuccess",
results: []error{nil},
},
{
name: "OneFailure",
results: []error{validateErr},
expectedErr: validateErr,
},
{
name: "FirstFailure",
results: []error{validateErr, errors.New("should not be called")},
expectedErr: validateErr,
},
{
name: "MiddleFailure",
results: []error{nil, validateErr, errors.New("should not be called")},
expectedErr: validateErr,
},
{
name: "AllSuccess",
results: []error{nil, nil, nil},
},
}

for _, testCase := range testCases {
suite.Run(testCase.name, func() {
var (
testCtx = suite.testContext()
testToken = suite.testToken()
vs Validators
)

for _, err := range testCase.results {
err := err
vs.Add(
ValidatorFunc(func(ctx context.Context, token Token) error {
suite.Same(testCtx, ctx)
suite.Same(testToken, token)
return err
}),
)
}

suite.Equal(
testCase.expectedErr,
vs.Validate(testCtx, testToken),
)
})
}
}

func TestValidators(t *testing.T) {
suite.Run(t, new(ValidatorsTestSuite))
}

0 comments on commit 7d7da18

Please sign in to comment.