Skip to content
This repository has been archived by the owner on Jun 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #42 from IdentityPython/reread_jwks_uri
Browse files Browse the repository at this point in the history
Needed to complete certification.
  • Loading branch information
rohe authored Oct 27, 2021
2 parents 0a1532f + 48186d6 commit 40fe88a
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 46 deletions.
2 changes: 1 addition & 1 deletion src/oidcmsg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__author__ = "Roland Hedberg"
__version__ = "1.4.0"
__version__ = "1.4.1"

import os
from typing import Dict
Expand Down
3 changes: 2 additions & 1 deletion src/oidcmsg/impexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List
from typing import Optional

from cryptojwt import as_unicode
from cryptojwt.utils import as_bytes
from cryptojwt.utils import importer
from cryptojwt.utils import qualified_name
Expand All @@ -25,7 +26,7 @@ def __init__(self):
def dump_attr(self, cls, item, exclude_attributes: Optional[List[str]] = None) -> dict:
if cls in [None, 0, "", [], {}, bool, b'']:
if cls == b'':
val = as_bytes(item)
val = as_unicode(item)
else:
val = item
elif cls == "DICT_TYPE":
Expand Down
68 changes: 39 additions & 29 deletions src/oidcmsg/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,10 @@ def from_dict(self, dictionary, **kwargs):
self._dict[key] = val
continue

self._add_value(skey, vtyp, key, val, _deser, null_allowed)
self._add_value(skey, vtyp, key, val, _deser, null_allowed, sformat="dict")
return self

def _add_value(self, skey, vtyp, key, val, _deser, null_allowed):
def _add_value(self, skey, vtyp, key, val, _deser, null_allowed, sformat="urlencoded"):
"""
Main method for adding a value to the instance. Does all the
checking on type of value and if among allowed values.
Expand Down Expand Up @@ -350,7 +350,7 @@ def _add_value(self, skey, vtyp, key, val, _deser, null_allowed):
self._dict[skey] = [val]
elif _deser:
try:
self._dict[skey] = _deser(val, sformat="urlencoded")
self._dict[skey] = _deser(val, sformat=sformat)
except Exception as exc:
raise DecodeError(ERRTXT % (key, exc))
else:
Expand Down Expand Up @@ -402,16 +402,6 @@ def _add_value(self, skey, vtyp, key, val, _deser, null_allowed):
except Exception as exc:
raise DecodeError(ERRTXT % (key, exc))
else:
# if isinstance(val, str):
# self._dict[skey] = val
# elif isinstance(val, list):
# if len(val) == 1:
# self._dict[skey] = val[0]
# elif not len(val):
# pass
# else:
# raise TooManyValues(key)
# else:
self._dict[skey] = val
elif vtyp is int:
try:
Expand Down Expand Up @@ -468,6 +458,28 @@ def to_jwt(self, key=None, algorithm="", lev=0, lifetime=0):
_jws = JWS(self.to_json(lev), alg=algorithm)
return _jws.sign_compact(key)

def _gather_keys(self, keyjar, jwt, header, **kwargs):
key = []

if keyjar:
_keys = keyjar.get_jwt_verify_keys(jwt, **kwargs)
if not _keys:
keyjar.update()
_keys = keyjar.get_jwt_verify_keys(jwt, **kwargs)
key.extend(_keys)

if "alg" in header and header["alg"] != "none":
if not key:
if keyjar:
keyjar.update()
key = keyjar.get_jwt_verify_keys(jwt, **kwargs)
if not key:
raise MissingSigningKey("alg=%s" % header["alg"])
else:
raise MissingSigningKey("alg=%s" % header["alg"])

return key

def from_jwt(self, txt, keyjar, verify=True, **kwargs):
"""
Given a signed and/or encrypted JWT, verify its correctness and then
Expand Down Expand Up @@ -515,7 +527,6 @@ def from_jwt(self, txt, keyjar, verify=True, **kwargs):
jso = _jwt.payload()
_header = _jwt.headers

key = []
# if "sender" in kwargs:
# key.extend(keyjar.get_verify_key(owner=kwargs["sender"]))

Expand All @@ -524,21 +535,13 @@ def from_jwt(self, txt, keyjar, verify=True, **kwargs):
if _header["alg"] == "none":
pass
elif verify:
if keyjar:
key.extend(keyjar.get_jwt_verify_keys(_jwt, **kwargs))
key = self._gather_keys(keyjar, _jwt, _header, **kwargs)

if "alg" in _header and _header["alg"] != "none":
if not key:
raise MissingSigningKey("alg=%s" % _header["alg"])
if not key:
raise MissingSigningKey("alg=%s" % _header["alg"])

logger.debug("Found signing key.")
try:
_verifier.verify_compact(txt, key)
except NoSuitableSigningKeys:
if keyjar:
keyjar.update()
key = keyjar.get_jwt_verify_keys(_jwt, **kwargs)
_verifier.verify_compact(txt, key)
_verifier.verify_compact(txt, key)

self.jws_header = _jwt.headers
else:
Expand Down Expand Up @@ -850,8 +853,12 @@ def add_non_standard(msg1, msg2):


def list_serializer(vals, sformat="urlencoded", lev=0):
if isinstance(vals, str) or not isinstance(vals, list):
if isinstance(vals, str) and sformat == "dict":
return [vals]

if not isinstance(vals, list):
raise ValueError("Expected list: %s" % vals)

if sformat == "urlencoded":
return " ".join(vals)
else:
Expand All @@ -864,8 +871,11 @@ def list_deserializer(val, sformat="urlencoded"):
return val.split(" ")
elif isinstance(val, list) and len(val) == 1:
return val[0].split(" ")
else:
return val
elif sformat == "dict":
if isinstance(val, str):
val = [val]

return val


def sp_sep_list_serializer(vals, sformat="urlencoded", lev=0):
Expand Down
12 changes: 2 additions & 10 deletions src/oidcmsg/oidc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ class RegistrationRequest(Message):
# "client_id": SINGLE_OPTIONAL_STRING,
# "client_secret": SINGLE_OPTIONAL_STRING,
# "access_token": SINGLE_OPTIONAL_STRING,
"post_logout_redirect_uris": OPTIONAL_LIST_OF_STRINGS,
"post_logout_redirect_uri": SINGLE_OPTIONAL_STRING,
"frontchannel_logout_uri": SINGLE_OPTIONAL_STRING,
"frontchannel_logout_session_required": SINGLE_OPTIONAL_BOOLEAN,
"backchannel_logout_uri": SINGLE_OPTIONAL_STRING,
Expand Down Expand Up @@ -771,14 +771,6 @@ def pack(self, alg="", **kwargs):
else:
self.pack_init()

# if 'jti' in self.c_param:
# try:
# _jti = kwargs['jti']
# except KeyError:
# _jti = uuid.uuid4().hex
#
# self['jti'] = _jti

def to_jwt(self, key=None, algorithm="", lev=0, lifetime=0):
self.pack(alg=algorithm, lifetime=lifetime)
return Message.to_jwt(self, key=key, algorithm=algorithm, lev=lev)
Expand All @@ -797,7 +789,7 @@ def verify(self, **kwargs):
# check that I'm among the recipients
if kwargs["client_id"] not in self["aud"]:
raise NotForMe(
"{} not in aud:{}".format(kwargs["client_id"], self["aud"]), self
'"{}" not in {}'.format(kwargs["client_id"], self["aud"]), self
)

# Then azp has to be present and be one of the aud values
Expand Down
2 changes: 1 addition & 1 deletion src/oidcmsg/oidc/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,6 @@ def verify(self, **kwargs):
return False

self[verified_claim_name("logout_token")] = idt
logger.info("Verified ID Token: {}".format(idt.to_dict()))
logger.info("Verified Logout Token: {}".format(idt.to_dict()))

return True
29 changes: 25 additions & 4 deletions tests/test_06_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,23 @@
from urllib.parse import parse_qs
from urllib.parse import urlencode

import pytest
from cryptojwt.exception import BadSignature
from cryptojwt.exception import UnsupportedAlgorithm
from cryptojwt.jws.exception import SignerAlgError
from cryptojwt.jws.utils import left_hash
from cryptojwt.jwt import JWT
from cryptojwt.key_bundle import KeyBundle
from cryptojwt.key_jar import KeyJar
import pytest

from oidcmsg import proper_path
from oidcmsg import time_util
from oidcmsg.exception import MessageException
from oidcmsg.exception import MissingRequiredAttribute
from oidcmsg.exception import NotAllowedValue
from oidcmsg.exception import OidcMsgError
from oidcmsg.oauth2 import ResponseMessage
from oidcmsg.oauth2 import ROPCAccessTokenRequest
from oidcmsg.oidc import JRD
from oidcmsg.oauth2 import ResponseMessage
from oidcmsg.oidc import AccessTokenRequest
from oidcmsg.oidc import AccessTokenResponse
from oidcmsg.oidc import AddressClaim
Expand All @@ -38,6 +37,7 @@
from oidcmsg.oidc import EXPError
from oidcmsg.oidc import IATError
from oidcmsg.oidc import IdToken
from oidcmsg.oidc import JRD
from oidcmsg.oidc import Link
from oidcmsg.oidc import OpenIDSchema
from oidcmsg.oidc import ProviderConfigurationResponse
Expand Down Expand Up @@ -661,7 +661,7 @@ def test_deserialize(self):
"client_secret_expires_at": 1577858400,
"registration_access_token": "this.is.an.access.token.value.ffx83",
"registration_client_uri": "https://server.example.com/connect/register?client_id"
"=s6BhdRkqt3",
"=s6BhdRkqt3",
"token_endpoint_auth_method": "client_secret_basic",
"application_type": "web",
"redirect_uris": [
Expand Down Expand Up @@ -1601,3 +1601,24 @@ def test_correct_sign_alg():
client_id="554295ce3770612820620000",
allowed_sign_alg="HS256",
)


def test_ID_Token_space_in_id():
idt = IdToken(**{
"at_hash": "buCCujNN632UIV8-VbKhgw",
"sub": "user-subject-1234531",
"aud": "client_ifCttPphtLxtPWd20602 ^.+/",
"iss": "https://www.certification.openid.net/test/a/idpy/",
"exp": 1632495959,
"nonce": "B88En9UpdHkQZMQXK9U3KHzV",
"iat": 1632495659
})

assert idt["aud"] == ["client_ifCttPphtLxtPWd20602 ^.+/"]

idt = IdToken(**{'at_hash': 'rgMbiR-Dj11dQjxhCyLkOw', 'sub': 'user-subject-1234531',
'aud': 'client_dVCwIQuSKklinFP70742;#__$',
'iss': 'https://www.certification.openid.net/test/a/idpy/', 'exp': 1632639462,
'nonce': 'hUT3RhSooxC9CilrD8al6bGx', 'iat': 1632639162})

assert idt["aud"] == ["client_dVCwIQuSKklinFP70742;#__$"]

0 comments on commit 40fe88a

Please sign in to comment.