Skip to content

Commit

Permalink
Make sure expired tokens give correct status / problem detail (PP-970) (
Browse files Browse the repository at this point in the history
#1677)

* Make sure expired tokens return the correct problem detail.
* Fix import loop
* Add one more test for coverage.
  • Loading branch information
jonathangreen authored Feb 15, 2024
1 parent 762fda1 commit bd88741
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 67 deletions.
13 changes: 0 additions & 13 deletions api/authentication/access_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,19 +165,6 @@ def decrypt_token(cls, _db: Session, token: jwe.JWE | str) -> TokenPatronInfo:

return TokenPatronInfo(**payload)

@classmethod
def is_access_token(cls, token: str | None) -> bool:
"""Test if the given token is a valid JWE token"""
if token is None:
return False

try:
cls.decode_token(token)
except Exception:
return False

return True

@classmethod
def delete_old_keys(cls, _db: Session) -> int:
"""Delete old keys from the DB
Expand Down
14 changes: 6 additions & 8 deletions api/authentication/basic_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,12 @@ def authenticated_patron(

def get_credential_from_header(self, auth: Authorization) -> str | None:
"""If we are the right type of token, then decode the password from the token"""
if (
auth
and auth.type.lower() == "bearer"
and auth.token
and PatronJWEAccessTokenProvider.is_access_token(auth.token)
):
token = PatronJWEAccessTokenProvider.decrypt_token(self._db, auth.token)
return token.pwd
if auth and auth.type.lower() == "bearer" and auth.token:
try:
token = PatronJWEAccessTokenProvider.decrypt_token(self._db, auth.token)
return token.pwd
except ProblemError:
...

return None

Expand Down
61 changes: 38 additions & 23 deletions api/authenticator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
import json
import logging
import sys
Expand All @@ -16,7 +17,6 @@

from api.adobe_vendor_id import AuthdataUtility
from api.annotations import AnnotationWriter
from api.authentication.access_token import PatronJWEAccessTokenProvider
from api.authentication.base import (
AuthenticationProvider,
LibrarySettingsType,
Expand Down Expand Up @@ -87,6 +87,28 @@ def profile_document(self):
return doc


class BearerTokenType(enum.Enum):
"""The type of token being used for authentication."""

JWE = enum.auto()
JWT = enum.auto()
UNKNOWN = enum.auto()

@classmethod
def from_token(cls, token: str | None) -> BearerTokenType:
"""Determine the type of token from its string representation."""
if token is None:
return cls.UNKNOWN

split_token = token.split(".")
if len(split_token) == 5:
return cls.JWE
elif len(split_token) == 3:
return cls.JWT
else:
return cls.UNKNOWN


class Authenticator(LoggerMixin):
"""Route requests to the appropriate LibraryAuthenticator."""

Expand Down Expand Up @@ -417,45 +439,38 @@ def authenticated_patron(
credentials do not authenticate any particular patron. A
ProblemDetail if an error occurs.
"""
provider: AuthenticationProvider | None = None
provider_token: dict[str, str | None] | str | None = None
if self.basic_auth_provider and auth.type.lower() == "basic":
# The patron wants to authenticate with the
# BasicAuthenticationProvider.
provider = self.basic_auth_provider
provider_token = auth.parameters
return self.basic_auth_provider.authenticated_patron(_db, auth.parameters)
elif auth.type.lower() == "bearer":
# The patron wants to use an
# SAMLAuthenticationProvider. Figure out which one.
if auth.token is None:
# The patron wants to use a bearer token. Figure out which type
# of token it is and which provider to use.
token_str = auth.token
if token_str is None:
return INVALID_SAML_BEARER_TOKEN

token_type = BearerTokenType.from_token(token_str)
if (
self.access_token_authentication_provider
and PatronJWEAccessTokenProvider.is_access_token(auth.token)
token_type == BearerTokenType.JWE
and self.access_token_authentication_provider
):
provider = self.access_token_authentication_provider
provider_token = auth.token
elif self.saml_providers_by_name:
# The patron wants to use an
# SAMLAuthenticationProvider. Figure out which one.
return self.access_token_authentication_provider.authenticated_patron(
_db, token_str
)
elif token_type == BearerTokenType.JWT:
# The patron wants to use an SAMLAuthenticationProvider. Figure out which one.
try:
provider_name, provider_token = self.decode_bearer_token(auth.token)
provider_name, provider_token = self.decode_bearer_token(token_str)
except jwt.exceptions.InvalidTokenError as e:
return INVALID_SAML_BEARER_TOKEN
saml_provider = self.saml_provider_lookup(provider_name)
if isinstance(saml_provider, ProblemDetail):
# There was a problem turning the provider name into
# a registered SAMLAuthenticationProvider.
return saml_provider
provider = saml_provider

if provider and provider_token:
# Turn the token/header into a patron
return provider.authenticated_patron(_db, provider_token)
return saml_provider.authenticated_patron(_db, provider_token)

# We were unable to determine what was going on with the
# Authenticate header.
return UNSUPPORTED_AUTHENTICATION_MECHANISM

def get_credential_from_header(self, auth: Authorization) -> str | None:
Expand Down
13 changes: 0 additions & 13 deletions tests/api/authentication/test_access_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,19 +272,6 @@ def test_decrypt_token_errors(
PatronJWEAccessTokenProvider.decrypt_token(db.session, token)
assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID

def test_is_access_token(self, jwe_provider: JWEProviderFixture):
# Happy path
token = jwe_provider.generate_token()
assert PatronJWEAccessTokenProvider.is_access_token(token) is True

with patch.object(PatronJWEAccessTokenProvider, "decode_token") as decode:
# An incorrect type
decode.side_effect = Exception("Bang!")
assert PatronJWEAccessTokenProvider.is_access_token(token) is False

# The token is not the right format
assert PatronJWEAccessTokenProvider.is_access_token("not-a-token") is False

@freeze_time()
def test_delete_old_keys(self):
mock_session = MagicMock()
Expand Down
20 changes: 15 additions & 5 deletions tests/api/authentication/test_basic_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,23 @@ def test_credential_from_header(self, db: DatabaseTransactionFixture):
db.session, patron, "passworx"
)

pwd = provider.get_credential_from_header(
Authorization(auth_type="Bearer", token=token)
assert (
provider.get_credential_from_header(
Authorization(auth_type="Bearer", token=token)
)
== "passworx"
)
assert pwd == "passworx"

pwd = provider.get_credential_from_header(Authorization(auth_type="Basic"))
assert pwd == None
assert (
provider.get_credential_from_header(Authorization(auth_type="Basic"))
is None
)
assert (
provider.get_credential_from_header(
Authorization(auth_type="Bearer", token="junk")
)
is None
)

def test_authentication_flow_document(
self, db: DatabaseTransactionFixture, controller_fixture: ControllerFixture
Expand Down
42 changes: 37 additions & 5 deletions tests/api/test_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from api.authenticator import (
Authenticator,
BaseSAMLAuthenticationProvider,
BearerTokenType,
CirculationPatronProfileStorage,
LibraryAuthenticator,
)
Expand Down Expand Up @@ -891,12 +892,13 @@ def test_authenticated_patron_bearer(
# Mock the sign verification
with patch.object(authenticator, "decode_bearer_token") as decode:
decode.return_value = ("Mock", "decoded-token")
bearer_token = authenticator.create_bearer_token("test", "test")
response = authenticator.authenticated_patron(
db.session, Authorization(auth_type="Bearer", token="some-bearer-token")
db.session, Authorization(auth_type="Bearer", token=bearer_token)
)
# The token was decoded
assert decode.call_count == 1
decode.assert_called_with("some-bearer-token")
decode.assert_called_with(bearer_token)
# The right saml provider was used
assert response == "foo"
assert saml.authenticated_patron.call_count == 1
Expand All @@ -906,6 +908,9 @@ def test_authenticated_patron_bearer_access_token(
db: DatabaseTransactionFixture,
mock_basic: MockBasicFixture,
):
now = utc_now()
two_hours_in_the_future = now + datetime.timedelta(hours=2)

basic = mock_basic()
# TODO: We can remove this patch once basic token authentication is fully deployed.
with patch.object(
Expand All @@ -927,9 +932,16 @@ def test_authenticated_patron_bearer_access_token(
token = PatronJWEAccessTokenProvider.generate_token(db.session, patron, "pass")
auth = Authorization(auth_type="bearer", token=token)

auth_patron = authenticator.authenticated_patron(db.session, auth)
assert type(auth_patron) == Patron
assert auth_patron.id == patron.id
# Token is valid
with freeze_time(now):
auth_patron = authenticator.authenticated_patron(db.session, auth)
assert type(auth_patron) == Patron
assert auth_patron.id == patron.id

# The token is expired
with freeze_time(two_hours_in_the_future):
problem = authenticator.authenticated_patron(db.session, auth)
assert PATRON_AUTH_ACCESS_TOKEN_EXPIRED == problem

def test_authenticated_patron_unsupported_mechanism(
self, db: DatabaseTransactionFixture
Expand Down Expand Up @@ -2537,3 +2549,23 @@ def test_authentication_updates_outdated_patron_on_authorization_identifier_matc
# then we have no way of locating them in our database. They will
# appear no different to us than a patron who has never used the
# circulation manager before.


class TestBearerTokenType:
def test_from_token(self, db: DatabaseTransactionFixture) -> None:
PatronJWEAccessTokenProvider.create_key(db.session)
patron = db.patron()
jwe_token = PatronJWEAccessTokenProvider.generate_token(
db.session, patron, "password"
)

authenticator = LibraryAuthenticator(
_db=db.session,
library=db.default_library(),
bearer_token_signing_secret="secret",
)
jwt_token = authenticator.create_bearer_token("test", "test")

assert BearerTokenType.from_token(jwt_token) == BearerTokenType.JWT
assert BearerTokenType.from_token(jwe_token) == BearerTokenType.JWE
assert BearerTokenType.from_token("test") == BearerTokenType.UNKNOWN

0 comments on commit bd88741

Please sign in to comment.