From 7d7da1852873e366c82cf629324a38881934a6ba Mon Sep 17 00:00:00 2001 From: johnabass Date: Wed, 10 Jul 2024 21:55:35 -0700 Subject: [PATCH] incorporate the generic source type into validation --- basculehttp/middleware.go | 12 +- validator.go | 224 +++++++++++++++++++++++++++++++++----- validator_test.go | 66 ----------- 3 files changed, 204 insertions(+), 98 deletions(-) diff --git a/basculehttp/middleware.go b/basculehttp/middleware.go index 0c9f4b5..a94c392 100644 --- a/basculehttp/middleware.go +++ b/basculehttp/middleware.go @@ -51,9 +51,9 @@ func WithTokenParser(scheme bascule.Scheme, tp bascule.TokenParser[*http.Request // WithAuthentication adds validators used for authentication to this Middleware. Each // invocation of this option is cumulative. Authentication validators are run in the order // supplied by this option. -func WithAuthentication(v ...bascule.Validator) MiddlewareOption { +func WithAuthentication(v ...bascule.Validator[*http.Request]) MiddlewareOption { return middlewareOptionFunc(func(m *Middleware) error { - m.authentication.Add(v...) + m.authentication = m.authentication.Append(v...) return nil }) } @@ -114,7 +114,7 @@ func WithErrorMarshaler(em ErrorMarshaler) MiddlewareOption { type Middleware struct { credentialsParser bascule.CredentialsParser[*http.Request] tokenParsers bascule.TokenParsers[*http.Request] - authentication bascule.Validators + authentication bascule.Validators[*http.Request] authorization bascule.Authorizers[*http.Request] challenges Challenges @@ -194,8 +194,8 @@ func (m *Middleware) getCredentialsAndToken(ctx context.Context, request *http.R return } -func (m *Middleware) authenticate(ctx context.Context, token bascule.Token) error { - return m.authentication.Validate(ctx, token) +func (m *Middleware) authenticate(ctx context.Context, request *http.Request, token bascule.Token) (bascule.Token, error) { + return m.authentication.Validate(ctx, request, token) } func (m *Middleware) authorize(ctx context.Context, token bascule.Token, request *http.Request) error { @@ -220,7 +220,7 @@ func (fd *frontDoor) ServeHTTP(response http.ResponseWriter, request *http.Reque } ctx = bascule.WithCredentials(ctx, creds) - err = fd.middleware.authenticate(ctx, token) + token, err = fd.middleware.authenticate(ctx, request, token) if err != nil { // at this point in the workflow, the request has valid credentials. we use // StatusForbidden as the default because any failure to authenticate isn't a diff --git a/validator.go b/validator.go index 487f16b..5aaa32b 100644 --- a/validator.go +++ b/validator.go @@ -5,50 +5,222 @@ package bascule import ( "context" + "reflect" ) // Validator represents a general strategy for validating tokens. Token validation -// typically happens during authentication, but it can also happen during parsing -// if a caller uses NewValidatingTokenParser. -type Validator interface { +// typically happens during authentication. +type Validator[S any] interface { // Validate validates a token. If this validator needs to interact // with external systems, the supplied context can be passed to honor - // cancelation semantics. + // cancelation semantics. Additionally, the source object from which the + // token was taken is made available. // // This method may be passed a token that it doesn't support, e.g. a Basic // validator can be passed a JWT token. In that case, this method should - // simply return nil. - Validate(context.Context, Token) error + // simply return a nil error. + // + // If this method returns a nil token, then the supplied token should be used + // as is. If this method returns a non-nil token, that new new token should be + // used instead. This allows a validator to augment a token with additional + // data, possibly from an external system or database. + Validate(ctx context.Context, source S, t Token) (Token, error) } -// ValidatorFunc is a closure type that implements Validator. -type ValidatorFunc func(context.Context, Token) error +// Validate applies several validators to the given token. Although each individual +// validator may return a nil Token to indicate that there is no change in the token, +// this function will always return a non-nil Token. +// +// This function returns the validated Token and a nil error to indicate success. +// If any validator fails, this function halts further validation and returns +// the error. +func Validate[S any](ctx context.Context, source S, original Token, v ...Validator[S]) (validated Token, err error) { + next := original + for i, prev := 0, next; err == nil && i < len(v); i, prev = i+1, next { + next, err = v[i].Validate(ctx, source, prev) + if next == nil { + // no change in the token + next = prev + } + } + + if err == nil { + validated = next + } -func (vf ValidatorFunc) Validate(ctx context.Context, token Token) error { - return vf(ctx, token) + return } -// Validators is an aggregate Validator. -type Validators []Validator +// Validators is an aggregate Validator that returns validity if and only if +// all of its contained validators return validity. +type Validators[S any] []Validator[S] + +// Append tacks on more validators to this aggregate, returning the possibly new +// instance. The semantics of this method are the same as the built-in append. +func (vs Validators[S]) Append(more ...Validator[S]) Validators[S] { + return append(vs, more...) +} + +// Validate executes each contained validator in order, returning validity only +// if all validators pass. Any validation failure prevents subsequent validators +// from running. +func (vs Validators[S]) Validate(ctx context.Context, source S, t Token) (Token, error) { + return Validate(ctx, source, t, vs...) +} + +// ValidatorFunc defines the closure signatures that are allowed as Validator instances. +type ValidatorFunc[S any] interface { + ~func(Token) error | + ~func(S, Token) error | + ~func(Token) (Token, error) | + ~func(S, Token) (Token, error) | + ~func(context.Context, Token) error | + ~func(context.Context, S, Token) error | + ~func(context.Context, Token) (Token, error) | + ~func(context.Context, S, Token) (Token, error) +} + +// validatorFunc is an internal type that implements Validator. Used to normalize +// and uncurry a closure. +type validatorFunc[S any] func(context.Context, S, Token) (Token, error) + +func (vf validatorFunc[S]) Validate(ctx context.Context, source S, t Token) (Token, error) { + return vf(ctx, source, t) +} + +var ( + tokenReturnsError = reflect.TypeOf((func(Token) error)(nil)) + tokenReturnsTokenError = reflect.TypeOf((func(Token) (Token, error))(nil)) + contextTokenReturnsError = reflect.TypeOf((func(context.Context, Token) error)(nil)) + contextTokenReturnsTokenError = reflect.TypeOf((func(context.Context, Token) (Token, error))(nil)) +) + +// asValidatorSimple tries simple conversions on f. This function will not catch +// user-defined types. +func asValidatorSimple[S any, F ValidatorFunc[S]](f F) (v Validator[S]) { + switch vf := any(f).(type) { + case func(Token) error: + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (Token, error) { + return nil, vf(t) + }, + ) -// Add appends validators to this aggregate Validators. -func (vs *Validators) Add(v ...Validator) { - if *vs == nil { - *vs = make(Validators, 0, len(v)) + case func(S, Token) error: + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (Token, error) { + return nil, vf(source, t) + }, + ) + + case func(Token) (Token, error): + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (next Token, err error) { + next, err = vf(t) + if next == nil { + next = t + } + + return + }, + ) + + case func(S, Token) (Token, error): + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (next Token, err error) { + next, err = vf(source, t) + if next == nil { + next = t + } + + return + }, + ) + + case func(context.Context, Token) error: + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (Token, error) { + return nil, vf(ctx, t) + }, + ) + + case func(context.Context, S, Token) error: + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (Token, error) { + return nil, vf(ctx, source, t) + }, + ) + + case func(context.Context, Token) (Token, error): + v = validatorFunc[S]( + func(ctx context.Context, source S, t Token) (next Token, err error) { + next, err = vf(ctx, t) + if next == nil { + next = t + } + + return + }, + ) + + case func(context.Context, S, Token) (Token, error): + v = validatorFunc[S](vf) } - *vs = append(*vs, v...) + return } -// Validate applies each validator in sequence. Execution stops at the first validator -// that returns an error, and that error is returned. If all validators return nil, -// this method returns nil, indicating the Token is valid. -func (vs Validators) Validate(ctx context.Context, token Token) error { - for _, v := range vs { - if err := v.Validate(ctx, token); err != nil { - return err - } +// AsValidator takes a ValidatorFunc closure and returns a Validator instance that +// executes that closure. +func AsValidator[S any, F ValidatorFunc[S]](f F) Validator[S] { + // first, try the simple way: + if v := asValidatorSimple[S](f); v != nil { + return v + } + + // next, support user-defined types that are closures that do not + // require the source type. + fVal := reflect.ValueOf(f) + switch { + case fVal.CanConvert(tokenReturnsError): + return asValidatorSimple[S]( + fVal.Convert(tokenReturnsError).Interface().(func(Token) error), + ) + + case fVal.CanConvert(tokenReturnsTokenError): + return asValidatorSimple[S]( + fVal.Convert(tokenReturnsError).Interface().(func(Token) (Token, error)), + ) + + case fVal.CanConvert(contextTokenReturnsError): + return asValidatorSimple[S]( + fVal.Convert(contextTokenReturnsError).Interface().(func(context.Context, Token) error), + ) + + case fVal.CanConvert(contextTokenReturnsTokenError): + return asValidatorSimple[S]( + fVal.Convert(contextTokenReturnsError).Interface().(func(context.Context, Token) (Token, error)), + ) } - return nil + // finally: user-defined types that are closures involving the source type S. + // we have to look these up here, due to the way generics in golang work. + if ft := reflect.TypeOf((func(S, Token) error)(nil)); fVal.CanConvert(ft) { + return asValidatorSimple[S]( + fVal.Convert(ft).Interface().(func(S, Token) error), + ) + } else if ft := reflect.TypeOf((func(S, Token) (Token, error))(nil)); fVal.CanConvert(ft) { + return asValidatorSimple[S]( + fVal.Convert(ft).Interface().(func(S, Token) (Token, error)), + ) + } else if ft := reflect.TypeOf((func(context.Context, S, Token) error)(nil)); fVal.CanConvert(ft) { + return asValidatorSimple[S]( + fVal.Convert(ft).Interface().(func(context.Context, S, Token) error), + ) + } else { + // we know this can be converted to this final type + return asValidatorSimple[S]( + fVal.Convert(ft).Interface().(func(context.Context, S, Token) (Token, error)), + ) + } } diff --git a/validator_test.go b/validator_test.go index 4ea2601..94fd939 100644 --- a/validator_test.go +++ b/validator_test.go @@ -4,8 +4,6 @@ package bascule import ( - "context" - "errors" "testing" "github.com/stretchr/testify/suite" @@ -15,70 +13,6 @@ type ValidatorsTestSuite struct { TestSuite } -func (suite *ValidatorsTestSuite) TestValidate() { - validateErr := errors.New("expected Validate error") - - testCases := []struct { - name string - results []error - expectedErr error - }{ - { - name: "EmptyValidators", - results: nil, - }, - { - name: "OneSuccess", - results: []error{nil}, - }, - { - name: "OneFailure", - results: []error{validateErr}, - expectedErr: validateErr, - }, - { - name: "FirstFailure", - results: []error{validateErr, errors.New("should not be called")}, - expectedErr: validateErr, - }, - { - name: "MiddleFailure", - results: []error{nil, validateErr, errors.New("should not be called")}, - expectedErr: validateErr, - }, - { - name: "AllSuccess", - results: []error{nil, nil, nil}, - }, - } - - for _, testCase := range testCases { - suite.Run(testCase.name, func() { - var ( - testCtx = suite.testContext() - testToken = suite.testToken() - vs Validators - ) - - for _, err := range testCase.results { - err := err - vs.Add( - ValidatorFunc(func(ctx context.Context, token Token) error { - suite.Same(testCtx, ctx) - suite.Same(testToken, token) - return err - }), - ) - } - - suite.Equal( - testCase.expectedErr, - vs.Validate(testCtx, testToken), - ) - }) - } -} - func TestValidators(t *testing.T) { suite.Run(t, new(ValidatorsTestSuite)) }