Skip to content

Commit

Permalink
Added Oauth2/Keycloak configs, decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
jcadam14 committed Feb 5, 2024
1 parent 644a803 commit a32c520
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 3 deletions.
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,16 @@ env = [
"INST_DB_USER=user",
"INST_DB_PWD=user",
"INST_DB_HOST=localhost:5432",
"INST_DB_NAME=filing"
"INST_DB_NAME=filing",
"KC_URL=http://localhost",
"KC_REALM=",
"KC_ADMIN_CLIENT_ID=",
"KC_ADMIN_CLIENT_SECRET=",
"KC_REALM_URL=http://localhost",
"AUTH_URL=http://localhost",
"TOKEN_URL=http://localhost",
"CERTS_URL=http://localhost",
"AUTH_CLIENT=",
]
testpaths = ["tests"]

Expand Down
4 changes: 4 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pydantic.networks import PostgresDsn
from pydantic_settings import BaseSettings, SettingsConfigDict

from regtech_api_commons.oauth2.config import KeycloakSettings

env_files_to_load = [".env"]
if os.getenv("ENV", "LOCAL") == "LOCAL":
env_files_to_load.append(".env.local")
Expand Down Expand Up @@ -39,3 +41,5 @@ def build_postgres_dsn(cls, postgres_dsn, info: ValidationInfo) -> Any:


settings = Settings()

kc_settings = KeycloakSettings(_env_file=env_files_to_load)
21 changes: 21 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import os

from fastapi import FastAPI
from fastapi.security import OAuth2AuthorizationCodeBearer
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware

from regtech_api_commons.oauth2.oauth2_backend import BearerTokenAuthBackend
from regtech_api_commons.oauth2.oauth2_admin import OAuth2Admin

from routers import filing_router

from alembic.config import Config
from alembic import command

from config import kc_settings

app = FastAPI()


Expand All @@ -19,4 +27,17 @@ async def app_start():
command.upgrade(alembic_cfg, "head")


token_bearer = OAuth2AuthorizationCodeBearer(
authorizationUrl=kc_settings.auth_url.unicode_string(), tokenUrl=kc_settings.token_url.unicode_string()
)

app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend(token_bearer, OAuth2Admin(kc_settings)))
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["authorization"],
)


app.include_router(filing_router, prefix="/v1/filing")
4 changes: 4 additions & 0 deletions src/routers/filing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from sqlalchemy.ext.asyncio import AsyncSession

from starlette.authentication import requires


async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_session)]):
request.state.db_session = session
Expand All @@ -19,6 +21,7 @@ async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_


@router.get("/periods", response_model=List[FilingPeriodDTO])
@requires("authenticated")
async def get_filing_periods(request: Request):
return await repo.get_filing_periods(request.state.db_session)

Expand All @@ -33,5 +36,6 @@ async def upload_file(


@router.get("/{lei}/filings/{filing_id}/submissions", response_model=List[SubmissionDTO])
@requires("authenticated")
async def get_submission(request: Request, lei: str, filing_id: int):
return await repo.get_submissions(request.state.db_session, filing_id)
29 changes: 29 additions & 0 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

from entities.models import FilingPeriodDAO, FilingType

from regtech_api_commons.models.auth import AuthenticatedUser
from starlette.authentication import AuthCredentials, UnauthenticatedUser


@pytest.fixture
def app_fixture(mocker: MockerFixture) -> FastAPI:
Expand All @@ -15,6 +18,32 @@ def app_fixture(mocker: MockerFixture) -> FastAPI:
return app


@pytest.fixture
def auth_mock(mocker: MockerFixture) -> Mock:
return mocker.patch("regtech_api_commons.oauth2.oauth2_backend.BearerTokenAuthBackend.authenticate")


@pytest.fixture
def authed_user_mock(auth_mock: Mock) -> Mock:
claims = {
"name": "test",
"preferred_username": "test_user",
"email": "[email protected]",
"institutions": "123456ABCDEF, 654321FEDCBA",
}
auth_mock.return_value = (
AuthCredentials(["authenticated"]),
AuthenticatedUser.from_claim(claims),
)
return auth_mock


@pytest.fixture
def unauthed_user_mock(auth_mock: Mock) -> Mock:
auth_mock.return_value = (AuthCredentials("unauthenticated"), UnauthenticatedUser())
return auth_mock


@pytest.fixture
def get_filing_period_mock(mocker: MockerFixture) -> Mock:
mock = mocker.patch("entities.repos.submission_repo.get_filing_periods")
Expand Down
20 changes: 18 additions & 2 deletions tests/api/routers/test_filing_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,30 @@


class TestFilingApi:
def test_get_periods(self, mocker: MockerFixture, app_fixture: FastAPI, get_filing_period_mock: Mock):
def test_unauthed_get_periods(
self, mocker: MockerFixture, app_fixture: FastAPI, get_filing_period_mock: Mock, unauthed_user_mock: Mock
):
client = TestClient(app_fixture)
res = client.get("/v1/filing/periods")
assert res.status_code == 403

def test_get_periods(
self, mocker: MockerFixture, app_fixture: FastAPI, get_filing_period_mock: Mock, authed_user_mock: Mock
):
client = TestClient(app_fixture)
res = client.get("/v1/filing/periods")
assert res.status_code == 200
assert len(res.json()) == 1
assert res.json()[0]["name"] == "FilingPeriod2024"

async def test_get_submissions(self, mocker: MockerFixture, app_fixture: FastAPI):
def test_unauthed_get_submissions(
self, mocker: MockerFixture, app_fixture: FastAPI, get_filing_period_mock: Mock, unauthed_user_mock: Mock
):
client = TestClient(app_fixture)
res = client.get("/v1/filing/123456790/filings/1/submissions")
assert res.status_code == 403

async def test_get_submissions(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
mock = mocker.patch("entities.repos.submission_repo.get_submissions")
mock.return_value = [
SubmissionDAO(
Expand Down

0 comments on commit a32c520

Please sign in to comment.