-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6e014f1
commit 06236ab
Showing
7 changed files
with
234 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
"""This module contains authentication methods used to verify the JWT token sent by clients.""" | ||
|
||
import os | ||
from typing import Optional | ||
|
||
import jwt | ||
from fastapi import HTTPException, Request | ||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer | ||
from loguru import logger | ||
|
||
ADMIN_SECRET_KEY = os.environ['ENP_ADMIN_SECRET_KEY'] | ||
ALGORITHM = os.environ.get('ENP_ALGORITHM', 'HS256') | ||
ACCESS_TOKEN_EXPIRE_SECONDS = int(os.environ.get('ENP_ACCESS_TOKEN_EXPIRE_SECONDS', 60)) | ||
|
||
|
||
class JWTBearer(HTTPBearer): | ||
"""JWTBearer class to verify the JWT token sent by the client.""" | ||
|
||
def __init__(self, auto_error: bool = True) -> None: | ||
"""Initialize the JWTBearer class. | ||
Args: | ||
auto_error (bool, optional): If True, raise an HTTPException if the token is invalid or expired. Defaults to True. | ||
""" | ||
super(JWTBearer, self).__init__(auto_error=auto_error) | ||
|
||
async def __call__(self, request: Request) -> Optional[HTTPAuthorizationCredentials]: | ||
"""Override the __call__ method to verify the JWT token. A JWT token is considered valid if it is not expired, and the signature is valid. | ||
Args: | ||
request (Request): FastAPI request object | ||
Returns: | ||
Optional[HTTPAuthorizationCredentials]: HTTPAuthorizationCredentials object if the token is valid, None otherwise. | ||
Raises: | ||
HTTPException: If the token is invalid or expired | ||
""" | ||
credentials: HTTPAuthorizationCredentials | None = await super(JWTBearer, self).__call__(request) | ||
if credentials is None: | ||
raise HTTPException(status_code=403, detail='Not authenticated') | ||
if not self.verify_token(str(credentials.credentials)): | ||
raise HTTPException(status_code=403, detail='Invalid token or expired token.') | ||
return credentials | ||
|
||
def verify_token(self, jwtoken: str) -> bool: | ||
"""Verify the JWT token and check if it is expired. | ||
Args: | ||
jwtoken (str): JWT token | ||
Returns: | ||
bool: True if the token is valid and not expired, False otherwise. | ||
""" | ||
try: | ||
payload = jwt.decode( | ||
jwtoken, | ||
ADMIN_SECRET_KEY, | ||
algorithms=[ALGORITHM], | ||
options={ | ||
'verify_signature': True, | ||
}, | ||
) | ||
logger.info('JWT payload: {}', payload) | ||
return True | ||
except (jwt.PyJWTError, jwt.ImmatureSignatureError): | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
"""Tests for the authentication module.""" | ||
|
||
import base64 | ||
import hmac | ||
import json | ||
import time | ||
from typing import TypedDict | ||
|
||
from fastapi.testclient import TestClient | ||
|
||
|
||
class PayloadDict(TypedDict): | ||
"""Payload dictionary type.""" | ||
|
||
iat: int | ||
exp: int | ||
jti: str | ||
|
||
|
||
def _get_jwt_token(client_secret: str, payload_dict: PayloadDict) -> str: | ||
"""Utility to generate a JWT token. | ||
Args: | ||
client_secret (str): Client secret | ||
payload_dict (dict[str, str]): Payload dictionary | ||
Returns: | ||
str: a signed JWT token | ||
""" | ||
header_dict = {'typ': 'JWT', 'alg': 'HS256'} | ||
header = json.dumps(header_dict) | ||
payload = json.dumps(payload_dict) | ||
|
||
header = base64.urlsafe_b64encode(bytes(str(header), 'utf-8')).decode().replace('=', '') | ||
payload = base64.urlsafe_b64encode(bytes(str(payload), 'utf-8')).decode().replace('=', '') | ||
|
||
signature = hmac.new( | ||
bytes(client_secret, 'utf-8'), bytes(header + '.' + payload, 'utf-8'), digestmod='sha256' | ||
).digest() | ||
sigb64 = base64.urlsafe_b64encode(bytes(signature)).decode().replace('=', '') | ||
|
||
token = header + '.' + payload + '.' + sigb64 | ||
return token | ||
|
||
|
||
def test_missing_authorization_scheme(client: TestClient) -> None: | ||
"""Test the invalid authorization scheme. | ||
Args: | ||
client (TestClient): FastAPI test client | ||
""" | ||
client_secret = 'not-very-secret' | ||
current_timestamp = int(time.time()) | ||
payload_dict: PayloadDict = { | ||
'iat': current_timestamp, | ||
'exp': current_timestamp + 60, | ||
'jti': 'jwt_nonce', | ||
} | ||
response = client.post( | ||
'/v3/device-registrations', headers={'Authorization': f'{_get_jwt_token(client_secret, payload_dict)}'} | ||
) | ||
assert response.status_code == 403 | ||
assert response.json() == {'detail': 'Not authenticated'} | ||
|
||
|
||
def test_expired_iat_in_token(client: TestClient) -> None: | ||
"""Test the missing iat in token. | ||
Args: | ||
client (TestClient): FastAPI test client | ||
""" | ||
client_secret = 'not-very-secret' | ||
current_timestamp = int(time.time()) | ||
payload_dict: PayloadDict = { | ||
'iat': current_timestamp - 300, | ||
'exp': current_timestamp - 240, | ||
'jti': 'jwt_nonce', | ||
} | ||
response = client.post( | ||
'/v3/device-registrations', headers={'Authorization': f'Bearer {_get_jwt_token(client_secret, payload_dict)}'} | ||
) | ||
assert response.status_code == 403 | ||
assert response.json() == {'detail': 'Invalid token or expired token.'} | ||
|
||
|
||
def test_future_iat_in_token(client: TestClient) -> None: | ||
"""Test the missing iat in token. | ||
Args: | ||
client (TestClient): FastAPI test client | ||
""" | ||
client_secret = 'not-very-secret' | ||
current_timestamp = int(time.time()) | ||
payload_dict: PayloadDict = { | ||
'iat': current_timestamp + 120, | ||
'exp': current_timestamp + 180, | ||
'jti': 'jwt_nonce', | ||
} | ||
response = client.post( | ||
'/v3/device-registrations', headers={'Authorization': f'Bearer {_get_jwt_token(client_secret, payload_dict)}'} | ||
) | ||
assert response.status_code == 403 | ||
assert response.json() == {'detail': 'Invalid token or expired token.'} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters