Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/30 gh actions update #32

Merged
merged 2 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/linters.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@ name: Linters
on: [push]

jobs:
linting:
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: psf/black@stable
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: chartboost/ruff-action@v1
32 changes: 17 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version = "0.1.0"
description = ""
authors = ["lchen-2101 <[email protected]>"]
readme = "README.md"
packages = [{include = "regtech-user-fi-management"}]
packages = [{ include = "regtech-user-fi-management" }]

[tool.poetry.dependencies]
python = "^3.11"
Expand All @@ -30,22 +30,24 @@ pytest-mock = "^3.11.1"

[tool.pytest.ini_options]
asyncio_mode = "auto"
pythonpath = [
"src"
]
pythonpath = ["src"]
addopts = [
"--cov-report=term-missing",
"--cov-branch",
"--cov-report=xml",
"--cov-report=term",
"--cov=src",
"-vv",
"--strict-markers",
"-rfE",
]
testpaths = [
"tests",
"--cov-report=term-missing",
"--cov-branch",
"--cov-report=xml",
"--cov-report=term",
"--cov=src",
"-vv",
"--strict-markers",
"-rfE",
]
testpaths = ["tests"]

[tool.black]
line-length = 120

[tool.ruff]
line-length = 120

[tool.coverage.run]
relative_files = true
Expand Down
13 changes: 3 additions & 10 deletions src/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,17 @@
}


async def check_domain(
request: Request, session: Annotated[AsyncSession, Depends(get_session)]
) -> None:
async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
if request_needs_domain_check(request):
if not request.user.is_authenticated:
raise HTTPException(status_code=HTTPStatus.FORBIDDEN)
if await email_domain_denied(session, request.user.email):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="email domain denied"
)
raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="email domain denied")


def request_needs_domain_check(request: Request) -> bool:
path = request.scope["path"].rstrip("/")
return not (
path in OPEN_DOMAIN_REQUESTS
and request.scope["method"] in OPEN_DOMAIN_REQUESTS[path]
)
return not (path in OPEN_DOMAIN_REQUESTS and request.scope["method"] in OPEN_DOMAIN_REQUESTS[path])


async def email_domain_denied(session: AsyncSession, email: str) -> bool:
Expand Down
8 changes: 2 additions & 6 deletions src/entities/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,8 @@

DB_URL = os.getenv("INST_CONN")
DB_SCHEMA = os.getenv("INST_DB_SCHEMA", "public")
engine = create_async_engine(DB_URL, echo=True).execution_options(
schema_translate_map={None: DB_SCHEMA}
)
SessionLocal = async_scoped_session(
async_sessionmaker(engine, expire_on_commit=False), current_task
)
engine = create_async_engine(DB_URL, echo=True).execution_options(schema_translate_map={None: DB_SCHEMA})
SessionLocal = async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task)


async def get_session():
Expand Down
4 changes: 1 addition & 3 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ class FinancialInstitutionDao(AuditMixin, Base):
class FinancialInstitutionDomainDao(AuditMixin, Base):
__tablename__ = "financial_institution_domains"
domain: Mapped[str] = mapped_column(index=True, primary_key=True)
lei: Mapped[str] = mapped_column(
ForeignKey("financial_institutions.lei"), index=True, primary_key=True
)
lei: Mapped[str] = mapped_column(ForeignKey("financial_institutions.lei"), index=True, primary_key=True)
fi = relationship("FinancialInstitutionDao", back_populates="domains")


Expand Down
8 changes: 2 additions & 6 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,9 @@ async def get_institution(session: AsyncSession, lei: str) -> FinancialInstituti
return await session.scalar(stmt)


async def upsert_institution(
session: AsyncSession, fi: FinancialInstitutionDto
) -> FinancialInstitutionDao:
async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto) -> FinancialInstitutionDao:
async with session.begin():
stmt = select(FinancialInstitutionDao).filter(
FinancialInstitutionDao.lei == fi.lei
)
stmt = select(FinancialInstitutionDao).filter(FinancialInstitutionDao.lei == fi.lei)
res = await session.execute(stmt)
db_fi = res.scalar_one_or_none()
if db_fi is None:
Expand Down
20 changes: 5 additions & 15 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,23 @@


@app.exception_handler(HTTPException)
async def http_exception_handler(
request: Request, exception: HTTPException
) -> JSONResponse:
async def http_exception_handler(request: Request, exception: HTTPException) -> JSONResponse:
log.error(exception, exc_info=True, stack_info=True)
return JSONResponse(
status_code=exception.status_code, content={"detail": exception.detail}
)
return JSONResponse(status_code=exception.status_code, content={"detail": exception.detail})


@app.exception_handler(Exception)
async def general_exception_handler(
request: Request, exception: Exception
) -> JSONResponse:
async def general_exception_handler(request: Request, exception: Exception) -> JSONResponse:
log.error(exception, exc_info=True, stack_info=True)
return JSONResponse(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
content={"detail": "server error"},
)


oauth2_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl=os.getenv("AUTH_URL"), tokenUrl=os.getenv("TOKEN_URL")
)
oauth2_scheme = OAuth2AuthorizationCodeBearer(authorizationUrl=os.getenv("AUTH_URL"), tokenUrl=os.getenv("TOKEN_URL"))

app.add_middleware(
AuthenticationMiddleware, backend=BearerTokenAuthBackend(oauth2_scheme)
)
app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend(oauth2_scheme))
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand Down
12 changes: 3 additions & 9 deletions src/oauth2/oauth2_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def update_user(self, user_id: str, payload: Dict[str, Any]) -> None:
self._admin.update_user(user_id, payload)
except kce.KeycloakError as e:
log.exception("Failed to update user: %s", user_id, extra=payload)
raise HTTPException(
status_code=e.response_code, detail="Failed to update user"
)
raise HTTPException(status_code=e.response_code, detail="Failed to update user")

def upsert_group(self, lei: str, name: str) -> str:
try:
Expand All @@ -65,9 +63,7 @@ def upsert_group(self, lei: str, name: str) -> str:
return group["id"]
except kce.KeycloakError as e:
log.exception("Failed to upsert group, lei: %s, name: %s", lei, name)
raise HTTPException(
status_code=e.response_code, detail="Failed to upsert group"
)
raise HTTPException(status_code=e.response_code, detail="Failed to upsert group")

def get_group(self, lei: str) -> Dict[str, Any] | None:
try:
Expand All @@ -80,9 +76,7 @@ def associate_to_group(self, user_id: str, group_id: str) -> None:
self._admin.group_user_add(user_id, group_id)
except kce.KeycloakError as e:
log.exception("Failed to associate user %s to group %s", user_id, group_id)
raise HTTPException(
status_code=e.response_code, detail="Failed to associate user to group"
)
raise HTTPException(status_code=e.response_code, detail="Failed to associate user to group")

def associate_to_lei(self, user_id: str, lei: str) -> None:
group = self.get_group(lei)
Expand Down
8 changes: 2 additions & 6 deletions src/oauth2/oauth2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,13 @@ class BearerTokenAuthBackend(AuthenticationBackend):
def __init__(self, token_bearer: OAuth2AuthorizationCodeBearer) -> None:
self.token_bearer = token_bearer

async def authenticate(
self, conn: HTTPConnection
) -> Coroutine[Any, Any, Tuple[AuthCredentials, BaseUser] | None]:
async def authenticate(self, conn: HTTPConnection) -> Coroutine[Any, Any, Tuple[AuthCredentials, BaseUser] | None]:
try:
token = await self.token_bearer(conn)
claims = oauth2_admin.get_claims(token)
if claims is not None:
auths = (
self.extract_nested(
claims, "resource_access", "realm-management", "roles"
)
self.extract_nested(claims, "resource_access", "realm-management", "roles")
+ self.extract_nested(claims, "resource_access", "account", "roles")
+ ["authenticated"]
)
Expand Down
8 changes: 2 additions & 6 deletions src/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
from starlette.authentication import requires


async def set_db(
request: Request, session: Annotated[AsyncSession, Depends(get_session)]
):
async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_session)]):
request.state.db_session = session


Expand All @@ -34,9 +32,7 @@ async def get_institutions(
page: int = 0,
count: int = 100,
):
return await repo.get_institutions(
request.state.db_session, leis, domain, page, count
)
return await repo.get_institutions(request.state.db_session, leis, domain, page, count)


@router.post("/", response_model=Tuple[str, FinancialInstitutionDto])
Expand Down
4 changes: 1 addition & 3 deletions src/util/router_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ def api_route(
if path.endswith("/"):
path = path[:-1]

add_path = super().api_route(
path, include_in_schema=include_in_schema, **kwargs
)
add_path = super().api_route(path, include_in_schema=include_in_schema, **kwargs)

add_alt_path = super().api_route(f"{path}/", include_in_schema=False, **kwargs)

Expand Down
4 changes: 1 addition & 3 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def authed_user_mock(auth_mock: Mock) -> Mock:
"sub": "testuser123",
}
auth_mock.return_value = (
AuthCredentials(
["manage-account", "query-groups", "manage-users", "authenticated"]
),
AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"]),
AuthenticatedUser.from_claim(claims),
)
return auth_mock
Expand Down
24 changes: 6 additions & 18 deletions tests/api/routers/test_admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,15 @@ def test_get_me_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock):
res = client.get("/v1/admin/me")
assert res.status_code == 403

def test_get_me_authed(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock
):
def test_get_me_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
client = TestClient(app_fixture)
res = client.get("/v1/admin/me")
assert res.status_code == 200
assert res.json().get("name") == "test"

def test_update_me_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock):
client = TestClient(app_fixture)
res = client.put(
"/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"}
)
res = client.put("/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"})
assert res.status_code == 403

def test_update_me_no_permission(self, app_fixture: FastAPI, auth_mock: Mock):
Expand All @@ -41,14 +37,10 @@ def test_update_me_no_permission(self, app_fixture: FastAPI, auth_mock: Mock):
AuthenticatedUser.from_claim(claims),
)
client = TestClient(app_fixture)
res = client.put(
"/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"}
)
res = client.put("/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"})
assert res.status_code == 403

def test_update_me(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock
):
def test_update_me(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
update_user_mock = mocker.patch("oauth2.oauth2_admin.OAuth2Admin.update_user")
update_user_mock.return_value = None
client = TestClient(app_fixture)
Expand All @@ -57,12 +49,8 @@ def test_update_me(
update_user_mock.assert_called_once_with("testuser123", data)
assert res.status_code == 202

def test_associate_institutions(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock
):
associate_lei_mock = mocker.patch(
"oauth2.oauth2_admin.OAuth2Admin.associate_to_lei"
)
def test_associate_institutions(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
associate_lei_mock = mocker.patch("oauth2.oauth2_admin.OAuth2Admin.associate_to_lei")
associate_lei_mock.return_value = None
client = TestClient(app_fixture)
data = ["testlei1", "testlei2"]
Expand Down
Loading