From 6540ea0439192905bb1cc0c6ee721b45c28f6d64 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Thu, 8 Feb 2024 14:07:56 -0400 Subject: [PATCH] Add a new key table to hold some of our sitewide keys (PP-893) (#1664) * Add keys table * Add migration * Fix comment --- .../20240207_fc3c9ccf0ad8_add_keys_table.py | 121 +++++++ .../password_admin_authentication_provider.py | 15 +- api/app.py | 7 +- api/authentication/access_token.py | 226 ++++++------- api/authentication/basic_token.py | 32 +- api/authenticator.py | 60 +--- api/controller/patron_auth_token.py | 13 +- core/jobs/rotate_jwe_key.py | 25 +- core/model/__init__.py | 12 +- core/model/key.py | 156 +++++++++ core/util/string_helpers.py | 11 + core/util/uuid.py | 18 +- pyproject.toml | 1 + scripts.py | 3 - tests/api/authentication/__init__.py | 0 tests/api/authentication/test_access_token.py | 299 ++++++++++++++++++ .../test_basic_token.py} | 31 +- tests/api/test_authenticator.py | 8 +- tests/api/test_jwe_provider.py | 122 ------- tests/core/jobs/test_rotate_jwe_key.py | 64 +++- tests/core/models/test_key.py | 159 ++++++++++ tests/core/util/test_string_helpers.py | 57 ++-- tests/core/util/test_uuid.py | 14 + tests/fixtures/api_admin.py | 6 - tests/fixtures/database.py | 3 +- 25 files changed, 1076 insertions(+), 387 deletions(-) create mode 100644 alembic/versions/20240207_fc3c9ccf0ad8_add_keys_table.py create mode 100644 core/model/key.py create mode 100644 tests/api/authentication/__init__.py create mode 100644 tests/api/authentication/test_access_token.py rename tests/api/{test_basic_token_authentication_provider.py => authentication/test_basic_token.py} (74%) delete mode 100644 tests/api/test_jwe_provider.py create mode 100644 tests/core/models/test_key.py diff --git a/alembic/versions/20240207_fc3c9ccf0ad8_add_keys_table.py b/alembic/versions/20240207_fc3c9ccf0ad8_add_keys_table.py new file mode 100644 index 0000000000..5af12253ce --- /dev/null +++ b/alembic/versions/20240207_fc3c9ccf0ad8_add_keys_table.py @@ -0,0 +1,121 @@ +"""Add keys table + +Revision ID: fc3c9ccf0ad8 +Revises: 993729d4bf97 +Create Date: 2024-02-07 17:51:44.823725+00:00 + +""" +import datetime +import uuid +from collections.abc import Callable + +import sqlalchemy as sa +from jwcrypto import jwk +from sqlalchemy.dialects import postgresql +from sqlalchemy.engine import Connection + +from alembic import op +from core.migration.util import migration_logger +from core.util.datetime_helpers import utc_now +from core.util.string_helpers import random_key + +# revision identifiers, used by Alembic. +revision = "fc3c9ccf0ad8" +down_revision = "993729d4bf97" +branch_labels = None +depends_on = None + +log = migration_logger(revision) + + +def get_sitewide_config(connection: Connection, key: str) -> str | None: + result = connection.execute( + "SELECT value from configurationsettings where key = %s and library_id is null and external_integration_id is null", + key, + ).one_or_none() + + if result is None: + return None + + return result.value + + +def insert_key( + connection: Connection, key_type: str, value: str, created: datetime.datetime +) -> None: + connection.execute( + "INSERT INTO keys (id, created, value, type) VALUES (%s, %s, %s, %s)", + (uuid.uuid4(), created, value, key_type), + ) + + +def migrate_configuration_setting( + connection: Connection, + key_type: str, + setting_value: str | None, + generate: Callable[[], str], +) -> None: + unknown_creation_time = datetime.datetime( + year=1970, month=1, day=1, tzinfo=datetime.timezone.utc + ) + + if setting_value: + log.info(f"Migrating {key_type} to new keys table") + insert_key(connection, key_type, setting_value, unknown_creation_time) + else: + log.warning(f"No {key_type} found. Generating a new one.") + insert_key(connection, key_type, generate(), utc_now()) + + +def upgrade() -> None: + op.create_table( + "keys", + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("created", sa.DateTime(timezone=True), nullable=False), + sa.Column("value", sa.Unicode(), nullable=False), + sa.Column( + "type", + sa.Enum( + "AUTH_TOKEN_JWE", + "BEARER_TOKEN_SIGNING", + "ADMIN_SECRET_KEY", + name="keytype", + ), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_keys_created"), "keys", ["created"], unique=False) + op.create_index(op.f("ix_keys_type"), "keys", ["type"], unique=False) + + # Migrate in the data from the old table. + connection = op.get_bind() + + admin_secret_key = get_sitewide_config(connection, "secret_key") + bearer_token_signing_key = get_sitewide_config( + connection, "bearer_token_signing_secret" + ) + auth_token_jwe_key = get_sitewide_config(connection, "PATRON_JWE_KEY") + + migrate_configuration_setting( + connection, "ADMIN_SECRET_KEY", admin_secret_key, lambda: random_key(48) + ) + migrate_configuration_setting( + connection, + "BEARER_TOKEN_SIGNING", + bearer_token_signing_key, + lambda: random_key(48), + ) + migrate_configuration_setting( + connection, + "AUTH_TOKEN_JWE", + auth_token_jwe_key, + lambda: jwk.JWK.generate(kty="oct", size=256).export(), + ) + + +def downgrade() -> None: + op.drop_index(op.f("ix_keys_type"), table_name="keys") + op.drop_index(op.f("ix_keys_created"), table_name="keys") + op.drop_table("keys") + sa.Enum(name="keytype").drop(op.get_bind(), checkfirst=False) diff --git a/api/admin/password_admin_authentication_provider.py b/api/admin/password_admin_authentication_provider.py index 549f82af08..5e0d033cac 100644 --- a/api/admin/password_admin_authentication_provider.py +++ b/api/admin/password_admin_authentication_provider.py @@ -12,8 +12,8 @@ reset_password_template, sign_in_template, ) -from api.config import Configuration -from core.model import Admin, ConfigurationSetting +from core.model import Admin, Key +from core.model.key import KeyType from core.util.email import EmailManager from core.util.problem_detail import ProblemDetail @@ -33,6 +33,13 @@ class PasswordAdminAuthenticationProvider(AdminAuthenticationProvider): label=label_style, input=input_style, button=button_style ) + @staticmethod + def get_secret_key(db: Session) -> str: + key = Key.get_key(db, KeyType.ADMIN_SECRET_KEY, raise_exception=True).value + # We know .value is a str because its a non-null column in the DB, so + # we use an ignore to tell mypy to trust us. + return key # type: ignore[return-value] + def sign_in_template(self, redirect): password_sign_in_url = url_for("password_auth") forgot_password_url = url_for("admin_forgot_password") @@ -83,7 +90,7 @@ def active_credentials(self, admin): return True def generate_reset_password_token(self, admin: Admin, _db: Session) -> str: - secret_key = ConfigurationSetting.sitewide_secret(_db, Configuration.SECRET_KEY) + secret_key = self.get_secret_key(_db) reset_password_token = admin.generate_reset_password_token(secret_key) @@ -109,7 +116,7 @@ def send_reset_password_email(self, admin: Admin, reset_password_url: str) -> No def validate_token_and_extract_admin( self, reset_password_token: str, admin_id: int, _db: Session ) -> Admin | ProblemDetail: - secret_key = ConfigurationSetting.sitewide_secret(_db, Configuration.SECRET_KEY) + secret_key = self.get_secret_key(_db) return Admin.validate_reset_password_token_and_fetch_admin( reset_password_token, admin_id, _db, secret_key diff --git a/api/app.py b/api/app.py index 17a7c4615a..8044be2af3 100644 --- a/api/app.py +++ b/api/app.py @@ -18,7 +18,8 @@ ) from core.app_server import ErrorHandler from core.flask_sqlalchemy_session import flask_scoped_session -from core.model import ConfigurationSetting, SessionManager +from core.model import Key, SessionManager +from core.model.key import KeyType from core.service.container import Services, container_instance from core.util import LanguageCodes from core.util.cache import CachedData @@ -61,7 +62,9 @@ def initialize_admin(_db=None): setup_admin_controllers(app.manager) _db = _db or app._db # The secret key is used for signing cookies for admin login - app.secret_key = ConfigurationSetting.sitewide_secret(_db, Configuration.SECRET_KEY) + app.secret_key = Key.get_key( + _db, KeyType.ADMIN_SECRET_KEY, raise_exception=True + ).value def initialize_circulation_manager(container: Services): diff --git a/api/authentication/access_token.py b/api/authentication/access_token.py index ecc29df13d..5ffd1ac8d6 100644 --- a/api/authentication/access_token.py +++ b/api/authentication/access_token.py @@ -1,10 +1,9 @@ from __future__ import annotations -import logging -import time -from abc import ABC, abstractmethod +import uuid +from dataclasses import dataclass from datetime import timedelta -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from jwcrypto import jwe, jwk @@ -12,166 +11,181 @@ PATRON_AUTH_ACCESS_TOKEN_EXPIRED, PATRON_AUTH_ACCESS_TOKEN_INVALID, ) -from core.model.configuration import ConfigurationSetting +from core.model import Key +from core.model.key import KeyType from core.model.patron import Patron from core.util.datetime_helpers import utc_now -from core.util.problem_detail import ProblemDetail, ProblemError -from core.util.string_helpers import random_string +from core.util.log import LoggerMixin +from core.util.problem_detail import ProblemError +from core.util.uuid import uuid_encode if TYPE_CHECKING: from sqlalchemy.orm import Session -class PatronAccessTokenProvider(ABC): - """Provides access tokens for patron auth""" - - @classmethod - @abstractmethod - def generate_token( - cls, _db, patron: Patron, password: str, expires_in: int = 3600 - ) -> str: - ... - - @classmethod - @abstractmethod - def decode_token(cls, _db, token: str) -> dict | ProblemDetail: - ... - - @classmethod - @abstractmethod - def is_access_token(cls, token: str | None) -> bool: - ... +@dataclass +class TokenPatronInfo: + id: int + pwd: str -class PatronJWEAccessTokenProvider(PatronAccessTokenProvider): +class PatronJWEAccessTokenProvider(LoggerMixin): """Provide JWE based access tokens for patron auth""" - NAME = "Patron Access Token Provider" - KEY_NAME = "PATRON_JWE_KEY" + CTY = "pv1" @classmethod - def generate_key(cls) -> jwk.JWK: + def generate_jwk(cls, key_id: uuid.UUID) -> str: """Generate a new key compatible with the token encyption type""" - kid = random_string(16) - return jwk.JWK.generate(kty="oct", size=256, kid=kid) + kid = uuid_encode(key_id) + generated_key = jwk.JWK.generate(kty="oct", size=256, kid=kid) + return generated_key.export() @classmethod - def rotate_key(cls, _db: Session) -> jwk.JWK: - """Rotate the current JWK key in the DB""" - key = cls.generate_key() - setting = ConfigurationSetting.sitewide(_db, cls.KEY_NAME) - setting.value = key.export() + def create_key(cls, _db: Session) -> Key: + """Create a new key in the DB""" + key = Key.create_key(_db, KeyType.AUTH_TOKEN_JWE, cls.generate_jwk) return key @classmethod - def get_current_key( - cls, _db: Session, kid: str | None = None, create: bool = True - ) -> jwk.JWK | None: - """Get the current JWK key for the CM - :param kid: (Optional) If present, compare this value to the currently active kid, - raise a ValueError if found to be different - :param create: (Optional) Create a key of no key exists in the system - """ - stored_key = ConfigurationSetting.sitewide(_db, cls.KEY_NAME) - key: str | None = stored_key.value + def get_jwk(cls, key: Key) -> jwk.JWK: + """Get a JWK key from the DB""" + jwk_obj = jwk.JWK.from_json(key.value) + return jwk_obj - # First time run, we don't have a value yet - if key is None: - if create: - jwk_key = cls.rotate_key(_db) - else: - return None - else: - jwk_key = jwk.JWK.from_json(key) - - if kid is not None and kid != jwk_key.get("kid"): - raise ValueError( - "Current KID has changed, the key has probably been rotated" + @classmethod + def get_key(cls, _db: Session, key_id: str | uuid.UUID | None = None) -> Key: + """Get the most recently created AUTH_TOKEN_JWE key from the DB""" + key = Key.get_key( + _db, KeyType.AUTH_TOKEN_JWE, key_id=key_id, raise_exception=True + ) + if ( + key_id is None + and key.created is not None + and key.created < utc_now() - timedelta(days=2) + ): + cls.logger().warning( + "The most recently created AUTH_TOKEN_JWE key is more then two days old. " + "This may indicate a problem with the key rotation." ) - - return jwk_key + return key @classmethod def generate_token( cls, _db: Session, patron: Patron, password: str, expires_in: int = 3600 ) -> str: - """Generate a JWE token for a patron - :param patron: Generate a token for this patron - :param password: Encrypt this password within the token - :param expires_in: Seconds after which this token will expire - :return: A compacted JWE token - """ - key = cls.get_current_key(_db) - if not key: - raise RuntimeError("Could fetch the JWE key from the DB") - - payload = dict(id=patron.id, pwd=password, typ="patron") - + """Generate a JWE token for a patron""" + key = cls.get_key(_db) + jwk_obj = cls.get_jwk(key) token = jwe.JWE( - jwe.json_encode(payload), - dict( + plaintext=jwe.json_encode(dict(id=patron.id, pwd=password)), + protected=dict( alg="dir", - kid=key.get("kid"), + kid=uuid_encode(cast(uuid.UUID, key.id)), typ="JWE", enc="A128CBC-HS256", + cty=cls.CTY, exp=(utc_now() + timedelta(seconds=expires_in)).timestamp(), ), - recipient=key, + recipient=jwk_obj, ) return token.serialize(compact=True) @classmethod - def decode_token(cls, _db: Session, token: str) -> dict | ProblemDetail: + def decode_token(cls, token: str) -> jwe.JWE: """Decode the given token :param token: A serialized JWE token :return: The decrypted data dictionary from the token """ - jwe_token = cls._decode(token) + jwe_token = jwe.JWE() + + # Set the allowed algorithms + jwe_token.allowed_algs = ["dir", "A128CBC-HS256"] + + try: + jwe_token.deserialize(token) + except jwe.InvalidJWEData as ex: + cls.logger().exception(f"Invalid JWE data was encountered: {ex}") + raise ProblemError(problem_detail=PATRON_AUTH_ACCESS_TOKEN_INVALID) # Check expiry - exp = jwe.json_decode(jwe_token.objects["protected"])["exp"] - if time.time() > exp: - return PATRON_AUTH_ACCESS_TOKEN_EXPIRED + exp = jwe_token.jose_header.get("exp") + if exp is None or utc_now().timestamp() > exp: + raise ProblemError(problem_detail=PATRON_AUTH_ACCESS_TOKEN_EXPIRED) + + # Make sure there is a kid + kid = jwe_token.jose_header.get("kid") + if kid is None: + raise ProblemError(problem_detail=PATRON_AUTH_ACCESS_TOKEN_INVALID) + + # Make sure we have the token type + typ = jwe_token.jose_header.get("typ") + if typ != "JWE": + raise ProblemError(problem_detail=PATRON_AUTH_ACCESS_TOKEN_INVALID) + # Make sure we have the payload type + cty = jwe_token.jose_header.get("cty") + if cty != cls.CTY: + raise ProblemError(problem_detail=PATRON_AUTH_ACCESS_TOKEN_INVALID) + + return jwe_token + + @classmethod + def decrypt_token(cls, _db: Session, token: jwe.JWE | str) -> TokenPatronInfo: + if isinstance(token, str): + token = cls.decode_token(token) + + kid = token.jose_header.get("kid") try: - key = cls.get_current_key(_db, jwe_token.jose_header.get("kid")) + key = cls.get_key(_db, kid) except ValueError: - # The kid was incorrect, the key has probably rotated - return PATRON_AUTH_ACCESS_TOKEN_EXPIRED + key = None + + if key is None: + raise ProblemError(problem_detail=PATRON_AUTH_ACCESS_TOKEN_INVALID) try: - jwe_token.decrypt(key) + token.decrypt(cls.get_jwk(key)) except jwe.InvalidJWEData: - return PATRON_AUTH_ACCESS_TOKEN_INVALID - - return jwe.json_decode(jwe_token.payload) + raise ProblemError(problem_detail=PATRON_AUTH_ACCESS_TOKEN_INVALID) - @classmethod - def _decode(cls, token: str) -> jwe.JWE: - """Decode a JWE token without decryption""" try: - jwe_token = jwe.JWE.from_jose_token(token) - except jwe.InvalidJWEData as ex: - logging.getLogger(cls.__name__).error( - f"Invalid JWE data was encountered: {ex}" - ) - raise ProblemError(PATRON_AUTH_ACCESS_TOKEN_INVALID) - return jwe_token + payload = jwe.json_decode(token.payload) + except ValueError: + raise ProblemError(problem_detail=PATRON_AUTH_ACCESS_TOKEN_INVALID) + + # Validate the payload + if ( + not isinstance(payload, dict) + or "id" not in payload + or "pwd" not in payload + or len(payload) != 2 + ): + raise ProblemError(problem_detail=PATRON_AUTH_ACCESS_TOKEN_INVALID) + + return TokenPatronInfo(**payload) @classmethod def is_access_token(cls, token: str | None) -> bool: """Test if the given token is a valid JWE token""" - try: - jwe_token = cls._decode(token) if token else None - except Exception: + if token is None: return False - if jwe_token is None: - return False - if jwe.json_decode(jwe_token.objects["protected"])["typ"] != "JWE": + try: + cls.decode_token(token) + except Exception: return False return True + @classmethod + def delete_old_keys(cls, _db: Session) -> int: + """Delete old keys from the DB -AccessTokenProvider: type[PatronAccessTokenProvider] = PatronJWEAccessTokenProvider + We keep the two most recent keys in the DB. And delete any keys with a created date older than + two days. + """ + two_days_ago = utc_now() - timedelta(days=2) + return Key.delete_old_keys( + _db, KeyType.AUTH_TOKEN_JWE, keep=2, older_than=two_days_ago + ) diff --git a/api/authentication/basic_token.py b/api/authentication/basic_token.py index 9197f078f3..2819e06945 100644 --- a/api/authentication/basic_token.py +++ b/api/authentication/basic_token.py @@ -7,14 +7,13 @@ from sqlalchemy.orm import Session from werkzeug.datastructures import Authorization -from api.authentication.access_token import AccessTokenProvider +from api.authentication.access_token import PatronJWEAccessTokenProvider from api.authentication.base import ( AuthenticationProvider, AuthProviderLibrarySettings, AuthProviderSettings, ) from api.authentication.basic import BasicAuthenticationProvider -from api.problem_details import PATRON_AUTH_ACCESS_TOKEN_INVALID from core.integration.base import LibrarySettingsType, SettingsType from core.model import Patron, Session, get_one from core.selftest import SelfTestResult @@ -61,29 +60,15 @@ def authenticated_patron( ) -> Patron | ProblemDetail | None: """Authenticate the patron by decoding the JWE token and fetching the patron from the DB based on the patron ID""" - if type(token) is not str: + if not isinstance(token, str): return None try: - data = AccessTokenProvider.decode_token(_db, token) + data = PatronJWEAccessTokenProvider.decrypt_token(_db, token) except ProblemError as ex: - data = ex.problem_detail + return ex.problem_detail - if type(data) == ProblemDetail: - return data - - # This exists because of mypy - assert type(data) is dict - - try: - patron_id = data["id"] - # Ensure the password exists - if "pwd" not in data: - return PATRON_AUTH_ACCESS_TOKEN_INVALID - except KeyError: - return PATRON_AUTH_ACCESS_TOKEN_INVALID - - patron: Patron | None = get_one(_db, Patron, id=patron_id) + patron: Patron | None = get_one(_db, Patron, id=data.id) if patron is None: return None @@ -95,11 +80,10 @@ def get_credential_from_header(self, auth: Authorization) -> str | None: auth and auth.type.lower() == "bearer" and auth.token - and AccessTokenProvider.is_access_token(auth.token) + and PatronJWEAccessTokenProvider.is_access_token(auth.token) ): - token = AccessTokenProvider.decode_token(self._db, auth.token) - if type(token) == dict: - return token.get("pwd") + token = PatronJWEAccessTokenProvider.decrypt_token(self._db, auth.token) + return token.pwd return None diff --git a/api/authenticator.py b/api/authenticator.py index 53fa792271..b0803a7183 100644 --- a/api/authenticator.py +++ b/api/authenticator.py @@ -16,7 +16,7 @@ from api.adobe_vendor_id import AuthdataUtility from api.annotations import AnnotationWriter -from api.authentication.access_token import AccessTokenProvider +from api.authentication.access_token import PatronJWEAccessTokenProvider from api.authentication.base import ( AuthenticationProvider, LibrarySettingsType, @@ -30,9 +30,10 @@ from core.analytics import Analytics from core.integration.goals import Goals from core.integration.registry import IntegrationRegistry -from core.model import ConfigurationSetting, Library, Patron, PatronProfileStorage +from core.model import Key, Library, Patron, PatronProfileStorage from core.model.announcements import Announcement from core.model.integration import IntegrationLibraryConfiguration +from core.model.key import KeyType from core.user_profile import ProfileController from core.util.authentication_for_opds import AuthenticationForOPDSDocument from core.util.http import RemoteIntegrationException @@ -188,17 +189,6 @@ def from_config( (integration.parent.id, library.id) ] = e - if authenticator.saml_providers_by_name: - # NOTE: this will immediately commit the database session, - # which may not be what you want during a test. To avoid - # this, you can create the bearer token signing secret as - # a regular site-wide ConfigurationSetting. - authenticator.bearer_token_signing_secret = ( - BearerTokenSigner.bearer_token_signing_secret(_db) - ) - - authenticator.assert_ready_for_token_signing() - return authenticator def __init__( @@ -241,7 +231,12 @@ def __init__( ) self.saml_providers_by_name = {} - self.bearer_token_signing_secret = bearer_token_signing_secret + self.bearer_token_signing_secret = ( + bearer_token_signing_secret + or Key.get_key( + _db, KeyType.BEARER_TOKEN_SIGNING, raise_exception=True + ).value + ) self.initialization_exceptions: dict[ tuple[int | None, int | None], Exception ] = {} @@ -257,8 +252,6 @@ def __init__( for provider in saml_providers: self.saml_providers_by_name[provider.label()] = provider - self.assert_ready_for_token_signing() - @property def supports_patron_authentication(self) -> bool: """Does this library have any way of authenticating patrons at all?""" @@ -289,17 +282,6 @@ def library(self) -> Library | None: return None return Library.by_id(self._db, self.library_id) - def assert_ready_for_token_signing(self): - """If this LibraryAuthenticator has SAML providers, ensure that it - also has a secret it can use to sign bearer tokens. - """ - if self.saml_providers_by_name and not self.bearer_token_signing_secret: - raise CannotLoadConfiguration( - _( - "SAML providers are configured, but secret for signing bearer tokens is not." - ) - ) - def register_provider( self, integration: IntegrationLibraryConfiguration, @@ -453,7 +435,7 @@ def authenticated_patron( if ( self.access_token_authentication_provider - and AccessTokenProvider.is_access_token(auth.token) + and PatronJWEAccessTokenProvider.is_access_token(auth.token) ): provider = self.access_token_authentication_provider provider_token = auth.token @@ -795,28 +777,8 @@ def create_authentication_headers(self) -> Headers: return headers -class BearerTokenSigner: - """Mixin class used for storing a secret used for signing Bearer tokens""" - - # Name of the site-wide ConfigurationSetting containing the secret - # used to sign bearer tokens. - BEARER_TOKEN_SIGNING_SECRET = Configuration.BEARER_TOKEN_SIGNING_SECRET - - @classmethod - def bearer_token_signing_secret(cls, db): - """Find or generate the site-wide bearer token signing secret. - - :param db: Database session - :type db: sqlalchemy.orm.session.Session - - :return: ConfigurationSetting object containing the signing secret - :rtype: ConfigurationSetting - """ - return ConfigurationSetting.sitewide_secret(db, cls.BEARER_TOKEN_SIGNING_SECRET) - - class BaseSAMLAuthenticationProvider( - AuthenticationProvider[SettingsType, LibrarySettingsType], BearerTokenSigner, ABC + AuthenticationProvider[SettingsType, LibrarySettingsType], ABC ): """ Base class for SAML authentication providers diff --git a/api/controller/patron_auth_token.py b/api/controller/patron_auth_token.py index 9c5010f426..920f8dc2b3 100644 --- a/api/controller/patron_auth_token.py +++ b/api/controller/patron_auth_token.py @@ -1,17 +1,16 @@ from __future__ import annotations -import logging - import flask -from api.authentication.access_token import AccessTokenProvider +from api.authentication.access_token import PatronJWEAccessTokenProvider from api.controller.circulation_manager import CirculationManagerController from api.model.patron_auth import PatronAuthAccessToken from api.problem_details import PATRON_AUTH_ACCESS_TOKEN_NOT_POSSIBLE +from core.util.log import LoggerMixin from core.util.problem_detail import ProblemError -class PatronAuthTokenController(CirculationManagerController): +class PatronAuthTokenController(CirculationManagerController, LoggerMixin): def get_token(self): """Create a Patron Auth access token for an authenticated patron""" patron = flask.request.patron @@ -22,16 +21,14 @@ def get_token(self): return PATRON_AUTH_ACCESS_TOKEN_NOT_POSSIBLE try: - token = AccessTokenProvider.generate_token( + token = PatronJWEAccessTokenProvider.generate_token( self._db, patron, auth["password"], expires_in=token_expiry, ) except ProblemError as ex: - logging.getLogger(self.__class__.__name__).error( - f"Could not generate Patron Auth Access Token: {ex}" - ) + self.log.error(f"Could not generate Patron Auth Access Token: {ex}") return ex.problem_detail return PatronAuthAccessToken( diff --git a/core/jobs/rotate_jwe_key.py b/core/jobs/rotate_jwe_key.py index 22ea036969..a35db19532 100644 --- a/core/jobs/rotate_jwe_key.py +++ b/core/jobs/rotate_jwe_key.py @@ -1,15 +1,22 @@ -from api.authentication.access_token import AccessTokenProvider +from api.authentication.access_token import PatronJWEAccessTokenProvider from core.scripts import Script class RotateJWEKeyScript(Script): - def do_run(self): - current = AccessTokenProvider.get_current_key(self._db, create=False) - self.log.info( - f"Rotating out key {current and current.get('kid')}: {current and current.thumbprint()}" - ) - - new_key = AccessTokenProvider.rotate_key(self._db) - self.log.info(f"Rotated new key {new_key.get('kid')}: {new_key.thumbprint()}") + def do_run(self) -> None: + try: + current = PatronJWEAccessTokenProvider.get_key(self._db) + jwk = PatronJWEAccessTokenProvider.get_jwk(current) + self.log.info(f"Rotating out key {current.id}: {jwk.thumbprint()}") + except ValueError: + self.log.info("No current key found") + + new_key = PatronJWEAccessTokenProvider.create_key(self._db) + new_jwk = PatronJWEAccessTokenProvider.get_jwk(new_key) + self.log.info(f"Rotated in key {new_key.id}: {new_jwk.thumbprint()}") + + # Remove old / expired keys + removed = PatronJWEAccessTokenProvider.delete_old_keys(self._db) + self.log.info(f"Removed {removed} expired keys") self._db.commit() diff --git a/core/model/__init__.py b/core/model/__init__.py index 8d37b88be4..b7abf4822a 100644 --- a/core/model/__init__.py +++ b/core/model/__init__.py @@ -4,7 +4,7 @@ import logging import os from collections.abc import Generator -from typing import Any, List, Literal, Tuple, Type, TypeVar, Union +from typing import Any, Literal, TypeVar from contextlib2 import contextmanager from psycopg2.extensions import adapt as sqlescape @@ -12,7 +12,7 @@ from pydantic.json import pydantic_encoder from sqlalchemy import create_engine from sqlalchemy.engine import Connection -from sqlalchemy.exc import DatabaseError, IntegrityError +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound @@ -438,6 +438,13 @@ def initialize_data(cls, session: Session): ) mechanism.default_client_can_fulfill = True + from api.authentication.access_token import PatronJWEAccessTokenProvider + + # Create our secret keys + Key.create_admin_secret_key(session) + Key.create_bearer_token_signing_key(session) + PatronJWEAccessTokenProvider.create_key(session) + # If there is currently no 'site configuration change' # Timestamp in the database, create one. timestamp, is_new = get_one_or_create( @@ -539,6 +546,7 @@ def _bulk_operation(self): IntegrationConfiguration, IntegrationLibraryConfiguration, ) +from core.model.key import Key from core.model.library import Library from core.model.licensing import ( DeliveryMechanism, diff --git a/core/model/key.py b/core/model/key.py new file mode 100644 index 0000000000..368a352689 --- /dev/null +++ b/core/model/key.py @@ -0,0 +1,156 @@ +import datetime +import uuid +from collections.abc import Callable +from enum import Enum +from typing import Literal, overload + +from sqlalchemy import Column, DateTime +from sqlalchemy import Enum as SaEnum +from sqlalchemy import Unicode, delete, select +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Session +from typing_extensions import Self + +from core.model import Base +from core.util.datetime_helpers import utc_now +from core.util.string_helpers import random_key +from core.util.uuid import uuid_decode + + +class KeyType(Enum): + AUTH_TOKEN_JWE = "auth_token" + BEARER_TOKEN_SIGNING = "bearer_token" + ADMIN_SECRET_KEY = "admin_auth" + + +class Key(Base): + __tablename__ = "keys" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + created = Column( + DateTime(timezone=True), index=True, nullable=False, default=utc_now + ) + value = Column(Unicode, nullable=False) + type = Column(SaEnum(KeyType), nullable=False, index=True) + + def __repr__(self) -> str: + return f"" + + @classmethod + @overload + def get_key( + cls, + db: Session, + key_type: KeyType, + key_id: str | uuid.UUID | None = None, + *, + raise_exception: Literal[True] = True, + ) -> Self: + ... + + @classmethod + @overload + def get_key( + cls, + db: Session, + key_type: KeyType, + key_id: str | uuid.UUID | None = None, + *, + raise_exception: bool = False, + ) -> Self | None: + ... + + @classmethod + def get_key( + cls, + db: Session, + key_type: KeyType, + key_id: str | uuid.UUID | None = None, + *, + raise_exception: bool = False, + ) -> Self | None: + """Get a key from the DB""" + key_query = select(Key).where(Key.type == key_type).order_by(Key.created.desc()) + + if key_id is not None: + decoded_kid = uuid_decode(key_id) if isinstance(key_id, str) else key_id + key_query = key_query.where(Key.id == decoded_kid) + + result_key = db.scalars(key_query).first() + if result_key is None and raise_exception: + raise ValueError(f"No key found in the database with type {key_type}") + + return result_key + + @classmethod + def create_key( + cls, db: Session, key_type: KeyType, create: Callable[[uuid.UUID], str] + ) -> Self: + """Create a new key in the DB""" + key_id = uuid.uuid4() + value = create(key_id) + + key = cls(id=key_id, value=value, type=key_type) + db.add(key) + db.flush() + return key + + @classmethod + def create_admin_secret_key(cls, db: Session) -> Self: + """Create a new admin secret key in the DB""" + # If we already have an admin secret key, we should not create a new one + existing_key = cls.get_key(db, KeyType.ADMIN_SECRET_KEY) + if existing_key: + return existing_key + + return cls.create_key(db, KeyType.ADMIN_SECRET_KEY, lambda _: random_key(48)) + + @classmethod + def create_bearer_token_signing_key(cls, db: Session) -> Self: + """Create a new admin secret key in the DB""" + # If we already have a bearer token signing key, we should not create a new one + existing_key = cls.get_key(db, KeyType.BEARER_TOKEN_SIGNING) + if existing_key: + return existing_key + + return cls.create_key( + db, KeyType.BEARER_TOKEN_SIGNING, lambda _: random_key(48) + ) + + @classmethod + def delete_old_keys( + cls, db: Session, key_type: KeyType, keep: int, older_than: datetime.datetime + ) -> int: + """ + Delete old keys from the DB + """ + if keep < 0: + raise ValueError("keep must be a non-negative integer") + + if keep > 0: + ids_to_keep = [ + row.id + for row in db.execute( + select(cls.id) + .where(cls.type == key_type) + .order_by(cls.created.desc()) + .limit(keep) + ) + ] + else: + ids_to_keep = [] + + delete_query = ( + delete(cls).where(cls.type == key_type).where(cls.created < older_than) + ) + + if ids_to_keep: + delete_query = delete_query.where(cls.id.notin_(ids_to_keep)) + + result = db.execute(delete_query) + # mypy doesn't recognize the rowcount attribute on the CursorResult + # since db.execute doesn't always return a CursorResult. + # The sqlalchemy docs say that the rowcount attribute is always present + # when doing a DELETE statement, so we can safely ignore this error. + # https://docs.sqlalchemy.org/en/20/tutorial/data_update.html#getting-affected-row-count-from-update-delete + return result.rowcount # type: ignore[attr-defined] diff --git a/core/util/string_helpers.py b/core/util/string_helpers.py index 3ef012f248..dfa05e31eb 100644 --- a/core/util/string_helpers.py +++ b/core/util/string_helpers.py @@ -3,6 +3,8 @@ # bytestrings. import binascii import os +import secrets +import string def random_string(size: int) -> str: @@ -12,3 +14,12 @@ def random_string(size: int) -> str: :return: A Unicode string. """ return binascii.hexlify(os.urandom(size)).decode("utf8") + + +def random_key(size: int) -> str: + """Generate a random string suitable for use as a key. + + :param: Size of the key to generate. + :return: A Unicode string. + """ + return "".join(secrets.choice(string.printable) for i in range(size)) diff --git a/core/util/uuid.py b/core/util/uuid.py index 07d81774a8..92b75e1a99 100644 --- a/core/util/uuid.py +++ b/core/util/uuid.py @@ -18,8 +18,16 @@ def uuid_decode(encoded: str) -> UUID: """ Decode a URL-safe base64 string to a UUID. Reverse of uuid_encode. """ - if len(encoded) != 22: - raise ValueError("Invalid base64 string for UUID") - padding = "==" - decoded_bytes = urlsafe_b64decode(encoded + padding) - return UUID(bytes=decoded_bytes) + if len(encoded) == 22: + # This looks like an encoded UUID, so add padding and try to decode it + padding = "==" + decoded_bytes = urlsafe_b64decode(encoded + padding) + return UUID(bytes=decoded_bytes) + + # See if this is a normal UUID hex string + encoded = encoded.replace("urn:", "").replace("uuid:", "") + encoded = encoded.strip("{}").replace("-", "") + if len(encoded) == 32: + return UUID(hex=encoded) + + raise ValueError("Invalid string for UUID") diff --git a/pyproject.toml b/pyproject.toml index 95728cd934..cf2047142d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ module = [ "api.opds_for_distributors", "core.feed.*", "core.integration.*", + "core.jobs.rotate_jwe_key", "core.marc", "core.migration.*", "core.model.announcements", diff --git a/scripts.py b/scripts.py index 77265cf924..01b6b61750 100644 --- a/scripts.py +++ b/scripts.py @@ -524,9 +524,6 @@ def initialize_database(self, connection: Connection) -> None: # Initialize the database with default data SessionManager.initialize_data(session) - # Create a secret key if one doesn't already exist. - ConfigurationSetting.sitewide_secret(session, Configuration.SECRET_KEY) - # Stamp the most recent migration as the current state of the DB alembic_conf = self._get_alembic_config(connection) command.stamp(alembic_conf, "head") diff --git a/tests/api/authentication/__init__.py b/tests/api/authentication/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/api/authentication/test_access_token.py b/tests/api/authentication/test_access_token.py new file mode 100644 index 0000000000..71155e4667 --- /dev/null +++ b/tests/api/authentication/test_access_token.py @@ -0,0 +1,299 @@ +import base64 +import functools +import json +import uuid +from datetime import datetime, timedelta +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from _pytest.logging import LogCaptureFixture +from dateutil import tz +from freezegun import freeze_time +from jwcrypto import jwe, jwk +from sqlalchemy import delete + +from api.authentication.access_token import PatronJWEAccessTokenProvider +from api.problem_details import ( + PATRON_AUTH_ACCESS_TOKEN_EXPIRED, + PATRON_AUTH_ACCESS_TOKEN_INVALID, +) +from core.model.key import Key, KeyType +from core.util.datetime_helpers import utc_now +from core.util.problem_detail import ProblemError +from core.util.uuid import uuid_encode +from tests.fixtures.database import DatabaseTransactionFixture + + +class JWEProviderFixture: + def __init__(self, db: DatabaseTransactionFixture): + self.db = db + self.patron = db.patron() + self.generate_token = functools.partial( + PatronJWEAccessTokenProvider.generate_token, + self.db.session, + patron=self.patron, + password="password", + ) + self.key = PatronJWEAccessTokenProvider.create_key(self.db.session) + self.jwk = PatronJWEAccessTokenProvider.get_jwk(self.key) + assert self.key.id is not None + self.kid = uuid_encode(self.key.id) + self.one_hour_in_future = utc_now() + timedelta(hours=1, minutes=1) + self.one_hour_ago = utc_now() - timedelta(hours=1, minutes=1) + + def create_token(self, plaintext: str = "blah blah", **kwargs: Any) -> str: + headers: dict[str, Any] = { + "alg": "dir", + "enc": "A128CBC-HS256", + } + + if "kid" not in kwargs: + headers["kid"] = self.kid + elif kwargs["kid"] is not None: + headers["kid"] = kwargs["kid"] + + if "typ" not in kwargs: + headers["typ"] = "JWE" + elif kwargs["typ"] is not None: + headers["typ"] = kwargs["typ"] + + if "cty" not in kwargs: + headers["cty"] = PatronJWEAccessTokenProvider.CTY + elif kwargs["cty"] is not None: + headers["cty"] = kwargs["cty"] + + if "exp" not in kwargs: + headers["exp"] = self.one_hour_in_future.timestamp() + elif kwargs["exp"] is not None: + headers["exp"] = kwargs["exp"] + + token = jwe.JWE( + plaintext=plaintext, + protected=headers, + recipient=self.jwk, + ) + return token.serialize(compact=True) + + +@pytest.fixture +def jwe_provider(db: DatabaseTransactionFixture) -> JWEProviderFixture: + return JWEProviderFixture(db) + + +class TestJWEProvider: + def test_generate_jwk(self): + _id = uuid.uuid4() + key = PatronJWEAccessTokenProvider.generate_jwk(_id) + assert isinstance(key, str) + jwk_key = jwk.JWK.from_json(key) + assert jwk_key.get("kty") == "oct" + assert jwk_key.get("kid") == uuid_encode(_id) + + @freeze_time("1990-05-05") + def test_create_key(self, db: DatabaseTransactionFixture): + key = PatronJWEAccessTokenProvider.create_key(db.session) + assert key.created == utc_now() + assert isinstance(key.id, uuid.UUID) + jwk_key = PatronJWEAccessTokenProvider.get_jwk(key) + assert isinstance(jwk_key, jwk.JWK) + + def test_get_key(self, db: DatabaseTransactionFixture, caplog: LogCaptureFixture): + # Remove any existing keys before running tests + db.session.execute(delete(Key).where(Key.type == KeyType.AUTH_TOKEN_JWE)) + + # If no key exists, we raise an exception + with pytest.raises(ValueError): + PatronJWEAccessTokenProvider.get_key(db.session) + + key = PatronJWEAccessTokenProvider.create_key(db.session) + + # If a key exists, it should return it + assert PatronJWEAccessTokenProvider.get_key(db.session) == key + + # If a key exists, but it's too old, it should return it and log a warning + key.created = utc_now() - timedelta(days=3) + assert PatronJWEAccessTokenProvider.get_key(db.session) == key + assert ( + "The most recently created AUTH_TOKEN_JWE key is more then two days old" + in caplog.text + ) + + # If multiple keys exist, it should return the most recent one + key2 = PatronJWEAccessTokenProvider.create_key(db.session) + assert PatronJWEAccessTokenProvider.get_key(db.session) == key2 + + # If a key id is passed in, it should return that key + assert PatronJWEAccessTokenProvider.get_key(db.session, key.id) == key + + def test_generate_token(self, jwe_provider: JWEProviderFixture): + t = utc_now() + with freeze_time(t): + token = jwe_provider.generate_token() + _header, _, _ = token.partition(".") + header = json.loads(base64.b64decode(_header + "===")) + assert datetime.fromtimestamp(header["exp"], tz=tz.tzutc()) == t + timedelta( + hours=1 + ) + assert header["typ"] == "JWE" + assert header["kid"] == jwe_provider.kid + + def test_decode_token(self, jwe_provider: JWEProviderFixture): + token = jwe_provider.generate_token() + decoded = PatronJWEAccessTokenProvider.decode_token(token) + assert isinstance(decoded, jwe.JWE) + assert decoded.allowed_algs == ["dir", "A128CBC-HS256"] + + def test_decode_token_errors(self, jwe_provider: JWEProviderFixture): + # Completely invalid token + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decode_token("not-a-token") + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID + + # Expired token + with freeze_time(jwe_provider.one_hour_ago): + token = jwe_provider.generate_token() + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decode_token(token) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_EXPIRED + + # Token with no exp + token = jwe_provider.create_token(exp=None) + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decode_token(token) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_EXPIRED + + # Token with no kid + token = jwe_provider.create_token(kid=None) + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decode_token(token) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID + + # Token with no typ + token = jwe_provider.create_token(typ=None) + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decode_token(token) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID + + # Token with wrong typ + token = jwe_provider.create_token(typ="foo") + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decode_token(token) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID + + # Token with no cty + token = jwe_provider.create_token(cty=None) + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decode_token(token) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID + + # Token with wrong cty + token = jwe_provider.create_token(cty="foo") + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decode_token(token) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID + + def test_decrypt_token( + self, db: DatabaseTransactionFixture, jwe_provider: JWEProviderFixture + ): + token = jwe_provider.generate_token() + decoded = PatronJWEAccessTokenProvider.decode_token(token) + decrypted = PatronJWEAccessTokenProvider.decrypt_token(db.session, decoded) + assert decrypted.id == jwe_provider.patron.id + assert decrypted.pwd == "password" + + # Decrypt can also directly take a token string + decrypted = PatronJWEAccessTokenProvider.decrypt_token(db.session, token) + assert decrypted.id == jwe_provider.patron.id + assert decrypted.pwd == "password" + + def test_decrypt_token_bad_key( + self, db: DatabaseTransactionFixture, jwe_provider: JWEProviderFixture + ): + token = jwe_provider.generate_token() + decoded = PatronJWEAccessTokenProvider.decode_token(token) + + # No key + db.session.delete(jwe_provider.key) + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decrypt_token(db.session, decoded) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID + + def test_decrypt_token_errors( + self, db: DatabaseTransactionFixture, jwe_provider: JWEProviderFixture + ): + # Bad kid + token = jwe_provider.create_token(kid="fake") + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decrypt_token(db.session, token) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID + + # Invalid token - Bad tag + token = jwe_provider.create_token() + "B" + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decrypt_token(db.session, token) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID + + # Invalid token - Bad enc type + token = jwe_provider.create_token(enc="A256GCM") + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decrypt_token(db.session, token) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID + + # Invalid payload - not json + token = jwe_provider.create_token(plaintext="not-json") + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decrypt_token(db.session, token) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID + + # Invalid payload - missing keys + token = jwe_provider.create_token(plaintext="{}") + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decrypt_token(db.session, token) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID + + # Invalid payload - missing id + token = jwe_provider.create_token(plaintext=json.dumps({"pwd": "password"})) + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decrypt_token(db.session, token) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID + + # Invalid payload - missing pwd + token = jwe_provider.create_token(plaintext=json.dumps({"id": "1234"})) + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decrypt_token(db.session, token) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID + + # Invalid payload - extra keys + token = jwe_provider.create_token( + plaintext=json.dumps({"id": "1234", "pwd": "password", "extra": "key"}) + ) + with pytest.raises(ProblemError) as exc: + PatronJWEAccessTokenProvider.decrypt_token(db.session, token) + assert exc.value.problem_detail == PATRON_AUTH_ACCESS_TOKEN_INVALID + + def test_is_access_token(self, jwe_provider: JWEProviderFixture): + # Happy path + token = jwe_provider.generate_token() + assert PatronJWEAccessTokenProvider.is_access_token(token) is True + + with patch.object(PatronJWEAccessTokenProvider, "decode_token") as decode: + # An incorrect type + decode.side_effect = Exception("Bang!") + assert PatronJWEAccessTokenProvider.is_access_token(token) is False + + # The token is not the right format + assert PatronJWEAccessTokenProvider.is_access_token("not-a-token") is False + + @freeze_time() + def test_delete_old_keys(self): + mock_session = MagicMock() + with patch("api.authentication.access_token.Key") as mock_key: + PatronJWEAccessTokenProvider.delete_old_keys(mock_session) + + mock_key.delete_old_keys.assert_called_once_with( + mock_session, + KeyType.AUTH_TOKEN_JWE, + keep=2, + older_than=utc_now() - timedelta(days=2), + ) diff --git a/tests/api/test_basic_token_authentication_provider.py b/tests/api/authentication/test_basic_token.py similarity index 74% rename from tests/api/test_basic_token_authentication_provider.py rename to tests/api/authentication/test_basic_token.py index 535e3abebd..a1cbf0ce3e 100644 --- a/tests/api/test_basic_token_authentication_provider.py +++ b/tests/api/authentication/test_basic_token.py @@ -2,7 +2,10 @@ from werkzeug.datastructures import Authorization -from api.authentication.access_token import AccessTokenProvider +from api.authentication.access_token import ( + PatronJWEAccessTokenProvider, + TokenPatronInfo, +) from api.authentication.basic import BasicAuthenticationProvider from api.authentication.basic_token import BasicTokenAuthenticationProvider from api.problem_details import PATRON_AUTH_ACCESS_TOKEN_INVALID @@ -18,32 +21,22 @@ def test_authenticated_patron(self, db: DatabaseTransactionFixture): db.session, db.default_library(), Mock() ) with patch( - "api.authentication.basic_token.AccessTokenProvider" + "api.authentication.basic_token.PatronJWEAccessTokenProvider" ) as token_provider: - token_provider.decode_token.return_value = dict( + assert isinstance(patron.id, int) + token_provider.decrypt_token.return_value = TokenPatronInfo( id=patron.id, pwd="password" ) got_patron = provider.authenticated_patron(db.session, "token-string") - assert type(got_patron) is Patron + assert isinstance(got_patron, Patron) assert got_patron.id == patron.id - # Any incorrect data would mean an invalid token - token_provider.decode_token.return_value = dict(id=patron.id, typ="patron") - assert PATRON_AUTH_ACCESS_TOKEN_INVALID == provider.authenticated_patron( - db.session, "token-string" - ) - - token_provider.decode_token.return_value = dict(pwd="password") - assert PATRON_AUTH_ACCESS_TOKEN_INVALID == provider.authenticated_patron( - db.session, "token-string" - ) - # Nonexistent patron - token_provider.decode_token.return_value = dict( + token_provider.decrypt_token.return_value = TokenPatronInfo( id=999999999, pwd="password" ) - assert None == provider.authenticated_patron(db.session, "token-string") + assert provider.authenticated_patron(db.session, "token-string") is None def test_authenticated_patron_errors(self, db: DatabaseTransactionFixture): provider = BasicTokenAuthenticationProvider( @@ -62,7 +55,9 @@ def test_credential_from_header(self, db: DatabaseTransactionFixture): db.session, db.default_library(), Mock() ) patron = db.patron() - token = AccessTokenProvider.generate_token(db.session, patron, "passworx") + token = PatronJWEAccessTokenProvider.generate_token( + db.session, patron, "passworx" + ) pwd = provider.get_credential_from_header( Authorization(auth_type="Bearer", token=token) diff --git a/tests/api/test_authenticator.py b/tests/api/test_authenticator.py index 62cc648e7d..458699b682 100644 --- a/tests/api/test_authenticator.py +++ b/tests/api/test_authenticator.py @@ -22,7 +22,7 @@ from werkzeug.datastructures import Authorization from api.annotations import AnnotationWriter -from api.authentication.access_token import AccessTokenProvider +from api.authentication.access_token import PatronJWEAccessTokenProvider from api.authentication.base import PatronData from api.authentication.basic import ( BarcodeFormats, @@ -924,7 +924,7 @@ def test_authenticated_patron_bearer_access_token( assert patron_lookup_provider == basic_auth_provider patron = db.patron() - token = AccessTokenProvider.generate_token(db.session, patron, "pass") + token = PatronJWEAccessTokenProvider.generate_token(db.session, patron, "pass") auth = Authorization(auth_type="bearer", token=token) auth_patron = authenticator.authenticated_patron(db.session, auth) @@ -975,7 +975,9 @@ def get_library_authenticator( authenticator = get_library_authenticator(basic_auth_provider=basic) patron = db.patron() - token = AccessTokenProvider.generate_token(db.session, patron, "passworx") + token = PatronJWEAccessTokenProvider.generate_token( + db.session, patron, "passworx" + ) credential = Authorization(auth_type="bearer", token=token) assert authenticator.get_credential_from_header(credential) == "passworx" diff --git a/tests/api/test_jwe_provider.py b/tests/api/test_jwe_provider.py deleted file mode 100644 index fe2a58ae10..0000000000 --- a/tests/api/test_jwe_provider.py +++ /dev/null @@ -1,122 +0,0 @@ -import base64 -import json -from datetime import datetime, timedelta -from unittest.mock import MagicMock, patch - -import pytest -from dateutil import tz -from freezegun import freeze_time -from jwcrypto import jwk - -from api.authentication.access_token import PatronJWEAccessTokenProvider -from api.problem_details import PATRON_AUTH_ACCESS_TOKEN_EXPIRED -from core.util.datetime_helpers import utc_now -from tests.fixtures.database import DatabaseTransactionFixture - - -class TestJWEProvider: - def test_generate_key(self): - key = PatronJWEAccessTokenProvider.generate_key() - assert type(key) == jwk.JWK - assert key.get("kty") == "oct" - - def test_generate_token(self, db: DatabaseTransactionFixture): - patron = db.patron() - t = utc_now() - with freeze_time(t): - token = PatronJWEAccessTokenProvider.generate_token( - db.session, patron, "password" - ) - _header, _, _ = token.partition(".") - header = json.loads(base64.b64decode(_header + "===")) - assert datetime.fromtimestamp(header["exp"], tz=tz.tzutc()) == t + timedelta( - hours=1 - ) - assert header["typ"] == "JWE" - - current_key = PatronJWEAccessTokenProvider.get_current_key(db.session) - assert isinstance(current_key, jwk.JWK) - assert header["kid"] == current_key.get("kid") - - def test_get_current_key(self, db: DatabaseTransactionFixture): - key1 = PatronJWEAccessTokenProvider.get_current_key(db.session) - key2 = PatronJWEAccessTokenProvider.get_current_key(db.session) - assert key1 == key2 - - with pytest.raises(ValueError): - PatronJWEAccessTokenProvider.get_current_key(db.session, kid="not-the-kid") - - assert isinstance(key1, jwk.JWK) - assert ( - PatronJWEAccessTokenProvider.get_current_key( - db.session, kid=key1.get("kid") - ) - == key1 - ) - - def test_decode_token(self, db: DatabaseTransactionFixture): - patron = db.patron() - token = PatronJWEAccessTokenProvider.generate_token( - db.session, patron, "password" - ) - decoded = PatronJWEAccessTokenProvider.decode_token(db.session, token) - assert isinstance(decoded, dict) - assert decoded["id"] == patron.id - assert decoded["pwd"] == "password" - assert decoded["typ"] == "patron" - - def test_decode_token_errors(self, db: DatabaseTransactionFixture): - patron = db.patron() - - with patch.object(PatronJWEAccessTokenProvider, "get_current_key") as mock_key: - mock_key.return_value = jwk.JWK.generate( - kty="oct", size=256, kid="some-kid" - ) - token = PatronJWEAccessTokenProvider.generate_token( - db.session, patron, "password", expires_in=1000 - ) - decoded = PatronJWEAccessTokenProvider.decode_token(db.session, token) - assert decoded == PATRON_AUTH_ACCESS_TOKEN_EXPIRED - - token = PatronJWEAccessTokenProvider.generate_token( - db.session, patron, "password", expires_in=-1 - ) - decoded = PatronJWEAccessTokenProvider.decode_token(db.session, token) - assert decoded == PATRON_AUTH_ACCESS_TOKEN_EXPIRED - - def test_rotate_key(self, db: DatabaseTransactionFixture): - key = PatronJWEAccessTokenProvider.rotate_key(db.session) - stored_key = PatronJWEAccessTokenProvider.get_current_key( - db.session, create=False - ) - assert stored_key == key - - key2 = PatronJWEAccessTokenProvider.rotate_key(db.session) - stored_key = PatronJWEAccessTokenProvider.get_current_key( - db.session, create=False - ) - assert key2.get("kid") != key.get("kid") - assert key2.thumbprint() != key.thumbprint() - assert stored_key == key2 - - def test_is_access_token(self, db: DatabaseTransactionFixture): - patron = db.patron() - # Happy path - token = PatronJWEAccessTokenProvider.generate_token( - db.session, patron, "password" - ) - assert PatronJWEAccessTokenProvider.is_access_token(token) == True - - with patch.object(PatronJWEAccessTokenProvider, "_decode") as decode: - # An incorrect type - decode.return_value = MagicMock( - objects=dict(protected=json.dumps(dict(typ="NotJWE"))) - ) - assert PatronJWEAccessTokenProvider.is_access_token(token) == False - - # Something failed during the decode - decode.return_value = None - assert PatronJWEAccessTokenProvider.is_access_token(token) == False - - # The token is not the right format - assert PatronJWEAccessTokenProvider.is_access_token("not-a-token") == False diff --git a/tests/core/jobs/test_rotate_jwe_key.py b/tests/core/jobs/test_rotate_jwe_key.py index 070a4da628..149cd21961 100644 --- a/tests/core/jobs/test_rotate_jwe_key.py +++ b/tests/core/jobs/test_rotate_jwe_key.py @@ -1,16 +1,72 @@ +from datetime import timedelta + +from freezegun import freeze_time +from sqlalchemy import delete, select + from api.authentication.access_token import PatronJWEAccessTokenProvider from core.jobs.rotate_jwe_key import RotateJWEKeyScript +from core.model import Key +from core.model.key import KeyType +from core.util.datetime_helpers import utc_now from tests.fixtures.database import DatabaseTransactionFixture class TestRotateJWEKeyScript: def test_do_run(self, db: DatabaseTransactionFixture): script = RotateJWEKeyScript(db.session) - current = PatronJWEAccessTokenProvider.get_current_key(db.session) + current = PatronJWEAccessTokenProvider.create_key(db.session) script.do_run() - db.session.expire_all() - new_key = PatronJWEAccessTokenProvider.get_current_key(db.session) + new_key = PatronJWEAccessTokenProvider.get_key(db.session) assert current is not None assert new_key is not None - assert current.thumbprint() != new_key.thumbprint() + assert current.id != new_key.id + + def test_do_run_no_current_key(self, db: DatabaseTransactionFixture): + db.session.execute(delete(Key).where(Key.type == KeyType.AUTH_TOKEN_JWE)) + assert Key.get_key(db.session, KeyType.AUTH_TOKEN_JWE) is None + script = RotateJWEKeyScript(db.session) + script.do_run() + created_key = PatronJWEAccessTokenProvider.get_key(db.session) + assert isinstance(created_key, Key) + + @freeze_time() + def test_do_run_remove_expired(self, db: DatabaseTransactionFixture): + db.session.execute(delete(Key).where(Key.type == KeyType.AUTH_TOKEN_JWE)) + script = RotateJWEKeyScript(db.session) + + key1 = PatronJWEAccessTokenProvider.create_key(db.session) + key1.created = utc_now() - timedelta(days=2, hours=4) + key2 = PatronJWEAccessTokenProvider.create_key(db.session) + key2.created = utc_now() - timedelta(days=3) + key3 = PatronJWEAccessTokenProvider.create_key(db.session) + key3.created = utc_now() - timedelta(days=4) + key4 = PatronJWEAccessTokenProvider.create_key(db.session) + key4.created = utc_now() - timedelta(days=5) + key5 = PatronJWEAccessTokenProvider.create_key(db.session) + key5.created = utc_now() - timedelta(days=6) + + script.do_run() + + queried_keys = db.session.scalars( + select(Key) + .where(Key.type == KeyType.AUTH_TOKEN_JWE) + .order_by(Key.created.desc()) + ).all() + assert len(queried_keys) == 2 + [queried_key_1, queried_key_2] = queried_keys + + # The most recent key is the one that was created by the script + assert queried_key_1.created == utc_now() + assert queried_key_1.id != key1.id + + # key1 was kept, even though it's more than two days old, because we always keep + # two keys, so that tokens created right before the key rotation can still be decrypted + # until the tokens expire. + assert queried_key_2.id == key1.id + + # The other keys were deleted + assert Key.get_key(db.session, KeyType.AUTH_TOKEN_JWE, key_id=key2.id) is None + assert Key.get_key(db.session, KeyType.AUTH_TOKEN_JWE, key_id=key3.id) is None + assert Key.get_key(db.session, KeyType.AUTH_TOKEN_JWE, key_id=key4.id) is None + assert Key.get_key(db.session, KeyType.AUTH_TOKEN_JWE, key_id=key5.id) is None diff --git a/tests/core/models/test_key.py b/tests/core/models/test_key.py new file mode 100644 index 0000000000..c945cfd691 --- /dev/null +++ b/tests/core/models/test_key.py @@ -0,0 +1,159 @@ +import functools +import uuid +from datetime import timedelta +from unittest.mock import MagicMock + +import pytest +from freezegun import freeze_time +from sqlalchemy import delete, select + +from core.model import Key +from core.model.key import KeyType +from core.util.datetime_helpers import utc_now +from core.util.uuid import uuid_encode +from tests.fixtures.database import DatabaseTransactionFixture + + +class KeyFixture: + def __init__(self, db: DatabaseTransactionFixture) -> None: + self.db = db + self.create_func = MagicMock(return_value="test_key") + self.create_key = functools.partial( + Key.create_key, + self.db.session, + KeyType.BEARER_TOKEN_SIGNING, + self.create_func, + ) + self.get_key = functools.partial( + Key.get_key, self.db.session, KeyType.BEARER_TOKEN_SIGNING + ) + self.delete_old_keys = functools.partial( + Key.delete_old_keys, self.db.session, KeyType.BEARER_TOKEN_SIGNING + ) + + # remove any existing keys before running tests + self.db.session.execute(delete(Key)) + + +@pytest.fixture +def key_fixture(db: DatabaseTransactionFixture) -> KeyFixture: + return KeyFixture(db) + + +class TestKey: + def test_create_key( + self, db: DatabaseTransactionFixture, key_fixture: KeyFixture + ) -> None: + with freeze_time("2020-01-01 00:00:00"): + key = key_fixture.create_key() + + assert key.id is not None + with freeze_time("2020-01-01 00:00:00"): + assert key.created == utc_now() + assert key.value == key_fixture.create_func.return_value + key_fixture.create_func.assert_called_once_with(key.id) + + db.session.expire_all() + assert key == db.session.execute(select(Key)).scalar_one() + + def test_get_key(self, key_fixture: KeyFixture) -> None: + key1 = key_fixture.create_key() + key2 = key_fixture.create_key() + + # If called without a key id, it should return the key that was created last + assert key_fixture.get_key() == key2 + + # If called with a key id, it should return the key with that id + assert key_fixture.get_key(key1.id) == key1 + + # The key id can also be passed in as an encoded uuid string + assert isinstance(key1.id, uuid.UUID) + assert key_fixture.get_key(uuid_encode(key1.id)) == key1 + + # Or as a UUID hex string + assert key_fixture.get_key(key1.id.hex) == key1 + + # If a key id is not found, it should return None + assert key_fixture.get_key("0000000000000000000000") is None + + # Unless raise_exception is True + with pytest.raises(ValueError): + key_fixture.get_key( + "0000000000000000000000", + raise_exception=True, + ) + + def test_create_admin_secret_key(self, db: DatabaseTransactionFixture) -> None: + key = Key.create_admin_secret_key(db.session) + assert key.type == KeyType.ADMIN_SECRET_KEY + assert key.value is not None + assert len(key.value) == 48 + + # If we already have an admin secret key, we should not create a new one + key2 = Key.create_admin_secret_key(db.session) + assert key2 == key + + def test_create_bearer_token_signing_key( + self, db: DatabaseTransactionFixture + ) -> None: + key = Key.create_bearer_token_signing_key(db.session) + assert key.type == KeyType.BEARER_TOKEN_SIGNING + assert key.value is not None + assert len(key.value) == 48 + + key2 = Key.create_bearer_token_signing_key(db.session) + assert key2 == key + + def test_delete_old_keys( + self, db: DatabaseTransactionFixture, key_fixture: KeyFixture + ) -> None: + one_day_ago = utc_now() - timedelta(days=1) + two_days_ago = utc_now() - timedelta(days=2) + three_days_ago = utc_now() - timedelta(days=3) + + # If there are no keys, nothing should happen + assert ( + Key.delete_old_keys( + db.session, KeyType.BEARER_TOKEN_SIGNING, keep=1, older_than=utc_now() + ) + == 0 + ) + + # If keep is negative, we raise an error + with pytest.raises(ValueError): + Key.delete_old_keys( + db.session, KeyType.BEARER_TOKEN_SIGNING, keep=-1, older_than=utc_now() + ) + + # Create some keys + key1 = key_fixture.create_key() + key1.created = three_days_ago + key2 = key_fixture.create_key() + key2.created = two_days_ago + key3 = key_fixture.create_key() + key3.created = one_day_ago + + # We always keep the number of keys specified in the keep parameter, even if they are older than the + # older_than parameter + assert key_fixture.delete_old_keys(keep=3, older_than=utc_now()) == 0 + + # If all the keys are newer than the older_than parameter, nothing should happen + assert ( + key_fixture.delete_old_keys( + keep=0, + older_than=three_days_ago, + ) + == 0 + ) + + # If we keep 2 keys, the oldest key should be deleted + assert key_fixture.delete_old_keys(keep=2, older_than=utc_now()) == 1 + assert db.session.execute( + select(Key).order_by(Key.created) + ).scalars().all() == [key2, key3] + + # If we keep 1 key, another key should be deleted + assert key_fixture.delete_old_keys(keep=1, older_than=utc_now()) == 1 + assert db.session.execute( + select(Key).order_by(Key.created) + ).scalars().all() == [key3] diff --git a/tests/core/util/test_string_helpers.py b/tests/core/util/test_string_helpers.py index 584d66a729..ba7e362634 100644 --- a/tests/core/util/test_string_helpers.py +++ b/tests/core/util/test_string_helpers.py @@ -1,31 +1,48 @@ # Test the helper objects in util.string. import re +import string -from core.util.string_helpers import random_string +from core.util.string_helpers import random_key, random_string -class TestRandomString: - def test_random_string(self): - m = random_string - assert "" == m(0) +def test_random_string(): + m = random_string + assert "" == m(0) - # The strings are random. - res1 = m(8) - res2 = m(8) - assert res1 != res2 + # The strings are random. + res1 = m(8) + res2 = m(8) + assert res1 != res2 - # We can't test exact values, because the randomness comes - # from /dev/urandom, but we can test some of their properties: - for size in range(1, 16): - x = m(size) + # We can't test exact values, because the randomness comes + # from /dev/urandom, but we can test some of their properties: + for size in range(1, 16): + x = m(size) - # The strings are Unicode strings, not bytestrings - assert isinstance(x, str) + # The strings are Unicode strings, not bytestrings + assert isinstance(x, str) - # The strings are entirely composed of lowercase hex digits. - assert None == re.compile("[^a-f0-9]").search(x) + # The strings are entirely composed of lowercase hex digits. + assert None == re.compile("[^a-f0-9]").search(x) - # Each byte is represented as two digits, so the length of the - # string is twice the length passed in to the function. - assert size * 2 == len(x) + # Each byte is represented as two digits, so the length of the + # string is twice the length passed in to the function. + assert size * 2 == len(x) + + +def test_random_key(): + m = random_key + assert "" == m(0) + + # The strings are random. + res1 = m(8) + res2 = m(8) + assert res1 != res2 + + # They match the length we asked for. + assert len(m(40)) == 40 + + # All characters are printable. + for letter in m(40): + assert letter in string.printable diff --git a/tests/core/util/test_uuid.py b/tests/core/util/test_uuid.py index 03bc27aa04..543e4c3e6d 100644 --- a/tests/core/util/test_uuid.py +++ b/tests/core/util/test_uuid.py @@ -29,6 +29,12 @@ def test_uuid_encode_decode(uuid: str, expected: str): assert str(decoded) == uuid assert decoded == uuid_obj + # Test that uuid_decode can also handle a normal UUID hex string + decoded = uuid_decode(uuid) + assert isinstance(decoded, UUID) + assert str(decoded) == uuid + assert decoded == uuid_obj + def test_uuid_decode_error(): # Invalid length @@ -38,3 +44,11 @@ def test_uuid_decode_error(): # Invalid characters with pytest.raises(ValueError): uuid_decode("/~") + + # Valid length but not valid + with pytest.raises(ValueError): + uuid_decode("~" * 22) + + # Valid length for a UUID hex string but not valid + with pytest.raises(ValueError): + uuid_decode("~" * 32) diff --git a/tests/fixtures/api_admin.py b/tests/fixtures/api_admin.py index b536219df3..c02ed32b78 100644 --- a/tests/fixtures/api_admin.py +++ b/tests/fixtures/api_admin.py @@ -8,7 +8,6 @@ from api.admin.controller.settings import SettingsController from api.app import initialize_admin from api.circulation_manager import CirculationManager -from api.config import Configuration from core.integration.goals import Goals from core.model import create from core.model.admin import Admin, AdminRole @@ -28,11 +27,6 @@ class AdminControllerFixture: def __init__(self, controller_fixture: ControllerFixture): self.ctrl = controller_fixture self.manager = self.ctrl.manager - - ConfigurationSetting.sitewide( - controller_fixture.db.session, Configuration.SECRET_KEY - ).value = "a secret" - initialize_admin(controller_fixture.db.session) setup_admin_controllers(controller_fixture.manager) self.admin, ignore = create( diff --git a/tests/fixtures/database.py b/tests/fixtures/database.py index acc782ef67..a6bd660ff6 100644 --- a/tests/fixtures/database.py +++ b/tests/fixtures/database.py @@ -59,7 +59,6 @@ from core.model.licensing import License, LicensePoolDeliveryMechanism, LicenseStatus from core.opds_import import OPDSAPI from core.util.datetime_helpers import utc_now -from core.util.string_helpers import random_string class ApplicationFixture: @@ -722,7 +721,7 @@ def integration_configuration( IntegrationConfiguration, protocol=protocol, goal=goal, - name=(name or random_string(16)), + name=(name or self.fresh_str()), ) if libraries is None: