From aeec4c754a93aa29fd8c40886272889301a54373 Mon Sep 17 00:00:00 2001 From: Miguel Jesus Date: Fri, 1 Sep 2023 21:30:02 +0100 Subject: [PATCH] fix: upgrade code --- fastapi_jwt_auth/auth_jwt.py | 2 +- fastapi_jwt_auth/config.py | 78 ++++++++++++++++++++++-------------- pyproject.toml | 3 +- tests/test_decode_token.py | 6 +-- 4 files changed, 52 insertions(+), 37 deletions(-) diff --git a/fastapi_jwt_auth/auth_jwt.py b/fastapi_jwt_auth/auth_jwt.py index de1f96a..4110bdb 100644 --- a/fastapi_jwt_auth/auth_jwt.py +++ b/fastapi_jwt_auth/auth_jwt.py @@ -191,7 +191,7 @@ def _create_token( secret_key, algorithm=algorithm, headers=headers - ) + ).decode('utf-8') def _has_token_in_denylist_callback(self) -> bool: """ diff --git a/fastapi_jwt_auth/config.py b/fastapi_jwt_auth/config.py index c81b50c..27deab8 100644 --- a/fastapi_jwt_auth/config.py +++ b/fastapi_jwt_auth/config.py @@ -1,30 +1,29 @@ from datetime import timedelta from typing import Optional, Union, Sequence, List -from pydantic import ( - BaseModel, - validator, - StrictBool, - StrictInt, - StrictStr -) +from pydantic import BaseModel, validator, StrictBool, StrictInt, StrictStr + class LoadConfig(BaseModel): - authjwt_token_location: Optional[Sequence[StrictStr]] = {'headers'} + authjwt_token_location: Optional[Sequence[StrictStr]] = {"headers"} authjwt_secret_key: Optional[StrictStr] = None authjwt_public_key: Optional[StrictStr] = None authjwt_private_key: Optional[StrictStr] = None authjwt_algorithm: Optional[StrictStr] = "HS256" authjwt_decode_algorithms: Optional[List[StrictStr]] = None - authjwt_decode_leeway: Optional[Union[StrictInt,timedelta]] = 0 + authjwt_decode_leeway: Optional[Union[StrictInt, timedelta]] = 0 authjwt_encode_issuer: Optional[StrictStr] = None authjwt_decode_issuer: Optional[StrictStr] = None - authjwt_decode_audience: Optional[Union[StrictStr,Sequence[StrictStr]]] = None + authjwt_decode_audience: Optional[Union[StrictStr, Sequence[StrictStr]]] = None authjwt_denylist_enabled: Optional[StrictBool] = False - authjwt_denylist_token_checks: Optional[Sequence[StrictStr]] = {'access','refresh'} + authjwt_denylist_token_checks: Optional[Sequence[StrictStr]] = {"access", "refresh"} authjwt_header_name: Optional[StrictStr] = "Authorization" authjwt_header_type: Optional[StrictStr] = "Bearer" - authjwt_access_token_expires: Optional[Union[StrictBool,StrictInt,timedelta]] = timedelta(minutes=15) - authjwt_refresh_token_expires: Optional[Union[StrictBool,StrictInt,timedelta]] = timedelta(days=30) + authjwt_access_token_expires: Optional[ + Union[StrictBool, StrictInt, timedelta] + ] = timedelta(minutes=15) + authjwt_refresh_token_expires: Optional[ + Union[StrictBool, StrictInt, timedelta] + ] = timedelta(days=30) # option for create cookies authjwt_access_cookie_key: Optional[StrictStr] = "access_token_cookie" authjwt_refresh_cookie_key: Optional[StrictStr] = "refresh_token_cookie" @@ -42,44 +41,61 @@ class LoadConfig(BaseModel): authjwt_refresh_csrf_cookie_path: Optional[StrictStr] = "/" authjwt_access_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token" authjwt_refresh_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token" - authjwt_csrf_methods: Optional[Sequence[StrictStr]] = {'POST','PUT','PATCH','DELETE'} + authjwt_csrf_methods: Optional[Sequence[StrictStr]] = { + "POST", + "PUT", + "PATCH", + "DELETE", + } - @validator('authjwt_access_token_expires') + @validator("authjwt_access_token_expires") def validate_access_token_expires(cls, v): if v is True: - raise ValueError("The 'authjwt_access_token_expires' only accept value False (bool)") + raise ValueError( + "The 'authjwt_access_token_expires' only accept value False (bool)" + ) return v - @validator('authjwt_refresh_token_expires') + @validator("authjwt_refresh_token_expires") def validate_refresh_token_expires(cls, v): if v is True: - raise ValueError("The 'authjwt_refresh_token_expires' only accept value False (bool)") + raise ValueError( + "The 'authjwt_refresh_token_expires' only accept value False (bool)" + ) return v - @validator('authjwt_denylist_token_checks', each_item=True) + @validator("authjwt_denylist_token_checks") def validate_denylist_token_checks(cls, v): - if v not in ['access','refresh']: - raise ValueError("The 'authjwt_denylist_token_checks' must be between 'access' or 'refresh'") + if v not in ["access", "refresh"]: + raise ValueError( + "The 'authjwt_denylist_token_checks' must be between 'access' or 'refresh'" + ) return v - @validator('authjwt_token_location', each_item=True) + @validator("authjwt_token_location") def validate_token_location(cls, v): - if v not in ['headers','cookies']: - raise ValueError("The 'authjwt_token_location' must be between 'headers' or 'cookies'") + if v not in ["headers", "cookies"]: + raise ValueError( + "The 'authjwt_token_location' must be between 'headers' or 'cookies'" + ) return v - @validator('authjwt_cookie_samesite') + @validator("authjwt_cookie_samesite") def validate_cookie_samesite(cls, v): - if v not in ['strict','lax','none']: - raise ValueError("The 'authjwt_cookie_samesite' must be between 'strict', 'lax', 'none'") + if v not in ["strict", "lax", "none"]: + raise ValueError( + "The 'authjwt_cookie_samesite' must be between 'strict', 'lax', 'none'" + ) return v - @validator('authjwt_csrf_methods', each_item=True) + @validator("authjwt_csrf_methods") def validate_csrf_methods(cls, v): if v.upper() not in {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"}: - raise ValueError("The 'authjwt_csrf_methods' must be between http request methods") + raise ValueError( + "The 'authjwt_csrf_methods' must be between http request methods" + ) return v.upper() class Config: - min_anystr_length = 1 - anystr_strip_whitespace = True + str_min_length = 1 + str_strip_whitespace = True diff --git a/pyproject.toml b/pyproject.toml index f715ae9..3799bd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,8 +24,7 @@ classifiers = [ requires = [ "fastapi>=0.61.0", - "PyJWT==2.1.0", - "setuptools>=45.2.0" + "PyJWT>=1.7.1" ] description-file = "README.md" diff --git a/tests/test_decode_token.py b/tests/test_decode_token.py index 38e367a..5344d48 100644 --- a/tests/test_decode_token.py +++ b/tests/test_decode_token.py @@ -51,7 +51,7 @@ def default_access_token(): @pytest.fixture(scope='function') def encoded_token(default_access_token): - return jwt.encode(default_access_token,'secret-key',algorithm='HS256') + return jwt.encode(default_access_token,'secret-key',algorithm='HS256').decode('utf-8') def test_verified_token(client,encoded_token,Authorize): class SettingsOne(BaseSettings): @@ -67,7 +67,7 @@ def get_settings_one(): assert response.status_code == 422 assert response.json() == {'detail': 'Not enough segments'} # InvalidSignatureError - token = jwt.encode({'some': 'payload'}, 'secret', algorithm='HS256') + token = jwt.encode({'some': 'payload'}, 'secret', algorithm='HS256').decode('utf-8') response = client.get('/protected',headers={"Authorization":f"Bearer {token}"}) assert response.status_code == 422 assert response.json() == {'detail': 'Signature verification failed'} @@ -78,7 +78,7 @@ def get_settings_one(): assert response.status_code == 422 assert response.json() == {'detail': 'Signature has expired'} # InvalidAlgorithmError - token = jwt.encode({'some': 'payload'}, 'secret', algorithm='HS384') + token = jwt.encode({'some': 'payload'}, 'secret', algorithm='HS384').decode('utf-8') response = client.get('/protected',headers={"Authorization":f"Bearer {token}"}) assert response.status_code == 422 assert response.json() == {'detail': 'The specified alg value is not allowed'}