Skip to content

Commit

Permalink
refactored validator code to allow access to the source during valida…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
johnabass committed Jul 29, 2024
1 parent 539a004 commit 7b342c4
Show file tree
Hide file tree
Showing 5 changed files with 419 additions and 32 deletions.
4 changes: 2 additions & 2 deletions authorizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (suite *AuthorizersTestSuite) TestAuthorize() {
as = as.Append(
AuthorizerFunc[string](func(ctx context.Context, resource string, token Token) error {
suite.Same(testCtx, ctx)
suite.Same(testToken, token)
suite.Equal(testToken, token)
suite.Equal(placeholderResource, resource)
return err
}),
Expand Down Expand Up @@ -126,7 +126,7 @@ func (suite *AuthorizersTestSuite) TestAny() {
as = as.Append(
AuthorizerFunc[string](func(ctx context.Context, resource string, token Token) error {
suite.Same(testCtx, ctx)
suite.Same(testToken, token)
suite.Equal(testToken, token)
suite.Equal(placeholderResource, resource)
return err
}),
Expand Down
36 changes: 36 additions & 0 deletions mocks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// SPDX-FileCopyrightText: 2020 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0

package bascule

import (
"context"

"github.com/stretchr/testify/mock"
)

type testToken string

func (tt testToken) Principal() string { return string(tt) }

type mockValidator[S any] struct {
mock.Mock
}

func (m *mockValidator[S]) Validate(ctx context.Context, source S, token Token) (Token, error) {
args := m.Called(ctx, source, token)
t, _ := args.Get(0).(Token)
return t, args.Error(1)
}

func (m *mockValidator[S]) ExpectValidate(ctx context.Context, source S, token Token) *mock.Call {
return m.On("Validate", ctx, source, token)
}

func assertValidators[S any](t mock.TestingT, vs ...Validator[S]) (passed bool) {
for _, v := range vs {
passed = v.(*mockValidator[S]).AssertExpectations(t) && passed
}

return
}
12 changes: 1 addition & 11 deletions testSuite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@ const (
testScheme Scheme = "Test"
)

type testToken struct {
principal string
}

func (tt *testToken) Principal() string {
return tt.principal
}

// TestSuite holds generally useful functionality for testing bascule.
type TestSuite struct {
suite.Suite
Expand All @@ -43,9 +35,7 @@ func (suite *TestSuite) testCredentials() Credentials {
}

func (suite *TestSuite) testToken() Token {
return &testToken{
principal: "test",
}
return testToken("test")
}

func (suite *TestSuite) contexter(ctx context.Context) Contexter {
Expand Down
45 changes: 26 additions & 19 deletions validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,20 @@ type ValidatorFunc[S any] interface {
// and uncurry a closure.
type validatorFunc[S any] func(context.Context, S, Token) (Token, error)

func (vf validatorFunc[S]) Validate(ctx context.Context, source S, t Token) (Token, error) {
return vf(ctx, source, t)
func (vf validatorFunc[S]) Validate(ctx context.Context, source S, t Token) (next Token, err error) {
next, err = vf(ctx, source, t)
if next == nil {
next = t
}

return
}

var (
tokenReturnsError = reflect.TypeOf((func(Token) error)(nil))
tokenReturnsTokenError = reflect.TypeOf((func(Token) (Token, error))(nil))
contextTokenReturnsError = reflect.TypeOf((func(context.Context, Token) error)(nil))
contextTokenReturnsTokenError = reflect.TypeOf((func(context.Context, Token) (Token, error))(nil))
tokenReturnError = reflect.TypeOf((func(Token) error)(nil))
tokenReturnTokenAndError = reflect.TypeOf((func(Token) (Token, error))(nil))
contextTokenReturnError = reflect.TypeOf((func(context.Context, Token) error)(nil))
contextTokenReturnTokenError = reflect.TypeOf((func(context.Context, Token) (Token, error))(nil))
)

// asValidatorSimple tries simple conversions on f. This function will not catch
Expand All @@ -102,14 +107,14 @@ func asValidatorSimple[S any, F ValidatorFunc[S]](f F) (v Validator[S]) {
case func(Token) error:
v = validatorFunc[S](
func(ctx context.Context, source S, t Token) (Token, error) {
return nil, vf(t)
return t, vf(t)
},
)

case func(S, Token) error:
v = validatorFunc[S](
func(ctx context.Context, source S, t Token) (Token, error) {
return nil, vf(source, t)
return t, vf(source, t)
},
)

Expand Down Expand Up @@ -140,14 +145,14 @@ func asValidatorSimple[S any, F ValidatorFunc[S]](f F) (v Validator[S]) {
case func(context.Context, Token) error:
v = validatorFunc[S](
func(ctx context.Context, source S, t Token) (Token, error) {
return nil, vf(ctx, t)
return t, vf(ctx, t)
},
)

case func(context.Context, S, Token) error:
v = validatorFunc[S](
func(ctx context.Context, source S, t Token) (Token, error) {
return nil, vf(ctx, source, t)
return t, vf(ctx, source, t)
},
)

Expand All @@ -171,7 +176,8 @@ func asValidatorSimple[S any, F ValidatorFunc[S]](f F) (v Validator[S]) {
}

// AsValidator takes a ValidatorFunc closure and returns a Validator instance that
// executes that closure.
// executes that closure. This function can also convert custom types which can
// be converted to any of the closure signatures.
func AsValidator[S any, F ValidatorFunc[S]](f F) Validator[S] {
// first, try the simple way:
if v := asValidatorSimple[S](f); v != nil {
Expand All @@ -182,24 +188,24 @@ func AsValidator[S any, F ValidatorFunc[S]](f F) Validator[S] {
// require the source type.
fVal := reflect.ValueOf(f)
switch {
case fVal.CanConvert(tokenReturnsError):
case fVal.CanConvert(tokenReturnError):
return asValidatorSimple[S](
fVal.Convert(tokenReturnsError).Interface().(func(Token) error),
fVal.Convert(tokenReturnError).Interface().(func(Token) error),
)

case fVal.CanConvert(tokenReturnsTokenError):
case fVal.CanConvert(tokenReturnTokenAndError):
return asValidatorSimple[S](
fVal.Convert(tokenReturnsError).Interface().(func(Token) (Token, error)),
fVal.Convert(tokenReturnTokenAndError).Interface().(func(Token) (Token, error)),
)

case fVal.CanConvert(contextTokenReturnsError):
case fVal.CanConvert(contextTokenReturnError):
return asValidatorSimple[S](
fVal.Convert(contextTokenReturnsError).Interface().(func(context.Context, Token) error),
fVal.Convert(contextTokenReturnError).Interface().(func(context.Context, Token) error),
)

case fVal.CanConvert(contextTokenReturnsTokenError):
case fVal.CanConvert(contextTokenReturnTokenError):
return asValidatorSimple[S](
fVal.Convert(contextTokenReturnsError).Interface().(func(context.Context, Token) (Token, error)),
fVal.Convert(contextTokenReturnTokenError).Interface().(func(context.Context, Token) (Token, error)),
)
}

Expand All @@ -219,6 +225,7 @@ func AsValidator[S any, F ValidatorFunc[S]](f F) Validator[S] {
)
} else {
// we know this can be converted to this final type
ft := reflect.TypeOf((func(context.Context, S, Token) (Token, error))(nil))
return asValidatorSimple[S](
fVal.Convert(ft).Interface().(func(context.Context, S, Token) (Token, error)),
)
Expand Down
Loading

0 comments on commit 7b342c4

Please sign in to comment.