diff --git a/src/db/methods/domains.py b/src/db/methods/domains.py index 9802582..5e8924f 100644 --- a/src/db/methods/domains.py +++ b/src/db/methods/domains.py @@ -1,11 +1,11 @@ -from db import types +from db.types import types from .collections import domains def attach_to_entity( domain: str, target_id: int, - target_type: types.domain.TargetType + target_type: types.EntityTargetType ) -> None: domains.find_one_and_update( {"_id": domain}, @@ -16,7 +16,7 @@ def attach_to_entity( upsert=True ) -def resolve_entity(domain: str) -> types.domain.Entity | None: +def resolve_entity(domain: str) -> types.Entity | None: if (entity := domains.find_one({"_id": domain})) is None: return None - return types.domain.Entity(**entity) + return types.Entity(**entity) diff --git a/src/db/methods/users.py b/src/db/methods/users.py index f50220b..6b87320 100644 --- a/src/db/methods/users.py +++ b/src/db/methods/users.py @@ -1,21 +1,21 @@ from pymongo.errors import DuplicateKeyError -from db import types +from db.types import types from .collections import users from .helpers import insert_with_auto_increment_id -def get(user_id: int) -> types.user.User | None: +def get(user_id: int) -> types.User | None: if (user := users.find_one({"_id": user_id})) is None: return None - return types.user.User(**user) + return types.User(**user) def get_by_email(email: str): if (user := users.find_one({"email": email})) is None: return None - return types.user.User(**user) + return types.User(**user) -def insert_user_with_id(user: types.user.UserWithoutID) -> int | None: +def insert_user_with_id(user: types.UserWithoutID) -> int | None: """ Returns: Inserted user id or None, if error occurred diff --git a/src/db/types/__init__.py b/src/db/types/__init__.py index 87c3b9a..4f6ce51 100644 --- a/src/db/types/__init__.py +++ b/src/db/types/__init__.py @@ -1,18 +1,9 @@ -from . import ( - auth, - common, - user, - domain, - requests, - responses -) - +from . import types +from .requests import RQ +from .responses import RS __all__ = [ - "auth", - "user", - "domain", - "common", - "requests", - "responses" + "types", + "RQ", + "RS" ] diff --git a/src/db/types/auth.py b/src/db/types/auth.py deleted file mode 100644 index 13bf3fe..0000000 --- a/src/db/types/auth.py +++ /dev/null @@ -1,6 +0,0 @@ -from utils.schemas import BaseModel - - -class JWTPair(BaseModel): - access_token: str - refresh_token: str diff --git a/src/db/types/domain.py b/src/db/types/domain.py deleted file mode 100644 index 33a74a7..0000000 --- a/src/db/types/domain.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Literal -from pydantic import Field - -from utils.schemas import BaseModel - - -TargetType = Literal["user", "group", "contest"] - -class Entity(BaseModel): - id: str = Field(alias='_id') - target_type: TargetType - target_id: int diff --git a/src/db/types/requests.py b/src/db/types/requests.py index 51e2d49..f33073f 100644 --- a/src/db/types/requests.py +++ b/src/db/types/requests.py @@ -4,7 +4,7 @@ from pydantic import model_validator, Field from utils.schemas import BaseModel -from .common import is_email +from .validators import is_email class RQ: diff --git a/src/db/types/user.py b/src/db/types/types.py similarity index 53% rename from src/db/types/user.py rename to src/db/types/types.py index a555d30..5a35448 100644 --- a/src/db/types/user.py +++ b/src/db/types/types.py @@ -1,8 +1,20 @@ +from typing import Literal from pydantic import Field from utils.schemas import BaseModel +class JWTPair(BaseModel): + access_token: str + refresh_token: str + +EntityTargetType = Literal["user", "group", "contest"] + +class Entity(BaseModel): + id: str = Field(alias='_id') + target_type: EntityTargetType + target_id: int + class _UserBase(BaseModel): domain: str | None = None first_name: str diff --git a/src/db/types/common.py b/src/db/types/validators.py similarity index 96% rename from src/db/types/common.py rename to src/db/types/validators.py index b5074ef..b9e0d2a 100644 --- a/src/db/types/common.py +++ b/src/db/types/validators.py @@ -1,9 +1,9 @@ -import re -from pydantic import AfterValidator - - -def _is_email(email: str) -> str: - pattern = r"^[a-zA-Z0-9_\.]+@[a-zA-Z0-9_\.]+\.[a-z]{2,5}" - assert re.fullmatch(pattern, email) is not None, f"String {email} is not a valid email" - return email -is_email = AfterValidator(_is_email) +import re +from pydantic import AfterValidator + + +def _is_email(email: str) -> str: + pattern = r"^[a-zA-Z0-9_\.]+@[a-zA-Z0-9_\.]+\.[a-z]{2,5}" + assert re.fullmatch(pattern, email) is not None, f"String {email} is not a valid email" + return email +is_email = AfterValidator(_is_email) diff --git a/src/routers/auth.py b/src/routers/auth.py index a318f90..4cf77a6 100644 --- a/src/routers/auth.py +++ b/src/routers/auth.py @@ -5,17 +5,15 @@ from utils.auth import get_current_user_by_refresh_token from utils.response import SuccessfulResponse, ErrorCodes, ErrorResponse from db import methods -from db import types -from db.types import requests as RQ, responses as RS -from db.types.requests import RQ -from db.types.responses import RS +from db.types import types, RS, RQ router = APIRouter() + @router.post("/signin", response_model=SuccessfulResponse[RS.auth.signin]) async def signin(request: RQ.auth.signin): - user: types.user.User | None = None + user: types.User | None = None if request.id: user = methods.users.get(request.id) elif request.domain: @@ -42,5 +40,5 @@ async def signin(request: RQ.auth.signin): return utils.auth.create_jwt_pair_by_user_id(user.id) @router.get("/refresh", response_model=SuccessfulResponse[RS.auth.refresh]) -async def refresh(current_user: types.user.User = Depends(get_current_user_by_refresh_token)): +async def refresh(current_user: types.User = Depends(get_current_user_by_refresh_token)): return utils.auth.create_jwt_pair_by_user_id(current_user.id) diff --git a/src/routers/test.py b/src/routers/test.py index e5e56a3..35acd74 100644 --- a/src/routers/test.py +++ b/src/routers/test.py @@ -6,9 +6,9 @@ import utils.auth from config import config from utils.response import ErrorCodes, ErrorResponse, SuccessfulResponse -from db import client, methods, types -from db.types.requests import RQ -from db.types.responses import RS +from db import client, methods +from db.types import types, RS, RQ + router = APIRouter() @@ -27,7 +27,7 @@ async def signup(request: RQ.test.signup): hashed_password = utils.auth.hash_password(request.password) inserted_user_id = methods.users.insert_user_with_id( - types.user.UserWithoutID( + types.UserWithoutID( email=request.email, hashed_password=hashed_password, first_name=request.first_name diff --git a/src/routers/users.py b/src/routers/users.py index f582133..847e884 100644 --- a/src/routers/users.py +++ b/src/routers/users.py @@ -2,15 +2,14 @@ from utils.response import SuccessfulResponse from utils.auth import get_current_user -from db import types -from db.types.responses import RS +from db.types import types, RS router = APIRouter() @router.get("/current", response_model=SuccessfulResponse[RS.users.current]) -async def current(current_user: types.user.User = Depends(get_current_user)): +async def current(current_user: types.User = Depends(get_current_user)): return RS.users.current( id=current_user.id, email=current_user.email, diff --git a/src/utils/auth/auth.py b/src/utils/auth/auth.py index b54740e..e675eb4 100644 --- a/src/utils/auth/auth.py +++ b/src/utils/auth/auth.py @@ -5,8 +5,8 @@ from fastapi import Header from fastapi import status as http_status -from db import types from db import methods +from db.types import types from utils import response from config import config @@ -117,8 +117,8 @@ def decode_jwt( http_status_code=http_status.HTTP_401_UNAUTHORIZED, ) from exc -def create_jwt_pair_by_user_id(user_id: int) ->types.auth.JWTPair: - return types.auth.JWTPair( +def create_jwt_pair_by_user_id(user_id: int) -> types.JWTPair: + return types.JWTPair( access_token=create_jwt( str(user_id), expiration_time_minutes=config.auth.access_token_expire_minutes, @@ -131,7 +131,7 @@ def create_jwt_pair_by_user_id(user_id: int) ->types.auth.JWTPair: ) ) -def get_current_user(authorization: str = Header()) -> types.user.User: +def get_current_user(authorization: str = Header()) -> types.User: token = get_auth_header_credentials(authorization, "Bearer") subject = decode_jwt(token, config.auth.jwt_access_secret_key.get_secret_value())["subject"] @@ -145,7 +145,7 @@ def get_current_user(authorization: str = Header()) -> types.user.User: message="Could not found user by token" ) -def get_current_user_by_refresh_token(authorization: str = Header()) -> types.user.User: +def get_current_user_by_refresh_token(authorization: str = Header()) -> types.User: token = get_auth_header_credentials(authorization, "Bearer") subject = decode_jwt(token, config.auth.jwt_refresh_secret_key.get_secret_value())["subject"]