Skip to content

Commit

Permalink
Added JWT auth and updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mchlwellman committed Jan 2, 2025
1 parent 6e014f1 commit 06236ab
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 64 deletions.
6 changes: 6 additions & 0 deletions .talismanrc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
fileignoreconfig:
- filename: .env.example
checksum: f04bd8b8a51131d412bea4b8098e714ca01f4c7475eeb0a40d196d922070ba43
- filename: app/auth.py
checksum: 31e75b44f03afab392d8bfe0a05dc66db227e32b0bec57dd0069b7f38307976b
- filename: app/db/__init__.py
checksum: b079eb426725423f18be76949e3e10ca50ebfdd419f8768c2b7b58ddc41e5e01
- filename: app/db/models.py
Expand All @@ -19,4 +21,8 @@ fileignoreconfig:
checksum: 0e3ae2fd3a50245a8c143d31c4316b164b51161d2db6e660aa956b78eda1b4d8
- filename: tests/app/providers/test_provider_aws.py
checksum: f2ed526b1a81c5c0facfd4d1e7f1a3e26194778c1909a354c98c3b042945d34d
- filename: tests/app/test_auth.py
checksum: 43a63d5642c744fab8e703b3ee4a1e65d1aa51e4004304cd50cab36f2ec93594
- filename: tests/conftest.py
checksum: 3656032eac92d9070439d581529790de4f26dd810f4d09ae504afc455799a6f1
version: "1.0"
67 changes: 67 additions & 0 deletions app/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""This module contains authentication methods used to verify the JWT token sent by clients."""

import os
from typing import Optional

import jwt
from fastapi import HTTPException, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from loguru import logger

ADMIN_SECRET_KEY = os.environ['ENP_ADMIN_SECRET_KEY']
ALGORITHM = os.environ.get('ENP_ALGORITHM', 'HS256')
ACCESS_TOKEN_EXPIRE_SECONDS = int(os.environ.get('ENP_ACCESS_TOKEN_EXPIRE_SECONDS', 60))


class JWTBearer(HTTPBearer):
"""JWTBearer class to verify the JWT token sent by the client."""

def __init__(self, auto_error: bool = True) -> None:
"""Initialize the JWTBearer class.
Args:
auto_error (bool, optional): If True, raise an HTTPException if the token is invalid or expired. Defaults to True.
"""
super(JWTBearer, self).__init__(auto_error=auto_error)

async def __call__(self, request: Request) -> Optional[HTTPAuthorizationCredentials]:
"""Override the __call__ method to verify the JWT token. A JWT token is considered valid if it is not expired, and the signature is valid.
Args:
request (Request): FastAPI request object
Returns:
Optional[HTTPAuthorizationCredentials]: HTTPAuthorizationCredentials object if the token is valid, None otherwise.
Raises:
HTTPException: If the token is invalid or expired
"""
credentials: HTTPAuthorizationCredentials | None = await super(JWTBearer, self).__call__(request)
if credentials is None:
raise HTTPException(status_code=403, detail='Not authenticated')
if not self.verify_token(str(credentials.credentials)):
raise HTTPException(status_code=403, detail='Invalid token or expired token.')
return credentials

def verify_token(self, jwtoken: str) -> bool:
"""Verify the JWT token and check if it is expired.
Args:
jwtoken (str): JWT token
Returns:
bool: True if the token is valid and not expired, False otherwise.
"""
try:
payload = jwt.decode(
jwtoken,
ADMIN_SECRET_KEY,
algorithms=[ALGORITHM],
options={
'verify_signature': True,
},
)
logger.info('JWT payload: {}', payload)
return True
except (jwt.PyJWTError, jwt.ImmatureSignatureError):
return False
6 changes: 4 additions & 2 deletions app/legacy/v2/notifications/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import json

from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status
from loguru import logger

from app.auth import JWTBearer
from app.dao.notifications_dao import dao_create_notification
from app.db.models import Notification, Template
from app.legacy.v2.notifications.route_schema import (
Expand All @@ -15,9 +16,10 @@
from app.routers import TimedAPIRoute

v2_notification_router = APIRouter(
dependencies=[Depends(JWTBearer())],
prefix='/legacy/v2/notifications',
tags=['v2 Notification Endpoints'],
route_class=TimedAPIRoute,
tags=['v2 Notification Endpoints'],
)


Expand Down
62 changes: 2 additions & 60 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,14 @@
import os
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Annotated, Any, AsyncContextManager, Callable, Mapping, Never
from typing import Any, AsyncContextManager, Callable, Mapping, Never

from fastapi import Depends, FastAPI, status
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from loguru import logger
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session

from app.db.db_init import (
close_db,
get_read_session_with_depends,
get_write_session_with_depends,
init_db,
)
from app.legacy.v2.notifications.rest import v2_notification_router
Expand Down Expand Up @@ -91,57 +87,3 @@ def simple_route() -> dict[str, str]:
"""
logger.info('Hello World')
return {'Hello': 'World'}


@app.post('/db/test', status_code=status.HTTP_201_CREATED)
async def test_db_create(
*,
data: str = 'hello',
db_session: Annotated[async_scoped_session[AsyncSession], Depends(get_write_session_with_depends)],
) -> dict[str, str]:
"""Test inserting Templates into the database. This is a temporary test endpoint.
Args:
data (str): The data to insert
db_session: The database session
Returns:
dict[str, str]: The inserted item
"""
from app.db.models import Template

template = Template(name=data)

async with db_session() as session:
session.add(template)
await session.commit()
return {
'id': str(template.id),
'name': template.name,
'created_at': str(template.created_at),
'updated_at': str(template.updated_at),
}


@app.get('/db/test', status_code=status.HTTP_200_OK)
async def test_db_read(
db_session: Annotated[async_scoped_session[AsyncSession], Depends(get_read_session_with_depends)],
) -> list[dict[str, str]]:
"""Test getting items from the database. This is a temporary test endpoint.
Args:
db_session: The database session
Returns:
list[dict[str,str]]: The items in the tests table
"""
from app.db.models import Template

items = []
async with db_session() as session:
results = await session.scalars(select(Template))
for r in results:
items.append({'id': str(r.id), 'name': r.name})
return items
6 changes: 4 additions & 2 deletions app/v3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from fastapi import APIRouter
from fastapi import APIRouter, Depends

from app.auth import JWTBearer
from app.constants import RESPONSE_404
from app.routers import TimedAPIRoute
from app.v3.device_registrations.rest import api_router as device_registrations_router
from app.v3.notifications.rest import api_router as notifications_router

api_router = APIRouter(
dependencies=[Depends(JWTBearer())],
prefix='/v3',
tags=['v3 Endpoints'],
responses={404: {'description': RESPONSE_404}},
route_class=TimedAPIRoute,
tags=['v3 Endpoints'],
)

api_router.include_router(device_registrations_router)
Expand Down
103 changes: 103 additions & 0 deletions tests/app/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Tests for the authentication module."""

import base64
import hmac
import json
import time
from typing import TypedDict

from fastapi.testclient import TestClient


class PayloadDict(TypedDict):
"""Payload dictionary type."""

iat: int
exp: int
jti: str


def _get_jwt_token(client_secret: str, payload_dict: PayloadDict) -> str:
"""Utility to generate a JWT token.
Args:
client_secret (str): Client secret
payload_dict (dict[str, str]): Payload dictionary
Returns:
str: a signed JWT token
"""
header_dict = {'typ': 'JWT', 'alg': 'HS256'}
header = json.dumps(header_dict)
payload = json.dumps(payload_dict)

header = base64.urlsafe_b64encode(bytes(str(header), 'utf-8')).decode().replace('=', '')
payload = base64.urlsafe_b64encode(bytes(str(payload), 'utf-8')).decode().replace('=', '')

signature = hmac.new(
bytes(client_secret, 'utf-8'), bytes(header + '.' + payload, 'utf-8'), digestmod='sha256'
).digest()
sigb64 = base64.urlsafe_b64encode(bytes(signature)).decode().replace('=', '')

token = header + '.' + payload + '.' + sigb64
return token


def test_missing_authorization_scheme(client: TestClient) -> None:
"""Test the invalid authorization scheme.
Args:
client (TestClient): FastAPI test client
"""
client_secret = 'not-very-secret'
current_timestamp = int(time.time())
payload_dict: PayloadDict = {
'iat': current_timestamp,
'exp': current_timestamp + 60,
'jti': 'jwt_nonce',
}
response = client.post(
'/v3/device-registrations', headers={'Authorization': f'{_get_jwt_token(client_secret, payload_dict)}'}
)
assert response.status_code == 403
assert response.json() == {'detail': 'Not authenticated'}


def test_expired_iat_in_token(client: TestClient) -> None:
"""Test the missing iat in token.
Args:
client (TestClient): FastAPI test client
"""
client_secret = 'not-very-secret'
current_timestamp = int(time.time())
payload_dict: PayloadDict = {
'iat': current_timestamp - 300,
'exp': current_timestamp - 240,
'jti': 'jwt_nonce',
}
response = client.post(
'/v3/device-registrations', headers={'Authorization': f'Bearer {_get_jwt_token(client_secret, payload_dict)}'}
)
assert response.status_code == 403
assert response.json() == {'detail': 'Invalid token or expired token.'}


def test_future_iat_in_token(client: TestClient) -> None:
"""Test the missing iat in token.
Args:
client (TestClient): FastAPI test client
"""
client_secret = 'not-very-secret'
current_timestamp = int(time.time())
payload_dict: PayloadDict = {
'iat': current_timestamp + 120,
'exp': current_timestamp + 180,
'jti': 'jwt_nonce',
}
response = client.post(
'/v3/device-registrations', headers={'Authorization': f'Bearer {_get_jwt_token(client_secret, payload_dict)}'}
)
assert response.status_code == 403
assert response.json() == {'detail': 'Invalid token or expired token.'}
48 changes: 48 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""Fixtures and setup to test the app."""

import base64
import hmac
import json
import time
from unittest.mock import Mock

import pytest
Expand All @@ -18,6 +22,50 @@ class ENPTestClient(TestClient):
"""

app: CustomFastAPI
token_expiry = 60

def __init__(self, app: CustomFastAPI) -> None:
"""Initialize the ENPTestClient.
Args:
app (CustomFastAPI): The FastAPI application instance.
"""
headers = {
'Authorization': f'Bearer {self.get_jwt_token('test', 'not-very-secret')}',
}
super().__init__(app, headers=headers)

def get_jwt_token(cls, client_id: str, client_secret: str) -> str:
"""Utility to generate a JWT token.
Args:
client_id (str): Client ID
client_secret (str): Client secret
Returns:
str: a signed JWT token
"""
header_dict = {'typ': 'JWT', 'alg': 'HS256'}
header = json.dumps(header_dict)
current_timestamp = int(time.time())
payload_dict = {
'iss': client_id,
'iat': current_timestamp,
'exp': current_timestamp + cls.token_expiry,
'jti': 'jwt_nonce',
}
payload = json.dumps(payload_dict)

header = base64.urlsafe_b64encode(bytes(str(header), 'utf-8')).decode().replace('=', '')
payload = base64.urlsafe_b64encode(bytes(str(payload), 'utf-8')).decode().replace('=', '')

signature = hmac.new(
bytes(client_secret, 'utf-8'), bytes(header + '.' + payload, 'utf-8'), digestmod='sha256'
).digest()
sigb64 = base64.urlsafe_b64encode(bytes(signature)).decode().replace('=', '')

token = header + '.' + payload + '.' + sigb64
return token


@pytest.fixture(scope='session')
Expand Down

0 comments on commit 06236ab

Please sign in to comment.