Skip to content

Commit

Permalink
feat: test black and ruff config
Browse files Browse the repository at this point in the history
  • Loading branch information
lchen-2101 committed Sep 18, 2023
1 parent fa4e532 commit 8b197b4
Show file tree
Hide file tree
Showing 14 changed files with 60 additions and 183 deletions.
13 changes: 3 additions & 10 deletions src/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,17 @@
}


async def check_domain(
request: Request, session: Annotated[AsyncSession, Depends(get_session)]
) -> None:
async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
if request_needs_domain_check(request):
if not request.user.is_authenticated:
raise HTTPException(status_code=HTTPStatus.FORBIDDEN)
if await email_domain_denied(session, request.user.email):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="email domain denied"
)
raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="email domain denied")


def request_needs_domain_check(request: Request) -> bool:
path = request.scope["path"].rstrip("/")
return not (
path in OPEN_DOMAIN_REQUESTS
and request.scope["method"] in OPEN_DOMAIN_REQUESTS[path]
)
return not (path in OPEN_DOMAIN_REQUESTS and request.scope["method"] in OPEN_DOMAIN_REQUESTS[path])


async def email_domain_denied(session: AsyncSession, email: str) -> bool:
Expand Down
8 changes: 2 additions & 6 deletions src/entities/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,8 @@

DB_URL = os.getenv("INST_CONN")
DB_SCHEMA = os.getenv("INST_DB_SCHEMA", "public")
engine = create_async_engine(DB_URL, echo=True).execution_options(
schema_translate_map={None: DB_SCHEMA}
)
SessionLocal = async_scoped_session(
async_sessionmaker(engine, expire_on_commit=False), current_task
)
engine = create_async_engine(DB_URL, echo=True).execution_options(schema_translate_map={None: DB_SCHEMA})
SessionLocal = async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task)


async def get_session():
Expand Down
4 changes: 1 addition & 3 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ class FinancialInstitutionDao(AuditMixin, Base):
class FinancialInstitutionDomainDao(AuditMixin, Base):
__tablename__ = "financial_institution_domains"
domain: Mapped[str] = mapped_column(index=True, primary_key=True)
lei: Mapped[str] = mapped_column(
ForeignKey("financial_institutions.lei"), index=True, primary_key=True
)
lei: Mapped[str] = mapped_column(ForeignKey("financial_institutions.lei"), index=True, primary_key=True)
fi = relationship("FinancialInstitutionDao", back_populates="domains")


Expand Down
8 changes: 2 additions & 6 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,9 @@ async def get_institution(session: AsyncSession, lei: str) -> FinancialInstituti
return await session.scalar(stmt)


async def upsert_institution(
session: AsyncSession, fi: FinancialInstitutionDto
) -> FinancialInstitutionDao:
async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto) -> FinancialInstitutionDao:
async with session.begin():
stmt = select(FinancialInstitutionDao).filter(
FinancialInstitutionDao.lei == fi.lei
)
stmt = select(FinancialInstitutionDao).filter(FinancialInstitutionDao.lei == fi.lei)
res = await session.execute(stmt)
db_fi = res.scalar_one_or_none()
if db_fi is None:
Expand Down
20 changes: 5 additions & 15 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,23 @@


@app.exception_handler(HTTPException)
async def http_exception_handler(
request: Request, exception: HTTPException
) -> JSONResponse:
async def http_exception_handler(request: Request, exception: HTTPException) -> JSONResponse:
log.error(exception, exc_info=True, stack_info=True)
return JSONResponse(
status_code=exception.status_code, content={"detail": exception.detail}
)
return JSONResponse(status_code=exception.status_code, content={"detail": exception.detail})


@app.exception_handler(Exception)
async def general_exception_handler(
request: Request, exception: Exception
) -> JSONResponse:
async def general_exception_handler(request: Request, exception: Exception) -> JSONResponse:
log.error(exception, exc_info=True, stack_info=True)
return JSONResponse(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
content={"detail": "server error"},
)


oauth2_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl=os.getenv("AUTH_URL"), tokenUrl=os.getenv("TOKEN_URL")
)
oauth2_scheme = OAuth2AuthorizationCodeBearer(authorizationUrl=os.getenv("AUTH_URL"), tokenUrl=os.getenv("TOKEN_URL"))

app.add_middleware(
AuthenticationMiddleware, backend=BearerTokenAuthBackend(oauth2_scheme)
)
app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend(oauth2_scheme))
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand Down
12 changes: 3 additions & 9 deletions src/oauth2/oauth2_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def update_user(self, user_id: str, payload: Dict[str, Any]) -> None:
self._admin.update_user(user_id, payload)
except kce.KeycloakError as e:
log.exception("Failed to update user: %s", user_id, extra=payload)
raise HTTPException(
status_code=e.response_code, detail="Failed to update user"
)
raise HTTPException(status_code=e.response_code, detail="Failed to update user")

def upsert_group(self, lei: str, name: str) -> str:
try:
Expand All @@ -65,9 +63,7 @@ def upsert_group(self, lei: str, name: str) -> str:
return group["id"]
except kce.KeycloakError as e:
log.exception("Failed to upsert group, lei: %s, name: %s", lei, name)
raise HTTPException(
status_code=e.response_code, detail="Failed to upsert group"
)
raise HTTPException(status_code=e.response_code, detail="Failed to upsert group")

def get_group(self, lei: str) -> Dict[str, Any] | None:
try:
Expand All @@ -80,9 +76,7 @@ def associate_to_group(self, user_id: str, group_id: str) -> None:
self._admin.group_user_add(user_id, group_id)
except kce.KeycloakError as e:
log.exception("Failed to associate user %s to group %s", user_id, group_id)
raise HTTPException(
status_code=e.response_code, detail="Failed to associate user to group"
)
raise HTTPException(status_code=e.response_code, detail="Failed to associate user to group")

def associate_to_lei(self, user_id: str, lei: str) -> None:
group = self.get_group(lei)
Expand Down
8 changes: 2 additions & 6 deletions src/oauth2/oauth2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,13 @@ class BearerTokenAuthBackend(AuthenticationBackend):
def __init__(self, token_bearer: OAuth2AuthorizationCodeBearer) -> None:
self.token_bearer = token_bearer

async def authenticate(
self, conn: HTTPConnection
) -> Coroutine[Any, Any, Tuple[AuthCredentials, BaseUser] | None]:
async def authenticate(self, conn: HTTPConnection) -> Coroutine[Any, Any, Tuple[AuthCredentials, BaseUser] | None]:
try:
token = await self.token_bearer(conn)
claims = oauth2_admin.get_claims(token)
if claims is not None:
auths = (
self.extract_nested(
claims, "resource_access", "realm-management", "roles"
)
self.extract_nested(claims, "resource_access", "realm-management", "roles")
+ self.extract_nested(claims, "resource_access", "account", "roles")
+ ["authenticated"]
)
Expand Down
8 changes: 2 additions & 6 deletions src/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
from starlette.authentication import requires


async def set_db(
request: Request, session: Annotated[AsyncSession, Depends(get_session)]
):
async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_session)]):
request.state.db_session = session


Expand All @@ -34,9 +32,7 @@ async def get_institutions(
page: int = 0,
count: int = 100,
):
return await repo.get_institutions(
request.state.db_session, leis, domain, page, count
)
return await repo.get_institutions(request.state.db_session, leis, domain, page, count)


@router.post("/", response_model=Tuple[str, FinancialInstitutionDto])
Expand Down
4 changes: 1 addition & 3 deletions src/util/router_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ def api_route(
if path.endswith("/"):
path = path[:-1]

add_path = super().api_route(
path, include_in_schema=include_in_schema, **kwargs
)
add_path = super().api_route(path, include_in_schema=include_in_schema, **kwargs)

add_alt_path = super().api_route(f"{path}/", include_in_schema=False, **kwargs)

Expand Down
4 changes: 1 addition & 3 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def authed_user_mock(auth_mock: Mock) -> Mock:
"sub": "testuser123",
}
auth_mock.return_value = (
AuthCredentials(
["manage-account", "query-groups", "manage-users", "authenticated"]
),
AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"]),
AuthenticatedUser.from_claim(claims),
)
return auth_mock
Expand Down
24 changes: 6 additions & 18 deletions tests/api/routers/test_admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,15 @@ def test_get_me_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock):
res = client.get("/v1/admin/me")
assert res.status_code == 403

def test_get_me_authed(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock
):
def test_get_me_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
client = TestClient(app_fixture)
res = client.get("/v1/admin/me")
assert res.status_code == 200
assert res.json().get("name") == "test"

def test_update_me_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock):
client = TestClient(app_fixture)
res = client.put(
"/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"}
)
res = client.put("/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"})
assert res.status_code == 403

def test_update_me_no_permission(self, app_fixture: FastAPI, auth_mock: Mock):
Expand All @@ -41,14 +37,10 @@ def test_update_me_no_permission(self, app_fixture: FastAPI, auth_mock: Mock):
AuthenticatedUser.from_claim(claims),
)
client = TestClient(app_fixture)
res = client.put(
"/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"}
)
res = client.put("/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"})
assert res.status_code == 403

def test_update_me(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock
):
def test_update_me(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
update_user_mock = mocker.patch("oauth2.oauth2_admin.OAuth2Admin.update_user")
update_user_mock.return_value = None
client = TestClient(app_fixture)
Expand All @@ -57,12 +49,8 @@ def test_update_me(
update_user_mock.assert_called_once_with("testuser123", data)
assert res.status_code == 202

def test_associate_institutions(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock
):
associate_lei_mock = mocker.patch(
"oauth2.oauth2_admin.OAuth2Admin.associate_to_lei"
)
def test_associate_institutions(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
associate_lei_mock = mocker.patch("oauth2.oauth2_admin.OAuth2Admin.associate_to_lei")
associate_lei_mock.return_value = None
client = TestClient(app_fixture)
data = ["testlei1", "testlei2"]
Expand Down
Loading

0 comments on commit 8b197b4

Please sign in to comment.