diff --git a/pyproject.toml b/pyproject.toml index a6583a07..67216ab2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,16 @@ env = [ "INST_DB_USER=user", "INST_DB_PWD=user", "INST_DB_HOST=localhost:5432", - "INST_DB_NAME=filing" + "INST_DB_NAME=filing", + "KC_URL=http://localhost", + "KC_REALM=", + "KC_ADMIN_CLIENT_ID=", + "KC_ADMIN_CLIENT_SECRET=", + "KC_REALM_URL=http://localhost", + "AUTH_URL=http://localhost", + "TOKEN_URL=http://localhost", + "CERTS_URL=http://localhost", + "AUTH_CLIENT=", ] testpaths = ["tests"] diff --git a/src/config.py b/src/config.py index 627df8c3..47b604df 100644 --- a/src/config.py +++ b/src/config.py @@ -6,6 +6,8 @@ from pydantic.networks import PostgresDsn from pydantic_settings import BaseSettings, SettingsConfigDict +from regtech_api_commons.oauth2.config import KeycloakSettings + env_files_to_load = [".env"] if os.getenv("ENV", "LOCAL") == "LOCAL": env_files_to_load.append(".env.local") @@ -39,3 +41,5 @@ def build_postgres_dsn(cls, postgres_dsn, info: ValidationInfo) -> Any: settings = Settings() + +kc_settings = KeycloakSettings(_env_file=env_files_to_load) diff --git a/src/main.py b/src/main.py index ee7ddf14..b3516718 100644 --- a/src/main.py +++ b/src/main.py @@ -1,12 +1,20 @@ import os from fastapi import FastAPI +from fastapi.security import OAuth2AuthorizationCodeBearer +from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.authentication import AuthenticationMiddleware + +from regtech_api_commons.oauth2.oauth2_backend import BearerTokenAuthBackend +from regtech_api_commons.oauth2.oauth2_admin import OAuth2Admin from routers import filing_router from alembic.config import Config from alembic import command +from config import kc_settings + app = FastAPI() @@ -19,4 +27,17 @@ async def app_start(): command.upgrade(alembic_cfg, "head") +token_bearer = OAuth2AuthorizationCodeBearer( + authorizationUrl=kc_settings.auth_url.unicode_string(), tokenUrl=kc_settings.token_url.unicode_string() +) + +app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend(token_bearer, OAuth2Admin(kc_settings))) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["authorization"], +) + + app.include_router(filing_router, prefix="/v1/filing") diff --git a/src/routers/filing.py b/src/routers/filing.py index 3f3bb363..acde7202 100644 --- a/src/routers/filing.py +++ b/src/routers/filing.py @@ -10,6 +10,8 @@ from sqlalchemy.ext.asyncio import AsyncSession +from starlette.authentication import requires + async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_session)]): request.state.db_session = session @@ -19,6 +21,7 @@ async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_ @router.get("/periods", response_model=List[FilingPeriodDTO]) +@requires("authenticated") async def get_filing_periods(request: Request): return await repo.get_filing_periods(request.state.db_session) @@ -33,5 +36,6 @@ async def upload_file( @router.get("/{lei}/filings/{filing_id}/submissions", response_model=List[SubmissionDTO]) +@requires("authenticated") async def get_submission(request: Request, lei: str, filing_id: int): return await repo.get_submissions(request.state.db_session, filing_id) diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 4ead5458..5bccf6bf 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -7,6 +7,9 @@ from entities.models import FilingPeriodDAO, FilingType +from regtech_api_commons.models.auth import AuthenticatedUser +from starlette.authentication import AuthCredentials, UnauthenticatedUser + @pytest.fixture def app_fixture(mocker: MockerFixture) -> FastAPI: @@ -15,6 +18,32 @@ def app_fixture(mocker: MockerFixture) -> FastAPI: return app +@pytest.fixture +def auth_mock(mocker: MockerFixture) -> Mock: + return mocker.patch("regtech_api_commons.oauth2.oauth2_backend.BearerTokenAuthBackend.authenticate") + + +@pytest.fixture +def authed_user_mock(auth_mock: Mock) -> Mock: + claims = { + "name": "test", + "preferred_username": "test_user", + "email": "test@local.host", + "institutions": "123456ABCDEF, 654321FEDCBA", + } + auth_mock.return_value = ( + AuthCredentials(["authenticated"]), + AuthenticatedUser.from_claim(claims), + ) + return auth_mock + + +@pytest.fixture +def unauthed_user_mock(auth_mock: Mock) -> Mock: + auth_mock.return_value = (AuthCredentials("unauthenticated"), UnauthenticatedUser()) + return auth_mock + + @pytest.fixture def get_filing_period_mock(mocker: MockerFixture) -> Mock: mock = mocker.patch("entities.repos.submission_repo.get_filing_periods") diff --git a/tests/api/routers/test_filing_api.py b/tests/api/routers/test_filing_api.py index 2a695fa1..9e711c11 100644 --- a/tests/api/routers/test_filing_api.py +++ b/tests/api/routers/test_filing_api.py @@ -8,14 +8,30 @@ class TestFilingApi: - def test_get_periods(self, mocker: MockerFixture, app_fixture: FastAPI, get_filing_period_mock: Mock): + def test_unauthed_get_periods( + self, mocker: MockerFixture, app_fixture: FastAPI, get_filing_period_mock: Mock, unauthed_user_mock: Mock + ): + client = TestClient(app_fixture) + res = client.get("/v1/filing/periods") + assert res.status_code == 403 + + def test_get_periods( + self, mocker: MockerFixture, app_fixture: FastAPI, get_filing_period_mock: Mock, authed_user_mock: Mock + ): client = TestClient(app_fixture) res = client.get("/v1/filing/periods") assert res.status_code == 200 assert len(res.json()) == 1 assert res.json()[0]["name"] == "FilingPeriod2024" - async def test_get_submissions(self, mocker: MockerFixture, app_fixture: FastAPI): + def test_unauthed_get_submissions( + self, mocker: MockerFixture, app_fixture: FastAPI, get_filing_period_mock: Mock, unauthed_user_mock: Mock + ): + client = TestClient(app_fixture) + res = client.get("/v1/filing/123456790/filings/1/submissions") + assert res.status_code == 403 + + async def test_get_submissions(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): mock = mocker.patch("entities.repos.submission_repo.get_submissions") mock.return_value = [ SubmissionDAO(