Skip to content

Commit

Permalink
Fix create_id_token with extra scope claims + add ruff as formatter.
Browse files Browse the repository at this point in the history
  • Loading branch information
juanifioren committed Dec 5, 2024
1 parent 98b9810 commit b744992
Show file tree
Hide file tree
Showing 8 changed files with 418 additions and 357 deletions.
19 changes: 6 additions & 13 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
{
"[python]": {
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.sortImports": "explicit"
}
},
"python.formatting.provider": "black",
"editor.formatOnSave": true,
"black-formatter.args": [
"--line-length=100",
"--preview",
],
"isort.args": [
"--profile",
"black"
],
"source.fixAll": "explicit",
"source.organizeImports": "explicit"
},
"editor.defaultFormatter": "charliermarsh.ruff"
}
}
4 changes: 2 additions & 2 deletions docs/sections/contribute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ Use `tox <https://pypi.python.org/pypi/tox>`_ for running tests in each of the e
# Run with Python 3.11 and Django 4.2.
$ tox -e py311-django42

# Run single test file on specific environment.
$ tox -e py311-django42 -- tests/cases/test_authorize_endpoint.py
# Run a single test method.
$ tox -e py311-django42 -- tests/cases/test_authorize_endpoint.py::TestClass::test_some_method

We use `Github Actions <https://github.com/juanifioren/django-oidc-provider/actions>`_ to automatically test every commit to the project.

Expand Down
263 changes: 139 additions & 124 deletions oidc_provider/lib/endpoints/authorize.py

Large diffs are not rendered by default.

223 changes: 113 additions & 110 deletions oidc_provider/lib/endpoints/token.py

Large diffs are not rendered by default.

63 changes: 31 additions & 32 deletions oidc_provider/lib/utils/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,52 +19,52 @@
from oidc_provider import settings


def create_id_token(token, user, aud, nonce='', at_hash='', request=None, scope=None):
def create_id_token(token, user, aud, nonce="", at_hash="", request=None, scope=None):
"""
Creates the id_token dictionary.
See: http://openid.net/specs/openid-connect-core-1_0.html#IDToken
Return a dic.
"""
if scope is None:
scope = []
sub = settings.get('OIDC_IDTOKEN_SUB_GENERATOR', import_str=True)(user=user)
sub = settings.get("OIDC_IDTOKEN_SUB_GENERATOR", import_str=True)(user=user)

expires_in = settings.get('OIDC_IDTOKEN_EXPIRE')
expires_in = settings.get("OIDC_IDTOKEN_EXPIRE")

# Convert datetimes into timestamps.
now = int(time.time())
iat_time = now
exp_time = int(now + expires_in)
user_auth_time = user.last_login or user.date_joined
auth_time = int(dateformat.format(user_auth_time, 'U'))
auth_time = int(dateformat.format(user_auth_time, "U"))

dic = {
'iss': get_issuer(request=request),
'sub': sub,
'aud': str(aud),
'exp': exp_time,
'iat': iat_time,
'auth_time': auth_time,
"iss": get_issuer(request=request),
"sub": sub,
"aud": str(aud),
"exp": exp_time,
"iat": iat_time,
"auth_time": auth_time,
}

if nonce:
dic['nonce'] = str(nonce)
dic["nonce"] = str(nonce)

if at_hash:
dic['at_hash'] = at_hash
dic["at_hash"] = at_hash

# Inlude (or not) user standard claims in the id_token.
if settings.get('OIDC_IDTOKEN_INCLUDE_CLAIMS'):
if settings.get('OIDC_EXTRA_SCOPE_CLAIMS'):
custom_claims = settings.get('OIDC_EXTRA_SCOPE_CLAIMS', import_str=True)(token)
claims = custom_claims.create_response_dic()
else:
claims = StandardScopeClaims(token).create_response_dic()
dic.update(claims)
if settings.get("OIDC_IDTOKEN_INCLUDE_CLAIMS"):
standard_claims = StandardScopeClaims(token)
dic.update(standard_claims.create_response_dic())

if settings.get("OIDC_EXTRA_SCOPE_CLAIMS"):
extra_claims = settings.get("OIDC_EXTRA_SCOPE_CLAIMS", import_str=True)(token)
dic.update(extra_claims.create_response_dic())

dic = run_processing_hook(
dic, 'OIDC_IDTOKEN_PROCESSING_HOOK',
user=user, token=token, request=request)
dic, "OIDC_IDTOKEN_PROCESSING_HOOK", user=user, token=token, request=request
)

return dic

Expand Down Expand Up @@ -94,7 +94,7 @@ def client_id_from_id_token(id_token):
Returns a string or None.
"""
payload = JWT().unpack(id_token).payload()
aud = payload.get('aud', None)
aud = payload.get("aud", None)
if aud is None:
return None
if isinstance(aud, list):
Expand All @@ -116,15 +116,15 @@ def create_token(user, client, scope, id_token_dic=None):
token.id_token = id_token_dic

token.refresh_token = uuid.uuid4().hex
token.expires_at = timezone.now() + timedelta(
seconds=settings.get('OIDC_TOKEN_EXPIRE'))
token.expires_at = timezone.now() + timedelta(seconds=settings.get("OIDC_TOKEN_EXPIRE"))
token.scope = scope

return token


def create_code(user, client, scope, nonce, is_authentication,
code_challenge=None, code_challenge_method=None):
def create_code(
user, client, scope, nonce, is_authentication, code_challenge=None, code_challenge_method=None
):
"""
Create and populate a Code object.
Return a Code object.
Expand All @@ -139,8 +139,7 @@ def create_code(user, client, scope, nonce, is_authentication,
code.code_challenge = code_challenge
code.code_challenge_method = code_challenge_method

code.expires_at = timezone.now() + timedelta(
seconds=settings.get('OIDC_CODE_EXPIRE'))
code.expires_at = timezone.now() + timedelta(seconds=settings.get("OIDC_CODE_EXPIRE"))
code.scope = scope
code.nonce = nonce
code.is_authentication = is_authentication
Expand All @@ -153,15 +152,15 @@ def get_client_alg_keys(client):
Takes a client and returns the set of keys associated with it.
Returns a list of keys.
"""
if client.jwt_alg == 'RS256':
if client.jwt_alg == "RS256":
keys = []
for rsakey in RSAKey.objects.all():
keys.append(jwk_RSAKey(key=importKey(rsakey.key), kid=rsakey.kid))
if not keys:
raise Exception('You must add at least one RSA Key.')
elif client.jwt_alg == 'HS256':
raise Exception("You must add at least one RSA Key.")
elif client.jwt_alg == "HS256":
keys = [SYMKey(key=client.client_secret, alg=client.jwt_alg)]
else:
raise Exception('Unsupported key algorithm.')
raise Exception("Unsupported key algorithm.")

return keys
96 changes: 55 additions & 41 deletions oidc_provider/tests/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,27 @@
from django.contrib.auth.backends import ModelBackend

try:
from urlparse import parse_qs, urlsplit
from urlparse import parse_qs
from urlparse import urlsplit
except ImportError:
from urllib.parse import parse_qs, urlsplit
from urllib.parse import parse_qs
from urllib.parse import urlsplit

from django.utils import timezone
from django.contrib.auth.models import User
from django.utils import timezone

from oidc_provider.models import (
Client,
Code,
Token,
ResponseType)

from oidc_provider.lib.claims import ScopeClaims
from oidc_provider.models import Client
from oidc_provider.models import Code
from oidc_provider.models import ResponseType
from oidc_provider.models import Token

FAKE_NONCE = 'cb584e44c43ed6bd0bc2d9c7e242837d'
FAKE_RANDOM_STRING = ''.join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(32))
FAKE_CODE_CHALLENGE = 'YlYXEqXuRm-Xgi2BOUiK50JW1KsGTX6F1TDnZSC8VTg'
FAKE_CODE_VERIFIER = 'SmxGa0XueyNh5bDgTcSrqzAh2_FmXEqU8kDT6CuXicw'
FAKE_NONCE = "cb584e44c43ed6bd0bc2d9c7e242837d"
FAKE_RANDOM_STRING = "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(32)
)
FAKE_CODE_CHALLENGE = "YlYXEqXuRm-Xgi2BOUiK50JW1KsGTX6F1TDnZSC8VTg"
FAKE_CODE_VERIFIER = "SmxGa0XueyNh5bDgTcSrqzAh2_FmXEqU8kDT6CuXicw"


def create_fake_user():
Expand All @@ -33,11 +35,11 @@ def create_fake_user():
Return a User object.
"""
user = User()
user.username = 'johndoe'
user.email = '[email protected]'
user.first_name = 'John'
user.last_name = 'Doe'
user.set_password('1234')
user.username = "johndoe"
user.email = "[email protected]"
user.first_name = "John"
user.last_name = "Doe"
user.set_password("1234")

user.save()

Expand All @@ -52,20 +54,20 @@ def create_fake_client(response_type, is_public=False, require_consent=True):
Return a Client object.
"""
client = Client()
client.name = 'Some Client'
client.name = "Some Client"
client.client_id = str(random.randint(1, 999999)).zfill(6)
if is_public:
client.client_type = 'public'
client.client_secret = ''
client.client_type = "public"
client.client_secret = ""
else:
client.client_secret = str(random.randint(1, 999999)).zfill(6)
client.redirect_uris = ['http://example.com/']
client.redirect_uris = ["http://example.com/"]
client.require_consent = require_consent
client.scope = ['openid', 'email']
client.scope = ["openid", "email"]
client.save()

# check if response_type is a string in a python 2 and 3 compatible way
if isinstance(response_type, ("".__class__, u"".__class__)):
if isinstance(response_type, ("".__class__, "".__class__)):
response_type = (response_type,)
for value in response_type:
client.response_types.add(ResponseType.objects.get(value=value))
Expand All @@ -90,7 +92,7 @@ def is_code_valid(url, user, client):
try:
parsed = urlsplit(url)
params = parse_qs(parsed.query or parsed.fragment)
code = params['code'][0]
code = params["code"][0]
code = Code.objects.get(code=code)
is_code_ok = (code.client == client) and (code.user == user)
except Exception:
Expand All @@ -103,15 +105,28 @@ def userinfo(claims, user):
"""
Fake function for setting OIDC_USERINFO.
"""
claims['given_name'] = 'John'
claims['family_name'] = 'Doe'
claims['name'] = '{0} {1}'.format(claims['given_name'], claims['family_name'])
claims['email'] = user.email
claims['email_verified'] = True
claims['address']['country'] = 'Argentina'
claims["given_name"] = "John"
claims["family_name"] = "Doe"
claims["name"] = "{0} {1}".format(claims["given_name"], claims["family_name"])
claims["email"] = user.email
claims["email_verified"] = True
claims["address"]["country"] = "Argentina"
return claims


class FakeScopeClaims(ScopeClaims):
info_pizza = (
"Pizza",
"Some description for the scope.",
)

def scope_pizza(self):
dic = {
"pizza": "Margherita",
}
return dic


def fake_sub_generator(user):
"""
Fake function for setting OIDC_IDTOKEN_SUB_GENERATOR.
Expand All @@ -123,8 +138,8 @@ def fake_idtoken_processing_hook(id_token, user, **kwargs):
"""
Fake function for inserting some keys into token. Testing OIDC_IDTOKEN_PROCESSING_HOOK.
"""
id_token['test_idtoken_processing_hook'] = FAKE_RANDOM_STRING
id_token['test_idtoken_processing_hook_user_email'] = user.email
id_token["test_idtoken_processing_hook"] = FAKE_RANDOM_STRING
id_token["test_idtoken_processing_hook_user_email"] = user.email
return id_token


Expand All @@ -133,32 +148,31 @@ def fake_idtoken_processing_hook2(id_token, user, **kwargs):
Fake function for inserting some keys into token.
Testing OIDC_IDTOKEN_PROCESSING_HOOK - tuple or list as param
"""
id_token['test_idtoken_processing_hook2'] = FAKE_RANDOM_STRING
id_token['test_idtoken_processing_hook_user_email2'] = user.email
id_token["test_idtoken_processing_hook2"] = FAKE_RANDOM_STRING
id_token["test_idtoken_processing_hook_user_email2"] = user.email
return id_token


def fake_idtoken_processing_hook3(id_token, user, token, **kwargs):
"""
Fake function for checking scope is passed to processing hook.
"""
id_token['scope_of_token_passed_to_processing_hook'] = token.scope
id_token["scope_of_token_passed_to_processing_hook"] = token.scope
return id_token


def fake_idtoken_processing_hook4(id_token, user, **kwargs):
"""
Fake function for checking kwargs passed to processing hook.
"""
id_token['kwargs_passed_to_processing_hook'] = {
key: repr(value)
for (key, value) in kwargs.items()
id_token["kwargs_passed_to_processing_hook"] = {
key: repr(value) for (key, value) in kwargs.items()
}
return id_token


def fake_introspection_processing_hook(response_dict, client, id_token):
response_dict['test_introspection_processing_hook'] = FAKE_RANDOM_STRING
response_dict["test_introspection_processing_hook"] = FAKE_RANDOM_STRING
return response_dict


Expand Down
Loading

0 comments on commit b744992

Please sign in to comment.