diff --git a/.coveragerc b/.coveragerc index abc2344c..05bb0ced 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,3 +1,8 @@ # .coveragerc to control coverage.py [run] +branch = true +omit = */tests/*, */wsgi.py, fabfile.py, /usr/local/*, ./setup.py source = . + +[report] +show_missing = true \ No newline at end of file diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 9b550e8f..a9f4eaac 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -18,10 +18,10 @@ jobs: fail-fast: false matrix: python-version: - - '3.7' - '3.8' - '3.9' - '3.10' + - '3.11' steps: - uses: actions/checkout@v2 diff --git a/doc/client/add_on/dpop.rst b/doc/client/add_on/dpop.rst index f5dae1b2..c46648a4 100644 --- a/doc/client/add_on/dpop.rst +++ b/doc/client/add_on/dpop.rst @@ -40,7 +40,7 @@ in a client configuration. 'add_ons': { "dpop": { - "function": "oidcrp.oauth2.add_on.dpop.add_support", + "function": "idpyoidc.client.oauth2.add_on.dpop.add_support", "kwargs": { "signing_algorithms": ["ES256", "ES512"] } diff --git a/doc/client/add_on/pushed_authorization.rst b/doc/client/add_on/pushed_authorization.rst index 2b65fb8d..14a02432 100644 --- a/doc/client/add_on/pushed_authorization.rst +++ b/doc/client/add_on/pushed_authorization.rst @@ -8,10 +8,39 @@ Pushed Authorization Introduction ------------ -https://tools.ietf.org/id/draft-lodderstedt-oauth-par-00.html +https://datatracker.ietf.org/doc/html/rfc9126 -The Internet draft defines the pushed authorization request endpoint, +The Internet draft defines the pushed authorization request (PAR) endpoint, which allows clients to push the payload of an OAuth 2.0 authorization request to the authorization server via a direct request and provides them with a request URI that is used as reference to the data in a -subsequent authorization request. \ No newline at end of file +subsequent authorization request. + +------------- +Configuration +------------- + +There is basically one things you can configure: + +- authn_method + Which client authentication method that should be used at the pushed authorization endpoint. + Default is none. + +------- +Example +------- + +What you have to do is to add a *par* section to an *add_ons* section +in a client configuration. + +.. code:: python + + 'add_ons': { + "par": { + "function": "idpyoidc.client.oauth2.add_on.par.add_support", + "kwargs": { + "authn_method": "private_key_jwt" + } + } + } + diff --git a/example/flask_op/config.json b/example/flask_op/config.json index 2fe378fc..444d4d91 100644 --- a/example/flask_op/config.json +++ b/example/flask_op/config.json @@ -91,7 +91,7 @@ } } }, - "capabilities": { + "preference": { "subject_types_supported": [ "public", "pairwise" @@ -260,6 +260,7 @@ "verify": false }, "issuer": "https://{domain}:{port}", + "entity_id": "https://{domain}:{port}", "keys": { "private_path": "private/jwks.json", "key_defs": [ @@ -277,9 +278,8 @@ ] } ], - "public_path": "static/jwks.json", "read_only": false, - "uri_path": "static/jwks.json" + "uri_path": "jwks" }, "login_hint2acrs": { "class": "idpyoidc.server.login_hint.LoginHint2Acrs", @@ -349,6 +349,6 @@ "verify_user": false, "port": 5000, "domain": "127.0.0.1", - "debug": true + "debug": false } } diff --git a/example/flask_op/views.py b/example/flask_op/views.py index 7846af50..615872a6 100644 --- a/example/flask_op/views.py +++ b/example/flask_op/views.py @@ -1,26 +1,24 @@ import json -import os import sys import traceback from typing import Union from urllib.parse import urlparse +import werkzeug from cryptojwt import as_unicode from flask import Blueprint -from flask import Response from flask import current_app from flask import redirect from flask import render_template from flask import request +from flask import Response from flask.helpers import make_response -from flask.helpers import send_from_directory + from idpyoidc.message.oauth2 import ResponseMessage from idpyoidc.message.oidc import AccessTokenRequest from idpyoidc.message.oidc import AuthorizationRequest -import werkzeug - -from idpyoidc.server.exception import FailedAuthentication from idpyoidc.server.exception import ClientAuthenticationError +from idpyoidc.server.exception import FailedAuthentication from idpyoidc.server.oidc.token import Token # logger = logging.getLogger(__name__) @@ -29,8 +27,8 @@ def _add_cookie(resp: Response, cookie_spec: Union[dict, list]): - kwargs = {k:v - for k,v in cookie_spec.items() + kwargs = {k: v + for k, v in cookie_spec.items() if k not in ('name',)} kwargs["path"] = "/" kwargs["samesite"] = "Lax" @@ -44,15 +42,22 @@ def add_cookie(resp: Response, cookie_spec: Union[dict, list]): elif isinstance(cookie_spec, dict): _add_cookie(resp, cookie_spec) -@oidc_op_views.route('/static/') -def send_js(path): - return send_from_directory('static', path) +# @oidc_op_views.route('/static/') +# def send_js(path): +# return send_from_directory('static', path) +# +# +# @oidc_op_views.route('/keys/') +# def keys(jwks): +# fname = os.path.join('static', jwks) +# return open(fname).read() +# -@oidc_op_views.route('/keys/') -def keys(jwks): - fname = os.path.join('static', jwks) - return open(fname).read() +@oidc_op_views.route('/jwks') +def jwks(): + _context = current_app.server.get_context() + return _context.keyjar.export_jwks() @oidc_op_views.route('/') @@ -188,11 +193,13 @@ def token(): return service_endpoint( current_app.server.get_endpoint('token')) + @oidc_op_views.route('/introspection', methods=['POST']) def introspection_endpoint(): return service_endpoint( current_app.server.get_endpoint('introspection')) + @oidc_op_views.route('/userinfo', methods=['GET', 'POST']) def userinfo(): return service_endpoint( diff --git a/pyproject.toml b/pyproject.toml index 8f06b36a..2d146e9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers =[ [options] package_dir = "src" packages = "find:" -python= "^3.6" +python= "^3.8" [tool.black] line-length = 100 diff --git a/setup.py b/setup.py index 9c15c688..6d99efdf 100644 --- a/setup.py +++ b/setup.py @@ -63,14 +63,13 @@ def run_tests(self): classifiers=[ "Development Status :: 4 - Beta", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Topic :: Software Development :: Libraries :: Python Modules"], install_requires=[ - "cryptojwt>=1.8.3", + "cryptojwt>=1.8.4", "pyOpenSSL", "filelock>=3.0.12", 'pyyaml>=5.1.2', diff --git a/src/idpyoidc/__init__.py b/src/idpyoidc/__init__.py index 626ef98f..76d83a3e 100644 --- a/src/idpyoidc/__init__.py +++ b/src/idpyoidc/__init__.py @@ -1,5 +1,5 @@ __author__ = "Roland Hedberg" -__version__ = "4.2.0" +__version__ = "4.3.0" VERIFIED_CLAIM_PREFIX = "__verified" diff --git a/src/idpyoidc/client/oauth2/add_on/dpop.py b/src/idpyoidc/client/oauth2/add_on/dpop.py index f92574ff..b8450500 100644 --- a/src/idpyoidc/client/oauth2/add_on/dpop.py +++ b/src/idpyoidc/client/oauth2/add_on/dpop.py @@ -99,6 +99,7 @@ def dpop_header( headers: Optional[dict] = None, token: Optional[str] = "", nonce: Optional[str] = "", + endpoint_url: Optional[str] = "", **kwargs ) -> dict: """ @@ -114,7 +115,11 @@ def dpop_header( :return: """ - provider_info = service_context.provider_info + if not endpoint_url: + endpoint_url = kwargs.get("endpoint") + if not endpoint_url: + endpoint_url = service_context.provider_info[service_endpoint] + _dpop_conf = service_context.add_on.get("dpop") if not _dpop_conf: logger.warning("Asked to do dpop when I do not support it") @@ -139,7 +144,7 @@ def dpop_header( "jwk": dpop_key.serialize(), "jti": uuid.uuid4().hex, "htm": http_method, - "htu": provider_info[service_endpoint], + "htu": endpoint_url, "iat": utc_time_sans_frac(), } diff --git a/src/idpyoidc/client/oauth2/add_on/par.py b/src/idpyoidc/client/oauth2/add_on/par.py index 705a2223..afa94058 100644 --- a/src/idpyoidc/client/oauth2/add_on/par.py +++ b/src/idpyoidc/client/oauth2/add_on/par.py @@ -1,6 +1,5 @@ import logging -from cryptojwt import JWT from cryptojwt.utils import importer from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD @@ -8,10 +7,11 @@ from idpyoidc.message.oauth2 import JWTSecuredAuthorizationRequest from idpyoidc.server.util import execute from idpyoidc.util import instantiate -from requests import request logger = logging.getLogger(__name__) +HTTP_METHOD = "POST" + def push_authorization(request_args, service, **kwargs): """ @@ -25,12 +25,6 @@ def push_authorization(request_args, service, **kwargs): logger.debug(f"PAR method args: {method_args}") logger.debug(f"PAR kwargs: {kwargs}") - if method_args["apply"] is False: - return request_args - - _http_method = method_args["http_client"] - _httpc_params = service.upstream_get("unit").httpc_params - # Add client authentication if needed _headers = {} authn_method = method_args["authn_method"] @@ -50,29 +44,22 @@ def push_authorization(request_args, service, **kwargs): _args["iss"] = _context.issuer _headers = service.get_headers( - request_args, http_method=_http_method, authn_method=authn_method, **_args + request_args, http_method=HTTP_METHOD, authn_method=authn_method, **_args ) _headers["Content-Type"] = "application/x-www-form-urlencoded" # construct the message body - if method_args["body_format"] == "urlencoded": - _body = request_args.to_urlencoded() - else: - _jwt = JWT( - key_jar=service.upstream_get("attribute", "keyjar"), - iss=_context.claims.prefer["client_id"], - ) - _jws = _jwt.pack(request_args.to_dict()) + _body = request_args.to_urlencoded() - _msg = Message(request=_jws) - for param in request_args.required_parameters(): - _msg[param] = request_args.get(param) + _http_client = method_args.get("http_client", None) + if not _http_client: + _http_client = service.upstream_get("unit").httpc - _body = _msg.to_urlencoded() + _httpc_params = service.upstream_get("unit").httpc_params # Send it to the Pushed Authorization Request Endpoint using POST - resp = _http_method( - method="POST", + resp = _http_client( + method=HTTP_METHOD, url=_context.provider_info["pushed_authorization_request_endpoint"], data=_body, headers=_headers, @@ -95,41 +82,30 @@ def push_authorization(request_args, service, **kwargs): def add_support( - services, - body_format="jws", - signing_algorithm="RS256", - http_client=None, - merge_rule="strict", - authn_method="", + services, + http_client=None, + authn_method="", ): """ Add the necessary pieces to support Pushed authorization. - :param merge_rule: - :param http_client: - :param signing_algorithm: + :param http_client: Specification for a HTTP client to use different from the default + :param authn_method: The client authentication method to use :param services: A dictionary with all the services the client has access to. - :param body_format: jws or urlencoded """ - if http_client is None: - _http_client = request - else: + if http_client is not None: if isinstance(http_client, dict): if "class" in http_client: - _http_client = instantiate(http_client["class"], **http_client.get("kwargs", {})) + http_client = instantiate(http_client["class"], **http_client.get("kwargs", {})) else: - _http_client = importer(http_client["function"]) + http_client = importer(http_client["function"]) else: - _http_client = importer(http_client) + http_client = importer(http_client) - _service = services["authorization"] + _service = services["authorization"] # There must be such a service _service.upstream_get("context").add_on["pushed_authorization"] = { - "body_format": body_format, - "signing_algorithm": signing_algorithm, - "http_client": _http_client, - "merge_rule": merge_rule, - "apply": True, + "http_client": http_client, "authn_method": authn_method, } diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index ffbee3a7..af629fab 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -44,8 +44,12 @@ def gather_verify_arguments( _context = self.upstream_get("context") _entity = self.upstream_get("unit") + _client_id = _entity.get_client_id() + if not _client_id: + _client_id = _context.get_client_id() + kwargs = { - "client_id": _entity.get_client_id(), + "client_id": _client_id, "iss": _context.issuer, "keyjar": self.upstream_get("attribute", "keyjar"), "verify": True, diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index d771150d..eea94c0b 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -5,54 +5,65 @@ from typing import Optional from cryptojwt import KeyJar -from cryptojwt import as_unicode from cryptojwt.key_jar import init_key_jar from cryptojwt.utils import as_bytes from cryptojwt.utils import importer -from idpyoidc import verified_claim_name from idpyoidc.client.defaults import DEFAULT_CLIENT_CONFIGS from idpyoidc.client.defaults import DEFAULT_OIDC_SERVICES from idpyoidc.client.defaults import DEFAULT_RP_KEY_DEFS -from idpyoidc.client.exception import ConfigurationError -from idpyoidc.client.exception import OidcServiceError from idpyoidc.client.oauth2.stand_alone_client import StandAloneClient -from idpyoidc.exception import MessageException -from idpyoidc.exception import MissingRequiredAttribute -from idpyoidc.exception import NotForMe -from idpyoidc.message.oauth2 import is_error_message -from idpyoidc.message.oidc import AuthorizationRequest -from idpyoidc.message.oidc import AuthorizationResponse -from idpyoidc.message.oidc import OpenIDSchema -from idpyoidc.message.oidc.session import BackChannelLogoutRequest -from idpyoidc.time_util import utc_time_sans_frac from idpyoidc.util import add_path from idpyoidc.util import rndstr - -from ..message import Message -from ..message.oauth2 import ResponseMessage from .oauth2 import Client +from ..message import Message logger = logging.getLogger(__name__) class RPHandler(object): + def __init__( - self, - base_url: Optional[str] = "", - client_configs=None, - services=None, - keyjar=None, - hash_seed="", - verify_ssl=True, - state_db=None, - httpc=None, - httpc_params=None, - config=None, - **kwargs, + self, + base_url: Optional[str] = "", + client_configs=None, + services=None, + keyjar=None, + hash_seed="", + verify_ssl=True, + state_db=None, + httpc=None, + httpc_params=None, + config=None, + **kwargs, ): self.base_url = base_url - _jwks_path = kwargs.get("jwks_path") + + if keyjar is None: + keyjar_defs = {} + if config: + keyjar_defs = getattr(config, "key_conf", None) + + if not keyjar_defs: + keyjar_defs = kwargs.get("key_conf", DEFAULT_RP_KEY_DEFS) + + _jwks_path = kwargs.get("jwks_path", keyjar_defs.get("uri_path", keyjar_defs.get("public_path", ""))) + if "uri_path" in keyjar_defs: + del keyjar_defs["uri_path"] + self.keyjar = init_key_jar(**keyjar_defs, issuer_id="") + self.keyjar.import_jwks_as_json(self.keyjar.export_jwks_as_json(True, ""), base_url) + else: + self.keyjar = keyjar + _jwks_path = kwargs.get("jwks_path", "") + + if _jwks_path: + self.jwks_uri = add_path(base_url, _jwks_path) + else: + self.jwks_uri = "" + if len(self.keyjar): + self.jwks = self.keyjar.export_jwks() + else: + self.jwks = {} if config: if not hash_seed: @@ -65,7 +76,7 @@ def __init__( if "client_class" in config: if isinstance(config["client_class"], str): self.client_cls = importer(config["client_class"]) - else: # assume it's a class + else: # assume it's a class self.client_cls = config["client_class"] else: self.client_cls = StandAloneClient @@ -75,29 +86,23 @@ def __init__( else: self.hash_seed = as_bytes(rndstr(32)) - if keyjar is None: - self.keyjar = init_key_jar(**DEFAULT_RP_KEY_DEFS, issuer_id="") - self.keyjar.import_jwks_as_json(self.keyjar.export_jwks_as_json(True, ""), base_url) - if _jwks_path is None: - _jwks_path = DEFAULT_RP_KEY_DEFS["public_path"] - else: - self.keyjar = keyjar - if client_configs is None: self.client_configs = DEFAULT_CLIENT_CONFIGS + for param in ["client_type", "preference", "add_ons"]: + val = kwargs.get(param, None) + if val: + self.client_configs[""][param] = val else: self.client_configs = client_configs - self.client_cls = StandAloneClient - - if _jwks_path: - self.jwks_uri = add_path(base_url, _jwks_path) - else: - self.jwks_uri = "" - if len(self.keyjar): - self.jwks = self.keyjar.export_jwks() + _cc = kwargs.get("client_class", None) + if _cc: + if isinstance(_cc, str): + _cc = importer(_cc) + self.client_cls =_cc else: - self.jwks = {} + self.client_cls = StandAloneClient + if state_db: self.state_db = state_db @@ -124,6 +129,8 @@ def __init__( if not self.keyjar.httpc_params: self.keyjar.httpc_params = self.httpc_params + self.upstream_get = kwargs.get("upstream_get", None) + def state2issuer(self, state): """ Given the state value find the Issuer ID of the OP/AS that state value @@ -203,6 +210,7 @@ def init_client(self, issuer): config=_cnf, httpc=self.httpc, httpc_params=self.httpc_params, + upstream_get=self.upstream_get ) except Exception as err: logger.error("Failed initiating client: {}".format(err)) @@ -230,10 +238,10 @@ def init_client(self, issuer): return client def do_provider_info( - self, - client: Optional[Client] = None, - state: Optional[str] = "", - behaviour_args: Optional[dict] = None, + self, + client: Optional[Client] = None, + state: Optional[str] = "", + behaviour_args: Optional[dict] = None, ) -> str: """ Either get the provider info from configuration or through dynamic @@ -254,12 +262,12 @@ def do_provider_info( return client.do_provider_info(behaviour_args=behaviour_args) def do_client_registration( - self, - client=None, - iss_id: Optional[str] = "", - state: Optional[str] = "", - request_args: Optional[dict] = None, - behaviour_args: Optional[dict] = None, + self, + client=None, + iss_id: Optional[str] = "", + state: Optional[str] = "", + request_args: Optional[dict] = None, + behaviour_args: Optional[dict] = None, ): """ Prepare for and do client registration if configured to do so @@ -301,10 +309,10 @@ def do_webfinger(self, user: str) -> Client: return temporary_client def client_setup( - self, - iss_id: Optional[str] = "", - user: Optional[str] = "", - behaviour_args: Optional[dict] = None, + self, + iss_id: Optional[str] = "", + user: Optional[str] = "", + behaviour_args: Optional[dict] = None, ) -> StandAloneClient: """ First if no issuer ID is given then the identifier for the user is @@ -360,11 +368,11 @@ def _get_response_type(self, context, req_args: Optional[dict] = None): return context.claims.get_usage("response_types")[0] def init_authorization( - self, - client: Optional[Client] = None, - state: Optional[str] = "", - req_args: Optional[dict] = None, - behaviour_args: Optional[dict] = None, + self, + client: Optional[Client] = None, + state: Optional[str] = "", + req_args: Optional[dict] = None, + behaviour_args: Optional[dict] = None, ) -> str: """ Constructs the URL that will redirect the user to the authorization @@ -514,7 +522,7 @@ def userinfo_in_id_token(id_token: Message, user_info_claims: Optional[List] = N return StandAloneClient.userinfo_in_id_token(id_token, user_info_claims) def finalize_auth( - self, client, issuer: str, response: dict, behaviour_args: Optional[dict] = None + self, client, issuer: str, response: dict, behaviour_args: Optional[dict] = None ): """ Given the response returned to the redirect_uri, parse and verify it. @@ -533,11 +541,11 @@ def finalize_auth( return client.finalize_auth(response, behaviour_args=behaviour_args) def get_access_and_id_token( - self, - authorization_response=None, - state: Optional[str] = "", - client: Optional[object] = None, - behaviour_args: Optional[dict] = None, + self, + authorization_response=None, + state: Optional[str] = "", + client: Optional[object] = None, + behaviour_args: Optional[dict] = None, ): """ There are a number of services where access tokens and ID tokens can @@ -609,10 +617,10 @@ def get_valid_access_token(self, state): return client.get_valid_access_token(state) def logout( - self, - state: str, - client: Optional[Client] = None, - post_logout_redirect_uri: Optional[str] = "", + self, + state: str, + client: Optional[Client] = None, + post_logout_redirect_uri: Optional[str] = "", ) -> dict: """ Does an RP initiated logout from an OP. After logout the user will be @@ -631,7 +639,7 @@ def logout( return client.logout(state, post_logout_redirect_uri=post_logout_redirect_uri) def close( - self, state: str, issuer: Optional[str] = "", post_logout_redirect_uri: Optional[str] = "" + self, state: str, issuer: Optional[str] = "", post_logout_redirect_uri: Optional[str] = "" ) -> dict: if issuer: diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index f630583c..8c0730fb 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -451,17 +451,19 @@ def get_request_parameters( if _context.issuer: _args["iss"] = _context.issuer - # Client authentication by usage of the Authorization HTTP header - # or by modifying the request object - _args.update(self.get_headers_args()) - _headers = self.get_headers(request, http_method=method, authn_method=authn_method, **_args) - # Find out where to send this request try: endpoint_url = kwargs["endpoint"] except KeyError: endpoint_url = self.get_endpoint() + _args["endpoint_url"] = endpoint_url + + # Client authentication by usage of the Authorization HTTP header + # or by modifying the request object + _args.update(self.get_headers_args()) + _headers = self.get_headers(request, http_method=method, authn_method=authn_method, **_args) + _info["url"] = get_http_url(endpoint_url, request, method=method) # If there is to be a body part diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index 041b0bab..d101fae2 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -295,7 +295,13 @@ def set(self, key, value): setattr(self, key, value) def get_client_id(self): - return self.claims.get_usage("client_id") + res = self.claims.get_usage("client_id") + if not res: + res = self.entity_id + if not res and self.upstream_get: + res = self.upstream_get("unit").entity_id + + return res def collect_usage(self): return self.claims.use diff --git a/src/idpyoidc/message/oidc/__init__.py b/src/idpyoidc/message/oidc/__init__.py index a6da5063..a1c9949f 100644 --- a/src/idpyoidc/message/oidc/__init__.py +++ b/src/idpyoidc/message/oidc/__init__.py @@ -777,7 +777,7 @@ def verify(self, **kwargs): try: if kwargs["iss"] != self["iss"]: - raise IssuerMismatch("{} != {}".format(kwargs["iss"], self["iss"])) + raise IssuerMismatch("{kwargs['iss']} != {self['iss']}") except KeyError: pass @@ -785,7 +785,7 @@ def verify(self, **kwargs): if "client_id" in kwargs: # check that I'm among the recipients if kwargs["client_id"] not in self["aud"]: - raise NotForMe('"{}" not in {}'.format(kwargs["client_id"], self["aud"]), self) + raise NotForMe(f'{kwargs["client_id"]} not in {self["aud"]}') # Then azp has to be present and be one of the aud values if len(self["aud"]) > 1: diff --git a/src/idpyoidc/server/__init__.py b/src/idpyoidc/server/__init__.py index 14fe949e..78c2370b 100644 --- a/src/idpyoidc/server/__init__.py +++ b/src/idpyoidc/server/__init__.py @@ -46,7 +46,11 @@ def __init__( entity_id: Optional[str] = "", key_conf: Optional[dict] = None, ): - self.entity_id = entity_id or conf.get("entity_id") + self.entity_id = entity_id or conf.get("entity_id", None) + if not self.entity_id: + _conf = conf.get("conf", None) + if _conf: + self.entity_id = _conf.get("entity_id", "") self.issuer = conf.get("issuer", self.entity_id) self.persistence = None @@ -80,6 +84,7 @@ def __init__( cwd=cwd, cookie_handler=cookie_handler, keyjar=self.keyjar, + entity_id=self.entity_id ) # Need to have context in place before doing this diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index b1d4bfda..8a0c72da 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -490,11 +490,13 @@ def verify_client( _method = None _cdb = _cinfo = None + _tested = [] for _method in (methods[meth] for meth in allowed_methods): if not _method.is_usable(request=request, authorization_token=authorization_token): continue try: logger.info(f"Verifying client authentication using {_method.tag}") + _tested.append(_method.tag) auth_info = _method.verify( keyjar=endpoint.upstream_get("attribute", "keyjar"), request=request, @@ -529,7 +531,12 @@ def verify_client( try: _cinfo = _cdb[client_id] except KeyError: - raise UnknownClient("Unknown Client ID") + _auto_reg = getattr(endpoint, "automatic_registration", None) + if _auto_reg: + _cinfo = {"client_id": client_id} + _auto_reg.set(client_id, _cinfo) + else: + raise UnknownClient("Unknown Client ID") if not _cinfo: raise UnknownClient("Unknown Client ID") @@ -552,6 +559,7 @@ def verify_client( break logger.debug(f"Authn methods applied") + logger.debug(f"Method tested: {_tested}") # store what authn method was used if "method" in auth_info and client_id and _cdb: diff --git a/src/idpyoidc/server/endpoint_context.py b/src/idpyoidc/server/endpoint_context.py index 0b5d2ffa..3b46ef3e 100755 --- a/src/idpyoidc/server/endpoint_context.py +++ b/src/idpyoidc/server/endpoint_context.py @@ -29,8 +29,10 @@ logger = logging.getLogger(__name__) -def init_user_info(conf, cwd: str): +def init_user_info(conf, cwd: str, upstream_get: Optional[Callable] = None): kwargs = conf.get("kwargs", {}) + if upstream_get: + kwargs["upstream_get"] = upstream_get if isinstance(conf["class"], str): return importer(conf["class"])(**kwargs) @@ -337,7 +339,7 @@ def do_userinfo(self): _conf = self.conf.get("userinfo") if _conf: if self.session_manager: - self.userinfo = init_user_info(_conf, self.cwd) + self.userinfo = init_user_info(_conf, self.cwd, upstream_get=self.unit_get) self.session_manager.userinfo = self.userinfo else: logger.warning("Cannot init_user_info if no session manager was provided.") diff --git a/src/idpyoidc/server/oauth2/add_on/dpop.py b/src/idpyoidc/server/oauth2/add_on/dpop.py index 8deba5cd..5148cfe0 100644 --- a/src/idpyoidc/server/oauth2/add_on/dpop.py +++ b/src/idpyoidc/server/oauth2/add_on/dpop.py @@ -4,16 +4,16 @@ from typing import Optional from typing import Union -from cryptojwt import JWS from cryptojwt import as_unicode +from cryptojwt import JWS from cryptojwt.jwk.jwk import key_from_jwk_dict from cryptojwt.jws.jws import factory +from idpyoidc.message import Message from idpyoidc.message import SINGLE_OPTIONAL_STRING from idpyoidc.message import SINGLE_REQUIRED_INT from idpyoidc.message import SINGLE_REQUIRED_JSON from idpyoidc.message import SINGLE_REQUIRED_STRING -from idpyoidc.message import Message from idpyoidc.metadata import get_signing_algs from idpyoidc.server.client_authn import BearerHeader @@ -130,6 +130,7 @@ def userinfo_post_parse_request(request, client_id, context, auth_info, **kwargs """ Expect http_info attribute in kwargs. http_info should be a dictionary containing HTTP information. + This function is ment for DPoP-protected resources. :param request: :param client_id: @@ -179,10 +180,19 @@ def token_args(context, client_id, token_args: Optional[dict] = None): return token_args +def _add_to_context(endpoint, algs_supported): + _context = endpoint.upstream_get("context") + _context.provider_info["dpop_signing_alg_values_supported"] = algs_supported + _context.add_on["dpop"] = {"algs_supported": algs_supported} + _context.client_authn_methods["dpop"] = DPoPClientAuth + + def add_support(endpoint: dict, **kwargs): - # - _token_endp = endpoint["token"] - _token_endp.post_parse_request.append(token_post_parse_request) + # Pick the token endpoint + _endp = endpoint.get("token", None) + if _endp: + _endp.post_parse_request.append(token_post_parse_request) + _added_to_context = False _algs_supported = kwargs.get("dpop_signing_alg_values_supported") if not _algs_supported: @@ -190,17 +200,18 @@ def add_support(endpoint: dict, **kwargs): else: _algs_supported = [alg for alg in _algs_supported if alg in get_signing_algs()] - _token_endp.upstream_get("context").provider_info[ - "dpop_signing_alg_values_supported" - ] = _algs_supported + if _endp: + _add_to_context(_endp, _algs_supported) + _added_to_context = True - _context = _token_endp.upstream_get("context") - _context.add_on["dpop"] = {"algs_supported": _algs_supported} - _context.client_authn_methods["dpop"] = DPoPClientAuth + for _dpop_endpoint in kwargs.get("dpop_endpoints", ["userinfo"]): + _endpoint = endpoint.get(_dpop_endpoint, None) + if _endpoint: + if not _added_to_context: + _add_to_context(_endp, _algs_supported) + _added_to_context = True - _userinfo_endpoint = endpoint.get("userinfo") - if _userinfo_endpoint: - _userinfo_endpoint.post_parse_request.append(userinfo_post_parse_request) + _endpoint.post_parse_request.append(userinfo_post_parse_request) # DPoP-bound access token in the "Authorization" header and the DPoP proof in the "DPoP" header @@ -215,12 +226,12 @@ def is_usable(self, request=None, authorization_token=None, http_headers=None): return False def verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - get_client_id_from_token: Optional[Callable] = None, - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + get_client_id_from_token: Optional[Callable] = None, + **kwargs, ): # info contains token and client_id info = BearerHeader._verify( diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index 9f7b2723..917c96a0 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -866,7 +866,12 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict list(set(scope + resource_scopes)), _sinfo["client_id"] ) - rtype = set(request["response_type"][:]) + if isinstance(request["response_type"], list): + rtype = set(request["response_type"][:]) + else: # assume it's a string + rtype = set() + rtype.add(request["response_type"]) + handled_response_type = [] fragment_enc = True diff --git a/src/idpyoidc/server/oauth2/introspection.py b/src/idpyoidc/server/oauth2/introspection.py index 096d666e..1a7f19c9 100644 --- a/src/idpyoidc/server/oauth2/introspection.py +++ b/src/idpyoidc/server/oauth2/introspection.py @@ -147,7 +147,7 @@ def process_request(self, request=None, release: Optional[list] = None, **kwargs ) if _claims_restriction: user_info = _context.claims_interface.get_user_claims( - _session_info["user_id"], _claims_restriction + _session_info["user_id"], _claims_restriction, client_id=_session_info["client_id"] ) if user_info: _resp.update(user_info) diff --git a/src/idpyoidc/server/oauth2/token_helper/access_token.py b/src/idpyoidc/server/oauth2/token_helper/access_token.py index afb74337..a5c94c2c 100755 --- a/src/idpyoidc/server/oauth2/token_helper/access_token.py +++ b/src/idpyoidc/server/oauth2/token_helper/access_token.py @@ -139,6 +139,8 @@ def process_request(self, req: Union[Message, dict], **kwargs): _response["expires_in"] = token.expires_at - utc_time_sans_frac() if issue_refresh and "refresh_token" in _supports_minting: + if token: + _based_on.used -= 1 try: refresh_token = self._mint_token( token_class="refresh_token", diff --git a/src/idpyoidc/server/oidc/session.py b/src/idpyoidc/server/oidc/session.py index 9e6071e6..03ddfd2a 100644 --- a/src/idpyoidc/server/oidc/session.py +++ b/src/idpyoidc/server/oidc/session.py @@ -94,7 +94,7 @@ def __init__(self, upstream_get, **kwargs): _csi = kwargs.get("check_session_iframe") if _csi and not _csi.startswith("http"): # unit since context does not exist at this point in time - kwargs["check_session_iframe"] = add_path(upstream_get("unit").issuer, _csi) + kwargs["check_session_iframe"] = add_path(upstream_get("unit").entity_id, _csi) Endpoint.__init__(self, upstream_get, **kwargs) self.iv = as_bytes(rndstr(24)) diff --git a/src/idpyoidc/server/oidc/token_helper/access_token.py b/src/idpyoidc/server/oidc/token_helper/access_token.py index a1e7e8e1..2594748e 100755 --- a/src/idpyoidc/server/oidc/token_helper/access_token.py +++ b/src/idpyoidc/server/oidc/token_helper/access_token.py @@ -121,6 +121,8 @@ def process_request(self, req: Union[Message, dict], **kwargs): _response["expires_in"] = token.expires_at - utc_time_sans_frac() if issue_refresh and "refresh_token" in _supports_minting: + if token: + _based_on.used -= 1 try: refresh_token = self._mint_token( token_class="refresh_token", @@ -139,6 +141,8 @@ def process_request(self, req: Union[Message, dict], **kwargs): if "openid" in _authn_req["scope"] and "id_token" in _supports_minting: if "id_token" in _based_on.usage_rules.get("supports_minting"): + if token: + _based_on.used -= 1 try: _idtoken = self._mint_token( token_class="id_token", diff --git a/src/idpyoidc/server/oidc/userinfo.py b/src/idpyoidc/server/oidc/userinfo.py index d497de0a..32d77506 100755 --- a/src/idpyoidc/server/oidc/userinfo.py +++ b/src/idpyoidc/server/oidc/userinfo.py @@ -157,7 +157,9 @@ def process_request(self, request=None, **kwargs): _session_info["branch_id"], scopes=token.scope, claims_release_point="userinfo" ) info = _cntxt.claims_interface.get_user_claims( - _session_info["user_id"], claims_restriction=_claims_restriction + _session_info["user_id"], + claims_restriction=_claims_restriction, + client_id=_session_info["client_id"] ) info["sub"] = _grant.sub if _grant.add_acr_value("userinfo"): diff --git a/src/idpyoidc/server/session/claims.py b/src/idpyoidc/server/session/claims.py index 63396188..272b4636 100755 --- a/src/idpyoidc/server/session/claims.py +++ b/src/idpyoidc/server/session/claims.py @@ -200,7 +200,7 @@ def get_claims_all_usage(self, session_id: str, scopes: str) -> dict: auth_req = {} return self.get_claims_all_usage_from_request(auth_req, scopes) - def get_user_claims(self, user_id: str, claims_restriction: dict) -> dict: + def get_user_claims(self, user_id: str, claims_restriction: dict, client_id: str) -> dict: """ :param user_id: User identifier @@ -212,7 +212,7 @@ def get_user_claims(self, user_id: str, claims_restriction: dict) -> dict: raise ImproperlyConfigured("userinfo MUST be defined in the configuration") if claims_restriction: # Get all possible claims - user_info = meth(user_id, client_id=None) + user_info = meth(user_id, client_id) # Filter out the claims that can be returned return { k: user_info.get(k) diff --git a/src/idpyoidc/server/session/database.py b/src/idpyoidc/server/session/database.py index c9cdaf55..abed53c0 100644 --- a/src/idpyoidc/server/session/database.py +++ b/src/idpyoidc/server/session/database.py @@ -50,7 +50,7 @@ def branch_key(*args): return DIVIDER.join(args) @staticmethod - def unpack_branch_key(key): + def unpack_branch_key(key: str) -> list: """Translate a key into an ordered list of names""" return key.split(DIVIDER) diff --git a/src/idpyoidc/server/session/grant.py b/src/idpyoidc/server/session/grant.py index 286827b2..d7ee7c39 100644 --- a/src/idpyoidc/server/session/grant.py +++ b/src/idpyoidc/server/session/grant.py @@ -243,8 +243,8 @@ def payload_arguments( ) if _claims_restriction and context.session_manager.node_type[0] == "user": - user_id, _, _ = context.session_manager.decrypt_branch_id(session_id) - user_info = context.claims_interface.get_user_claims(user_id, _claims_restriction) + user_id, client_id, _ = context.session_manager.decrypt_branch_id(session_id) + user_info = context.claims_interface.get_user_claims(user_id, _claims_restriction, client_id=client_id) payload.update(user_info) # Should I add the acr value @@ -382,6 +382,8 @@ def mint_token( session_id=session_id, usage_rules=usage_rules, **token_payload ) + if based_on: + based_on.used += 1 else: raise ValueError("Can not mint that kind of token") @@ -596,8 +598,8 @@ def payload_arguments( secondary_identifier=secondary_identifier, ) - user_id, _, _ = endpoint_context.session_manager.decrypt_session_id(session_id) - user_info = endpoint_context.claims_interface.get_user_claims(user_id, _claims_restriction) + user_id, client_id, _ = endpoint_context.session_manager.decrypt_session_id(session_id) + user_info = endpoint_context.claims_interface.get_user_claims(user_id, _claims_restriction, client_id) payload.update(user_info) # Should I add the acr value diff --git a/src/idpyoidc/server/session/manager.py b/src/idpyoidc/server/session/manager.py index 1c7c35a4..d064433d 100644 --- a/src/idpyoidc/server/session/manager.py +++ b/src/idpyoidc/server/session/manager.py @@ -198,6 +198,11 @@ def create_grant( resources = [] if "resource" in auth_req: resources = auth_req["resource"] + if "audience" in auth_req: + if isinstance(auth_req["audience"], str): + resources.append(auth_req["audience"]) + else: + resources.extend(auth_req["audience"]) return self.add_grant( path=self.make_path(user_id=user_id, client_id=client_id), diff --git a/src/idpyoidc/server/token/id_token.py b/src/idpyoidc/server/token/id_token.py index 7aa74752..7c66da46 100755 --- a/src/idpyoidc/server/token/id_token.py +++ b/src/idpyoidc/server/token/id_token.py @@ -169,6 +169,7 @@ def payload( user_info = _context.claims_interface.get_user_claims( user_id=session_information["user_id"], claims_restriction=_claims_restriction, + client_id=session_information["client_id"] ) if _claims_restriction and "acr" in _claims_restriction and "acr" in _args: if claims_match(_args["acr"], _claims_restriction["acr"]) is False: diff --git a/src/idpyoidc/server/user_info/__init__.py b/src/idpyoidc/server/user_info/__init__.py index f8206017..a3a2683b 100755 --- a/src/idpyoidc/server/user_info/__init__.py +++ b/src/idpyoidc/server/user_info/__init__.py @@ -30,7 +30,7 @@ def dict_subset(a, b): class UserInfo(object): """Read only interface to a user info store""" - def __init__(self, db=None, db_file=""): + def __init__(self, db=None, db_file="", **kwargs): if db is not None: self.db = db elif db_file: diff --git a/src/idpyoidc/storage/listfile.py b/src/idpyoidc/storage/listfile.py index 1fdfd009..77520de3 100644 --- a/src/idpyoidc/storage/listfile.py +++ b/src/idpyoidc/storage/listfile.py @@ -134,3 +134,6 @@ def _read_info(self, fname): _msg = f"No such file: '{fname}'" logger.error(_msg) return None + + def __call__(self): + return self._read_info(self.file_name) diff --git a/tests/private/token_jwks.json b/tests/private/token_jwks.json index 3790bbf7..d3e0f070 100644 --- a/tests/private/token_jwks.json +++ b/tests/private/token_jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "LxZS5XrgcQJSv6ldNDNYd3xNE7ldifF2"}]} \ No newline at end of file +{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "vrjoMrmgK8SmJJPc318zTxqG_tvBqF5l"}]} \ No newline at end of file diff --git a/tests/static/jwks.json b/tests/static/jwks.json index 8322d976..161a407b 100644 --- a/tests/static/jwks.json +++ b/tests/static/jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file +{"keys": [{"kty": "RSA", "use": "sig", "kid": "YnNESFhyQjloMnYzV2VqRGR2a3VCblFLX2h4VGl3TDVlY3FUNkViUE90bw", "e": "AQAB", "n": "2iMaDALTQolz4UaT--GhjriLMyNbrDGlIXxSmgRh17Cm3cuHiyPOIQv1pjZVg4ATU1aafxmFyTfrmtf56tPuJ8yqcNNZC8XadYPAw7PTW9g8GJgLtC8GURJ9GQZD6FYIE6YCou8fYo6yd4b99y2y_vsl06cm9xQnstfp6eyMkcgQyrmdmlbyeuXwvcxsxtGX61MTJtCp4VELmDctJiYP_bD7HNRPV7uqXDMNmWSY0TYL-tg0As4y8-w3wSwmtcfWhnQEraFT0-m4hBpEWHlouuFNXRQIrXbamKxeh6kJNO0wJN8fZ4Ovygf8sE4kEwBPfWO59wxDF7camTpDUqg29Q"}, {"kty": "EC", "use": "sig", "kid": "aWhtalRSTDZmNmRTd1ZDNWZmY3ZGMTNqM1dnLVA2RjQyMi1CNGdOSUNKVQ", "crv": "P-256", "x": "Ww5XVT3CxYN88BpJDZGodRiar0qr8UvPFaRoqzyD1Io", "y": "w23EDFAvwe03NjL5NKtUXwxuVMFmEn3ecJOPbljiDkg"}]} \ No newline at end of file diff --git a/tests/test_client_29_pushed_auth.py b/tests/test_client_29_pushed_auth.py index 02734cf7..696db493 100644 --- a/tests/test_client_29_pushed_auth.py +++ b/tests/test_client_29_pushed_auth.py @@ -36,14 +36,9 @@ def create_client(self): "add_ons": { "pushed_authorization": { "function": "idpyoidc.client.oauth2.add_on.par.add_support", - "kwargs": { - "body_format": "jws", - "signing_algorithm": "RS256", - "http_client": None, - "merge_rule": "lax", - }, + "kwargs": {} } - }, + } } self.entity = Client(keyjar=CLI_KEY, config=config, services=DEFAULT_OAUTH2_SERVICES) diff --git a/tests/test_server_06_grant.py b/tests/test_server_06_grant.py index 2f163811..9a69d043 100644 --- a/tests/test_server_06_grant.py +++ b/tests/test_server_06_grant.py @@ -170,7 +170,7 @@ def test_grant(self): token_handler=TOKEN_HANDLER["access_token"], based_on=code, ) - + code.used = 0 refresh_token = grant.mint_token( session_id, context=self.context, @@ -231,7 +231,7 @@ def test_grant_revoked_based_on(self): token_handler=TOKEN_HANDLER["access_token"], based_on=code, ) - + code.used = 0 refresh_token = grant.mint_token( session_id, context=self.context, @@ -271,9 +271,11 @@ def test_revoke(self): grant.revoke_token(based_on=code.value) - assert code.is_active() is True + assert code.is_active() is False assert access_token.is_active() is False + # Reset code usage + code.used = 0 access_token_2 = grant.mint_token( session_id, context=self.context, @@ -535,6 +537,8 @@ def test_grant_remove_based_on_code(self): based_on=code, ) + # reset code usage + code.used = 0 refresh_token = grant.mint_token( session_id, context=self.context, @@ -567,6 +571,8 @@ def test_grant_remove_one_by_one(self): based_on=code, ) + # reset code usage + code.used = 0 refresh_token = grant.mint_token( session_id, context=self.context, diff --git a/tests/test_server_08_id_token.py b/tests/test_server_08_id_token.py index 001d7ff7..50df1d1e 100644 --- a/tests/test_server_08_id_token.py +++ b/tests/test_server_08_id_token.py @@ -278,6 +278,9 @@ def test_id_token_payload_with_access_token(self): code = self._mint_code(grant, session_id) access_token = self._mint_access_token(grant, session_id, code) + # reset code usage + code.used = 0 + id_token = self._mint_id_token( grant, session_id, token_ref=code, access_token=access_token.value ) @@ -308,6 +311,9 @@ def test_id_token_payload_with_code_and_access_token(self): code = self._mint_code(grant, session_id) access_token = self._mint_access_token(grant, session_id, code) + # reset code usage + code.used = 0 + id_token = self._mint_id_token( grant, session_id, @@ -369,6 +375,8 @@ def test_id_token_payload_many_0(self): code = self._mint_code(grant, session_id) access_token = self._mint_access_token(grant, session_id, code) + # reset code usage + code.used = 0 id_token = self._mint_id_token( grant, session_id, @@ -622,6 +630,8 @@ def test_id_token_info(self): code = self._mint_code(grant, session_id) access_token = self._mint_access_token(grant, session_id, code) + # reset code usage + code.used = 0 id_token = self._mint_id_token( grant, session_id, token_ref=code, access_token=access_token.value ) @@ -654,6 +664,8 @@ def test_id_token_acr_claim(self): code = self._mint_code(grant, session_id) access_token = self._mint_access_token(grant, session_id, code) + # reset code usage + code.used = 0 id_token = self._mint_id_token( grant, session_id, token_ref=code, access_token=access_token.value ) @@ -671,6 +683,8 @@ def test_id_token_acr_none(self): code = self._mint_code(grant, session_id) access_token = self._mint_access_token(grant, session_id, code) + # reset code usage + code.used = 0 id_token = self._mint_id_token( grant, session_id, token_ref=code, access_token=access_token.value ) diff --git a/tests/test_server_10_session_manager.py b/tests/test_server_10_session_manager.py index 8812a68f..3ca10166 100644 --- a/tests/test_server_10_session_manager.py +++ b/tests/test_server_10_session_manager.py @@ -259,6 +259,7 @@ def test_code_usage(self): assert access_token.is_active() assert len(grant.issued_token) == 2 + code.used = 0 refresh_token = self._mint_token("refresh_token", grant, session_id, code) assert isinstance(refresh_token, RefreshToken) assert refresh_token.is_active() @@ -411,6 +412,8 @@ def test_token_usage_default(self): assert token.usage_rules == {} + # reset code usage + code.used = 0 refresh_token = self._mint_token("refresh_token", grant, _session_id, code) assert refresh_token.usage_rules == {"supports_minting": ["access_token", "refresh_token"]} @@ -443,6 +446,8 @@ def test_token_usage_grant(self): token = self._mint_token("access_token", grant, _session_id, code) assert token.usage_rules == {"expires_in": 3600} + # reset code usage + code.used = 0 refresh_token = self._mint_token("refresh_token", grant, _session_id, code) assert refresh_token.usage_rules == { "supports_minting": ["access_token", "refresh_token", "id_token"] @@ -546,6 +551,8 @@ def test_token_usage_client_config(self): token = self._mint_token("access_token", grant, _session_id, code) assert token.usage_rules == {"expires_in": 600} + # reset code usage + code.used = 0 refresh_token = self._mint_token("refresh_token", grant, _session_id, code) assert refresh_token.usage_rules == {"supports_minting": ["access_token"]} @@ -694,9 +701,15 @@ def test_find_latest_idtoken(self): grant = self.session_manager[_session_id] code = self._mint_token("authorization_code", grant, _session_id) + # reset code usage + code.used = 0 id_token_1 = self._mint_token("id_token", grant, _session_id) + # reset code usage + code.used = 0 refresh_token = self._mint_token("refresh_token", grant, _session_id, code) + # reset code usage + code.used = 0 id_token_2 = self._mint_token("id_token", grant, _session_id, code) _jwt1 = factory(id_token_1.value) diff --git a/tests/test_server_12_session_life.py b/tests/test_server_12_session_life.py index d36ecdb6..50de8c8b 100644 --- a/tests/test_server_12_session_life.py +++ b/tests/test_server_12_session_life.py @@ -154,6 +154,8 @@ def test_code_flow(self): assert tok.supports_minting("refresh_token") + # reset code usage + code.used = 0 refresh_token = grant.mint_token( session_id=session_id, context=self.context, @@ -388,6 +390,8 @@ def test_code_flow(self): # this test is include in the mint_token methods # assert tok.supports_minting("refresh_token") + # reset code usage + code.used = 0 refresh_token = grant.mint_token( session_id=session_id, context=self.context, diff --git a/tests/test_server_20b_claims.py b/tests/test_server_20b_claims.py index f84572ab..8389fbcd 100644 --- a/tests/test_server_20b_claims.py +++ b/tests/test_server_20b_claims.py @@ -19,9 +19,12 @@ def full_path(local_file): {"type": "EC", "crv": "P-256", "use": ["sig"]}, ] +CLIENT_ID = "client_1" + + AREQ = AuthorizationRequest( response_type="code", - client_id="client_1", + client_id=CLIENT_ID, redirect_uri="http://example.com/authz", scope=["openid"], state="state000", @@ -30,7 +33,7 @@ def full_path(local_file): AREQ_2 = AuthorizationRequest( response_type="code", - client_id="client_1", + client_id=CLIENT_ID, redirect_uri="http://example.com/authz", scope=["openid", "address", "email"], state="state000", @@ -40,7 +43,7 @@ def full_path(local_file): AREQ_3 = AuthorizationRequest( response_type="code", - client_id="client_1", + client_id=CLIENT_ID, redirect_uri="http://example.com/authz", scope=["openid", "address", "email"], state="state000", @@ -116,7 +119,7 @@ class TestEndpoint(object): @pytest.fixture(autouse=True) def create_idtoken(self): server = Server(conf) - server.context.cdb["client_1"] = { + server.context.cdb[CLIENT_ID] = { "client_secret": "hemligtochintekort", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", @@ -127,7 +130,7 @@ def create_idtoken(self): }, "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } - server.keyjar.add_symmetric("client_1", "hemligtochintekort", ["sig", "enc"]) + server.keyjar.add_symmetric(CLIENT_ID, "hemligtochintekort", ["sig", "enc"]) self.claims_interface = server.context.claims_interface self.context = server.context self.session_manager = self.context.session_manager @@ -161,7 +164,7 @@ def test_get_claims_userinfo_3(self): "enable_claims_per_client": True, "add_claims_by_scope": True, } - self.context.cdb["client_1"]["add_claims"]["always"]["userinfo"] = [ + self.context.cdb[CLIENT_ID]["add_claims"]["always"]["userinfo"] = [ "name", "email", ] @@ -182,7 +185,7 @@ def test_get_claims_introspection_3(self): "enable_claims_per_client": True, "add_claims_by_scope": True, } - self.context.cdb["client_1"]["add_claims"]["always"]["introspection"] = [ + self.context.cdb[CLIENT_ID]["add_claims"]["always"]["introspection"] = [ "name", "email", ] @@ -227,7 +230,7 @@ def test_get_claims_all_usage_2(self): self.server.get_endpoint("userinfo").kwargs = { "enable_claims_per_client": True, } - self.context.cdb["client_1"]["add_claims"]["always"]["userinfo"] = [ + self.context.cdb[CLIENT_ID]["add_claims"]["always"]["userinfo"] = [ "name", "email", ] @@ -259,7 +262,7 @@ def test_get_user_claims(self): self.server.get_endpoint("userinfo").kwargs = { "enable_claims_per_client": True, } - self.context.cdb["client_1"]["add_claims"]["always"]["userinfo"] = [ + self.context.cdb[CLIENT_ID]["add_claims"]["always"]["userinfo"] = [ "name", "email", ] @@ -273,14 +276,14 @@ def test_get_user_claims(self): session_id, ["openid", "address"] ) - _claims = self.claims_interface.get_user_claims(USER_ID, claims_restriction["userinfo"]) + _claims = self.claims_interface.get_user_claims(USER_ID, claims_restriction["userinfo"], CLIENT_ID) assert _claims == {"name": "Diana Krall", "email": "diana@example.org"} - _claims = self.claims_interface.get_user_claims(USER_ID, claims_restriction["id_token"]) + _claims = self.claims_interface.get_user_claims(USER_ID, claims_restriction["id_token"], CLIENT_ID) assert _claims == {"email_verified": False, "email": "diana@example.org"} _claims = self.claims_interface.get_user_claims( - USER_ID, claims_restriction["introspection"] + USER_ID, claims_restriction["introspection"], CLIENT_ID ) # Note that sub is not a user claim assert _claims == { @@ -292,5 +295,5 @@ def test_get_user_claims(self): } } - _claims = self.claims_interface.get_user_claims(USER_ID, claims_restriction["access_token"]) + _claims = self.claims_interface.get_user_claims(USER_ID, claims_restriction["access_token"], CLIENT_ID) assert _claims == {} diff --git a/tests/test_server_20f_userinfo.py b/tests/test_server_20f_userinfo.py index 544a059d..03ecd2a6 100644 --- a/tests/test_server_20f_userinfo.py +++ b/tests/test_server_20f_userinfo.py @@ -41,10 +41,11 @@ "email_verified": {"essential": True}, } } +CLIENT_ID = "client1" OIDR = OpenIDRequest( response_type="code", - client_id="client1", + client_id=CLIENT_ID, redirect_uri="http://example.com/authz", scope=["openid"], state="state000", @@ -194,7 +195,7 @@ def create_endpoint_context(self): server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) self.endpoint_context = server.context # Just has to be there - self.endpoint_context.cdb["client1"] = { + self.endpoint_context.cdb[CLIENT_ID] = { "add_claims": { "always": {}, "by_scope": {}, @@ -228,7 +229,7 @@ def test_collect_user_info(self): session_id=session_id, scopes=OIDR["scope"], claims_release_point="userinfo" ) - res = self.claims_interface.get_user_claims("diana", _userinfo_restriction) + res = self.claims_interface.get_user_claims("diana", _userinfo_restriction, CLIENT_ID) assert res == { "eduperson_scoped_affiliation": ["staff@example.org"], @@ -241,7 +242,7 @@ def test_collect_user_info(self): session_id=session_id, scopes=OIDR["scope"], claims_release_point="id_token" ) - res = self.claims_interface.get_user_claims("diana", _id_token_restriction) + res = self.claims_interface.get_user_claims("diana", _id_token_restriction, CLIENT_ID) assert res == { "email": "diana@example.org", @@ -252,7 +253,7 @@ def test_collect_user_info(self): session_id=session_id, scopes=OIDR["scope"], claims_release_point="introspection" ) - res = self.claims_interface.get_user_claims("diana", _restriction) + res = self.claims_interface.get_user_claims("diana", _restriction, CLIENT_ID) assert res == {} @@ -268,7 +269,7 @@ def test_collect_user_info_2(self): session_id=session_id, scopes=_req["scope"], claims_release_point="userinfo" ) - res = self.claims_interface.get_user_claims("diana", _userinfo_restriction) + res = self.claims_interface.get_user_claims("diana", _userinfo_restriction, CLIENT_ID) assert res == { "address": { @@ -299,7 +300,7 @@ def test_collect_user_info_scope_not_supported_no_base_claims(self): session_id=session_id, scopes=_req["scope"], claims_release_point="userinfo" ) - res = self.claims_interface.get_user_claims("diana", _userinfo_restriction) + res = self.claims_interface.get_user_claims("diana", _userinfo_restriction, CLIENT_ID) assert res == {} @@ -324,7 +325,7 @@ def test_collect_user_info_enable_claims_per_client(self): session_id=session_id, scopes=_req["scope"], claims_release_point="userinfo" ) - res = self.claims_interface.get_user_claims("diana", _userinfo_restriction) + res = self.claims_interface.get_user_claims("diana", _userinfo_restriction, CLIENT_ID) assert res == {"phone_number": "+46907865000"} @@ -423,7 +424,7 @@ def conf(self): def create_endpoint_context(self, conf): self.server = Server(conf) self.endpoint_context = self.server.context - self.endpoint_context.cdb["client1"] = { + self.endpoint_context.cdb[CLIENT_ID] = { "allowed_scopes": [ "openid", "profile", @@ -470,7 +471,7 @@ def test_collect_user_info_custom_scope(self): session_id=session_id, scopes=_req["scope"], claims_release_point="userinfo" ) - res = self.claims_interface.get_user_claims("diana", _restriction) + res = self.claims_interface.get_user_claims("diana", _restriction, CLIENT_ID) assert res == { "eduperson_scoped_affiliation": ["staff@example.org"], @@ -487,7 +488,7 @@ def test_collect_user_info_scope_mapping_per_client(self, conf): endpoint_context = server.context self.session_manager = endpoint_context.session_manager claims_interface = endpoint_context.claims_interface - endpoint_context.cdb["client1"] = { + endpoint_context.cdb[CLIENT_ID] = { "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] } @@ -501,7 +502,7 @@ def test_collect_user_info_scope_mapping_per_client(self, conf): session_id=session_id, scopes=_req["scope"], claims_release_point="userinfo" ) - res = claims_interface.get_user_claims("diana", _restriction) + res = claims_interface.get_user_claims("diana", _restriction, CLIENT_ID) assert res == { "eduperson_scoped_affiliation": ["staff@example.org"], "email": "diana@example.org", diff --git a/tox.ini b/tox.ini index 74bf2d21..79cae5ea 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py{37,38,39,310},docs,quality +envlist = py{38,39,310,311},docs,quality [testenv] passenv = CI TRAVIS TRAVIS_*