diff --git a/fastapi_jwt_auth/__init__.py b/fastapi_jwt_auth/__init__.py index 1c4486a..8b796d2 100644 --- a/fastapi_jwt_auth/__init__.py +++ b/fastapi_jwt_auth/__init__.py @@ -1,5 +1,6 @@ """FastAPI extension that provides JWT Auth support (secure, easy to use and lightweight)""" -__version__ = "0.6.2" +__version__ = "0.6.3" from .auth_jwt import AuthJWT +from .auth_jwt import AuthJWTRefresh diff --git a/fastapi_jwt_auth/auth_config.py b/fastapi_jwt_auth/auth_config.py index b259f2e..cdda61f 100644 --- a/fastapi_jwt_auth/auth_config.py +++ b/fastapi_jwt_auth/auth_config.py @@ -3,9 +3,10 @@ from typing import Callable, List from datetime import timedelta + class AuthConfig: _token = None - _token_location = {'headers'} + _token_location = {"headers"} _secret_key = None _public_key = None @@ -17,7 +18,7 @@ class AuthConfig: _decode_issuer = None _decode_audience = None _denylist_enabled = False - _denylist_token_checks = {'access','refresh'} + _denylist_token_checks = {"access", "refresh"} _header_name = "Authorization" _header_type = "Bearer" _token_in_denylist_callback = None @@ -42,20 +43,24 @@ class AuthConfig: _refresh_csrf_cookie_path = "/" _access_csrf_header_name = "X-CSRF-Token" _refresh_csrf_header_name = "X-CSRF-Token" - _csrf_methods = {'POST','PUT','PATCH','DELETE'} + _csrf_methods = {"POST", "PUT", "PATCH", "DELETE"} @property def jwt_in_cookies(self) -> bool: - return 'cookies' in self._token_location + return "cookies" in self._token_location @property def jwt_in_headers(self) -> bool: - return 'headers' in self._token_location + return "headers" in self._token_location + + @property + def jwt_in_body(self) -> bool: + return "body" in self._token_location @classmethod - def load_config(cls, settings: Callable[...,List[tuple]]) -> "AuthConfig": + def load_config(cls, settings: Callable[..., List[tuple]]) -> "AuthConfig": try: - config = LoadConfig(**{key.lower():value for key,value in settings()}) + config = LoadConfig(**{key.lower(): value for key, value in settings()}) cls._token_location = config.authjwt_token_location cls._secret_key = config.authjwt_secret_key @@ -97,7 +102,7 @@ def load_config(cls, settings: Callable[...,List[tuple]]) -> "AuthConfig": raise TypeError("Config must be pydantic 'BaseSettings' or list of tuple") @classmethod - def token_in_denylist_loader(cls, callback: Callable[...,bool]) -> "AuthConfig": + def token_in_denylist_loader(cls, callback: Callable[..., bool]) -> "AuthConfig": """ This decorator sets the callback function that will be called when a protected endpoint is accessed and will check if the JWT has been diff --git a/fastapi_jwt_auth/auth_jwt.py b/fastapi_jwt_auth/auth_jwt.py index 4110bdb..81c707b 100644 --- a/fastapi_jwt_auth/auth_jwt.py +++ b/fastapi_jwt_auth/auth_jwt.py @@ -1,5 +1,7 @@ -import jwt, re, uuid, hmac -from jwt.algorithms import requires_cryptography, has_crypto +import hmac +import jwt +import re +import uuid from datetime import datetime, timezone, timedelta from typing import Optional, Dict, Union, Sequence from fastapi import Request, Response, WebSocket @@ -12,11 +14,18 @@ MissingTokenError, AccessTokenRequired, RefreshTokenRequired, - FreshTokenRequired + FreshTokenRequired, ) +from jwt.algorithms import requires_cryptography, has_crypto +from pydantic import BaseModel + + +class BodyToken(BaseModel): + refresh_token: Optional[str] = None + class AuthJWT(AuthConfig): - def __init__(self,req: Request = None, res: Response = None): + def __init__(self, req: Request = None, res: Response = None): """ Get jwt header from incoming request or get request and response object if jwt in the cookie @@ -24,6 +33,7 @@ def __init__(self,req: Request = None, res: Response = None): :param req: all incoming request :param res: response from endpoint """ + if res and self.jwt_in_cookies: self._response = res @@ -34,9 +44,10 @@ def __init__(self,req: Request = None, res: Response = None): # get jwt in headers when headers in token location if self.jwt_in_headers: auth = req.headers.get(self._header_name.lower()) - if auth: self._get_jwt_from_headers(auth) + if auth: + self._get_jwt_from_headers(auth) - def _get_jwt_from_headers(self,auth: str) -> "AuthJWT": + def _get_jwt_from_headers(self, auth: str) -> None: """ Get token from the headers @@ -51,26 +62,28 @@ def _get_jwt_from_headers(self,auth: str) -> "AuthJWT": # : if len(parts) != 1: msg = "Bad {} header. Expected value ''".format(header_name) - raise InvalidHeaderError(status_code=422,message=msg) + raise InvalidHeaderError(status_code=422, message=msg) self._token = parts[0] else: # : - if not re.match(r"{}\s".format(header_type),auth) or len(parts) != 2: - msg = "Bad {} header. Expected value '{} '".format(header_name,header_type) - raise InvalidHeaderError(status_code=422,message=msg) + if not re.match(r"{}\s".format(header_type), auth) or len(parts) != 2: + msg = "Bad {} header. Expected value '{} '".format( + header_name, header_type + ) + raise InvalidHeaderError(status_code=422, message=msg) self._token = parts[1] def _get_jwt_identifier(self) -> str: return str(uuid.uuid4()) - def _get_int_from_datetime(self,value: datetime) -> int: + def _get_int_from_datetime(self, value: datetime) -> int: """ :param value: datetime with or without timezone, if don't contains timezone it will managed as it is UTC :return: Seconds since the Epoch """ if not isinstance(value, datetime): # pragma: no cover - raise TypeError('a datetime is required') + raise TypeError("a datetime is required") return int(value.timestamp()) def _get_secret_key(self, algorithm: str, process: str) -> str: @@ -82,15 +95,24 @@ def _get_secret_key(self, algorithm: str, process: str) -> str: :return: plain text or RSA depends on algorithm """ - symmetric_algorithms, asymmetric_algorithms = {"HS256","HS384","HS512"}, requires_cryptography + symmetric_algorithms, asymmetric_algorithms = { + "HS256", + "HS384", + "HS512", + }, requires_cryptography - if algorithm not in symmetric_algorithms and algorithm not in asymmetric_algorithms: + if ( + algorithm not in symmetric_algorithms + and algorithm not in asymmetric_algorithms + ): raise ValueError("Algorithm {} could not be found".format(algorithm)) if algorithm in symmetric_algorithms: if not self._secret_key: raise RuntimeError( - "authjwt_secret_key must be set when using symmetric algorithm {}".format(algorithm) + "authjwt_secret_key must be set when using symmetric algorithm {}".format( + algorithm + ) ) return self._secret_key @@ -103,7 +125,9 @@ def _get_secret_key(self, algorithm: str, process: str) -> str: if process == "encode": if not self._private_key: raise RuntimeError( - "authjwt_private_key must be set when using asymmetric algorithm {}".format(algorithm) + "authjwt_private_key must be set when using asymmetric algorithm {}".format( + algorithm + ) ) return self._private_key @@ -111,22 +135,24 @@ def _get_secret_key(self, algorithm: str, process: str) -> str: if process == "decode": if not self._public_key: raise RuntimeError( - "authjwt_public_key must be set when using asymmetric algorithm {}".format(algorithm) + "authjwt_public_key must be set when using asymmetric algorithm {}".format( + algorithm + ) ) return self._public_key def _create_token( self, - subject: Union[str,int], + subject: Union[str, int], type_token: str, exp_time: Optional[int], fresh: Optional[bool] = False, algorithm: Optional[str] = None, headers: Optional[Dict] = None, issuer: Optional[str] = None, - audience: Optional[Union[str,Sequence[str]]] = None, - user_claims: Optional[Dict] = {} + audience: Optional[Union[str, Sequence[str]]] = None, + user_claims: Optional[Dict] = {}, ) -> str: """ Create token for access_token and refresh_token (utf-8) @@ -144,7 +170,7 @@ def _create_token( :return: Encoded token """ # Validation type data - if not isinstance(subject, (str,int)): + if not isinstance(subject, (str, int)): raise TypeError("subject must be a string or integer") if not isinstance(fresh, bool): raise TypeError("fresh must be a boolean") @@ -160,29 +186,29 @@ def _create_token( "sub": subject, "iat": self._get_int_from_datetime(datetime.now(timezone.utc)), "nbf": self._get_int_from_datetime(datetime.now(timezone.utc)), - "jti": self._get_jwt_identifier() + "jti": self._get_jwt_identifier(), } custom_claims = {"type": type_token} # for access_token only fresh needed - if type_token == 'access': - custom_claims['fresh'] = fresh + if type_token == "access": + custom_claims["fresh"] = fresh # if cookie in token location and csrf protection enabled if self.jwt_in_cookies and self._cookie_csrf_protect: - custom_claims['csrf'] = self._get_jwt_identifier() + custom_claims["csrf"] = self._get_jwt_identifier() if exp_time: - reserved_claims['exp'] = exp_time + reserved_claims["exp"] = exp_time if issuer: - reserved_claims['iss'] = issuer + reserved_claims["iss"] = issuer if audience: - reserved_claims['aud'] = audience + reserved_claims["aud"] = audience algorithm = algorithm or self._algorithm try: - secret_key = self._get_secret_key(algorithm,"encode") + secret_key = self._get_secret_key(algorithm, "encode") except Exception: raise @@ -190,8 +216,8 @@ def _create_token( {**reserved_claims, **custom_claims, **user_claims}, secret_key, algorithm=algorithm, - headers=headers - ).decode('utf-8') + headers=headers, + ) def _has_token_in_denylist_callback(self) -> bool: """ @@ -199,7 +225,9 @@ def _has_token_in_denylist_callback(self) -> bool: """ return self._token_in_denylist_callback is not None - def _check_token_is_revoked(self, raw_token: Dict[str,Union[str,int,bool]]) -> None: + def _check_token_is_revoked( + self, raw_token: Dict[str, Union[str, int, bool]] + ) -> None: """ Ensure that AUTHJWT_DENYLIST_ENABLED is true and callback regulated, and then call function denylist callback with passing decode JWT, if true @@ -209,18 +237,20 @@ def _check_token_is_revoked(self, raw_token: Dict[str,Union[str,int,bool]]) -> N return if not self._has_token_in_denylist_callback(): - raise RuntimeError("A token_in_denylist_callback must be provided via " + raise RuntimeError( + "A token_in_denylist_callback must be provided via " "the '@AuthJWT.token_in_denylist_loader' if " - "authjwt_denylist_enabled is 'True'") + "authjwt_denylist_enabled is 'True'" + ) if self._token_in_denylist_callback.__func__(raw_token): - raise RevokedTokenError(status_code=401,message="Token has been revoked") + raise RevokedTokenError(status_code=401, message="Token has been revoked") def _get_expired_time( self, type_token: str, - expires_time: Optional[Union[timedelta,int,bool]] = None - ) -> Union[None,int]: + expires_time: Optional[Union[timedelta, int, bool]] = None, + ) -> Union[None, int]: """ Dynamic token expired, if expires_time is False exp claim not created @@ -229,37 +259,39 @@ def _get_expired_time( :return: duration exp claim jwt """ - if expires_time and not isinstance(expires_time, (timedelta,int,bool)): + if expires_time and not isinstance(expires_time, (timedelta, int, bool)): raise TypeError("expires_time must be between timedelta, int, bool") if expires_time is not False: - if type_token == 'access': + if type_token == "access": expires_time = expires_time or self._access_token_expires - if type_token == 'refresh': + if type_token == "refresh": expires_time = expires_time or self._refresh_token_expires if expires_time is not False: if isinstance(expires_time, bool): - if type_token == 'access': + if type_token == "access": expires_time = self._access_token_expires - if type_token == 'refresh': + if type_token == "refresh": expires_time = self._refresh_token_expires if isinstance(expires_time, timedelta): expires_time = int(expires_time.total_seconds()) - return self._get_int_from_datetime(datetime.now(timezone.utc)) + expires_time + return ( + self._get_int_from_datetime(datetime.now(timezone.utc)) + expires_time + ) else: return None def create_access_token( self, - subject: Union[str,int], + subject: Union[str, int], fresh: Optional[bool] = False, algorithm: Optional[str] = None, headers: Optional[Dict] = None, - expires_time: Optional[Union[timedelta,int,bool]] = None, - audience: Optional[Union[str,Sequence[str]]] = None, - user_claims: Optional[Dict] = {} + expires_time: Optional[Union[timedelta, int, bool]] = None, + audience: Optional[Union[str, Sequence[str]]] = None, + user_claims: Optional[Dict] = {}, ) -> str: """ Create a access token with 15 minutes for expired time (default), @@ -270,23 +302,23 @@ def create_access_token( return self._create_token( subject=subject, type_token="access", - exp_time=self._get_expired_time("access",expires_time), + exp_time=self._get_expired_time("access", expires_time), fresh=fresh, algorithm=algorithm, headers=headers, audience=audience, user_claims=user_claims, - issuer=self._encode_issuer + issuer=self._encode_issuer, ) def create_refresh_token( self, - subject: Union[str,int], + subject: Union[str, int], algorithm: Optional[str] = None, headers: Optional[Dict] = None, - expires_time: Optional[Union[timedelta,int,bool]] = None, - audience: Optional[Union[str,Sequence[str]]] = None, - user_claims: Optional[Dict] = {} + expires_time: Optional[Union[timedelta, int, bool]] = None, + audience: Optional[Union[str, Sequence[str]]] = None, + user_claims: Optional[Dict] = {}, ) -> str: """ Create a refresh token with 30 days for expired time (default), @@ -297,27 +329,27 @@ def create_refresh_token( return self._create_token( subject=subject, type_token="refresh", - exp_time=self._get_expired_time("refresh",expires_time), + exp_time=self._get_expired_time("refresh", expires_time), algorithm=algorithm, headers=headers, audience=audience, - user_claims=user_claims + user_claims=user_claims, ) - def _get_csrf_token(self,encoded_token: str) -> str: + def _get_csrf_token(self, encoded_token: str) -> str: """ Returns the CSRF double submit token from an encoded JWT. :param encoded_token: The encoded JWT :return: The CSRF double submit token """ - return self._verified_token(encoded_token)['csrf'] + return self._verified_token(encoded_token)["csrf"] def set_access_cookies( self, encoded_access_token: str, response: Optional[Response] = None, - max_age: Optional[int] = None + max_age: Optional[int] = None, ) -> None: """ Configures the response to set access token in a cookie. @@ -332,9 +364,9 @@ def set_access_cookies( "set_access_cookies() called without 'authjwt_token_location' configured to use cookies" ) - if max_age and not isinstance(max_age,int): + if max_age and not isinstance(max_age, int): raise TypeError("max_age must be a integer") - if response and not isinstance(response,Response): + if response and not isinstance(response, Response): raise TypeError("The response must be an object response FastAPI") response = response or self._response @@ -348,7 +380,7 @@ def set_access_cookies( domain=self._cookie_domain, secure=self._cookie_secure, httponly=True, - samesite=self._cookie_samesite + samesite=self._cookie_samesite, ) # If enabled, set the csrf double submit access cookie @@ -361,14 +393,14 @@ def set_access_cookies( domain=self._cookie_domain, secure=self._cookie_secure, httponly=False, - samesite=self._cookie_samesite + samesite=self._cookie_samesite, ) def set_refresh_cookies( self, encoded_refresh_token: str, response: Optional[Response] = None, - max_age: Optional[int] = None + max_age: Optional[int] = None, ) -> None: """ Configures the response to set refresh token in a cookie. @@ -383,9 +415,9 @@ def set_refresh_cookies( "set_refresh_cookies() called without 'authjwt_token_location' configured to use cookies" ) - if max_age and not isinstance(max_age,int): + if max_age and not isinstance(max_age, int): raise TypeError("max_age must be a integer") - if response and not isinstance(response,Response): + if response and not isinstance(response, Response): raise TypeError("The response must be an object response FastAPI") response = response or self._response @@ -399,7 +431,7 @@ def set_refresh_cookies( domain=self._cookie_domain, secure=self._cookie_secure, httponly=True, - samesite=self._cookie_samesite + samesite=self._cookie_samesite, ) # If enabled, set the csrf double submit refresh cookie @@ -412,10 +444,10 @@ def set_refresh_cookies( domain=self._cookie_domain, secure=self._cookie_secure, httponly=False, - samesite=self._cookie_samesite + samesite=self._cookie_samesite, ) - def unset_jwt_cookies(self,response: Optional[Response] = None) -> None: + def unset_jwt_cookies(self, response: Optional[Response] = None) -> None: """ Unset (delete) all jwt stored in a cookie @@ -424,7 +456,7 @@ def unset_jwt_cookies(self,response: Optional[Response] = None) -> None: self.unset_access_cookies(response) self.unset_refresh_cookies(response) - def unset_access_cookies(self,response: Optional[Response] = None) -> None: + def unset_access_cookies(self, response: Optional[Response] = None) -> None: """ Remove access token and access CSRF double submit from the response cookies @@ -435,7 +467,7 @@ def unset_access_cookies(self,response: Optional[Response] = None) -> None: "unset_access_cookies() called without 'authjwt_token_location' configured to use cookies" ) - if response and not isinstance(response,Response): + if response and not isinstance(response, Response): raise TypeError("The response must be an object response FastAPI") response = response or self._response @@ -443,17 +475,17 @@ def unset_access_cookies(self,response: Optional[Response] = None) -> None: response.delete_cookie( self._access_cookie_key, path=self._access_cookie_path, - domain=self._cookie_domain + domain=self._cookie_domain, ) if self._cookie_csrf_protect: response.delete_cookie( self._access_csrf_cookie_key, path=self._access_csrf_cookie_path, - domain=self._cookie_domain + domain=self._cookie_domain, ) - def unset_refresh_cookies(self,response: Optional[Response] = None) -> None: + def unset_refresh_cookies(self, response: Optional[Response] = None) -> None: """ Remove refresh token and refresh CSRF double submit from the response cookies @@ -464,7 +496,7 @@ def unset_refresh_cookies(self,response: Optional[Response] = None) -> None: "unset_refresh_cookies() called without 'authjwt_token_location' configured to use cookies" ) - if response and not isinstance(response,Response): + if response and not isinstance(response, Response): raise TypeError("The response must be an object response FastAPI") response = response or self._response @@ -472,19 +504,19 @@ def unset_refresh_cookies(self,response: Optional[Response] = None) -> None: response.delete_cookie( self._refresh_cookie_key, path=self._refresh_cookie_path, - domain=self._cookie_domain + domain=self._cookie_domain, ) if self._cookie_csrf_protect: response.delete_cookie( self._refresh_csrf_cookie_key, path=self._refresh_csrf_cookie_path, - domain=self._cookie_domain + domain=self._cookie_domain, ) def _verify_and_get_jwt_optional_in_cookies( self, - request: Union[Request,WebSocket], + request: Union[Request, WebSocket], csrf_token: Optional[str] = None, ) -> "AuthJWT": """ @@ -495,7 +527,7 @@ def _verify_and_get_jwt_optional_in_cookies( :param request: for identity get cookies from HTTP or WebSocket :param csrf_token: the CSRF double submit token """ - if not isinstance(request,(Request,WebSocket)): + if not isinstance(request, (Request, WebSocket)): raise TypeError("request must be an instance of 'Request' or 'WebSocket'") cookie_key = self._access_cookie_key @@ -505,7 +537,7 @@ def _verify_and_get_jwt_optional_in_cookies( if cookie and self._cookie_csrf_protect and not csrf_token: if isinstance(request, WebSocket) or request.method in self._csrf_methods: - raise CSRFError(status_code=401,message="Missing CSRF Token") + raise CSRFError(status_code=401, message="Missing CSRF Token") # set token from cookie and verify jwt self._token = cookie @@ -515,15 +547,18 @@ def _verify_and_get_jwt_optional_in_cookies( if decoded_token and self._cookie_csrf_protect and csrf_token: if isinstance(request, WebSocket) or request.method in self._csrf_methods: - if 'csrf' not in decoded_token: - raise JWTDecodeError(status_code=422,message="Missing claim: csrf") - if not hmac.compare_digest(csrf_token,decoded_token['csrf']): - raise CSRFError(status_code=401,message="CSRF double submit tokens do not match") + if "csrf" not in decoded_token: + raise JWTDecodeError(status_code=422, message="Missing claim: csrf") + if not hmac.compare_digest(csrf_token, decoded_token["csrf"]): + raise CSRFError( + status_code=401, + message="CSRF double submit tokens do not match", + ) def _verify_and_get_jwt_in_cookies( self, type_token: str, - request: Union[Request,WebSocket], + request: Union[Request, WebSocket], csrf_token: Optional[str] = None, fresh: Optional[bool] = False, ) -> "AuthJWT": @@ -537,59 +572,67 @@ def _verify_and_get_jwt_in_cookies( :param csrf_token: the CSRF double submit token :param fresh: check freshness token if True """ - if type_token not in ['access','refresh']: + if type_token not in ["access", "refresh"]: raise ValueError("type_token must be between 'access' or 'refresh'") - if not isinstance(request,(Request,WebSocket)): + if not isinstance(request, (Request, WebSocket)): raise TypeError("request must be an instance of 'Request' or 'WebSocket'") - if type_token == 'access': + if type_token == "access": cookie_key = self._access_cookie_key cookie = request.cookies.get(cookie_key) if not isinstance(request, WebSocket): csrf_token = request.headers.get(self._access_csrf_header_name) - if type_token == 'refresh': + if type_token == "refresh": cookie_key = self._refresh_cookie_key cookie = request.cookies.get(cookie_key) if not isinstance(request, WebSocket): csrf_token = request.headers.get(self._refresh_csrf_header_name) if not cookie: - raise MissingTokenError(status_code=401,message="Missing cookie {}".format(cookie_key)) + raise MissingTokenError( + status_code=401, message="Missing cookie {}".format(cookie_key) + ) if self._cookie_csrf_protect and not csrf_token: if isinstance(request, WebSocket) or request.method in self._csrf_methods: - raise CSRFError(status_code=401,message="Missing CSRF Token") + raise CSRFError(status_code=401, message="Missing CSRF Token") # set token from cookie and verify jwt self._token = cookie - self._verify_jwt_in_request(self._token,type_token,'cookies',fresh) + self._verify_jwt_in_request(self._token, type_token, "cookies", fresh) decoded_token = self.get_raw_jwt() if self._cookie_csrf_protect and csrf_token: if isinstance(request, WebSocket) or request.method in self._csrf_methods: - if 'csrf' not in decoded_token: - raise JWTDecodeError(status_code=422,message="Missing claim: csrf") - if not hmac.compare_digest(csrf_token,decoded_token['csrf']): - raise CSRFError(status_code=401,message="CSRF double submit tokens do not match") + if "csrf" not in decoded_token: + raise JWTDecodeError(status_code=422, message="Missing claim: csrf") + if not hmac.compare_digest(csrf_token, decoded_token["csrf"]): + raise CSRFError( + status_code=401, + message="CSRF double submit tokens do not match", + ) - def _verify_jwt_optional_in_request(self,token: str) -> None: + def _verify_jwt_optional_in_request(self, token: str) -> None: """ Optionally check if this request has a valid access token :param token: The encoded JWT """ - if token: self._verifying_token(token) + if token: + self._verifying_token(token) - if token and self.get_raw_jwt(token)['type'] != 'access': - raise AccessTokenRequired(status_code=422,message="Only access tokens are allowed") + if token and self.get_raw_jwt(token)["type"] != "access": + raise AccessTokenRequired( + status_code=422, message="Only access tokens are allowed" + ) def _verify_jwt_in_request( self, token: str, type_token: str, token_from: str, - fresh: Optional[bool] = False + fresh: Optional[bool] = False, ) -> None: """ Ensure that the requester has a valid token. this also check the freshness of the access token @@ -599,43 +642,60 @@ def _verify_jwt_in_request( :param token_from: indicate token from headers cookies, websocket :param fresh: check freshness token if True """ - if type_token not in ['access','refresh']: + if type_token not in ["access", "refresh"]: raise ValueError("type_token must be between 'access' or 'refresh'") - if token_from not in ['headers','cookies','websocket']: - raise ValueError("token_from must be between 'headers', 'cookies', 'websocket'") + if token_from not in ["headers", "cookies", "websocket", "body"]: + raise ValueError( + "token_from must be between 'headers', 'cookies', 'websocket', 'body'" + ) if not token: - if token_from == 'headers': - raise MissingTokenError(status_code=401,message="Missing {} Header".format(self._header_name)) - if token_from == 'websocket': - raise MissingTokenError(status_code=1008,message="Missing {} token from Query or Path".format(type_token)) + if token_from == "headers": + raise MissingTokenError( + status_code=401, + message="Missing {} Header".format(self._header_name), + ) + if token_from == "websocket": + raise MissingTokenError( + status_code=1008, + message="Missing {} token from Query or Path".format(type_token), + ) + if token_from == "body": + raise MissingTokenError( + status_code=401, + message="Missing {} token from Body".format(type_token), + ) # verify jwt - issuer = self._decode_issuer if type_token == 'access' else None - self._verifying_token(token,issuer) + issuer = self._decode_issuer if type_token == "access" else None + self._verifying_token(token, issuer) - if self.get_raw_jwt(token)['type'] != type_token: + if self.get_raw_jwt(token)["type"] != type_token: msg = "Only {} tokens are allowed".format(type_token) - if type_token == 'access': - raise AccessTokenRequired(status_code=422,message=msg) - if type_token == 'refresh': - raise RefreshTokenRequired(status_code=422,message=msg) + if type_token == "access": + raise AccessTokenRequired(status_code=422, message=msg) + if type_token == "refresh": + raise RefreshTokenRequired(status_code=422, message=msg) - if fresh and not self.get_raw_jwt(token)['fresh']: - raise FreshTokenRequired(status_code=401,message="Fresh token required") + if fresh and not self.get_raw_jwt(token)["fresh"]: + raise FreshTokenRequired(status_code=401, message="Fresh token required") - def _verifying_token(self,encoded_token: str, issuer: Optional[str] = None) -> None: + def _verifying_token( + self, encoded_token: str, issuer: Optional[str] = None + ) -> None: """ Verified token and check if token is revoked :param encoded_token: token hash :param issuer: expected issuer in the JWT """ - raw_token = self._verified_token(encoded_token,issuer) - if raw_token['type'] in self._denylist_token_checks: + raw_token = self._verified_token(encoded_token, issuer) + if raw_token["type"] in self._denylist_token_checks: self._check_token_is_revoked(raw_token) - def _verified_token(self,encoded_token: str, issuer: Optional[str] = None) -> Dict[str,Union[str,int,bool]]: + def _verified_token( + self, encoded_token: str, issuer: Optional[str] = None + ) -> Dict[str, Union[str, int, bool]]: """ Verified token and catch all error from jwt package and return decode token @@ -649,10 +709,10 @@ def _verified_token(self,encoded_token: str, issuer: Optional[str] = None) -> Di try: unverified_headers = self.get_unverified_jwt_headers(encoded_token) except Exception as err: - raise InvalidHeaderError(status_code=422,message=str(err)) + raise InvalidHeaderError(status_code=422, message=str(err)) try: - secret_key = self._get_secret_key(unverified_headers['alg'],"decode") + secret_key = self._get_secret_key(unverified_headers["alg"], "decode") except Exception: raise @@ -663,10 +723,10 @@ def _verified_token(self,encoded_token: str, issuer: Optional[str] = None) -> Di issuer=issuer, audience=self._decode_audience, leeway=self._decode_leeway, - algorithms=algorithms + algorithms=algorithms, ) except Exception as err: - raise JWTDecodeError(status_code=422,message=str(err)) + raise JWTDecodeError(status_code=422, message=str(err)) def jwt_required( self, @@ -686,20 +746,22 @@ def jwt_required( its must be passing csrf_token manually and can achieve by Query Url or Path """ if auth_from == "websocket": - if websocket: self._verify_and_get_jwt_in_cookies('access',websocket,csrf_token) - else: self._verify_jwt_in_request(token,'access','websocket') + if websocket: + self._verify_and_get_jwt_in_cookies("access", websocket, csrf_token) + else: + self._verify_jwt_in_request(token, "access", "websocket") if auth_from == "request": if len(self._token_location) == 2: if self._token and self.jwt_in_headers: - self._verify_jwt_in_request(self._token,'access','headers') + self._verify_jwt_in_request(self._token, "access", "headers") if not self._token and self.jwt_in_cookies: - self._verify_and_get_jwt_in_cookies('access',self._request) + self._verify_and_get_jwt_in_cookies("access", self._request) else: if self.jwt_in_headers: - self._verify_jwt_in_request(self._token,'access','headers') + self._verify_jwt_in_request(self._token, "access", "headers") if self.jwt_in_cookies: - self._verify_and_get_jwt_in_cookies('access',self._request) + self._verify_and_get_jwt_in_cookies("access", self._request) def jwt_optional( self, @@ -721,8 +783,10 @@ def jwt_optional( its must be passing csrf_token manually and can achieve by Query Url or Path """ if auth_from == "websocket": - if websocket: self._verify_and_get_jwt_optional_in_cookies(websocket,csrf_token) - else: self._verify_jwt_optional_in_request(token) + if websocket: + self._verify_and_get_jwt_optional_in_cookies(websocket, csrf_token) + else: + self._verify_jwt_optional_in_request(token) if auth_from == "request": if len(self._token_location) == 2: @@ -754,20 +818,24 @@ def jwt_refresh_token_required( its must be passing csrf_token manually and can achieve by Query Url or Path """ if auth_from == "websocket": - if websocket: self._verify_and_get_jwt_in_cookies('refresh',websocket,csrf_token) - else: self._verify_jwt_in_request(token,'refresh','websocket') + if websocket: + self._verify_and_get_jwt_in_cookies("refresh", websocket, csrf_token) + else: + self._verify_jwt_in_request(token, "refresh", "websocket") if auth_from == "request": if len(self._token_location) == 2: if self._token and self.jwt_in_headers: - self._verify_jwt_in_request(self._token,'refresh','headers') + self._verify_jwt_in_request(self._token, "refresh", "headers") if not self._token and self.jwt_in_cookies: - self._verify_and_get_jwt_in_cookies('refresh',self._request) + self._verify_and_get_jwt_in_cookies("refresh", self._request) else: + if "body" in self._token_location: + self._verify_jwt_in_request(self._token, "refresh", "body") if self.jwt_in_headers: - self._verify_jwt_in_request(self._token,'refresh','headers') + self._verify_jwt_in_request(self._token, "refresh", "headers") if self.jwt_in_cookies: - self._verify_and_get_jwt_in_cookies('refresh',self._request) + self._verify_and_get_jwt_in_cookies("refresh", self._request) def fresh_jwt_required( self, @@ -787,22 +855,32 @@ def fresh_jwt_required( its must be passing csrf_token manually and can achieve by Query Url or Path """ if auth_from == "websocket": - if websocket: self._verify_and_get_jwt_in_cookies('access',websocket,csrf_token,True) - else: self._verify_jwt_in_request(token,'access','websocket',True) + if websocket: + self._verify_and_get_jwt_in_cookies( + "access", websocket, csrf_token, True + ) + else: + self._verify_jwt_in_request(token, "access", "websocket", True) if auth_from == "request": if len(self._token_location) == 2: if self._token and self.jwt_in_headers: - self._verify_jwt_in_request(self._token,'access','headers',True) + self._verify_jwt_in_request(self._token, "access", "headers", True) if not self._token and self.jwt_in_cookies: - self._verify_and_get_jwt_in_cookies('access',self._request,fresh=True) + self._verify_and_get_jwt_in_cookies( + "access", self._request, fresh=True + ) else: if self.jwt_in_headers: - self._verify_jwt_in_request(self._token,'access','headers',True) + self._verify_jwt_in_request(self._token, "access", "headers", True) if self.jwt_in_cookies: - self._verify_and_get_jwt_in_cookies('access',self._request,fresh=True) + self._verify_and_get_jwt_in_cookies( + "access", self._request, fresh=True + ) - def get_raw_jwt(self,encoded_token: Optional[str] = None) -> Optional[Dict[str,Union[str,int,bool]]]: + def get_raw_jwt( + self, encoded_token: Optional[str] = None + ) -> Optional[Dict[str, Union[str, int, bool]]]: """ this will return the python dictionary which has all of the claims of the JWT that is accessing the endpoint. If no JWT is currently present, return None instead @@ -816,16 +894,16 @@ def get_raw_jwt(self,encoded_token: Optional[str] = None) -> Optional[Dict[str,U return self._verified_token(token) return None - def get_jti(self,encoded_token: str) -> str: + def get_jti(self, encoded_token: str) -> str: """ Returns the JTI (unique identifier) of an encoded JWT :param encoded_token: The encoded JWT from parameter :return: string of JTI """ - return self._verified_token(encoded_token)['jti'] + return self._verified_token(encoded_token)["jti"] - def get_jwt_subject(self) -> Optional[Union[str,int]]: + def get_jwt_subject(self) -> Optional[Union[str, int]]: """ this will return the subject of the JWT that is accessing this endpoint. If no JWT is present, `None` is returned instead. @@ -833,10 +911,10 @@ def get_jwt_subject(self) -> Optional[Union[str,int]]: :return: sub of JWT """ if self._token: - return self._verified_token(self._token)['sub'] + return self._verified_token(self._token)["sub"] return None - def get_unverified_jwt_headers(self,encoded_token: Optional[str] = None) -> dict: + def get_unverified_jwt_headers(self, encoded_token: Optional[str] = None) -> dict: """ Returns the Headers of an encoded JWT without verifying the actual signature of JWT @@ -846,3 +924,15 @@ def get_unverified_jwt_headers(self,encoded_token: Optional[str] = None) -> dict encoded_token = encoded_token or self._token return jwt.get_unverified_header(encoded_token) + + +class AuthJWTRefresh(AuthJWT): + def __init__( + self, req: Request = None, res: Response = None, refresh_token: BodyToken = None + ): + if refresh_token is not None: + self._token = refresh_token.refresh_token + self._token_location = {"body"} + + else: + AuthJWT.__init__(self, req, res) diff --git a/fastapi_jwt_auth/config.py b/fastapi_jwt_auth/config.py index 27deab8..8d27f8e 100644 --- a/fastapi_jwt_auth/config.py +++ b/fastapi_jwt_auth/config.py @@ -64,7 +64,7 @@ def validate_refresh_token_expires(cls, v): ) return v - @validator("authjwt_denylist_token_checks") + @validator("authjwt_denylist_token_checks", each_item=True) def validate_denylist_token_checks(cls, v): if v not in ["access", "refresh"]: raise ValueError( @@ -72,11 +72,11 @@ def validate_denylist_token_checks(cls, v): ) return v - @validator("authjwt_token_location") + @validator("authjwt_token_location", each_item=True) def validate_token_location(cls, v): - if v not in ["headers", "cookies"]: + if v not in ["headers", "cookies", "body"]: raise ValueError( - "The 'authjwt_token_location' must be between 'headers' or 'cookies'" + "The 'authjwt_token_location' must be between 'headers' or 'cookies' 'body'" ) return v @@ -88,7 +88,7 @@ def validate_cookie_samesite(cls, v): ) return v - @validator("authjwt_csrf_methods") + @validator("authjwt_csrf_methods", each_item=True) def validate_csrf_methods(cls, v): if v.upper() not in {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"}: raise ValueError( diff --git a/fastapi_jwt_auth/exceptions.py b/fastapi_jwt_auth/exceptions.py index 590423c..1057571 100644 --- a/fastapi_jwt_auth/exceptions.py +++ b/fastapi_jwt_auth/exceptions.py @@ -2,71 +2,88 @@ class AuthJWTException(Exception): """ Base except which all fastapi_jwt_auth errors extend """ + pass + class InvalidHeaderError(AuthJWTException): """ An error getting jwt in header or jwt header information from a request """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message + class JWTDecodeError(AuthJWTException): """ An error decoding a JWT """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message + class CSRFError(AuthJWTException): """ An error with CSRF protection """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message + class MissingTokenError(AuthJWTException): """ Error raised when token not found """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message + class RevokedTokenError(AuthJWTException): """ Error raised when a revoked token attempt to access a protected endpoint """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message + class AccessTokenRequired(AuthJWTException): """ Error raised when a valid, non-access JWT attempt to access an endpoint protected by jwt_required, jwt_optional, fresh_jwt_required """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message + class RefreshTokenRequired(AuthJWTException): """ Error raised when a valid, non-refresh JWT attempt to access an endpoint protected by jwt_refresh_token_required """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message + class FreshTokenRequired(AuthJWTException): """ Error raised when a valid, non-fresh JWT attempt to access an endpoint protected by fresh_jwt_required """ - def __init__(self,status_code: int, message: str): + + def __init__(self, status_code: int, message: str): self.status_code = status_code self.message = message