diff --git a/.env.example b/.env.example index 8da06f014..62426bf42 100644 --- a/.env.example +++ b/.env.example @@ -57,6 +57,10 @@ DISALLOW_OLD_CLIENTS=True DISCORD_AUDIT_LOG_WEBHOOK= +JWT_PUBLIC_KEY= +JWT_PRIVATE_KEY= +ROTATION_JWT_PRIVATE_KEY= + # automatically share information with the primary # developer of bancho.py (https://github.com/cmyui) # for debugging & development purposes. diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 796351cc5..b392447f4 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -31,6 +31,9 @@ env: DISALLOWED_PASSWORDS: ${{ vars.DISALLOWED_PASSWORDS }} DISALLOW_OLD_CLIENTS: ${{ vars.DISALLOW_OLD_CLIENTS }} DISCORD_AUDIT_LOG_WEBHOOK: ${{ vars.DISCORD_AUDIT_LOG_WEBHOOK }} + JWT_PUBLIC_KEY: ${{ vars.JWT_PUBLIC_KEY }} + JWT_PRIVATE_KEY: ${{ vars.JWT_PRIVATE_KEY }} + ROTATION_JWT_PRIVATE_KEY: ${{ vars.ROTATION_JWT_PRIVATE_KEY }} AUTOMATICALLY_REPORT_PROBLEMS: ${{ vars.AUTOMATICALLY_REPORT_PROBLEMS }} SSL_CERT_PATH: ${{ vars.SSL_CERT_PATH }} SSL_KEY_PATH: ${{ vars.SSL_KEY_PATH }} diff --git a/app/api/v2/__init__.py b/app/api/v2/__init__.py index 418217d7f..c5623bda8 100644 --- a/app/api/v2/__init__.py +++ b/app/api/v2/__init__.py @@ -1,14 +1,16 @@ # isort: dont-add-imports from typing import Any +from typing import TypedDict +import jwt from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import status +from app import settings from app.api.v2.common.oauth import OAuth2Scheme -from app.repositories import access_tokens as access_tokens_repo oauth2_scheme = OAuth2Scheme( authorizationUrl="/v2/oauth/authorize", @@ -23,16 +25,45 @@ ) -async def get_current_client(token: str = Depends(oauth2_scheme)) -> dict[str, Any]: - """Look up the token in the Redis-based token store""" - access_token = await access_tokens_repo.fetch_one(token) - if not access_token: +class AuthorizationContext(TypedDict): + verified_claims: dict[str, Any] + + +async def authenticate_api_request( + token: str = Depends(oauth2_scheme), +) -> AuthorizationContext: + verified_claims: dict[str, Any] | None = None + try: + verified_claims = jwt.decode( + token, + settings.JWT_PRIVATE_KEY, + algorithms=["HS256"], + options={"require": ["exp", "nbf", "iss", "aud", "iat"]}, + ) + except jwt.InvalidTokenError: + pass + + if verified_claims is None: + try: + verified_claims = jwt.decode( + token, + settings.ROTATION_JWT_PRIVATE_KEY, + algorithms=["HS256"], + options={"require": ["exp", "nbf", "iss", "aud", "iat"]}, + ) + except jwt.InvalidTokenError: + pass + + if verified_claims is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated", headers={"WWW-Authenticate": "Bearer"}, ) - return access_token + + return AuthorizationContext( + verified_claims=verified_claims, + ) from . import clans diff --git a/app/api/v2/common/oauth.py b/app/api/v2/common/oauth.py index 0e4455442..1faa50b55 100644 --- a/app/api/v2/common/oauth.py +++ b/app/api/v2/common/oauth.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +from typing import TypedDict from fastapi import Request from fastapi import status @@ -60,10 +61,15 @@ async def __call__(self, request: Request) -> str | None: return param +class BasicAuthCredentials(TypedDict): + client_id: str + client_secret: str + + # https://developer.zendesk.com/api-reference/sales-crm/authentication/requests/#client-authentication def get_credentials_from_basic_auth( request: Request, -) -> dict[str, str | int] | None: +) -> BasicAuthCredentials | None: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "basic": diff --git a/app/api/v2/models/oauth.py b/app/api/v2/models/oauth.py index 381525c85..f6c217f40 100644 --- a/app/api/v2/models/oauth.py +++ b/app/api/v2/models/oauth.py @@ -9,20 +9,40 @@ # input models +class ClientCredentialsGrantData(BaseModel): + scope: str | None + + +class AuthorizationCodeGrantData(BaseModel): + code: str + redirect_uri: str + client_id: str + + +class RefreshGrantData(BaseModel): + refresh_token: str + scope: str | None + + # output models class GrantType(StrEnum): AUTHORIZATION_CODE = "authorization_code" CLIENT_CREDENTIALS = "client_credentials" + REFRESH_TOKEN = "refresh_token" # TODO: Add support for other grant types +class TokenType(StrEnum): + BEARER = "Bearer" + + class Token(BaseModel): access_token: str refresh_token: str | None token_type: Literal["Bearer"] expires_in: int expires_at: datetime - scope: str + scope: str | None diff --git a/app/api/v2/oauth.py b/app/api/v2/oauth.py index 3e5821752..7e41c117b 100644 --- a/app/api/v2/oauth.py +++ b/app/api/v2/oauth.py @@ -2,9 +2,15 @@ from __future__ import annotations +import urllib.parse import uuid +from dataclasses import dataclass +from datetime import datetime +from datetime import timedelta from typing import Any +from typing import Literal +import jwt from fastapi import APIRouter from fastapi import Depends from fastapi import Response @@ -12,32 +18,49 @@ from fastapi.param_functions import Form from fastapi.param_functions import Query -from app.api.v2 import get_current_client +from app import settings +from app.api.v2 import AuthorizationContext +from app.api.v2 import authenticate_api_request +from app.api.v2.common.oauth import BasicAuthCredentials from app.api.v2.common.oauth import get_credentials_from_basic_auth +from app.api.v2.models.oauth import AuthorizationCodeGrantData +from app.api.v2.models.oauth import ClientCredentialsGrantData from app.api.v2.models.oauth import GrantType from app.api.v2.models.oauth import Token -from app.repositories import access_tokens as access_tokens_repo +from app.api.v2.models.oauth import TokenType from app.repositories import authorization_codes as authorization_codes_repo from app.repositories import ouath_clients as clients_repo from app.repositories import refresh_tokens as refresh_tokens_repo router = APIRouter() +ACCESS_TOKEN_TTL = timedelta(minutes=5) + def oauth_failure_response(reason: str) -> dict[str, Any]: return {"error": reason} +def generate_authorization_code() -> str: + return str(uuid.uuid4()) + + @router.get("/oauth/authorize", status_code=status.HTTP_302_FOUND) async def authorize( - client_id: int = Query(), - redirect_uri: str = Query(), - response_type: str = Query(regex="code"), - player_id: int = Query(), - scope: str = Query(default="", regex=r"\b\w+\b(?:,\s*\b\w+\b)*"), - state: str | None = Query(default=None), + client_id: str = Query(...), + redirect_uri: str = Query(...), + # TODO: support for "token" response type in implcit flow? + # https://www.rfc-editor.org/rfc/rfc6749#section-3.1.1 + response_type: Literal["code"] = Query(...), + player_id: int = Query(...), + scope: str | None = Query(default=None, regex=r"\b\w+\b(?:,\s*\b\w+\b)*"), + state: str | None = Query(default=None), # csrf protection ): - """Authorize a client to access the API on behalf of a user.""" + """\ + Authorize a client to access the API on behalf of a user. + + Used by the authorizaton_grant and implicit grant flows. + """ # NOTE: We should have to implement the frontend part to request the user to authorize the client # and then redirect to the redirect_uri with the code. # For now, we just return the code and the state if it's provided. @@ -51,29 +74,103 @@ async def authorize( if response_type != "code": return oauth_failure_response("unsupported_response_type") - code = uuid.uuid4() + code = generate_authorization_code() await authorization_codes_repo.create(code, client_id, scope, player_id) - if state is None: - redirect_uri = f"{redirect_uri}?code={code}" - else: - redirect_uri = f"{redirect_uri}?code={code}&state={state}" + params: dict[str, Any] = { + "code": code, + } + if state is not None: + params["state"] = state + + redirect_uri = redirect_uri + "?" + urllib.parse.urlencode(params) - # return RedirectResponse(redirect_uri, status_code=status.HTTP_302_FOUND) return redirect_uri +def generate_access_token( + access_token_id: uuid.UUID, + issued_at: datetime, + expires_at: datetime, + client_id: str, + grant_type: GrantType, + scope: str | None, + issuer: str = "bancho", + audiences: list[str] = ["bancho"], + additional_claims: dict[str, Any] | None = None, +) -> str: + if additional_claims is None: + additional_claims = {} + new_claims = { + # registered claims + "exp": expires_at, + "nbf": issued_at, + "iss": issuer, + "aud": audiences, + "iat": issued_at, + # unregistered claims + "access_token_id": access_token_id, + "client_id": client_id, + "grant_type": grant_type.value, + "scope": scope, + **additional_claims, + } + raw_access_token = jwt.encode( + new_claims, + settings.JWT_PRIVATE_KEY, + algorithm="HS256", + ) + return raw_access_token + + +def generate_refresh_token( + refresh_token_id: uuid.UUID, + issued_at: datetime, + expires_at: datetime, + client_id: str, + scope: str | None, + issuer: str = "bancho", + audiences: list[str] = ["bancho"], + additional_claims: dict[str, Any] | None = None, +) -> str: + if additional_claims is None: + additional_claims = {} + new_claims = { + # registered claims + "exp": expires_at, + "nbf": issued_at, + "iss": issuer, + "aud": audiences, + "iat": issued_at, + # unregistered claims + "refresh_token_id": refresh_token_id, + "client_id": client_id, + "scope": scope, + **additional_claims, + } + raw_refresh_token = jwt.encode( + new_claims, + settings.JWT_PRIVATE_KEY, + algorithm="HS256", + ) + return raw_refresh_token + + @router.post("/oauth/token", status_code=status.HTTP_200_OK) async def token( response: Response, grant_type: GrantType = Form(), - client_id: int | None = Form(default=None), - client_secret: str | None = Form(default=None), - auth_credentials: dict[str, Any] | None = Depends( - get_credentials_from_basic_auth, - ), - code: str | None = Form(default=None), - scope: str = Form(default="", regex=r"\b\w+\b(?:,\s*\b\w+\b)*"), + scope: str | None = Form(default=None, regex=r"\b\w+\b(?:,\s*\b\w+\b)*"), + # specific args to authorization code grant + code: str | None = Form(None), + redirect_uri: str | None = Form(None), + client_id: str | None = Form(None), + # args specific to refresh grant + refresh_token: str = Form(...), + # TODO: support basic authentication + # auth_credentials: BasicAuthCredentials | None = Depends( + # get_credentials_from_basic_auth, + # ), ): """Get an access token for the API.""" # https://www.rfc-editor.org/rfc/rfc6749#section-5.1 @@ -81,6 +178,31 @@ async def token( response.headers["Cache-Control"] = "no-store, private" response.headers["Pragma"] = "no-cache" + if grant_type is GrantType.CLIENT_CREDENTIALS: + client_credentials_grant_data = ClientCredentialsGrantData(scope=scope) + + elif grant_type is GrantType.AUTHORIZATION_CODE: + if code is None or redirect_uri is None or client_id is None: + return oauth_failure_response("invalid_request") + + authorization_code_grant_form = AuthorizationCodeGrantData( + code=code, + redirect_uri=redirect_uri, + client_id=client_id, + ) + + client = await clients_repo.fetch_one(client_id) + if client is None: + return oauth_failure_response("invalid_client") + + if client["secret"] != code: + return oauth_failure_response("invalid_client") + ... + elif grant_type is GrantType.REFRESH_TOKEN: + ... + else: + return oauth_failure_response("unsupported_grant_type") + if (client_id is None or client_secret is None) and auth_credentials is None: return oauth_failure_response("invalid_request") @@ -99,44 +221,48 @@ async def token( return oauth_failure_response("invalid_client") if grant_type is GrantType.AUTHORIZATION_CODE: - if code is None: + if authorization_code_grant_form is None: return oauth_failure_response("invalid_request") - authorization_code = await authorization_codes_repo.fetch_one(code) + if authorization_code_grant_form.code is None: + return oauth_failure_response("invalid_request") + + authorization_code = await authorization_codes_repo.fetch_one( + authorization_code_grant_form.code, + ) if not authorization_code: return oauth_failure_response("invalid_grant") if client_id is None or authorization_code["client_id"] != client_id: return oauth_failure_response("invalid_client") - if authorization_code["scopes"] != scope: + if authorization_code["scope"] != scope: return oauth_failure_response("invalid_scope") - await authorization_codes_repo.delete(code) - refresh_token = uuid.uuid4() - raw_access_token = uuid.uuid4() + await authorization_codes_repo.delete(code) - access_token = await access_tokens_repo.create( - raw_access_token, - client_id, - grant_type, - scope, - refresh_token, - authorization_code["player_id"], - ) - await refresh_tokens_repo.create( - refresh_token, - raw_access_token, - client_id, - scope, + access_token_id = uuid.uuid4() + now = datetime.now() + expires_at = now + ACCESS_TOKEN_TTL + raw_access_token = generate_access_token( + access_token_id=access_token_id, + issued_at=now, + expires_at=expires_at, + client_id=str(client_id), + grant_type=grant_type, + scope=scope, + additional_claims={"player_id": authorization_code["player_id"]}, ) + refresh_token_id = uuid.uuid4() + + await refresh_tokens_repo.create(refresh_token_id, client_id, scope) return Token( access_token=str(raw_access_token), - refresh_token=str(refresh_token), - token_type="Bearer", + refresh_token=str(refresh_token_id), + token_type=TokenType.BEARER.value, expires_in=3600, - expires_at=access_token["expires_at"], + expires_at=expires_at, scope=scope, ) elif grant_type is GrantType.CLIENT_CREDENTIALS: @@ -150,21 +276,23 @@ async def token( if client["secret"] != client_secret: return oauth_failure_response("invalid_client") - raw_access_token = uuid.uuid4() - access_token = await access_tokens_repo.create( - raw_access_token, - client_id, - grant_type, - scope, - expires_in=86400, + access_token_id = uuid.uuid4() + now = datetime.now() + expires_at = now + ACCESS_TOKEN_TTL + raw_access_token = generate_access_token( + access_token_id=access_token_id, + issued_at=now, + expires_at=expires_at, + client_id=str(client_id), + grant_type=grant_type, + scope=scope, ) - return Token( - access_token=str(raw_access_token), + access_token=raw_access_token, refresh_token=None, - token_type="Bearer", - expires_in=86400, - expires_at=access_token["expires_at"], + token_type=TokenType.BEARER.value, + expires_in=int((expires_at - now).total_seconds()), + expires_at=expires_at, scope=scope, ) else: @@ -174,9 +302,9 @@ async def token( @router.post("/oauth/refresh", status_code=status.HTTP_200_OK) async def refresh( response: Response, - client: dict[str, Any] = Depends(get_current_client), - grant_type: str = Form(), - refresh_token: str = Form(), + auth_ctx: AuthorizationContext = Depends(authenticate_api_request), + grant_type: GrantType = Form(), + raw_refresh_token: str = Form(), ): """Refresh an access token.""" # https://www.rfc-editor.org/rfc/rfc6749#section-5.1 @@ -184,30 +312,36 @@ async def refresh( response.headers["Cache-Control"] = "no-store, private" response.headers["Pragma"] = "no-cache" - if grant_type != "refresh_token": + verified_claims = auth_ctx["verified_claims"] + + if grant_type is not GrantType.REFRESH_TOKEN: return oauth_failure_response("unsupported_grant_type") - if client["grant_type"] != "authorization_code": + if verified_claims["grant_type"] != "authorization_code": return oauth_failure_response("invalid_grant") - if client["refresh_token"] != refresh_token: + if verified_claims["refresh_token"] != raw_refresh_token: return oauth_failure_response("invalid_grant") - raw_access_token = uuid.uuid4() - access_token = await access_tokens_repo.create( - raw_access_token, - client["client_id"], - client["grant_type"], - client["scope"], - refresh_token, - client["player_id"], + access_token_id = uuid.uuid4() + now = datetime.now() + expires_at = now + ACCESS_TOKEN_TTL + raw_access_token = generate_access_token( + access_token_id=access_token_id, + issued_at=now, + expires_at=expires_at, + client_id=verified_claims["client_id"], + grant_type=grant_type, + scope=verified_claims["scope"], + additional_claims={"player_id": verified_claims["player_id"]}, ) + # TODO: should we generate a new refresh token? return Token( - access_token=str(raw_access_token), - refresh_token=refresh_token, - token_type="Bearer", - expires_in=3600, - expires_at=access_token["expires_at"], - scope=access_token["scope"], + access_token=raw_access_token, + refresh_token=raw_refresh_token, + token_type=TokenType.BEARER.value, + expires_in=int((expires_at - now).total_seconds()), + expires_at=expires_at, + scope=verified_claims["scope"], ) diff --git a/app/repositories/access_tokens.py b/app/repositories/access_tokens.py deleted file mode 100644 index 47a23491e..000000000 --- a/app/repositories/access_tokens.py +++ /dev/null @@ -1,123 +0,0 @@ -from __future__ import annotations - -from datetime import datetime -from datetime import timedelta -from typing import Any -from typing import Literal -from typing import TypedDict -from uuid import UUID - -import app.state.services -from app.api.v2.common import json - -ACCESS_TOKEN_TTL = timedelta(hours=1) - - -class AccessToken(TypedDict): - refresh_token: UUID | None - client_id: int - grant_type: str - scope: str - player_id: int | None - created_at: datetime - expires_at: datetime - - -def create_access_token_key(code: UUID | Literal["*"]) -> str: - return f"bancho:access_tokens:{code}" - - -async def create( - access_token_id: UUID, - client_id: int, - grant_type: str, - scope: str, - refresh_token: UUID | None = None, - player_id: int | None = None, -) -> AccessToken: - now = datetime.now() - expires_at = now + ACCESS_TOKEN_TTL - access_token: AccessToken = { - "refresh_token": refresh_token, - "client_id": client_id, - "grant_type": grant_type, - "scope": scope, - "player_id": player_id, - "created_at": now, - "expires_at": expires_at, - } - await app.state.services.redis.set( - create_access_token_key(access_token_id), - json.dumps(access_token), - exat=expires_at, - ) - return access_token - - -async def fetch_one(access_token_id: UUID) -> AccessToken | None: - raw_access_token = await app.state.services.redis.get( - create_access_token_key(access_token_id), - ) - if raw_access_token is None: - return None - return json.loads(raw_access_token) - - -async def fetch_all( - client_id: int | None = None, - scope: str | None = None, - grant_type: str | None = None, - player_id: int | None = None, - page: int = 1, - page_size: int = 10, -) -> list[AccessToken]: - access_token_key = create_access_token_key("*") - - if page > 1: - cursor, keys = await app.state.services.redis.scan( - cursor=0, - match=access_token_key, - count=(page - 1) * page_size, - ) - else: - cursor = None - - access_tokens = [] - while cursor != 0: - cursor, keys = await app.state.services.redis.scan( - cursor=cursor or 0, - match=access_token_key, - count=page_size, - ) - - raw_access_token = await app.state.services.redis.mget(keys) - for raw_access_token in raw_access_token: - access_token = json.loads(raw_access_token) - - if client_id is not None and access_token["client_id"] != client_id: - continue - - if scope is not None and access_token["scopes"] != scope: - continue - - if grant_type is not None and access_token["grant_type"] != grant_type: - continue - - if player_id is not None and access_token["player_id"] != player_id: - continue - - access_tokens.append(access_token) - - return access_tokens - - -async def delete(access_token_id: UUID) -> AccessToken | None: - access_token_key = create_access_token_key(access_token_id) - - raw_access_token = await app.state.services.redis.get(access_token_key) - if raw_access_token is None: - return None - - await app.state.services.redis.delete(access_token_key) - - return json.loads(raw_access_token) diff --git a/app/repositories/authorization_codes.py b/app/repositories/authorization_codes.py index f8fdc6294..d9998d28f 100644 --- a/app/repositories/authorization_codes.py +++ b/app/repositories/authorization_codes.py @@ -13,21 +13,21 @@ class AuthorizationCode(TypedDict): - client_id: int - scope: str + client_id: str + scope: str | None player_id: int created_at: datetime expires_at: datetime -def create_authorization_code_key(code: UUID | Literal["*"]) -> str: +def create_authorization_code_key(code: str | Literal["*"]) -> str: return f"bancho:authorization_codes:{code}" async def create( - code: UUID, - client_id: int, - scope: str, + code: str, + client_id: str, + scope: str | None, player_id: int, ) -> AuthorizationCode: now = datetime.now() @@ -47,7 +47,7 @@ async def create( return authorization_code -async def fetch_one(code: UUID) -> AuthorizationCode | None: +async def fetch_one(code: str) -> AuthorizationCode | None: raw_authorization_code = await app.state.services.redis.get( create_authorization_code_key(code), ) @@ -58,7 +58,7 @@ async def fetch_one(code: UUID) -> AuthorizationCode | None: async def fetch_all( - client_id: int | None = None, + client_id: str | None = None, scope: str | None = None, page: int = 1, page_size: int = 10, @@ -97,7 +97,7 @@ async def fetch_all( return authorization_codes -async def delete(code: UUID) -> AuthorizationCode | None: +async def delete(code: str) -> AuthorizationCode | None: authorization_code_key = create_authorization_code_key(code) raw_authorization_code = await app.state.services.redis.get(authorization_code_key) diff --git a/app/repositories/ouath_clients.py b/app/repositories/ouath_clients.py index d4626c731..1b13568af 100644 --- a/app/repositories/ouath_clients.py +++ b/app/repositories/ouath_clients.py @@ -1,15 +1,15 @@ from __future__ import annotations import textwrap -from typing import Any -from typing import Optional +from typing import TypedDict +from typing import cast import app.state.services # +--------------+-------------+------+-----+---------+----------------+ # | Field | Type | Null | Key | Default | Extra | # +--------------+-------------+------+-----+---------+----------------+ -# | id | int | NO | PRI | NULL | auto_increment | +# | id | varchar(64) | NO | PRI | NULL | auto_increment | # | name | varchar(16) | YES | | NULL | | # | secret | varchar(32) | NO | | NULL | | # | owner | int | NO | | NULL | | @@ -23,12 +23,20 @@ ) +class OAuthClient(TypedDict): + id: int + name: str | None + secret: str + owner: int + redirect_uri: str | None + + async def create( secret: str, owner: int, name: str | None = None, redirect_uri: str | None = None, -) -> dict[str, Any]: +) -> OAuthClient: """Create a new client in the database.""" query = """\ INSERT INTO oauth_clients (secret, owner, name, redirect_uri) @@ -53,58 +61,37 @@ async def create( rec = await app.state.services.database.fetch_one(query, params) assert rec is not None - return dict(rec) + return cast(OAuthClient, dict(rec._mapping)) -async def fetch_one( - id: int | None = None, - owner: int | None = None, - secret: str | None = None, - name: str | None = None, -) -> dict[str, Any] | None: +async def fetch_one(id: str) -> OAuthClient | None: """Fetch a signle client from the database.""" - if id is None and owner is None and secret is None: - raise ValueError("Must provide at least one parameter.") - query = f"""\ SELECT {READ_PARAMS} FROM oauth_clients - WHERE id = COALESCE(:id, id) - AND owner = COALESCE(:owner, owner) - AND secret = COALESCE(:secret, secret) - AND name = COALESCE(:name, name) + WHERE id = :id """ params = { "id": id, - "owner": owner, - "secret": secret, - "name": name, } rec = await app.state.services.database.fetch_one(query, params) - return dict(rec) if rec is not None else None + return cast(OAuthClient, dict(rec._mapping)) if rec is not None else None async def fetch_many( - id: int | None = None, owner: int | None = None, - secret: str | None = None, page: int | None = None, page_size: int | None = None, -) -> list[dict[str, Any]] | None: +) -> list[OAuthClient]: """Fetch all clients from the database.""" query = f"""\ SELECT {READ_PARAMS} FROM oauth_clients - WHERE id = COALESCE(:id, id) - AND owner = COALESCE(:owner, owner) - AND secret = COALESCE(:secret, secret) + WHERE owner = COALESCE(:owner, owner) """ params = { - "id": id, "owner": owner, - "secret": secret, } - if page is not None and page_size is not None: query += """\ LIMIT :limit @@ -113,8 +100,8 @@ async def fetch_many( params["limit"] = page_size params["offset"] = (page - 1) * page_size - rec = await app.state.services.database.fetch_one(query, params) - return dict(rec) if rec is not None else None + recs = await app.state.services.database.fetch_all(query, params) + return cast(list[OAuthClient], [dict(rec._mapping) for rec in recs]) async def update( @@ -123,13 +110,13 @@ async def update( owner: int | None = None, name: str | None = None, redirect_uri: str | None = None, -) -> dict[str, Any] | None: +) -> OAuthClient | None: """Update an existing client in the database.""" query = """\ UPDATE oauth_clients SET secret = COALESCE(:secret, secret), owner = COALESCE(:owner, owner), - redirect_uri = COALESCE(:redirect_uri, redirect_uri) + redirect_uri = COALESCE(:redirect_uri, redirect_uri), name = COALESCE(:name, name) WHERE id = :id """ @@ -151,4 +138,4 @@ async def update( "id": id, } rec = await app.state.services.database.fetch_one(query, params) - return dict(rec) if rec is not None else None + return cast(OAuthClient, dict(rec._mapping)) if rec is not None else None diff --git a/app/repositories/refresh_tokens.py b/app/repositories/refresh_tokens.py index b1856a15c..641922f52 100644 --- a/app/repositories/refresh_tokens.py +++ b/app/repositories/refresh_tokens.py @@ -11,23 +11,21 @@ class RefreshToken(TypedDict): - client_id: int - scope: str + client_id: str + scope: str | None refresh_token_id: UUID - access_token_id: UUID created_at: datetime expires_at: datetime -def create_refresh_token_key(code: UUID | Literal["*"]) -> str: - return f"bancho:refresh_tokens:{code}" +def create_refresh_token_key(refresh_token_id: UUID | Literal["*"]) -> str: + return f"bancho:refresh_tokens:{refresh_token_id}" async def create( refresh_token_id: UUID, - access_token_id: UUID, - client_id: int, - scope: str, + client_id: str, + scope: str | None, ) -> RefreshToken: now = datetime.now() expires_at = now + timedelta(days=30) @@ -35,7 +33,6 @@ async def create( "client_id": client_id, "scope": scope, "refresh_token_id": refresh_token_id, - "access_token_id": access_token_id, "created_at": now, "expires_at": expires_at, } @@ -58,7 +55,7 @@ async def fetch_one(refresh_token_id: UUID) -> RefreshToken | None: async def fetch_all( - client_id: int | None = None, + client_id: str | None = None, scope: str | None = None, page: int = 1, page_size: int = 10, diff --git a/app/settings.py b/app/settings.py index f925e7566..0efc38ebe 100644 --- a/app/settings.py +++ b/app/settings.py @@ -70,6 +70,11 @@ DISCORD_AUDIT_LOG_WEBHOOK = os.environ["DISCORD_AUDIT_LOG_WEBHOOK"] +# TODO: store public keys in db; abstract towards jwks +JWT_PUBLIC_KEY = os.environ["JWT_PUBLIC_KEY"] +JWT_PRIVATE_KEY = os.environ["JWT_PRIVATE_KEY"] +ROTATION_JWT_PRIVATE_KEY = os.environ["ROTATION_JWT_PRIVATE_KEY"] + AUTOMATICALLY_REPORT_PROBLEMS = read_bool(os.environ["AUTOMATICALLY_REPORT_PROBLEMS"]) # advanced dev settings diff --git a/docker-compose.yml b/docker-compose.yml index 6484f0722..a5bf27b6f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -81,6 +81,9 @@ services: - DISALLOWED_PASSWORDS=${DISALLOWED_PASSWORDS} - DISALLOW_OLD_CLIENTS=${DISALLOW_OLD_CLIENTS} - DISCORD_AUDIT_LOG_WEBHOOK=${DISCORD_AUDIT_LOG_WEBHOOK} + - JWT_PUBLIC_KEY=${JWT_PUBLIC_KEY} + - JWT_PRIVATE_KEY=${JWT_PRIVATE_KEY} + - ROTATION_JWT_PRIVATE_KEY=${ROTATION_JWT_PRIVATE_KEY} - AUTOMATICALLY_REPORT_PROBLEMS=${AUTOMATICALLY_REPORT_PROBLEMS} - SSL_CERT_PATH=${SSL_CERT_PATH} - SSL_KEY_PATH=${SSL_KEY_PATH} diff --git a/migrations/migrations.sql b/migrations/migrations.sql index 4f57865aa..9a3e31659 100644 --- a/migrations/migrations.sql +++ b/migrations/migrations.sql @@ -410,12 +410,13 @@ alter table maps add primary key (id); alter table maps modify column server enum('osu!', 'private') not null default 'osu!' after id; unlock tables; -# v4.7.3 -CREATE TABLE oauth_clients ( - id INT(10) NOT NULL AUTO_INCREMENT, - name VARCHAR(16) NULL DEFAULT NULL, - secret VARCHAR(32) NOT NULL, - owner INT(10) NOT NULL, - redirect_uri TEXT NULL DEFAULT NULL, - PRIMARY KEY (`id`) USING BTREE +# v5.1.0 +create table oauth_clients ( + rec_id int(11) not null auto_increment, + client_id varchar(256) not null, + name varchar(16) null default null, + secret varchar(32) not null, + owner int(10) not null, + redirect_uri text null default null, + primary key (`rec_id`) ) diff --git a/poetry.lock b/poetry.lock index e6eeaf94e..395f56023 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1310,6 +1310,23 @@ files = [ {file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"}, ] +[[package]] +name = "pyjwt" +version = "2.8.0" +description = "JSON Web Token implementation in Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"}, + {file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"}, +] + +[package.extras] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] +docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] +tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] + [[package]] name = "pymysql" version = "1.1.0" @@ -1824,4 +1841,4 @@ cython = "*" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "e7e641515cdcf47b22feeccfde0b7853022fb56712b640c1bae9a6cff1772f54" +content-hash = "ed6552fb694329d4f1459e38ec9dac3d713335fa195aadee07638ad14254c52f" diff --git a/pyproject.toml b/pyproject.toml index 616c63f9a..06895b18c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ pytest = "8.0.0" pytest-asyncio = "0.23.5" asgi-lifespan = "2.1.0" respx = "0.20.2" +pyjwt = "^2.8.0" [tool.poetry.group.dev.dependencies] pre-commit = "3.6.1"