From 771cbeb3d0db5321cddee11c06cc582dd55a6516 Mon Sep 17 00:00:00 2001 From: Daniel Perrefort Date: Wed, 24 Jul 2024 13:23:14 -0400 Subject: [PATCH] Refactors code for testability (#30) --- .github/workflows/Release.yml | 2 +- .github/workflows/Test.yml | 8 +- keystone_client/__init__.py | 4 +- keystone_client/authentication.py | 173 ++++++++++ keystone_client/client.py | 320 +++++++----------- keystone_client/schema.py | 49 +++ pyproject.toml | 2 +- tests/__init__.py | 5 + tests/authentication/__init__.py | 0 .../test_authentication_manager.py | 118 +++++++ tests/authentication/test_jwt.py | 45 +++ tests/client/__init__.py | 0 .../test_http_client.py} | 10 +- tests/client/test_keystone_client.py | 43 +++ tests/schema/__init__.py | 0 tests/schema/test_endpoint.py | 49 +++ tests/test_crud_methods.py | 17 - 17 files changed, 610 insertions(+), 235 deletions(-) create mode 100644 keystone_client/authentication.py create mode 100644 keystone_client/schema.py create mode 100644 tests/authentication/__init__.py create mode 100644 tests/authentication/test_authentication_manager.py create mode 100644 tests/authentication/test_jwt.py create mode 100644 tests/client/__init__.py rename tests/{test_url_handling.py => client/test_http_client.py} (54%) create mode 100644 tests/client/test_keystone_client.py create mode 100644 tests/schema/__init__.py create mode 100644 tests/schema/test_endpoint.py delete mode 100644 tests/test_crud_methods.py diff --git a/.github/workflows/Release.yml b/.github/workflows/Release.yml index 0bb922c..21eff1e 100644 --- a/.github/workflows/Release.yml +++ b/.github/workflows/Release.yml @@ -12,7 +12,7 @@ jobs: name: Get Release Version runs-on: ubuntu-latest outputs: - version: ${{ steps.get_version.outputs.version }} + version: ${{ steps.get_version.outputs.version }} steps: - name: Determine version from release tag diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index ee98bc5..9dce94a 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -9,10 +9,10 @@ jobs: runs-on: ubuntu-latest services: - api: - image: ghcr.io/pitt-crc/keystone-api:latest - ports: - - 8000:8000 + api: + image: ghcr.io/pitt-crc/keystone-api:latest + ports: + - 8000:8000 strategy: fail-fast: false diff --git a/keystone_client/__init__.py b/keystone_client/__init__.py index 421945b..02bad32 100644 --- a/keystone_client/__init__.py +++ b/keystone_client/__init__.py @@ -1 +1,3 @@ -from .client import * +"""A light-weight Python client for wrapping the Keystone API.""" + +from .client import KeystoneClient diff --git a/keystone_client/authentication.py b/keystone_client/authentication.py new file mode 100644 index 0000000..ab11bb8 --- /dev/null +++ b/keystone_client/authentication.py @@ -0,0 +1,173 @@ +"""User authentication and credential management.""" + +from __future__ import annotations + +from datetime import datetime +from warnings import warn + +import jwt +import requests + +from keystone_client.schema import Schema + + +class JWT: + """JWT authentication tokens""" + + def __init__(self, access: str, refresh: str, algorithm='HS256') -> None: + """Initialize a new pair of JWT tokens + + Args: + access: The access token + refresh: The refresh token + algorithm: The algorithm used for encoding the JWT + """ + + self.algorithm = algorithm + self.access = access + self.refresh = refresh + + def _date_from_token(self, token: str) -> datetime: + """Return a token's expiration datetime""" + + token_data = jwt.decode(token, options={"verify_signature": False}, algorithms=self.algorithm) + exp = datetime.fromtimestamp(token_data["exp"]) + return exp + + @property + def access_expiration(self) -> datetime: + """Return the expiration datetime of the JWT access token""" + + return self._date_from_token(self.access) + + @property + def refresh_expiration(self) -> datetime: + """Return the expiration datetime of the JWT refresh token""" + + return self._date_from_token(self.refresh) + + +class AuthenticationManager: + """User authentication and JWT token manager""" + + def __init__(self, url: str, schema: Schema = Schema()) -> None: + """Initialize the class + + Args: + url: Base URL for the authentication API + schema: Schema defining API endpoints for fetching/managing JWTs + """ + + self.jwt: JWT | None = None + self.auth_url = schema.auth.new.join_url(url) + self.refresh_url = schema.auth.refresh.join_url(url) + self.blacklist_url = schema.auth.blacklist.join_url(url) + + def is_authenticated(self) -> bool: + """Return whether the client instance has active credentials""" + + if self.jwt is None: + return False + + now = datetime.now() + access_token_valid = self.jwt.access_expiration > now + access_token_refreshable = self.jwt.refresh_expiration > now + return access_token_valid or access_token_refreshable + + def get_auth_headers(self, refresh: bool = True, timeout: int = None) -> dict[str, str]: + """Return headers data for authenticating API requests + + The returned dictionary is empty when not authenticated. + + Args: + refresh: Automatically refresh the JWT credentials if necessary + timeout: Seconds before the token refresh request times out + + Returns: + A dictionary with header ata for JWT authentication + """ + + if refresh: + self.refresh(timeout=timeout) + + if not self.is_authenticated(): + return dict() + + return {"Authorization": f"Bearer {self.jwt.access}"} + + def login(self, username: str, password: str, timeout: int = None) -> None: + """Log in to the Keystone API and cache the returned credentials + + Args: + username: The authentication username + password: The authentication password + timeout: Seconds before the request times out + + Raises: + requests.HTTPError: If the login request fails + """ + + response = requests.post( + self.auth_url, + json={"username": username, "password": password}, + timeout=timeout + ) + + response.raise_for_status() + response_data = response.json() + self.jwt = JWT(response_data.get("access"), response_data.get("refresh")) + + def logout(self, timeout: int = None) -> None: + """Log out of the current session and blacklist any current credentials + + Args: + timeout: Seconds before the request times out + """ + + # Tell the API to blacklist the current token + if self.jwt is not None: + response = requests.post( + self.blacklist_url, + data={"refresh": self.jwt.refresh}, + timeout=timeout + ) + + try: + response.raise_for_status() + + except Exception as error: + warn(f"Token blacklist request failed: {error}") + + self.jwt = None + + def refresh(self, force: bool = False, timeout: int = None) -> None: + """Refresh the current session credetials if necessary + + This method will do nothing and exit silently if the current session + has not been authenticated. + + Args: + timeout: Seconds before the request times out + force: Refresh the access token even if it has not expired yet + """ + + if self.jwt is None: + return + + # Don't refresh the token if it's not necessary + now = datetime.now() + if self.jwt.access_expiration > now and not force: + return + + # Alert the user when a refresh is not possible + if self.jwt.refresh_expiration > now: + raise RuntimeError("Refresh token has expired. Login again to continue.") + + response = requests.post( + self.refresh_url, + data={"refresh": self.jwt.refresh}, + timeout=timeout + ) + + response.raise_for_status() + self.jwt.refresh = response.json().get("refresh") diff --git a/keystone_client/client.py b/keystone_client/client.py index d34dc2f..7a90f95 100644 --- a/keystone_client/client.py +++ b/keystone_client/client.py @@ -7,16 +7,16 @@ from __future__ import annotations -from collections import namedtuple -from datetime import datetime -from functools import partial +from functools import cached_property, partial from typing import Literal, Union -from warnings import warn -import urllib.parse -import jwt +from urllib.parse import urljoin + import requests -__all__ = ["KeystoneClient"] +from keystone_client.authentication import AuthenticationManager +from keystone_client.schema import Endpoint, Schema + +DEFAULT_TIMEOUT = 15 # Custom types ContentType = Literal["json", "text", "content"] @@ -24,31 +24,11 @@ QueryResult = Union[None, dict, list[dict]] HTTPMethod = Literal["get", "post", "put", "patch", "delete"] -# API schema mapping human-readable, python-friendly names to API endpoints -Schema = namedtuple("Schema", [ - "allocations", - "requests", - "research_groups", - "users", -]) - - -class KeystoneClient: - """Client class for submitting requests to the Keystone API""" - # Default API behavior - default_timeout = 15 +class HTTPClient: + """Low level API client for sending standard HTTP operations""" - # API endpoints - authentication_new = "authentication/new/" - authentication_blacklist = "authentication/blacklist/" - authentication_refresh = "authentication/refresh/" - schema = Schema( - allocations="allocations/allocations/", - requests="allocations/requests/", - research_groups="users/researchgroups/", - users="users/users/", - ) + schema = Schema() def __init__(self, url: str) -> None: """Initialize the class @@ -57,89 +37,44 @@ def __init__(self, url: str) -> None: url: The base URL for a running Keystone API server """ - self._url = url - self._api_version: str | None = None - self._access_token: str | None = None - self._access_expiration: datetime | None = None - self._refresh_token: str | None = None - self._refresh_expiration: datetime | None = None - - def __new__(cls, *args, **kwargs) -> KeystoneClient: - """Dynamically create CRUD methods for each endpoint in the API schema + self._url = url.rstrip('/') + '/' + self._auth = AuthenticationManager(url, self.schema) - Dynamic method are only generated of they do not already implemented - in the class definition. - """ - - instance: KeystoneClient = super().__new__(cls) - for key, endpoint in zip(cls.schema._fields, cls.schema): - - # Create a retrieve method - retrieve_name = f"retrieve_{key}" - if not hasattr(instance, retrieve_name): - retrieve_method = partial(instance._retrieve_records, _endpoint=endpoint) - setattr(instance, f"retrieve_{key}", retrieve_method) - - return instance + @property + def url(self) -> str: + """Return the server URL""" - def _retrieve_records( - self, - _endpoint: str, - pk: int | None = None, - filters: dict | None = None, - timeout=default_timeout - ) -> QueryResult: - """Retrieve data from the specified endpoint with optional primary key and filters + return self._url - A single record is returned when specifying a primary key, otherwise the returned - object is a list of records. In either case, the return value is `None` when no data - is available for the query. + def login(self, username: str, password: str, timeout: int = DEFAULT_TIMEOUT) -> None: + """Authenticate a new user session Args: - pk: Optional primary key to fetch a specific record - filters: Optional query parameters to include in the request + username: The authentication username + password: The authentication password timeout: Seconds before the request times out - Returns: - The response from the API in JSON format + Raises: + requests.HTTPError: If the login request fails """ - if pk is not None: - _endpoint = f"{_endpoint}/{pk}/" + self._auth.login(username, password, timeout) # pragma: nocover - try: - response = self.http_get(_endpoint, params=filters, timeout=timeout) - response.raise_for_status() - return response.json() + def logout(self, timeout: int = DEFAULT_TIMEOUT) -> None: + """Log out and blacklist any active credentials - except requests.HTTPError as exception: - if exception.response.status_code == 404: - return None - - raise - - def _get_headers(self) -> dict[str, str]: - """Return header data for API requests - - Returns: - A dictionary with header data + Args: + timeout: Seconds before the blacklist request times out """ - if not self._access_token: - return dict() + self._auth.logout(timeout) # pragma: nocover - return { - "Authorization": f"Bearer {self._access_token}", - "Content-Type": "application/json" - } + def is_authenticated(self) -> bool: + """Return whether the client instance has active credentials""" - def _send_request( - self, - method: HTTPMethod, - endpoint: str, - timeout: int = default_timeout, - **kwargs - ) -> requests.Response: + return self._auth.is_authenticated() # pragma: nocover + + def _send_request(self, method: HTTPMethod, endpoint: str, **kwargs) -> requests.Response: """Send an HTTP request Args: @@ -151,25 +86,16 @@ def _send_request( An HTTP response """ - self._refresh_tokens(force=False, timeout=timeout) - - url = urllib.parse.urljoin(self.url, endpoint) - response = requests.request(method, url, headers=self._get_headers(), **kwargs) + url = urljoin(self.url, endpoint) + response = requests.request(method, url, **kwargs) response.raise_for_status() return response - @property - def url(self) -> str: - """Return the server URL""" - - # Make sure the url includes a single trailing slash - return self._url.rstrip('/') + '/' - def http_get( self, endpoint: str, params: dict[str, any] | None = None, - timeout: int = default_timeout + timeout: int = DEFAULT_TIMEOUT ) -> requests.Response: """Send a GET request to an API endpoint @@ -185,13 +111,19 @@ def http_get( requests.HTTPError: If the request returns an error code """ - return self._send_request("get", endpoint, params=params, timeout=timeout) + return self._send_request( + "get", + endpoint, + params=params, + headers=self._auth.get_auth_headers(), + timeout=timeout + ) def http_post( self, endpoint: str, data: dict[str, any] | None = None, - timeout: int = default_timeout + timeout: int = DEFAULT_TIMEOUT ) -> requests.Response: """Send a POST request to an API endpoint @@ -207,13 +139,19 @@ def http_post( requests.HTTPError: If the request returns an error code """ - return self._send_request("post", endpoint, data=data, timeout=timeout) + return self._send_request( + "post", + endpoint, + data=data, + headers=self._auth.get_auth_headers(), + timeout=timeout + ) def http_patch( self, endpoint: str, data: dict[str, any] | None = None, - timeout: int = default_timeout + timeout: int = DEFAULT_TIMEOUT ) -> requests.Response: """Send a PATCH request to an API endpoint @@ -229,13 +167,19 @@ def http_patch( requests.HTTPError: If the request returns an error code """ - return self._send_request("patch", endpoint, data=data, timeout=timeout) + return self._send_request( + "patch", + endpoint, + data=data, + headers=self._auth.get_auth_headers(), + timeout=timeout + ) def http_put( self, endpoint: str, data: dict[str, any] | None = None, - timeout: int = default_timeout + timeout: int = DEFAULT_TIMEOUT ) -> requests.Response: """Send a PUT request to an endpoint @@ -251,12 +195,18 @@ def http_put( requests.HTTPError: If the request returns an error code """ - return self._send_request("put", endpoint, data=data, timeout=timeout) + return self._send_request( + "put", + endpoint, + data=data, + headers=self._auth.get_auth_headers(), + timeout=timeout + ) def http_delete( self, endpoint: str, - timeout: int = default_timeout + timeout: int = DEFAULT_TIMEOUT ) -> requests.Response: """Send a DELETE request to an endpoint @@ -271,112 +221,70 @@ def http_delete( requests.HTTPError: If the request returns an error code """ - return self._send_request("delete", endpoint, timeout=timeout) + return self._send_request( + "delete", + endpoint, + headers=self._auth.get_auth_headers(), + timeout=timeout + ) - @property - def is_authenticated(self) -> None: - """Return whether the client instance has been authenticated""" - now = datetime.now() - has_token = self._refresh_token is not None - access_token_valid = self._access_expiration is not None and self._access_expiration > now - access_token_refreshable = self._refresh_expiration is not None and self._refresh_expiration > now - return has_token and (access_token_valid or access_token_refreshable) +class KeystoneClient(HTTPClient): + """Client class for submitting requests to the Keystone API""" - @property + @cached_property def api_version(self) -> str: """Return the version number of the API server""" - if self._api_version is None: - response = self.http_get("version") - response.raise_for_status() - self._api_version = response.text - - return self._api_version - - def login(self, username: str, password: str, timeout: int = default_timeout) -> None: - """Log in to the Keystone API and cache the returned JWT - - Args: - username: The authentication username - password: The authentication password - timeout: Seconds before the request times out - - Raises: - requests.HTTPError: If the login request fails - """ - - response = requests.post( - f"{self.url}/{self.authentication_new}", - json={"username": username, "password": password}, - timeout=timeout - ) - + response = self.http_get("version") response.raise_for_status() + return response.text - # Parse data from the refresh token - self._refresh_token = response.json().get("refresh") - refresh_payload = jwt.decode(self._refresh_token, options={"verify_signature": False}, algorithms='HS256') - self._refresh_expiration = datetime.fromtimestamp(refresh_payload["exp"]) - - # Parse data from the access token - self._access_token = response.json().get("access") - access_payload = jwt.decode(self._access_token, options={"verify_signature": False}, algorithms='HS256') - self._access_expiration = datetime.fromtimestamp(access_payload["exp"]) - - def logout(self, timeout: int = default_timeout) -> None: - """Log out and blacklist any active JWTs - - Args: - timeout: Seconds before the request times out - """ + def __new__(cls, *args, **kwargs) -> KeystoneClient: + """Dynamically create CRUD methods for each data endpoint in the API schema""" - if self._refresh_token is not None: - response = requests.post( - f"{self.url}/{self.authentication_blacklist}", - data={"refresh": self._refresh_token}, - timeout=timeout - ) + new: KeystoneClient = super().__new__(cls) - try: - response.raise_for_status() + new.retrieve_allocations = partial(new._retrieve_records, cls.schema.data.allocations) + new.retrieve_requests = partial(new._retrieve_records, cls.schema.data.requests) + new.retrieve_research_groups = partial(new._retrieve_records, cls.schema.data.research_groups) + new.retrieve_users = partial(new._retrieve_records, cls.schema.data.users) - except Exception as exception: - warn(str(exception)) + return new - self._refresh_token = None - self._refresh_expiration = None - self._access_token = None - self._access_expiration = None + def _retrieve_records( + self, + _endpoint: Endpoint, + pk: int | None = None, + filters: dict | None = None, + timeout=DEFAULT_TIMEOUT + ) -> QueryResult: + """Retrieve data from the specified endpoint with optional primary key and filters - def _refresh_tokens(self, force: bool = True, timeout: int = default_timeout) -> None: - """Refresh the JWT access token + A single record is returned when specifying a primary key, otherwise the returned + object is a list of records. In either case, the return value is `None` when no data + is available for the query. Args: + pk: Optional primary key to fetch a specific record + filters: Optional query parameters to include in the request timeout: Seconds before the request times out - force: Refresh the access token even if it has not expired yet - """ - - if not self.is_authenticated: - return - # Don't refresh the token if it's not necessary - now = datetime.now() - if self._access_expiration > now and not force: - return + Returns: + The response from the API in JSON format + """ - # Alert the user when a refresh is not possible - if self._refresh_expiration > now: - raise RuntimeError("Refresh token has expired. Login again to continue.") + url = _endpoint.join_url(self.url) + if pk is not None: + url = urljoin(url, str(pk)) - response = requests.post( - f"{self.url}/{self.authentication_refresh}", - data={"refresh": self._refresh_token}, - timeout=timeout - ) + try: + response = self.http_get(url, params=filters, timeout=timeout) + response.raise_for_status() + return response.json() - response.raise_for_status() - self._refresh_token = response.json().get("refresh") - refresh_payload = jwt.decode(self._refresh_token, options={"verify_signature": False}, algorithms='HS256') - self._refresh_expiration = datetime.fromtimestamp(refresh_payload["exp"]) + except requests.HTTPError as exception: + if exception.response.status_code == 404: + return None + raise diff --git a/keystone_client/schema.py b/keystone_client/schema.py new file mode 100644 index 0000000..64a06e7 --- /dev/null +++ b/keystone_client/schema.py @@ -0,0 +1,49 @@ +"""Schema objects used to define available API endpoints.""" + +from dataclasses import dataclass, field +from urllib.parse import urljoin + + +class Endpoint(str): + + def join_url(self, url: str) -> str: + """Join the endpoint with a base URL + + This method returns URLs in a format that avoids trailing slash + redirects from the Keystone API. + + Args: + url: The base URL + + Returns: + The base URL join with the endpoint + """ + + return urljoin(url, self).rstrip('/') + '/' + + +@dataclass +class AuthSchema: + """Schema defining API endpoints used for JWT authentication""" + + new: Endpoint = Endpoint("authentication/new") + refresh: Endpoint = Endpoint("authentication/refresh") + blacklist: Endpoint = Endpoint("authentication/blacklist") + + +@dataclass +class DataSchema: + """Schema defining API endpoints for data access""" + + allocations: Endpoint = Endpoint("allocations/allocations") + requests: Endpoint = Endpoint("allocations/requests") + research_groups: Endpoint = Endpoint("users/researchgroups") + users: Endpoint = Endpoint("users/users") + + +@dataclass +class Schema: + """Schema defining the complete set of API endpoints""" + + auth: AuthSchema = field(default_factory=AuthSchema) + data: DataSchema = field(default_factory=DataSchema) diff --git a/pyproject.toml b/pyproject.toml index 8811613..b31a319 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,8 +21,8 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.9" -requests = "^2.32.3" pyjwt = "^2.8.0" +requests = "^2.32.3" [tool.poetry.group.tests] optional = true diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..e34b2ee 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +import os + +API_HOST = os.environ.get('TEST_API_HOST', 'http://localhost:8000') +API_USER = os.environ.get('TEST_API_USER', 'admin') +API_PASSWORD = os.environ.get('TEST_API_PASSWORD', 'quickstart') diff --git a/tests/authentication/__init__.py b/tests/authentication/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/authentication/test_authentication_manager.py b/tests/authentication/test_authentication_manager.py new file mode 100644 index 0000000..a9a1c6c --- /dev/null +++ b/tests/authentication/test_authentication_manager.py @@ -0,0 +1,118 @@ +"""Tests for the `AuthenticationManager` class.""" + +from datetime import datetime, timedelta +from unittest import TestCase + +import jwt +from requests.exceptions import HTTPError + +from keystone_client.authentication import AuthenticationManager, JWT +from tests import API_HOST, API_PASSWORD, API_USER + + +def create_token(access_expires: datetime, refresh_expires: datetime) -> JWT: + """Create a JWT token + + Args: + access_expires: The expiration datetime for the access token + refresh_expires: The expiration datetime for the refresh token + + Returns: + A JWT instance with the given expiration dates + """ + + return JWT( + access=jwt.encode({'exp': access_expires.timestamp()}, 'secret'), + refresh=jwt.encode({'exp': refresh_expires.timestamp()}, 'secret') + ) + + +class IsAuthenticated(TestCase): + """Tests for the `is_authenticated` method""" + + def test_not_authenticated(self) -> None: + """Test the return value is `false` when the manager has no JWT data""" + + manager = AuthenticationManager(API_HOST) + self.assertIsNone(manager.jwt) + self.assertFalse(manager.is_authenticated()) + + def test_valid_jwt(self) -> None: + """Test the return value is `True` when the JWT token is not expired""" + + manager = AuthenticationManager(API_HOST) + manager.jwt = create_token( + access_expires=datetime.now() + timedelta(hours=1), + refresh_expires=datetime.now() + timedelta(days=1) + ) + + self.assertTrue(manager.is_authenticated()) + + def test_refreshable_jwt(self) -> None: + """Test the return value is `True` when the JWT token expired but refreshable""" + + manager = AuthenticationManager(API_HOST) + manager.jwt = create_token( + access_expires=datetime.now() - timedelta(hours=1), + refresh_expires=datetime.now() + timedelta(days=1) + ) + + self.assertTrue(manager.is_authenticated()) + + def test_expired_jwt(self) -> None: + """Test the return value is `False` when the JWT token is expired""" + + manager = AuthenticationManager(API_HOST) + manager.jwt = create_token( + access_expires=datetime.now() - timedelta(days=1), + refresh_expires=datetime.now() - timedelta(hours=1) + ) + + self.assertFalse(manager.is_authenticated()) + + +class GetAuthHeaders(TestCase): + """Tests for the `get_auth_headers` method""" + + def test_not_authenticated(self) -> None: + """Test the returned headers are empty when not authenticated""" + + manager = AuthenticationManager(API_HOST) + headers = manager.get_auth_headers() + self.assertEqual(dict(), headers) + + def test_headers_match_jwt(self) -> None: + """Test the returned data matches the JWT token""" + + manager = AuthenticationManager(API_HOST) + manager.jwt = create_token( + access_expires=datetime.now() + timedelta(hours=1), + refresh_expires=datetime.now() + timedelta(days=1) + ) + + headers = manager.get_auth_headers() + self.assertEqual(f"Bearer {manager.jwt.access}", headers["Authorization"]) + + +class LoginLogout(TestCase): + """Test the logging in/out of users""" + + def test_correct_credentials(self) -> None: + """Test users are successfully logged in/out when providing correct credentials""" + + manager = AuthenticationManager(API_HOST) + self.assertFalse(manager.is_authenticated()) + + manager.login(API_USER, API_PASSWORD) + self.assertTrue(manager.is_authenticated()) + + manager.logout() + self.assertFalse(manager.is_authenticated()) + + def test_incorrect_credentials(self) -> None: + """Test an error is raised when authenticating with invalid credentials""" + + manager = AuthenticationManager(API_HOST) + with self.assertRaises(HTTPError) as error: + manager.login('foo', 'bar') + self.assertEqual(401, error.response.status_code) diff --git a/tests/authentication/test_jwt.py b/tests/authentication/test_jwt.py new file mode 100644 index 0000000..618881d --- /dev/null +++ b/tests/authentication/test_jwt.py @@ -0,0 +1,45 @@ +"""Tests for the `JWT` class.""" + +from datetime import datetime, timedelta +from unittest import TestCase + +import jwt + +from keystone_client.authentication import JWT + + +class BaseParsingTests: + """Base class containing reusable tests for token parsing""" + + algorithm: str + + @classmethod + def setUpClass(cls) -> None: + """Test the parsing of JWT data""" + + # Build a JWT + cls.access_expiration = datetime.now() + timedelta(hours=1) + cls.access_token = jwt.encode({'exp': cls.access_expiration.timestamp()}, 'secret', algorithm=cls.algorithm) + + cls.refresh_expiration = datetime.now() + timedelta(days=1) + cls.refresh_token = jwt.encode({'exp': cls.refresh_expiration.timestamp()}, 'secret', algorithm=cls.algorithm) + + cls.jwt = JWT(cls.access_token, cls.refresh_token, cls.algorithm) + + def test_access_token(self) -> None: + """Test the access token is parsed correctly""" + + self.assertEqual(self.access_token, self.jwt.access) + self.assertEqual(self.access_expiration, self.jwt.access_expiration) + + def test_refresh_token(self) -> None: + """Test the refresh token is parsed correctly""" + + self.assertEqual(self.refresh_token, self.jwt.refresh) + self.assertEqual(self.refresh_expiration, self.jwt.refresh_expiration) + + +class HS256Parsing(BaseParsingTests, TestCase): + """Test JWT token parsing using the HS256 algorithm.""" + + algorithm = 'HS256' diff --git a/tests/client/__init__.py b/tests/client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_url_handling.py b/tests/client/test_http_client.py similarity index 54% rename from tests/test_url_handling.py rename to tests/client/test_http_client.py index 41d8f8e..03a4e87 100644 --- a/tests/test_url_handling.py +++ b/tests/client/test_http_client.py @@ -1,8 +1,8 @@ -"""Tests for the handling to the base API url.""" +"""Tests for the `HTTPClient` class.""" from unittest import TestCase -from keystone_client import KeystoneClient +from keystone_client.client import HTTPClient class TestUrl(TestCase): @@ -15,6 +15,6 @@ def test_trailing_slash_removed(self): expected_url = base_url + '/' # Test for various numbers of trailing slashes provided at init - self.assertEqual(expected_url, KeystoneClient(base_url).url) - self.assertEqual(expected_url, KeystoneClient(base_url + '/').url) - self.assertEqual(expected_url, KeystoneClient(base_url + '////').url) + self.assertEqual(expected_url, HTTPClient(base_url).url) + self.assertEqual(expected_url, HTTPClient(base_url + '/').url) + self.assertEqual(expected_url, HTTPClient(base_url + '////').url) diff --git a/tests/client/test_keystone_client.py b/tests/client/test_keystone_client.py new file mode 100644 index 0000000..b4b9f7f --- /dev/null +++ b/tests/client/test_keystone_client.py @@ -0,0 +1,43 @@ +"""Tests for CRUD operations.""" + +import re +from dataclasses import asdict +from unittest import TestCase + +from keystone_client import KeystoneClient +from keystone_client.schema import Schema +from tests import API_HOST + + +class APIVersion(TestCase): + """Tests for the `api_version` method""" + + def test_version_is_returned(self) -> None: + """Test a version number is returned""" + + # Simplified version identification from PEP 440 + version_regex = re.compile(r""" + ^ + (?P[0-9]+)\. # Major version number + (?P[0-9]+)\. # Minor version number + (?P[0-9]+) # Patch version number + (?:\. # Optional dot + (?P[a-zA-Z0-9]+) # Optional suffix (letters or numbers) + )? # Make the entire suffix part optional + $ + """, re.VERBOSE) + + client = KeystoneClient(API_HOST) + self.assertRegex(client.api_version, version_regex) + + +class RetrieveMethods(TestCase): + """Tests for retrieve methods""" + + def test_methods_exist(self) -> None: + """Test a method exists for each endpoint in the class schema""" + + client = KeystoneClient('http://test.domain.com') + for endpoint in asdict(Schema().data): + method_name = f'retrieve_{endpoint}' + self.assertTrue(hasattr(client, method_name), f'Method does not exist {method_name}') diff --git a/tests/schema/__init__.py b/tests/schema/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/schema/test_endpoint.py b/tests/schema/test_endpoint.py new file mode 100644 index 0000000..17f54df --- /dev/null +++ b/tests/schema/test_endpoint.py @@ -0,0 +1,49 @@ +"""Tests for the `Endpoint` class.""" + +from unittest import TestCase + +from keystone_client.schema import Endpoint + + +class JoinUrl(TestCase): + """Tests for the `join_url` method""" + + def test_with_trailing_slash(self) -> None: + """Test join_url with a base URL that has a trailing slash""" + + endpoint = Endpoint("authentication/new") + base_url = "https://api.example.com/" + expected_result = "https://api.example.com/authentication/new/" + self.assertEqual(expected_result, endpoint.join_url(base_url)) + + def test_without_trailing_slash(self) -> None: + """Test join_url with a base URL that does not have a trailing slash""" + + endpoint = Endpoint("authentication/new") + base_url = "https://api.example.com" + expected_result = "https://api.example.com/authentication/new/" + self.assertEqual(expected_result, endpoint.join_url(base_url)) + + def test_with_endpoint_trailing_slash(self) -> None: + """Test join_url with an endpoint that has a trailing slash""" + + endpoint = Endpoint("authentication/new/") + base_url = "https://api.example.com" + expected_result = "https://api.example.com/authentication/new/" + self.assertEqual(expected_result, endpoint.join_url(base_url)) + + def test_without_endpoint_trailing_slash(self) -> None: + """Test join_url with an endpoint that does not have a trailing slash""" + + endpoint = Endpoint("authentication/new") + base_url = "https://api.example.com" + expected_result = "https://api.example.com/authentication/new/" + self.assertEqual(expected_result, endpoint.join_url(base_url)) + + def test_with_complete_url_as_endpoint(self) -> None: + """Test join_url when the endpoint is a complete URL""" + + endpoint = Endpoint("https://anotherapi.com/authentication/new") + base_url = "https://api.example.com" + expected_result = "https://anotherapi.com/authentication/new/" + self.assertEqual(expected_result, endpoint.join_url(base_url)) diff --git a/tests/test_crud_methods.py b/tests/test_crud_methods.py deleted file mode 100644 index e3920ba..0000000 --- a/tests/test_crud_methods.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Tests for CRUD operations.""" - -from unittest import TestCase - -from keystone_client import KeystoneClient - - -class Retrieve(TestCase): - """Tests for retrieve methods""" - - def test_methods_exist(self) -> None: - """Test a method exists for each endpoint in the class schema""" - - client = KeystoneClient('http://test.domain.com') - for endpoint in KeystoneClient.schema._fields: - method_name = f'retrieve_{endpoint}' - self.assertTrue(hasattr(client, method_name), f'Method does not exist {method_name}')