From ca1ceae18fd0e5b4903233fe4b0d6861426d1e06 Mon Sep 17 00:00:00 2001 From: alowave223 Date: Mon, 16 Jan 2023 00:02:22 +0200 Subject: [PATCH 01/11] feat: Added `/v2/oauth/` authorization to the API v2 --- app/api/v2/__init__.py | 36 +++++ app/api/v2/common/json.py | 9 +- app/api/v2/common/oauth.py | 85 ++++++++++ app/api/v2/models/oauth.py | 22 +++ app/api/v2/oauth.py | 204 ++++++++++++++++++++++++ app/repositories/access_tokens.py | 113 +++++++++++++ app/repositories/authorization_codes.py | 88 ++++++++++ app/repositories/ouath_clients.py | 145 +++++++++++++++++ app/repositories/refresh_tokens.py | 101 ++++++++++++ app/state/services.py | 2 +- 10 files changed, 803 insertions(+), 2 deletions(-) create mode 100644 app/api/v2/common/oauth.py create mode 100644 app/api/v2/models/oauth.py create mode 100644 app/api/v2/oauth.py create mode 100644 app/repositories/access_tokens.py create mode 100644 app/repositories/authorization_codes.py create mode 100644 app/repositories/ouath_clients.py create mode 100644 app/repositories/refresh_tokens.py diff --git a/app/api/v2/__init__.py b/app/api/v2/__init__.py index 13faca62..4cd577f8 100644 --- a/app/api/v2/__init__.py +++ b/app/api/v2/__init__.py @@ -1,11 +1,46 @@ # isort: dont-add-imports +from typing import Any + from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi import status + +from app.api.v2.common.oauth import OAuth2Scheme +from app.repositories import access_tokens as access_tokens_repo + + +oauth2_scheme = OAuth2Scheme( + authorizationUrl="/v2/oauth/authorize", + tokenUrl="/v2/oauth/token", + refreshUrl="/v2/oauth/refresh", + scheme_name="OAuth2 for third-party clients.", + scopes={ + "public": "Access endpoints with public data.", + "identify": "Access endpoints with user's data.", + "admin": "Access admin endpoints.", + }, +) + + +async def get_current_client(token: str = Depends(oauth2_scheme)) -> dict[str, Any]: + """Look up the token in the Redis-based token store""" + access_token = await access_tokens_repo.fetch_one(token) + if not access_token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + return access_token + from . import clans from . import maps from . import players from . import scores +from . import oauth apiv2_router = APIRouter(tags=["API v2"], prefix="/v2") @@ -13,3 +48,4 @@ apiv2_router.include_router(maps.router) apiv2_router.include_router(players.router) apiv2_router.include_router(scores.router) +apiv2_router.include_router(oauth.router) diff --git a/app/api/v2/common/json.py b/app/api/v2/common/json.py index e5679918..f5c1ed8c 100644 --- a/app/api/v2/common/json.py +++ b/app/api/v2/common/json.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Any +from uuid import UUID import orjson from fastapi.responses import JSONResponse @@ -14,6 +15,8 @@ def _default_processor(data: Any) -> Any: return {k: _default_processor(v) for k, v in data.items()} elif isinstance(data, list): return [_default_processor(v) for v in data] + elif isinstance(data, UUID): + return str(data) else: return data @@ -22,8 +25,12 @@ def dumps(data: Any) -> bytes: return orjson.dumps(data, default=_default_processor) +def loads(data: str) -> Any: + return orjson.loads(data) + + class ORJSONResponse(JSONResponse): - media_type = "application/json" + media_type = "application/json;charset=UTF-8" def render(self, content: Any) -> bytes: return dumps(content) diff --git a/app/api/v2/common/oauth.py b/app/api/v2/common/oauth.py new file mode 100644 index 00000000..cf094f80 --- /dev/null +++ b/app/api/v2/common/oauth.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import base64 +from typing import Optional +from typing import Union + +from fastapi import Request +from fastapi import status +from fastapi.exceptions import HTTPException +from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel +from fastapi.security import OAuth2 +from fastapi.security.utils import get_authorization_scheme_param + + +class OAuth2Scheme(OAuth2): + def __init__( + self, + authorizationUrl: str, + tokenUrl: str, + refreshUrl: Optional[str] = None, + scheme_name: Optional[str] = None, + scopes: Optional[dict[str, str]] = None, + description: Optional[str] = None, + auto_error: bool = True, + ): + if not scopes: + scopes = {} + flows = OAuthFlowsModel( + authorizationCode={ + "authorizationUrl": authorizationUrl, + "tokenUrl": tokenUrl, + "refreshUrl": refreshUrl, + "scopes": scopes, + }, + clientCredentials={ + "tokenUrl": tokenUrl, + "refreshUrl": refreshUrl, + "scopes": scopes, + }, + ) + super().__init__( + flows=flows, + scheme_name=scheme_name, + description=description, + auto_error=auto_error, + ) + + async def __call__(self, request: Request) -> Optional[str]: + authorization = request.headers.get("Authorization") + scheme, param = get_authorization_scheme_param(authorization) + if not authorization or scheme.lower() != "bearer": + if self.auto_error: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + else: + return None + return param + + +# https://developer.zendesk.com/api-reference/sales-crm/authentication/requests/#client-authentication +def get_credentials_from_basic_auth( + request: Request, +) -> Optional[dict[str, Union[str, int]]]: + authorization = request.headers.get("Authorization") + scheme, param = get_authorization_scheme_param(authorization) + if not authorization or scheme.lower() != "basic": + return None + + data = base64.b64decode(param).decode("utf-8") + if ":" not in data: + return None + + data = data.split(":") + if len(data) != 2: + return None + if not data[0].isdecimal(): + return None + + return { + "client_id": int(data[0]), + "client_secret": data[1], + } diff --git a/app/api/v2/models/oauth.py b/app/api/v2/models/oauth.py new file mode 100644 index 00000000..b97fc08c --- /dev/null +++ b/app/api/v2/models/oauth.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Literal +from typing import Optional + +from . import BaseModel + + +# input models + + +# output models + + +class Token(BaseModel): + access_token: str + refresh_token: Optional[str] + token_type: Literal["Bearer"] + expires_in: int + expires_at: str + scope: str diff --git a/app/api/v2/oauth.py b/app/api/v2/oauth.py new file mode 100644 index 00000000..b703350b --- /dev/null +++ b/app/api/v2/oauth.py @@ -0,0 +1,204 @@ +""" bancho.py's v2 apis for interacting with clans """ +from __future__ import annotations + +import uuid +from typing import Any +from typing import Optional +from typing import Union + +from fastapi import APIRouter +from fastapi import Depends +from fastapi import Response +from fastapi import status +from fastapi.param_functions import Form +from fastapi.param_functions import Query + +from app.api.v2 import get_current_client +from app.api.v2.common import responses +from app.api.v2.common.oauth import get_credentials_from_basic_auth +from app.api.v2.models.oauth import Token +from app.repositories import access_tokens as access_tokens_repo +from app.repositories import authorization_codes as authorization_codes_repo +from app.repositories import ouath_clients as clients_repo +from app.repositories import refresh_tokens as refresh_tokens_repo + +router = APIRouter() + + +@router.get("/oauth/authorize", status_code=status.HTTP_302_FOUND) +async def authorize( + client_id: int = Query(), + redirect_uri: str = Query(), + response_type: str = Query(regex="code"), + player_id: int = Query(), + scope: str = Query(default="", regex=r"\b\w+\b(?:,\s*\b\w+\b)*"), + state: str = Query(default=None), +) -> str: + """Authorize a client to access the API on behalf of a user.""" + # NOTE: We should have to implement the frontend part to request the user to authorize the client + # and then redirect to the redirect_uri with the code. + # For now, we just return the code and the state if it's provided. + client = await clients_repo.fetch_one(client_id) + if client is None: + return responses.failure("invalid_client") + + if client.redirect_uri != redirect_uri: + return responses.failure("invalid_client") + + if response_type != "code": + return responses.failure("unsupported_response_type") + + code = uuid.uuid4() + await authorization_codes_repo.create(code, client_id, scope, player_id) + + if state is None: + redirect_uri = f"{redirect_uri}?code={code}" + else: + redirect_uri = f"{redirect_uri}?code={code}&state={state}" + + # return RedirectResponse(redirect_uri, status_code=status.HTTP_302_FOUND) + return redirect_uri + + +@router.post("/oauth/token", status_code=status.HTTP_200_OK) +async def token( + response: Response, + grant_type: str = Form(), + client_id: int = Form(default=None), + client_secret: str = Form(default=None), + auth_credentials: Optional[dict[str, Union[str, int]]] = Depends( + get_credentials_from_basic_auth, + ), + code: Optional[str] = Form(default=None), + scope: Optional[str] = Form(default="", regex=r"\b\w+\b(?:,\s*\b\w+\b)*"), +) -> Token: + """Get an access token for the API.""" + response.headers["Content-Type"] = "application/json; charset=utf-8" + response.headers["Cache-Control"] = "no-store, private" + response.headers["Pragma"] = "no-cache" + + if (client_id is None or client_secret is None) and auth_credentials is None: + return responses.failure("invalid_request") + + if client_id is None and client_secret is None: + if auth_credentials is None: + return responses.failure("invalid_request") + else: + client_id = auth_credentials["client_id"] + client_secret = auth_credentials["client_secret"] + + client = await clients_repo.fetch_one(client_id) + if client is None: + return responses.failure("invalid_client") + + if client["secret"] != client_secret: + return responses.failure("invalid_client") + + if grant_type == "authorization_code": + if code is None: + return responses.failure("invalid_request") + + authorization_code = await authorization_codes_repo.fetch_one(code) + if not authorization_code: + return responses.failure("invalid_grant") + + if authorization_code["client_id"] != client_id: + return responses.failure("invalid_client") + + if authorization_code["scopes"] != scope: + return responses.failure("invalid_scope") + await authorization_codes_repo.delete(code) + + refresh_token = uuid.uuid4() + raw_access_token = uuid.uuid4() + + access_token = await access_tokens_repo.create( + raw_access_token, + client_id, + grant_type, + scope, + refresh_token, + authorization_code["player_id"], + ) + await refresh_tokens_repo.create( + refresh_token, + raw_access_token, + client_id, + scope, + ) + + return Token( + access_token=str(raw_access_token), + refresh_token=str(refresh_token), + token_type="Bearer", + expires_in=3600, + expires_at=access_token["expires_at"], + scope=scope, + ) + elif grant_type == "client_credentials": + client = await clients_repo.fetch_one(client_id) + if client is None: + return responses.failure("invalid_client") + + if client["secret"] != client_secret: + return responses.failure("invalid_client") + + raw_access_token = uuid.uuid4() + access_token = await access_tokens_repo.create( + raw_access_token, + client_id, + grant_type, + scope, + expires_in=86400, + ) + + return Token( + access_token=str(raw_access_token), + token_type="Bearer", + expires_in=86400, + expires_at=access_token["expires_at"], + scope=scope, + ) + else: + return responses.failure("unsupported_grant_type") + + +@router.post("/oauth/refresh", status_code=status.HTTP_200_OK) +async def refresh( + response: Response, + client: dict[str, Any] = Depends(get_current_client), + grant_type: str = Form(), + refresh_token: str = Form(), +) -> Token: + """Refresh an access token.""" + response.headers["Content-Type"] = "application/json; charset=utf-8" + response.headers["Cache-Control"] = "no-store, private" + response.headers["Pragma"] = "no-cache" + + if grant_type != "refresh_token": + return responses.failure("unsupported_grant_type") + + if client["grant_type"] != "authorization_code": + return responses.failure("invalid_grant") + + if client["refresh_token"] != refresh_token: + return responses.failure("invalid_grant") + + raw_access_token = uuid.uuid4() + access_token = await access_tokens_repo.create( + raw_access_token, + client["client_id"], + client["grant_type"], + client["scope"], + refresh_token, + client["player_id"], + ) + + return Token( + access_token=str(raw_access_token), + refresh_token=refresh_token, + token_type="Bearer", + expires_in=3600, + expires_at=access_token["expires_at"], + scope=token["scope"], + ) diff --git a/app/repositories/access_tokens.py b/app/repositories/access_tokens.py new file mode 100644 index 00000000..618034d3 --- /dev/null +++ b/app/repositories/access_tokens.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from datetime import datetime +from datetime import timedelta +from typing import Any +from typing import Literal +from typing import Optional +from typing import Union +from uuid import UUID + +import app.state.services +from app.api.v2.common import json + + +def create_access_token_key(code: Union[UUID, str]) -> str: + return f"bancho:access_tokens:{code}" + + +async def create( + access_token: Union[UUID, str], + client_id: int, + grant_type: str, + scope: str, + refresh_token: Optional[Union[UUID, str]] = "", + player_id: Optional[int] = "", + expires_in: Optional[int] = "", +) -> dict[str, Any]: + access_token_key = create_access_token_key(access_token) + now = datetime.now() + access_token_expires_at = now + timedelta(seconds=expires_in or 3600) + + data = { + "refresh_token": refresh_token, + "client_id": client_id, + "grant_type": grant_type, + "scope": scope, + "player_id": player_id, + "created_at": now.isoformat(), + "expires_at": access_token_expires_at.isoformat(), + } + await app.state.services.redis.hmset(access_token_key, data) + await app.state.services.redis.expireat(access_token_key, access_token_expires_at) + + return data + + +async def fetch_one(access_token: Union[UUID, str]) -> Optional[dict[str, Any]]: + data = await app.state.services.redis.hgetall(create_access_token_key(access_token)) + + if data is None: + return None + + return data + + +async def fetch_all( + client_id: Optional[int] = None, + scope: Optional[str] = None, + grant_type: Optional[str] = None, + player_id: Optional[int] = None, + page: int = 1, + page_size: int = 10, +) -> list[dict[str, Any]]: + access_token_key = create_access_token_key("*") + + if page > 1: + cursor, keys = await app.state.services.redis.scan( + cursor=0, + match=access_token_key, + count=(page - 1) * page_size, + ) + else: + cursor = None + + access_tokens = [] + while cursor != 0: + cursor, keys = await app.state.services.redis.scan( + cursor=cursor or 0, + match=access_token_key, + count=page_size, + ) + + raw_access_token = await app.state.services.redis.mget(keys) + for raw_access_token in raw_access_token: + access_token = json.loads(raw_access_token) + + if client_id is not None and access_token["client_id"] != client_id: + continue + + if scope is not None and access_token["scopes"] != scope: + continue + + if grant_type is not None and access_token["grant_type"] != grant_type: + continue + + if player_id is not None and access_token["player_id"] != player_id: + continue + + access_tokens.append(access_token) + + return access_tokens + + +async def delete(access_token: Union[UUID, str]) -> Optional[dict[str, Any]]: + access_token_key = create_access_token_key(access_token) + + data = await app.state.services.redis.hgetall(access_token_key) + if data is None: + return None + + await app.state.services.redis.delete(access_token_key) + + return data diff --git a/app/repositories/authorization_codes.py b/app/repositories/authorization_codes.py new file mode 100644 index 00000000..e5352281 --- /dev/null +++ b/app/repositories/authorization_codes.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import Any +from typing import Literal +from typing import Optional +from typing import Union +from uuid import UUID + +import app.state.services +from app.api.v2.common import json + + +def create_authorization_code_key(code: Union[UUID, str]) -> str: + return f"bancho:authorization_codes:{code}" + + +async def create( + code: Union[UUID, str], + client_id: int, + scope: str, + player_id: int, +) -> None: + await app.state.services.redis.setex( + create_authorization_code_key(code), + 180, + client_id, + json.dumps({"client_id": client_id, "scope": scope, "player_id": player_id}), + ) + + +async def fetch_one(code: Union[UUID, str]) -> Optional[dict[str, Any]]: + data = await app.state.services.redis.get(create_authorization_code_key(code)) + if data is None: + return None + + return json.loads(data) + + +async def fetch_all( + client_id: Optional[int] = None, + scope: Optional[str] = None, + page: int = 1, + page_size: int = 10, +) -> list[dict[str, Any]]: + authorization_code_key = create_authorization_code_key("*") + + if page > 1: + cursor, keys = await app.state.services.redis.scan( + cursor=0, + match=authorization_code_key, + count=(page - 1) * page_size, + ) + else: + cursor = None + + authorization_codes = [] + while cursor != 0: + cursor, keys = await app.state.services.redis.scan( + cursor=cursor or 0, + match=authorization_code_key, + count=page_size, + ) + + raw_authorization_code = await app.state.services.redis.mget(keys) + for raw_authorization_code in raw_authorization_code: + authorization_code = json.loads(raw_authorization_code) + + if client_id is not None and authorization_code["client_id"] != client_id: + continue + + if scope is not None and authorization_code["scope"] != scope: + continue + + authorization_codes.append(authorization_code) + + return authorization_codes + + +async def delete(code: Union[UUID, str]) -> Optional[dict[str, Any]]: + authorization_code_key = create_authorization_code_key(code) + + data = await app.state.services.redis.get(authorization_code_key) + if data is None: + return None + + await app.state.services.redis.delete(authorization_code_key) + + return json.loads(data) diff --git a/app/repositories/ouath_clients.py b/app/repositories/ouath_clients.py new file mode 100644 index 00000000..0dbc3794 --- /dev/null +++ b/app/repositories/ouath_clients.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import textwrap +from typing import Any +from typing import Optional + +import app.state.services + +# +--------------+-------------+------+-----+---------+----------------+ +# | Field | Type | Null | Key | Default | Extra | +# +--------------+-------------+------+-----+---------+----------------+ +# | id | int | NO | PRI | NULL | auto_increment | +# | secret | varchar(32) | NO | | NULL | | +# | owner | int | NO | | NULL | | +# | redirect_uri | text | YES | | NULL | | +# +--------------+-------------+------+-----+---------+----------------+ + +READ_PARAMS = textwrap.dedent( + """\ + id, secret, owner, redirect_uri + """, +) + + +async def create( + secret: str, + owner: int, + redirect_uri: Optional[str] = None, +) -> dict[str, Any]: + """Create a new client in the database.""" + query = """\ + INSERT INTO oauth_clients (secret, owner, redirect_uri) + VALUES (:secret, :owner, :redirect_uri) + """ + params = { + "secret": secret, + "owner": owner, + "redirect_uri": redirect_uri, + } + rec_id = await app.state.services.database.execute(query, params) + + query = f"""\ + SELECT {READ_PARAMS} + FROM oauth_clients + WHERE id = :id + """ + params = { + "id": rec_id, + } + + rec = await app.state.services.database.fetch_one(query, params) + assert rec is not None + return dict(rec) + + +async def fetch_one( + id: Optional[int] = None, + owner: Optional[int] = None, + secret: Optional[str] = None, +) -> Optional[dict[str, Any]]: + """Fetch a signle client from the database.""" + if id is None and owner is None and secret is None: + raise ValueError("Must provide at least one parameter.") + + query = f"""\ + SELECT {READ_PARAMS} + FROM oauth_clients + WHERE id = COALESCE(:id, id) + AND owner = COALESCE(:owner, owner) + AND secret = COALESCE(:secret, secret) + """ + params = { + "id": id, + "owner": owner, + "secret": secret, + } + rec = await app.state.services.database.fetch_one(query, params) + return dict(rec) if rec is not None else None + + +async def fetch_many( + id: Optional[int] = None, + owner: Optional[int] = None, + secret: Optional[str] = None, + page: Optional[int] = None, + page_size: Optional[int] = None, +) -> Optional[list[dict[str, Any]]]: + """Fetch all clients from the database.""" + query = f"""\ + SELECT {READ_PARAMS} + FROM oauth_clients + WHERE id = COALESCE(:id, id) + AND owner = COALESCE(:owner, owner) + AND secret = COALESCE(:secret, secret) + """ + params = { + "id": id, + "owner": owner, + "secret": secret, + } + + if page is not None and page_size is not None: + query += """\ + LIMIT :limit + OFFSET :offset + """ + params["limit"] = page_size + params["offset"] = (page - 1) * page_size + + rec = await app.state.services.database.fetch_one(query, params) + return dict(rec) if rec is not None else None + + +async def update( + id: int, + secret: Optional[str] = None, + owner: Optional[int] = None, + redirect_uri: Optional[str] = None, +) -> Optional[dict[str, Any]]: + """Update an existing client in the database.""" + query = """\ + UPDATE oauth_clients + SET secret = COALESCE(:secret, secret), + owner = COALESCE(:owner, owner), + redirect_uri = COALESCE(:redirect_uri, redirect_uri) + WHERE id = :id + """ + params = { + "id": id, + "secret": secret, + "owner": owner, + "redirect_uri": redirect_uri, + } + await app.state.services.database.execute(query, params) + + query = f"""\ + SELECT {READ_PARAMS} + FROM oauth_clients + WHERE id = :id + """ + params = { + "id": id, + } + rec = await app.state.services.database.fetch_one(query, params) + return dict(rec) if rec is not None else None diff --git a/app/repositories/refresh_tokens.py b/app/repositories/refresh_tokens.py new file mode 100644 index 00000000..211918be --- /dev/null +++ b/app/repositories/refresh_tokens.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from datetime import datetime +from datetime import timedelta +from typing import Any +from typing import Literal +from typing import Optional +from typing import Union +from uuid import UUID + +import app.state.services +from app.api.v2.common import json + + +def create_refresh_token_key(code: Union[UUID, str]) -> str: + return f"bancho:refresh_tokens:{code}" + + +async def create( + refresh_token: Union[UUID, str], + access_token: Union[UUID, str], + client_id: int, + scope: str, +) -> dict[str, Any]: + refresh_token_key = create_refresh_token_key(refresh_token) + now = datetime.now() + refresh_token_expires_at = now + timedelta(days=30) + + data = { + "client_id": client_id, + "scope": scope, + "access_token": access_token, + "created_at": now.isoformat(), + "expires_at": refresh_token_expires_at.isoformat(), + } + await app.state.services.redis.hmset(refresh_token_key, data) + await app.state.services.redis.expireat(refresh_token_key, refresh_token_expires_at) + + return data + + +async def fetch_one(refresh_token: Union[UUID, str]) -> Optional[dict[str, Any]]: + data = await app.state.services.redis.hgetall( + create_refresh_token_key(refresh_token), + ) + if data is None: + return None + + return data + + +async def fetch_all( + client_id: Optional[int] = None, + scope: Optional[str] = None, + page: int = 1, + page_size: int = 10, +) -> list[dict[str, Any]]: + refresh_token_key = create_refresh_token_key("*") + + if page > 1: + cursor, keys = await app.state.services.redis.scan( + cursor=0, + match=refresh_token_key, + count=(page - 1) * page_size, + ) + else: + cursor = None + + refresh_tokens = [] + while cursor != 0: + cursor, keys = await app.state.services.redis.scan( + cursor=cursor or 0, + match=refresh_token_key, + count=page_size, + ) + + raw_refresh_token = await app.state.services.redis.mget(keys) + for raw_refresh_token in raw_refresh_token: + refresh_token = json.loads(raw_refresh_token) + + if client_id is not None and refresh_token["client_id"] != client_id: + continue + + if scope is not None and refresh_token["scope"] != scope: + continue + + refresh_tokens.append(refresh_token) + + return refresh_tokens + + +async def delete(refresh_token: Union[UUID, str]) -> Optional[dict[str, Any]]: + refresh_token_key = create_refresh_token_key(refresh_token) + + data = await app.state.services.redis.hgetall(refresh_token_key) + if data is None: + return None + + await app.state.services.redis.delete(refresh_token_key) + + return data diff --git a/app/state/services.py b/app/state/services.py index 0365a552..3c729c8b 100644 --- a/app/state/services.py +++ b/app/state/services.py @@ -40,7 +40,7 @@ http_client = httpx.AsyncClient() database = databases.Database(app.settings.DB_DSN) -redis: aioredis.Redis = aioredis.from_url(app.settings.REDIS_DSN) +redis: aioredis.Redis = aioredis.from_url(app.settings.REDIS_DSN, decode_responses=True) datadog: datadog_client.ThreadStats | None = None if str(app.settings.DATADOG_API_KEY) and str(app.settings.DATADOG_APP_KEY): From d54ec0938cce564f7002a778aafb38474853dae4 Mon Sep 17 00:00:00 2001 From: alowave223 Date: Mon, 16 Jan 2023 01:20:41 +0200 Subject: [PATCH 02/11] chore: added comments for RFC, added `name` to oauth_clients. --- app/api/v2/oauth.py | 2 ++ app/repositories/ouath_clients.py | 15 ++++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/app/api/v2/oauth.py b/app/api/v2/oauth.py index b703350b..15c2ccb5 100644 --- a/app/api/v2/oauth.py +++ b/app/api/v2/oauth.py @@ -73,6 +73,7 @@ async def token( scope: Optional[str] = Form(default="", regex=r"\b\w+\b(?:,\s*\b\w+\b)*"), ) -> Token: """Get an access token for the API.""" + # https://www.rfc-editor.org/rfc/rfc6749#section-5.1 response.headers["Content-Type"] = "application/json; charset=utf-8" response.headers["Cache-Control"] = "no-store, private" response.headers["Pragma"] = "no-cache" @@ -171,6 +172,7 @@ async def refresh( refresh_token: str = Form(), ) -> Token: """Refresh an access token.""" + # https://www.rfc-editor.org/rfc/rfc6749#section-5.1 response.headers["Content-Type"] = "application/json; charset=utf-8" response.headers["Cache-Control"] = "no-store, private" response.headers["Pragma"] = "no-cache" diff --git a/app/repositories/ouath_clients.py b/app/repositories/ouath_clients.py index 0dbc3794..113c36c0 100644 --- a/app/repositories/ouath_clients.py +++ b/app/repositories/ouath_clients.py @@ -10,6 +10,7 @@ # | Field | Type | Null | Key | Default | Extra | # +--------------+-------------+------+-----+---------+----------------+ # | id | int | NO | PRI | NULL | auto_increment | +# | name | varchar(16) | YES | | NULL | | # | secret | varchar(32) | NO | | NULL | | # | owner | int | NO | | NULL | | # | redirect_uri | text | YES | | NULL | | @@ -17,7 +18,7 @@ READ_PARAMS = textwrap.dedent( """\ - id, secret, owner, redirect_uri + id, name, secret, owner, redirect_uri """, ) @@ -25,16 +26,18 @@ async def create( secret: str, owner: int, + name: Optional[str] = None, redirect_uri: Optional[str] = None, ) -> dict[str, Any]: """Create a new client in the database.""" query = """\ - INSERT INTO oauth_clients (secret, owner, redirect_uri) - VALUES (:secret, :owner, :redirect_uri) + INSERT INTO oauth_clients (secret, owner, name, redirect_uri) + VALUES (:secret, :owner, :name, :redirect_uri) """ params = { "secret": secret, "owner": owner, + "name": name, "redirect_uri": redirect_uri, } rec_id = await app.state.services.database.execute(query, params) @@ -57,6 +60,7 @@ async def fetch_one( id: Optional[int] = None, owner: Optional[int] = None, secret: Optional[str] = None, + name: Optional[str] = None, ) -> Optional[dict[str, Any]]: """Fetch a signle client from the database.""" if id is None and owner is None and secret is None: @@ -68,11 +72,13 @@ async def fetch_one( WHERE id = COALESCE(:id, id) AND owner = COALESCE(:owner, owner) AND secret = COALESCE(:secret, secret) + AND name = COALESCE(:name, name) """ params = { "id": id, "owner": owner, "secret": secret, + "name": name, } rec = await app.state.services.database.fetch_one(query, params) return dict(rec) if rec is not None else None @@ -115,6 +121,7 @@ async def update( id: int, secret: Optional[str] = None, owner: Optional[int] = None, + name: Optional[str] = None, redirect_uri: Optional[str] = None, ) -> Optional[dict[str, Any]]: """Update an existing client in the database.""" @@ -123,12 +130,14 @@ async def update( SET secret = COALESCE(:secret, secret), owner = COALESCE(:owner, owner), redirect_uri = COALESCE(:redirect_uri, redirect_uri) + name = COALESCE(:name, name) WHERE id = :id """ params = { "id": id, "secret": secret, "owner": owner, + "name": name, "redirect_uri": redirect_uri, } await app.state.services.database.execute(query, params) From 6b0f1702214530a6e1077799b4370f3f9fe746ed Mon Sep 17 00:00:00 2001 From: alowave223 Date: Mon, 16 Jan 2023 01:21:23 +0200 Subject: [PATCH 03/11] chore: bum version (v4.7.3) --- migrations/migrations.sql | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/migrations/migrations.sql b/migrations/migrations.sql index e524aace..4f57865a 100644 --- a/migrations/migrations.sql +++ b/migrations/migrations.sql @@ -409,3 +409,13 @@ alter table maps drop primary key; alter table maps add primary key (id); alter table maps modify column server enum('osu!', 'private') not null default 'osu!' after id; unlock tables; + +# v4.7.3 +CREATE TABLE oauth_clients ( + id INT(10) NOT NULL AUTO_INCREMENT, + name VARCHAR(16) NULL DEFAULT NULL, + secret VARCHAR(32) NOT NULL, + owner INT(10) NOT NULL, + redirect_uri TEXT NULL DEFAULT NULL, + PRIMARY KEY (`id`) USING BTREE +) From c64582ac800799a9f8fd1e3d33d6ebb3424348dd Mon Sep 17 00:00:00 2001 From: cmyui Date: Sat, 8 Apr 2023 20:02:08 -0400 Subject: [PATCH 04/11] fix errors in code --- app/api/v2/oauth.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/app/api/v2/oauth.py b/app/api/v2/oauth.py index 15c2ccb5..02ec04bd 100644 --- a/app/api/v2/oauth.py +++ b/app/api/v2/oauth.py @@ -42,7 +42,7 @@ async def authorize( if client is None: return responses.failure("invalid_client") - if client.redirect_uri != redirect_uri: + if client["redirect_uri"] != redirect_uri: return responses.failure("invalid_client") if response_type != "code": @@ -66,11 +66,11 @@ async def token( grant_type: str = Form(), client_id: int = Form(default=None), client_secret: str = Form(default=None), - auth_credentials: Optional[dict[str, Union[str, int]]] = Depends( + auth_credentials: Optional[dict[str, Any]] = Depends( get_credentials_from_basic_auth, ), code: Optional[str] = Form(default=None), - scope: Optional[str] = Form(default="", regex=r"\b\w+\b(?:,\s*\b\w+\b)*"), + scope: str = Form(default="", regex=r"\b\w+\b(?:,\s*\b\w+\b)*"), ) -> Token: """Get an access token for the API.""" # https://www.rfc-editor.org/rfc/rfc6749#section-5.1 @@ -155,6 +155,7 @@ async def token( return Token( access_token=str(raw_access_token), + refresh_token=None, token_type="Bearer", expires_in=86400, expires_at=access_token["expires_at"], @@ -202,5 +203,5 @@ async def refresh( token_type="Bearer", expires_in=3600, expires_at=access_token["expires_at"], - scope=token["scope"], + scope=access_token["scope"], ) From b675920b0ea2ff252cd457183b7dffa97e118d4a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Feb 2024 21:21:18 +0000 Subject: [PATCH 05/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- app/api/v2/__init__.py | 3 +-- app/api/v2/common/oauth.py | 12 ++++----- app/api/v2/models/oauth.py | 3 +-- app/api/v2/oauth.py | 5 ++-- app/repositories/access_tokens.py | 22 +++++++-------- app/repositories/authorization_codes.py | 12 ++++----- app/repositories/ouath_clients.py | 36 ++++++++++++------------- app/repositories/refresh_tokens.py | 14 +++++----- 8 files changed, 53 insertions(+), 54 deletions(-) diff --git a/app/api/v2/__init__.py b/app/api/v2/__init__.py index 4cd577f8..418217d7 100644 --- a/app/api/v2/__init__.py +++ b/app/api/v2/__init__.py @@ -10,7 +10,6 @@ from app.api.v2.common.oauth import OAuth2Scheme from app.repositories import access_tokens as access_tokens_repo - oauth2_scheme = OAuth2Scheme( authorizationUrl="/v2/oauth/authorize", tokenUrl="/v2/oauth/token", @@ -38,9 +37,9 @@ async def get_current_client(token: str = Depends(oauth2_scheme)) -> dict[str, A from . import clans from . import maps +from . import oauth from . import players from . import scores -from . import oauth apiv2_router = APIRouter(tags=["API v2"], prefix="/v2") diff --git a/app/api/v2/common/oauth.py b/app/api/v2/common/oauth.py index cf094f80..7dfa40b0 100644 --- a/app/api/v2/common/oauth.py +++ b/app/api/v2/common/oauth.py @@ -17,10 +17,10 @@ def __init__( self, authorizationUrl: str, tokenUrl: str, - refreshUrl: Optional[str] = None, - scheme_name: Optional[str] = None, - scopes: Optional[dict[str, str]] = None, - description: Optional[str] = None, + refreshUrl: str | None = None, + scheme_name: str | None = None, + scopes: dict[str, str] | None = None, + description: str | None = None, auto_error: bool = True, ): if not scopes: @@ -45,7 +45,7 @@ def __init__( auto_error=auto_error, ) - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: Request) -> str | None: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "bearer": @@ -63,7 +63,7 @@ async def __call__(self, request: Request) -> Optional[str]: # https://developer.zendesk.com/api-reference/sales-crm/authentication/requests/#client-authentication def get_credentials_from_basic_auth( request: Request, -) -> Optional[dict[str, Union[str, int]]]: +) -> dict[str, str | int] | None: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "basic": diff --git a/app/api/v2/models/oauth.py b/app/api/v2/models/oauth.py index b97fc08c..e331b96d 100644 --- a/app/api/v2/models/oauth.py +++ b/app/api/v2/models/oauth.py @@ -6,7 +6,6 @@ from . import BaseModel - # input models @@ -15,7 +14,7 @@ class Token(BaseModel): access_token: str - refresh_token: Optional[str] + refresh_token: str | None token_type: Literal["Bearer"] expires_in: int expires_at: str diff --git a/app/api/v2/oauth.py b/app/api/v2/oauth.py index 02ec04bd..311dacca 100644 --- a/app/api/v2/oauth.py +++ b/app/api/v2/oauth.py @@ -1,4 +1,5 @@ """ bancho.py's v2 apis for interacting with clans """ + from __future__ import annotations import uuid @@ -66,10 +67,10 @@ async def token( grant_type: str = Form(), client_id: int = Form(default=None), client_secret: str = Form(default=None), - auth_credentials: Optional[dict[str, Any]] = Depends( + auth_credentials: dict[str, Any] | None = Depends( get_credentials_from_basic_auth, ), - code: Optional[str] = Form(default=None), + code: str | None = Form(default=None), scope: str = Form(default="", regex=r"\b\w+\b(?:,\s*\b\w+\b)*"), ) -> Token: """Get an access token for the API.""" diff --git a/app/repositories/access_tokens.py b/app/repositories/access_tokens.py index 618034d3..8d71755d 100644 --- a/app/repositories/access_tokens.py +++ b/app/repositories/access_tokens.py @@ -12,18 +12,18 @@ from app.api.v2.common import json -def create_access_token_key(code: Union[UUID, str]) -> str: +def create_access_token_key(code: UUID | str) -> str: return f"bancho:access_tokens:{code}" async def create( - access_token: Union[UUID, str], + access_token: UUID | str, client_id: int, grant_type: str, scope: str, - refresh_token: Optional[Union[UUID, str]] = "", - player_id: Optional[int] = "", - expires_in: Optional[int] = "", + refresh_token: UUID | str | None = "", + player_id: int | None = "", + expires_in: int | None = "", ) -> dict[str, Any]: access_token_key = create_access_token_key(access_token) now = datetime.now() @@ -44,7 +44,7 @@ async def create( return data -async def fetch_one(access_token: Union[UUID, str]) -> Optional[dict[str, Any]]: +async def fetch_one(access_token: UUID | str) -> dict[str, Any] | None: data = await app.state.services.redis.hgetall(create_access_token_key(access_token)) if data is None: @@ -54,10 +54,10 @@ async def fetch_one(access_token: Union[UUID, str]) -> Optional[dict[str, Any]]: async def fetch_all( - client_id: Optional[int] = None, - scope: Optional[str] = None, - grant_type: Optional[str] = None, - player_id: Optional[int] = None, + client_id: int | None = None, + scope: str | None = None, + grant_type: str | None = None, + player_id: int | None = None, page: int = 1, page_size: int = 10, ) -> list[dict[str, Any]]: @@ -101,7 +101,7 @@ async def fetch_all( return access_tokens -async def delete(access_token: Union[UUID, str]) -> Optional[dict[str, Any]]: +async def delete(access_token: UUID | str) -> dict[str, Any] | None: access_token_key = create_access_token_key(access_token) data = await app.state.services.redis.hgetall(access_token_key) diff --git a/app/repositories/authorization_codes.py b/app/repositories/authorization_codes.py index e5352281..b0b7e3ab 100644 --- a/app/repositories/authorization_codes.py +++ b/app/repositories/authorization_codes.py @@ -10,12 +10,12 @@ from app.api.v2.common import json -def create_authorization_code_key(code: Union[UUID, str]) -> str: +def create_authorization_code_key(code: UUID | str) -> str: return f"bancho:authorization_codes:{code}" async def create( - code: Union[UUID, str], + code: UUID | str, client_id: int, scope: str, player_id: int, @@ -28,7 +28,7 @@ async def create( ) -async def fetch_one(code: Union[UUID, str]) -> Optional[dict[str, Any]]: +async def fetch_one(code: UUID | str) -> dict[str, Any] | None: data = await app.state.services.redis.get(create_authorization_code_key(code)) if data is None: return None @@ -37,8 +37,8 @@ async def fetch_one(code: Union[UUID, str]) -> Optional[dict[str, Any]]: async def fetch_all( - client_id: Optional[int] = None, - scope: Optional[str] = None, + client_id: int | None = None, + scope: str | None = None, page: int = 1, page_size: int = 10, ) -> list[dict[str, Any]]: @@ -76,7 +76,7 @@ async def fetch_all( return authorization_codes -async def delete(code: Union[UUID, str]) -> Optional[dict[str, Any]]: +async def delete(code: UUID | str) -> dict[str, Any] | None: authorization_code_key = create_authorization_code_key(code) data = await app.state.services.redis.get(authorization_code_key) diff --git a/app/repositories/ouath_clients.py b/app/repositories/ouath_clients.py index 113c36c0..d4626c73 100644 --- a/app/repositories/ouath_clients.py +++ b/app/repositories/ouath_clients.py @@ -26,8 +26,8 @@ async def create( secret: str, owner: int, - name: Optional[str] = None, - redirect_uri: Optional[str] = None, + name: str | None = None, + redirect_uri: str | None = None, ) -> dict[str, Any]: """Create a new client in the database.""" query = """\ @@ -57,11 +57,11 @@ async def create( async def fetch_one( - id: Optional[int] = None, - owner: Optional[int] = None, - secret: Optional[str] = None, - name: Optional[str] = None, -) -> Optional[dict[str, Any]]: + id: int | None = None, + owner: int | None = None, + secret: str | None = None, + name: str | None = None, +) -> dict[str, Any] | None: """Fetch a signle client from the database.""" if id is None and owner is None and secret is None: raise ValueError("Must provide at least one parameter.") @@ -85,12 +85,12 @@ async def fetch_one( async def fetch_many( - id: Optional[int] = None, - owner: Optional[int] = None, - secret: Optional[str] = None, - page: Optional[int] = None, - page_size: Optional[int] = None, -) -> Optional[list[dict[str, Any]]]: + id: int | None = None, + owner: int | None = None, + secret: str | None = None, + page: int | None = None, + page_size: int | None = None, +) -> list[dict[str, Any]] | None: """Fetch all clients from the database.""" query = f"""\ SELECT {READ_PARAMS} @@ -119,11 +119,11 @@ async def fetch_many( async def update( id: int, - secret: Optional[str] = None, - owner: Optional[int] = None, - name: Optional[str] = None, - redirect_uri: Optional[str] = None, -) -> Optional[dict[str, Any]]: + secret: str | None = None, + owner: int | None = None, + name: str | None = None, + redirect_uri: str | None = None, +) -> dict[str, Any] | None: """Update an existing client in the database.""" query = """\ UPDATE oauth_clients diff --git a/app/repositories/refresh_tokens.py b/app/repositories/refresh_tokens.py index 211918be..794885f2 100644 --- a/app/repositories/refresh_tokens.py +++ b/app/repositories/refresh_tokens.py @@ -12,13 +12,13 @@ from app.api.v2.common import json -def create_refresh_token_key(code: Union[UUID, str]) -> str: +def create_refresh_token_key(code: UUID | str) -> str: return f"bancho:refresh_tokens:{code}" async def create( - refresh_token: Union[UUID, str], - access_token: Union[UUID, str], + refresh_token: UUID | str, + access_token: UUID | str, client_id: int, scope: str, ) -> dict[str, Any]: @@ -39,7 +39,7 @@ async def create( return data -async def fetch_one(refresh_token: Union[UUID, str]) -> Optional[dict[str, Any]]: +async def fetch_one(refresh_token: UUID | str) -> dict[str, Any] | None: data = await app.state.services.redis.hgetall( create_refresh_token_key(refresh_token), ) @@ -50,8 +50,8 @@ async def fetch_one(refresh_token: Union[UUID, str]) -> Optional[dict[str, Any]] async def fetch_all( - client_id: Optional[int] = None, - scope: Optional[str] = None, + client_id: int | None = None, + scope: str | None = None, page: int = 1, page_size: int = 10, ) -> list[dict[str, Any]]: @@ -89,7 +89,7 @@ async def fetch_all( return refresh_tokens -async def delete(refresh_token: Union[UUID, str]) -> Optional[dict[str, Any]]: +async def delete(refresh_token: UUID | str) -> dict[str, Any] | None: refresh_token_key = create_refresh_token_key(refresh_token) data = await app.state.services.redis.hgetall(refresh_token_key) From 88ddd71043c2b92c1380a29c5c4d1171260478ac Mon Sep 17 00:00:00 2001 From: cmyui Date: Tue, 13 Feb 2024 00:12:34 -0500 Subject: [PATCH 06/11] use enum for oauth grant type --- app/api/v2/oauth.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/app/api/v2/oauth.py b/app/api/v2/oauth.py index 311dacca..72c8bbad 100644 --- a/app/api/v2/oauth.py +++ b/app/api/v2/oauth.py @@ -3,9 +3,8 @@ from __future__ import annotations import uuid +from enum import StrEnum from typing import Any -from typing import Optional -from typing import Union from fastapi import APIRouter from fastapi import Depends @@ -26,6 +25,13 @@ router = APIRouter() +class GrantType(StrEnum): + AUTHORIZATION_CODE = "authorization_code" + CLIENT_CREDENTIALS = "client_credentials" + + # TODO: Add support for other grant types + + @router.get("/oauth/authorize", status_code=status.HTTP_302_FOUND) async def authorize( client_id: int = Query(), @@ -64,7 +70,7 @@ async def authorize( @router.post("/oauth/token", status_code=status.HTTP_200_OK) async def token( response: Response, - grant_type: str = Form(), + grant_type: GrantType = Form(), client_id: int = Form(default=None), client_secret: str = Form(default=None), auth_credentials: dict[str, Any] | None = Depends( @@ -96,7 +102,7 @@ async def token( if client["secret"] != client_secret: return responses.failure("invalid_client") - if grant_type == "authorization_code": + if grant_type is GrantType.AUTHORIZATION_CODE: if code is None: return responses.failure("invalid_request") @@ -137,7 +143,7 @@ async def token( expires_at=access_token["expires_at"], scope=scope, ) - elif grant_type == "client_credentials": + elif grant_type is GrantType.CLIENT_CREDENTIALS: client = await clients_repo.fetch_one(client_id) if client is None: return responses.failure("invalid_client") From 87ee05c3f8cc7229cced5854c2aeed543a12d3bd Mon Sep 17 00:00:00 2001 From: cmyui Date: Tue, 13 Feb 2024 00:18:19 -0500 Subject: [PATCH 07/11] format failures correctly and some fixes/organization --- app/api/v2/models/oauth.py | 11 +++++-- app/api/v2/oauth.py | 50 ++++++++++++++----------------- app/repositories/access_tokens.py | 3 -- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/app/api/v2/models/oauth.py b/app/api/v2/models/oauth.py index e331b96d..381525c8 100644 --- a/app/api/v2/models/oauth.py +++ b/app/api/v2/models/oauth.py @@ -1,8 +1,8 @@ from __future__ import annotations from datetime import datetime +from enum import StrEnum from typing import Literal -from typing import Optional from . import BaseModel @@ -12,10 +12,17 @@ # output models +class GrantType(StrEnum): + AUTHORIZATION_CODE = "authorization_code" + CLIENT_CREDENTIALS = "client_credentials" + + # TODO: Add support for other grant types + + class Token(BaseModel): access_token: str refresh_token: str | None token_type: Literal["Bearer"] expires_in: int - expires_at: str + expires_at: datetime scope: str diff --git a/app/api/v2/oauth.py b/app/api/v2/oauth.py index 72c8bbad..a5fe8c8e 100644 --- a/app/api/v2/oauth.py +++ b/app/api/v2/oauth.py @@ -3,7 +3,6 @@ from __future__ import annotations import uuid -from enum import StrEnum from typing import Any from fastapi import APIRouter @@ -14,8 +13,8 @@ from fastapi.param_functions import Query from app.api.v2 import get_current_client -from app.api.v2.common import responses from app.api.v2.common.oauth import get_credentials_from_basic_auth +from app.api.v2.models.oauth import GrantType from app.api.v2.models.oauth import Token from app.repositories import access_tokens as access_tokens_repo from app.repositories import authorization_codes as authorization_codes_repo @@ -25,11 +24,8 @@ router = APIRouter() -class GrantType(StrEnum): - AUTHORIZATION_CODE = "authorization_code" - CLIENT_CREDENTIALS = "client_credentials" - - # TODO: Add support for other grant types +def oauth_failure_response(reason: str) -> dict[str, Any]: + return {"error": reason} @router.get("/oauth/authorize", status_code=status.HTTP_302_FOUND) @@ -40,20 +36,20 @@ async def authorize( player_id: int = Query(), scope: str = Query(default="", regex=r"\b\w+\b(?:,\s*\b\w+\b)*"), state: str = Query(default=None), -) -> str: +): """Authorize a client to access the API on behalf of a user.""" # NOTE: We should have to implement the frontend part to request the user to authorize the client # and then redirect to the redirect_uri with the code. # For now, we just return the code and the state if it's provided. client = await clients_repo.fetch_one(client_id) if client is None: - return responses.failure("invalid_client") + return oauth_failure_response("invalid_client") if client["redirect_uri"] != redirect_uri: - return responses.failure("invalid_client") + return oauth_failure_response("invalid_client") if response_type != "code": - return responses.failure("unsupported_response_type") + return oauth_failure_response("unsupported_response_type") code = uuid.uuid4() await authorization_codes_repo.create(code, client_id, scope, player_id) @@ -78,7 +74,7 @@ async def token( ), code: str | None = Form(default=None), scope: str = Form(default="", regex=r"\b\w+\b(?:,\s*\b\w+\b)*"), -) -> Token: +): """Get an access token for the API.""" # https://www.rfc-editor.org/rfc/rfc6749#section-5.1 response.headers["Content-Type"] = "application/json; charset=utf-8" @@ -86,35 +82,35 @@ async def token( response.headers["Pragma"] = "no-cache" if (client_id is None or client_secret is None) and auth_credentials is None: - return responses.failure("invalid_request") + return oauth_failure_response("invalid_request") if client_id is None and client_secret is None: if auth_credentials is None: - return responses.failure("invalid_request") + return oauth_failure_response("invalid_request") else: client_id = auth_credentials["client_id"] client_secret = auth_credentials["client_secret"] client = await clients_repo.fetch_one(client_id) if client is None: - return responses.failure("invalid_client") + return oauth_failure_response("invalid_client") if client["secret"] != client_secret: - return responses.failure("invalid_client") + return oauth_failure_response("invalid_client") if grant_type is GrantType.AUTHORIZATION_CODE: if code is None: - return responses.failure("invalid_request") + return oauth_failure_response("invalid_request") authorization_code = await authorization_codes_repo.fetch_one(code) if not authorization_code: - return responses.failure("invalid_grant") + return oauth_failure_response("invalid_grant") if authorization_code["client_id"] != client_id: - return responses.failure("invalid_client") + return oauth_failure_response("invalid_client") if authorization_code["scopes"] != scope: - return responses.failure("invalid_scope") + return oauth_failure_response("invalid_scope") await authorization_codes_repo.delete(code) refresh_token = uuid.uuid4() @@ -146,10 +142,10 @@ async def token( elif grant_type is GrantType.CLIENT_CREDENTIALS: client = await clients_repo.fetch_one(client_id) if client is None: - return responses.failure("invalid_client") + return oauth_failure_response("invalid_client") if client["secret"] != client_secret: - return responses.failure("invalid_client") + return oauth_failure_response("invalid_client") raw_access_token = uuid.uuid4() access_token = await access_tokens_repo.create( @@ -169,7 +165,7 @@ async def token( scope=scope, ) else: - return responses.failure("unsupported_grant_type") + return oauth_failure_response("unsupported_grant_type") @router.post("/oauth/refresh", status_code=status.HTTP_200_OK) @@ -178,7 +174,7 @@ async def refresh( client: dict[str, Any] = Depends(get_current_client), grant_type: str = Form(), refresh_token: str = Form(), -) -> Token: +): """Refresh an access token.""" # https://www.rfc-editor.org/rfc/rfc6749#section-5.1 response.headers["Content-Type"] = "application/json; charset=utf-8" @@ -186,13 +182,13 @@ async def refresh( response.headers["Pragma"] = "no-cache" if grant_type != "refresh_token": - return responses.failure("unsupported_grant_type") + return oauth_failure_response("unsupported_grant_type") if client["grant_type"] != "authorization_code": - return responses.failure("invalid_grant") + return oauth_failure_response("invalid_grant") if client["refresh_token"] != refresh_token: - return responses.failure("invalid_grant") + return oauth_failure_response("invalid_grant") raw_access_token = uuid.uuid4() access_token = await access_tokens_repo.create( diff --git a/app/repositories/access_tokens.py b/app/repositories/access_tokens.py index 8d71755d..b684fca7 100644 --- a/app/repositories/access_tokens.py +++ b/app/repositories/access_tokens.py @@ -3,9 +3,6 @@ from datetime import datetime from datetime import timedelta from typing import Any -from typing import Literal -from typing import Optional -from typing import Union from uuid import UUID import app.state.services From 373e6c8edac887a9635d22bdabcecd1d98b4b7f2 Mon Sep 17 00:00:00 2001 From: cmyui Date: Tue, 13 Feb 2024 00:19:16 -0500 Subject: [PATCH 08/11] type fixes for params --- app/api/v2/oauth.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/app/api/v2/oauth.py b/app/api/v2/oauth.py index a5fe8c8e..6c17c09f 100644 --- a/app/api/v2/oauth.py +++ b/app/api/v2/oauth.py @@ -35,7 +35,7 @@ async def authorize( response_type: str = Query(regex="code"), player_id: int = Query(), scope: str = Query(default="", regex=r"\b\w+\b(?:,\s*\b\w+\b)*"), - state: str = Query(default=None), + state: str | None = Query(default=None), ): """Authorize a client to access the API on behalf of a user.""" # NOTE: We should have to implement the frontend part to request the user to authorize the client @@ -67,8 +67,8 @@ async def authorize( async def token( response: Response, grant_type: GrantType = Form(), - client_id: int = Form(default=None), - client_secret: str = Form(default=None), + client_id: int | None = Form(default=None), + client_secret: str | None = Form(default=None), auth_credentials: dict[str, Any] | None = Depends( get_credentials_from_basic_auth, ), From c7a84570e793d686f451f42d8e3f8a0c6f9007be Mon Sep 17 00:00:00 2001 From: cmyui Date: Tue, 13 Feb 2024 00:20:25 -0500 Subject: [PATCH 09/11] fix runtime bugs --- app/api/v2/oauth.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/app/api/v2/oauth.py b/app/api/v2/oauth.py index 6c17c09f..3e582175 100644 --- a/app/api/v2/oauth.py +++ b/app/api/v2/oauth.py @@ -106,7 +106,7 @@ async def token( if not authorization_code: return oauth_failure_response("invalid_grant") - if authorization_code["client_id"] != client_id: + if client_id is None or authorization_code["client_id"] != client_id: return oauth_failure_response("invalid_client") if authorization_code["scopes"] != scope: @@ -140,6 +140,9 @@ async def token( scope=scope, ) elif grant_type is GrantType.CLIENT_CREDENTIALS: + if client_id is None: + return oauth_failure_response("invalid_client") + client = await clients_repo.fetch_one(client_id) if client is None: return oauth_failure_response("invalid_client") From 90d60bd7e23abe17b4730b5c40580a9bb3a04e86 Mon Sep 17 00:00:00 2001 From: cmyui Date: Tue, 13 Feb 2024 00:36:55 -0500 Subject: [PATCH 10/11] fix: use fastapi classes rather than dicts --- app/api/v2/common/oauth.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/app/api/v2/common/oauth.py b/app/api/v2/common/oauth.py index 7dfa40b0..53050c08 100644 --- a/app/api/v2/common/oauth.py +++ b/app/api/v2/common/oauth.py @@ -1,13 +1,13 @@ from __future__ import annotations import base64 -from typing import Optional -from typing import Union from fastapi import Request from fastapi import status from fastapi.exceptions import HTTPException -from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlowClientCredentials +from fastapi.openapi.models import OAuthFlows from fastapi.security import OAuth2 from fastapi.security.utils import get_authorization_scheme_param @@ -25,18 +25,18 @@ def __init__( ): if not scopes: scopes = {} - flows = OAuthFlowsModel( - authorizationCode={ - "authorizationUrl": authorizationUrl, - "tokenUrl": tokenUrl, - "refreshUrl": refreshUrl, - "scopes": scopes, - }, - clientCredentials={ - "tokenUrl": tokenUrl, - "refreshUrl": refreshUrl, - "scopes": scopes, - }, + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl=authorizationUrl, + tokenUrl=tokenUrl, + scopes=scopes, + refreshUrl=refreshUrl, + ), + clientCredentials=OAuthFlowClientCredentials( + tokenUrl=tokenUrl, + scopes=scopes, + refreshUrl=refreshUrl, + ), ) super().__init__( flows=flows, From 47029a2da33ae0b25d2def514cb3d533994a2d6c Mon Sep 17 00:00:00 2001 From: cmyui Date: Tue, 13 Feb 2024 01:12:17 -0500 Subject: [PATCH 11/11] refactor oauth repos --- app/api/v2/common/oauth.py | 10 ++-- app/repositories/access_tokens.py | 73 +++++++++++++++---------- app/repositories/authorization_codes.py | 59 +++++++++++++------- app/repositories/refresh_tokens.py | 66 ++++++++++++---------- 4 files changed, 125 insertions(+), 83 deletions(-) diff --git a/app/api/v2/common/oauth.py b/app/api/v2/common/oauth.py index 53050c08..0e445544 100644 --- a/app/api/v2/common/oauth.py +++ b/app/api/v2/common/oauth.py @@ -73,13 +73,13 @@ def get_credentials_from_basic_auth( if ":" not in data: return None - data = data.split(":") - if len(data) != 2: + split = data.split(":") + if len(split) != 2: return None - if not data[0].isdecimal(): + if not split[0].isdecimal(): return None return { - "client_id": int(data[0]), - "client_secret": data[1], + "client_id": int(split[0]), + "client_secret": split[1], } diff --git a/app/repositories/access_tokens.py b/app/repositories/access_tokens.py index b684fca7..47a23491 100644 --- a/app/repositories/access_tokens.py +++ b/app/repositories/access_tokens.py @@ -3,51 +3,64 @@ from datetime import datetime from datetime import timedelta from typing import Any +from typing import Literal +from typing import TypedDict from uuid import UUID import app.state.services from app.api.v2.common import json +ACCESS_TOKEN_TTL = timedelta(hours=1) -def create_access_token_key(code: UUID | str) -> str: + +class AccessToken(TypedDict): + refresh_token: UUID | None + client_id: int + grant_type: str + scope: str + player_id: int | None + created_at: datetime + expires_at: datetime + + +def create_access_token_key(code: UUID | Literal["*"]) -> str: return f"bancho:access_tokens:{code}" async def create( - access_token: UUID | str, + access_token_id: UUID, client_id: int, grant_type: str, scope: str, - refresh_token: UUID | str | None = "", - player_id: int | None = "", - expires_in: int | None = "", -) -> dict[str, Any]: - access_token_key = create_access_token_key(access_token) + refresh_token: UUID | None = None, + player_id: int | None = None, +) -> AccessToken: now = datetime.now() - access_token_expires_at = now + timedelta(seconds=expires_in or 3600) - - data = { + expires_at = now + ACCESS_TOKEN_TTL + access_token: AccessToken = { "refresh_token": refresh_token, "client_id": client_id, "grant_type": grant_type, "scope": scope, "player_id": player_id, - "created_at": now.isoformat(), - "expires_at": access_token_expires_at.isoformat(), + "created_at": now, + "expires_at": expires_at, } - await app.state.services.redis.hmset(access_token_key, data) - await app.state.services.redis.expireat(access_token_key, access_token_expires_at) - - return data - - -async def fetch_one(access_token: UUID | str) -> dict[str, Any] | None: - data = await app.state.services.redis.hgetall(create_access_token_key(access_token)) - - if data is None: + await app.state.services.redis.set( + create_access_token_key(access_token_id), + json.dumps(access_token), + exat=expires_at, + ) + return access_token + + +async def fetch_one(access_token_id: UUID) -> AccessToken | None: + raw_access_token = await app.state.services.redis.get( + create_access_token_key(access_token_id), + ) + if raw_access_token is None: return None - - return data + return json.loads(raw_access_token) async def fetch_all( @@ -57,7 +70,7 @@ async def fetch_all( player_id: int | None = None, page: int = 1, page_size: int = 10, -) -> list[dict[str, Any]]: +) -> list[AccessToken]: access_token_key = create_access_token_key("*") if page > 1: @@ -98,13 +111,13 @@ async def fetch_all( return access_tokens -async def delete(access_token: UUID | str) -> dict[str, Any] | None: - access_token_key = create_access_token_key(access_token) +async def delete(access_token_id: UUID) -> AccessToken | None: + access_token_key = create_access_token_key(access_token_id) - data = await app.state.services.redis.hgetall(access_token_key) - if data is None: + raw_access_token = await app.state.services.redis.get(access_token_key) + if raw_access_token is None: return None await app.state.services.redis.delete(access_token_key) - return data + return json.loads(raw_access_token) diff --git a/app/repositories/authorization_codes.py b/app/repositories/authorization_codes.py index b0b7e3ab..f8fdc629 100644 --- a/app/repositories/authorization_codes.py +++ b/app/repositories/authorization_codes.py @@ -1,39 +1,60 @@ from __future__ import annotations -from typing import Any +from datetime import datetime +from datetime import timedelta from typing import Literal -from typing import Optional -from typing import Union +from typing import TypedDict from uuid import UUID import app.state.services from app.api.v2.common import json +AUTHORIZATION_CODE_TTL = timedelta(minutes=3) -def create_authorization_code_key(code: UUID | str) -> str: + +class AuthorizationCode(TypedDict): + client_id: int + scope: str + player_id: int + created_at: datetime + expires_at: datetime + + +def create_authorization_code_key(code: UUID | Literal["*"]) -> str: return f"bancho:authorization_codes:{code}" async def create( - code: UUID | str, + code: UUID, client_id: int, scope: str, player_id: int, -) -> None: - await app.state.services.redis.setex( +) -> AuthorizationCode: + now = datetime.now() + expires_at = now + AUTHORIZATION_CODE_TTL + authorization_code: AuthorizationCode = { + "client_id": client_id, + "scope": scope, + "player_id": player_id, + "created_at": now, + "expires_at": expires_at, + } + await app.state.services.redis.set( create_authorization_code_key(code), - 180, - client_id, - json.dumps({"client_id": client_id, "scope": scope, "player_id": player_id}), + json.dumps(authorization_code), + exat=expires_at, ) + return authorization_code -async def fetch_one(code: UUID | str) -> dict[str, Any] | None: - data = await app.state.services.redis.get(create_authorization_code_key(code)) - if data is None: +async def fetch_one(code: UUID) -> AuthorizationCode | None: + raw_authorization_code = await app.state.services.redis.get( + create_authorization_code_key(code), + ) + if raw_authorization_code is None: return None - return json.loads(data) + return json.loads(raw_authorization_code) async def fetch_all( @@ -41,7 +62,7 @@ async def fetch_all( scope: str | None = None, page: int = 1, page_size: int = 10, -) -> list[dict[str, Any]]: +) -> list[AuthorizationCode]: authorization_code_key = create_authorization_code_key("*") if page > 1: @@ -76,13 +97,13 @@ async def fetch_all( return authorization_codes -async def delete(code: UUID | str) -> dict[str, Any] | None: +async def delete(code: UUID) -> AuthorizationCode | None: authorization_code_key = create_authorization_code_key(code) - data = await app.state.services.redis.get(authorization_code_key) - if data is None: + raw_authorization_code = await app.state.services.redis.get(authorization_code_key) + if raw_authorization_code is None: return None await app.state.services.redis.delete(authorization_code_key) - return json.loads(data) + return json.loads(raw_authorization_code) diff --git a/app/repositories/refresh_tokens.py b/app/repositories/refresh_tokens.py index 794885f2..b1856a15 100644 --- a/app/repositories/refresh_tokens.py +++ b/app/repositories/refresh_tokens.py @@ -2,51 +2,59 @@ from datetime import datetime from datetime import timedelta -from typing import Any from typing import Literal -from typing import Optional -from typing import Union +from typing import TypedDict from uuid import UUID import app.state.services from app.api.v2.common import json -def create_refresh_token_key(code: UUID | str) -> str: +class RefreshToken(TypedDict): + client_id: int + scope: str + refresh_token_id: UUID + access_token_id: UUID + created_at: datetime + expires_at: datetime + + +def create_refresh_token_key(code: UUID | Literal["*"]) -> str: return f"bancho:refresh_tokens:{code}" async def create( - refresh_token: UUID | str, - access_token: UUID | str, + refresh_token_id: UUID, + access_token_id: UUID, client_id: int, scope: str, -) -> dict[str, Any]: - refresh_token_key = create_refresh_token_key(refresh_token) +) -> RefreshToken: now = datetime.now() - refresh_token_expires_at = now + timedelta(days=30) - - data = { + expires_at = now + timedelta(days=30) + refresh_token: RefreshToken = { "client_id": client_id, "scope": scope, - "access_token": access_token, - "created_at": now.isoformat(), - "expires_at": refresh_token_expires_at.isoformat(), + "refresh_token_id": refresh_token_id, + "access_token_id": access_token_id, + "created_at": now, + "expires_at": expires_at, } - await app.state.services.redis.hmset(refresh_token_key, data) - await app.state.services.redis.expireat(refresh_token_key, refresh_token_expires_at) - - return data + await app.state.services.redis.set( + create_refresh_token_key(refresh_token_id), + json.dumps(refresh_token), + exat=expires_at, + ) + return refresh_token -async def fetch_one(refresh_token: UUID | str) -> dict[str, Any] | None: - data = await app.state.services.redis.hgetall( - create_refresh_token_key(refresh_token), +async def fetch_one(refresh_token_id: UUID) -> RefreshToken | None: + raw_refresh_token = await app.state.services.redis.get( + create_refresh_token_key(refresh_token_id), ) - if data is None: + if raw_refresh_token is None: return None - return data + return json.loads(raw_refresh_token) async def fetch_all( @@ -54,7 +62,7 @@ async def fetch_all( scope: str | None = None, page: int = 1, page_size: int = 10, -) -> list[dict[str, Any]]: +) -> list[RefreshToken]: refresh_token_key = create_refresh_token_key("*") if page > 1: @@ -89,13 +97,13 @@ async def fetch_all( return refresh_tokens -async def delete(refresh_token: UUID | str) -> dict[str, Any] | None: - refresh_token_key = create_refresh_token_key(refresh_token) +async def delete(refresh_token_id: UUID) -> RefreshToken | None: + refresh_token_key = create_refresh_token_key(refresh_token_id) - data = await app.state.services.redis.hgetall(refresh_token_key) - if data is None: + raw_refresh_token = await app.state.services.redis.get(refresh_token_key) + if raw_refresh_token is None: return None await app.state.services.redis.delete(refresh_token_key) - return data + return json.loads(raw_refresh_token)