From b7449923174f4fad14d628e881b1c713e2e9636e Mon Sep 17 00:00:00 2001 From: juanifioren Date: Wed, 4 Dec 2024 22:56:33 -0300 Subject: [PATCH] Fix create_id_token with extra scope claims + add ruff as formatter. --- .vscode/settings.json | 19 +- docs/sections/contribute.rst | 4 +- oidc_provider/lib/endpoints/authorize.py | 263 ++++++++++++----------- oidc_provider/lib/endpoints/token.py | 223 +++++++++---------- oidc_provider/lib/utils/token.py | 63 +++--- oidc_provider/tests/app/utils.py | 96 +++++---- oidc_provider/tests/cases/test_utils.py | 102 ++++++--- pyproject.toml | 5 + 8 files changed, 418 insertions(+), 357 deletions(-) create mode 100644 pyproject.toml diff --git a/.vscode/settings.json b/.vscode/settings.json index ab96ebe3..e64211ad 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,17 +1,10 @@ { "[python]": { + "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.sortImports": "explicit" - } - }, - "python.formatting.provider": "black", - "editor.formatOnSave": true, - "black-formatter.args": [ - "--line-length=100", - "--preview", - ], - "isort.args": [ - "--profile", - "black" - ], + "source.fixAll": "explicit", + "source.organizeImports": "explicit" + }, + "editor.defaultFormatter": "charliermarsh.ruff" + } } \ No newline at end of file diff --git a/docs/sections/contribute.rst b/docs/sections/contribute.rst index 8d29a332..121cc883 100644 --- a/docs/sections/contribute.rst +++ b/docs/sections/contribute.rst @@ -24,8 +24,8 @@ Use `tox `_ for running tests in each of the e # Run with Python 3.11 and Django 4.2. $ tox -e py311-django42 - # Run single test file on specific environment. - $ tox -e py311-django42 -- tests/cases/test_authorize_endpoint.py + # Run a single test method. + $ tox -e py311-django42 -- tests/cases/test_authorize_endpoint.py::TestClass::test_some_method We use `Github Actions `_ to automatically test every commit to the project. diff --git a/oidc_provider/lib/endpoints/authorize.py b/oidc_provider/lib/endpoints/authorize.py index 4728158e..c14d7689 100644 --- a/oidc_provider/lib/endpoints/authorize.py +++ b/oidc_provider/lib/endpoints/authorize.py @@ -1,42 +1,41 @@ -from datetime import timedelta -from hashlib import ( - md5, - sha256, -) import logging +from datetime import timedelta +from hashlib import md5 +from hashlib import sha256 + try: from urllib import urlencode - from urlparse import urlsplit, parse_qs, urlunsplit + + from urlparse import parse_qs + from urlparse import urlsplit + from urlparse import urlunsplit except ImportError: - from urllib.parse import urlsplit, parse_qs, urlunsplit, urlencode + from urllib.parse import parse_qs + from urllib.parse import urlencode + from urllib.parse import urlsplit + from urllib.parse import urlunsplit from uuid import uuid4 from django.utils import timezone -from oidc_provider.lib.claims import StandardScopeClaims -from oidc_provider.lib.errors import ( - AuthorizeError, - ClientIdError, - RedirectUriError, -) -from oidc_provider.lib.utils.token import ( - create_code, - create_id_token, - create_token, - encode_id_token, -) -from oidc_provider.models import ( - Client, - UserConsent, -) from oidc_provider import settings +from oidc_provider.lib.claims import StandardScopeClaims +from oidc_provider.lib.errors import AuthorizeError +from oidc_provider.lib.errors import ClientIdError +from oidc_provider.lib.errors import RedirectUriError from oidc_provider.lib.utils.common import get_browser_state_or_default +from oidc_provider.lib.utils.token import create_code +from oidc_provider.lib.utils.token import create_id_token +from oidc_provider.lib.utils.token import create_token +from oidc_provider.lib.utils.token import encode_id_token +from oidc_provider.models import Client +from oidc_provider.models import UserConsent logger = logging.getLogger(__name__) class AuthorizeEndpoint(object): - _allowed_prompt_params = {'none', 'login', 'consent', 'select_account'} + _allowed_prompt_params = {"none", "login", "consent", "select_account"} client_class = Client def __init__(self, request): @@ -46,18 +45,17 @@ def __init__(self, request): self._extract_params() # Determine which flow to use. - if self.params['response_type'] in ['code']: - self.grant_type = 'authorization_code' - elif self.params['response_type'] in ['id_token', 'id_token token', 'token']: - self.grant_type = 'implicit' - elif self.params['response_type'] in [ - 'code token', 'code id_token', 'code id_token token']: - self.grant_type = 'hybrid' + if self.params["response_type"] in ["code"]: + self.grant_type = "authorization_code" + elif self.params["response_type"] in ["id_token", "id_token token", "token"]: + self.grant_type = "implicit" + elif self.params["response_type"] in ["code token", "code id_token", "code id_token token"]: + self.grant_type = "hybrid" else: self.grant_type = None # Determine if it's an OpenID Authentication request (or OAuth2). - self.is_authentication = 'openid' in self.params['scope'] + self.is_authentication = "openid" in self.params["scope"] def _extract_params(self): """ @@ -68,73 +66,79 @@ def _extract_params(self): """ # Because in this endpoint we handle both GET # and POST request. - query_dict = (self.request.POST if self.request.method == 'POST' - else self.request.GET) + query_dict = self.request.POST if self.request.method == "POST" else self.request.GET - self.params['client_id'] = query_dict.get('client_id', '') - self.params['redirect_uri'] = query_dict.get('redirect_uri', '') - self.params['response_type'] = query_dict.get('response_type', '') - self.params['scope'] = query_dict.get('scope', '').split() - self.params['state'] = query_dict.get('state', '') - self.params['nonce'] = query_dict.get('nonce', '') + self.params["client_id"] = query_dict.get("client_id", "") + self.params["redirect_uri"] = query_dict.get("redirect_uri", "") + self.params["response_type"] = query_dict.get("response_type", "") + self.params["scope"] = query_dict.get("scope", "").split() + self.params["state"] = query_dict.get("state", "") + self.params["nonce"] = query_dict.get("nonce", "") - self.params['prompt'] = self._allowed_prompt_params.intersection( - set(query_dict.get('prompt', '').split())) + self.params["prompt"] = self._allowed_prompt_params.intersection( + set(query_dict.get("prompt", "").split()) + ) - self.params['code_challenge'] = query_dict.get('code_challenge', '') - self.params['code_challenge_method'] = query_dict.get('code_challenge_method', '') + self.params["code_challenge"] = query_dict.get("code_challenge", "") + self.params["code_challenge_method"] = query_dict.get("code_challenge_method", "") def validate_params(self): # Client validation. try: - self.client = self.client_class.objects.get(client_id=self.params['client_id']) + self.client = self.client_class.objects.get(client_id=self.params["client_id"]) except Client.DoesNotExist: - logger.debug('[Authorize] Invalid client identifier: %s', self.params['client_id']) + logger.debug("[Authorize] Invalid client identifier: %s", self.params["client_id"]) raise ClientIdError() # Redirect URI validation. - if self.is_authentication and not self.params['redirect_uri']: - logger.debug('[Authorize] Missing redirect uri.') + if self.is_authentication and not self.params["redirect_uri"]: + logger.debug("[Authorize] Missing redirect uri.") raise RedirectUriError() - if not (self.params['redirect_uri'] in self.client.redirect_uris): - logger.debug('[Authorize] Invalid redirect uri: %s', self.params['redirect_uri']) + if self.params["redirect_uri"] not in self.client.redirect_uris: + logger.debug("[Authorize] Invalid redirect uri: %s", self.params["redirect_uri"]) raise RedirectUriError() # Grant type validation. if not self.grant_type: - logger.debug('[Authorize] Invalid response type: %s', self.params['response_type']) + logger.debug("[Authorize] Invalid response type: %s", self.params["response_type"]) raise AuthorizeError( - self.params['redirect_uri'], 'unsupported_response_type', self.grant_type) + self.params["redirect_uri"], "unsupported_response_type", self.grant_type + ) - if (not self.is_authentication and (self.grant_type == 'hybrid' or - self.params['response_type'] in ['id_token', 'id_token token'])): - logger.debug('[Authorize] Missing openid scope.') - raise AuthorizeError(self.params['redirect_uri'], 'invalid_scope', self.grant_type) + if not self.is_authentication and ( + self.grant_type == "hybrid" + or self.params["response_type"] in ["id_token", "id_token token"] + ): + logger.debug("[Authorize] Missing openid scope.") + raise AuthorizeError(self.params["redirect_uri"], "invalid_scope", self.grant_type) # Nonce parameter validation. - if self.is_authentication and self.grant_type == 'implicit' and not self.params['nonce']: - raise AuthorizeError(self.params['redirect_uri'], 'invalid_request', self.grant_type) + if self.is_authentication and self.grant_type == "implicit" and not self.params["nonce"]: + raise AuthorizeError(self.params["redirect_uri"], "invalid_request", self.grant_type) # Response type parameter validation. - if self.is_authentication \ - and self.params['response_type'] not in self.client.response_type_values(): - raise AuthorizeError(self.params['redirect_uri'], 'invalid_request', self.grant_type) + if ( + self.is_authentication + and self.params["response_type"] not in self.client.response_type_values() + ): + raise AuthorizeError(self.params["redirect_uri"], "invalid_request", self.grant_type) # PKCE validation of the transformation method. - if self.params['code_challenge']: - if not (self.params['code_challenge_method'] in ['plain', 'S256']): + if self.params["code_challenge"]: + if self.params["code_challenge_method"] not in ["plain", "S256"]: raise AuthorizeError( - self.params['redirect_uri'], 'invalid_request', self.grant_type) + self.params["redirect_uri"], "invalid_request", self.grant_type + ) def create_code(self): code = create_code( user=self.request.user, client=self.client, - scope=self.params['scope'], - nonce=self.params['nonce'], + scope=self.params["scope"], + nonce=self.params["nonce"], is_authentication=self.is_authentication, - code_challenge=self.params['code_challenge'], - code_challenge_method=self.params['code_challenge_method'], + code_challenge=self.params["code_challenge"], + code_challenge_method=self.params["code_challenge_method"], ) return code @@ -143,50 +147,58 @@ def create_token(self): token = create_token( user=self.request.user, client=self.client, - scope=self.params['scope'], + scope=self.params["scope"], ) return token def create_response_uri(self): - uri = urlsplit(self.params['redirect_uri']) + uri = urlsplit(self.params["redirect_uri"]) query_params = parse_qs(uri.query) query_fragment = {} try: - if self.grant_type in ['authorization_code', 'hybrid']: + if self.grant_type in ["authorization_code", "hybrid"]: code = self.create_code() code.save() - if self.grant_type == 'authorization_code': - query_params['code'] = code.code - query_params['state'] = self.params['state'] if self.params['state'] else '' - elif self.grant_type in ['implicit', 'hybrid']: + if self.grant_type == "authorization_code": + query_params["code"] = code.code + query_params["state"] = self.params["state"] if self.params["state"] else "" + elif self.grant_type in ["implicit", "hybrid"]: token = self.create_token() # Check if response_type must include access_token in the response. - if (self.params['response_type'] in - ['id_token token', 'token', 'code token', 'code id_token token']): - query_fragment['access_token'] = token.access_token + if self.params["response_type"] in [ + "id_token token", + "token", + "code token", + "code id_token token", + ]: + query_fragment["access_token"] = token.access_token # We don't need id_token if it's an OAuth2 request. if self.is_authentication: kwargs = { - 'token': token, - 'user': self.request.user, - 'aud': self.client.client_id, - 'nonce': self.params['nonce'], - 'request': self.request, - 'scope': self.params['scope'], + "token": token, + "user": self.request.user, + "aud": self.client.client_id, + "nonce": self.params["nonce"], + "request": self.request, + "scope": self.params["scope"], } # Include at_hash when access_token is being returned. - if 'access_token' in query_fragment: - kwargs['at_hash'] = token.at_hash + if "access_token" in query_fragment: + kwargs["at_hash"] = token.at_hash id_token_dic = create_id_token(**kwargs) # Check if response_type must include id_token in the response. - if self.params['response_type'] in [ - 'id_token', 'id_token token', 'code id_token', 'code id_token token']: - query_fragment['id_token'] = encode_id_token(id_token_dic, self.client) + if self.params["response_type"] in [ + "id_token", + "id_token token", + "code id_token", + "code id_token token", + ]: + query_fragment["id_token"] = encode_id_token(id_token_dic, self.client) else: id_token_dic = {} @@ -195,20 +207,21 @@ def create_response_uri(self): token.save() # Code parameter must be present if it's Hybrid Flow. - if self.grant_type == 'hybrid': - query_fragment['code'] = code.code + if self.grant_type == "hybrid": + query_fragment["code"] = code.code - query_fragment['token_type'] = 'bearer' + query_fragment["token_type"] = "bearer" - query_fragment['expires_in'] = settings.get('OIDC_TOKEN_EXPIRE') + query_fragment["expires_in"] = settings.get("OIDC_TOKEN_EXPIRE") - query_fragment['state'] = self.params['state'] if self.params['state'] else '' + query_fragment["state"] = self.params["state"] if self.params["state"] else "" - if settings.get('OIDC_SESSION_MANAGEMENT_ENABLE'): + if settings.get("OIDC_SESSION_MANAGEMENT_ENABLE"): # Generate client origin URI from the redirect_uri param. - redirect_uri_parsed = urlsplit(self.params['redirect_uri']) - client_origin = '{0}://{1}'.format( - redirect_uri_parsed.scheme, redirect_uri_parsed.netloc) + redirect_uri_parsed = urlsplit(self.params["redirect_uri"]) + client_origin = "{0}://{1}".format( + redirect_uri_parsed.scheme, redirect_uri_parsed.netloc + ) # Create random salt. salt = md5(uuid4().hex.encode()).hexdigest() @@ -216,25 +229,27 @@ def create_response_uri(self): # The generation of suitable Session State values is based # on a salted cryptographic hash of Client ID, origin URL, # and OP browser state. - session_state = '{client_id} {origin} {browser_state} {salt}'.format( + session_state = "{client_id} {origin} {browser_state} {salt}".format( client_id=self.client.client_id, origin=client_origin, browser_state=get_browser_state_or_default(self.request), - salt=salt) - session_state = sha256(session_state.encode('utf-8')).hexdigest() - session_state += '.' + salt - if self.grant_type == 'authorization_code': - query_params['session_state'] = session_state - elif self.grant_type in ['implicit', 'hybrid']: - query_fragment['session_state'] = session_state + salt=salt, + ) + session_state = sha256(session_state.encode("utf-8")).hexdigest() + session_state += "." + salt + if self.grant_type == "authorization_code": + query_params["session_state"] = session_state + elif self.grant_type in ["implicit", "hybrid"]: + query_fragment["session_state"] = session_state except Exception as error: - logger.exception('[Authorize] Error when trying to create response uri: %s', error) - raise AuthorizeError(self.params['redirect_uri'], 'server_error', self.grant_type) + logger.exception("[Authorize] Error when trying to create response uri: %s", error) + raise AuthorizeError(self.params["redirect_uri"], "server_error", self.grant_type) uri = uri._replace( query=urlencode(query_params, doseq=True), - fragment=uri.fragment + urlencode(query_fragment, doseq=True)) + fragment=uri.fragment + urlencode(query_fragment, doseq=True), + ) return urlunsplit(uri) @@ -245,18 +260,17 @@ def set_client_user_consent(self): Return None. """ date_given = timezone.now() - expires_at = date_given + timedelta( - days=settings.get('OIDC_SKIP_CONSENT_EXPIRE')) + expires_at = date_given + timedelta(days=settings.get("OIDC_SKIP_CONSENT_EXPIRE")) uc, created = UserConsent.objects.get_or_create( user=self.request.user, client=self.client, defaults={ - 'expires_at': expires_at, - 'date_given': date_given, - } + "expires_at": expires_at, + "date_given": date_given, + }, ) - uc.scope = self.params['scope'] + uc.scope = self.params["scope"] # Rewrite expires_at and date_given if object already exists. if not created: @@ -274,7 +288,7 @@ def client_has_user_consent(self): value = False try: uc = UserConsent.objects.get(user=self.request.user, client=self.client) - if (set(self.params['scope']).issubset(uc.scope)) and not (uc.has_expired()): + if (set(self.params["scope"]).issubset(uc.scope)) and not (uc.has_expired()): value = True except UserConsent.DoesNotExist: pass @@ -282,23 +296,24 @@ def client_has_user_consent(self): return value def is_client_allowed_to_skip_consent(self): - implicit_flow_resp_types = {'id_token', 'id_token token'} + implicit_flow_resp_types = {"id_token", "id_token token"} return ( - self.client.client_type != 'public' or - self.params['response_type'] in implicit_flow_resp_types + self.client.client_type != "public" + or self.params["response_type"] in implicit_flow_resp_types ) def get_scopes_information(self): """ Return a list with the description of all the scopes requested. """ - scopes = StandardScopeClaims.get_scopes_info(self.params['scope']) - if settings.get('OIDC_EXTRA_SCOPE_CLAIMS'): - scopes_extra = settings.get( - 'OIDC_EXTRA_SCOPE_CLAIMS', import_str=True).get_scopes_info(self.params['scope']) + scopes = StandardScopeClaims.get_scopes_info(self.params["scope"]) + if settings.get("OIDC_EXTRA_SCOPE_CLAIMS"): + scopes_extra = settings.get("OIDC_EXTRA_SCOPE_CLAIMS", import_str=True).get_scopes_info( + self.params["scope"] + ) for index_extra, scope_extra in enumerate(scopes_extra): for index, scope in enumerate(scopes[:]): - if scope_extra['scope'] == scope['scope']: + if scope_extra["scope"] == scope["scope"]: del scopes[index] else: scopes_extra = [] diff --git a/oidc_provider/lib/endpoints/token.py b/oidc_provider/lib/endpoints/token.py index 73a29937..39cc871f 100644 --- a/oidc_provider/lib/endpoints/token.py +++ b/oidc_provider/lib/endpoints/token.py @@ -8,27 +8,20 @@ from django.http import JsonResponse from oidc_provider import settings -from oidc_provider.lib.errors import ( - TokenError, - UserAuthError, -) +from oidc_provider.lib.errors import TokenError +from oidc_provider.lib.errors import UserAuthError from oidc_provider.lib.utils.oauth2 import extract_client_auth -from oidc_provider.lib.utils.token import ( - create_id_token, - create_token, - encode_id_token, -) -from oidc_provider.models import ( - Client, - Code, - Token, -) +from oidc_provider.lib.utils.token import create_id_token +from oidc_provider.lib.utils.token import create_token +from oidc_provider.lib.utils.token import encode_id_token +from oidc_provider.models import Client +from oidc_provider.models import Code +from oidc_provider.models import Token logger = logging.getLogger(__name__) class TokenEndpoint(object): - def __init__(self, request): self.request = request self.params = {} @@ -38,72 +31,79 @@ def __init__(self, request): def _extract_params(self): client_id, client_secret = extract_client_auth(self.request) - self.params['client_id'] = client_id - self.params['client_secret'] = client_secret - self.params['redirect_uri'] = self.request.POST.get('redirect_uri', '') - self.params['grant_type'] = self.request.POST.get('grant_type', '') - self.params['code'] = self.request.POST.get('code', '') - self.params['state'] = self.request.POST.get('state', '') - self.params['scope'] = self.request.POST.get('scope', '') - self.params['refresh_token'] = self.request.POST.get('refresh_token', '') + self.params["client_id"] = client_id + self.params["client_secret"] = client_secret + self.params["redirect_uri"] = self.request.POST.get("redirect_uri", "") + self.params["grant_type"] = self.request.POST.get("grant_type", "") + self.params["code"] = self.request.POST.get("code", "") + self.params["state"] = self.request.POST.get("state", "") + self.params["scope"] = self.request.POST.get("scope", "") + self.params["refresh_token"] = self.request.POST.get("refresh_token", "") # PKCE parameter. - self.params['code_verifier'] = self.request.POST.get('code_verifier') + self.params["code_verifier"] = self.request.POST.get("code_verifier") - self.params['username'] = self.request.POST.get('username', '') - self.params['password'] = self.request.POST.get('password', '') + self.params["username"] = self.request.POST.get("username", "") + self.params["password"] = self.request.POST.get("password", "") def validate_params(self): try: - self.client = Client.objects.get(client_id=self.params['client_id']) + self.client = Client.objects.get(client_id=self.params["client_id"]) except Client.DoesNotExist: - logger.debug('[Token] Client does not exist: %s', self.params['client_id']) - raise TokenError('invalid_client') + logger.debug("[Token] Client does not exist: %s", self.params["client_id"]) + raise TokenError("invalid_client") - if self.client.client_type == 'confidential': - if not (self.client.client_secret == self.params['client_secret']): - logger.debug('[Token] Invalid client secret: client %s do not have secret %s', - self.client.client_id, self.client.client_secret) - raise TokenError('invalid_client') + if self.client.client_type == "confidential": + if not (self.client.client_secret == self.params["client_secret"]): + logger.debug( + "[Token] Invalid client secret: client %s do not have secret %s", + self.client.client_id, + self.client.client_secret, + ) + raise TokenError("invalid_client") - if self.params['grant_type'] == 'authorization_code': - if not (self.params['redirect_uri'] in self.client.redirect_uris): - logger.debug('[Token] Invalid redirect uri: %s', self.params['redirect_uri']) - raise TokenError('invalid_client') + if self.params["grant_type"] == "authorization_code": + if self.params["redirect_uri"] not in self.client.redirect_uris: + logger.debug("[Token] Invalid redirect uri: %s", self.params["redirect_uri"]) + raise TokenError("invalid_client") try: self.code = Code.objects.select_for_update(nowait=True).get( - code=self.params['code']) + code=self.params["code"] + ) except DatabaseError: - logger.debug('[Token] Code cannot be reused: %s', self.params['code']) - raise TokenError('invalid_grant') + logger.debug("[Token] Code cannot be reused: %s", self.params["code"]) + raise TokenError("invalid_grant") except Code.DoesNotExist: - logger.debug('[Token] Code does not exist: %s', self.params['code']) - raise TokenError('invalid_grant') + logger.debug("[Token] Code does not exist: %s", self.params["code"]) + raise TokenError("invalid_grant") - if not (self.code.client == self.client) \ - or self.code.has_expired(): - logger.debug('[Token] Invalid code: invalid client or code has expired') - raise TokenError('invalid_grant') + if not (self.code.client == self.client) or self.code.has_expired(): + logger.debug("[Token] Invalid code: invalid client or code has expired") + raise TokenError("invalid_grant") # Validate PKCE parameters. if self.code.code_challenge: - if self.params['code_verifier'] is None: - raise TokenError('invalid_grant') - - if self.code.code_challenge_method == 'S256': - new_code_challenge = urlsafe_b64encode( - hashlib.sha256(self.params['code_verifier'].encode('ascii')).digest() - ).decode('utf-8').replace('=', '') + if self.params["code_verifier"] is None: + raise TokenError("invalid_grant") + + if self.code.code_challenge_method == "S256": + new_code_challenge = ( + urlsafe_b64encode( + hashlib.sha256(self.params["code_verifier"].encode("ascii")).digest() + ) + .decode("utf-8") + .replace("=", "") + ) else: - new_code_challenge = self.params['code_verifier'] + new_code_challenge = self.params["code_verifier"] # TODO: We should explain the error. if not (new_code_challenge == self.code.code_challenge): - raise TokenError('invalid_grant') + raise TokenError("invalid_grant") - elif self.params['grant_type'] == 'password': - if not settings.get('OIDC_GRANT_TYPE_PASSWORD_ENABLE'): - raise TokenError('unsupported_grant_type') + elif self.params["grant_type"] == "password": + if not settings.get("OIDC_GRANT_TYPE_PASSWORD_ENABLE"): + raise TokenError("unsupported_grant_type") auth_args = (self.request,) try: @@ -112,9 +112,7 @@ def validate_params(self): auth_args = () user = authenticate( - *auth_args, - username=self.params['username'], - password=self.params['password'] + *auth_args, username=self.params["username"], password=self.params["password"] ) if not user: @@ -122,56 +120,61 @@ def validate_params(self): self.user = user - elif self.params['grant_type'] == 'refresh_token': - if not self.params['refresh_token']: - logger.debug('[Token] Missing refresh token') - raise TokenError('invalid_grant') + elif self.params["grant_type"] == "refresh_token": + if not self.params["refresh_token"]: + logger.debug("[Token] Missing refresh token") + raise TokenError("invalid_grant") try: - self.token = Token.objects.get(refresh_token=self.params['refresh_token'], - client=self.client) + self.token = Token.objects.get( + refresh_token=self.params["refresh_token"], client=self.client + ) except Token.DoesNotExist: logger.debug( - '[Token] Refresh token does not exist: %s', self.params['refresh_token']) - raise TokenError('invalid_grant') - elif self.params['grant_type'] == 'client_credentials': + "[Token] Refresh token does not exist: %s", self.params["refresh_token"] + ) + raise TokenError("invalid_grant") + elif self.params["grant_type"] == "client_credentials": if not self.client._scope: - logger.debug('[Token] Client using client credentials with empty scope') - raise TokenError('invalid_scope') + logger.debug("[Token] Client using client credentials with empty scope") + raise TokenError("invalid_scope") else: - logger.debug('[Token] Invalid grant type: %s', self.params['grant_type']) - raise TokenError('unsupported_grant_type') + logger.debug("[Token] Invalid grant type: %s", self.params["grant_type"]) + raise TokenError("unsupported_grant_type") def validate_requested_scopes(self): """ Handling validation of requested scope for grant_type=[password|client_credentials] """ token_scopes = [] - if self.params['scope']: + if self.params["scope"]: # See https://tools.ietf.org/html/rfc6749#section-3.3 # The value of the scope parameter is expressed # as a list of space-delimited, case-sensitive strings - for scope_requested in self.params['scope'].split(' '): + for scope_requested in self.params["scope"].split(" "): if scope_requested in self.client.scope: token_scopes.append(scope_requested) else: - logger.debug('[Token] The request scope %s is not supported by client %s', - scope_requested, self.client.client_id) - raise TokenError('invalid_scope') + logger.debug( + "[Token] The request scope %s is not supported by client %s", + scope_requested, + self.client.client_id, + ) + raise TokenError("invalid_scope") # if no scopes requested assign client's scopes else: token_scopes.extend(self.client.scope) return token_scopes def create_response_dic(self): - if self.params['grant_type'] == 'authorization_code': + if self.params["grant_type"] == "authorization_code": return self.create_code_response_dic() - elif self.params['grant_type'] == 'refresh_token': + elif self.params["grant_type"] == "refresh_token": return self.create_refresh_response_dic() - elif self.params['grant_type'] == 'password': + elif self.params["grant_type"] == "password": return self.create_access_token_response_dic() - elif self.params['grant_type'] == 'client_credentials': + elif self.params["grant_type"] == "client_credentials": return self.create_client_credentials_response_dic() def create_token(self, user, client, scope): @@ -213,11 +216,11 @@ def create_code_response_dic(self): self.code.delete() dic = { - 'access_token': token.access_token, - 'refresh_token': token.refresh_token, - 'token_type': 'bearer', - 'expires_in': settings.get('OIDC_TOKEN_EXPIRE'), - 'id_token': encode_id_token(id_token_dic, token.client), + "access_token": token.access_token, + "refresh_token": token.refresh_token, + "token_type": "bearer", + "expires_in": settings.get("OIDC_TOKEN_EXPIRE"), + "id_token": encode_id_token(id_token_dic, token.client), } return dic @@ -225,11 +228,11 @@ def create_code_response_dic(self): def create_refresh_response_dic(self): # See https://tools.ietf.org/html/rfc6749#section-6 - scope_param = self.params['scope'] - scope = (scope_param.split(' ') if scope_param else self.token.scope) + scope_param = self.params["scope"] + scope = scope_param.split(" ") if scope_param else self.token.scope unauthorized_scopes = set(scope) - set(self.token.scope) if unauthorized_scopes: - raise TokenError('invalid_scope') + raise TokenError("invalid_scope") token = self.create_token( user=self.token.user, @@ -259,11 +262,11 @@ def create_refresh_response_dic(self): self.token.delete() dic = { - 'access_token': token.access_token, - 'refresh_token': token.refresh_token, - 'token_type': 'bearer', - 'expires_in': settings.get('OIDC_TOKEN_EXPIRE'), - 'id_token': encode_id_token(id_token_dic, self.token.client), + "access_token": token.access_token, + "refresh_token": token.refresh_token, + "token_type": "bearer", + "expires_in": settings.get("OIDC_TOKEN_EXPIRE"), + "id_token": encode_id_token(id_token_dic, self.token.client), } return dic @@ -281,7 +284,7 @@ def create_access_token_response_dic(self): token=token, user=self.user, aud=self.client.client_id, - nonce='self.code.nonce', + nonce="self.code.nonce", at_hash=token.at_hash, request=self.request, scope=token.scope, @@ -291,12 +294,12 @@ def create_access_token_response_dic(self): token.save() return { - 'access_token': token.access_token, - 'refresh_token': token.refresh_token, - 'expires_in': settings.get('OIDC_TOKEN_EXPIRE'), - 'token_type': 'bearer', - 'id_token': encode_id_token(id_token_dic, token.client), - 'scope': ' '.join(token.scope) + "access_token": token.access_token, + "refresh_token": token.refresh_token, + "expires_in": settings.get("OIDC_TOKEN_EXPIRE"), + "token_type": "bearer", + "id_token": encode_id_token(id_token_dic, token.client), + "scope": " ".join(token.scope), } def create_client_credentials_response_dic(self): @@ -311,10 +314,10 @@ def create_client_credentials_response_dic(self): token.save() return { - 'access_token': token.access_token, - 'expires_in': settings.get('OIDC_TOKEN_EXPIRE'), - 'token_type': 'bearer', - 'scope': ' '.join(token.scope), + "access_token": token.access_token, + "expires_in": settings.get("OIDC_TOKEN_EXPIRE"), + "token_type": "bearer", + "scope": " ".join(token.scope), } @classmethod @@ -323,7 +326,7 @@ def response(cls, dic, status=200): Create and return a response object. """ response = JsonResponse(dic, status=status) - response['Cache-Control'] = 'no-store' - response['Pragma'] = 'no-cache' + response["Cache-Control"] = "no-store" + response["Pragma"] = "no-cache" return response diff --git a/oidc_provider/lib/utils/token.py b/oidc_provider/lib/utils/token.py index d3fd3ab2..403440ad 100644 --- a/oidc_provider/lib/utils/token.py +++ b/oidc_provider/lib/utils/token.py @@ -19,7 +19,7 @@ from oidc_provider import settings -def create_id_token(token, user, aud, nonce='', at_hash='', request=None, scope=None): +def create_id_token(token, user, aud, nonce="", at_hash="", request=None, scope=None): """ Creates the id_token dictionary. See: http://openid.net/specs/openid-connect-core-1_0.html#IDToken @@ -27,44 +27,44 @@ def create_id_token(token, user, aud, nonce='', at_hash='', request=None, scope= """ if scope is None: scope = [] - sub = settings.get('OIDC_IDTOKEN_SUB_GENERATOR', import_str=True)(user=user) + sub = settings.get("OIDC_IDTOKEN_SUB_GENERATOR", import_str=True)(user=user) - expires_in = settings.get('OIDC_IDTOKEN_EXPIRE') + expires_in = settings.get("OIDC_IDTOKEN_EXPIRE") # Convert datetimes into timestamps. now = int(time.time()) iat_time = now exp_time = int(now + expires_in) user_auth_time = user.last_login or user.date_joined - auth_time = int(dateformat.format(user_auth_time, 'U')) + auth_time = int(dateformat.format(user_auth_time, "U")) dic = { - 'iss': get_issuer(request=request), - 'sub': sub, - 'aud': str(aud), - 'exp': exp_time, - 'iat': iat_time, - 'auth_time': auth_time, + "iss": get_issuer(request=request), + "sub": sub, + "aud": str(aud), + "exp": exp_time, + "iat": iat_time, + "auth_time": auth_time, } if nonce: - dic['nonce'] = str(nonce) + dic["nonce"] = str(nonce) if at_hash: - dic['at_hash'] = at_hash + dic["at_hash"] = at_hash # Inlude (or not) user standard claims in the id_token. - if settings.get('OIDC_IDTOKEN_INCLUDE_CLAIMS'): - if settings.get('OIDC_EXTRA_SCOPE_CLAIMS'): - custom_claims = settings.get('OIDC_EXTRA_SCOPE_CLAIMS', import_str=True)(token) - claims = custom_claims.create_response_dic() - else: - claims = StandardScopeClaims(token).create_response_dic() - dic.update(claims) + if settings.get("OIDC_IDTOKEN_INCLUDE_CLAIMS"): + standard_claims = StandardScopeClaims(token) + dic.update(standard_claims.create_response_dic()) + + if settings.get("OIDC_EXTRA_SCOPE_CLAIMS"): + extra_claims = settings.get("OIDC_EXTRA_SCOPE_CLAIMS", import_str=True)(token) + dic.update(extra_claims.create_response_dic()) dic = run_processing_hook( - dic, 'OIDC_IDTOKEN_PROCESSING_HOOK', - user=user, token=token, request=request) + dic, "OIDC_IDTOKEN_PROCESSING_HOOK", user=user, token=token, request=request + ) return dic @@ -94,7 +94,7 @@ def client_id_from_id_token(id_token): Returns a string or None. """ payload = JWT().unpack(id_token).payload() - aud = payload.get('aud', None) + aud = payload.get("aud", None) if aud is None: return None if isinstance(aud, list): @@ -116,15 +116,15 @@ def create_token(user, client, scope, id_token_dic=None): token.id_token = id_token_dic token.refresh_token = uuid.uuid4().hex - token.expires_at = timezone.now() + timedelta( - seconds=settings.get('OIDC_TOKEN_EXPIRE')) + token.expires_at = timezone.now() + timedelta(seconds=settings.get("OIDC_TOKEN_EXPIRE")) token.scope = scope return token -def create_code(user, client, scope, nonce, is_authentication, - code_challenge=None, code_challenge_method=None): +def create_code( + user, client, scope, nonce, is_authentication, code_challenge=None, code_challenge_method=None +): """ Create and populate a Code object. Return a Code object. @@ -139,8 +139,7 @@ def create_code(user, client, scope, nonce, is_authentication, code.code_challenge = code_challenge code.code_challenge_method = code_challenge_method - code.expires_at = timezone.now() + timedelta( - seconds=settings.get('OIDC_CODE_EXPIRE')) + code.expires_at = timezone.now() + timedelta(seconds=settings.get("OIDC_CODE_EXPIRE")) code.scope = scope code.nonce = nonce code.is_authentication = is_authentication @@ -153,15 +152,15 @@ def get_client_alg_keys(client): Takes a client and returns the set of keys associated with it. Returns a list of keys. """ - if client.jwt_alg == 'RS256': + if client.jwt_alg == "RS256": keys = [] for rsakey in RSAKey.objects.all(): keys.append(jwk_RSAKey(key=importKey(rsakey.key), kid=rsakey.kid)) if not keys: - raise Exception('You must add at least one RSA Key.') - elif client.jwt_alg == 'HS256': + raise Exception("You must add at least one RSA Key.") + elif client.jwt_alg == "HS256": keys = [SYMKey(key=client.client_secret, alg=client.jwt_alg)] else: - raise Exception('Unsupported key algorithm.') + raise Exception("Unsupported key algorithm.") return keys diff --git a/oidc_provider/tests/app/utils.py b/oidc_provider/tests/app/utils.py index 51f51d4e..491af0e5 100644 --- a/oidc_provider/tests/app/utils.py +++ b/oidc_provider/tests/app/utils.py @@ -5,25 +5,27 @@ from django.contrib.auth.backends import ModelBackend try: - from urlparse import parse_qs, urlsplit + from urlparse import parse_qs + from urlparse import urlsplit except ImportError: - from urllib.parse import parse_qs, urlsplit + from urllib.parse import parse_qs + from urllib.parse import urlsplit -from django.utils import timezone from django.contrib.auth.models import User +from django.utils import timezone -from oidc_provider.models import ( - Client, - Code, - Token, - ResponseType) - +from oidc_provider.lib.claims import ScopeClaims +from oidc_provider.models import Client +from oidc_provider.models import Code +from oidc_provider.models import ResponseType +from oidc_provider.models import Token -FAKE_NONCE = 'cb584e44c43ed6bd0bc2d9c7e242837d' -FAKE_RANDOM_STRING = ''.join( - random.choice(string.ascii_uppercase + string.digits) for _ in range(32)) -FAKE_CODE_CHALLENGE = 'YlYXEqXuRm-Xgi2BOUiK50JW1KsGTX6F1TDnZSC8VTg' -FAKE_CODE_VERIFIER = 'SmxGa0XueyNh5bDgTcSrqzAh2_FmXEqU8kDT6CuXicw' +FAKE_NONCE = "cb584e44c43ed6bd0bc2d9c7e242837d" +FAKE_RANDOM_STRING = "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(32) +) +FAKE_CODE_CHALLENGE = "YlYXEqXuRm-Xgi2BOUiK50JW1KsGTX6F1TDnZSC8VTg" +FAKE_CODE_VERIFIER = "SmxGa0XueyNh5bDgTcSrqzAh2_FmXEqU8kDT6CuXicw" def create_fake_user(): @@ -33,11 +35,11 @@ def create_fake_user(): Return a User object. """ user = User() - user.username = 'johndoe' - user.email = 'johndoe@example.com' - user.first_name = 'John' - user.last_name = 'Doe' - user.set_password('1234') + user.username = "johndoe" + user.email = "johndoe@example.com" + user.first_name = "John" + user.last_name = "Doe" + user.set_password("1234") user.save() @@ -52,20 +54,20 @@ def create_fake_client(response_type, is_public=False, require_consent=True): Return a Client object. """ client = Client() - client.name = 'Some Client' + client.name = "Some Client" client.client_id = str(random.randint(1, 999999)).zfill(6) if is_public: - client.client_type = 'public' - client.client_secret = '' + client.client_type = "public" + client.client_secret = "" else: client.client_secret = str(random.randint(1, 999999)).zfill(6) - client.redirect_uris = ['http://example.com/'] + client.redirect_uris = ["http://example.com/"] client.require_consent = require_consent - client.scope = ['openid', 'email'] + client.scope = ["openid", "email"] client.save() # check if response_type is a string in a python 2 and 3 compatible way - if isinstance(response_type, ("".__class__, u"".__class__)): + if isinstance(response_type, ("".__class__, "".__class__)): response_type = (response_type,) for value in response_type: client.response_types.add(ResponseType.objects.get(value=value)) @@ -90,7 +92,7 @@ def is_code_valid(url, user, client): try: parsed = urlsplit(url) params = parse_qs(parsed.query or parsed.fragment) - code = params['code'][0] + code = params["code"][0] code = Code.objects.get(code=code) is_code_ok = (code.client == client) and (code.user == user) except Exception: @@ -103,15 +105,28 @@ def userinfo(claims, user): """ Fake function for setting OIDC_USERINFO. """ - claims['given_name'] = 'John' - claims['family_name'] = 'Doe' - claims['name'] = '{0} {1}'.format(claims['given_name'], claims['family_name']) - claims['email'] = user.email - claims['email_verified'] = True - claims['address']['country'] = 'Argentina' + claims["given_name"] = "John" + claims["family_name"] = "Doe" + claims["name"] = "{0} {1}".format(claims["given_name"], claims["family_name"]) + claims["email"] = user.email + claims["email_verified"] = True + claims["address"]["country"] = "Argentina" return claims +class FakeScopeClaims(ScopeClaims): + info_pizza = ( + "Pizza", + "Some description for the scope.", + ) + + def scope_pizza(self): + dic = { + "pizza": "Margherita", + } + return dic + + def fake_sub_generator(user): """ Fake function for setting OIDC_IDTOKEN_SUB_GENERATOR. @@ -123,8 +138,8 @@ def fake_idtoken_processing_hook(id_token, user, **kwargs): """ Fake function for inserting some keys into token. Testing OIDC_IDTOKEN_PROCESSING_HOOK. """ - id_token['test_idtoken_processing_hook'] = FAKE_RANDOM_STRING - id_token['test_idtoken_processing_hook_user_email'] = user.email + id_token["test_idtoken_processing_hook"] = FAKE_RANDOM_STRING + id_token["test_idtoken_processing_hook_user_email"] = user.email return id_token @@ -133,8 +148,8 @@ def fake_idtoken_processing_hook2(id_token, user, **kwargs): Fake function for inserting some keys into token. Testing OIDC_IDTOKEN_PROCESSING_HOOK - tuple or list as param """ - id_token['test_idtoken_processing_hook2'] = FAKE_RANDOM_STRING - id_token['test_idtoken_processing_hook_user_email2'] = user.email + id_token["test_idtoken_processing_hook2"] = FAKE_RANDOM_STRING + id_token["test_idtoken_processing_hook_user_email2"] = user.email return id_token @@ -142,7 +157,7 @@ def fake_idtoken_processing_hook3(id_token, user, token, **kwargs): """ Fake function for checking scope is passed to processing hook. """ - id_token['scope_of_token_passed_to_processing_hook'] = token.scope + id_token["scope_of_token_passed_to_processing_hook"] = token.scope return id_token @@ -150,15 +165,14 @@ def fake_idtoken_processing_hook4(id_token, user, **kwargs): """ Fake function for checking kwargs passed to processing hook. """ - id_token['kwargs_passed_to_processing_hook'] = { - key: repr(value) - for (key, value) in kwargs.items() + id_token["kwargs_passed_to_processing_hook"] = { + key: repr(value) for (key, value) in kwargs.items() } return id_token def fake_introspection_processing_hook(response_dict, client, id_token): - response_dict['test_introspection_processing_hook'] = FAKE_RANDOM_STRING + response_dict["test_introspection_processing_hook"] = FAKE_RANDOM_STRING return response_dict diff --git a/oidc_provider/tests/cases/test_utils.py b/oidc_provider/tests/cases/test_utils.py index 787a3f56..24c9ae65 100644 --- a/oidc_provider/tests/cases/test_utils.py +++ b/oidc_provider/tests/cases/test_utils.py @@ -1,56 +1,59 @@ import time from datetime import datetime from hashlib import sha224 +from unittest import mock from django.http import HttpRequest -from django.test import TestCase, override_settings +from django.test import TestCase +from django.test import override_settings from django.utils import timezone -from mock import mock -from oidc_provider.lib.utils.common import get_issuer, get_browser_state_or_default -from oidc_provider.lib.utils.token import create_token, create_id_token -from oidc_provider.tests.app.utils import create_fake_user, create_fake_client +from oidc_provider.lib.utils.common import get_browser_state_or_default +from oidc_provider.lib.utils.common import get_issuer +from oidc_provider.lib.utils.token import create_id_token +from oidc_provider.lib.utils.token import create_token +from oidc_provider.tests.app.utils import create_fake_client +from oidc_provider.tests.app.utils import create_fake_user class Request(object): """ Mock request object. """ - scheme = 'http' + + scheme = "http" def get_host(self): - return 'host-from-request:8888' + return "host-from-request:8888" class CommonTest(TestCase): """ Test cases for common utils. """ + def test_get_issuer(self): request = Request() # from default settings - self.assertEqual(get_issuer(), - 'http://localhost:8000/openid') + self.assertEqual(get_issuer(), "http://localhost:8000/openid") # from custom settings - with self.settings(SITE_URL='http://otherhost:8000'): - self.assertEqual(get_issuer(), - 'http://otherhost:8000/openid') + with self.settings(SITE_URL="http://otherhost:8000"): + self.assertEqual(get_issuer(), "http://otherhost:8000/openid") # `SITE_URL` not set, from `request` - with self.settings(SITE_URL=''): - self.assertEqual(get_issuer(request=request), - 'http://host-from-request:8888/openid') + with self.settings(SITE_URL=""): + self.assertEqual(get_issuer(request=request), "http://host-from-request:8888/openid") # use settings first if both are provided - self.assertEqual(get_issuer(request=request), - 'http://localhost:8000/openid') + self.assertEqual(get_issuer(request=request), "http://localhost:8000/openid") # `site_url` can even be overridden manually - self.assertEqual(get_issuer(site_url='http://127.0.0.1:9000', - request=request), - 'http://127.0.0.1:9000/openid') + self.assertEqual( + get_issuer(site_url="http://127.0.0.1:9000", request=request), + "http://127.0.0.1:9000/openid", + ) def timestamp_to_datetime(timestamp): @@ -69,32 +72,61 @@ def test_create_id_token(self): self.user.last_login = timestamp_to_datetime(login_timestamp) client = create_fake_client("code") token = create_token(self.user, client, []) - id_token_data = create_id_token(token=token, user=self.user, aud='test-aud') - iat = id_token_data['iat'] + id_token_data = create_id_token(token=token, user=self.user, aud="test-aud") + iat = id_token_data["iat"] self.assertEqual(type(iat), int) self.assertGreaterEqual(iat, start_time) self.assertLessEqual(iat - start_time, 5) # Can't take more than 5 s - self.assertEqual(id_token_data, { - 'aud': 'test-aud', - 'auth_time': login_timestamp, - 'exp': iat + 600, - 'iat': iat, - 'iss': 'http://localhost:8000/openid', - 'sub': str(self.user.id), - }) + self.assertEqual( + id_token_data, + { + "aud": "test-aud", + "auth_time": login_timestamp, + "exp": iat + 600, + "iat": iat, + "iss": "http://localhost:8000/openid", + "sub": str(self.user.id), + }, + ) + + @override_settings(OIDC_IDTOKEN_INCLUDE_CLAIMS=True) + def test_create_id_token_with_include_claims_setting(self): + client = create_fake_client("code") + token = create_token(self.user, client, scope=["openid", "email"]) + id_token_data = create_id_token(token=token, user=self.user, aud="test-aud") + self.assertIn("email", id_token_data) + self.assertTrue(id_token_data["email"]) + self.assertIn("email_verified", id_token_data) + self.assertTrue(id_token_data["email_verified"]) + + @override_settings( + OIDC_IDTOKEN_INCLUDE_CLAIMS=True, + OIDC_EXTRA_SCOPE_CLAIMS="oidc_provider.tests.app.utils.FakeScopeClaims", + ) + def test_create_id_token_with_include_claims_setting_and_extra(self): + client = create_fake_client("code") + token = create_token(self.user, client, scope=["openid", "email", "pizza"]) + id_token_data = create_id_token(token=token, user=self.user, aud="test-aud") + # Standard claims included. + self.assertIn("email", id_token_data) + self.assertTrue(id_token_data["email"]) + self.assertIn("email_verified", id_token_data) + self.assertTrue(id_token_data["email_verified"]) + # Extra claims included. + self.assertIn("pizza", id_token_data) + self.assertEqual(id_token_data["pizza"], "Margherita") class BrowserStateTest(TestCase): - - @override_settings(OIDC_UNAUTHENTICATED_SESSION_MANAGEMENT_KEY='my_static_key') + @override_settings(OIDC_UNAUTHENTICATED_SESSION_MANAGEMENT_KEY="my_static_key") def test_get_browser_state_uses_value_from_settings_to_calculate_browser_state(self): request = HttpRequest() request.session = mock.Mock(session_key=None) state = get_browser_state_or_default(request) - self.assertEqual(state, sha224('my_static_key'.encode('utf-8')).hexdigest()) + self.assertEqual(state, sha224("my_static_key".encode("utf-8")).hexdigest()) def test_get_browser_state_uses_session_key_to_calculate_browser_state_if_available(self): request = HttpRequest() - request.session = mock.Mock(session_key='my_session_key') + request.session = mock.Mock(session_key="my_session_key") state = get_browser_state_or_default(request) - self.assertEqual(state, sha224('my_session_key'.encode('utf-8')).hexdigest()) + self.assertEqual(state, sha224("my_session_key".encode("utf-8")).hexdigest()) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..17304614 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,5 @@ +[tool.ruff] +line-length = 100 + +[tool.ruff.lint.isort] +force-single-line = true \ No newline at end of file