diff --git a/basculehttp/error.go b/basculehttp/error.go index 8299c25..08badb2 100644 --- a/basculehttp/error.go +++ b/basculehttp/error.go @@ -27,6 +27,9 @@ type ErrorStatusCoder func(request *http.Request, err error) int // (3) If err has bascule.ErrMissingCredentials in its chain, this function returns // http.StatusUnauthorized. // +// (3) If err has bascule.ErrBadCredentials in its chain, this function returns +// http.StatusUnauthorized. +// // (4) If err has bascule.ErrUnauthorized in its chain, this function returns // http.StatusForbidden. // @@ -52,6 +55,9 @@ func DefaultErrorStatusCoder(_ *http.Request, err error) int { case errors.Is(err, bascule.ErrMissingCredentials): return http.StatusUnauthorized + case errors.Is(err, bascule.ErrBadCredentials): + return http.StatusUnauthorized + case errors.Is(err, bascule.ErrUnauthorized): return http.StatusForbidden diff --git a/basculehttp/error_test.go b/basculehttp/error_test.go index fba79bf..570e350 100644 --- a/basculehttp/error_test.go +++ b/basculehttp/error_test.go @@ -31,6 +31,13 @@ func (suite *ErrorTestSuite) TestDefaultErrorStatusCoder() { ) }) + suite.Run("ErrBadCredentials", func() { + suite.Equal( + http.StatusUnauthorized, + DefaultErrorStatusCoder(nil, bascule.ErrBadCredentials), + ) + }) + suite.Run("ErrUnauthorized", func() { suite.Equal( http.StatusForbidden, diff --git a/basculehttp/middleware_examples_test.go b/basculehttp/middleware_examples_test.go index 18d2f83..e77181f 100644 --- a/basculehttp/middleware_examples_test.go +++ b/basculehttp/middleware_examples_test.go @@ -4,6 +4,7 @@ package basculehttp import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -55,3 +56,111 @@ func ExampleMiddleware_basicauth() { // principal: joe // with basic auth response code: 200 } + +// ExampleMiddleware_authentication shows how to authenticate a token. +func ExampleMiddleware_authentication() { + tp, _ := NewAuthorizationParser( + WithBasic(), + ) + + m, _ := NewMiddleware( + UseAuthenticator( + NewAuthenticator( + bascule.WithTokenParsers(tp), + bascule.WithValidators( + AsValidator( + // the signature of validator closures is very flexible + // see bascule.AsValidator + func(token bascule.Token) error { + if basic, ok := token.(BasicToken); ok && basic.Password() == "correct_password" { + return nil + } + + return bascule.ErrBadCredentials + }, + ), + ), + ), + ), + ) + + h := m.ThenFunc( + func(response http.ResponseWriter, request *http.Request) { + t, _ := bascule.GetFrom(request) + fmt.Println("principal:", t.Principal()) + }, + ) + + requestForJoe := httptest.NewRequest("GET", "/", nil) + requestForJoe.SetBasicAuth("joe", "correct_password") + response := httptest.NewRecorder() + h.ServeHTTP(response, requestForJoe) + fmt.Println("we let joe in with the code:", response.Code) + + requestForCurly := httptest.NewRequest("GET", "/", nil) + requestForCurly.SetBasicAuth("joe", "bad_password") + response = httptest.NewRecorder() + h.ServeHTTP(response, requestForCurly) + fmt.Println("this isn't joe:", response.Code) + + // Output: + // principal: joe + // we let joe in with the code: 200 + // this isn't joe: 401 +} + +// ExampleMiddleware_authorization shows how to set up custom +// authorization for tokens. +func ExampleMiddleware_authorization() { + tp, _ := NewAuthorizationParser( + WithBasic(), + ) + + m, _ := NewMiddleware( + UseAuthenticator( + NewAuthenticator( + bascule.WithTokenParsers(tp), + ), + ), + UseAuthorizer( + NewAuthorizer( + bascule.WithApproverFuncs( + // this can also be a type that implements the bascule.Approver interface, + // when used with bascule.WithApprovers + func(_ context.Context, resource *http.Request, token bascule.Token) error { + if token.Principal() != "joe" { + // only joe can access this resource + return bascule.ErrUnauthorized + } + + return nil // approved + }, + ), + ), + ), + ) + + h := m.ThenFunc( + func(response http.ResponseWriter, request *http.Request) { + t, _ := bascule.GetFrom(request) + fmt.Println("principal:", t.Principal()) + }, + ) + + requestForJoe := httptest.NewRequest("GET", "/", nil) + requestForJoe.SetBasicAuth("joe", "password") + response := httptest.NewRecorder() + h.ServeHTTP(response, requestForJoe) + fmt.Println("we let joe in with the code:", response.Code) + + requestForCurly := httptest.NewRequest("GET", "/", nil) + requestForCurly.SetBasicAuth("curly", "another_password") + response = httptest.NewRecorder() + h.ServeHTTP(response, requestForCurly) + fmt.Println("we didn't authorize curly:", response.Code) + + // Output: + // principal: joe + // we let joe in with the code: 200 + // we didn't authorize curly: 403 +} diff --git a/basculehttp/validator.go b/basculehttp/validator.go new file mode 100644 index 0000000..e2b83b1 --- /dev/null +++ b/basculehttp/validator.go @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package basculehttp + +import ( + "net/http" + + "github.com/xmidt-org/bascule" +) + +// AsValidator an HTTP-specific version of bascule.AsValidator. This function +// eases the syntactic pain of using golang's generics. +func AsValidator[F bascule.ValidatorFunc[*http.Request]](f F) bascule.Validator[*http.Request] { + return bascule.AsValidator[*http.Request](f) +} diff --git a/token.go b/token.go index d0627e5..1ebf068 100644 --- a/token.go +++ b/token.go @@ -15,13 +15,17 @@ var ( // because of configuration, possibly intentionally. ErrNoTokenParsers = errors.New("no token parsers") - // ErrMissingCredentials is returned by TokenParser.Parse to indicate that a source - // object did not have any credentials recognized by that parser. + // ErrMissingCredentials indicates that a source object did not have any credentials + // recognized by that parser. ErrMissingCredentials = errors.New("missing credentials") - // ErrInvalidCredentials is returned by TokenParser.Parse to indicate that a source - // did contain recognizable credentials, but those credentials could not be parsed, - // possibly due to bad formatting. + // ErrBadCredentials indicates that parseable credentials were present in the source, + // but that the credentials did not match what the application expects. For example, + // a password mismatch should return this error. + ErrBadCredentials = errors.New("bad credentials") + + // ErrInvalidCredentials indicates that a source did contain recognizable credentials, + // but those credentials could not be parsed, possibly due to bad formatting. ErrInvalidCredentials = errors.New("invalid credentials") )