Skip to content

Commit

Permalink
fix: upgrade code
Browse files Browse the repository at this point in the history
  • Loading branch information
Miguel Jesus committed Sep 1, 2023
1 parent 1c6a5ca commit aeec4c7
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 37 deletions.
2 changes: 1 addition & 1 deletion fastapi_jwt_auth/auth_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _create_token(
secret_key,
algorithm=algorithm,
headers=headers
)
).decode('utf-8')

def _has_token_in_denylist_callback(self) -> bool:
"""
Expand Down
78 changes: 47 additions & 31 deletions fastapi_jwt_auth/config.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
from datetime import timedelta
from typing import Optional, Union, Sequence, List
from pydantic import (
BaseModel,
validator,
StrictBool,
StrictInt,
StrictStr
)
from pydantic import BaseModel, validator, StrictBool, StrictInt, StrictStr


class LoadConfig(BaseModel):
authjwt_token_location: Optional[Sequence[StrictStr]] = {'headers'}
authjwt_token_location: Optional[Sequence[StrictStr]] = {"headers"}
authjwt_secret_key: Optional[StrictStr] = None
authjwt_public_key: Optional[StrictStr] = None
authjwt_private_key: Optional[StrictStr] = None
authjwt_algorithm: Optional[StrictStr] = "HS256"
authjwt_decode_algorithms: Optional[List[StrictStr]] = None
authjwt_decode_leeway: Optional[Union[StrictInt,timedelta]] = 0
authjwt_decode_leeway: Optional[Union[StrictInt, timedelta]] = 0
authjwt_encode_issuer: Optional[StrictStr] = None
authjwt_decode_issuer: Optional[StrictStr] = None
authjwt_decode_audience: Optional[Union[StrictStr,Sequence[StrictStr]]] = None
authjwt_decode_audience: Optional[Union[StrictStr, Sequence[StrictStr]]] = None
authjwt_denylist_enabled: Optional[StrictBool] = False
authjwt_denylist_token_checks: Optional[Sequence[StrictStr]] = {'access','refresh'}
authjwt_denylist_token_checks: Optional[Sequence[StrictStr]] = {"access", "refresh"}
authjwt_header_name: Optional[StrictStr] = "Authorization"
authjwt_header_type: Optional[StrictStr] = "Bearer"
authjwt_access_token_expires: Optional[Union[StrictBool,StrictInt,timedelta]] = timedelta(minutes=15)
authjwt_refresh_token_expires: Optional[Union[StrictBool,StrictInt,timedelta]] = timedelta(days=30)
authjwt_access_token_expires: Optional[
Union[StrictBool, StrictInt, timedelta]
] = timedelta(minutes=15)
authjwt_refresh_token_expires: Optional[
Union[StrictBool, StrictInt, timedelta]
] = timedelta(days=30)
# option for create cookies
authjwt_access_cookie_key: Optional[StrictStr] = "access_token_cookie"
authjwt_refresh_cookie_key: Optional[StrictStr] = "refresh_token_cookie"
Expand All @@ -42,44 +41,61 @@ class LoadConfig(BaseModel):
authjwt_refresh_csrf_cookie_path: Optional[StrictStr] = "/"
authjwt_access_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token"
authjwt_refresh_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token"
authjwt_csrf_methods: Optional[Sequence[StrictStr]] = {'POST','PUT','PATCH','DELETE'}
authjwt_csrf_methods: Optional[Sequence[StrictStr]] = {
"POST",
"PUT",
"PATCH",
"DELETE",
}

@validator('authjwt_access_token_expires')
@validator("authjwt_access_token_expires")
def validate_access_token_expires(cls, v):
if v is True:
raise ValueError("The 'authjwt_access_token_expires' only accept value False (bool)")
raise ValueError(
"The 'authjwt_access_token_expires' only accept value False (bool)"
)
return v

@validator('authjwt_refresh_token_expires')
@validator("authjwt_refresh_token_expires")
def validate_refresh_token_expires(cls, v):
if v is True:
raise ValueError("The 'authjwt_refresh_token_expires' only accept value False (bool)")
raise ValueError(
"The 'authjwt_refresh_token_expires' only accept value False (bool)"
)
return v

@validator('authjwt_denylist_token_checks', each_item=True)
@validator("authjwt_denylist_token_checks")
def validate_denylist_token_checks(cls, v):
if v not in ['access','refresh']:
raise ValueError("The 'authjwt_denylist_token_checks' must be between 'access' or 'refresh'")
if v not in ["access", "refresh"]:
raise ValueError(
"The 'authjwt_denylist_token_checks' must be between 'access' or 'refresh'"
)
return v

@validator('authjwt_token_location', each_item=True)
@validator("authjwt_token_location")
def validate_token_location(cls, v):
if v not in ['headers','cookies']:
raise ValueError("The 'authjwt_token_location' must be between 'headers' or 'cookies'")
if v not in ["headers", "cookies"]:
raise ValueError(
"The 'authjwt_token_location' must be between 'headers' or 'cookies'"
)
return v

@validator('authjwt_cookie_samesite')
@validator("authjwt_cookie_samesite")
def validate_cookie_samesite(cls, v):
if v not in ['strict','lax','none']:
raise ValueError("The 'authjwt_cookie_samesite' must be between 'strict', 'lax', 'none'")
if v not in ["strict", "lax", "none"]:
raise ValueError(
"The 'authjwt_cookie_samesite' must be between 'strict', 'lax', 'none'"
)
return v

@validator('authjwt_csrf_methods', each_item=True)
@validator("authjwt_csrf_methods")
def validate_csrf_methods(cls, v):
if v.upper() not in {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"}:
raise ValueError("The 'authjwt_csrf_methods' must be between http request methods")
raise ValueError(
"The 'authjwt_csrf_methods' must be between http request methods"
)
return v.upper()

class Config:
min_anystr_length = 1
anystr_strip_whitespace = True
str_min_length = 1
str_strip_whitespace = True
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ classifiers = [

requires = [
"fastapi>=0.61.0",
"PyJWT==2.1.0",
"setuptools>=45.2.0"
"PyJWT>=1.7.1"
]

description-file = "README.md"
Expand Down
6 changes: 3 additions & 3 deletions tests/test_decode_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def default_access_token():

@pytest.fixture(scope='function')
def encoded_token(default_access_token):
return jwt.encode(default_access_token,'secret-key',algorithm='HS256')
return jwt.encode(default_access_token,'secret-key',algorithm='HS256').decode('utf-8')

def test_verified_token(client,encoded_token,Authorize):
class SettingsOne(BaseSettings):
Expand All @@ -67,7 +67,7 @@ def get_settings_one():
assert response.status_code == 422
assert response.json() == {'detail': 'Not enough segments'}
# InvalidSignatureError
token = jwt.encode({'some': 'payload'}, 'secret', algorithm='HS256')
token = jwt.encode({'some': 'payload'}, 'secret', algorithm='HS256').decode('utf-8')
response = client.get('/protected',headers={"Authorization":f"Bearer {token}"})
assert response.status_code == 422
assert response.json() == {'detail': 'Signature verification failed'}
Expand All @@ -78,7 +78,7 @@ def get_settings_one():
assert response.status_code == 422
assert response.json() == {'detail': 'Signature has expired'}
# InvalidAlgorithmError
token = jwt.encode({'some': 'payload'}, 'secret', algorithm='HS384')
token = jwt.encode({'some': 'payload'}, 'secret', algorithm='HS384').decode('utf-8')
response = client.get('/protected',headers={"Authorization":f"Bearer {token}"})
assert response.status_code == 422
assert response.json() == {'detail': 'The specified alg value is not allowed'}
Expand Down

0 comments on commit aeec4c7

Please sign in to comment.