Skip to content

Commit

Permalink
Formatting and QA inspections (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
djperrefort authored Jul 2, 2024
1 parent 2405f73 commit a358e69
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 23 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ A light-weight Python client for wrapping the Keystone API.
- [Keystone-API](https://github.com/pitt-crc/keystone-api): Backend REST API for managing HPC allocations and resources.
- [Keystone-Web](https://github.com/pitt-crc/keystone-web): Website frontend for HPC administration and self-service.
- [Keystone-Docs](https://github.com/pitt-crc/keystone-docs): Documentation for the Keystone project and its components.

2 changes: 1 addition & 1 deletion keystone_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .client import *
from .client import *
53 changes: 33 additions & 20 deletions keystone_client/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""Keystone API Client
This module provides a client class `KeystoneAPIClient` for interacting with the
Keystone API. It streamlines communication with the API, providing methods for
authentication, data retrieval, and data manipulation.
"""

from __future__ import annotations

from collections import namedtuple
Expand All @@ -12,17 +19,17 @@
__all__ = ["KeystoneClient"]

# Custom types
ContentType = Literal['json', 'text', 'content']
ContentType = Literal["json", "text", "content"]
ResponseContent = Union[Dict[str, Any], str, bytes]
QueryResult = Union[None, dict, List[dict]]
HTTPMethod = Literal['get', 'post', 'put', 'patch', 'delete']
HTTPMethod = Literal["get", "post", "put", "patch", "delete"]

# API schema mapping human-readable, python-friendly names to API endpoints
Schema = namedtuple('Schema', [
'allocations',
'requests',
'research_groups',
'users',
Schema = namedtuple("Schema", [
"allocations",
"requests",
"research_groups",
"users",
])


Expand All @@ -37,10 +44,10 @@ class KeystoneClient:
authentication_blacklist = "authentication/blacklist/"
authentication_refresh = "authentication/refresh/"
schema = Schema(
allocations='allocations/allocations/',
requests='allocations/requests/',
research_groups='users/researchgroups/',
users='users/users/',
allocations="allocations/allocations/",
requests="allocations/requests/",
research_groups="users/researchgroups/",
users="users/users/",
)

def __init__(self, url: str, auto_refresh: bool = True) -> None:
Expand Down Expand Up @@ -71,10 +78,10 @@ def __new__(cls, *args, **kwargs) -> KeystoneClient:
for key, endpoint in zip(cls.schema._fields, cls.schema):

# Create a retrieve method
retrieve_name = f'retrieve_{key}'
retrieve_name = f"retrieve_{key}"
if not hasattr(instance, retrieve_name):
retrieve_method = partial(instance._retrieve_records, _endpoint=endpoint)
setattr(instance, f'retrieve_{key}', retrieve_method)
setattr(instance, f"retrieve_{key}", retrieve_method)

return instance

Expand All @@ -101,7 +108,7 @@ def _retrieve_records(
"""

if pk is not None:
_endpoint = f'{_endpoint}/{pk}/'
_endpoint = f"{_endpoint}/{pk}/"

try:
response = self.http_get(_endpoint, params=filters, timeout=timeout)
Expand Down Expand Up @@ -129,7 +136,13 @@ def _get_headers(self) -> Dict[str, str]:
"Content-Type": "application/json"
}

def _send_request(self, method: HTTPMethod, url: str, timeout: int = default_timeout, **kwargs) -> requests.Response:
def _send_request(
self,
method: HTTPMethod,
url: str,
timeout: int = default_timeout,
**kwargs
) -> requests.Response:
"""Send an HTTP request
Args:
Expand All @@ -142,7 +155,7 @@ def _send_request(self, method: HTTPMethod, url: str, timeout: int = default_tim
"""

if self.auto_refresh:
self.refresh(force=False, timeout=timeout)
self._refresh_tokens(force=False, timeout=timeout)

response = requests.request(method, url, **kwargs)
response.raise_for_status()
Expand Down Expand Up @@ -298,12 +311,12 @@ def login(self, username: str, password: str, timeout: int = default_timeout) ->
response.raise_for_status()

# Parse data from the refresh token
refresh_payload = jwt.decode(self._refresh_token, options={"verify_signature": False})
refresh_payload = jwt.decode(self._refresh_token)
self._refresh_token = response.json().get("refresh")
self._refresh_expiration = datetime.fromtimestamp(refresh_payload["exp"])

# Parse data from the access token
access_payload = jwt.decode(self._access_token, options={"verify_signature": False})
access_payload = jwt.decode(self._access_token)
self._access_token = response.json().get("access")
self._access_expiration = datetime.fromtimestamp(access_payload["exp"])

Expand Down Expand Up @@ -332,7 +345,7 @@ def logout(self, timeout: int = default_timeout) -> None:
self._access_token = None
self._access_expiration = None

def refresh(self, force: bool = True, timeout: int = default_timeout) -> None:
def _refresh_tokens(self, force: bool = True, timeout: int = default_timeout) -> None:
"""Refresh the JWT access token
Args:
Expand All @@ -347,7 +360,7 @@ def refresh(self, force: bool = True, timeout: int = default_timeout) -> None:

# Alert the user when a refresh is not possible
if self._refresh_expiration > now:
raise RuntimeError('Refresh token has expired. Login again to continue.')
raise RuntimeError("Refresh token has expired. Login again to continue.")

response = requests.post(
f"{self.url}/{self.authentication_refresh}",
Expand Down
6 changes: 5 additions & 1 deletion tests/test_dummy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""Placeholder test module"""

from unittest import TestCase


class TestDummy(TestCase):
"""Placeholder test class"""

def test_dummy(self):
pass
"""Placeholder test"""

0 comments on commit a358e69

Please sign in to comment.