diff --git a/api/authentication/access_token.py b/api/authentication/access_token.py index 5ffd1ac8d6..1f53f05a30 100644 --- a/api/authentication/access_token.py +++ b/api/authentication/access_token.py @@ -165,19 +165,6 @@ def decrypt_token(cls, _db: Session, token: jwe.JWE | str) -> TokenPatronInfo: return TokenPatronInfo(**payload) - @classmethod - def is_access_token(cls, token: str | None) -> bool: - """Test if the given token is a valid JWE token""" - if token is None: - return False - - try: - cls.decode_token(token) - except Exception: - return False - - return True - @classmethod def delete_old_keys(cls, _db: Session) -> int: """Delete old keys from the DB diff --git a/api/authentication/basic_token.py b/api/authentication/basic_token.py index 2819e06945..417799b786 100644 --- a/api/authentication/basic_token.py +++ b/api/authentication/basic_token.py @@ -76,14 +76,12 @@ def authenticated_patron( def get_credential_from_header(self, auth: Authorization) -> str | None: """If we are the right type of token, then decode the password from the token""" - if ( - auth - and auth.type.lower() == "bearer" - and auth.token - and PatronJWEAccessTokenProvider.is_access_token(auth.token) - ): - token = PatronJWEAccessTokenProvider.decrypt_token(self._db, auth.token) - return token.pwd + if auth and auth.type.lower() == "bearer" and auth.token: + try: + token = PatronJWEAccessTokenProvider.decrypt_token(self._db, auth.token) + return token.pwd + except ProblemError: + ... return None diff --git a/api/authenticator.py b/api/authenticator.py index 8b49feef59..0e5ab98f53 100644 --- a/api/authenticator.py +++ b/api/authenticator.py @@ -1,5 +1,6 @@ from __future__ import annotations +import enum import json import logging import sys @@ -16,7 +17,6 @@ from api.adobe_vendor_id import AuthdataUtility from api.annotations import AnnotationWriter -from api.authentication.access_token import PatronJWEAccessTokenProvider from api.authentication.base import ( AuthenticationProvider, LibrarySettingsType, @@ -87,6 +87,28 @@ def profile_document(self): return doc +class BearerTokenType(enum.Enum): + """The type of token being used for authentication.""" + + JWE = enum.auto() + JWT = enum.auto() + UNKNOWN = enum.auto() + + @classmethod + def from_token(cls, token: str | None) -> BearerTokenType: + """Determine the type of token from its string representation.""" + if token is None: + return cls.UNKNOWN + + split_token = token.split(".") + if len(split_token) == 5: + return cls.JWE + elif len(split_token) == 3: + return cls.JWT + else: + return cls.UNKNOWN + + class Authenticator(LoggerMixin): """Route requests to the appropriate LibraryAuthenticator.""" @@ -417,30 +439,29 @@ def authenticated_patron( credentials do not authenticate any particular patron. A ProblemDetail if an error occurs. """ - provider: AuthenticationProvider | None = None - provider_token: dict[str, str | None] | str | None = None if self.basic_auth_provider and auth.type.lower() == "basic": # The patron wants to authenticate with the # BasicAuthenticationProvider. - provider = self.basic_auth_provider - provider_token = auth.parameters + return self.basic_auth_provider.authenticated_patron(_db, auth.parameters) elif auth.type.lower() == "bearer": - # The patron wants to use an - # SAMLAuthenticationProvider. Figure out which one. - if auth.token is None: + # The patron wants to use a bearer token. Figure out which type + # of token it is and which provider to use. + token_str = auth.token + if token_str is None: return INVALID_SAML_BEARER_TOKEN + token_type = BearerTokenType.from_token(token_str) if ( - self.access_token_authentication_provider - and PatronJWEAccessTokenProvider.is_access_token(auth.token) + token_type == BearerTokenType.JWE + and self.access_token_authentication_provider ): - provider = self.access_token_authentication_provider - provider_token = auth.token - elif self.saml_providers_by_name: - # The patron wants to use an - # SAMLAuthenticationProvider. Figure out which one. + return self.access_token_authentication_provider.authenticated_patron( + _db, token_str + ) + elif token_type == BearerTokenType.JWT: + # The patron wants to use an SAMLAuthenticationProvider. Figure out which one. try: - provider_name, provider_token = self.decode_bearer_token(auth.token) + provider_name, provider_token = self.decode_bearer_token(token_str) except jwt.exceptions.InvalidTokenError as e: return INVALID_SAML_BEARER_TOKEN saml_provider = self.saml_provider_lookup(provider_name) @@ -448,14 +469,8 @@ def authenticated_patron( # There was a problem turning the provider name into # a registered SAMLAuthenticationProvider. return saml_provider - provider = saml_provider - - if provider and provider_token: - # Turn the token/header into a patron - return provider.authenticated_patron(_db, provider_token) + return saml_provider.authenticated_patron(_db, provider_token) - # We were unable to determine what was going on with the - # Authenticate header. return UNSUPPORTED_AUTHENTICATION_MECHANISM def get_credential_from_header(self, auth: Authorization) -> str | None: diff --git a/tests/api/authentication/test_access_token.py b/tests/api/authentication/test_access_token.py index 71155e4667..6fc196cf0d 100644 --- a/tests/api/authentication/test_access_token.py +++ b/tests/api/authentication/test_access_token.py @@ -272,19 +272,6 @@ def test_decrypt_token_errors( PatronJWEAccessTokenProvider.decrypt_token(db.session, token) assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID - def test_is_access_token(self, jwe_provider: JWEProviderFixture): - # Happy path - token = jwe_provider.generate_token() - assert PatronJWEAccessTokenProvider.is_access_token(token) is True - - with patch.object(PatronJWEAccessTokenProvider, "decode_token") as decode: - # An incorrect type - decode.side_effect = Exception("Bang!") - assert PatronJWEAccessTokenProvider.is_access_token(token) is False - - # The token is not the right format - assert PatronJWEAccessTokenProvider.is_access_token("not-a-token") is False - @freeze_time() def test_delete_old_keys(self): mock_session = MagicMock() diff --git a/tests/api/authentication/test_basic_token.py b/tests/api/authentication/test_basic_token.py index a1cbf0ce3e..345669791c 100644 --- a/tests/api/authentication/test_basic_token.py +++ b/tests/api/authentication/test_basic_token.py @@ -59,13 +59,23 @@ def test_credential_from_header(self, db: DatabaseTransactionFixture): db.session, patron, "passworx" ) - pwd = provider.get_credential_from_header( - Authorization(auth_type="Bearer", token=token) + assert ( + provider.get_credential_from_header( + Authorization(auth_type="Bearer", token=token) + ) + == "passworx" ) - assert pwd == "passworx" - pwd = provider.get_credential_from_header(Authorization(auth_type="Basic")) - assert pwd == None + assert ( + provider.get_credential_from_header(Authorization(auth_type="Basic")) + is None + ) + assert ( + provider.get_credential_from_header( + Authorization(auth_type="Bearer", token="junk") + ) + is None + ) def test_authentication_flow_document( self, db: DatabaseTransactionFixture, controller_fixture: ControllerFixture diff --git a/tests/api/test_authenticator.py b/tests/api/test_authenticator.py index 40c8dca3ab..383380e03c 100644 --- a/tests/api/test_authenticator.py +++ b/tests/api/test_authenticator.py @@ -36,6 +36,7 @@ from api.authenticator import ( Authenticator, BaseSAMLAuthenticationProvider, + BearerTokenType, CirculationPatronProfileStorage, LibraryAuthenticator, ) @@ -891,12 +892,13 @@ def test_authenticated_patron_bearer( # Mock the sign verification with patch.object(authenticator, "decode_bearer_token") as decode: decode.return_value = ("Mock", "decoded-token") + bearer_token = authenticator.create_bearer_token("test", "test") response = authenticator.authenticated_patron( - db.session, Authorization(auth_type="Bearer", token="some-bearer-token") + db.session, Authorization(auth_type="Bearer", token=bearer_token) ) # The token was decoded assert decode.call_count == 1 - decode.assert_called_with("some-bearer-token") + decode.assert_called_with(bearer_token) # The right saml provider was used assert response == "foo" assert saml.authenticated_patron.call_count == 1 @@ -906,6 +908,9 @@ def test_authenticated_patron_bearer_access_token( db: DatabaseTransactionFixture, mock_basic: MockBasicFixture, ): + now = utc_now() + two_hours_in_the_future = now + datetime.timedelta(hours=2) + basic = mock_basic() # TODO: We can remove this patch once basic token authentication is fully deployed. with patch.object( @@ -927,9 +932,16 @@ def test_authenticated_patron_bearer_access_token( token = PatronJWEAccessTokenProvider.generate_token(db.session, patron, "pass") auth = Authorization(auth_type="bearer", token=token) - auth_patron = authenticator.authenticated_patron(db.session, auth) - assert type(auth_patron) == Patron - assert auth_patron.id == patron.id + # Token is valid + with freeze_time(now): + auth_patron = authenticator.authenticated_patron(db.session, auth) + assert type(auth_patron) == Patron + assert auth_patron.id == patron.id + + # The token is expired + with freeze_time(two_hours_in_the_future): + problem = authenticator.authenticated_patron(db.session, auth) + assert PATRON_AUTH_ACCESS_TOKEN_EXPIRED == problem def test_authenticated_patron_unsupported_mechanism( self, db: DatabaseTransactionFixture @@ -2537,3 +2549,23 @@ def test_authentication_updates_outdated_patron_on_authorization_identifier_matc # then we have no way of locating them in our database. They will # appear no different to us than a patron who has never used the # circulation manager before. + + +class TestBearerTokenType: + def test_from_token(self, db: DatabaseTransactionFixture) -> None: + PatronJWEAccessTokenProvider.create_key(db.session) + patron = db.patron() + jwe_token = PatronJWEAccessTokenProvider.generate_token( + db.session, patron, "password" + ) + + authenticator = LibraryAuthenticator( + _db=db.session, + library=db.default_library(), + bearer_token_signing_secret="secret", + ) + jwt_token = authenticator.create_bearer_token("test", "test") + + assert BearerTokenType.from_token(jwt_token) == BearerTokenType.JWT + assert BearerTokenType.from_token(jwe_token) == BearerTokenType.JWE + assert BearerTokenType.from_token("test") == BearerTokenType.UNKNOWN