diff --git a/demo/README.md b/demo/README.md new file mode 100644 index 00000000..f51ac468 --- /dev/null +++ b/demo/README.md @@ -0,0 +1,315 @@ +# Usage stories + +This is a set of usage stories. +Here to display what you can do with IdpyOIDC using OAuth2 or OIDC. + +Every story follows the same pattern it starts by initiating one client/RP and +one AS/OP. +After that a sequence of requests/responses are performed. Each one follows this +pattern: + +- The client/RP constructs the request and possible client authentication information +- The request and client authentication information is printed +- The AS/OP does client authentication based on the authentication information received +- The AS/OP parses and verifies the client request +- The AS/OP constructs the server response +- The client/RP parses and verifies the server response +- The parsed and verified response is printed + +This pattern is repeated for each request/response in the sequence. + +To understand the descriptions below you have to remember that an AS/OP provides +**endpoints** while a client/RP accesses **services**. An endpoint can +support more than one service. A service can only reside at one endpoint. + +## Basic OAuth2 Stories + +These are based on the two basic OAuth2 RFCs; +* [The OAuth 2.0 Authorization Framework](https://www.rfc-editor.org/rfc/rfc6749) +* [The OAuth 2.0 Authorization Framework: Bearer Token Usage](https://www.rfc-editor.org/rfc/rfc6750) + +### Client Credentials Grant (oauth2_cc.py) + +Displays the usage of the +[client credentials grant](https://www.rfc-editor.org/rfc/rfc6749#section-4.4) . + +The client can request an access token using only its client +credentials (or other supported means of authentication). + +The request/response sequence only contains the client credential exchange. + +The client is statically registered with the AS. + +#### configuration + +The server configuration expresses these points: + +- The server needs only one endpoint, the token endpoint. +- The token released form the token endpoint is a signed JSON Web token (JWT) +- The server deals only with access tokens. The default lifetime of a token is 3600 +seconds. +- The server can deal with 2 client authentication methods at the token endpoint: + client_secret_basic and client_secret_post +- In this example the audience for the token (the resource server) is statically set. + + + "endpoint": { + "token": { + "path": "token", + "class": Token, + "kwargs": { + "client_authn_method": ["client_secret_basic", "client_secret_post"], + }, + }, + }, + "token_handler_args": { + "jwks_defs": {"key_defs": KEYDEFS}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + } + } + } + +The client configuration + +- lists only one service - client credentials +- specifies client ID and client secret since the client is statically + registered with the server. + + + "client_id": "client_1", + "client_secret": "another password", + "base_url": "https://example.com", + "services": { + "client_credentials": { + "class": "idpyoidc.client.oauth2.client_credentials.CCAccessTokenRequest" + } + } + +**services** is a dictionary. The keys in that dictionary is for your usage only. +Internally the software uses identifiers that are statically assigned to every Service class. +This means that you can not have two instances of the same class in a _services_ +definition. + +### Resource Owners Password Credentials (oauth2_ropc.py) + +**NOTE** Resource Owners Password Credentials is not part of OAuth2.1 + +Displays the usage of the +[resource owners username and password](https://www.rfc-editor.org/rfc/rfc6749#section-4.3) +for doing authorization. + +The resource owner password credentials grant type is suitable in +cases where the resource owner has a trust relationship with the +client, such as the device operating system or a highly privileged application. + +#### Configuration + +The big difference between Client Credentials and Resource Owners Passsword credentials +is that the server also most support user authentication. Therefor this +part is added to the server configuration: + + "authentication": { + "user": { + "acr": "urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword", + "class": "idpyoidc.server.user_authn.user.UserPass", + "kwargs": { + "db_conf": { + "class": "idpyoidc.server.util.JSONDictDB", + "kwargs": {"filename": full_path("passwd.json")} + } + } + } + } + +This allows for a very simple username/password check against a static file. + +On the client side the change is that the service configuration now looks +like this: + + services = { + "ropc": { + "class": "idpyoidc.client.oauth2.resource_owner_password_credentials.ROPCAccessTokenRequest" + } + } + + +### Authorization Code Grant (oauth2_code.py) + +The +[authorization code grant](https://www.rfc-editor.org/rfc/rfc6749#section-4.1) +is used to obtain both access tokens and possibly refresh tokens and is optimized +for confidential clients. + +Since this is a redirection-based flow, the client must be capable of +interacting with the resource owner's user-agent (typically a web +browser) and capable of receiving incoming requests (via redirection) +from the authorization server. + +In the demo implementation the response is transmitted directly from the server +to the client no user agent is involved. + +In this story the flow contains three request/responses + +- Fetching server metadata +- Authorization +- Access token + +#### Configuration + +Let's take it part by part. +First the endpoints, straight forward support for the sequence of exchanges we +want to exercise. + + "endpoint": { + "metadata": { + "path": ".well-known/oauth-authorization-server", + "class": "idpyoidc.server.oauth2.server_metadata.ServerMetadata", + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": "idpyoidc.server.oauth2.authorization.Authorization", + "kwargs": {}, + }, + "token": { + "path": "token", + "class": "idpyoidc.server.oauth2.token.Token", + "kwargs": {}, + } + }, + +Next comes the type of tokens the grant manager can issue. +In this case authorization codes and access tokens. + + "token_handler_args": { + "key_conf": {"key_defs": KEYDEFS}, + "code": { + "lifetime": 600, + "kwargs": { + "crypt_conf": CRYPT_CONFIG + } + }, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + }, + } + }, + +The software can produce 3 types of tokens. + +- An encrypted value, unreadable by anyone but the server +- A signed JSON Web Token following the pattern described in +[JSON Web Token (JWT) Profile for OAuth 2.0 Access Tokens](https://datatracker.ietf.org/doc/rfc9068/) +- An IDToken which only is used to represent ID Tokens. + +In this example only the two first types are used since no ID Tokens are produced. + +The next part is about the grant manager. + + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "authorization_code": { + "supports_minting": ["access_token"], + "max_usage": 1, + }, + "access_token": { + "expires_in": 600, + } + } + } + }, + }, + +What this says is that an authorization code can only be used once and +only to mint an access token. The lifetime for an authorization code is +the default which is 300 seconds (5 minutes). +The access token can not be used to mint anything. Note that in the +token handler arguments the lifetime is set to 3600 seconds for a token +while in the authz part and access tokens lifetime is defined to be +600 seconds. It's the later that is used since it is more specific. + + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "idpyoidc.server.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + +It's convenient to use this no-authentication method in this context since we +can't deal with user interaction. +What happens is that authentication is assumed to have happened and that +it resulted in that **diana** was authenticated. + +## OAuth2 Extension Stories + +The stories display support for a set of OAuth2 extension RFCs + +### PKCE (oauth2_add_on_pkce.py) + +[Proof Key for Code Exchange by OAuth Public Clients](https://datatracker.ietf.org/doc/rfc7636/). +A technique to mitigate against the authorization code interception attack through +the use of Proof Key for Code Exchange (PKCE). + +#### Configuration + +On the server side only one thing is added: + + "add_ons": { + "pkce": { + "function": "idpyoidc.server.oauth2.add_on.pkce.add_support", + "kwargs": {}, + }, + } + +Similar on the client side: + + "add_ons": { + "pkce": { + "function": "idpyoidc.client.oauth2.add_on.pkce.add_support", + "kwargs": { + "code_challenge_length": 64, + "code_challenge_method": "S256" + }, + }, + } + +### JAR (oauth2_add_on_jar.py) + +[JWT-Secured Authorization Request (JAR)](https://datatracker.ietf.org/doc/rfc9101/) +This document introduces the ability to send request parameters in a +JSON Web Token (JWT) instead, which allows the request to be signed +with JSON Web Signature (JWS) and encrypted with JSON Web Encryption +(JWE) so that the integrity, source authentication, and +confidentiality properties of the authorization request are attained. +The request can be sent by value or by reference. + +#### Configuration + +On the server side nothing has to be done. The support for the +request and request_uri parameters are built in to begin with. +The reason for this is that OIDC had this from the beginning. + +On the client side this had to be added: + + "add_ons": { + "jar": { + "function": "idpyoidc.client.oauth2.add_on.jar.add_support", + "kwargs": { + 'request_type': 'request_parameter', + 'request_object_signing_alg': "ES256", + 'expires_in': 600 + }, + }, + } + diff --git a/demo/common.py b/demo/common.py new file mode 100644 index 00000000..cac17922 --- /dev/null +++ b/demo/common.py @@ -0,0 +1,27 @@ +import os + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +CRYPT_CONFIG = { + "kwargs": { + "keys": { + "key_defs": [ + {"type": "OCT", "use": ["enc"], "kid": "password"}, + {"type": "OCT", "use": ["enc"], "kid": "salt"}, + ] + }, + "iterations": 1, + } +} + +SESSION_PARAMS = {"encrypter": CRYPT_CONFIG} + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] diff --git a/demo/flow.py b/demo/flow.py new file mode 100755 index 00000000..149e810e --- /dev/null +++ b/demo/flow.py @@ -0,0 +1,253 @@ +import json + +import responses + +from idpyoidc.message import Message +from idpyoidc.message.oauth2 import is_error_message +from idpyoidc.util import rndstr + + +class Flow(object): + + def __init__(self, client, server): + self.client = client + self.server = server + + def print(self, proc, msg): + print(30 * '=' + f' {proc} ' + 30 * '=') + print("-- REQUEST --") + print(f" METHOD: {msg['method']}") + if 'url' in msg: + print(f" URL: {msg['url']}") + if msg['headers']: + print(' HEADERS') + for line in json.dumps(msg['headers'], sort_keys=True, indent=4).split('\n'): + print(' ' + line) + if not msg['request']: + print('{}') + else: + print(json.dumps(msg['request'].to_dict(), sort_keys=True, indent=4)) + print('-- RESPONSE --') + _resp = msg['response'] + if isinstance(_resp, Message): + print(json.dumps(_resp.to_dict(), sort_keys=True, indent=4)) + else: + print(json.dumps(_resp, sort_keys=True, indent=4)) + print() + + def do_query(self, service_type, endpoint_type, request_args=None, msg=None): + if request_args is None: + request_args = {} + if msg is None: + msg = {} + + _client_service = self.client.get_service(service_type) + + _additions = msg.get('request_additions') + if _additions: + _info = _additions.get(service_type) + if _info: + request_args.update(_info) + + _args = msg.get('get_request_parameters', {}) + kwargs = _args.get(service_type, {}) + + if service_type in ["userinfo", 'refresh_token']: + kwargs['state'] = msg['authorization']['request']['state'] + + _mock_resp = msg.get('mock_response') + if _mock_resp: + _func = _mock_resp.get(service_type) + _info = _func(_client_service) + with responses.RequestsMock() as rsps: + rsps.add( + "GET", + _info["uri"], + json=_info["data"], + content_type="application/json", + status=200, + ) + req_info = _client_service.get_request_parameters(request_args=request_args, + **kwargs) + else: + req_info = _client_service.get_request_parameters(request_args=request_args, **kwargs) + + areq = req_info.get("request") + headers = req_info.get("headers") + + _server_endpoint = self.server.get_endpoint(endpoint_type) + if headers: + argv = {"http_info": {"headers": headers}} + argv['http_info']['url'] = req_info['url'] + argv['http_info']['method'] = req_info['method'] + else: + argv = {} + + if areq: + if _server_endpoint.request_format == 'json': + _pr_req = _server_endpoint.parse_request(areq.to_json(), **argv) + else: + _pr_req = _server_endpoint.parse_request(areq.to_urlencoded(), **argv) + else: + if areq is None: + _pr_req = _server_endpoint.parse_request(areq) + else: + _pr_req = _server_endpoint.parse_request(areq, **argv) + + if is_error_message(_pr_req): + result = {'request': _pr_req, 'headers': headers, + 'method': req_info['method'], 'url': req_info['url']} + self.print(f"{service_type} - ERROR", result) + return result + + args = msg.get('process_request_args', {}) + _resp = _server_endpoint.process_request(_pr_req, **args.get(endpoint_type, {})) + if is_error_message(_resp): + result = {'request': areq, 'response': _resp, 'headers': headers, + 'method': req_info['method'], 'url': req_info['url']} + self.print(f"{service_type} - ERROR", result) + return result + + _response = _server_endpoint.do_response(**_resp) + + # resp = _client_service.parse_response(_response["response"]) + _state = '' + if service_type == 'authorization': + _state = areq.get('state', _pr_req.get('state')) + else: + _authz = msg.get('authorization') + if _authz: + _state = _authz['request']['state'] + + if 'response_args' in _resp: + if _client_service.service_name in ['server_metadata', 'provider_info']: + if 'server_jwks_uri' in msg and 'server_jwks' in msg: + with responses.RequestsMock() as rsps: + rsps.add( + "GET", + msg["server_jwks_uri"], + json=msg["server_jwks"], + content_type="application/json", + status=200, + ) + + _client_service.update_service_context(_resp["response_args"], key=_state) + else: + _client_service.update_service_context(_resp["response_args"], key=_state) + else: + _client_service.update_service_context(_resp["response_args"], key=_state) + + _response = _resp.get('response_args', _resp.get('response', _resp.get('response_msg'))) + result = {'request': areq, 'response': _response, 'headers': headers, + 'method': req_info['method'], 'url': req_info['url']} + self.print(service_type, result) + return result + + def server_metadata_request(self, msg): + return {} + + def provider_info_request(self, msg): + return {} + + def authorization_request(self, msg): + # ***** Authorization Request ********** + _nonce = rndstr(24) + _context = self.client.get_service_context() + # Need a new state for a new authorization request + _state = _context.cstate.create_state(iss=_context.get("issuer")) + _context.cstate.bind_key(_nonce, _state) + _response_type = msg.get('response_type', ['code']) + + req_args = { + "response_type": _response_type, + "nonce": _nonce, + "state": _state + } + + scope = msg.get('scope') + if scope: + if 'openid' not in scope: + scope.append('openid') + _scope = scope + else: + _scope = ["openid"] + + req_args["scope"] = _scope + + return req_args + + def accesstoken_request(self, msg): + # ***** Token Request ********** + _context = self.client.get_service_context() + + auth_resp = msg['authorization']['response'] + req_args = { + "code": auth_resp["code"], + "state": auth_resp["state"], + "redirect_uri": msg['authorization']['request']["redirect_uri"], + "grant_type": "authorization_code", + "client_id": self.client.get_client_id(), + "client_secret": _context.get_usage("client_secret"), + } + + return req_args + + def introspection_request(self, msg): + _context = self.client.get_context() + auth_resp = msg['authorization']['response'] + _state = _context.cstate.get(auth_resp["state"]) + + return { + "token": _state['access_token'], + "token_type_hint": 'access_token' + } + + def token_revocation_request(self, msg): + _context = self.client.get_context() + auth_resp = msg['authorization']['response'] + _state = _context.cstate.get(auth_resp["state"]) + + return { + "token": _state['access_token'], + "token_type_hint": 'access_token' + } + + def token_exchange_request(self, msg): + _token = msg['accesstoken']['response']['access_token'] + _state = msg['authorization']['request']['state'] + + return { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "requested_token_type": 'urn:ietf:params:oauth:token-type:access_token', + "subject_token": _token, + "subject_token_type": 'urn:ietf:params:oauth:token-type:access_token', + "state": _state + } + + def refresh_token_request(self, msg): + _state = msg['authorization']['request']['state'] + + return { + "grant_type": "refresh_token", + "state": _state, + } + + def registration_request(self, msg): + return {} + + def userinfo_request(self, msg): + return {} + + def client_credentials_request(self, msg): + return {} + + def resource_owner_password_credentials_request(self, msg): + return {} + + def __call__(self, request_responses: list[list], **kwargs): + msg = kwargs + for request, response in request_responses: + func = getattr(self, f"{request}_request") + req_args = func(msg) + msg[request] = self.do_query(request, response, req_args, msg) + return msg diff --git a/demo/oauth2_add_on_dpop.py b/demo/oauth2_add_on_dpop.py new file mode 100755 index 00000000..bfb37bc2 --- /dev/null +++ b/demo/oauth2_add_on_dpop.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +import json + +from common import BASEDIR +from common import KEYDEFS +from common import full_path +from flow import Flow +from idpyoidc.claims import get_signing_algs +from idpyoidc.client.oauth2 import Client +from idpyoidc.server import Server +from idpyoidc.server.configure import ASConfiguration +from idpyoidc.server.user_info import UserInfo +from oauth2_client_conf import CLIENT_CONFIG +from oauth2_client_conf import CLIENT_ID +from oauth2_server_conf import SERVER_CONF + +# ================ Server side =================================== + +USERINFO = UserInfo(json.loads(open(full_path("users.json")).read())) + +server_conf = SERVER_CONF.copy() +server_conf["keys"] = {"uri_path": "jwks.json", "key_defs": KEYDEFS} +server_conf["token_handler_args"]["key_conf"] = {"key_defs": KEYDEFS} + +server_conf['add_ons'] = { + "dpop": { + "function": "idpyoidc.server.oauth2.add_on.dpop.add_support", + "kwargs": { + 'dpop_signing_alg_values_supported': get_signing_algs() + } + } +} + +server = Server(ASConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + +# ================ Client side =================================== + +client_conf = CLIENT_CONFIG +client_conf['issuer'] = SERVER_CONF['issuer'] +client_conf['key_conf'] = {'key_defs': KEYDEFS} + +client_conf['add_ons'] = { + "dpop": { + "function": "idpyoidc.client.oauth2.add_on.dpop.add_support", + "kwargs": { + "dpop_signing_alg_values_supported": ["ES256"] + } + } +} + +client = Client(config=client_conf) + +# ==== What the server needs to know about the client. + +server.context.cdb[CLIENT_ID] = {k: v for k, v in CLIENT_CONFIG.items() if k not in ['services']} +server.context.keyjar.import_jwks(client.keyjar.export_jwks(), CLIENT_ID) + +# Initiating the Server's metadata + +server.context.set_provider_info() + +# ==== And now for the protocol exchange sequence + +flow = Flow(client, server) +msg = flow( + [ + ['server_metadata', 'server_metadata'], + ['authorization', 'authorization'], + ["accesstoken", 'token'], + ], + scope=['foobar'], + server_jwks=server.keyjar.export_jwks(''), + server_jwks_uri=server.context.provider_info['jwks_uri'] +) diff --git a/demo/oauth2_add_on_jar.py b/demo/oauth2_add_on_jar.py new file mode 100755 index 00000000..71c6825b --- /dev/null +++ b/demo/oauth2_add_on_jar.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +import os + +from common import BASEDIR +from common import KEYDEFS +from flow import Flow +from idpyoidc.client.oauth2 import Client +from idpyoidc.server import Server +from idpyoidc.server.configure import ASConfiguration +from oauth2_client_conf import CLIENT_CONFIG +from oauth2_client_conf import CLIENT_ID +from oauth2_server_conf import SERVER_CONF + + +# ================ Server side =================================== + +server_conf = SERVER_CONF.copy() +server_conf["keys"] = {"uri_path": "jwks.json", "key_defs": KEYDEFS} +server_conf["token_handler_args"]["key_conf"] = {"key_defs": KEYDEFS} + +# The server knows how to deal with JAR without an add-on + +server = Server(ASConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + +# ================ Client side =================================== + +client_conf = CLIENT_CONFIG.copy() +client_conf['issuer'] = SERVER_CONF['issuer'] +client_conf['key_conf'] = {'key_defs': KEYDEFS} + +client_conf['add_ons'] = { + "jar": { + "function": "idpyoidc.client.oauth2.add_on.jar.add_support", + "kwargs": { + 'request_type': 'request_parameter', + 'request_object_signing_alg': "ES256", + 'expires_in': 600 + } + } +} + +client = Client(config=client_conf) + +# ==== What the server needs to know about the client. + +server.context.cdb[CLIENT_ID] = {k: v for k, v in CLIENT_CONFIG.items() if k not in ['services']} +server.context.keyjar.import_jwks(client.keyjar.export_jwks(), CLIENT_ID) + +# Initiating the server's metadata + +server.context.set_provider_info() + +# ==== And now for the protocol exchange sequence + +flow = Flow(client, server) +msg = flow( + [ + ['server_metadata', 'server_metadata'], + ['authorization', 'authorization'] + ], + scope=['foobar'], + server_jwks=server.keyjar.export_jwks(''), + server_jwks_uri=server.context.provider_info['jwks_uri'] +) diff --git a/demo/oauth2_add_on_par.py b/demo/oauth2_add_on_par.py new file mode 100755 index 00000000..b9f86a82 --- /dev/null +++ b/demo/oauth2_add_on_par.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 + +from common import BASEDIR +from common import KEYDEFS +from flow import Flow +from idpyoidc.client.oauth2 import Client +from idpyoidc.server import Server +from idpyoidc.server.configure import ASConfiguration +from oauth2_client_conf import CLIENT_CONFIG +from oauth2_client_conf import CLIENT_ID +from oauth2_server_conf import SERVER_CONF + +# ================ Server side =================================== + +server_conf = SERVER_CONF.copy() +server_conf["keys"] = {"uri_path": "jwks.json", "key_defs": KEYDEFS} +server_conf["token_handler_args"]["key_conf"] = {"key_defs": KEYDEFS} +server_conf['endpoint'] = { + "metadata": { + "path": ".well-known/oauth-authorization-server", + "class": "idpyoidc.server.oauth2.server_metadata.ServerMetadata", + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": "idpyoidc.server.oauth2.authorization.Authorization", + "kwargs": {}, + }, + "token": { + "path": "token", + "class": "idpyoidc.server.oauth2.token.Token", + "kwargs": {}, + }, + "pushed_authorization": { + "path": "pushed_authorization", + "class": "idpyoidc.server.oauth2.pushed_authorization.PushedAuthorization", + "kwargs": { + "client_authn_method": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ] + }, + }, +} + +# The server knows how to deal with JAR without an add-on + +server = Server(ASConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + +# ================ Client side =================================== + + +client_conf = CLIENT_CONFIG.copy() +client_conf['issuer'] = SERVER_CONF['issuer'] +client_conf['key_conf'] = {'key_defs': KEYDEFS} + +client_conf['add_ons'] = { + "par": { + "function": "idpyoidc.client.oauth2.add_on.par.add_support", + "kwargs": { + 'http_client': { + 'class': 'utils.EmulatePARCall' + }, + 'authn_method': 'client_secret_basic' + } + } +} + +client = Client(config=client_conf) + +# ==== What the server needs to know about the client. + +server.context.cdb[CLIENT_ID] = {k: v for k, v in CLIENT_CONFIG.items() if k not in ['services']} +server.context.keyjar.import_jwks(client.keyjar.export_jwks(), CLIENT_ID) + +# Initiating the server's metadata + +server.context.set_provider_info() + +# ==== And now for the protocol exchange sequence + +client.context.add_on['pushed_authorization']['http_client'].server = server + +flow = Flow(client, server) +msg = flow( + [ + ['server_metadata', 'server_metadata'], + ['authorization', 'authorization'] + ], + scope=['foobar'], + server_jwks=server.keyjar.export_jwks(''), + server_jwks_uri=server.context.provider_info['jwks_uri'], +) diff --git a/demo/oauth2_add_on_pkce.py b/demo/oauth2_add_on_pkce.py new file mode 100755 index 00000000..16d723fd --- /dev/null +++ b/demo/oauth2_add_on_pkce.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +import os + +from common import BASEDIR +from common import KEYDEFS +from flow import Flow +from idpyoidc.client.oauth2 import Client +from idpyoidc.server import Server +from idpyoidc.server.configure import ASConfiguration +from oauth2_client_conf import CLIENT_CONFIG +from oauth2_client_conf import CLIENT_ID +from oauth2_server_conf import SERVER_CONF + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +# ================ Server side =================================== + +server_conf = SERVER_CONF.copy() +server_conf["keys"] = {"uri_path": "jwks.json", "key_defs": KEYDEFS} +server_conf["token_handler_args"]["key_conf"] = {"key_defs": KEYDEFS} + +server_conf['add_ons'] = { + "pkce": { + "function": "idpyoidc.server.oauth2.add_on.pkce.add_support", + "kwargs": {}, + }, +} +server = Server(ASConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + +# ================ Client side =================================== + +client_config = CLIENT_CONFIG +client_config['issuer'] = SERVER_CONF['issuer'] +client_config['key_conf'] = {'key_defs': KEYDEFS} + +client_config['add_ons'] = { + "pkce": { + "function": "idpyoidc.client.oauth2.add_on.pkce.add_support", + "kwargs": { + "code_challenge_length": 64, + "code_challenge_method": "S256" + }, + }, +} + +client = Client(config=client_config) + +# ==== What the server needs to know about the client. + +server.context.cdb[CLIENT_ID] = {k: v for k, v in CLIENT_CONFIG.items() if k not in ['services']} +server.context.keyjar.import_jwks(client.keyjar.export_jwks(), CLIENT_ID) + +# Initiating the server's metadata + +server.context.set_provider_info() + +# ==== And now for the exchange sequence + +flow = Flow(client, server) +msg = flow( + [ + ['server_metadata', 'server_metadata'], + ['authorization', 'authorization'], + ["accesstoken", 'token'] + ], + scope=['foobar'], + server_jwks=server.keyjar.export_jwks(''), + server_jwks_uri=server.context.provider_info['jwks_uri'] +) diff --git a/demo/oauth2_cc.py b/demo/oauth2_cc.py new file mode 100755 index 00000000..852e7541 --- /dev/null +++ b/demo/oauth2_cc.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +""" +Displaying how Client Credentials works +""" + +from common import BASEDIR +from common import KEYDEFS +from common import SESSION_PARAMS +from flow import Flow +from idpyoidc.client.oauth2 import Client +from idpyoidc.server import Server +from idpyoidc.server.authz import AuthzHandling +from idpyoidc.server.client_authn import verify_client +from idpyoidc.server.configure import ASConfiguration +from idpyoidc.server.oauth2.token import Token + +SERVER_CONFIG = { + "issuer": "https://example.net/", + "httpc_params": {"verify": False}, + "preference": { + "grant_types_supported": ["client_credentials", "password"] + }, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS, 'read_only': False}, + "endpoint": { + "token": { + "path": "token", + "class": Token, + "kwargs": { + "client_authn_method": ["client_secret_basic", "client_secret_post"], + } + } + }, + "token_handler_args": { + "jwks_defs": {"key_defs": KEYDEFS}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + } + } + }, + "client_authn": verify_client, + "claims_interface": { + "class": "idpyoidc.server.session.claims.OAuth2ClaimsInterface", + "kwargs": {}, + }, + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "access_token": {}, + } + } + }, + }, + "session_params": {"encrypter": SESSION_PARAMS}, +} + +CLIENT_CONFIG = { + "client_id": "client_1", + "client_secret": "another password", + "base_url": "https://example.com", + 'services': { + "client_credentials": { + "class": "idpyoidc.client.oauth2.client_credentials.CCAccessTokenRequest" + } + } +} + +# Client side + +client = Client(config=CLIENT_CONFIG) + +client_credentials_service = client.get_service('client_credentials') +client_credentials_service.endpoint = "https://example.com/token" + +# Server side + +server = Server(ASConfiguration(conf=SERVER_CONFIG, base_path=BASEDIR), cwd=BASEDIR) +server.context.cdb["client_1"] = { + "client_secret": CLIENT_CONFIG['client_secret'], + "allowed_scopes": ["resourceA"], +} + +flow = Flow(client, server) +msg = flow( + [ + ["client_credentials", 'token'] + ] +) diff --git a/demo/oauth2_client_conf.py b/demo/oauth2_client_conf.py new file mode 100644 index 00000000..9680b63f --- /dev/null +++ b/demo/oauth2_client_conf.py @@ -0,0 +1,14 @@ +CLIENT_ID = 'client' + +CLIENT_CONFIG = { + "client_secret": "SUPERhemligtlösenord", + "client_id": CLIENT_ID, + "redirect_uris": ["https://example.com/cb"], + "token_endpoint_auth_methods_supported": ["client_secret_post"], + "response_types_supported": ["code"], + "services": { + "metadata": {"class": "idpyoidc.client.oauth2.server_metadata.ServerMetadata"}, + "authorization": {"class": "idpyoidc.client.oauth2.authorization.Authorization"}, + "access_token": {"class": "idpyoidc.client.oauth2.access_token.AccessToken"}, + } +} diff --git a/demo/oauth2_code.py b/demo/oauth2_code.py new file mode 100755 index 00000000..19e7264d --- /dev/null +++ b/demo/oauth2_code.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +import os + +from common import BASEDIR +from common import KEYDEFS +from flow import Flow +from idpyoidc.client.oauth2 import Client +from idpyoidc.server import Server +from idpyoidc.server.configure import ASConfiguration +from oauth2_client_conf import CLIENT_CONFIG +from oauth2_client_conf import CLIENT_ID +from oauth2_server_conf import SERVER_CONF + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +# ================ Server side =================================== + +server_conf = SERVER_CONF.copy() +server_conf["keys"] = {"uri_path": "jwks.json", "key_defs": KEYDEFS} +server_conf["token_handler_args"]["key_conf"] = {"key_defs": KEYDEFS} + +server = Server(ASConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + +# ================ Client side =================================== + +client_conf = CLIENT_CONFIG.copy() +client_conf['issuer'] = SERVER_CONF['issuer'] +client_conf['key_conf'] = {'key_defs': KEYDEFS} + +client = Client(config=client_conf) + +# ==== What the server needs to know about the client. + +server.context.cdb[CLIENT_ID] = {k: v for k, v in CLIENT_CONFIG.items() if k not in ['services']} +server.context.keyjar.import_jwks(client.keyjar.export_jwks(), CLIENT_ID) + +# Initiating the server's metadata + +server.context.set_provider_info() + +# ==== And now for the protocol exchange sequence + +flow = Flow(client, server) +msg = flow( + [ + ['server_metadata', 'server_metadata'], + ['authorization', 'authorization'], + ["accesstoken", 'token'], + ], + scope=['foobar'], + server_jwks=server.keyjar.export_jwks(''), + server_jwks_uri=server.context.provider_info['jwks_uri'] +) diff --git a/demo/oauth2_ropc.py b/demo/oauth2_ropc.py new file mode 100755 index 00000000..feaecba2 --- /dev/null +++ b/demo/oauth2_ropc.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 + +from common import BASEDIR +from common import KEYDEFS +from common import SESSION_PARAMS +from common import full_path +from flow import Flow +from idpyoidc.client.oauth2 import Client +from idpyoidc.server import ASConfiguration +from idpyoidc.server import Server +from idpyoidc.server.authz import AuthzHandling +from idpyoidc.server.client_authn import verify_client +from idpyoidc.server.oauth2.token import Token + +SERVER_CONFIG = { + "issuer": "https://example.net/", + "httpc_params": {"verify": False}, + "preference": { + "grant_types_supported": ["client_credentials", "password"] + }, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS, 'read_only': False}, + "endpoint": { + "token": { + "path": "token", + "class": Token, + "kwargs": { + "client_authn_method": ["client_secret_basic", "client_secret_post"], + }, + }, + }, + "token_handler_args": { + "jwks_defs": {"key_defs": KEYDEFS}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + } + } + }, + "client_authn": verify_client, + "claims_interface": { + "class": "idpyoidc.server.session.claims.OAuth2ClaimsInterface", + "kwargs": {}, + }, + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "access_token": {"expires_in": 3600} + } + } + } + }, + "session_params": {"encrypter": SESSION_PARAMS}, + "authentication": { + "user": { + "acr": "urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword", + "class": "idpyoidc.server.user_authn.user.UserPass", + "kwargs": { + "db_conf": { + "class": "idpyoidc.server.util.JSONDictDB", + "kwargs": {"filename": full_path("passwd.json")} + } + } + } + } +} + +CLIENT_BASE_URL = "https://example.com" + +CLIENT_CONFIG = { + "client_id": "client_1", + "client_secret": "another password", + "base_url": CLIENT_BASE_URL, + 'services': { + "resource_owner_password_credentials": { + "class": "idpyoidc.client.oauth2.resource_owner_password_credentials" + ".ROPCAccessTokenRequest" + } + } +} + +# Client side + +client = Client(config=CLIENT_CONFIG) + +ropc_service = client.get_service('resource_owner_password_credentials') +ropc_service.endpoint = "https://example.com/token" + +# Server side + +server = Server(ASConfiguration(conf=SERVER_CONFIG, base_path=BASEDIR), cwd=BASEDIR) +server.context.cdb["client_1"] = { + "client_secret": "another password", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "code id_token", "id_token"], + "allowed_scopes": ["resourceA"], +} + +flow = Flow(client, server) +msg = flow( + [ + ["resource_owner_password_credentials", 'token'] + ], + request_additions={ + 'resource_owner_password_credentials': {'username': 'diana', 'password': 'krall'} + } +) diff --git a/demo/oauth2_server_conf.py b/demo/oauth2_server_conf.py new file mode 100644 index 00000000..163ba944 --- /dev/null +++ b/demo/oauth2_server_conf.py @@ -0,0 +1,67 @@ +from common import CRYPT_CONFIG +from common import SESSION_PARAMS +from idpyoidc.server.authz import AuthzHandling +from idpyoidc.server.client_authn import verify_client +from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD + +SERVER_CONF = { + "issuer": "https://example.com/", + "httpc_params": {"verify": False, "timeout": 1}, + "endpoint": { + "metadata": { + "path": ".well-known/oauth-authorization-server", + "class": "idpyoidc.server.oauth2.server_metadata.ServerMetadata", + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": "idpyoidc.server.oauth2.authorization.Authorization", + "kwargs": {}, + }, + "token": { + "path": "token", + "class": "idpyoidc.server.oauth2.token.Token", + "kwargs": {}, + } + }, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "idpyoidc.server.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "client_authn": verify_client, + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "authorization_code": { + "supports_minting": ["access_token"], + "max_usage": 1, + }, + "access_token": { + "expires_in": 600, + } + } + } + }, + }, + "token_handler_args": { + "code": { + "lifetime": 600, + "kwargs": { + "crypt_conf": CRYPT_CONFIG + } + }, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + }, + } + }, + "session_params": SESSION_PARAMS, +} diff --git a/demo/oauth2_token_exchange.py b/demo/oauth2_token_exchange.py new file mode 100755 index 00000000..70b7b43e --- /dev/null +++ b/demo/oauth2_token_exchange.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +import os + +from flow import Flow +from idpyoidc.client.oauth2 import Client +from idpyoidc.server import Server +from idpyoidc.server.configure import ASConfiguration +from oauth2_client_conf import CLIENT_CONFIG +from oauth2_client_conf import CLIENT_ID +from oauth2_server_conf import SERVER_CONF + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +# ================ Server side =================================== + +server_conf = SERVER_CONF.copy() +server_conf["keys"] = {"uri_path": "jwks.json", "key_defs": KEYDEFS} +server_conf["token_handler_args"]["key_conf"] = {"key_defs": KEYDEFS} +server_conf["authz"]["kwargs"]["grant_config"]["usage_rules"]["access_token"] = { + "supports_minting": ["access_token"], + "expires_in": 600, +} + +server = Server(ASConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + +# ================ Client side =================================== + +client_conf = CLIENT_CONFIG.copy() +client_conf['issuer'] = SERVER_CONF['issuer'] +client_conf['key_conf'] = {'key_defs': KEYDEFS} +client_conf["services"] = { + "metadata": {"class": "idpyoidc.client.oauth2.server_metadata.ServerMetadata"}, + "authorization": {"class": "idpyoidc.client.oauth2.authorization.Authorization"}, + "access_token": {"class": "idpyoidc.client.oauth2.access_token.AccessToken"}, + "token_exchange": {"class": "idpyoidc.client.oauth2.token_exchange.TokenExchange"} +} +client_conf["allowed_scopes"] = ["foobar"] + +client = Client(config=client_conf) + +# ==== What the server needs to know about the client. + +server.context.cdb[CLIENT_ID] = {k: v for k, v in CLIENT_CONFIG.items() if k not in ['services']} +server.context.cdb[CLIENT_ID]['allowed_scopes'] = client_conf['allowed_scopes'] + +server.context.keyjar.import_jwks(client.keyjar.export_jwks(), CLIENT_ID) + +# Initiating the server's metadata + +server.context.set_provider_info() + +# ==== And now for the protocol exchange sequence + +flow = Flow(client, server) +msg = flow( + [ + ['server_metadata', 'server_metadata'], + ['authorization', 'authorization'], + ["accesstoken", 'token'], + ['token_exchange', 'token'] + ], + scope=['foobar'], + server_jwks=server.keyjar.export_jwks(''), + server_jwks_uri=server.context.provider_info['jwks_uri'], + request_additions={ + 'authorization': {'scope': 'foobar'} + } +) diff --git a/demo/oauth2_token_refresh.py b/demo/oauth2_token_refresh.py new file mode 100755 index 00000000..375fca8a --- /dev/null +++ b/demo/oauth2_token_refresh.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +import os + +from common import BASEDIR +from common import KEYDEFS +from flow import Flow +from idpyoidc.client.oauth2 import Client +from idpyoidc.server import Server +from idpyoidc.server.configure import ASConfiguration +from oauth2_client_conf import CLIENT_CONFIG +from oauth2_client_conf import CLIENT_ID +from oauth2_server_conf import SERVER_CONF + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +# ================ Server side =================================== + +server_conf = SERVER_CONF.copy() +server_conf["keys"] = {"uri_path": "jwks.json", "key_defs": KEYDEFS} +server_conf["token_handler_args"]["key_conf"] = {"key_defs": KEYDEFS} +server_conf["authz"]["kwargs"] = { + "grant_config": { + "usage_rules": { + "authorization_code": { + "supports_minting": ["access_token", "refresh_token"], + "max_usage": 1, + }, + "access_token": { + "supports_minting": ["access_token", "refresh_token"], + "expires_in": 600, + }, + "refresh_token": { + "supports_minting": ["access_token"], + "audience": ["https://example.com", "https://example2.com"], + "expires_in": 43200, + }, + }, + "expires_in": 43200, + } +} +server_conf['token_handler_args']["refresh"] = { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + } +} + +server = Server(ASConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + +# ================ Client side =================================== + +client_conf = CLIENT_CONFIG.copy() +client_conf['issuer'] = SERVER_CONF['issuer'] +client_conf['key_conf'] = {'key_defs': KEYDEFS} +client_conf["services"] = { + "metadata": {"class": "idpyoidc.client.oauth2.server_metadata.ServerMetadata"}, + "authorization": {"class": "idpyoidc.client.oauth2.authorization.Authorization"}, + "access_token": {"class": "idpyoidc.client.oauth2.access_token.AccessToken"}, + "refresh_token": {"class": "idpyoidc.client.oauth2.refresh_access_token.RefreshAccessToken"} +} +client_conf["allowed_scopes"] = ["profile", "offline_access", "foobar"] + +client = Client(config=client_conf) + +# ==== What the server needs to know about the client. + +server.context.cdb[CLIENT_ID] = {k: v for k, v in CLIENT_CONFIG.items() if k not in ['services']} +server.context.cdb[CLIENT_ID]['allowed_scopes'] = client_conf['allowed_scopes'] + +server.context.keyjar.import_jwks(client.keyjar.export_jwks(), CLIENT_ID) + +# Initiating the server's metadata + +server.context.set_provider_info() + +# ==== And now for the protocol exchange sequence + +flow = Flow(client, server) +msg = flow( + [ + ['server_metadata', 'server_metadata'], + ['authorization', 'authorization'], + ["accesstoken", 'token'], + ['refresh_token', 'token'] + ], + scope=['foobar'], + server_jwks=server.keyjar.export_jwks(''), + server_jwks_uri=server.context.provider_info['jwks_uri'], + process_request_args={'token': {'issue_refresh': True}}, + get_request_parameters={'refresh_token': {'authn_method': 'client_secret_post'}} +) diff --git a/demo/oauth2_token_revocation.py b/demo/oauth2_token_revocation.py new file mode 100755 index 00000000..42f6bf17 --- /dev/null +++ b/demo/oauth2_token_revocation.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 + +from common import BASEDIR +from common import KEYDEFS +from flow import Flow +from idpyoidc.client.oauth2 import Client +from idpyoidc.server import ASConfiguration +from idpyoidc.server import Server +from oauth2_client_conf import CLIENT_CONFIG +from oauth2_client_conf import CLIENT_ID +from oauth2_server_conf import SERVER_CONF + +# ================ Server side =================================== + +server_conf = SERVER_CONF.copy() +server_conf["keys"] = {"uri_path": "jwks.json", "key_defs": KEYDEFS} +server_conf["token_handler_args"]["key_conf"] = {"key_defs": KEYDEFS} +server_conf["authz"]["kwargs"] = { + "grant_config": { + "usage_rules": { + "authorization_code": { + "supports_minting": ["access_token", "refresh_token"], + "max_usage": 1, + }, + "access_token": { + "supports_minting": ["access_token", "refresh_token"], + "expires_in": 600, + }, + "refresh_token": { + "supports_minting": ["access_token"], + "audience": ["https://example.com", "https://example2.com"], + "expires_in": 43200, + }, + }, + "expires_in": 43200, + } +} +server_conf['token_handler_args']["refresh"] = { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + } +} +server_conf['endpoint'] = { + 'discovery': { + 'path': "/.well-known/oauth-authorization-server", + 'class': "idpyoidc.server.oauth2.server_metadata.ServerMetadata", + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": "idpyoidc.server.oauth2.authorization.Authorization", + "kwargs": {}, + }, + "token": { + "path": "token", + "class": "idpyoidc.server.oauth2.token.Token", + "kwargs": {}, + }, + "token_revocation": { + 'path': 'revocation', + "class": "idpyoidc.server.oauth2.token_revocation.TokenRevocation", + "kwargs": {}, + }, + 'introspection': { + 'path': 'introspection', + 'class': "idpyoidc.server.oauth2.introspection.Introspection" + } +} + +server = Server(ASConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + +# ================ Client side =================================== + +client_conf = CLIENT_CONFIG.copy() +client_conf['issuer'] = SERVER_CONF['issuer'] +client_conf['key_conf'] = {'key_defs': KEYDEFS} +client_conf["services"] = { + "metadata": {"class": "idpyoidc.client.oauth2.server_metadata.ServerMetadata"}, + "authorization": {"class": "idpyoidc.client.oauth2.authorization.Authorization"}, + "access_token": {"class": "idpyoidc.client.oauth2.access_token.AccessToken"}, + 'token_revocation': { + 'class': 'idpyoidc.client.oauth2.token_revocation.TokenRevocation' + }, + 'introspection': { + 'class': 'idpyoidc.client.oauth2.introspection.Introspection' + } +} +client_conf["allowed_scopes"] = ["profile", "offline_access", "foobar"] + +client = Client(config=client_conf) + +# ==== What the server needs to know about the client. + +server.context.cdb[CLIENT_ID] = {k: v for k, v in CLIENT_CONFIG.items() if k not in ['services']} +server.context.cdb[CLIENT_ID]['allowed_scopes'] = client_conf['allowed_scopes'] + +server.context.keyjar.import_jwks(client.keyjar.export_jwks(), CLIENT_ID) + +# Initiating the server's metadata + +server.context.set_provider_info() + +# ------- tell the server about the client ---------------- + +flow = Flow(client, server) +msg = flow( + [ + ['server_metadata', 'server_metadata'], + ['authorization', 'authorization'], + ["accesstoken", 'token'], + ['introspection', 'introspection'], + ['token_revocation', 'token_revocation'], + ['introspection', 'introspection'], + ], + scope=['foobar'], + server_jwks=server.keyjar.export_jwks(''), + server_jwks_uri=server.context.provider_info['jwks_uri'] +) diff --git a/demo/oidc_add_on_dpop.py b/demo/oidc_add_on_dpop.py new file mode 100755 index 00000000..bfb37bc2 --- /dev/null +++ b/demo/oidc_add_on_dpop.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +import json + +from common import BASEDIR +from common import KEYDEFS +from common import full_path +from flow import Flow +from idpyoidc.claims import get_signing_algs +from idpyoidc.client.oauth2 import Client +from idpyoidc.server import Server +from idpyoidc.server.configure import ASConfiguration +from idpyoidc.server.user_info import UserInfo +from oauth2_client_conf import CLIENT_CONFIG +from oauth2_client_conf import CLIENT_ID +from oauth2_server_conf import SERVER_CONF + +# ================ Server side =================================== + +USERINFO = UserInfo(json.loads(open(full_path("users.json")).read())) + +server_conf = SERVER_CONF.copy() +server_conf["keys"] = {"uri_path": "jwks.json", "key_defs": KEYDEFS} +server_conf["token_handler_args"]["key_conf"] = {"key_defs": KEYDEFS} + +server_conf['add_ons'] = { + "dpop": { + "function": "idpyoidc.server.oauth2.add_on.dpop.add_support", + "kwargs": { + 'dpop_signing_alg_values_supported': get_signing_algs() + } + } +} + +server = Server(ASConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + +# ================ Client side =================================== + +client_conf = CLIENT_CONFIG +client_conf['issuer'] = SERVER_CONF['issuer'] +client_conf['key_conf'] = {'key_defs': KEYDEFS} + +client_conf['add_ons'] = { + "dpop": { + "function": "idpyoidc.client.oauth2.add_on.dpop.add_support", + "kwargs": { + "dpop_signing_alg_values_supported": ["ES256"] + } + } +} + +client = Client(config=client_conf) + +# ==== What the server needs to know about the client. + +server.context.cdb[CLIENT_ID] = {k: v for k, v in CLIENT_CONFIG.items() if k not in ['services']} +server.context.keyjar.import_jwks(client.keyjar.export_jwks(), CLIENT_ID) + +# Initiating the Server's metadata + +server.context.set_provider_info() + +# ==== And now for the protocol exchange sequence + +flow = Flow(client, server) +msg = flow( + [ + ['server_metadata', 'server_metadata'], + ['authorization', 'authorization'], + ["accesstoken", 'token'], + ], + scope=['foobar'], + server_jwks=server.keyjar.export_jwks(''), + server_jwks_uri=server.context.provider_info['jwks_uri'] +) diff --git a/demo/oidc_client_conf.py b/demo/oidc_client_conf.py new file mode 100644 index 00000000..e37ca950 --- /dev/null +++ b/demo/oidc_client_conf.py @@ -0,0 +1,17 @@ +CLIENT_ID = 'client' + +CLIENT_CONFIG = { + "client_secret": "SUPERhemligtlösenord", + "client_id": CLIENT_ID, + "redirect_uris": ["https://example.com/cb"], + "token_endpoint_auth_methods_supported": ["client_secret_post"], + "response_types_supported": ["code"], + "allowed_scopes": ["foobar", "openid"], + "services": { + "provider_info": { + "class": "idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery"}, + "authorization": {"class": "idpyoidc.client.oidc.authorization.Authorization"}, + "access_token": {"class": "idpyoidc.client.oidc.access_token.AccessToken"}, + 'userinfo': {'class': "idpyoidc.client.oidc.userinfo.UserInfo"} + } +} diff --git a/demo/oidc_code.py b/demo/oidc_code.py new file mode 100755 index 00000000..1d8a9414 --- /dev/null +++ b/demo/oidc_code.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +import os + +from common import BASEDIR +from common import KEYDEFS +from flow import Flow +from idpyoidc.client.oidc import RP +from idpyoidc.server import OPConfiguration +from idpyoidc.server import Server +from oidc_client_conf import CLIENT_CONFIG +from oidc_client_conf import CLIENT_ID +from oidc_server_conf import SERVER_CONF + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +# ================ Server side =================================== + +server_conf = SERVER_CONF.copy() +server_conf["key_conf"] = {"uri_path": "jwks.json", "key_defs": KEYDEFS} +server_conf["token_handler_args"]["key_conf"] = {"key_defs": KEYDEFS} + +server = Server(OPConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + +# ================ Client side =================================== + +client_conf = CLIENT_CONFIG.copy() +client_conf['issuer'] = SERVER_CONF['issuer'] +client_conf['key_conf'] = {'key_defs': KEYDEFS} +client_conf["allowed_scopes"] = ["foobar", "openid", 'offline_access'] + +client = RP(config=client_conf) + +# ==== What the server needs to know about the client. + +server.context.cdb[CLIENT_ID] = CLIENT_CONFIG +server.context.cdb[CLIENT_ID]['allowed_scopes'] = client_conf["allowed_scopes"] +server.context.keyjar.import_jwks(client.keyjar.export_jwks(), CLIENT_ID) + +# Initiating the server's metadata + +server.context.set_provider_info() + +flow = Flow(client, server) +msg = flow( + [ + ['provider_info', 'provider_config'], + ['authorization', 'authorization'], + ["accesstoken", 'token'], + ['userinfo', 'userinfo'] + ], + scope=['foobar', 'offline_access', 'email'], + server_jwks=server.keyjar.export_jwks(''), + server_jwks_uri=server.context.provider_info['jwks_uri'] +) diff --git a/demo/oidc_code_claims.py b/demo/oidc_code_claims.py new file mode 100755 index 00000000..15be0060 --- /dev/null +++ b/demo/oidc_code_claims.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +import os + +from common import BASEDIR +from common import KEYDEFS +from flow import Flow +from idpyoidc.client.oidc import RP +from idpyoidc.server import OPConfiguration +from idpyoidc.server import Server +from oidc_client_conf import CLIENT_CONFIG +from oidc_client_conf import CLIENT_ID +from oidc_server_conf import SERVER_CONF + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +# ================ Server side =================================== + +server_conf = SERVER_CONF.copy() +server_conf["key_conf"] = {"uri_path": "jwks.json", "key_defs": KEYDEFS} +server_conf["token_handler_args"]["key_conf"] = {"key_defs": KEYDEFS} + +server = Server(OPConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + +# ================ Client side =================================== + +client_conf = CLIENT_CONFIG.copy() +client_conf['issuer'] = SERVER_CONF['issuer'] +client_conf['key_conf'] = {'key_defs': KEYDEFS} +client_conf["allowed_scopes"] = ["foobar", "openid", 'offline_access'] + +client = RP(config=client_conf) + +# ==== What the server needs to know about the client. + +server.context.cdb[CLIENT_ID] = CLIENT_CONFIG +server.context.cdb[CLIENT_ID]['allowed_scopes'] = client_conf["allowed_scopes"] +server.context.keyjar.import_jwks(client.keyjar.export_jwks(), CLIENT_ID) + +# Initiating the server's metadata + +server.context.set_provider_info() + +flow = Flow(client, server) +msg = flow( + [ + ['provider_info', 'provider_config'], + ['authorization', 'authorization'], + ["accesstoken", 'token'], + ['userinfo', 'userinfo'] + ], + scope=['foobar'], + server_jwks=server.keyjar.export_jwks(''), + server_jwks_uri=server.context.provider_info['jwks_uri'], + request_additions={ + 'authorization': { + 'claims': { + "id_token": {"nickname": None}, + "userinfo": {"name": None, "email": None, "email_verified": None}, + } + } + } +) diff --git a/demo/oidc_code_dyn_reg.py b/demo/oidc_code_dyn_reg.py new file mode 100755 index 00000000..c8580745 --- /dev/null +++ b/demo/oidc_code_dyn_reg.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +import os + +from common import BASEDIR +from common import KEYDEFS +from flow import Flow +from idpyoidc.client.oidc import RP +from idpyoidc.server import OPConfiguration +from idpyoidc.server import Server +from oidc_client_conf import CLIENT_CONFIG +from oidc_server_conf import SERVER_CONF + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +# ================ Server side =================================== + +server_conf = SERVER_CONF.copy() +server_conf["key_conf"] = {"uri_path": "jwks.json", "key_defs": KEYDEFS} +server_conf["token_handler_args"]["key_conf"] = {"key_defs": KEYDEFS} +server_conf["endpoint"] = { + "provider_info": { + "path": ".well-known/oauth-authorization-server", + "class": "idpyoidc.server.oidc.provider_config.ProviderConfiguration", + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": "idpyoidc.server.oidc.authorization.Authorization", + "kwargs": {}, + }, + "token": { + "path": "token", + "class": "idpyoidc.server.oidc.token.Token", + "kwargs": {}, + }, + "registration": { + "path": 'register', + "class": "idpyoidc.server.oidc.registration.Registration" + } +} + +server = Server(OPConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + +# ================ Client side =================================== + +client_conf = CLIENT_CONFIG.copy() +client_conf['issuer'] = SERVER_CONF['issuer'] +client_conf['key_conf'] = {'key_defs': KEYDEFS} +client_conf["allowed_scopes"] = ["foobar", "openid", 'offline_access'] +client_conf['services'] = { + "provider_info": { + "class": "idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery"}, + "register": {"class": "idpyoidc.client.oidc.registration.Registration"}, + "authorization": {"class": "idpyoidc.client.oidc.authorization.Authorization"}, + "access_token": {"class": "idpyoidc.client.oidc.access_token.AccessToken"}, +} + +client = RP(config=client_conf) + +# Initiating the server's metadata + +server.context.set_provider_info() + +flow = Flow(client, server) +msg = flow( + [ + ['provider_info', 'provider_config'], + ['registration', 'registration'], + ['authorization', 'authorization'], + ["accesstoken", 'token'] + ], + scope=['foobar'], + server_jwks=server.keyjar.export_jwks(''), + server_jwks_uri=server.context.provider_info['jwks_uri'] +) diff --git a/demo/oidc_code_id_token.py b/demo/oidc_code_id_token.py new file mode 100755 index 00000000..723a55b6 --- /dev/null +++ b/demo/oidc_code_id_token.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 + +from common import BASEDIR +from common import CRYPT_CONFIG +from common import KEYDEFS +from flow import Flow +from idpyoidc.client.oidc import RP +from idpyoidc.server import OPConfiguration +from idpyoidc.server import Server +from oidc_client_conf import CLIENT_CONFIG +from oidc_client_conf import CLIENT_ID +from oidc_server_conf import SERVER_CONF + +# ================ Server side =================================== + +server_conf = SERVER_CONF.copy() +server_conf["key_conf"] = {"uri_path": "jwks.json", "key_defs": KEYDEFS} +server_conf["token_handler_args"]["key_conf"] = {"key_defs": KEYDEFS} + +del server_conf['endpoint']['userinfo'] + +server_conf['authz']['kwargs'] = { + "grant_config": { + "usage_rules": { + "authorization_code": { + "supports_minting": ["access_token"], + "max_usage": 1, + "expires_in": 300 + }, + "access_token": { + "expires_in": 600, + } + } + } +} + +server_conf['token_handler_args'] = { + "code": { + "lifetime": 600, + "kwargs": { + "crypt_conf": CRYPT_CONFIG + } + }, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "id_token": { + "class": "idpyoidc.server.token.id_token.IDToken", + "kwargs": { + "lifetime": 86400, + "add_claims_by_scope": True + } + } +} + +server = Server(OPConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + +# ================ Client side =================================== + +client_conf = CLIENT_CONFIG.copy() +client_conf['issuer'] = SERVER_CONF['issuer'] +client_conf['key_conf'] = {'key_defs': KEYDEFS} +client_conf["allowed_scopes"] = ["foobar", "openid", 'offline_access'] +client_conf["response_types_supported"] = ["code id_token"] + +client = RP(config=client_conf) + +# ==== What the server needs to know about the client. + +server.context.cdb[CLIENT_ID] = CLIENT_CONFIG +for claim in ['allowed_scopes', 'response_types_supported']: + server.context.cdb["client"][claim] = client_conf[claim] + +server.context.keyjar.import_jwks(client.keyjar.export_jwks(), CLIENT_ID) + +# Initiating the server's metadata + +server.context.set_provider_info() + +flow = Flow(client, server) +msg = flow( + [ + ['provider_info', 'provider_config'], + ['authorization', 'authorization'], + ["accesstoken", 'token'] + ], + scope=['foobar'], + server_jwks=server.keyjar.export_jwks(''), + server_jwks_uri=server.context.provider_info['jwks_uri'], + response_type=['code id_token'] +) diff --git a/demo/oidc_id_token.py b/demo/oidc_id_token.py new file mode 100755 index 00000000..3a8278d2 --- /dev/null +++ b/demo/oidc_id_token.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +import os + +from common import BASEDIR +from common import KEYDEFS +from flow import Flow +from idpyoidc.client.oidc import RP +from idpyoidc.server import OPConfiguration +from idpyoidc.server import Server +from oidc_client_conf import CLIENT_CONFIG +from oidc_client_conf import CLIENT_ID +from oidc_server_conf import SERVER_CONF + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +# ================ Server side =================================== + +server_conf = SERVER_CONF.copy() +server_conf["key_conf"] = {"uri_path": "jwks.json", "key_defs": KEYDEFS} +server_conf["token_handler_args"]["key_conf"] = {"key_defs": KEYDEFS} + +del server_conf['endpoint']['userinfo'] +server_conf['authz']['kwargs'] = {} +server_conf['token_handler_args'] = { + "id_token": { + "class": "idpyoidc.server.token.id_token.IDToken", + "kwargs": { + "lifetime": 86400, + "add_claims_by_scope": True + } + } +} + +server = Server(OPConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + +# ================ Client side =================================== + +client_conf = CLIENT_CONFIG.copy() +client_conf['issuer'] = SERVER_CONF['issuer'] +client_conf['key_conf'] = {'key_defs': KEYDEFS} +client_conf["allowed_scopes"] = ["foobar", "openid", 'offline_access'] +client_conf["response_types_supported"] = ["id_token"] + +client = RP(config=client_conf) + +# ==== What the server needs to know about the client. + +server.context.cdb[CLIENT_ID] = CLIENT_CONFIG +for claim in ['allowed_scopes', 'response_types_supported']: + server.context.cdb["client"][claim] = client_conf[claim] + +server.context.keyjar.import_jwks(client.keyjar.export_jwks(), CLIENT_ID) + +# Initiating the server's metadata + +server.context.set_provider_info() + +flow = Flow(client, server) +msg = flow( + [ + ['provider_info', 'provider_config'], + ['authorization', 'authorization'] + ], + scope=['foobar'], + server_jwks=server.keyjar.export_jwks(''), + server_jwks_uri=server.context.provider_info['jwks_uri'], + response_type=['id_token'] +) diff --git a/demo/oidc_server_conf.py b/demo/oidc_server_conf.py new file mode 100644 index 00000000..e4b1a6e9 --- /dev/null +++ b/demo/oidc_server_conf.py @@ -0,0 +1,103 @@ +from common import CRYPT_CONFIG +from common import SESSION_PARAMS +from common import full_path +from idpyoidc.server.authz import AuthzHandling +from idpyoidc.server.client_authn import verify_client +from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD + +SERVER_CONF = { + "issuer": "https://example.com/", + "httpc_params": {"verify": False, "timeout": 1}, + "subject_types_supported": ["public", "pairwise", "ephemeral"], + "endpoint": { + "provider_info": { + "path": ".well-known/oauth-authorization-server", + "class": "idpyoidc.server.oidc.provider_config.ProviderConfiguration", + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": "idpyoidc.server.oidc.authorization.Authorization", + "kwargs": {}, + }, + "token": { + "path": "token", + "class": "idpyoidc.server.oidc.token.Token", + "kwargs": {}, + }, + "userinfo": { + "path": "userinfo", + "class": "idpyoidc.server.oidc.userinfo.UserInfo", + "kwargs": { + "client_authn_method": ["bearer_header", "bearer_body"], + "base_claims": { + "email": {"essential": True}, + "email_verified": {"essential": True}, + } + }, + } + }, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "idpyoidc.server.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "userinfo": { + "class": "idpyoidc.server.user_info.UserInfo", + "kwargs": {"db_file": full_path("users.json")}, + }, + "client_authn": verify_client, + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "authorization_code": { + "supports_minting": ["access_token", "refresh_token", "id_token"], + "max_usage": 1, + "expires_in": 300 + }, + "access_token": { + "expires_in": 600, + }, + "refresh_token": { + "supports_minting": ["access_token"], + "audience": ["https://example.com", "https://example2.com"], + "expires_in": 43200, + }, + }, + "expires_in": 43200, + } + }, + }, + "token_handler_args": { + "code": { + "kwargs": { + "crypt_conf": CRYPT_CONFIG + } + }, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "aud": ["https://example.org/appl"], + }, + }, + "id_token": { + "class": "idpyoidc.server.token.id_token.IDToken", + "kwargs": { + "lifetime": 86400, + "add_claims_by_scope": True + }, + } + }, + "session_params": SESSION_PARAMS, +} diff --git a/demo/passwd.json b/demo/passwd.json new file mode 100644 index 00000000..d07df8c1 --- /dev/null +++ b/demo/passwd.json @@ -0,0 +1,5 @@ +{ + "diana": "krall", + "babs": "howes", + "upper": "crust" +} \ No newline at end of file diff --git a/demo/users.json b/demo/users.json new file mode 100755 index 00000000..71aac3f9 --- /dev/null +++ b/demo/users.json @@ -0,0 +1,43 @@ +{ + "diana": { + "name": "Diana Krall", + "given_name": "Diana", + "family_name": "Krall", + "nickname": "Dina", + "email": "diana@example.org", + "email_verified": false, + "phone_number": "+46907865000", + "address": { + "street_address": "Umeå Universitet", + "locality": "Umeå", + "postal_code": "SE-90187", + "country": "Sweden" + }, + "eduperson_scoped_affiliation": [ + "staff@example.org" + ], + "webid": "http://bblfish.net/#hjs" + }, + "babs": { + "name": "Barbara J Jensen", + "given_name": "Barbara", + "family_name": "Jensen", + "nickname": "babs", + "email": "babs@example.com", + "email_verified": true, + "address": { + "street_address": "100 Universal City Plaza", + "locality": "Hollywood", + "region": "CA", + "postal_code": "91608", + "country": "USA" + } + }, + "upper": { + "name": "Upper Crust", + "given_name": "Upper", + "family_name": "Crust", + "email": "uc@example.com", + "email_verified": true + } +} \ No newline at end of file diff --git a/demo/utils.py b/demo/utils.py new file mode 100644 index 00000000..c49c7ceb --- /dev/null +++ b/demo/utils.py @@ -0,0 +1,19 @@ +import json + + +class DummyResponse(): + def __init__(self, status_code, text): + self.text = text + self.status_code = status_code + + +class EmulatePARCall(): + def __init__(self, server=None): + self.server = server + + def __call__(self, method, url, data, headers): + # I can ignore the method and url. Only interested in the data + _endp = self.server.endpoint['pushed_authorization'] + _request = _endp.parse_request(data, http_info={'headers': headers}) + _resp = _endp.process_request(request=_request) + return DummyResponse(text=json.dumps(_resp['http_response']), status_code=200) diff --git a/example/flask_rp/views.py b/example/flask_rp/views.py index 2cb47934..d833a2d2 100644 --- a/example/flask_rp/views.py +++ b/example/flask_rp/views.py @@ -186,7 +186,7 @@ def repost_fragment(): return finalize(op_identifier, args) -@oidc_rp_views.route('/authz_im_cb') +@oidc_rp_views.route('/authz_tok_cb') def authz_im_cb(op_identifier='', **kwargs): logger.debug('implicit_hybrid_flow kwargs: {}'.format(kwargs)) return render_template('repost_fragment.html', op_identifier=op_identifier) @@ -244,9 +244,10 @@ def session_logout(op_identifier): @oidc_rp_views.route('/logout') def logout(): logger.debug('logout') - _info = current_app.rph.logout(state=session['state']) - logger.debug('logout redirect to "{}"'.format(_info['url'])) - return redirect(_info['url'], 303) + _request_info = current_app.rph.logout(state=session['state']) + _url = _request_info["url"] + logger.debug(f'logout redirect to "{_url}"') + return redirect(_url, 303) @oidc_rp_views.route('/bc_logout/', methods=['GET', 'POST']) diff --git a/pyproject.toml b/pyproject.toml index 32305ba9..a155aa8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta" [metadata] name = "idpyoidc" -version = "2.0.0" +version = "2.1.0" author = "Roland Hedberg" author_email = "roland@catalogix.se" description = "Everything OAuth2 and OIDC" diff --git a/src/idpyoidc/__init__.py b/src/idpyoidc/__init__.py index c7216254..4fb4fa96 100644 --- a/src/idpyoidc/__init__.py +++ b/src/idpyoidc/__init__.py @@ -1,5 +1,5 @@ __author__ = "Roland Hedberg" -__version__ = "2.0.0" +__version__ = "2.1.0" VERIFIED_CLAIM_PREFIX = "__verified" diff --git a/src/idpyoidc/claims.py b/src/idpyoidc/claims.py index 05893a29..6328fc17 100644 --- a/src/idpyoidc/claims.py +++ b/src/idpyoidc/claims.py @@ -26,18 +26,11 @@ def claims_load(item: dict, **kwargs): class Claims(ImpExp): - parameter = { - "prefer": None, - "use": None, - "callback_path": None, - "_local": None - } + parameter = {"prefer": None, "use": None, "callback_path": None, "_local": None} _supports = {} - def __init__(self, - prefer: Optional[dict] = None, - callback_path: Optional[dict] = None): + def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] = None): ImpExp.__init__(self) if isinstance(prefer, dict): @@ -70,11 +63,11 @@ def remove_preference(self, key): def _callback_uris(self, base_url, hex): _uri = [] - for type in self.get_usage("response_types", self._supports['response_types']): + for type in self.get_usage("response_types", self._supports["response_types"]): if "code" in type: - _uri.append('code') - elif type in ["id_token", "id_token token"]: - _uri.append('implicit') + _uri.append("code") + elif type in ["id_token"]: + _uri.append("implicit") if "form_post" in self._supports: _uri.append("form_post") @@ -84,36 +77,33 @@ def _callback_uris(self, base_url, hex): callback_uri[key] = get_uri(base_url, self.callback_path[key], hex) return callback_uri - def construct_redirect_uris(self, - base_url: str, - hex: str, - callbacks: Optional[dict] = None): + def construct_redirect_uris(self, base_url: str, hex: str, callbacks: Optional[dict] = None): if not callbacks: callbacks = self._callback_uris(base_url, hex) if callbacks: - self.set_preference('callbacks', callbacks) + self.set_preference("callbacks", callbacks) self.set_preference("redirect_uris", [v for k, v in callbacks.items()]) self.callback = callbacks - def verify_rules(self): + def verify_rules(self, supports): return True def locals(self, info): pass def _keyjar(self, keyjar=None, conf=None, entity_id=""): - _uri_path = '' + _uri_path = "" if keyjar is None: if "keys" in conf: keys_args = {k: v for k, v in conf["keys"].items() if k != "uri_path"} _keyjar = init_key_jar(**keys_args) - _uri_path = conf['keys'].get('uri_path') + _uri_path = conf["keys"].get("uri_path") elif "key_conf" in conf and conf["key_conf"]: keys_args = {k: v for k, v in conf["key_conf"].items() if k != "uri_path"} _keyjar = init_key_jar(**keys_args) - _uri_path = conf['key_conf'].get('uri_path') + _uri_path = conf["key_conf"].get("uri_path") else: _keyjar = KeyJar() if "jwks" in conf: @@ -129,9 +119,9 @@ def _keyjar(self, keyjar=None, conf=None, entity_id=""): return _keyjar, _uri_path else: if "keys" in conf: - _uri_path = conf['keys'].get('uri_path') + _uri_path = conf["keys"].get("uri_path") elif "key_conf" in conf and conf["key_conf"]: - _uri_path = conf['key_conf'].get('uri_path') + _uri_path = conf["key_conf"].get("uri_path") return keyjar, _uri_path @@ -157,22 +147,21 @@ def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None): keyjar = _kj # now that keys are in the Key Jar, now for how to publish it - if 'jwks_uri' in configuration: # simple - _jwks_uri = configuration.get('jwks_uri') + if "jwks_uri" in configuration: # simple + _jwks_uri = configuration.get("jwks_uri") elif uri_path: _base_url = self.get_base_url(configuration) _jwks_uri = add_path(_base_url, uri_path) else: # jwks or nothing _jwks = self.get_jwks(keyjar) - return {'keyjar': keyjar, 'jwks': _jwks, 'jwks_uri': _jwks_uri} + return {"keyjar": keyjar, "jwks": _jwks, "jwks_uri": _jwks_uri} - def load_conf(self, - configuration: dict, - supports: dict, - keyjar: Optional[KeyJar] = None) -> KeyJar: + def load_conf( + self, configuration: dict, supports: dict, keyjar: Optional[KeyJar] = None + ) -> KeyJar: for attr, val in configuration.items(): - if attr == "preference": + if attr in ["preference", "capabilities"]: for k, v in val.items(): if k in supports: self.set_preference(k, v) @@ -182,12 +171,12 @@ def load_conf(self, self.locals(configuration) for key, val in self.handle_keys(configuration, keyjar=keyjar).items(): - if key == 'keyjar': + if key == "keyjar": keyjar = val elif val: self.set_preference(key, val) - self.verify_rules() + self.verify_rules(supports) return keyjar def get(self, key, default=None): @@ -227,7 +216,8 @@ def get_claim(self, key, default=None): else: return _val -SIGNING_ALGORITHM_SORT_ORDER = ['RS', 'ES', 'PS', 'HS'] + +SIGNING_ALGORITHM_SORT_ORDER = ["RS", "ES", "PS", "HS"] def cmp(a, b): @@ -235,9 +225,9 @@ def cmp(a, b): def alg_cmp(a, b): - if a == 'none': + if a == "none": return 1 - elif b == 'none': + elif b == "none": return -1 _pos1 = SIGNING_ALGORITHM_SORT_ORDER.index(a[0:2]) @@ -252,12 +242,15 @@ def alg_cmp(a, b): def get_signing_algs(): # Assumes Cryptojwt - return sorted(list(SIGNER_ALGS.keys()), key=cmp_to_key(alg_cmp)) + _list = list(SIGNER_ALGS.keys()) + # know how to do none but should not + _list.remove("none") + return sorted(_list, key=cmp_to_key(alg_cmp)) def get_encryption_algs(): - return SUPPORTED['alg'] + return SUPPORTED["alg"] def get_encryption_encs(): - return SUPPORTED['enc'] + return SUPPORTED["enc"] diff --git a/src/idpyoidc/client/claims/__init__.py b/src/idpyoidc/client/claims/__init__.py index f303e9e2..12a0a358 100644 --- a/src/idpyoidc/client/claims/__init__.py +++ b/src/idpyoidc/client/claims/__init__.py @@ -11,16 +11,15 @@ def get_client_authn_methods(): class Claims(claims.Claims): - def get_base_url(self, configuration: dict): - _base = configuration.get('base_url') + _base = configuration.get("base_url") if not _base: - _base = configuration.get('client_id') + _base = configuration.get("client_id") return _base def get_id(self, configuration: dict): - return self.get_preference('client_id') + return self.get_preference("client_id") def _add_key_if_missing(self, keyjar, id, key): try: @@ -33,12 +32,12 @@ def _add_key_if_missing(self, keyjar, id, key): keyjar.add_symmetric(issuer_id=id, key=key) def add_extra_keys(self, keyjar, id): - _secret = self.get_preference('client_secret') + _secret = self.get_preference("client_secret") if _secret: if keyjar is None: keyjar = KeyJar() self._add_key_if_missing(keyjar, id, _secret) - self._add_key_if_missing(keyjar, '', _secret) + self._add_key_if_missing(keyjar, "", _secret) def get_jwks(self, keyjar): if keyjar is None: @@ -46,14 +45,17 @@ def get_jwks(self, keyjar): _jwks = None try: - _own_keys = keyjar.get_issuer_keys('') + _own_keys = keyjar.get_issuer_keys("") except IssuerNotFound: pass else: # if only one key under the id == "", that key being a SYMKey I assume it's # and I have a client_secret then don't publish a JWKS - if len(_own_keys) == 1 and isinstance(_own_keys[0], SYMKey) and self.prefer[ - 'client_secret']: + if ( + len(_own_keys) == 1 + and isinstance(_own_keys[0], SYMKey) + and self.prefer["client_secret"] + ): pass else: _jwks = keyjar.export_jwks() diff --git a/src/idpyoidc/client/claims/oauth2.py b/src/idpyoidc/client/claims/oauth2.py index a979faa9..9d093d40 100644 --- a/src/idpyoidc/client/claims/oauth2.py +++ b/src/idpyoidc/client/claims/oauth2.py @@ -10,7 +10,7 @@ class Claims(claims.Claims): "grant_types_supported": ["authorization_code", "implicit", "refresh_token"], "response_types_supported": ["code"], "client_id": None, - 'client_secret': None, + "client_secret": None, "client_name": None, "client_uri": None, "logo_uri": None, @@ -21,16 +21,14 @@ class Claims(claims.Claims): "jwks_uri": None, "jwks": None, "software_id": None, - "software_version": None + "software_version": None, } callback_path = {} callback_uris = ["redirect_uris"] - def __init__(self, - prefer: Optional[dict] = None, - callback_path: Optional[dict] = None): + def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] = None): claims.Claims.__init__(self, prefer=prefer, callback_path=callback_path) def create_registration_request(self): diff --git a/src/idpyoidc/client/claims/oidc.py b/src/idpyoidc/client/claims/oidc.py index dfad0c17..787e3443 100644 --- a/src/idpyoidc/client/claims/oidc.py +++ b/src/idpyoidc/client/claims/oidc.py @@ -40,24 +40,22 @@ PREFERRED2REGISTER = dict([(v, k) for k, v in REGISTER2PREFERRED.items()]) REQUEST2REGISTER = { - 'client_id': "client_id", + "client_id": "client_id", "client_secret": "client_secret", # 'acr_values': "default_acr_values" , # 'max_age': "default_max_age", - 'redirect_uri': "redirect_uris", - 'response_type': "response_types", - 'request_uri': "request_uris", - 'grant_type': "grant_types", - "scope": 'scopes_supported', - 'post_logout_redirect_uri': "post_logout_redirect_uris" + "redirect_uri": "redirect_uris", + "response_type": "response_types", + "request_uri": "request_uris", + "grant_type": "grant_types", + "scope": "scopes_supported", + "post_logout_redirect_uri": "post_logout_redirect_uris", } class Claims(client_claims.Claims): parameter = client_claims.Claims.parameter.copy() - parameter.update({ - "requests_dir": None - }) + parameter.update({"requests_dir": None}) register2preferred = REGISTER2PREFERRED registration_response = RegistrationResponse @@ -92,32 +90,38 @@ class Claims(client_claims.Claims): "tos_uri": None, } - def __init__(self, - prefer: Optional[dict] = None, - callback_path: Optional[dict] = None - ): - client_claims.Claims.__init__(self, - prefer=prefer, - callback_path=callback_path) + def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] = None): + client_claims.Claims.__init__(self, prefer=prefer, callback_path=callback_path) - def verify_rules(self): + def verify_rules(self, supports): if self.get_preference("request_parameter_supported") and self.get_preference( - "request_uri_parameter_supported"): + "request_uri_parameter_supported" + ): raise ValueError( "You have to chose one of 'request_parameter_supported' and " - "'request_uri_parameter_supported'. You can't have both.") - - if not self.get_preference('encrypt_userinfo_supported'): - self.set_preference('userinfo_encryption_alg_values_supported', []) - self.set_preference('userinfo_encryption_enc_values_supported', []) - - if not self.get_preference('encrypt_request_object_supported'): - self.set_preference('request_object_encryption_alg_values_supported', []) - self.set_preference('request_object_encryption_enc_values_supported', []) - - if not self.get_preference('encrypt_id_token_supported'): - self.set_preference('id_token_encryption_alg_values_supported', []) - self.set_preference('id_token_encryption_enc_values_supported', []) + "'request_uri_parameter_supported'. You can't have both." + ) + + if self.get_preference("request_parameter_supported") or self.get_preference( + "request_uri_parameter_supported" + ): + if not self.get_preference("request_object_signing_alg_values_supported"): + self.set_preference( + "request_object_signing_alg_values_supported", + supports["request_object_signing_alg_values_supported"], + ) + + if not self.get_preference("encrypt_userinfo_supported"): + self.set_preference("userinfo_encryption_alg_values_supported", []) + self.set_preference("userinfo_encryption_enc_values_supported", []) + + if not self.get_preference("encrypt_request_object_supported"): + self.set_preference("request_object_encryption_alg_values_supported", []) + self.set_preference("request_object_encryption_enc_values_supported", []) + + if not self.get_preference("encrypt_id_token_supported"): + self.set_preference("id_token_encryption_alg_values_supported", []) + self.set_preference("id_token_encryption_enc_values_supported", []) def locals(self, info): requests_dir = info.get("requests_dir") diff --git a/src/idpyoidc/client/claims/transform.py b/src/idpyoidc/client/claims/transform.py index 744f1a77..ac63b2c8 100644 --- a/src/idpyoidc/client/claims/transform.py +++ b/src/idpyoidc/client/claims/transform.py @@ -21,6 +21,7 @@ "subject_type": "subject_types_supported", "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", "response_types": "response_types_supported", + "response_modes": "response_modes_supported", "grant_types": "grant_types_supported", # In OAuth2 but not in OIDC "scope": "scopes_supported", @@ -36,24 +37,25 @@ PREFERRED2REGISTER = dict([(v, k) for k, v in REGISTER2PREFERRED.items()]) REQUEST2REGISTER = { - 'client_id': "client_id", + "client_id": "client_id", "client_secret": "client_secret", # 'acr_values': "default_acr_values" , # 'max_age': "default_max_age", - 'redirect_uri': "redirect_uris", - 'response_type': "response_types", - 'request_uri': "request_uris", - 'grant_type': "grant_types", - "scope": 'scopes_supported', - 'post_logout_redirect_uri': "post_logout_redirect_uris" + "redirect_uri": "redirect_uris", + "response_type": "response_types", + "request_uri": "request_uris", + "grant_type": "grant_types", + "scope": "scopes_supported", + "post_logout_redirect_uri": "post_logout_redirect_uris", } -def supported_to_preferred(supported: dict, - preference: dict, - base_url: str, - info: Optional[dict] = None, - ): +def supported_to_preferred( + supported: dict, + preference: dict, + base_url: str, + info: Optional[dict] = None, +): if info: # The provider info for key, val in supported.items(): if key in preference: @@ -61,7 +63,7 @@ def supported_to_preferred(supported: dict, _info_val = info.get(key) if _info_val: # Only use provider setting if less or equal to what I support - if key.endswith('supported'): # list + if key.endswith("supported"): # list preference[key] = [x for x in _pref_val if x in _info_val] else: pass @@ -72,7 +74,7 @@ def supported_to_preferred(supported: dict, # there is a default _info_val = info.get(key) if _info_val: # The OP has an opinion - if key.endswith('supported'): # list + if key.endswith("supported"): # list preference[key] = [x for x in val if x in _info_val] else: pass @@ -80,11 +82,11 @@ def supported_to_preferred(supported: dict, preference[key] = val # special case -> must have a request_uris value - if 'require_request_uri_registration' in info: + if "require_request_uri_registration" in info: # only makes sense if I want to use request_uri - if preference.get('request_parameter') == 'request_uri': - if 'request_uri' not in preference: - preference['request_uris'] = [f'{base_url}/requests'] + if preference.get("request_parameter") == "request_uri": + if "request_uri" not in preference: + preference["request_uris"] = [f"{base_url}/requests"] else: # just ignore logger.info('Asked for "request_uri" which it did not plan to use') else: @@ -121,6 +123,7 @@ def _is_subset(a, b): else: return a == b + def _intersection(a, b): res = None if isinstance(a, list): @@ -138,8 +141,10 @@ def _intersection(a, b): res = [] return res -def preferred_to_registered(prefers: dict, supported: dict, - registration_response: Optional[dict] = None): + +def preferred_to_registered( + prefers: dict, supported: dict, registration_response: Optional[dict] = None +): """ The claims with values that are returned from the OP is what goes unless (!!) the values returned are not within the supported values. @@ -159,13 +164,14 @@ def preferred_to_registered(prefers: dict, supported: dict, registered[key] = val else: logger.warning( - f'OP tells me to do something I do not support: {key} = {val} not within ' - f'{_supports}') + f"OP tells me to do something I do not support: {key} = {val} not within " + f"{_supports}" + ) _val = _intersection(val, _supports) if _val: registered[key] = _val else: - raise ValueError(f'Not able to support the OPs choice: {key}={val}') + raise ValueError(f"Not able to support the OPs choice: {key}={val}") else: registered[key] = val # Should I just accept with the OP says ?? diff --git a/src/idpyoidc/client/client_auth.py b/src/idpyoidc/client/client_auth.py index 6bcff13d..7e49969a 100755 --- a/src/idpyoidc/client/client_auth.py +++ b/src/idpyoidc/client/client_auth.py @@ -98,7 +98,7 @@ def _get_passwd(request, service, **kwargs): try: passwd = request["client_secret"] except KeyError: - passwd = service.upstream_get("context").get_usage('client_secret') + passwd = service.upstream_get("context").get_usage("client_secret") return passwd @staticmethod @@ -136,8 +136,8 @@ def _with_or_without_client_id(request, service): :param service: A :py:class:`idpyoidc.client.service.Service` instance """ if ( - isinstance(request, AccessTokenRequest) - and request["grant_type"] == "authorization_code" + isinstance(request, AccessTokenRequest) + and request["grant_type"] == "authorization_code" ): if "client_id" not in request: try: @@ -223,7 +223,7 @@ def modify_request(self, request, service, **kwargs): try: request["client_secret"] = kwargs["client_secret"] except (KeyError, TypeError): - request["client_secret"] = _context.get_usage('client_secret') + request["client_secret"] = _context.get_usage("client_secret") if not request["client_secret"]: raise AuthnFailure("Missing client secret") @@ -442,9 +442,7 @@ def _get_signing_key(self, algorithm, keyjar, key_types, kid=None): signing_key = [self._get_key_by_kid(kid, algorithm, keyjar)] elif ktype in key_types: try: - signing_key = [ - self._get_key_by_kid(key_types[ktype], algorithm, keyjar) - ] + signing_key = [self._get_key_by_kid(key_types[ktype], algorithm, keyjar)] except KeyError: signing_key = self.get_signing_key_from_keyjar(algorithm, keyjar) else: @@ -470,9 +468,7 @@ def _get_audience_and_algorithm(self, context, keyjar, **kwargs): algorithm = "RS256" # default else: for alg in algs: # pick the first one I support and have keys for - if alg in SIGNER_ALGS and self.get_signing_key_from_keyjar( - alg, keyjar - ): + if alg in SIGNER_ALGS and self.get_signing_key_from_keyjar(alg, keyjar): algorithm = alg break @@ -487,12 +483,13 @@ def _get_audience_and_algorithm(self, context, keyjar, **kwargs): def _construct_client_assertion(self, service, **kwargs): _context = service.upstream_get("context") _entity = service.upstream_get("entity") - _keyjar = service.upstream_get('attribute', 'keyjar') + _keyjar = service.upstream_get("attribute", "keyjar") audience, algorithm = self._get_audience_and_algorithm(_context, _keyjar, **kwargs) if "kid" in kwargs: - signing_key = self._get_signing_key(algorithm, _keyjar, _context.kid["sig"], - kid=kwargs["kid"]) + signing_key = self._get_signing_key( + algorithm, _keyjar, _context.kid["sig"], kid=kwargs["kid"] + ) else: signing_key = self._get_signing_key(algorithm, _keyjar, _context.kid["sig"]) diff --git a/src/idpyoidc/client/configure.py b/src/idpyoidc/client/configure.py index 3a7fa911..50740986 100755 --- a/src/idpyoidc/client/configure.py +++ b/src/idpyoidc/client/configure.py @@ -25,16 +25,15 @@ class RPHConfiguration(Base): - def __init__( - self, - conf: Dict, - base_path: Optional[str] = "", - entity_conf: Optional[List[dict]] = None, - domain: Optional[str] = "127.0.0.1", - port: Optional[int] = 80, - file_attributes: Optional[List[str]] = None, - dir_attributes: Optional[List[str]] = None, + self, + conf: Dict, + base_path: Optional[str] = "", + entity_conf: Optional[List[dict]] = None, + domain: Optional[str] = "127.0.0.1", + port: Optional[int] = 80, + file_attributes: Optional[List[str]] = None, + dir_attributes: Optional[List[str]] = None, ): Base.__init__( @@ -71,7 +70,7 @@ def __init__( self.clients = lower_or_upper(conf, "clients") if self.clients: for id, client in self.clients.items(): - for param in ["services", "usage", "add_ons", 'claims']: + for param in ["services", "usage", "add_ons", "claims"]: if param not in client: if param in self.default: client[param] = self.default[param] @@ -88,17 +87,17 @@ def __init__( class Configuration(Base): - """ Configuration for a single RP """ + """Configuration for a single RP""" def __init__( - self, - conf: Dict, - base_path: str = "", - entity_conf: Optional[List[dict]] = None, - file_attributes: Optional[List[str]] = None, - domain: Optional[str] = "", - port: Optional[int] = 0, - dir_attributes: Optional[List[str]] = None, + self, + conf: Dict, + base_path: str = "", + entity_conf: Optional[List[dict]] = None, + file_attributes: Optional[List[str]] = None, + domain: Optional[str] = "", + port: Optional[int] = 0, + dir_attributes: Optional[List[str]] = None, ): Base.__init__( self, diff --git a/src/idpyoidc/client/current.py b/src/idpyoidc/client/current.py index 196ec19b..435ea9ea 100644 --- a/src/idpyoidc/client/current.py +++ b/src/idpyoidc/client/current.py @@ -51,22 +51,19 @@ def set(self, key: str, info: Union[Message, dict]): def get_claim(self, key: str, claim: str) -> Union[str, None]: return self.get(key).get(claim) - def get_set(self, - key: str, - message: Optional[type(Message)] = None, - claim: Optional[list] = None) -> dict: + def get_set( + self, key: str, message: Optional[type(Message)] = None, claim: Optional[list] = None + ) -> dict: """ - @param key: The key to a seet of current claims + @param key: The key to a set of current claims @param message: A message class @param claim: A list of claims @return: Dictionary + @raise KeyError if no such key """ - try: - _current = self.get(key) - except KeyError: - return {} + _current = self.get(key) if message: _res = {k: _current[k] for k in message.c_param.keys() if k in _current} diff --git a/src/idpyoidc/client/defaults.py b/src/idpyoidc/client/defaults.py index b8d50659..fbacba9b 100644 --- a/src/idpyoidc/client/defaults.py +++ b/src/idpyoidc/client/defaults.py @@ -31,10 +31,7 @@ "response_types": [ "code", "id_token", - "id_token token", "code id_token", - "code id_token token", - "code token", ], "token_endpoint_auth_method": "client_secret_basic", "scopes_supported": ["openid"], @@ -48,6 +45,7 @@ # Using PKCE is default DEFAULT_CLIENT_CONFIGS = { "": { + "client_type": "oidc", "preference": DEFAULT_CLIENT_PREFERENCES, "add_ons": { "pkce": { @@ -71,6 +69,8 @@ } OIDCONF_PATTERN = "{}/.well-known/openid-configuration" +OAUTH2_SERVER_METADATA_URL = "{}/.well-known/oauth-authorization-server" + CC_METHOD = { "S256": hashlib.sha256, "S384": hashlib.sha384, @@ -92,3 +92,13 @@ SAML2_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:saml2-bearer" BASECHR = string.ascii_letters + string.digits + +DEFAULT_RESPONSE_MODE = { + "code": "query", + "id_token": "fragment", + "token": "fragment", + "code token": "fragment", + "code id_token": "fragment", + "id_token token": "fragment", + "code id_token token": "fragment", +} diff --git a/src/idpyoidc/client/entity.py b/src/idpyoidc/client/entity.py index 9e8b7a8e..d9f29ca7 100644 --- a/src/idpyoidc/client/entity.py +++ b/src/idpyoidc/client/entity.py @@ -24,10 +24,10 @@ RESPONSE_TYPES2GRANT_TYPES = { "code": ["authorization_code"], "id_token": ["implicit"], - "id_token token": ["implicit"], + # "id_token token": ["implicit"], "code id_token": ["authorization_code", "implicit"], - "code token": ["authorization_code", "implicit"], - "code id_token token": ["authorization_code", "implicit"], + # "code token": ["authorization_code", "implicit"], + # "code id_token token": ["authorization_code", "implicit"], } @@ -48,7 +48,7 @@ def response_types_to_grant_types(response_types): def _set_jwks(service_context, config: Configuration, keyjar: Optional[KeyJar]): - _key_conf = config.get("key_conf") or config.conf.get('key_conf') + _key_conf = config.get("key_conf") or config.conf.get("key_conf") if _key_conf: keys_args = {k: v for k, v in _key_conf.items() if k != "uri_path"} @@ -71,44 +71,51 @@ def set_jwks_uri_or_jwks(service_context, config, jwks_uri, keyjar): def redirect_uris_from_callback_uris(callback_uris): res = [] - for k, v in callback_uris['redirect_uris'].items(): + for k, v in callback_uris["redirect_uris"].items(): res.extend(v) return res class Entity(Unit): # This is a Client. What type is undefined here. parameter = { - 'entity_id': None, - 'jwks_uri': None, - 'httpc_params': None, - 'key_conf': None, - 'keyjar': KeyJar, - 'context': None + "entity_id": None, + "jwks_uri": None, + "httpc_params": None, + "key_conf": None, + "keyjar": KeyJar, + "context": None, } def __init__( - self, - keyjar: Optional[KeyJar] = None, - config: Optional[Union[dict, Configuration]] = None, - services: Optional[dict] = None, - jwks_uri: Optional[str] = "", - httpc: Optional[Callable] = None, - httpc_params: Optional[dict] = None, - client_type: Optional[str] = "oauth2", - context: Optional[OidcContext] = None, - upstream_get: Optional[Callable] = None, - key_conf: Optional[dict] = None, - entity_id: Optional[str] = '' + self, + keyjar: Optional[KeyJar] = None, + config: Optional[Union[dict, Configuration]] = None, + services: Optional[dict] = None, + jwks_uri: Optional[str] = "", + httpc: Optional[Callable] = None, + httpc_params: Optional[dict] = None, + client_type: Optional[str] = "oauth2", + context: Optional[OidcContext] = None, + upstream_get: Optional[Callable] = None, + key_conf: Optional[dict] = None, + entity_id: Optional[str] = "", ): if config is None: config = {} - _id = config.get('client_id') - self.client_id = self.entity_id = entity_id or config.get('entity_id', _id) + _id = config.get("client_id") + self.client_id = self.entity_id = entity_id or config.get("entity_id", _id) - Unit.__init__(self, upstream_get=upstream_get, keyjar=keyjar, httpc=httpc, - httpc_params=httpc_params, config=config, key_conf=key_conf, - client_id=self.client_id) + Unit.__init__( + self, + upstream_get=upstream_get, + keyjar=keyjar, + httpc=httpc, + httpc_params=httpc_params, + config=config, + key_conf=key_conf, + client_id=self.client_id, + ) if services: _srvs = services @@ -118,7 +125,7 @@ def __init__( _srvs = None if not _srvs: - if client_type == 'oauth2': + if client_type == "oauth2": _srvs = DEFAULT_OAUTH2_SERVICES else: _srvs = DEFAULT_OIDC_SERVICES @@ -128,8 +135,13 @@ def __init__( if context: self.context = context else: - self.context = ServiceContext(config=config, jwks_uri=jwks_uri, keyjar=self.keyjar, - upstream_get=self.unit_get, client_type=client_type) + self.context = ServiceContext( + config=config, + jwks_uri=jwks_uri, + keyjar=self.keyjar, + upstream_get=self.unit_get, + client_type=client_type, + ) self.setup_client_authn_methods(config) @@ -161,11 +173,11 @@ def get_entity(self): return self def get_client_id(self): - _val = self.context.claims.get_usage('client_id') + _val = self.context.claims.get_usage("client_id") if _val: return _val else: - return self.context.claims.get_preference('client_id') + return self.context.claims.get_preference("client_id") def setup_client_authn_methods(self, config): if config and "client_authn_methods" in config: @@ -183,7 +195,7 @@ def import_keys(self, keyspec): :param keyspec: """ - _keyjar = self.get_attribute('keyjar') + _keyjar = self.get_attribute("keyjar") if _keyjar is None: _keyjar = KeyJar() @@ -192,8 +204,7 @@ def import_keys(self, keyspec): for typ, files in spec.items(): if typ == "rsa": for fil in files: - _key = RSAKey(priv_key=import_private_rsa_key_from_file(fil), - use="sig") + _key = RSAKey(priv_key=import_private_rsa_key_from_file(fil), use="sig") _bundle = KeyBundle() _bundle.append(_key) _keyjar.add_kb("", _bundle) diff --git a/src/idpyoidc/client/oauth2/__init__.py b/src/idpyoidc/client/oauth2/__init__.py index ab15c941..312b0cfa 100755 --- a/src/idpyoidc/client/oauth2/__init__.py +++ b/src/idpyoidc/client/oauth2/__init__.py @@ -1,5 +1,5 @@ -from json import JSONDecodeError import logging +from json import JSONDecodeError from typing import Callable from typing import Optional from typing import Union @@ -12,8 +12,8 @@ from idpyoidc.client.exception import OidcServiceError from idpyoidc.client.exception import ParseError from idpyoidc.client.service import REQUEST_INFO -from idpyoidc.client.service import SUCCESSFUL from idpyoidc.client.service import Service +from idpyoidc.client.service import SUCCESSFUL from idpyoidc.client.util import do_add_ons from idpyoidc.client.util import get_deserialization_method from idpyoidc.configure import Configuration @@ -26,8 +26,6 @@ logger = logging.getLogger(__name__) -Version = "2.0" - class ExpiredToken(Exception): pass @@ -37,7 +35,8 @@ class ExpiredToken(Exception): class Client(Entity): - client_type = 'oauth2' + client_type = "oauth2" + def __init__( self, keyjar: Optional[KeyJar] = None, @@ -48,7 +47,7 @@ def __init__( context: Optional[OidcContext] = None, upstream_get: Optional[Callable] = None, key_conf: Optional[dict] = None, - entity_id: Optional[str] = '', + entity_id: Optional[str] = "", verify_ssl: Optional[bool] = True, jwks_uri: Optional[str] = "", client_type: Optional[str] = "", @@ -69,15 +68,21 @@ def __init__( :return: Client instance """ - if not client_type: + if client_type: + self.client_type = client_type + elif config and 'client_type' in config: + client_type = self.client_type = config["client_type"] + else: client_type = self.client_type if verify_ssl is False: # just ignore verify_ssl until it goes away if httpc_params: - httpc_params['verify'] = False + httpc_params["verify"] = False else: - httpc_params = {'verify': False} + httpc_params = {"verify": False} + + jwks_uri = jwks_uri or config.get('jwks_uri', '') Entity.__init__( self, @@ -91,7 +96,7 @@ def __init__( context=context, upstream_get=upstream_get, key_conf=key_conf, - entity_id=entity_id + entity_id=entity_id, ) self.httpc = httpc or request @@ -163,7 +168,7 @@ def get_response( if resp.status_code < 300: if "keyjar" not in kwargs: - kwargs["keyjar"] = self.get_attribute('keyjar') + kwargs["keyjar"] = self.get_attribute("keyjar") if not response_body_type: response_body_type = service.response_body_type @@ -217,7 +222,7 @@ def service_request( if "error" in response: pass else: - service.update_service_context(response, key=kwargs.get('state'), **kwargs) + service.update_service_context(response, key=kwargs.get("state"), **kwargs) return response def parse_request_response(self, service, reqresp, response_body_type="", state="", **kwargs): @@ -311,17 +316,20 @@ def dynamic_provider_info_discovery(client: Client, behaviour_args: Optional[dic :param behaviour_args: :param client: A :py:class:`idpyoidc.client.oidc.Client` instance """ + + if client.client_type == 'oidc' and client.get_service("provider_info"): + service = 'provider_info' + elif client.client_type == 'oauth2' and client.get_service('server_metadata'): + service = 'server_metadata' + else: + raise ConfigurationError("Can not do dynamic provider info discovery") + + _context = client.get_context() try: - client.get_service("provider_info") + _context.set("issuer", _context.config["srv_discovery_url"]) except KeyError: - raise ConfigurationError("Can not do dynamic provider info discovery") - else: - _context = client.get_context() - try: - _context.set("issuer", _context.config["srv_discovery_url"]) - except KeyError: - pass + pass - response = client.do_request("provider_info", behaviour_args=behaviour_args) - if is_error_message(response): - raise OidcServiceError(response["error"]) + response = client.do_request(service, behaviour_args=behaviour_args) + if is_error_message(response): + raise OidcServiceError(response["error"]) diff --git a/src/idpyoidc/client/oauth2/access_token.py b/src/idpyoidc/client/oauth2/access_token.py index a100a830..83f1ef96 100644 --- a/src/idpyoidc/client/oauth2/access_token.py +++ b/src/idpyoidc/client/oauth2/access_token.py @@ -27,7 +27,7 @@ class AccessToken(Service): request_body_type = "urlencoded" response_body_type = "json" - _include = {"grant_types_supported": ['authorization_code']} + _include = {"grant_types_supported": ["authorization_code"]} _supports = { "token_endpoint_auth_methods_supported": get_client_authn_methods, @@ -38,7 +38,7 @@ def __init__(self, upstream_get, conf=None): Service.__init__(self, upstream_get, conf=conf) self.pre_construct.append(self.oauth_pre_construct) - def update_service_context(self, resp, key: Optional[str] = '', **kwargs): + def update_service_context(self, resp, key: Optional[str] = "", **kwargs): if "expires_in" in resp: resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) if key: diff --git a/src/idpyoidc/client/oauth2/add_on/dpop.py b/src/idpyoidc/client/oauth2/add_on/dpop.py index 3122a55a..9c8e7f23 100644 --- a/src/idpyoidc/client/oauth2/add_on/dpop.py +++ b/src/idpyoidc/client/oauth2/add_on/dpop.py @@ -1,18 +1,26 @@ +import logging import uuid +from hashlib import sha256 from typing import Optional +from cryptography.hazmat.primitives import hashes from cryptojwt.jwk.jwk import key_from_jwk_dict from cryptojwt.jws.jws import JWS from cryptojwt.jws.jws import factory +from cryptojwt.jws.jws import SIGNER_ALGS from cryptojwt.key_bundle import key_by_alg +from idpyoidc.claims import get_signing_algs from idpyoidc.client.service_context import ServiceContext +from idpyoidc.message import SINGLE_OPTIONAL_STRING from idpyoidc.message import SINGLE_REQUIRED_INT from idpyoidc.message import SINGLE_REQUIRED_JSON from idpyoidc.message import SINGLE_REQUIRED_STRING from idpyoidc.message import Message from idpyoidc.time_util import utc_time_sans_frac +logger = logging.getLogger(__name__) + class DPoPProof(Message): c_param = { @@ -25,9 +33,10 @@ class DPoPProof(Message): "htm": SINGLE_REQUIRED_STRING, "htu": SINGLE_REQUIRED_STRING, "iat": SINGLE_REQUIRED_INT, + "ath": SINGLE_OPTIONAL_STRING, } header_params = {"typ", "alg", "jwk"} - body_params = {"jti", "htm", "htu", "iat"} + body_params = {"jti", "htm", "htu", "iat", "ath"} def __init__(self, set_defaults=True, **kwargs): self.key = None @@ -56,7 +65,7 @@ def verify(self, **kwargs): raise ValueError("'none' is not allowed as signing algorithm") def create_header(self) -> str: - payload = {k: self[k] for k in self.body_params} + payload = {k: self[k] for k in self.body_params if k in self} _jws = JWS(payload, alg=self["alg"]) _jws_headers = {k: self[k] for k in self.header_params} _signed_jwt = _jws.sign_compact(keys=[self.key], **_jws_headers) @@ -88,6 +97,8 @@ def dpop_header( service_endpoint: str, http_method: str, headers: Optional[dict] = None, + token: Optional[str] = "", + nonce: Optional[str] = "", **kwargs ) -> dict: """ @@ -95,36 +106,36 @@ def dpop_header( :param service_context: :param service_endpoint: :param http_method: - :param headers: + :param headers: The HTTP headers to which the DPoP header should be added. + :param token: If the DPoP Proof is sent together with an access token this should lead to + the addition of the ath claim (hash of the token as value) + :param nonce: AS or RS provided nonce. :param kwargs: :return: """ provider_info = service_context.provider_info - dpop_key = service_context.add_on["dpop"].get("key") + _dpop_conf = service_context.add_on.get("dpop") + if not _dpop_conf: + logger.warning("Asked to do dpop when I do not support it") + return headers - if not dpop_key: - algs_supported = provider_info["dpop_signing_alg_values_supported"] - if not algs_supported: # does not support DPoP - return headers + dpop_key = _dpop_conf.get("key") - chosen_alg = "" - for alg in service_context.add_on["dpop"]["sign_algs"]: - if alg in algs_supported: - chosen_alg = alg - break + if not dpop_key: + chosen_alg = _dpop_conf.get("algs_supported", [])[0] if not chosen_alg: return headers # Mint a new key dpop_key = key_by_alg(chosen_alg) - service_context.add_on["dpop"]["key"] = dpop_key - service_context.add_on["dpop"]["alg"] = chosen_alg + _dpop_conf["key"] = dpop_key + _dpop_conf["alg"] = chosen_alg header_dict = { "typ": "dpop+jwt", - "alg": service_context.add_on["dpop"]["alg"], + "alg": _dpop_conf["alg"], "jwk": dpop_key.serialize(), "jti": uuid.uuid4().hex, "htm": http_method, @@ -132,6 +143,12 @@ def dpop_header( "iat": utc_time_sans_frac(), } + if token: + header_dict["ath"] = sha256(token.encode("utf8")).hexdigest() + + if nonce: + header_dict["nonce"] = nonce + _dpop = DPoPProof(**header_dict) _dpop.key = dpop_key jws = _dpop.create_header() @@ -155,10 +172,15 @@ def add_support(services, dpop_signing_alg_values_supported): # Access token request should use DPoP header _service = services["accesstoken"] _context = _service.upstream_get("context") + _algs_supported = [ + alg for alg in dpop_signing_alg_values_supported if alg in get_signing_algs() + ] _context.add_on["dpop"] = { # "key": key_by_alg(signing_algorithm), - "sign_algs": dpop_signing_alg_values_supported + "algs_supported": _algs_supported } + _context.set_preference("dpop_signing_alg_values_supported", _algs_supported) + _service.construct_extra_headers.append(dpop_header) # The same for userinfo requests diff --git a/src/idpyoidc/client/oauth2/add_on/identity_assurance.py b/src/idpyoidc/client/oauth2/add_on/identity_assurance.py index ea1253cd..a50d3366 100644 --- a/src/idpyoidc/client/oauth2/add_on/identity_assurance.py +++ b/src/idpyoidc/client/oauth2/add_on/identity_assurance.py @@ -35,8 +35,7 @@ def format_response(format, response, verified_response): def identity_assurance_process(response, service_context, state): - auth_request = service_context.cstate.get_set(state, - message=AuthorizationRequest) + auth_request = service_context.cstate.get_set(state, message=AuthorizationRequest) claims_request = auth_request.get("claims") if claims_request and "userinfo" in claims_request: vc = VerifiedClaims(**response["verified_claims"]) diff --git a/src/idpyoidc/client/oauth2/add_on/jar.py b/src/idpyoidc/client/oauth2/add_on/jar.py new file mode 100644 index 00000000..349050ce --- /dev/null +++ b/src/idpyoidc/client/oauth2/add_on/jar.py @@ -0,0 +1,226 @@ +import logging +from typing import Optional + +from idpyoidc import claims +from idpyoidc.client.oidc.utils import construct_request_uri +from idpyoidc.client.oidc.utils import request_object_encryption +from idpyoidc.message.oidc import make_openid_request +from idpyoidc.time_util import utc_time_sans_frac + +logger = logging.getLogger(__name__) + +DEFAULT_EXPIRES_IN = 3600 + + +def store_request_on_file(service, req, **kwargs): + """ + Stores the request parameter in a file. + :param req: The request + :param kwargs: Extra keyword arguments + :return: The URL the OP should use to access the file + """ + _context = service.upstream_get("context") + _webname = _context.get_usage("request_uris") + if _webname is None: + filename, _webname = construct_request_uri(**kwargs) + else: + # webname should be a list + _webname = _webname[0] + filename = _context.filename_from_webname(_webname) + + fid = open(filename, mode="w") + fid.write(req) + fid.close() + return _webname + + +def get_request_object_signing_alg(service, **kwargs): + alg = "" + for arg in ["request_object_signing_alg", "algorithm"]: + try: # Trumps everything + alg = kwargs[arg] + except KeyError: + pass + else: + break + + if not alg: + _context = service.upstream_get("context") + alg = _context.add_on["jar"].get("request_object_signing_alg") + if alg is None: + alg = "RS256" + return alg + + +def construct_request_parameter(service, req, audience=None, **kwargs): + """Construct a request parameter""" + alg = get_request_object_signing_alg(service, **kwargs) + kwargs["request_object_signing_alg"] = alg + + _context = service.upstream_get("context") + if "keys" not in kwargs and alg and alg != "none": + kwargs["keys"] = service.upstream_get("attribute", "keyjar") + + if alg == "none": + kwargs["keys"] = [] + + # This is the issuer of the JWT, that is me ! + _issuer = kwargs.get("issuer") + if _issuer is None: + kwargs["issuer"] = _context.get_client_id() + + if kwargs.get("recv") is None: + try: + kwargs["recv"] = _context.provider_info["issuer"] + except KeyError: + kwargs["recv"] = _context.issuer + + try: + del kwargs["service"] + except KeyError: + pass + + _jar_conf = _context.add_on["jar"] + expires_in = _jar_conf.get("expires_in", DEFAULT_EXPIRES_IN) + if expires_in: + req["exp"] = utc_time_sans_frac() + int(expires_in) + + if _jar_conf.get("with_jti", False): + kwargs["with_jti"] = True + + _enc_enc = _jar_conf.get("request_object_encryption_enc", "") + if _enc_enc: + kwargs["request_object_encryption_enc"] = _enc_enc + kwargs["request_object_encryption_alg"] = _jar_conf.get("request_object_encryption_alg") + + # Filter out only the arguments I want + _mor_args = { + k: kwargs[k] + for k in [ + "keys", + "issuer", + "request_object_signing_alg", + "recv", + "with_jti", + "lifetime", + ] + if k in kwargs + } + + if audience: + _mor_args["aud"] = audience + + _req_jwt = make_openid_request(req, **_mor_args) + + if "target" not in kwargs: + kwargs["target"] = _context.provider_info.get("issuer", _context.issuer) + + # Should the request be encrypted + _req_jwte = request_object_encryption( + _req_jwt, _context, service.upstream_get("attribute", "keyjar"), **kwargs + ) + return _req_jwte + + +def jar_post_construct(request_args, service, **kwargs): + """ + Modify the request arguments. + + :param request_args: The request + :param service: The service that uses this post_constructor + :param kwargs: Extra keyword arguments + :return: A possibly modified request. + """ + _context = service.upstream_get("context") + + # Overrides what's in the configuration + _request_param = kwargs.get("request_param") + _local_dir = "" + if _request_param: + del kwargs["request_param"] + else: + _jar_config = _context.add_on["jar"] + if "request_uri" in _context.add_on["jar"]: + _request_param = "request_uri" + _local_dir = _jar_config.get("requests_dir", "./requests") + elif "request_parameter" in _jar_config: + _request_param = "request" + + _req = None # just a flag + _state = request_args["state"] + if _request_param == "request_uri": + kwargs["base_path"] = _context.get("base_url") + "/" + "requests" + if _local_dir: + kwargs["local_dir"] = _local_dir + else: + kwargs["local_dir"] = kwargs.get("requests_dir", "./requests") + + _req = construct_request_parameter(service, request_args, _request_param, **kwargs) + request_args["request_uri"] = store_request_on_file(service, _req, **kwargs) + elif _request_param == "request": + _req = construct_request_parameter(service, request_args, **kwargs) + request_args["request"] = _req + + if _req: + _leave = ["request", "request_uri"] + _leave.extend(request_args.required_parameters()) + _keys = [k for k in request_args.keys() if k not in _leave] + for k in _keys: + del request_args[k] + + _context.cstate.update(_state, request_args) + + return request_args + + +def add_support( + service, + request_type: Optional[str] = "request_parameter", + request_dir: Optional[str] = "", + request_object_signing_alg: Optional[str] = "RS256", + expires_in: Optional[int] = DEFAULT_EXPIRES_IN, + with_jti: Optional[bool] = False, + request_object_encryption_alg: Optional[str] = "", + request_object_encryption_enc: Optional[str] = "", +): + """ + JAR support can only be considered if this client can access an authorization service. + + :param service: Dictionary of services + :return: + """ + if "authorization" in service: + _service = service["authorization"] + _context = _service.upstream_get("context") + + _service.post_construct.append(jar_post_construct) + args = { + "request_object_signing_alg": request_object_signing_alg, + "expires_in": expires_in, + "with_jti": with_jti, + } + if request_type == "request_parameter": + args["request_parameter"] = True + elif request_type == "request_uri": + args["request_uri"] = True + if request_dir: + args["request_dir"] = request_dir + + if request_object_encryption_enc and request_object_encryption_alg: + if request_object_encryption_enc in claims.get_encryption_encs(): + if request_object_encryption_alg in claims.get_encryption_algs(): + args["request_object_encryption_enc"] = request_object_encryption_enc + args["request_object_encryption_alg"] = request_object_encryption_alg + else: + AttributeError( + f"An encryption alg {request_object_encryption_alg} there is no support " + f"for" + ) + else: + AttributeError( + f"An encryption enc {request_object_encryption_enc} there is no support for" + ) + + _context.add_on["jar"] = args + else: + logger.warning("JAR support could NOT be added") diff --git a/src/idpyoidc/client/oauth2/add_on/par.py b/src/idpyoidc/client/oauth2/add_on/par.py new file mode 100644 index 00000000..03416c5d --- /dev/null +++ b/src/idpyoidc/client/oauth2/add_on/par.py @@ -0,0 +1,121 @@ +import logging + +from cryptojwt import JWT +from cryptojwt.utils import importer +from requests import request + +from idpyoidc.client.client_auth import CLIENT_AUTHN_METHOD +from idpyoidc.message import Message +from idpyoidc.message.oauth2 import JWTSecuredAuthorizationRequest +from idpyoidc.util import instantiate + +logger = logging.getLogger(__name__) + + +def push_authorization(request_args, service, **kwargs): + """ + :param request_args: All the request arguments as a AuthorizationRequest instance + :param service: The service to which this post construct method is applied. + :param kwargs: Extra keyword arguments. + """ + + _context = service.upstream_get("context") + method_args = _context.add_on["pushed_authorization"] + if method_args["apply"] is False: + return request_args + + _http_method = method_args["http_client"] + + # Add client authentication if needed + _headers = {} + authn_method = method_args["authn_method"] + if authn_method: + if authn_method not in _context.client_authn_methods: + _context.client_authn_methods[authn_method] = CLIENT_AUTHN_METHOD[authn_method]() + + _args = {} + if _context.issuer: + _args["iss"] = _context.issuer + _headers = service.get_headers( + request_args, http_method=_http_method, authn_method=authn_method, **_args + ) + + # construct the message body + if method_args["body_format"] == "urlencoded": + _body = request_args.to_urlencoded() + else: + _jwt = JWT( + key_jar=service.upstream_get("attribute", "keyjar"), + iss=_context.claims.prefer["client_id"], + ) + _jws = _jwt.pack(request_args.to_dict()) + + _msg = Message(request=_jws) + for param in request_args.required_parameters(): + _msg[param] = request_args.get(param) + + _body = _msg.to_urlencoded() + + # Send it to the Pushed Authorization Request Endpoint + resp = _http_method( + method="GET", + url=_context.provider_info["pushed_authorization_request_endpoint"], + data=_body, + headers=_headers, + ) + + if resp.status_code == 200: + _resp = Message().from_json(resp.text) + _req = JWTSecuredAuthorizationRequest(request_uri=_resp["request_uri"]) + for param in request_args.required_parameters(): + _req[param] = request_args.get(param) + request_args = _req + else: + raise ConnectionError( + f"Could not connect to " + f'{_context.provider_info["pushed_authorization_request_endpoint"]}' + ) + + return request_args + + +def add_support( + services, + body_format="jws", + signing_algorithm="RS256", + http_client=None, + merge_rule="strict", + authn_method="", +): + """ + Add the necessary pieces to support Pushed authorization. + + :param merge_rule: + :param http_client: + :param signing_algorithm: + :param services: A dictionary with all the services the client has access to. + :param body_format: jws or urlencoded + """ + + if http_client is None: + _http_client = request + else: + if isinstance(http_client, dict): + if "class" in http_client: + _http_client = instantiate(http_client["class"], **http_client.get("kwargs", {})) + else: + _http_client = importer(http_client["function"]) + else: + _http_client = importer(http_client) + + _service = services["authorization"] + _service.upstream_get("context").add_on["pushed_authorization"] = { + "body_format": body_format, + "signing_algorithm": signing_algorithm, + "http_client": _http_client, + "merge_rule": merge_rule, + "apply": True, + "authn_method": authn_method, + } + + _service.post_construct.append(push_authorization) diff --git a/src/idpyoidc/client/oauth2/add_on/pkce.py b/src/idpyoidc/client/oauth2/add_on/pkce.py index f9491975..738067f6 100644 --- a/src/idpyoidc/client/oauth2/add_on/pkce.py +++ b/src/idpyoidc/client/oauth2/add_on/pkce.py @@ -69,7 +69,7 @@ def add_code_verifier(request_args, service, **kwargs): _state = request_args.get("state") if _state is None: _state = kwargs.get("state") - _item = service.upstream_get("context").cstate.get_set(_state, claim=['code_verifier']) + _item = service.upstream_get("context").cstate.get_set(_state, claim=["code_verifier"]) request_args.update(_item) return request_args diff --git a/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py b/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py deleted file mode 100644 index 611a0008..00000000 --- a/src/idpyoidc/client/oauth2/add_on/pushed_authorization.py +++ /dev/null @@ -1,83 +0,0 @@ -import logging - -from cryptojwt import JWT -from requests import request - -from idpyoidc.message import Message -from idpyoidc.message.oauth2 import JWTSecuredAuthorizationRequest - -logger = logging.getLogger(__name__) - - -def push_authorization(request_args, service, **kwargs): - """ - :param request_args: All the request arguments as a AuthorizationRequest instance - :param service: The service to which this post construct method is applied. - :param kwargs: Extra keyword arguments. - """ - - _context = service.upstream_get("context") - method_args = _context.add_on["pushed_authorization"] - if method_args['apply'] is False: - return request_args - - # construct the message body - if method_args["body_format"] == "urlencoded": - _body = request_args.to_urlencoded() - else: - _jwt = JWT(key_jar=service.upstream_get('attribute', 'keyjar'), - iss=_context.base_url) - _jws = _jwt.pack(request_args.to_dict()) - - _msg = Message(request=_jws) - if method_args["merge_rule"] == "lax": - for param in request_args.required_parameters(): - _msg[param] = request_args.get(param) - - _body = _msg.to_urlencoded() - - # Send it to the Pushed Authorization Request Endpoint - resp = method_args["http_client"]( - method="GET", - url=_context.provider_info["pushed_authorization_request_endpoint"], - data=_body - ) - - if resp.status_code == 200: - _resp = Message().from_json(resp.text) - _req = JWTSecuredAuthorizationRequest(request_uri=_resp["request_uri"]) - if method_args["merge_rule"] == "lax": - for param in request_args.required_parameters(): - _req[param] = request_args.get(param) - request_args = _req - - return request_args - - -def add_support( - services, body_format="jws", signing_algorithm="RS256", http_client=None, - merge_rule="strict" -): - """ - Add the necessary pieces to make Demonstration of proof of possession (DPOP). - - :param merge_rule: - :param http_client: - :param signing_algorithm: - :param services: A dictionary with all the services the client has access to. - :param body_format: jws or urlencoded - """ - - if http_client is None: - http_client = request - - _service = services["authorization"] - _service.upstream_get("context").add_on["pushed_authorization"] = { - "body_format": body_format, - "signing_algorithm": signing_algorithm, - "http_client": http_client, - "merge_rule": merge_rule, - 'apply': True - } - - _service.post_construct.append(push_authorization) diff --git a/src/idpyoidc/client/oauth2/authorization.py b/src/idpyoidc/client/oauth2/authorization.py index 39f5ff7d..1ce76728 100644 --- a/src/idpyoidc/client/oauth2/authorization.py +++ b/src/idpyoidc/client/oauth2/authorization.py @@ -3,12 +3,12 @@ from typing import List from typing import Optional -from idpyoidc import claims from idpyoidc.client.oauth2.utils import get_state_parameter from idpyoidc.client.oauth2.utils import pre_construct_pick_redirect_uri from idpyoidc.client.oauth2.utils import set_state_parameter from idpyoidc.client.service import Service from idpyoidc.client.service_context import ServiceContext +from idpyoidc.client.util import IMPLICIT_RESPONSE_TYPES from idpyoidc.client.util import implicit_response_types from idpyoidc.exception import MissingParameter from idpyoidc.message import oauth2 @@ -30,8 +30,8 @@ class Authorization(Service): response_body_type = "urlencoded" _supports = { - "response_types_supported": ["code", 'token'], - "response_modes_supported": ['query', 'fragment'], + "response_types_supported": ["code"], + "response_modes_supported": ["query", "fragment"], # Below not OAuth2 functionality # "request_object_signing_alg_values_supported": claims.get_signing_algs, # "request_object_encryption_alg_values_supported": claims.get_encryption_algs, @@ -40,9 +40,9 @@ class Authorization(Service): } _callback_path = { - "redirect_uris": { # based on response_types - "code": "authz_cb", - "implicit": "authz_im_cb", + "redirect_uris": { # based on response_mode + "query": "authz_cb", + "fragment": "authz_im_cb", # "form_post": "form" } } @@ -68,8 +68,7 @@ def gather_request_args(self, **kwargs): if "redirect_uri" not in ar_args: try: - ar_args["redirect_uri"] = self.upstream_get("context").get_usage( - "redirect_uris")[0] + ar_args["redirect_uri"] = self.upstream_get("context").get_usage("redirect_uris")[0] except (KeyError, AttributeError): raise MissingParameter("redirect_uri") @@ -93,64 +92,82 @@ def post_parse_response(self, response, **kwargs): else: if _key: item = self.upstream_get("context").cstate.get_set( - _key, message=oauth2.AuthorizationRequest) + _key, message=oauth2.AuthorizationRequest + ) try: response["scope"] = item["scope"] except KeyError: pass return response - def _do_flow(self, flow_type, response_types): - if flow_type == 'code' and 'code' in response_types: - return True - elif flow_type == 'implicit': + def _do_flow(self, flow_type, response_types, context) -> str: + if flow_type == "query": + if "code" in response_types: + return "query" + elif flow_type == "fragment": if implicit_response_types(response_types): - return True - return False + return "fragment" + elif flow_type == 'form_post': + rm = context.get_preference('response_modes_supported') + if rm and 'form_post' in rm: + if context.config.conf.get("separate_form_post_cb", True): + return "form_post" + else: + return "query" + return '' def _do_redirect_uris(self, base_url, hex, context, callback_uris, response_types): - _redirect_uris = context.get_preference('redirect_uris', []) + _redirect_uris = context.get_preference("redirect_uris", []) if _redirect_uris: - if not callback_uris or 'redirect_uris' not in callback_uris: + if not callback_uris or "redirect_uris" not in callback_uris: # the same redirect_uris for all flow types - callback_uris['redirect_uris'] = {} - for flow_type in self._callback_path['redirect_uris'].keys(): - if self._do_flow(flow_type, response_types): - callback_uris['redirect_uris'][flow_type] = _redirect_uris + callback_uris["redirect_uris"] = {} + for flow_type in self._callback_path["redirect_uris"].keys(): + if self._do_flow(flow_type, response_types, context): + callback_uris["redirect_uris"][flow_type] = _redirect_uris elif callback_uris: - if 'redirect_uris' in callback_uris: + if "redirect_uris" in callback_uris: pass else: - callback_uris['redirect_uris'] = {} - for flow_type, path in self._callback_path['redirect_uris'].items(): - if self._do_flow(flow_type, response_types): - callback_uris['redirect_uris'][flow_type] = [ - self.get_uri(base_url, path, hex)] + callback_uris["redirect_uris"] = {} + for flow_type in self._callback_path["redirect_uris"].keys(): + _var = self._do_flow(flow_type, response_types, context) + if _var: + _path = self._callback_path["redirect_uris"][_var] + callback_uris["redirect_uris"][flow_type] = [ + self.get_uri(base_url, _path, hex) + ] else: - callback_uris['redirect_uris'] = {} - for flow_type, path in self._callback_path['redirect_uris'].items(): - if self._do_flow(flow_type, response_types): - callback_uris['redirect_uris'][flow_type] = [self.get_uri(base_url, path, hex)] + callback_uris["redirect_uris"] = {} + for flow_type in self._callback_path["redirect_uris"].keys(): + _var = self._do_flow(flow_type, response_types, context) + if _var: + _path = self._callback_path["redirect_uris"][_var] + callback_uris["redirect_uris"][flow_type] = [self.get_uri(base_url, _path, hex)] return callback_uris - def construct_uris(self, - base_url: str, - hex: bytes, - context: ServiceContext, - targets: Optional[List[str]] = None, - response_types: Optional[List[str]] = None): - _callback_uris = context.get_preference('callback_uris', {}) + def construct_uris( + self, + base_url: str, + hex: bytes, + context: ServiceContext, + targets: Optional[List[str]] = None, + response_types: Optional[List[str]] = None, + ): + _callback_uris = context.get_preference("callback_uris", {}) for uri_name in self._callback_path.keys(): - if uri_name == 'redirect_uris': - _callback_uris = self._do_redirect_uris(base_url, hex, context, _callback_uris, - response_types) + if uri_name == "redirect_uris": + _callback_uris = self._do_redirect_uris( + base_url, hex, context, _callback_uris, response_types + ) _redirect_uris = set() - for flow, _uris in _callback_uris['redirect_uris'].items(): + for flow, _uris in _callback_uris["redirect_uris"].items(): _redirect_uris.update(set(_uris)) - context.set_preference('redirect_uris', list(_redirect_uris)) + context.set_preference("redirect_uris", list(_redirect_uris)) else: - _callback_uris[uri_name] = self.get_uri(base_url, self._callback_path[uri_name], - hex) + _callback_uris[uri_name] = self.get_uri( + base_url, self._callback_path[uri_name], hex + ) return _callback_uris diff --git a/src/idpyoidc/client/oauth2/client_credentials.py b/src/idpyoidc/client/oauth2/client_credentials.py index 3c7459de..4eedb465 100644 --- a/src/idpyoidc/client/oauth2/client_credentials.py +++ b/src/idpyoidc/client/oauth2/client_credentials.py @@ -24,16 +24,14 @@ def __init__(self, upstream_get, conf=None): Service.__init__(self, upstream_get, conf=conf) self.pre_construct.append(self.cc_pre_construct) - def cc_pre_construct(self, - request: Union[Message, dict], - service: Service, - post_args: Optional[dict], - **_args): - _grant_type = request.get('grant_type') + def cc_pre_construct( + self, request: Union[Message, dict], service: Service, post_args: Optional[dict], **_args + ): + _grant_type = request.get("grant_type") if not _grant_type: - request['grant_type'] = 'client_credentials' - elif _grant_type != 'client_credentials': - logging.error('Wrong grant_type') + request["grant_type"] = "client_credentials" + elif _grant_type != "client_credentials": + logging.error("Wrong grant_type") return request, post_args diff --git a/src/idpyoidc/client/oauth2/introspection.py b/src/idpyoidc/client/oauth2/introspection.py new file mode 100644 index 00000000..419d3e8f --- /dev/null +++ b/src/idpyoidc/client/oauth2/introspection.py @@ -0,0 +1,27 @@ +"""The service that talks to the OAuth2 refresh access token endpoint.""" +import logging +from typing import Optional + +from idpyoidc.client.oauth2.utils import get_state_parameter +from idpyoidc.client.service import Service +from idpyoidc.message import oauth2 +from idpyoidc.message.oauth2 import ResponseMessage +from idpyoidc.time_util import time_sans_frac + +LOGGER = logging.getLogger(__name__) + + +class Introspection(Service): + """The service that talks to the OAuth2 introspection endpoint.""" + + msg_type = oauth2.TokenIntrospectionRequest + response_cls = oauth2.TokenIntrospectionResponse + error_msg = oauth2.ResponseMessage + endpoint_name = "introspection_endpoint" + synchronous = True + service_name = "introspection" + default_authn_method = "client_secret_basic" + http_method = "POST" + + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) diff --git a/src/idpyoidc/client/oauth2/refresh_access_token.py b/src/idpyoidc/client/oauth2/refresh_access_token.py index 69400787..a88e251d 100644 --- a/src/idpyoidc/client/oauth2/refresh_access_token.py +++ b/src/idpyoidc/client/oauth2/refresh_access_token.py @@ -20,10 +20,10 @@ class RefreshAccessToken(Service): endpoint_name = "token_endpoint" synchronous = True service_name = "refresh_token" - default_authn_method = "bearer_header" + default_authn_method = "client_secret_post" http_method = "POST" - _include = {"grant_types_supported": ['refresh_token']} + _include = {"grant_types_supported": ["refresh_token"]} def __init__(self, upstream_get, conf=None): Service.__init__(self, upstream_get, conf=conf) diff --git a/src/idpyoidc/client/oauth2/registration.py b/src/idpyoidc/client/oauth2/registration.py new file mode 100644 index 00000000..19da4982 --- /dev/null +++ b/src/idpyoidc/client/oauth2/registration.py @@ -0,0 +1,110 @@ +import logging + +from cryptojwt import KeyJar + +from idpyoidc.client.entity import response_types_to_grant_types +from idpyoidc.client.service import Service +from idpyoidc.message import oauth2 +from idpyoidc.message.oauth2 import ResponseMessage + +__author__ = "Roland Hedberg" + +logger = logging.getLogger(__name__) + + +class Registration(Service): + msg_type = oauth2.OauthClientMetadata + response_cls = oauth2.OauthClientInformationResponse + error_msg = ResponseMessage + endpoint_name = "registration_endpoint" + synchronous = True + service_name = "registration" + request_body_type = "json" + http_method = "POST" + + callback_path = {} + + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) + self.pre_construct = [self.add_client_preference] + self.post_construct = [self.oauth2_post_construct] + + def add_client_preference(self, request_args=None, **kwargs): + _context = self.upstream_get("context") + _use = _context.map_preferred_to_registered() + for prop, spec in self.msg_type.c_param.items(): + if prop in request_args: + continue + + _val = _use.get(prop) + if _val: + if isinstance(_val, list): + if isinstance(spec[0], list): + request_args[prop] = _val + else: + request_args[prop] = _val[0] # get the first one + else: + request_args[prop] = _val + return request_args, {} + + def oauth2_post_construct(self, request_args=None, **kwargs): + try: + request_args["grant_types"] = response_types_to_grant_types( + request_args["response_types"] + ) + except KeyError: + pass + + # If a Client can use jwks_uri, it MUST NOT use jwks. + if "jwks_uri" in request_args and "jwks" in request_args: + del request_args["jwks"] + + return request_args + + def update_service_context(self, resp, key="", **kwargs): + # if "token_endpoint_auth_method" not in resp: + # resp["token_endpoint_auth_method"] = "client_secret_basic" + + _context = self.upstream_get("context") + _context.map_preferred_to_registered(resp) + + _context.registration_response = resp + _client_id = _context.get_usage("client_id") + if _client_id: + _context.client_id = _client_id + _keyjar = self.upstream_get("attribute", "keyjar") + if _keyjar: + if _client_id not in _keyjar: + _keyjar.import_jwks(_keyjar.export_jwks(True, ""), issuer_id=_client_id) + _client_secret = _context.get_usage("client_secret") + if _client_secret: + if not _keyjar: + _entity = self.upstream_get("unit") + _keyjar = _entity.keyjar = KeyJar() + + _context.client_secret = _client_secret + _keyjar.add_symmetric("", _client_secret) + _keyjar.add_symmetric(_client_id, _client_secret) + try: + _context.set_usage("client_secret_expires_at", resp["client_secret_expires_at"]) + except KeyError: + pass + + try: + _context.set_usage("registration_access_token", resp["registration_access_token"]) + except KeyError: + pass + + def gather_request_args(self, **kwargs): + """ + + @param kwargs: + @return: + """ + _context = self.upstream_get("context") + req_args = _context.claims.create_registration_request() + if "request_args" in self.conf: + req_args.update(self.conf["request_args"]) + + req_args.update(kwargs) + return req_args diff --git a/src/idpyoidc/client/oauth2/resource.py b/src/idpyoidc/client/oauth2/resource.py new file mode 100644 index 00000000..efec4db6 --- /dev/null +++ b/src/idpyoidc/client/oauth2/resource.py @@ -0,0 +1,28 @@ +import logging +from typing import Optional +from typing import Union + +from idpyoidc import verified_claim_name +from idpyoidc.client.oauth2.utils import get_state_parameter +from idpyoidc.client.service import Service +from idpyoidc.claims import get_encryption_algs +from idpyoidc.claims import get_encryption_encs +from idpyoidc.claims import get_signing_algs +from idpyoidc.exception import MissingSigningKey +from idpyoidc.message import Message +from idpyoidc.message import oauth2 +from idpyoidc.message import oidc + +logger = logging.getLogger(__name__) + + +class Resource(Service): + msg_type = Message + response_cls = Message + error_msg = oauth2.ResponseMessage + endpoint_name = "" + service_name = "resource" + default_authn_method = "bearer_header" + + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) diff --git a/src/idpyoidc/client/oauth2/resource_owner_password_credentials.py b/src/idpyoidc/client/oauth2/resource_owner_password_credentials.py index e2148035..6db5cf76 100644 --- a/src/idpyoidc/client/oauth2/resource_owner_password_credentials.py +++ b/src/idpyoidc/client/oauth2/resource_owner_password_credentials.py @@ -24,16 +24,14 @@ def __init__(self, upstream_get, conf=None): Service.__init__(self, upstream_get, conf=conf) self.pre_construct.append(self.ropc_pre_construct) - def ropc_pre_construct(self, - request: Union[Message, dict], - service: Service, - post_args: Optional[dict], - **_args): - _grant_type = request.get('grant_type') + def ropc_pre_construct( + self, request: Union[Message, dict], service: Service, post_args: Optional[dict], **_args + ): + _grant_type = request.get("grant_type") if not _grant_type: - request['grant_type'] = 'password' - elif _grant_type != 'password': - logging.error('Wrong grant_type') + request["grant_type"] = "password" + elif _grant_type != "password": + logging.error("Wrong grant_type") return request, post_args diff --git a/src/idpyoidc/client/oauth2/server_metadata.py b/src/idpyoidc/client/oauth2/server_metadata.py index 9bc868f4..8fe4ecc9 100644 --- a/src/idpyoidc/client/oauth2/server_metadata.py +++ b/src/idpyoidc/client/oauth2/server_metadata.py @@ -4,6 +4,7 @@ from cryptojwt.key_jar import KeyJar +from idpyoidc.client.defaults import OAUTH2_SERVER_METADATA_URL from idpyoidc.client.defaults import OIDCONF_PATTERN from idpyoidc.client.exception import OidcServiceError from idpyoidc.client.service import Service @@ -23,6 +24,7 @@ class ServerMetadata(Service): synchronous = True service_name = "server_metadata" http_method = "GET" + url_pattern = OAUTH2_SERVER_METADATA_URL _supports = {} @@ -41,9 +43,9 @@ def get_endpoint(self): _iss = self.endpoint if _iss.endswith("/"): - return OIDCONF_PATTERN.format(_iss[:-1]) + return self.url_pattern.format(_iss[:-1]) - return OIDCONF_PATTERN.format(_iss) + return self.url_pattern.format(_iss) def get_request_parameters(self, method="GET", **kwargs): """ @@ -117,7 +119,7 @@ def _update_service_context(self, resp): # If I already have a Key Jar then I'll add then provider keys to # that. Otherwise, a new Key Jar is minted try: - _keyjar = self.upstream_get('attribute', 'keyjar') + _keyjar = self.upstream_get("attribute", "keyjar") if _keyjar is None: _keyjar = KeyJar() except KeyError: @@ -135,7 +137,7 @@ def _update_service_context(self, resp): _info = resp.to_dict() else: _info = resp - _context.map_supported_to_preferred(_info) + _context.map_service_against_endpoint(_info) def update_service_context(self, resp, key: Optional[str] = "", **kwargs): return self._update_service_context(resp) diff --git a/src/idpyoidc/client/oauth2/stand_alone_client.py b/src/idpyoidc/client/oauth2/stand_alone_client.py new file mode 100644 index 00000000..db14bf1f --- /dev/null +++ b/src/idpyoidc/client/oauth2/stand_alone_client.py @@ -0,0 +1,761 @@ +import logging +import sys +import traceback +from typing import List +from typing import Optional + +from cryptojwt import as_unicode +from cryptojwt.key_bundle import keybundle_from_local_file + +from idpyoidc import verified_claim_name +from idpyoidc.client.defaults import DEFAULT_RESPONSE_MODE +from idpyoidc.client.exception import ConfigurationError +from idpyoidc.client.exception import OidcServiceError +from idpyoidc.client.exception import Unsupported +from idpyoidc.client.oauth2 import Client +from idpyoidc.client.oauth2 import dynamic_provider_info_discovery +from idpyoidc.client.oauth2.utils import pick_redirect_uri +from idpyoidc.exception import MessageException +from idpyoidc.exception import MissingRequiredAttribute +from idpyoidc.exception import NotForMe +from idpyoidc.message import Message +from idpyoidc.message.oauth2 import is_error_message +from idpyoidc.message.oauth2 import ResponseMessage +from idpyoidc.message.oidc import AuthorizationRequest +from idpyoidc.message.oidc import AuthorizationResponse +from idpyoidc.message.oidc import Claims +from idpyoidc.message.oidc import OpenIDSchema +from idpyoidc.message.oidc import RegistrationRequest +from idpyoidc.message.oidc.session import BackChannelLogoutRequest +from idpyoidc.time_util import utc_time_sans_frac +from idpyoidc.util import rndstr + +logger = logging.getLogger(__name__) + + +class StandAloneClient(Client): + + def get_session_information(self, key): + """ + This is the second of the methods users of this class should know about. + It will return the complete session information as an + :py:class:`idpyoidc.client.current.Current` instance. + + :param key: The session key (state) + :return: A State instance + """ + + return self.get_context().cstate.get(key) + + def do_provider_info( + self, + behaviour_args: Optional[dict] = None, + ) -> str: + """ + Either get the provider info from configuration or through dynamic + discovery. + + :param behaviour_args: Behaviour specific attributes + :return: issuer ID + """ + logger.debug(20 * "*" + " do_provider_info " + 20 * "*") + + _context = self.get_context() + _pi = _context.get("provider_info") + if _pi is None or _pi == {}: + dynamic_provider_info_discovery(self, behaviour_args=behaviour_args) + _pi = _context.provider_info + elif len(_pi) == 1 and 'issuer' in _pi: + _context.issuer = _pi['issuer'] + dynamic_provider_info_discovery(self, behaviour_args=behaviour_args) + _pi = _context.provider_info + else: + for key, val in _pi.items(): + # All service endpoint parameters in the provider info has + # a name ending in '_endpoint' so I can look specifically + # for those + if key.endswith("_endpoint"): + for _srv in self.get_services().values(): + # Every service has an endpoint_name assigned + # when initiated. This name *MUST* match the + # endpoint names used in the provider info + if _srv.endpoint_name == key: + _srv.endpoint = val + + if "keys" in _pi: + _kj = self.get_attribute("keyjar") + for typ, _spec in _pi["keys"].items(): + if typ == "url": + for _iss, _url in _spec.items(): + _kj.add_url(_iss, _url) + elif typ == "file": + for kty, _name in _spec.items(): + if kty == "jwks": + _kj.import_jwks_from_file(_name, _context.get("issuer")) + elif kty == "rsa": # PEM file + _kb = keybundle_from_local_file(_name, "der", ["sig"]) + _kj.add_kb(_context.get("issuer"), _kb) + else: + raise ValueError("Unknown provider JWKS type: {}".format(typ)) + + _context.map_supported_to_preferred(info=_pi) + + try: + return _context.provider_info['issuer'] + except: + return _context.issuer + + def do_client_registration( + self, + request_args: Optional[dict] = None, + behaviour_args: Optional[dict] = None, + ): + """ + Prepare for and do client registration if configured to do so + + :param iss_id: Issuer ID + :param behaviour_args: To fine tune behaviour + :param client: A Client instance + :param state: A key by which the state of the session can be + retrieved + """ + + logger.debug(20 * "*" + " do_client_registration " + 20 * "*") + + _context = self.get_context() + + # This should only be interesting if the client supports Single Log Out + # if _context.callback.get("post_logout_redirect_uri") is None: + # _context.callback["post_logout_redirect_uri"] = [self.base_url] + + if not self.get_client_id(): # means I have to do dynamic client registration + if request_args is None: + request_args = {} + + if behaviour_args: + _params = RegistrationRequest().parameters() + request_args.update({k: v for k, v in behaviour_args.items() if k in _params}) + + load_registration_response(self, request_args=request_args) + else: + _context.map_preferred_to_registered() + + def _get_response_type(self, context, req_args: Optional[dict] = None): + if req_args: + return req_args.get("response_type", context.claims.get_usage("response_types")[0]) + else: + return context.claims.get_usage("response_types")[0] + + def _get_response_mode(self, context, response_type, request_args): + if request_args: + _requested = request_args.get('response_mode') + else: + _requested = None + _supported = context.claims.get_usage('response_modes') + if _requested: + if _supported and _requested not in _supported: + raise ValueError( + "You can not use a response_mode you have not stated should be supported") + + if DEFAULT_RESPONSE_MODE[response_type] == _requested: + return None + else: + return _requested + elif _supported: + _type = response_type.split(' ') + _type.sort() + response_type = " ".join(_type) + # Is it the default response mode + if DEFAULT_RESPONSE_MODE[response_type] in _supported: + return None + else: + return _supported[0] + else: + return None + + def init_authorization( + self, + req_args: Optional[dict] = None, + behaviour_args: Optional[dict] = None, + ) -> str: + """ + Constructs the URL that will redirect the user to the authorization + endpoint of the OP/AS. + + :param behaviour_args: + :param req_args: Non-default Request arguments + :return: A dictionary with 2 keys: **url** The authorization redirect + URL and **state** the key to the session information in the + state data store. + """ + + logger.debug(20 * "*" + " init_authorization " + 20 * "*") + + _context = self.get_context() + _response_type = self._get_response_type(_context, req_args) + _response_mode = self._get_response_mode(_context, _response_type, req_args) + try: + _redirect_uri = pick_redirect_uri( + _context, request_args=req_args, response_type=_response_type, + response_mode=_response_mode + ) + except KeyError: + raise Unsupported( + 'Could not pick a redirect_uri based on the given response_type and response_mode') + except [MissingRequiredAttribute, ValueError]: + raise + + request_args = { + "redirect_uri": _redirect_uri, + "response_type": _response_type, + } + + if _response_mode: + request_args['response_mode'] = _response_mode + + _nonce = '' + if self.client_type == 'oidc': + _nonce = rndstr(24) + request_args['nonce'] = _nonce + + _scope = _context.claims.get_usage("scope") + if _scope: + request_args["scope"] = _scope + + _req_args = _context.config.get("request_args") + if _req_args: + if "claims" in _req_args: + _req_args["claims"] = Claims(**_req_args["claims"]) + request_args.update(_req_args) + + if req_args is not None: + request_args.update(req_args) + + # Need a new state for a new authorization request + _current = _context.cstate + _state = _current.create_key() + request_args["state"] = _state + if _nonce: + _current.bind_key(_nonce, _state) + + _current.set(_state, {"iss": _context.get("issuer")}) + + logger.debug("Authorization request args: {}".format(request_args)) + + # if behaviour_args and "request_param" not in behaviour_args: + # _pi = _context.get("provider_info") + + _srv = self.get_service("authorization") + _info = _srv.get_request_parameters( + request_args=request_args, behaviour_args=behaviour_args + ) + logger.debug("Authorization info: {}".format(_info)) + return _info["url"] + + @staticmethod + def get_client_authn_method(self, endpoint): + """ + Return the client authentication method a client wants to use at a + specific endpoint + + :param endpoint: The endpoint at which the client has to authenticate + :return: The client authentication method + """ + if endpoint == "token_endpoint": + auth_method = self.get_context().get_usage("token_endpoint_auth_method") + if not auth_method: + return "" + else: + if isinstance(auth_method, str): + return auth_method + else: # a list + return auth_method[0] + return "" + + def get_tokens(self, state): + """ + Use the 'accesstoken' service to get an access token from the OP/AS. + + :param state: The state key (the state parameter in the + authorization request) + :return: A :py:class:`idpyoidc.message.oidc.AccessTokenResponse` or + :py:class:`idpyoidc.message.oauth2.AuthorizationResponse` + """ + logger.debug(20 * "*" + " get_tokens " + 20 * "*") + + _context = self.get_context() + _claims = _context.cstate.get_set(state, claim=["code", "redirect_uri"]) + + req_args = { + "code": _claims["code"], + "state": state, + "redirect_uri": _claims["redirect_uri"], + "grant_type": "authorization_code", + "client_id": self.get_client_id(), + "client_secret": _context.claims.get_usage("client_secret"), + } + logger.debug("request_args: {}".format(req_args)) + try: + tokenresp = self.do_request( + "accesstoken", + request_args=req_args, + authn_method=self.get_client_authn_method(self, "token_endpoint"), + state=state, + ) + except Exception: + message = traceback.format_exception(*sys.exc_info()) + logger.error(message) + raise + else: + if is_error_message(tokenresp): + raise OidcServiceError(tokenresp["error"]) + + return tokenresp + + def refresh_access_token(self, state, scope=""): + """ + Refresh an access token using a refresh_token. When asking for a new + access token the RP can ask for another scope for the new token. + + :param state: The state key (the state parameter in the + authorization request) + :param scope: What the returned token should be valid for. + :return: A :py:class:`idpyoidc.message.oidc.AccessTokenResponse` instance + """ + + logger.debug(20 * "*" + " refresh_access_token " + 20 * "*") + + if scope: + req_args = {"scope": scope} + else: + req_args = {} + + try: + tokenresp = self.do_request( + "refresh_token", + authn_method=self.get_client_authn_method(self, "token_endpoint"), + state=state, + request_args=req_args, + ) + except Exception: + message = traceback.format_exception(*sys.exc_info()) + logger.error(message) + raise + else: + if is_error_message(tokenresp): + raise OidcServiceError(tokenresp["error"]) + + return tokenresp + + def get_user_info(self, state, access_token="", **kwargs): + """ + use the access token previously acquired to get some userinfo + + :param state: The state value, this is the key into the session + data store + :param access_token: An access token + :param kwargs: Extra keyword arguments + :return: A :py:class:`idpyoidc.message.oidc.OpenIDSchema` instance + """ + + logger.debug(20 * "*" + " get_user_info " + 20 * "*") + + if not access_token: + _arg = self.get_context().cstate.get_set(state, claim=["access_token"]) + access_token = _arg["access_token"] + + request_args = {"access_token": access_token} + + resp = self.do_request("userinfo", state=state, request_args=request_args, **kwargs) + if is_error_message(resp): + raise OidcServiceError(resp["error"]) + + return resp + + @staticmethod + def userinfo_in_id_token(id_token: Message, user_info_claims: Optional[List] = None) -> dict: + """ + Given a verified ID token return all the claims that may be user information. + + :param id_token: An :py:class:`idpyoidc.message.oidc.IDToken` instance + :return: A dictionary with user information + """ + if user_info_claims is None: + user_info_claims = list(OpenIDSchema.c_param.keys()) + + res = dict([(k, id_token[k]) for k in user_info_claims if k in id_token]) + res.update(id_token.extra()) + return res + + def finalize_auth( + self, response: dict, behaviour_args: Optional[dict] = None + ): + """ + Given the response returned to the redirect_uri, parse and verify it. + + :param behaviour_args: For finetuning behaviour + :param response: The authorization response as a dictionary + :return: An :py:class:`idpyoidc.message.oidc.AuthorizationResponse` or + :py:class:`idpyoidc.message.oauth2.AuthorizationResponse` instance. + """ + + logger.debug(20 * "*" + " finalize_auth " + 20 * "*") + + _srv = self.get_service("authorization") + try: + authorization_response = _srv.parse_response( + response, sformat="dict", behaviour_args=behaviour_args + ) + except Exception as err: + logger.error("Parsing authorization_response: {}".format(err)) + message = traceback.format_exception(*sys.exc_info()) + logger.error(message) + raise + else: + logger.debug("Authz response: {}".format(authorization_response.to_dict())) + + if is_error_message(authorization_response): + return authorization_response + + _context = self.get_context() + try: + _iss = _context.cstate.get_set(authorization_response["state"], claim=["iss"]).get( + "iss" + ) + except KeyError: + raise KeyError("Unknown state value") + + try: + issuer = _context.provider_info['issuer'] + except KeyError: + issuer = _context.issuer + + if _iss != issuer: + logger.error("Issuer problem: {} != {}".format(_iss, issuer)) + # got it from the wrong bloke + raise ValueError("Impersonator {}".format(issuer)) + + _srv.update_service_context(authorization_response, key=authorization_response["state"]) + return authorization_response + + def get_access_and_id_token( + self, + authorization_response: Optional[Message] = None, + state: Optional[str] = "", + behaviour_args: Optional[dict] = None, + ): + """ + There are a number of services where access tokens and ID tokens can + occur in the response. This method goes through the possible places + based on the response_type the client uses. + + :param behaviour_args: For finetuning behaviour + :param authorization_response: The Authorization response + :param state: The state key (the state parameter in the + authorization request) + :return: A dictionary with 2 keys: **access_token** with the access + token as value and **id_token** with a verified ID Token if one + was returned otherwise None. + """ + + logger.debug(20 * "*" + " get_access_and_id_token " + 20 * "*") + + _context = self.get_context() + + resp_attr = authorization_response or _context.cstate.get_set( + state, message=AuthorizationResponse + ) + if resp_attr is None: + raise ValueError("One of authorization_response or state must be provided") + + if not state: + state = authorization_response["state"] + + _req_attr = _context.cstate.get_set(state, AuthorizationRequest) + if isinstance(_req_attr["response_type"], list): + _resp_type = set(_req_attr["response_type"]) + else: + _resp_type = set(_req_attr["response_type"].split(" ")) + + access_token = None + id_token = None + if _resp_type in [{"id_token"}, {"id_token", "token"}, {"code", "id_token", "token"}]: + id_token = authorization_response["__verified_id_token"] + + if _resp_type in [ + {"token"}, + {"id_token", "token"}, + {"code", "token"}, + {"code", "id_token", "token"}, + ]: + access_token = authorization_response["access_token"] + if behaviour_args: + if behaviour_args.get("collect_tokens", False): + # get what you can from the token endpoint + token_resp = self.get_tokens(state) + if is_error_message(token_resp): + return False, "Invalid response %s." % token_resp["error"] + # Now which access_token should I use + access_token = token_resp["access_token"] + # May or may not get an ID Token + id_token = token_resp.get("__verified_id_token") + + elif _resp_type in [{"code"}, {"code", "id_token"}]: + # get the access token + token_resp = self.get_tokens(state) + if is_error_message(token_resp): + return False, "Invalid response %s." % token_resp["error"] + + access_token = token_resp["access_token"] + # May or may not get an ID Token + id_token = token_resp.get("__verified_id_token") + + return {"access_token": access_token, "id_token": id_token} + + # noinspection PyUnusedLocal + def finalize(self, response, behaviour_args: Optional[dict] = None): + """ + The third of the high level methods that a user of this Class should + know about. + Once the consumer has redirected the user back to the + callback URL there might be a number of services that the client should + use. Which one those are defined by the client configuration. + + :param behaviour_args: For finetuning + :param issuer: Who sent the response + :param response: The Authorization response as a dictionary + :returns: A dictionary with the following keys: + **state** The key under which the session information is + stored in the data store and + **token** The access token + **id_token:: the ID Token + **userinfo** The collected user information + **session_state** If logout is supported the special session_state claim + """ + + authorization_response = self.finalize_auth(response) + if is_error_message(authorization_response): + return { + "state": authorization_response["state"], + "error": authorization_response["error"], + } + + _state = authorization_response["state"] + token = self.get_access_and_id_token( + authorization_response, state=_state, behaviour_args=behaviour_args + ) + _id_token = token.get("id_token") + logger.debug(f"ID Token: {_id_token}") + + if self.get_service("userinfo") and token["access_token"]: + inforesp = self.get_user_info( + state=authorization_response["state"], + access_token=token["access_token"], + ) + + if isinstance(inforesp, ResponseMessage) and "error" in inforesp: + return {"error": "Invalid response %s." % inforesp["error"], "state": _state} + + elif _id_token: # look for it in the ID Token + inforesp = self.userinfo_in_id_token(_id_token) + else: + inforesp = {} + + logger.debug("UserInfo: %s", inforesp) + + _context = self.get_context() + try: + _sid_support = _context.get("provider_info")["backchannel_logout_session_required"] + except KeyError: + try: + _sid_support = _context.get("provider_info")["frontchannel_logout_session_required"] + except Exception: + _sid_support = False + + if _sid_support and _id_token: + try: + sid = _id_token["sid"] + except KeyError: + pass + else: + _context.cstate.bind_key(sid, _state) + + if _id_token: + _context.cstate.bind_key(_id_token["sub"], _state) + else: + _context.cstate.bind_key(inforesp["sub"], _state) + + return { + "userinfo": inforesp, + "state": authorization_response["state"], + "token": token["access_token"], + "id_token": _id_token, + "session_state": authorization_response.get("session_state", ""), + "issuer": _context.issuer + } + + def has_active_authentication(self, state): + """ + Find out if the user has an active authentication + + :param state: + :return: True/False + """ + + # Look for an IdToken + _arg = self.get_context().cstate.get_set(state, claim=["__verified_id_token"]) + + if _arg: + _now = utc_time_sans_frac() + exp = _arg["__verified_id_token"]["exp"] + return _now < exp + else: + return False + + def get_valid_access_token(self, state: str) -> tuple: + """ + Find a valid access token. + + :param state: + :return: An access token if a valid one exists and when it + expires else raise exception. + """ + + token_info = None + indefinite = [] + now = utc_time_sans_frac() + + _context = self.get_context() + _args = _context.cstate.get_set(state, claim=["access_token", "__expires_at"]) + if "access_token" in _args: + access_token = _args["access_token"] + _exp = _args.get("__expires_at", 0) + if not _exp: # No expiry date, lives forever + indefinite.append((access_token, 0)) + else: + if _exp > now: # expires sometime in the future + token_info = (access_token, _exp) + + if indefinite: + return indefinite[0] + else: + if token_info: + return token_info + else: + raise OidcServiceError("No valid access token") + + def logout( + self, + state: str, + post_logout_redirect_uri: Optional[str] = "", + ) -> dict: + """ + Does an RP initiated logout from an OP. After logout the user will be + redirected by the OP to a URL of choice (post_logout_redirect_uri). + + :param state: Key to an active session + :param client: Which client to use + :param post_logout_redirect_uri: If a special post_logout_redirect_uri + should be used + :return: Request arguments + """ + + logger.debug(20 * "*" + " logout " + 20 * "*") + + try: + srv = self.get_service("end_session") + except KeyError: + raise OidcServiceError("Does not know how to logout") + + if post_logout_redirect_uri: + request_args = {"post_logout_redirect_uri": post_logout_redirect_uri} + else: + request_args = {} + + _info = srv.get_request_parameters(state=state, request_args=request_args) + + logger.debug(f"EndSession Request: {_info['request'].to_dict()}") + return _info + + def close( + self, state: str, post_logout_redirect_uri: Optional[str] = "" + ) -> dict: + + logger.debug(20 * "*" + " close " + 20 * "*") + + return self.logout( + state=state, post_logout_redirect_uri=post_logout_redirect_uri + ) + + def clear_session(self, state): + self.get_context().cstate.remove_state(state) + + +def backchannel_logout(client, request="", request_args=None): + """ + + :param request: URL encoded logout request + :return: + """ + if request: + req = BackChannelLogoutRequest().from_urlencoded(as_unicode(request)) + elif request_args: + req = BackChannelLogoutRequest(**request_args) + else: + raise MissingRequiredAttribute("logout_token") + + _context = client.get_context() + kwargs = { + "aud": client.get_client_id(), + "iss": _context.get("issuer"), + "keyjar": client.get_attribute("keyjar"), + "allowed_sign_alg": _context.get("registration_response").get( + "id_token_signed_response_alg", "RS256" + ), + } + + logger.debug(f"(backchannel_logout) Verifying request using: {kwargs}") + try: + req.verify(**kwargs) + except (MessageException, ValueError, NotForMe) as err: + raise MessageException("Bogus logout request: {}".format(err)) + else: + logger.debug("Request verified OK") + + # Find the subject through 'sid' or 'sub' + sub = req[verified_claim_name("logout_token")].get("sub") + sid = None + if not sub: + sid = req[verified_claim_name("logout_token")].get("sid") + + if not sub and not sid: + raise MessageException('Neither "sid" nor "sub"') + elif sub: + _state = _context.cstate.get_base_key(sub) + elif sid: + _state = _context.cstate.get_base_key(sid) + else: + _state = None + + return _state + + +def load_registration_response(client, request_args=None): + """ + If the client has been statically registered that information + must be provided during the configuration. If expected to be + done dynamically this method will do dynamic client registration. + + :param client: A :py:class:`idpyoidc.client.oidc.Client` instance + """ + if not client.get_context().get_client_id(): + try: + response = client.do_request("registration", request_args=request_args) + except KeyError: + raise ConfigurationError("No registration info") + except Exception as err: + logger.error(err) + raise + else: + if "error" in response: + raise OidcServiceError(response.to_json()) diff --git a/src/idpyoidc/client/oauth2/token_exchange.py b/src/idpyoidc/client/oauth2/token_exchange.py index 36a3658a..ab182fb9 100644 --- a/src/idpyoidc/client/oauth2/token_exchange.py +++ b/src/idpyoidc/client/oauth2/token_exchange.py @@ -27,7 +27,7 @@ class TokenExchange(Service): request_body_type = "urlencoded" response_body_type = "json" - _include = {'grant_types_supported': ['urn:ietf:params:oauth:grant-type:token-exchange']} + _include = {"grant_types_supported": ["urn:ietf:params:oauth:grant-type:token-exchange"]} def __init__(self, upstream_get, conf=None): Service.__init__(self, upstream_get, conf=conf) @@ -48,21 +48,21 @@ def oauth_pre_construct(self, request_args=None, post_args=None, **kwargs): if request_args is None: request_args = {} - if 'subject_token' not in request_args: + if "subject_token" not in request_args: try: _key = get_state_parameter(request_args, kwargs) except MissingParameter: raise MissingRequiredAttribute("subject_token") - parameters = {'access_token', 'scope'} + parameters = {"access_token", "scope"} _current = self.upstream_get("service_context").cstate _args = _current.get_set(_key, claim=parameters) request_args["subject_token"] = _args["access_token"] - request_args["subject_token_type"] = 'urn:ietf:params:oauth:token-type:access_token' - if 'scope' not in request_args and "scope" in _args: + request_args["subject_token_type"] = "urn:ietf:params:oauth:token-type:access_token" + if "scope" not in request_args and "scope" in _args: request_args["scope"] = _args["scope"] return request_args, post_args diff --git a/src/idpyoidc/client/oauth2/token_revocation.py b/src/idpyoidc/client/oauth2/token_revocation.py new file mode 100644 index 00000000..a9fbcd1f --- /dev/null +++ b/src/idpyoidc/client/oauth2/token_revocation.py @@ -0,0 +1,28 @@ +"""The service that talks to the OAuth2 refresh access token endpoint.""" +import logging +from typing import Optional + +from idpyoidc.client.oauth2.utils import get_state_parameter +from idpyoidc.client.service import Service +from idpyoidc.message import oauth2 +from idpyoidc.message.oauth2 import ResponseMessage +from idpyoidc.time_util import time_sans_frac + +LOGGER = logging.getLogger(__name__) + + +class TokenRevocation(Service): + """The service that talks to the OAuth2 refresh access token endpoint.""" + + msg_type = oauth2.TokenRevocationRequest + response_cls = oauth2.TokenRevocationResponse + error_msg = oauth2.TokenRevocationErrorResponse + endpoint_name = "revocation_endpoint" + response_body_type = "text" + synchronous = True + service_name = "token_revocation" + default_authn_method = "client_secret_basic" + http_method = "POST" + + def __init__(self, upstream_get, conf=None): + Service.__init__(self, upstream_get, conf=conf) diff --git a/src/idpyoidc/client/oauth2/utils.py b/src/idpyoidc/client/oauth2/utils.py index 15d2c04c..1933a2d0 100644 --- a/src/idpyoidc/client/oauth2/utils.py +++ b/src/idpyoidc/client/oauth2/utils.py @@ -2,6 +2,7 @@ from typing import Optional from typing import Union +from idpyoidc.client.defaults import DEFAULT_RESPONSE_MODE from idpyoidc.client.service import Service from idpyoidc.exception import MissingParameter from idpyoidc.exception import MissingRequiredAttribute @@ -24,9 +25,10 @@ def get_state_parameter(request_args, kwargs): def pick_redirect_uri( - context, - request_args: Optional[Union[Message, dict]] = None, - response_type: Optional[str] = "", + context, + request_args: Optional[Union[Message, dict]] = None, + response_type: Optional[str] = "", + response_mode: Optional[str] = "" ): if request_args is None: request_args = {} @@ -36,29 +38,34 @@ def pick_redirect_uri( _callback_uris = context.get_preference("callback_uris") if _callback_uris: - _callback_uris = _callback_uris.get("redirect_uris") - - if _callback_uris: - if not response_type: - _conf_resp_types = context.get_usage("response_types", []) - response_type = request_args.get("response_type") - if not response_type and _conf_resp_types: - response_type = _conf_resp_types[0] - - _response_mode = request_args.get("response_mode") + _redirect_uris = _callback_uris.get("redirect_uris") + _response_mode = request_args.get("response_mode") or response_mode if _response_mode: if _response_mode == "form_post": - redirect_uri = _callback_uris["form_post"][0] - elif response_type == "code" or response_type == ["code"]: - redirect_uri = _callback_uris["code"][0] + try: + redirect_uri = _redirect_uris["form_post"][0] + except KeyError: + redirect_uri = _redirect_uris["query"][0] else: - redirect_uri = _callback_uris["implicit"][0] + redirect_uri = _redirect_uris[_response_mode] else: - if 'code' == response_type: - redirect_uri = _callback_uris["code"][0] - else: - redirect_uri = _callback_uris["implicit"][0] + if not response_type: + _conf_resp_types = context.get_usage("response_types", []) + response_type = request_args.get("response_type") + if not response_type and _conf_resp_types: + response_type = _conf_resp_types[0] + + if isinstance(response_type, list): + response_type.sort() + response_type = " ".join(response_type) + + try: + _response_mode = DEFAULT_RESPONSE_MODE[response_type] + except KeyError: + raise ValueError(f"Unknown response_type: {response_type}") + + redirect_uri = _redirect_uris[_response_mode][0] logger.debug( f"pick_redirect_uris: response_type={response_type}, response_mode={_response_mode}, " @@ -76,11 +83,11 @@ def pick_redirect_uri( def pre_construct_pick_redirect_uri( - request_args: Optional[Union[Message, dict]] = None, service: Optional[Service] = None, - **kwargs + request_args: Optional[Union[Message, dict]] = None, service: Optional[Service] = None, **kwargs ): - request_args["redirect_uri"] = pick_redirect_uri(service.upstream_get("context"), - request_args=request_args) + request_args["redirect_uri"] = pick_redirect_uri( + service.upstream_get("context"), request_args=request_args + ) return request_args, {} diff --git a/src/idpyoidc/client/oidc/__init__.py b/src/idpyoidc/client/oidc/__init__.py index 7d171ef9..ab446455 100755 --- a/src/idpyoidc/client/oidc/__init__.py +++ b/src/idpyoidc/client/oidc/__init__.py @@ -61,6 +61,7 @@ "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", "response_types": "response_types_supported", + "response_modes": "response_modes_supported", "grant_types": "grant_types_supported", } @@ -77,21 +78,21 @@ class FetchException(Exception): class RP(oauth2.Client): - client_type = 'oidc' + client_type = "oidc" def __init__( - self, - keyjar: Optional[KeyJar] = None, - config: Optional[Union[dict, Configuration]] = None, - services: Optional[dict] = None, - httpc: Optional[Callable] = None, - httpc_params: Optional[dict] = None, - upstream_get: Optional[Callable] = None, - key_conf: Optional[dict] = None, - entity_id: Optional[str] = '', - verify_ssl: Optional[bool] = True, - jwks_uri: Optional[str] = "", - **kwargs + self, + keyjar: Optional[KeyJar] = None, + config: Optional[Union[dict, Configuration]] = None, + services: Optional[dict] = None, + httpc: Optional[Callable] = None, + httpc_params: Optional[dict] = None, + upstream_get: Optional[Callable] = None, + key_conf: Optional[dict] = None, + entity_id: Optional[str] = "", + verify_ssl: Optional[bool] = True, + jwks_uri: Optional[str] = "", + **kwargs ): self.upstream_get = upstream_get if services: @@ -111,13 +112,13 @@ def __init__( entity_id=entity_id, verify_ssl=verify_ssl, jwks_uri=jwks_uri, - client_type='oidc', + client_type="oidc", **kwargs ) _context = self.get_service_context() - if _context.get_preference('callback_uris') is None: - _context.set_preference('callback_uris', {}) + if _context.get_preference("callback_uris") is None: + _context.set_preference("callback_uris", {}) def fetch_distributed_claims(self, userinfo, callback=None): """ diff --git a/src/idpyoidc/client/oidc/access_token.py b/src/idpyoidc/client/oidc/access_token.py index c39a404d..39f778fb 100644 --- a/src/idpyoidc/client/oidc/access_token.py +++ b/src/idpyoidc/client/oidc/access_token.py @@ -23,19 +23,18 @@ class AccessToken(access_token.AccessToken): error_msg = oidc.ResponseMessage default_authn_method = "client_secret_basic" - _include = {"grant_types_supported": ['authorization_code']} + _include = {"grant_types_supported": ["authorization_code"]} _supports = { "token_endpoint_auth_methods_supported": get_client_authn_methods, - "token_endpoint_auth_signing_alg_values_supported": get_signing_algs + "token_endpoint_auth_signing_alg_values_supported": get_signing_algs, } def __init__(self, upstream_get, conf: Optional[dict] = None): access_token.AccessToken.__init__(self, upstream_get, conf=conf) def gather_verify_arguments( - self, response: Optional[Union[dict, Message]] = None, - behaviour_args: Optional[dict] = None + self, response: Optional[Union[dict, Message]] = None, behaviour_args: Optional[dict] = None ): """ Need to add some information before running verify() @@ -48,7 +47,7 @@ def gather_verify_arguments( kwargs = { "client_id": _entity.get_client_id(), "iss": _context.issuer, - "keyjar": self.upstream_get('attribute', 'keyjar'), + "keyjar": self.upstream_get("attribute", "keyjar"), "verify": True, "skew": _context.clock_skew, } diff --git a/src/idpyoidc/client/oidc/authorization.py b/src/idpyoidc/client/oidc/authorization.py index 44a7ada9..dfccc14d 100644 --- a/src/idpyoidc/client/oidc/authorization.py +++ b/src/idpyoidc/client/oidc/authorization.py @@ -35,24 +35,23 @@ class Authorization(authorization.Authorization): "request_object_signing_alg_values_supported": claims.get_signing_algs, "request_object_encryption_alg_values_supported": claims.get_encryption_algs, "request_object_encryption_enc_values_supported": claims.get_encryption_encs, - "response_types_supported": ["code", "token", "code token", 'id_token', 'id_token token', - 'code id_token', 'code idtoken token'], - 'request_parameter_supported': None, - 'request_uri_parameter_supported': None, + "response_types_supported": ["code", "id_token", "code id_token"], + "request_parameter_supported": None, + "request_uri_parameter_supported": None, "request_uris": None, "request_parameter": None, "encrypt_request_object_supported": False, "redirect_uris": None, - "response_modes_supported": ['query', 'fragment', 'form_post'] + "response_modes_supported": ["query", "fragment", "form_post"], } _callback_path = { "request_uris": ["req"], - "redirect_uris": { # based on response_types - "code": "authz_cb", - "token": "authz_tok_cb", - "form_post": "form" - } + "redirect_uris": { # based on response_mode + "query": "authz_cb", + "fragment": "authz_tok_cb", + "form_post": "authz_cb_form", + }, } def __init__(self, upstream_get, conf=None, request_args: Optional[dict] = None): @@ -66,8 +65,8 @@ def __init__(self, upstream_get, conf=None, request_args: Optional[dict] = None) self.oidc_pre_construct, ] self.post_construct = [self.oidc_post_construct] - if 'scope' not in self.default_request_args: - self.default_request_args['scope'] = ['openid'] + if "scope" not in self.default_request_args: + self.default_request_args["scope"] = ["openid"] def set_state(self, request_args, **kwargs): _context = self.upstream_get("context") @@ -80,7 +79,7 @@ def set_state(self, request_args, **kwargs): _state = _context.cstate.create_key() request_args["state"] = _state - _context.cstate.set(_state, {'iss': _context.issuer}) + _context.cstate.set(_state, {"iss": _context.issuer}) return request_args, {} def update_service_context(self, resp, key="", **kwargs): @@ -101,8 +100,11 @@ def post_parse_response(self, response, **kwargs): if _idt: # If there is a verified ID Token then we have to do nonce # verification. - _req_nonce = self.upstream_get("context").cstate.get_set( - response["state"], claim=['nonce']).get('nonce') + _req_nonce = ( + self.upstream_get("context") + .cstate.get_set(response["state"], claim=["nonce"]) + .get("nonce") + ) if _req_nonce: _id_token_nonce = _idt.get("nonce") if not _id_token_nonce: @@ -123,6 +125,7 @@ def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): if _response_types: request_args["response_type"] = _response_types[0] else: + _response_types = ["code"] request_args["response_type"] = "code" # For OIDC 'openid' is required in scope @@ -133,7 +136,7 @@ def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): else: _scope = _context.get_preference("scopes_supported") if _scope: - request_args['scope'] = _scope + request_args["scope"] = _scope else: request_args["scope"] = "openid" elif "openid" not in request_args["scope"]: @@ -210,7 +213,7 @@ def store_request_on_file(self, req, **kwargs): return _webname def construct_request_parameter( - self, req, request_param, audience=None, expires_in=0, **kwargs + self, req, request_param, audience=None, expires_in=0, **kwargs ): """Construct a request parameter""" alg = self.get_request_object_signing_alg(**kwargs) @@ -218,7 +221,7 @@ def construct_request_parameter( _context = self.upstream_get("context") if "keys" not in kwargs and alg and alg != "none": - kwargs["keys"] = self.upstream_get('attribute', 'keyjar') + kwargs["keys"] = self.upstream_get("attribute", "keyjar") if alg == "none": kwargs["keys"] = [] @@ -257,13 +260,13 @@ def construct_request_parameter( _req_jwt = make_openid_request(req, **_mor_args) - if 'target' not in kwargs: - kwargs['target'] = _context.provider_info.get("issuer", _context.issuer) + if "target" not in kwargs: + kwargs["target"] = _context.provider_info.get("issuer", _context.issuer) # Should the request be encrypted - _req_jwte = request_object_encryption(_req_jwt, _context, - self.upstream_get('attribute', 'keyjar'), - **kwargs) + _req_jwte = request_object_encryption( + _req_jwt, _context, self.upstream_get("attribute", "keyjar"), **kwargs + ) return _req_jwte def oidc_post_construct(self, req, **kwargs): @@ -316,8 +319,7 @@ def oidc_post_construct(self, req, **kwargs): return req def gather_verify_arguments( - self, response: Optional[Union[dict, Message]] = None, - behaviour_args: Optional[dict] = None + self, response: Optional[Union[dict, Message]] = None, behaviour_args: Optional[dict] = None ): """ Need to add some information before running verify() @@ -327,7 +329,7 @@ def gather_verify_arguments( _context = self.upstream_get("context") kwargs = { "iss": _context.issuer, - "keyjar": self.upstream_get('attribute', 'keyjar'), + "keyjar": self.upstream_get("attribute", "keyjar"), "verify": True, "skew": _context.clock_skew, } @@ -355,43 +357,47 @@ def gather_verify_arguments( return kwargs def _do_request_uris(self, base_url, hex, context, callback_uris): - _uri_name = 'request_uris' - if context.get_preference('request_parameter') == _uri_name: + _uri_name = "request_uris" + if context.get_preference("request_parameter") == _uri_name: if _uri_name not in callback_uris: - callback_uris[_uri_name] = self.get_uri(base_url, - self._callback_path[_uri_name], - hex) + callback_uris[_uri_name] = self.get_uri( + base_url, self._callback_path[_uri_name], hex + ) return callback_uris def _do_type(self, context, typ, response_types): - if typ == 'code' and 'code' in response_types: - if typ in context.get_preference('response_modes_supported'): - return True - elif typ == 'implicit': - if typ in context.get_preference('response_modes_supported'): + if typ == "code" and "code" in response_types: + if typ in context.get_preference("response_modes_supported"): + return "query" + elif typ == "implicit": + if typ in context.get_preference("response_modes_supported"): if implicit_response_types(response_types): - return True - elif typ == 'form_post': - if typ in context.get_preference('response_modes_supported'): - return True - return False - - def construct_uris(self, - base_url: str, - hex: bytes, - context: ServiceContext, - targets: Optional[List[str]] = None, - response_types: Optional[List[str]] = None): - _callback_uris = context.get_preference('callback_uris', {}) + return "fragment" + elif typ == "form_post": + if typ in context.get_preference("response_modes_supported"): + return "form_post" + return '' + + def construct_uris( + self, + base_url: str, + hex: bytes, + context: ServiceContext, + targets: Optional[List[str]] = None, + response_types: Optional[List[str]] = None, + ): + _callback_uris = context.get_preference("callback_uris", {}) for uri_name in self._callback_path.keys(): - if uri_name == 'redirect_uris': - _callback_uris = self._do_redirect_uris(base_url, hex, context, _callback_uris, - response_types) - elif uri_name == 'request_uris': + if uri_name == "redirect_uris": + _callback_uris = self._do_redirect_uris( + base_url, hex, context, _callback_uris, response_types + ) + elif uri_name == "request_uris": _callback_uris = self._do_request_uris(base_url, hex, context, _callback_uris) else: - _callback_uris[uri_name] = self.get_uri(base_url, self._callback_path[uri_name], - hex) + _callback_uris[uri_name] = self.get_uri( + base_url, self._callback_path[uri_name], hex + ) return _callback_uris diff --git a/src/idpyoidc/client/oidc/check_id.py b/src/idpyoidc/client/oidc/check_id.py index 3e33e3c7..569df02b 100644 --- a/src/idpyoidc/client/oidc/check_id.py +++ b/src/idpyoidc/client/oidc/check_id.py @@ -24,10 +24,7 @@ def __init__(self, upstream_get, conf=None): self.pre_construct = [self.oidc_pre_construct] def oidc_pre_construct(self, request_args: Optional[dict] = None, **kwargs): - _args = self.upstream_get("context").cstate.get_set( - kwargs["state"], - claim=["id_token"] - ) + _args = self.upstream_get("context").cstate.get_set(kwargs["state"], claim=["id_token"]) if request_args: request_args.update() else: diff --git a/src/idpyoidc/client/oidc/check_session.py b/src/idpyoidc/client/oidc/check_session.py index b089e2d3..89056bd2 100644 --- a/src/idpyoidc/client/oidc/check_session.py +++ b/src/idpyoidc/client/oidc/check_session.py @@ -23,8 +23,7 @@ def __init__(self, upstream_get, conf=None): self.pre_construct = [self.oidc_pre_construct] def oidc_pre_construct(self, request_args=None, **kwargs): - _args = self.upstream_get("context").cstate.get_set(kwargs["state"], - claim=["id_token"]) + _args = self.upstream_get("context").cstate.get_set(kwargs["state"], claim=["id_token"]) if request_args: request_args.update(_args) else: diff --git a/src/idpyoidc/client/oidc/end_session.py b/src/idpyoidc/client/oidc/end_session.py index 315e672d..4396d531 100644 --- a/src/idpyoidc/client/oidc/end_session.py +++ b/src/idpyoidc/client/oidc/end_session.py @@ -22,18 +22,18 @@ class EndSession(Service): _supports = { "post_logout_redirect_uris": None, - 'frontchannel_logout_supported': None, + "frontchannel_logout_supported": None, "frontchannel_logout_uri": None, "frontchannel_logout_session_required": None, - 'backchannel_logout_supported': None, + "backchannel_logout_supported": None, "backchannel_logout_uri": None, - "backchannel_logout_session_required": None + "backchannel_logout_session_required": None, } _callback_path = { "frontchannel_logout_uri": "fc_logout", "backchannel_logout_uri": "bc_logout", - "post_logout_redirect_uris": ["session_logout"] + "post_logout_redirect_uris": ["session_logout"], } def __init__(self, upstream_get, conf=None): @@ -53,7 +53,7 @@ def get_id_token_hint(self, request_args=None, **kwargs): :return: """ - _id_token = self.upstream_get("context").cstate.get_claim(kwargs["state"], claim='id_token') + _id_token = self.upstream_get("context").cstate.get_claim(kwargs["state"], claim="id_token") if _id_token: request_args["id_token_hint"] = _id_token diff --git a/src/idpyoidc/client/oidc/provider_info_discovery.py b/src/idpyoidc/client/oidc/provider_info_discovery.py index a05fde77..0f8b0a00 100644 --- a/src/idpyoidc/client/oidc/provider_info_discovery.py +++ b/src/idpyoidc/client/oidc/provider_info_discovery.py @@ -1,6 +1,7 @@ import logging from typing import Optional +from idpyoidc.client.defaults import OIDCONF_PATTERN from idpyoidc.client.oauth2 import server_metadata from idpyoidc.message import oidc from idpyoidc.message.oauth2 import ResponseMessage @@ -28,14 +29,15 @@ def add_redirect_uris(request_args, service=None, **kwargs): if "redirect_uris" not in request_args: # Callbacks is a dictionary with callback type 'code', 'implicit', # 'form_post' as keys. - _callback = _work_environment.get_preference('callback') + _callback = _work_environment.get_preference("callback") if _callback: # Filter out local additions. _uris = [v for k, v in _callback.items() if not k.startswith("__")] request_args["redirect_uris"] = _uris else: request_args["redirect_uris"] = _work_environment.get_preference( - "redirect_uris", _work_environment.supports.get('redirect_uris')) + "redirect_uris", _work_environment.supports.get("redirect_uris") + ) return request_args, {} @@ -45,6 +47,7 @@ class ProviderInfoDiscovery(server_metadata.ServerMetadata): response_cls = oidc.ProviderConfigurationResponse error_msg = ResponseMessage service_name = "provider_info" + url_pattern = OIDCONF_PATTERN _include = {} _supports = {} @@ -52,13 +55,14 @@ class ProviderInfoDiscovery(server_metadata.ServerMetadata): def __init__(self, upstream_get, conf=None): server_metadata.ServerMetadata.__init__(self, upstream_get, conf=conf) - def update_service_context(self, resp, key: Optional[str] = '', **kwargs): + def update_service_context(self, resp, key: Optional[str] = "", **kwargs): _context = self.upstream_get("context") self._update_service_context(resp) _context.map_supported_to_preferred(resp) if "pre_load_keys" in self.conf and self.conf["pre_load_keys"]: - _jwks = self.upstream_get('attribute', 'keyjar').export_jwks_as_json( - issuer=resp["issuer"]) + _jwks = self.upstream_get("attribute", "keyjar").export_jwks_as_json( + issuer=resp["issuer"] + ) logger.info("Preloaded keys for {}: {}".format(resp["issuer"], _jwks)) def match_preferences(self, pcr=None, issuer=None): diff --git a/src/idpyoidc/client/oidc/read_registration.py b/src/idpyoidc/client/oidc/read_registration.py index cf0a02a9..e4fdc04e 100644 --- a/src/idpyoidc/client/oidc/read_registration.py +++ b/src/idpyoidc/client/oidc/read_registration.py @@ -19,9 +19,7 @@ class RegistrationRead(Service): def get_endpoint(self): try: - return self.upstream_get("context").registration_response[ - "registration_client_uri" - ] + return self.upstream_get("context").registration_response["registration_client_uri"] except KeyError: return "" @@ -40,9 +38,7 @@ def get_authn_header(self, request, authn_method, **kwargs): if authn_method == "client_secret_basic": LOGGER.debug("Client authn method: %s", authn_method) headers["Authorization"] = "Bearer {}".format( - self.upstream_get("context").registration_response[ - "registration_access_token" - ] + self.upstream_get("context").registration_response["registration_access_token"] ) return headers diff --git a/src/idpyoidc/client/oidc/registration.py b/src/idpyoidc/client/oidc/registration.py index 3c6ac713..49339053 100644 --- a/src/idpyoidc/client/oidc/registration.py +++ b/src/idpyoidc/client/oidc/registration.py @@ -72,22 +72,21 @@ def update_service_context(self, resp, key="", **kwargs): _client_id = _context.get_usage("client_id") if _client_id: _context.client_id = _client_id - _keyjar = self.upstream_get('attribute', 'keyjar') + _keyjar = self.upstream_get("attribute", "keyjar") if _keyjar: if _client_id not in _keyjar: _keyjar.import_jwks(_keyjar.export_jwks(True, ""), issuer_id=_client_id) _client_secret = _context.get_usage("client_secret") if _client_secret: if not _keyjar: - _entity = self.upstream_get('unit') + _entity = self.upstream_get("unit") _keyjar = _entity.keyjar = KeyJar() _context.client_secret = _client_secret _keyjar.add_symmetric("", _client_secret) _keyjar.add_symmetric(_client_id, _client_secret) try: - _context.set_usage("client_secret_expires_at", - resp["client_secret_expires_at"]) + _context.set_usage("client_secret_expires_at", resp["client_secret_expires_at"]) except KeyError: pass diff --git a/src/idpyoidc/client/oidc/userinfo.py b/src/idpyoidc/client/oidc/userinfo.py index 0a4cf22b..602d2ce9 100644 --- a/src/idpyoidc/client/oidc/userinfo.py +++ b/src/idpyoidc/client/oidc/userinfo.py @@ -39,12 +39,13 @@ class UserInfo(Service): endpoint_name = "userinfo_endpoint" service_name = "userinfo" default_authn_method = "bearer_header" + response_body_type = "jose" _supports = { "userinfo_signing_alg_values_supported": get_signing_algs, "userinfo_encryption_alg_values_supported": get_encryption_algs, "userinfo_encryption_enc_values_supported": get_encryption_encs, - "encrypt_userinfo_supported": False + "encrypt_userinfo_supported": False, } def __init__(self, upstream_get, conf=None): @@ -59,8 +60,7 @@ def oidc_pre_construct(self, request_args=None, **kwargs): pass else: request_args = self.upstream_get("context").cstate.get_set( - kwargs["state"], - claim=["access_token"] + kwargs["state"], claim=["access_token"] ) return request_args, {} @@ -88,7 +88,7 @@ def post_parse_response(self, response, **kwargs): try: aggregated_claims = Message().from_jwt( spec["JWT"].encode("utf-8"), - keyjar=self.upstream_get('attribute', 'keyjar') + keyjar=self.upstream_get("attribute", "keyjar"), ) except MissingSigningKey as err: logger.warning( @@ -110,8 +110,7 @@ def post_parse_response(self, response, **kwargs): return response def gather_verify_arguments( - self, response: Optional[Union[dict, Message]] = None, - behaviour_args: Optional[dict] = None + self, response: Optional[Union[dict, Message]] = None, behaviour_args: Optional[dict] = None ): """ Need to add some information before running verify() @@ -122,7 +121,7 @@ def gather_verify_arguments( kwargs = { "client_id": _context.get_client_id(), "iss": _context.issuer, - "keyjar": self.upstream_get('attribute', 'keyjar'), + "keyjar": self.upstream_get("attribute", "keyjar"), "verify": True, "skew": _context.clock_skew, } diff --git a/src/idpyoidc/client/oidc/utils.py b/src/idpyoidc/client/oidc/utils.py index 4ccd9f1c..2b428feb 100644 --- a/src/idpyoidc/client/oidc/utils.py +++ b/src/idpyoidc/client/oidc/utils.py @@ -46,7 +46,7 @@ def request_object_encryption(msg, service_context, keyjar, **kwargs): except KeyError: _kid = "" - _target = kwargs.get('target', kwargs.get('recv', None)) + _target = kwargs.get("target", kwargs.get("recv", None)) if _target is None: raise MissingRequiredAttribute("No target specified") diff --git a/src/idpyoidc/client/oidc/webfinger.py b/src/idpyoidc/client/oidc/webfinger.py index c97e8284..e71285bc 100644 --- a/src/idpyoidc/client/oidc/webfinger.py +++ b/src/idpyoidc/client/oidc/webfinger.py @@ -49,8 +49,8 @@ def update_service_context(self, resp, key="", **kwargs): for link in links: if link["rel"] == self.rel: _href = link["href"] - _context = self.upstream_get('service_context') - _http_allowed = 'http_links' in _context.get("allow", default={}) + _context = self.upstream_get("service_context") + _http_allowed = "http_links" in _context.get("allow", default={}) if _href.startswith("http://") and not _http_allowed: raise ValueError("http link not allowed ({})".format(_href)) diff --git a/src/idpyoidc/client/provider/github.py b/src/idpyoidc/client/provider/github.py index 123b1191..7749bfb4 100644 --- a/src/idpyoidc/client/provider/github.py +++ b/src/idpyoidc/client/provider/github.py @@ -29,7 +29,7 @@ class AccessToken(access_token.AccessToken): _supports = { "token_endpoint_auth_methods_supported": get_client_authn_methods, - "token_endpoint_auth_signing_alg_values_supported": get_signing_algs + "token_endpoint_auth_signing_alg_values_supported": get_signing_algs, } diff --git a/src/idpyoidc/client/provider/linkedin.py b/src/idpyoidc/client/provider/linkedin.py index aec69216..f889dffb 100644 --- a/src/idpyoidc/client/provider/linkedin.py +++ b/src/idpyoidc/client/provider/linkedin.py @@ -35,7 +35,7 @@ class AccessToken(access_token.AccessToken): _supports = { "token_endpoint_auth_methods_supported": get_client_authn_methods, - "token_endpoint_auth_signing_alg_values_supported": get_signing_algs + "token_endpoint_auth_signing_alg_values_supported": get_signing_algs, } diff --git a/src/idpyoidc/client/rp_handler.py b/src/idpyoidc/client/rp_handler.py index 2ceb0e50..fe44054d 100644 --- a/src/idpyoidc/client/rp_handler.py +++ b/src/idpyoidc/client/rp_handler.py @@ -1,11 +1,11 @@ import logging import sys import traceback +from typing import List from typing import Optional from cryptojwt import as_unicode from cryptojwt import KeyJar -from cryptojwt.key_bundle import keybundle_from_local_file from cryptojwt.key_jar import init_key_jar from cryptojwt.utils import as_bytes @@ -15,23 +15,20 @@ from idpyoidc.client.defaults import DEFAULT_RP_KEY_DEFS from idpyoidc.client.exception import ConfigurationError from idpyoidc.client.exception import OidcServiceError +from idpyoidc.client.oauth2.stand_alone_client import StandAloneClient from idpyoidc.exception import MessageException from idpyoidc.exception import MissingRequiredAttribute from idpyoidc.exception import NotForMe from idpyoidc.message.oauth2 import is_error_message from idpyoidc.message.oidc import AuthorizationRequest from idpyoidc.message.oidc import AuthorizationResponse -from idpyoidc.message.oidc import Claims from idpyoidc.message.oidc import OpenIDSchema -from idpyoidc.message.oidc import RegistrationRequest from idpyoidc.message.oidc.session import BackChannelLogoutRequest from idpyoidc.time_util import utc_time_sans_frac from idpyoidc.util import add_path from idpyoidc.util import rndstr -from . import oidc from .oauth2 import Client -from .oauth2 import dynamic_provider_info_discovery -from .oauth2.utils import pick_redirect_uri +from ..message import Message from ..message.oauth2 import ResponseMessage logger = logging.getLogger(__name__) @@ -47,7 +44,6 @@ def __init__( keyjar=None, hash_seed="", verify_ssl=True, - client_cls=None, state_db=None, httpc=None, httpc_params=None, @@ -99,7 +95,7 @@ def __init__( self.extra = kwargs - self.client_cls = client_cls or oidc.RP + self.client_cls = StandAloneClient if services is None: self.services = DEFAULT_OIDC_SERVICES else: @@ -128,8 +124,12 @@ def state2issuer(self, state): :return: An Issuer ID """ for _rp in self.issuer2rp.values(): - _iss = _rp.get_context().cstate.get_set( - state, claim=['iss']).get('iss') + try: + _set = _rp.get_context().cstate.get_set(state, claim=["iss"]) + except KeyError: + continue + + _iss = _set.get("iss") if _iss: return _iss return None @@ -156,7 +156,7 @@ def get_session_information(self, key, client=None): if not client: client = self.get_client_from_session_key(key) - return client.get_context().cstate.get(key) + return client.get_session_information(key) def init_client(self, issuer): """ @@ -180,11 +180,11 @@ def init_client(self, issuer): except KeyError: _services = self.services - if 'base_url' not in _cnf: - _cnf['base_url'] = self.base_url + if "base_url" not in _cnf: + _cnf["base_url"] = self.base_url if self.jwks_uri: - _cnf['jwks_uri'] = self.jwks_uri + _cnf["jwks_uri"] = self.jwks_uri try: client = self.client_cls( @@ -234,54 +234,13 @@ def do_provider_info( retrieved :return: issuer ID """ - logger.debug(20 * "*" + " do_provider_info " + 20 * "*") - if not client: if state: client = self.get_client_from_session_key(state) else: raise ValueError("Missing state/session key") - _context = client.get_context() - if not _context.get("provider_info"): - dynamic_provider_info_discovery(client, behaviour_args=behaviour_args) - return _context.get("provider_info")["issuer"] - else: - _pi = _context.get("provider_info") - for key, val in _pi.items(): - # All service endpoint parameters in the provider info has - # a name ending in '_endpoint' so I can look specifically - # for those - if key.endswith("_endpoint"): - for _srv in client.get_services().values(): - # Every service has an endpoint_name assigned - # when initiated. This name *MUST* match the - # endpoint names used in the provider info - if _srv.endpoint_name == key: - _srv.endpoint = val - - if "keys" in _pi: - _kj = client.get_attribute('keyjar') - for typ, _spec in _pi["keys"].items(): - if typ == "url": - for _iss, _url in _spec.items(): - _kj.add_url(_iss, _url) - elif typ == "file": - for kty, _name in _spec.items(): - if kty == "jwks": - _kj.import_jwks_from_file(_name, _context.get("issuer")) - elif kty == "rsa": # PEM file - _kb = keybundle_from_local_file(_name, "der", ["sig"]) - _kj.add_kb(_context.get("issuer"), _kb) - else: - raise ValueError("Unknown provider JWKS type: {}".format(typ)) - - _context.map_supported_to_preferred(info=_pi) - - try: - return _context.get("provider_info")["issuer"] - except KeyError: - return _context.get("issuer") + return client.do_provider_info(behaviour_args=behaviour_args) def do_client_registration( self, @@ -301,8 +260,6 @@ def do_client_registration( retrieved """ - logger.debug(20 * "*" + " do_client_registration " + 20 * "*") - if not client: if state: client = self.get_client_from_session_key(state) @@ -313,21 +270,8 @@ def do_client_registration( _iss = _context.get("issuer") self.hash2issuer[iss_id] = _iss - # This should only be interesting if the client supports Single Log Out - # if _context.callback.get("post_logout_redirect_uri") is None: - # _context.callback["post_logout_redirect_uri"] = [self.base_url] - - if not client.get_client_id(): # means I have to do dynamic client registration - if request_args is None: - request_args = {} - - if behaviour_args: - _params = RegistrationRequest().parameters() - request_args.update({k: v for k, v in behaviour_args.items() if k in _params}) - - load_registration_response(client, request_args=request_args) - else: - _context.map_preferred_to_registered() + return client.do_client_registration(request_args=request_args, + behaviour_args=behaviour_args) def do_webfinger(self, user: str) -> Client: """ @@ -349,7 +293,7 @@ def client_setup( iss_id: Optional[str] = "", user: Optional[str] = "", behaviour_args: Optional[dict] = None, - ) -> Client: + ) -> StandAloneClient: """ First if no issuer ID is given then the identifier for the user is used by the webfinger service to try to find the issuer ID. @@ -389,18 +333,17 @@ def client_setup( return client logger.debug("Get provider info") - issuer = self.do_provider_info(client, behaviour_args=behaviour_args) + issuer = client.do_provider_info(behaviour_args=behaviour_args) logger.debug("Do client registration") - self.do_client_registration(client, iss_id, behaviour_args=behaviour_args) + client.do_client_registration(behaviour_args=behaviour_args) self.issuer2rp[issuer] = client return client def _get_response_type(self, context, req_args: Optional[dict] = None): if req_args: - return req_args.get("response_type", - context.claims.get_usage("response_types")[0]) + return req_args.get("response_type", context.claims.get_usage("response_types")[0]) else: return context.claims.get_usage("response_types")[0] @@ -410,7 +353,7 @@ def init_authorization( state: Optional[str] = "", req_args: Optional[dict] = None, behaviour_args: Optional[dict] = None, - ) -> dict: + ) -> str: """ Constructs the URL that will redirect the user to the authorization endpoint of the OP/AS. @@ -431,55 +374,13 @@ def init_authorization( else: raise ValueError("Missing state/session key") - _context = client.get_context() - # _entity = client.upstream_get("entity") - _nonce = rndstr(24) - _response_type = self._get_response_type(_context, req_args) - request_args = { - "redirect_uri": pick_redirect_uri( - _context, request_args=req_args, response_type=_response_type - ), - "response_type": _response_type, - "nonce": _nonce, - } - - _scope = _context.claims.get_usage("scope") - if _scope: - request_args['scope'] = _scope - - _req_args = _context.config.get("request_args") - if _req_args: - if "claims" in _req_args: - _req_args["claims"] = Claims(**_req_args["claims"]) - request_args.update(_req_args) - - if req_args is not None: - request_args.update(req_args) - - # Need a new state for a new authorization request - _current = _context.cstate - _state = _current.create_key() - request_args["state"] = _state - _current.bind_key(_nonce, _state) - _current.set(_state, {'iss': _context.get("issuer")}) - - logger.debug("Authorization request args: {}".format(request_args)) - - # if behaviour_args and "request_param" not in behaviour_args: - # _pi = _context.get("provider_info") - - _srv = client.get_service("authorization") - _info = _srv.get_request_parameters( - request_args=request_args, behaviour_args=behaviour_args - ) - logger.debug("Authorization info: {}".format(_info)) - return {"url": _info["url"], "state": _state} + return client.init_authorization(req_args=req_args, behaviour_args=behaviour_args) def begin(self, issuer_id="", user_id="", req_args=None, behaviour_args=None): """ This is the first of the 3 high level methods that most users of this library should confine them self to use. - If will use client_setup to produce a Client instance ready to be used + It will use client_setup to produce a Client instance ready to be used against the OP/AS the user wants to use. Once it has the client it will construct an Authorization request. @@ -497,7 +398,7 @@ def begin(self, issuer_id="", user_id="", req_args=None, behaviour_args=None): client = self.client_setup(issuer_id, user_id, behaviour_args=behaviour_args) try: - res = self.init_authorization(client, req_args=req_args, behaviour_args=behaviour_args) + res = client.init_authorization(req_args=req_args, behaviour_args=behaviour_args) except Exception: message = traceback.format_exception(*sys.exc_info()) logger.error(message) @@ -550,39 +451,10 @@ def get_tokens(self, state, client: Optional[Client] = None): :return: A :py:class:`idpyoidc.message.oidc.AccessTokenResponse` or :py:class:`idpyoidc.message.oauth2.AuthorizationResponse` """ - logger.debug(20 * "*" + " get_tokens " + 20 * "*") - if client is None: client = self.get_client_from_session_key(state) - _context = client.get_context() - _claims = _context.cstate.get_set(state, claim=['code', 'redirect_uri']) - - req_args = { - "code": _claims["code"], - "state": state, - "redirect_uri": _claims["redirect_uri"], - "grant_type": "authorization_code", - "client_id": client.get_client_id(), - "client_secret": _context.get("client_secret"), - } - logger.debug("request_args: {}".format(req_args)) - try: - tokenresp = client.do_request( - "accesstoken", - request_args=req_args, - authn_method=self.get_client_authn_method(client, "token_endpoint"), - state=state, - ) - except Exception: - message = traceback.format_exception(*sys.exc_info()) - logger.error(message) - raise - else: - if is_error_message(tokenresp): - raise OidcServiceError(tokenresp["error"]) - - return tokenresp + return client.get_tokens(state) def refresh_access_token(self, state, client=None, scope=""): """ @@ -596,32 +468,10 @@ def refresh_access_token(self, state, client=None, scope=""): :return: A :py:class:`idpyoidc.message.oidc.AccessTokenResponse` instance """ - logger.debug(20 * "*" + " refresh_access_token " + 20 * "*") - - if scope: - req_args = {"scope": scope} - else: - req_args = {} - if client is None: client = self.get_client_from_session_key(state) - try: - tokenresp = client.do_request( - "refresh_token", - authn_method=self.get_client_authn_method(client, "token_endpoint"), - state=state, - request_args=req_args, - ) - except Exception: - message = traceback.format_exception(*sys.exc_info()) - logger.error(message) - raise - else: - if is_error_message(tokenresp): - raise OidcServiceError(tokenresp["error"]) - - return tokenresp + return client.refresh_access_token(state, scope="") def get_user_info(self, state, client=None, access_token="", **kwargs): """ @@ -635,35 +485,21 @@ def get_user_info(self, state, client=None, access_token="", **kwargs): :return: A :py:class:`idpyoidc.message.oidc.OpenIDSchema` instance """ - logger.debug(20 * "*" + " get_user_info " + 20 * "*") - if client is None: client = self.get_client_from_session_key(state) - if not access_token: - _arg = client.get_context().cstate.get_set(state, claim=["access_token"]) - access_token = _arg["access_token"] - - request_args = {"access_token": access_token} - - resp = client.do_request("userinfo", state=state, request_args=request_args, **kwargs) - if is_error_message(resp): - raise OidcServiceError(resp["error"]) - - return resp + return client.get_user_info(state, access_token=access_token, **kwargs) @staticmethod - def userinfo_in_id_token(id_token): + def userinfo_in_id_token(id_token: Message, user_info_claims: Optional[List] = None) -> dict: """ - Given an verified ID token return all the claims that may been user + Given a verified ID token return all the claims that may be user information. :param id_token: An :py:class:`idpyoidc.message.oidc.IDToken` instance :return: A dictionary with user information """ - res = dict([(k, id_token[k]) for k in OpenIDSchema.c_param.keys() if k in id_token]) - res.update(id_token.extra()) - return res + return StandAloneClient.userinfo_in_id_token(id_token, user_info_claims) def finalize_auth( self, client, issuer: str, response: dict, behaviour_args: Optional[dict] = None @@ -679,38 +515,10 @@ def finalize_auth( :py:class:`idpyoidc.message.oauth2.AuthorizationResponse` instance. """ - logger.debug(20 * "*" + " finalize_auth " + 20 * "*") - - _srv = client.get_service("authorization") - try: - authorization_response = _srv.parse_response( - response, sformat="dict", behaviour_args=behaviour_args - ) - except Exception as err: - logger.error("Parsing authorization_response: {}".format(err)) - message = traceback.format_exception(*sys.exc_info()) - logger.error(message) - raise - else: - logger.debug("Authz response: {}".format(authorization_response.to_dict())) - - if is_error_message(authorization_response): - return authorization_response - - _context = client.get_context() - try: - _iss = _context.cstate.get_set( - authorization_response["state"], claim=['iss']).get('iss') - except KeyError: - raise KeyError("Unknown state value") - - if _iss != issuer: - logger.error("Issuer problem: {} != {}".format(_iss, issuer)) - # got it from the wrong bloke - raise ValueError("Impersonator {}".format(issuer)) + if not client: + client = self.issuer2rp[issuer] - _srv.update_service_context(authorization_response, key=authorization_response["state"]) - return authorization_response + return client.finalize_auth(response, behaviour_args=behaviour_args) def get_access_and_id_token( self, @@ -733,58 +541,11 @@ def get_access_and_id_token( was returned otherwise None. """ - logger.debug(20 * "*" + " get_access_and_id_token " + 20 * "*") - if client is None: client = self.get_client_from_session_key(state) - _context = client.get_context() - - resp_attr = authorization_response or _context.cstate.get_set(state, - message=AuthorizationResponse) - if resp_attr is None: - raise ValueError("One of authorization_response or state must be provided") - - if not state: - state = authorization_response["state"] - - _req_attr = _context.cstate.get_set(state, AuthorizationRequest) - _resp_type = set(_req_attr["response_type"].split(' ')) - - access_token = None - id_token = None - if _resp_type in [{"id_token"}, {"id_token", "token"}, {"code", "id_token", "token"}]: - id_token = authorization_response["__verified_id_token"] - - if _resp_type in [ - {"token"}, - {"id_token", "token"}, - {"code", "token"}, - {"code", "id_token", "token"}, - ]: - access_token = authorization_response["access_token"] - if behaviour_args: - if behaviour_args.get("collect_tokens", False): - # get what you can from the token endpoint - token_resp = self.get_tokens(state, client=client) - if is_error_message(token_resp): - return False, "Invalid response %s." % token_resp["error"] - # Now which access_token should I use - access_token = token_resp["access_token"] - # May or may not get an ID Token - id_token = token_resp.get("__verified_id_token") - - elif _resp_type in [{"code"}, {"code", "id_token"}]: - # get the access token - token_resp = self.get_tokens(state, client=client) - if is_error_message(token_resp): - return False, "Invalid response %s." % token_resp["error"] - - access_token = token_resp["access_token"] - # May or may not get an ID Token - id_token = token_resp.get("__verified_id_token") - - return {"access_token": access_token, "id_token": id_token} + return client.get_access_and_id_token(authorization_response=authorization_response, + state=state, behaviour_args=behaviour_args) # noinspection PyUnusedLocal def finalize(self, issuer, response, behaviour_args: Optional[dict] = None): @@ -807,71 +568,7 @@ def finalize(self, issuer, response, behaviour_args: Optional[dict] = None): client = self.issuer2rp[issuer] - if behaviour_args: - logger.debug(f"Finalize behaviour args: {behaviour_args}") - - authorization_response = self.finalize_auth(client, issuer, response) - if is_error_message(authorization_response): - return { - "state": authorization_response["state"], - "error": authorization_response["error"], - } - - _state = authorization_response["state"] - token = self.get_access_and_id_token( - authorization_response, state=_state, client=client, behaviour_args=behaviour_args - ) - _id_token = token.get("id_token") - logger.debug(f"ID Token: {_id_token}") - - if client.get_service("userinfo") and token["access_token"]: - inforesp = self.get_user_info( - state=authorization_response["state"], - client=client, - access_token=token["access_token"], - ) - - if isinstance(inforesp, ResponseMessage) and "error" in inforesp: - return {"error": "Invalid response %s." % inforesp["error"], "state": _state} - - elif _id_token: # look for it in the ID Token - inforesp = self.userinfo_in_id_token(_id_token) - else: - inforesp = {} - - logger.debug("UserInfo: %s", inforesp) - - _context = client.get_context() - try: - _sid_support = _context.get("provider_info")["backchannel_logout_session_required"] - except KeyError: - try: - _sid_support = _context.get("provider_info")[ - "frontchannel_logout_session_required" - ] - except Exception: - _sid_support = False - - if _sid_support and _id_token: - try: - sid = _id_token["sid"] - except KeyError: - pass - else: - _context.cstate.bind_key(sid, _state) - - if _id_token: - _context.cstate.bind_key(_id_token["sub"], _state) - else: - _context.cstate.bind_key(inforesp["sub"], _state) - - return { - "userinfo": inforesp, - "state": authorization_response["state"], - "token": token["access_token"], - "id_token": _id_token, - "session_state": authorization_response.get("session_state", ""), - } + return client.finalize(response, behaviour_args) def has_active_authentication(self, state): """ @@ -882,17 +579,7 @@ def has_active_authentication(self, state): """ client = self.get_client_from_session_key(state) - - # Look for an IdToken - _arg = client.get_context().cstate.get_set(state, - claim=["__verified_id_token"]) - - if _arg: - _now = utc_time_sans_frac() - exp = _arg["__verified_id_token"]["exp"] - return _now < exp - else: - return False + return client.has_active_authentication(state) def get_valid_access_token(self, state): """ @@ -900,32 +587,11 @@ def get_valid_access_token(self, state): :param state: :return: An access token if a valid one exists and when it - expires. Otherwise raise exception. + expires. Other wise raise exception. """ - token = None - indefinite = [] - now = utc_time_sans_frac() - client = self.get_client_from_session_key(state) - _context = client.get_context() - _args = _context.cstate.get_set(state, claim=["access_token", "__expires_at"]) - if "access_token" in _args: - access_token = _args["access_token"] - _exp = _args.get("__expires_at", 0) - if not _exp: # No expiry date, lives for ever - indefinite.append((access_token, 0)) - else: - if _exp > now: # expires sometime in the future - token = (access_token, _exp) - - if indefinite: - return indefinite[0] - else: - if token: - return token - else: - raise OidcServiceError("No valid access token") + return client.get_valid_access_token(state) def logout( self, @@ -944,112 +610,27 @@ def logout( :return: Request arguments """ - logger.debug(20 * "*" + " logout " + 20 * "*") - if client is None: client = self.get_client_from_session_key(state) - try: - srv = client.get_service("end_session") - except KeyError: - raise OidcServiceError("Does not know how to logout") + return client.logout(state, post_logout_redirect_uri=post_logout_redirect_uri) - if post_logout_redirect_uri: - request_args = {"post_logout_redirect_uri": post_logout_redirect_uri} - else: - request_args = {} - - resp = srv.get_request_parameters(state=state, request_args=request_args) - - logger.debug(f"EndSession Request: {resp}") - return resp def close( self, state: str, issuer: Optional[str] = "", post_logout_redirect_uri: Optional[str] = "" ) -> dict: - logger.debug(20 * "*" + " close " + 20 * "*") - if issuer: client = self.issuer2rp[issuer] else: client = self.get_client_from_session_key(state) - return self.logout( - state=state, client=client, post_logout_redirect_uri=post_logout_redirect_uri + return client.logout( + state=state, post_logout_redirect_uri=post_logout_redirect_uri ) def clear_session(self, state): client = self.get_client_from_session_key(state) client.get_context().cstate.remove_state(state) - -def backchannel_logout(client, request="", request_args=None): - """ - - :param request: URL encoded logout request - :return: - """ - if request: - req = BackChannelLogoutRequest().from_urlencoded(as_unicode(request)) - elif request_args: - req = BackChannelLogoutRequest(**request_args) - else: - raise MissingRequiredAttribute("logout_token") - - _context = client.get_context() - kwargs = { - "aud": client.get_client_id(), - "iss": _context.get("issuer"), - "keyjar": client.get_attribute('keyjar'), - "allowed_sign_alg": _context.get("registration_response").get( - "id_token_signed_response_alg", "RS256" - ), - } - - logger.debug(f"(backchannel_logout) Verifying request using: {kwargs}") - try: - req.verify(**kwargs) - except (MessageException, ValueError, NotForMe) as err: - raise MessageException("Bogus logout request: {}".format(err)) - else: - logger.debug("Request verified OK") - - # Find the subject through 'sid' or 'sub' - sub = req[verified_claim_name("logout_token")].get("sub") - sid = None - if not sub: - sid = req[verified_claim_name("logout_token")].get("sid") - - if not sub and not sid: - raise MessageException('Neither "sid" nor "sub"') - elif sub: - _state = _context.cstate.get_base_key(sub) - elif sid: - _state = _context.cstate.get_base_key(sid) - else: - _state = None - - return _state - - -def load_registration_response(client, request_args=None): - """ - If the client has been statically registered that information - must be provided during the configuration. If expected to be - done dynamically this method will do dynamic client registration. - - :param client: A :py:class:`idpyoidc.client.oidc.Client` instance - """ - if not client.get_context().get_client_id(): - try: - response = client.do_request("registration", request_args=request_args) - except KeyError: - raise ConfigurationError("No registration info") - except Exception as err: - logger.error(err) - raise - else: - if "error" in response: - raise OidcServiceError(response.to_json()) diff --git a/src/idpyoidc/client/service.py b/src/idpyoidc/client/service.py index e17bbf49..0b9dd641 100644 --- a/src/idpyoidc/client/service.py +++ b/src/idpyoidc/client/service.py @@ -8,14 +8,16 @@ from typing import Union from urllib.parse import urlparse +from cryptojwt.jwe.jwe import factory as jwe_factory +from cryptojwt.jws.jws import factory as jws_factory from cryptojwt.jwt import JWT from idpyoidc.client.exception import Unsupported from idpyoidc.impexp import ImpExp from idpyoidc.item import DLDict from idpyoidc.message import Message -from idpyoidc.message.oauth2 import is_error_message from idpyoidc.message.oauth2 import ResponseMessage +from idpyoidc.message.oauth2 import is_error_message from idpyoidc.util import importer from .client_auth import client_auth_setup from .client_auth import method_to_item @@ -74,10 +76,7 @@ class Service(ImpExp): _callback_path = {} def __init__( - self, - upstream_get: Callable, - conf: Optional[Union[dict, Configuration]] = None, - **kwargs + self, upstream_get: Callable, conf: Optional[Union[dict, Configuration]] = None, **kwargs ): ImpExp.__init__(self) @@ -94,7 +93,7 @@ def __init__( "http_method", "request_body_type", "response_body_type", - "default_authn_method" + "default_authn_method", ]: if param in conf: setattr(self, param, conf[param]) @@ -111,13 +110,15 @@ def __init__( if self.default_authn_method: if self.default_authn_method not in self.client_authn_methods: self.client_authn_methods[self.default_authn_method] = single_authn_setup( - self.default_authn_method, None) + self.default_authn_method, None + ) else: self.conf = {} if self.default_authn_method: self.client_authn_methods[self.default_authn_method] = single_authn_setup( - self.default_authn_method, None) + self.default_authn_method, None + ) # pull in all the modifiers self.pre_construct = [] @@ -221,7 +222,7 @@ def do_post_construct(self, request_args, **kwargs): return request_args - def update_service_context(self, resp: Message, key: Optional[str] = '', **kwargs): + def update_service_context(self, resp: Message, key: Optional[str] = "", **kwargs): """ A method run after the response has been parsed and verified. @@ -297,7 +298,7 @@ def init_authentication_method(self, request, authn_method, http_args=None, **kw LOGGER.error(f"Unknown client authentication method: {authn_method}") raise Unsupported(f"Unknown client authentication method: {authn_method}") - return _func.construct(request, self, http_args=http_args, **kwargs) + return _func.construct(request=request, service=self, http_args=http_args, **kwargs) return http_args @@ -329,7 +330,7 @@ def get_endpoint(self): return self.upstream_get("context").provider_info[self.endpoint_name] def get_authn_header( - self, request: Union[dict, Message], authn_method: Optional[str] = "", **kwargs + self, request: Union[dict, Message], authn_method: Optional[str] = "", **kwargs ) -> dict: """ Construct an authorization specification to be sent in the @@ -361,11 +362,11 @@ def get_authn_method(self) -> str: return self.default_authn_method def get_headers( - self, - request: Union[dict, Message], - http_method: str, - authn_method: Optional[str] = "", - **kwargs, + self, + request: Union[dict, Message], + http_method: str, + authn_method: Optional[str] = "", + **kwargs, ) -> dict: """ @@ -381,6 +382,10 @@ def get_headers( request, authn_method=authn_method, authn_endpoint=self.endpoint_name, **kwargs ) + _authz = _headers.get("Authorization") + if _authz and _authz.startswith("Bearer"): + kwargs["token"] = _authz.split(" ")[1] + for meth in self.construct_extra_headers: _headers = meth( self.upstream_get("context"), @@ -395,7 +400,7 @@ def get_headers( return _headers def get_request_parameters( - self, request_args=None, method="", request_body_type="", authn_method="", **kwargs + self, request_args=None, method="", request_body_type="", authn_method="", **kwargs ) -> dict: """ Builds the request message and constructs the HTTP headers. @@ -497,8 +502,7 @@ def post_parse_response(self, response, **kwargs): return response def gather_verify_arguments( - self, response: Optional[Union[dict, Message]] = None, - behaviour_args: Optional[dict] = None + self, response: Optional[Union[dict, Message]] = None, behaviour_args: Optional[dict] = None ): """ Need to add some information before running verify() @@ -509,7 +513,7 @@ def gather_verify_arguments( _context = self.upstream_get("context") kwargs = { "iss": _context.issuer, - "keyjar": self.upstream_get('attribute', 'keyjar'), + "keyjar": self.upstream_get("attribute", "keyjar"), "verify": True, "client_id": _context.get_client_id(), } @@ -527,7 +531,7 @@ def _do_jwt(self, info): args["allowed_enc_algs"] = enc_algs["alg"] args["allowed_enc_encs"] = enc_algs["enc"] - _jwt = JWT(key_jar=self.upstream_get('attribute', 'keyjar'), **args) + _jwt = JWT(key_jar=self.upstream_get("attribute", "keyjar"), **args) _jwt.iss = _context.get_client_id() return _jwt.unpack(info) @@ -555,12 +559,12 @@ def _do_response(self, info, sformat, **kwargs): return resp def parse_response( - self, - info, - sformat: Optional[str] = "", - state: Optional[str] = "", - behaviour_args: Optional[dict] = None, - **kwargs, + self, + info, + sformat: Optional[str] = "", + state: Optional[str] = "", + behaviour_args: Optional[dict] = None, + **kwargs, ): """ This the start of a pipeline that will: @@ -587,16 +591,23 @@ def parse_response( LOGGER.debug("response format: %s", sformat) resp = None - if sformat == "jose": + if sformat == "jose": # can be jwe, jws or json + # the checks for JWS and JWE will be replaced with functions from cryptojwt try: - self._do_jwt(info) - sformat = "dict" - except Exception: - _keyjar = self.upstream_get("attribute", 'keyjar') - resp = self.response_cls().from_jwe(info, keys=_keyjar) + if jws_factory(info): + info = self._do_jwt(info) + except: + try: + if jwe_factory(info): + info = self._do_jwt(info) + except: + LOGGER.debug('jwe detected') + if info and isinstance(info, str): + info = json.loads(info) + sformat = "dict" elif sformat == "jwe": - _keyjar = self.upstream_get("attribute", 'keyjar') - _client_id = self.upstream_get("attribute", 'client_id') + _keyjar = self.upstream_get("attribute", "keyjar") + _client_id = self.upstream_get("attribute", "client_id") resp = self.response_cls().from_jwe(info, keys=_keyjar.get_issuer_keys(_client_id)) # If format is urlencoded 'info' may be a URL # in which case I have to get at the query/fragment part @@ -616,12 +627,16 @@ def parse_response( LOGGER.error("Missing or faulty response") raise ResponseError("Missing or faulty response") - resp = self._do_response(info, sformat, **kwargs) - - LOGGER.debug('Initial response parsing => "%s"', resp.to_dict()) + if sformat == "text": + resp = info + else: + resp = self._do_response(info, sformat, **kwargs) + LOGGER.debug('Initial response parsing => "%s"', resp.to_dict()) # is this an error message - if is_error_message(resp): + if sformat == "text": + pass + elif is_error_message(resp): LOGGER.debug("Error response: %s", resp) else: vargs = self.gather_verify_arguments(response=resp, behaviour_args=behaviour_args) @@ -665,19 +680,21 @@ def get_callback_path(self, callback): def get_uri(base_url, path, hex): return f"{base_url}/{path}/{hex}" - def construct_uris(self, - base_url: str, - hex: bytes, - context: OidcContext, - targets: Optional[List[str]] = None, - response_types: Optional[list] = None): + def construct_uris( + self, + base_url: str, + hex: bytes, + context: OidcContext, + targets: Optional[List[str]] = None, + response_types: Optional[list] = None, + ): if not targets: targets = self._callback_path.keys() if not targets: return {} - _callback_uris = context.get_preference('callback_uris', {}) + _callback_uris = context.get_preference("callback_uris", {}) for uri in targets: if uri in _callback_uris: pass diff --git a/src/idpyoidc/client/service_context.py b/src/idpyoidc/client/service_context.py index ae6e75d0..57dfc7dc 100644 --- a/src/idpyoidc/client/service_context.py +++ b/src/idpyoidc/client/service_context.py @@ -72,7 +72,7 @@ "redirect_uris": [], "provider_info": {}, "callback": {}, - "issuer": "" + "issuer": "", } @@ -97,7 +97,7 @@ class ServiceContext(ImpExp): "httpc_params": None, "iss_hash": None, "issuer": None, - 'keyjar': KeyJar, + "keyjar": KeyJar, "claims": Claims, "provider_info": None, "requests_dir": None, @@ -111,16 +111,18 @@ class ServiceContext(ImpExp): "specs": {"load": claims_load, "dump": claims_dump}, } - init_args = ['upstream_get'] - - def __init__(self, - upstream_get: Optional[Callable] = None, - base_url: Optional[str] = "", - keyjar: Optional[KeyJar] = None, - config: Optional[Union[dict, Configuration]] = None, - cstate: Optional[Current] = None, - client_type: Optional[str] = 'oauth2', - **kwargs): + init_args = ["upstream_get"] + + def __init__( + self, + upstream_get: Optional[Callable] = None, + base_url: Optional[str] = "", + keyjar: Optional[KeyJar] = None, + config: Optional[Union[dict, Configuration]] = None, + cstate: Optional[Current] = None, + client_type: Optional[str] = "oauth2", + **kwargs, + ): ImpExp.__init__(self) config = get_configuration(config) self.config = config @@ -138,7 +140,7 @@ def __init__(self, self.kid = {"sig": {}, "enc": {}} - self.allow = config.conf.get('allow', {}) + self.allow = config.conf.get("allow", {}) self.base_url = base_url or config.conf.get("base_url", "") self.provider_info = config.conf.get("provider_info", {}) @@ -167,12 +169,15 @@ def __init__(self, for key, val in kwargs.items(): setattr(self, key, val) - self.keyjar = self.claims.load_conf(config.conf, supports=self.supports(), - keyjar=keyjar) + self.keyjar = self.claims.load_conf(config.conf, supports=self.supports(), keyjar=keyjar) + + _jwks_uri = self.provider_info.get('jwks_uri') + if _jwks_uri: + self.keyjar.load_keys(self.provider_info.get('issuer'), jwks_uri=_jwks_uri) _response_types = self.get_preference( - 'response_types_supported', - self.supports().get('response_types_supported', [])) + "response_types_supported", self.supports().get("response_types_supported", []) + ) self.construct_uris(response_types=_response_types) @@ -195,7 +200,7 @@ def filename_from_webname(self, webname): if not webname.startswith(self.base_url): raise ValueError("Webname doesn't match base_url") - _name = webname[len(self.base_url):] + _name = webname[len(self.base_url) :] if _name.startswith("/"): return _name[1:] @@ -210,7 +215,7 @@ def import_keys(self, keyspec): :param keyspec: """ - _keyjar = self.upstream_get('attribute', 'keyjar') + _keyjar = self.upstream_get("attribute", "keyjar") if _keyjar is None: _keyjar = KeyJar() new = True @@ -232,12 +237,12 @@ def import_keys(self, keyspec): _keyjar.add_kb(iss, _bundle) if new: - _unit = self.upstream_get('unit') - _unit.setattribute('keyjar', _keyjar) + _unit = self.upstream_get("unit") + _unit.setattribute("keyjar", _keyjar) def _get_crypt(self, typ, attr): _item_typ = CLI_REG_MAP.get(typ) - _alg = '' + _alg = "" if _item_typ: _alg = self.claims.get_usage(_item_typ[attr]) if not _alg: @@ -256,7 +261,7 @@ def get_sign_alg(self, typ): :param typ: ['id_token', 'userinfo', 'request_object'] :return: signing algorithm """ - return self._get_crypt(typ, 'sign') + return self._get_crypt(typ, "sign") def get_enc_alg_enc(self, typ): """ @@ -286,7 +291,7 @@ def collect_usage(self): def supports(self): res = {} if self.upstream_get: - services = self.upstream_get('services') + services = self.upstream_get("services") if not services: pass else: @@ -313,7 +318,7 @@ def set_usage(self, claim, value): def _callback_per_service(self): _cb = {} - for service in self.upstream_get('services').values(): + for service in self.upstream_get("services").values(): _cbs = service._callback_path.keys() if _cbs: _cb[service.service_name] = _cbs @@ -329,45 +334,66 @@ def construct_uris(self, response_types: Optional[list] = None): _base_url = self.get("base_url") - _callback_uris = self.get_preference('callback_uris', {}) + _callback_uris = self.get_preference("callback_uris", {}) if self.upstream_get: - services = self.upstream_get('services') + services = self.upstream_get("services") if services: for service in services.values(): - _callback_uris.update(service.construct_uris(base_url=_base_url, hex=_hex, - context=self, - response_types=response_types)) - - self.set_preference('callback_uris', _callback_uris) - if 'redirect_uris' in _callback_uris: + _callback_uris.update( + service.construct_uris( + base_url=_base_url, + hex=_hex, + context=self, + response_types=response_types, + ) + ) + + self.set_preference("callback_uris", _callback_uris) + if "redirect_uris" in _callback_uris: _redirect_uris = set() - for flow, _uris in _callback_uris['redirect_uris'].items(): + for flow, _uris in _callback_uris["redirect_uris"].items(): _redirect_uris.update(set(_uris)) - self.set_preference('redirect_uris', list(_redirect_uris)) + self.set_preference("redirect_uris", list(_redirect_uris)) def prefer_or_support(self, claim): if claim in self.claims.prefer: - return 'prefer' + return "prefer" else: - for service in self.upstream_get('services').values(): + for service in self.upstream_get("services").values(): _res = service.prefer_or_support(claim) if _res: return _res if claim in self.claims.supported(claim): - return 'support' + return "support" return None def map_supported_to_preferred(self, info: Optional[dict] = None): - self.claims.prefer = supported_to_preferred(self.supports(), - self.claims.prefer, - base_url=self.base_url, - info=info) + self.claims.prefer = supported_to_preferred( + self.supports(), self.claims.prefer, base_url=self.base_url, info=info + ) return self.claims.prefer + def map_service_against_endpoint(self, provider_config): + # Check endpoints against services + remove = [] + for srv_name, srv in self.upstream_get("services").items(): + if srv.endpoint_name: + _match = provider_config.get(srv.endpoint_name) + if _match is None: + for key in srv._supports.keys(): + if key in self.claims.prefer: + del self.claims.prefer[key] + remove.append(srv_name) + + for item in remove: + del self.upstream_get("services")[item] + def map_preferred_to_registered(self, registration_response: Optional[dict] = None): self.claims.use = preferred_to_registered( self.claims.prefer, supported=self.supports(), - registration_response=registration_response) + registration_response=registration_response, + ) + return self.claims.use diff --git a/src/idpyoidc/client/util.py b/src/idpyoidc/client/util.py index e2418cd2..38cb3a09 100755 --- a/src/idpyoidc/client/util.py +++ b/src/idpyoidc/client/util.py @@ -310,15 +310,19 @@ def lower_or_upper(config, param, default=None): IMPLICIT_RESPONSE_TYPES = [ - {'id_token'}, {'id_token', 'token'}, {'code', 'token'}, ['code', 'id_token'], - {'code', 'id_token', 'token'}, {'token'} + {"id_token"}, + {"id_token", "token"}, + {"code", "token"}, + {"code", "id_token"}, + {"code", "id_token", "token"}, + {"token"}, ] def implicit_response_types(a): res = [] for typ in a: - if set(typ.split(' ')) in IMPLICIT_RESPONSE_TYPES: + if set(typ.split(" ")) in IMPLICIT_RESPONSE_TYPES: res.append(typ) return res diff --git a/src/idpyoidc/context.py b/src/idpyoidc/context.py index 55ec0e6a..0fbcdf48 100644 --- a/src/idpyoidc/context.py +++ b/src/idpyoidc/context.py @@ -25,11 +25,11 @@ def __init__(self, config=None, entity_id=""): self.entity_id = entity_id else: if config: - val = '' - for alt in ['client_id', 'issuer', 'entity_id']: + val = "" + for alt in ["client_id", "issuer", "entity_id"]: val = config.get(alt) if val: break self.entity_id = val else: - self.entity_id = '' + self.entity_id = "" diff --git a/src/idpyoidc/impexp.py b/src/idpyoidc/impexp.py index efa2ac62..94592c0f 100644 --- a/src/idpyoidc/impexp.py +++ b/src/idpyoidc/impexp.py @@ -78,11 +78,11 @@ def local_load_adjustments(self, **kwargs): pass def load_attr( - self, - cls: Any, - item: dict, - init_args: Optional[dict] = None, - load_args: Optional[dict] = None, + self, + cls: Any, + item: dict, + init_args: Optional[dict] = None, + load_args: Optional[dict] = None, ) -> Any: if load_args: _kwargs = {"load_args": load_args} diff --git a/src/idpyoidc/message/oauth2/__init__.py b/src/idpyoidc/message/oauth2/__init__.py index e0841847..9b411790 100644 --- a/src/idpyoidc/message/oauth2/__init__.py +++ b/src/idpyoidc/message/oauth2/__init__.py @@ -137,6 +137,7 @@ class AuthorizationRequest(Message): "redirect_uri": SINGLE_OPTIONAL_STRING, "state": SINGLE_OPTIONAL_STRING, "request": SINGLE_OPTIONAL_STRING, + "resource": OPTIONAL_LIST_OF_STRINGS, # From RFC8707 } def merge(self, request_object, treatement="strict", whitelist=None): @@ -227,7 +228,7 @@ class AuthorizationResponse(ResponseMessage): { "code": SINGLE_REQUIRED_STRING, "state": SINGLE_OPTIONAL_STRING, - "iss": SINGLE_OPTIONAL_STRING, + "iss": SINGLE_OPTIONAL_STRING, # RFC 9207 "client_id": SINGLE_OPTIONAL_STRING, } ) @@ -301,12 +302,12 @@ class CCAccessTokenRequest(Message): "client_id": SINGLE_OPTIONAL_STRING, "client_secret": SINGLE_OPTIONAL_STRING, "grant_type": SINGLE_REQUIRED_STRING, - "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS + "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, } def verify(self, **kwargs): - if self['grant_type'] != 'client_credentials': - raise ValueError('Grant type MUST be client_credentials') + if self["grant_type"] != "client_credentials": + raise ValueError("Grant type MUST be client_credentials") class RefreshAccessTokenRequest(Message): @@ -345,7 +346,7 @@ class ASConfigurationResponse(Message): "scopes_supported": OPTIONAL_LIST_OF_STRINGS, "response_types_supported": REQUIRED_LIST_OF_STRINGS, "response_modes_supported": OPTIONAL_LIST_OF_STRINGS, - "grant_types_supported": REQUIRED_LIST_OF_STRINGS, + "grant_types_supported": OPTIONAL_LIST_OF_STRINGS, "token_endpoint_auth_methods_supported": OPTIONAL_LIST_OF_STRINGS, "token_endpoint_auth_signing_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, "service_documentation": SINGLE_OPTIONAL_STRING, @@ -378,6 +379,7 @@ def deserialize_from_one_of(val, msgtype, sformat): class OauthClientMetadata(Message): """Metadata for an OAuth2 Client.""" + c_param = { "redirect_uris": OPTIONAL_LIST_OF_STRINGS, "token_endpoint_auth_method": SINGLE_OPTIONAL_STRING, @@ -393,28 +395,42 @@ class OauthClientMetadata(Message): "jwks_uri": SINGLE_OPTIONAL_STRING, "jwks": SINGLE_OPTIONAL_JSON, "software_id": SINGLE_OPTIONAL_STRING, - "software_version": SINGLE_OPTIONAL_STRING + "software_version": SINGLE_OPTIONAL_STRING, + "software_statement": SINGLE_OPTIONAL_JSON, } + def verify(self, **kwargs): + super(OauthClientMetadata, self).verify(**kwargs) + + # if grant type is present and if contains the values authorization_code or + # implicit then redirect_uris must be present + + _grant_types = self.get("grant_types", []) + if set(_grant_types).intersection({"authorization_code", "implicit"}): + if "redirect_uris" not in self: + raise ValueError("Missing redirect_uris claim") + def oauth_client_metadata_deser(val, sformat="json"): """Deserializes a JSON object (most likely) into a OauthClientMetadata.""" return deserialize_from_one_of(val, OauthClientMetadata, sformat) -OPTIONAL_OAUTH_CLIENT_METADATA = (Message, False, msg_ser, - oauth_client_metadata_deser, False) +OPTIONAL_OAUTH_CLIENT_METADATA = (Message, False, msg_ser, oauth_client_metadata_deser, False) class OauthClientInformationResponse(OauthClientMetadata): """The information returned by a OAuth2 Server about an OAuth2 client.""" + c_param = OauthClientMetadata.c_param.copy() - c_param.update({ - "client_id": SINGLE_REQUIRED_STRING, - "client_secret": SINGLE_OPTIONAL_STRING, - "client_id_issued_at": SINGLE_OPTIONAL_INT, - "client_secret_expires_at": SINGLE_OPTIONAL_INT - }) + c_param.update( + { + "client_id": SINGLE_REQUIRED_STRING, + "client_secret": SINGLE_OPTIONAL_STRING, + "client_id_issued_at": SINGLE_OPTIONAL_INT, + "client_secret_expires_at": SINGLE_OPTIONAL_INT, + } + ) def verify(self, **kwargs): super(OauthClientInformationResponse, self).verify(**kwargs) @@ -422,7 +438,8 @@ def verify(self, **kwargs): if "client_secret" in self: if "client_secret_expires_at" not in self: raise MissingRequiredAttribute( - "client_secret_expires_at is a MUST if client_secret is present") + "client_secret_expires_at is a MUST if client_secret is present" + ) def oauth_client_registration_response_deser(val, sformat="json"): @@ -431,7 +448,12 @@ def oauth_client_registration_response_deser(val, sformat="json"): OPTIONAL_OAUTH_CLIENT_REGISTRATION_RESPONSE = ( - Message, False, msg_ser, oauth_client_registration_response_deser, False) + Message, + False, + msg_ser, + oauth_client_registration_response_deser, + False, +) # RFC 7662 @@ -565,30 +587,30 @@ class JWTAccessToken(Message): "auth_time": SINGLE_OPTIONAL_INT, "acr": SINGLE_OPTIONAL_STRING, "amr": OPTIONAL_LIST_OF_STRINGS, - 'scope': OPTIONAL_LIST_OF_SP_SEP_STRINGS, - 'groups': OPTIONAL_LIST_OF_STRINGS, - 'roles': OPTIONAL_LIST_OF_STRINGS, - 'entitlements': OPTIONAL_LIST_OF_STRINGS + "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, + "groups": OPTIONAL_LIST_OF_STRINGS, + "roles": OPTIONAL_LIST_OF_STRINGS, + "entitlements": OPTIONAL_LIST_OF_STRINGS, } class JSONWebToken(Message): # implements RFC 9068 c_param = { - 'iss': SINGLE_REQUIRED_STRING, - 'exp': SINGLE_REQUIRED_STRING, - 'aud': SINGLE_REQUIRED_STRING, - 'sub': SINGLE_REQUIRED_STRING, + "iss": SINGLE_REQUIRED_STRING, + "exp": SINGLE_REQUIRED_STRING, + "aud": SINGLE_REQUIRED_STRING, + "sub": SINGLE_REQUIRED_STRING, "client_id": SINGLE_REQUIRED_STRING, - 'iat': SINGLE_REQUIRED_STRING, - 'jti': SINGLE_REQUIRED_STRING, - 'auth_time': SINGLE_OPTIONAL_INT, - 'acr': SINGLE_OPTIONAL_STRING, - 'amr': OPTIONAL_LIST_OF_STRINGS, - 'scope': OPTIONAL_LIST_OF_SP_SEP_STRINGS, - 'groups': OPTIONAL_LIST_OF_STRINGS, - 'roles': OPTIONAL_LIST_OF_STRINGS, - 'entitlements': OPTIONAL_LIST_OF_STRINGS + "iat": SINGLE_REQUIRED_STRING, + "jti": SINGLE_REQUIRED_STRING, + "auth_time": SINGLE_OPTIONAL_INT, + "acr": SINGLE_OPTIONAL_STRING, + "amr": OPTIONAL_LIST_OF_STRINGS, + "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, + "groups": OPTIONAL_LIST_OF_STRINGS, + "roles": OPTIONAL_LIST_OF_STRINGS, + "entitlements": OPTIONAL_LIST_OF_STRINGS, } @@ -611,12 +633,9 @@ class TokenRevocationErrorResponse(ResponseMessage): """ Error response from the revocation endpoint """ + c_allowed_values = ResponseMessage.c_allowed_values.copy() - c_allowed_values.update({ - "error": [ - "unsupported_token_type" - ] - }) + c_allowed_values.update({"error": ["unsupported_token_type"]}) def factory(msgtype, **kwargs): diff --git a/src/idpyoidc/message/oidc/__init__.py b/src/idpyoidc/message/oidc/__init__.py index 98af8d66..a6da5063 100644 --- a/src/idpyoidc/message/oidc/__init__.py +++ b/src/idpyoidc/message/oidc/__init__.py @@ -240,6 +240,7 @@ def check_char_set(string, allowed): "encenc", "sigalg", "issuer", + "iss", "allow_missing_kid", "no_kid_issuer", "trusting", @@ -635,6 +636,7 @@ class RegistrationRequest(Message): "backchannel_logout_session_required": SINGLE_OPTIONAL_BOOLEAN, # "federation_type": OPTIONAL_LIST_OF_STRINGS, # "organization_name": SINGLE_OPTIONAL_STRING, + "response_modes": OPTIONAL_LIST_OF_STRINGS, } c_default = {"application_type": "web", "response_types": ["code"]} c_allowed_values = { diff --git a/src/idpyoidc/message/oidc/session.py b/src/idpyoidc/message/oidc/session.py index 9ac4cd5f..8136b052 100644 --- a/src/idpyoidc/message/oidc/session.py +++ b/src/idpyoidc/message/oidc/session.py @@ -135,7 +135,7 @@ def verify(self, **kwargs): except KeyError: _skew = 0 - if 'iat' in self and self["iat"] > (_now + _skew): + if "iat" in self and self["iat"] > (_now + _skew): raise ValueError("Invalid issued_at time") _allowed = kwargs.get("allowed_sign_alg") diff --git a/src/idpyoidc/metadata.py b/src/idpyoidc/metadata.py index 55a90fbc..c879a17d 100644 --- a/src/idpyoidc/metadata.py +++ b/src/idpyoidc/metadata.py @@ -29,18 +29,11 @@ def metadata_load(item: dict, **kwargs): class Metadata(ImpExp): - parameter = { - "prefer": None, - "use": None, - "callback_path": None, - "_local": None - } + parameter = {"prefer": None, "use": None, "callback_path": None, "_local": None} _supports = {} - def __init__(self, - prefer: Optional[dict] = None, - callback_path: Optional[dict] = None): + def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] = None): ImpExp.__init__(self) if isinstance(prefer, dict): @@ -73,11 +66,11 @@ def remove_preference(self, key): def _callback_uris(self, base_url, hex): _uri = [] - for type in self.get_usage("response_types", self._supports['response_types']): + for type in self.get_usage("response_types", self._supports["response_types"]): if "code" in type: - _uri.append('code') + _uri.append("code") elif type in ["id_token", "id_token token"]: - _uri.append('implicit') + _uri.append("implicit") if "form_post" in self.supports: _uri.append("form_post") @@ -87,36 +80,33 @@ def _callback_uris(self, base_url, hex): callback_uri[key] = get_uri(base_url, self.callback_path[key], hex) return callback_uri - def construct_redirect_uris(self, - base_url: str, - hex: str, - callbacks: Optional[dict] = None): + def construct_redirect_uris(self, base_url: str, hex: str, callbacks: Optional[dict] = None): if not callbacks: callbacks = self._callback_uris(base_url, hex) if callbacks: - self.set_preference('callbacks', callbacks) + self.set_preference("callbacks", callbacks) self.set_preference("redirect_uris", [v for k, v in callbacks.items()]) self.callback = callbacks - def verify_rules(self): + def verify_rules(self, supports): return True def locals(self, info): pass def _keyjar(self, keyjar=None, conf=None, entity_id=""): - _uri_path = '' + _uri_path = "" if keyjar is None: if "keys" in conf: keys_args = {k: v for k, v in conf["keys"].items() if k != "uri_path"} _keyjar = init_key_jar(**keys_args) - _uri_path = conf['keys'].get('uri_path') + _uri_path = conf["keys"].get("uri_path") elif "key_conf" in conf and conf["key_conf"]: keys_args = {k: v for k, v in conf["key_conf"].items() if k != "uri_path"} _keyjar = init_key_jar(**keys_args) - _uri_path = conf['key_conf'].get('uri_path') + _uri_path = conf["key_conf"].get("uri_path") else: _keyjar = KeyJar() if "jwks" in conf: @@ -133,9 +123,9 @@ def _keyjar(self, keyjar=None, conf=None, entity_id=""): return _keyjar, _uri_path else: if "keys" in conf: - _uri_path = conf['keys'].get('uri_path') + _uri_path = conf["keys"].get("uri_path") elif "key_conf" in conf and conf["key_conf"]: - _uri_path = conf['key_conf'].get('uri_path') + _uri_path = conf["key_conf"].get("uri_path") return keyjar, _uri_path def get_base_url(self, configuration: dict): @@ -150,8 +140,9 @@ def add_extra_keys(self, keyjar, id): def get_jwks(self, keyjar): return None - def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None, - base_url: Optional[str] = ''): + def handle_keys( + self, configuration: dict, keyjar: Optional[KeyJar] = None, base_url: Optional[str] = "" + ): _jwks = _jwks_uri = None _id = self.get_id(configuration) keyjar, uri_path = self._keyjar(keyjar, configuration, entity_id=_id) @@ -159,8 +150,8 @@ def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None, self.add_extra_keys(keyjar, _id) # now that keys are in the Key Jar, now for how to publish it - if 'jwks_uri' in configuration: # simple - _jwks_uri = configuration.get('jwks_uri') + if "jwks_uri" in configuration: # simple + _jwks_uri = configuration.get("jwks_uri") elif uri_path: if not base_url: base_url = self.get_base_url(configuration) @@ -168,10 +159,11 @@ def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None, else: # jwks or nothing _jwks = self.get_jwks(keyjar) - return {'keyjar': keyjar, 'jwks': _jwks, 'jwks_uri': _jwks_uri} + return {"keyjar": keyjar, "jwks": _jwks, "jwks_uri": _jwks_uri} - def load_conf(self, configuration, supports, keyjar: Optional[KeyJar] = None, - base_url: Optional[str] = ''): + def load_conf( + self, configuration, supports, keyjar: Optional[KeyJar] = None, base_url: Optional[str] = "" + ): for attr, val in configuration.items(): if attr == "preference": for k, v in val.items(): @@ -183,12 +175,12 @@ def load_conf(self, configuration, supports, keyjar: Optional[KeyJar] = None, self.locals(configuration) for key, val in self.handle_keys(configuration, keyjar=keyjar, base_url=base_url).items(): - if key == 'keyjar': + if key == "keyjar": keyjar = val elif val: self.set_preference(key, val) - self.verify_rules() + self.verify_rules(supports) return keyjar def get(self, key, default=None): @@ -219,7 +211,7 @@ def prefers(self): return self.prefer -SIGNING_ALGORITHM_SORT_ORDER = ['RS', 'ES', 'PS', 'HS'] +SIGNING_ALGORITHM_SORT_ORDER = ["RS", "ES", "PS", "HS"] def cmp(a, b): @@ -227,9 +219,9 @@ def cmp(a, b): def alg_cmp(a, b): - if a == 'none': + if a == "none": return 1 - elif b == 'none': + elif b == "none": return -1 _pos1 = SIGNING_ALGORITHM_SORT_ORDER.index(a[0:2]) @@ -244,16 +236,16 @@ def alg_cmp(a, b): def get_signing_algs(): # Assumes Cryptojwt - _algs = [name for name in list(SIGNER_ALGS.keys()) if name != 'none'] + _algs = [name for name in list(SIGNER_ALGS.keys()) if name != "none"] return sorted(_algs, key=cmp_to_key(alg_cmp)) def get_encryption_algs(): - return SUPPORTED['alg'] + return SUPPORTED["alg"] def get_encryption_encs(): - return SUPPORTED['enc'] + return SUPPORTED["enc"] def array_or_singleton(claim_spec, values): diff --git a/src/idpyoidc/node.py b/src/idpyoidc/node.py index 498e83e1..f1148612 100644 --- a/src/idpyoidc/node.py +++ b/src/idpyoidc/node.py @@ -3,8 +3,10 @@ from typing import Union from cryptojwt import KeyJar +from cryptojwt.key_jar import build_keyjar from cryptojwt.key_jar import init_key_jar +from idpyoidc.client.defaults import DEFAULT_KEY_DEFS from idpyoidc.configure import Configuration from idpyoidc.impexp import ImpExp from idpyoidc.util import instantiate @@ -43,8 +45,50 @@ def create_keyjar( return keyjar -class Node: +def make_keyjar( + keyjar: Optional[Union[KeyJar, bool]] = None, + config: Optional[Union[Configuration, dict]] = None, + key_conf: Optional[dict] = None, + issuer_id: Optional[str] = "", + client_id: Optional[str] = "", + ): + if keyjar is False: + return None + + keyjar = keyjar or config.get("keyjar") + key_conf = key_conf or config.get("key_conf", config.get("keys")) + + if not keyjar and not key_conf: + keyjar = KeyJar() + _jwks = config.get("jwks") + if _jwks: + keyjar.import_jwks_as_json(_jwks, client_id) + + if keyjar or key_conf: + # Should be either one + id = issuer_id or client_id + keyjar = create_keyjar(keyjar, conf=config, key_conf=key_conf, id=id) + if client_id: + _key = config.get("client_secret") + if _key: + keyjar.add_symmetric(client_id, _key) + keyjar.add_symmetric("", _key) + else: + if client_id: + _key = config.get("client_secret") + if _key: + keyjar = KeyJar() + keyjar.add_symmetric(client_id, _key) + keyjar.add_symmetric("", _key) + else: + keyjar = build_keyjar(DEFAULT_KEY_DEFS) + if issuer_id: + keyjar.import_jwks(keyjar.export_jwks(private=True), issuer_id) + + return keyjar + +class Node: def __init__(self, upstream_get: Callable = None): self.upstream_get = upstream_get @@ -76,20 +120,21 @@ def get_unit(self, *args): class Unit(ImpExp): - name = '' - - init_args = ['upstream_get'] - - def __init__(self, - upstream_get: Callable = None, - keyjar: Optional[KeyJar] = None, - httpc: Optional[object] = None, - httpc_params: Optional[dict] = None, - config: Optional[Union[Configuration, dict]] = None, - key_conf: Optional[dict] = None, - issuer_id: Optional[str] = '', - client_id: Optional[str] = '' - ): + name = "" + + init_args = ["upstream_get"] + + def __init__( + self, + upstream_get: Callable = None, + keyjar: Optional[Union[KeyJar, bool]] = None, + httpc: Optional[object] = None, + httpc_params: Optional[dict] = None, + config: Optional[Union[Configuration, dict]] = None, + key_conf: Optional[dict] = None, + issuer_id: Optional[str] = "", + client_id: Optional[str] = "", + ): ImpExp.__init__(self) self.upstream_get = upstream_get self.httpc = httpc @@ -97,30 +142,7 @@ def __init__(self, if config is None: config = {} - keyjar = keyjar or config.get('keyjar') - key_conf = key_conf or config.get('key_conf', config.get('keys')) - - if not keyjar and not key_conf: - _jwks = config.get('jwks') - if _jwks: - keyjar = KeyJar() - keyjar.import_jwks_as_json(_jwks, client_id) - - if keyjar or key_conf: - # Should be either one - id = issuer_id or client_id - self.keyjar = create_keyjar(keyjar, conf=config, key_conf=key_conf, id=id) - if client_id: - self.keyjar.add_symmetric('', client_id) - else: - if client_id: - _key = config.get("client_secret") - if _key: - self.keyjar = KeyJar() - self.keyjar.add_symmetric(client_id, _key) - self.keyjar.add_symmetric('', _key) - else: - self.keyjar = None + self.keyjar = make_keyjar(keyjar, config, key_conf, issuer_id, client_id) self.httpc_params = httpc_params or config.get("httpc_params", {}) @@ -156,9 +178,9 @@ def get_unit(self, *args): def topmost_unit(unit): - if hasattr(unit, 'upstream_get'): + if hasattr(unit, "upstream_get"): if unit.upstream_get: - next_unit = unit.upstream_get('unit') + next_unit = unit.upstream_get("unit") if next_unit: unit = topmost_unit(next_unit) @@ -166,64 +188,78 @@ def topmost_unit(unit): class ClientUnit(Unit): - name = '' - - def __init__(self, - upstream_get: Callable = None, - httpc: Optional[object] = None, - httpc_params: Optional[dict] = None, - keyjar: Optional[KeyJar] = None, - context: Optional[ImpExp] = None, - config: Optional[Union[Configuration, dict]] = None, - # jwks_uri: Optional[str] = "", - entity_id: Optional[str] = "", - key_conf: Optional[dict] = None - ): + name = "" + + def __init__( + self, + upstream_get: Callable = None, + httpc: Optional[object] = None, + httpc_params: Optional[dict] = None, + keyjar: Optional[KeyJar] = None, + context: Optional[ImpExp] = None, + config: Optional[Union[Configuration, dict]] = None, + # jwks_uri: Optional[str] = "", + entity_id: Optional[str] = "", + key_conf: Optional[dict] = None, + ): if config is None: config = {} - self.entity_id = entity_id or config.get('entity_id') - self.client_id = config.get('client_id', entity_id) + self.entity_id = entity_id or config.get("entity_id") + self.client_id = config.get("client_id", entity_id) - Unit.__init__(self, upstream_get=upstream_get, keyjar=keyjar, httpc=httpc, - httpc_params=httpc_params, config=config, client_id=self.client_id, - key_conf=key_conf) + Unit.__init__( + self, + upstream_get=upstream_get, + keyjar=keyjar, + httpc=httpc, + httpc_params=httpc_params, + config=config, + client_id=self.client_id, + key_conf=key_conf, + ) self.context = context or None def get_context_attribute(self, attr, *args): _val = getattr(self.context, attr) if not _val and self.upstream_get: - return self.upstream_get('context_attribute', attr) + return self.upstream_get("context_attribute", attr) else: return _val # Neither client nor Server class Collection(Unit): - - def __init__(self, - upstream_get: Callable = None, - keyjar: Optional[KeyJar] = None, - httpc: Optional[object] = None, - httpc_params: Optional[dict] = None, - config: Optional[Union[Configuration, dict]] = None, - entity_id: Optional[str] = "", - key_conf: Optional[dict] = None, - functions: Optional[dict] = None, - claims: Optional[dict] = None - ): + def __init__( + self, + upstream_get: Callable = None, + keyjar: Optional[KeyJar] = None, + httpc: Optional[object] = None, + httpc_params: Optional[dict] = None, + config: Optional[Union[Configuration, dict]] = None, + entity_id: Optional[str] = "", + key_conf: Optional[dict] = None, + functions: Optional[dict] = None, + claims: Optional[dict] = None, + ): if config is None: config = {} - self.entity_id = entity_id or config.get('entity_id') + self.entity_id = entity_id or config.get("entity_id") - Unit.__init__(self, upstream_get, keyjar, httpc, httpc_params, config, - issuer_id=self.entity_id, key_conf=key_conf) + Unit.__init__( + self, + upstream_get, + keyjar, + httpc, + httpc_params, + config, + issuer_id=self.entity_id, + key_conf=key_conf, + ) - _args = { - 'upstream_get': self.unit_get - } + _args = {"upstream_get": self.unit_get} self.claims = claims or {} self.upstream_get = upstream_get @@ -236,14 +272,14 @@ def __init__(self, setattr(self, key, instantiate(val["class"], **_kwargs)) def get_context_attribute(self, attr, *args): - _cntx = getattr(self, 'context', None) + _cntx = getattr(self, "context", None) if _cntx: _val = getattr(_cntx, attr, None) if _val: return _val if self.upstream_get: - return self.upstream_get('context_attribute', attr) + return self.upstream_get("context_attribute", attr) else: return None @@ -253,6 +289,6 @@ def get_attribute(self, attr, *args): return val if self.upstream_get: - return self.upstream_get('attribute', attr) + return self.upstream_get("attribute", attr) else: return None diff --git a/src/idpyoidc/server/__init__.py b/src/idpyoidc/server/__init__.py index 7f3d7d94..0509feeb 100644 --- a/src/idpyoidc/server/__init__.py +++ b/src/idpyoidc/server/__init__.py @@ -8,12 +8,14 @@ from cryptojwt import KeyJar from idpyoidc.node import Unit + # from idpyoidc.server import authz # from idpyoidc.server.client_authn import client_auth_setup from idpyoidc.server.configure import ASConfiguration from idpyoidc.server.configure import OPConfiguration from idpyoidc.server.endpoint import Endpoint from idpyoidc.server.endpoint_context import EndpointContext + # from idpyoidc.server.session.manager import create_session_manager # from idpyoidc.server.user_authn.authn_context import populate_authn_broker from idpyoidc.server.util import allow_refresh_token @@ -34,23 +36,30 @@ class Server(Unit): parameter = {"endpoint": [Endpoint], "context": EndpointContext} def __init__( - self, - conf: Union[dict, OPConfiguration, ASConfiguration], - keyjar: Optional[KeyJar] = None, - cwd: Optional[str] = "", - cookie_handler: Optional[Any] = None, - httpc: Optional[Callable] = None, - upstream_get: Optional[Callable] = None, - httpc_params: Optional[dict] = None, - entity_id: Optional[str] = "", - key_conf: Optional[dict] = None + self, + conf: Union[dict, OPConfiguration, ASConfiguration], + keyjar: Optional[KeyJar] = None, + cwd: Optional[str] = "", + cookie_handler: Optional[Any] = None, + httpc: Optional[Callable] = None, + upstream_get: Optional[Callable] = None, + httpc_params: Optional[dict] = None, + entity_id: Optional[str] = "", + key_conf: Optional[dict] = None, ): - self.entity_id = entity_id or conf.get('entity_id') - self.issuer = conf.get('issuer', self.entity_id) + self.entity_id = entity_id or conf.get("entity_id") + self.issuer = conf.get("issuer", self.entity_id) - Unit.__init__(self, config=conf, keyjar=keyjar, httpc=httpc, upstream_get=upstream_get, - httpc_params=httpc_params, key_conf=key_conf, - issuer_id=self.issuer) + Unit.__init__( + self, + config=conf, + keyjar=keyjar, + httpc=httpc, + upstream_get=upstream_get, + httpc_params=httpc_params, + key_conf=key_conf, + issuer_id=self.issuer, + ) self.upstream_get = upstream_get if isinstance(conf, OPConfiguration) or isinstance(conf, ASConfiguration): @@ -65,7 +74,7 @@ def __init__( upstream_get=self.unit_get, # points to me cwd=cwd, cookie_handler=cookie_handler, - keyjar=self.keyjar + keyjar=self.keyjar, ) # Need to have context in place before doing this @@ -75,10 +84,10 @@ def __init__( self.endpoint[endpoint_name].upstream_get = self.unit_get _token_endp = self.endpoint.get("token") - if _token_endp: - _token_endp.allow_refresh = allow_refresh_token(self.context) self.context.map_supported_to_preferred() + if _token_endp: + _token_endp.allow_refresh = allow_refresh_token(self.context) def get_endpoints(self, *arg): return self.endpoint @@ -104,4 +113,4 @@ def get_entity(self, *args): def get_context_attribute(self, attr, *args): _val = getattr(self.context, attr) if not _val and self.upstream_get: - return self.upstream_get('context_attribute', attr) + return self.upstream_get("context_attribute", attr) diff --git a/src/idpyoidc/server/authz/__init__.py b/src/idpyoidc/server/authz/__init__.py index 8fdcb268..b326e322 100755 --- a/src/idpyoidc/server/authz/__init__.py +++ b/src/idpyoidc/server/authz/__init__.py @@ -62,11 +62,9 @@ def __call__( resources: Optional[list] = None, ) -> Grant: _context = self.upstream_get("context") - session_info = _context.session_manager.get_session_info( - session_id=session_id, grant=True - ) + session_info = _context.session_manager.get_session_info(session_id=session_id, grant=True) grant = session_info["grant"] - _client_id = session_info['client_id'] + _client_id = session_info["client_id"] args = self.grant_config.copy() diff --git a/src/idpyoidc/server/claims/__init__.py b/src/idpyoidc/server/claims/__init__.py index 4c37b47f..6ca13ecc 100644 --- a/src/idpyoidc/server/claims/__init__.py +++ b/src/idpyoidc/server/claims/__init__.py @@ -4,21 +4,19 @@ class Claims(claims.Claims): - def get_base_url(self, configuration: dict): - _base = configuration.get('base_url') + _base = configuration.get("base_url") if not _base: - _base = configuration.get('issuer') + _base = configuration.get("issuer") return _base def get_id(self, configuration: dict): - return configuration.get('issuer') + return configuration.get("issuer") - def supported_to_preferred(self, - supported: dict, - base_url: Optional[str] = '', - info: Optional[dict] = None): + def supported_to_preferred( + self, supported: dict, base_url: Optional[str] = "", info: Optional[dict] = None + ): # Add defaults for key, val in supported.items(): if val is None: diff --git a/src/idpyoidc/server/claims/oauth2.py b/src/idpyoidc/server/claims/oauth2.py index f0137543..86e969df 100644 --- a/src/idpyoidc/server/claims/oauth2.py +++ b/src/idpyoidc/server/claims/oauth2.py @@ -35,9 +35,7 @@ class Claims(claims.Claims): callback_uris = ["redirect_uris"] - def __init__(self, - prefer: Optional[dict] = None, - callback_path: Optional[dict] = None): + def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] = None): claims.Claims.__init__(self, prefer=prefer, callback_path=callback_path) def provider_info(self, supports): diff --git a/src/idpyoidc/server/claims/oidc.py b/src/idpyoidc/server/claims/oidc.py index f2b57506..a6620a05 100644 --- a/src/idpyoidc/server/claims/oidc.py +++ b/src/idpyoidc/server/claims/oidc.py @@ -21,6 +21,7 @@ "subject_type": "subject_types_supported", "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", "response_types": "response_types_supported", + "response_modes": "response_modes_supported", "grant_types": "grant_types_supported", # In OAuth2 but not in OIDC "scope": "scopes_supported", @@ -57,7 +58,7 @@ class Claims(server_claims.Claims): "require_auth_time": None, "scopes_supported": ["openid"], "service_documentation": None, - 'subject_types_supported': ['public', 'pairwise', 'ephemeral'], + "subject_types_supported": ["public", "pairwise", "ephemeral"], "op_tos_uri": None, "ui_locales_supported": None, # "version": '3.0' @@ -66,30 +67,29 @@ class Claims(server_claims.Claims): register2preferred = REGISTER2PREFERRED - def __init__(self, - prefer: Optional[dict] = None, - callback_path: Optional[dict] = None - ): + def __init__(self, prefer: Optional[dict] = None, callback_path: Optional[dict] = None): server_claims.Claims.__init__(self, prefer=prefer, callback_path=callback_path) - def verify_rules(self): + def verify_rules(self, supports): if self.get_preference("request_parameter_supported") and self.get_preference( - "request_uri_parameter_supported"): + "request_uri_parameter_supported" + ): raise ValueError( "You have to chose one of 'request_parameter_supported' and " - "'request_uri_parameter_supported'. You can't have both.") + "'request_uri_parameter_supported'. You can't have both." + ) - if not self.get_preference('encrypt_userinfo_supported'): - self.set_preference('userinfo_encryption_alg_values_supported', []) - self.set_preference('userinfo_encryption_enc_values_supported', []) + if not self.get_preference("encrypt_userinfo_supported"): + self.set_preference("userinfo_encryption_alg_values_supported", []) + self.set_preference("userinfo_encryption_enc_values_supported", []) - if not self.get_preference('encrypt_request_object_supported'): - self.set_preference('request_object_encryption_alg_values_supported', []) - self.set_preference('request_object_encryption_enc_values_supported', []) + if not self.get_preference("encrypt_request_object_supported"): + self.set_preference("request_object_encryption_alg_values_supported", []) + self.set_preference("request_object_encryption_enc_values_supported", []) - if not self.get_preference('encrypt_id_token_supported'): - self.set_preference('id_token_encryption_alg_values_supported', []) - self.set_preference('id_token_encryption_enc_values_supported', []) + if not self.get_preference("encrypt_id_token_supported"): + self.set_preference("id_token_encryption_alg_values_supported", []) + self.set_preference("id_token_encryption_enc_values_supported", []) def provider_info(self, supports): _info = {} diff --git a/src/idpyoidc/server/client_authn.py b/src/idpyoidc/server/client_authn.py index 1bcd95b4..06321c0a 100755 --- a/src/idpyoidc/server/client_authn.py +++ b/src/idpyoidc/server/client_authn.py @@ -41,11 +41,11 @@ def __init__(self, upstream_get): self.upstream_get = upstream_get def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): """ Verify authentication information in a request @@ -55,12 +55,12 @@ def _verify( raise NotImplementedError() def verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - get_client_id_from_token: Optional[Callable] = None, - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + get_client_id_from_token: Optional[Callable] = None, + **kwargs, ): """ Verify authentication information in a request @@ -78,9 +78,9 @@ def verify( return res def is_usable( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, ): """ Verify that this authentication method is applicable. @@ -117,11 +117,11 @@ def is_usable(self, request=None, authorization_token=None): return request is not None def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): return {"client_id": request.get("client_id")} @@ -138,11 +138,11 @@ def is_usable(self, request=None, authorization_token=None): return request and "client_id" in request def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): return {"client_id": request["client_id"]} @@ -162,14 +162,14 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): client_info = basic_authn(authorization_token) - _context = self.upstream_get('context') + _context = self.upstream_get("context") if _context.cdb[client_info["id"]]["client_secret"] == client_info["secret"]: return {"client_id": client_info["id"]} else: @@ -194,13 +194,13 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): - _context = self.upstream_get('context') + _context = self.upstream_get("context") if _context.cdb[request["client_id"]]["client_secret"] == request["client_secret"]: return {"client_id": request["client_id"]} else: @@ -218,15 +218,15 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - get_client_id_from_token: Optional[Callable] = None, - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + get_client_id_from_token: Optional[Callable] = None, + **kwargs, ): token = authorization_token.split(" ", 1)[1] - _context = self.upstream_get('context') + _context = self.upstream_get("context") try: client_id = get_client_id_from_token(_context, token, request) except ToOld: @@ -249,19 +249,19 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - get_client_id_from_token: Optional[Callable] = None, - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + get_client_id_from_token: Optional[Callable] = None, + **kwargs, ): _token = request.get("access_token") if _token is None: raise ClientAuthenticationError("No access token") res = {"token": _token} - _context = self.upstream_get('context') + _context = self.upstream_get("context") _client_id = get_client_id_from_token(_context, _token, request) if _client_id: res["client_id"] = _client_id @@ -269,7 +269,6 @@ def _verify( class JWSAuthnMethod(ClientAuthnMethod): - def is_usable(self, request=None, authorization_token=None): if request is None: return False @@ -278,15 +277,15 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - key_type: Optional[str] = None, - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + key_type: Optional[str] = None, + **kwargs, ): - _context = self.upstream_get('context') - _keyjar = self.upstream_get('attribute', 'keyjar') + _context = self.upstream_get("context") + _keyjar = self.upstream_get("attribute", "keyjar") _jwt = JWT(_keyjar, msg_cls=JsonWebToken) try: ca_jwt = _jwt.unpack(request["client_assertion"]) @@ -298,9 +297,7 @@ def _verify( if _sign_alg and _sign_alg.startswith("HS"): if key_type == "private_key": raise AttributeError("Wrong key type") - keys = _keyjar.get( - "sig", "oct", ca_jwt["iss"], ca_jwt.jws_header.get("kid") - ) + keys = _keyjar.get("sig", "oct", ca_jwt["iss"], ca_jwt.jws_header.get("kid")) _secret = _context.cdb[ca_jwt["iss"]].get("client_secret") if _secret and keys[0].key != as_bytes(_secret): raise AttributeError("Oct key used for signing not client_secret") @@ -348,11 +345,11 @@ class ClientSecretJWT(JWSAuthnMethod): tag = "client_secret_jwt" def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): res = super()._verify( request=request, key_type="client_secret", endpoint=endpoint, **kwargs @@ -369,11 +366,11 @@ class PrivateKeyJWT(JWSAuthnMethod): tag = "private_key_jwt" def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): res = super()._verify( request=request, @@ -394,14 +391,14 @@ def is_usable(self, request=None, authorization_token=None): return True def _verify( - self, - request: Optional[Union[dict, Message]] = None, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - **kwargs, + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + **kwargs, ): - _context = self.upstream_get('context') - _jwt = JWT(self.upstream_get('attribute', 'keyjar'), msg_cls=JsonWebToken) + _context = self.upstream_get("context") + _jwt = JWT(self.upstream_get("attribute", "keyjar"), msg_cls=JsonWebToken) try: _jwt = _jwt.unpack(request["request"]) except (Invalid, MissingKey, BadSignature) as err: @@ -446,11 +443,12 @@ def valid_client_info(cinfo): def verify_client( - request: Union[dict, Message], - http_info: Optional[dict] = None, - get_client_id_from_token: Optional[Callable] = None, - endpoint=None, # Optional[Endpoint] - also_known_as: Optional[Dict[str, str]] = None, + request: Union[dict, Message], + http_info: Optional[dict] = None, + get_client_id_from_token: Optional[Callable] = None, + endpoint=None, # Optional[Endpoint] + also_known_as: Optional[Dict[str, str]] = None, + **kwargs, ) -> dict: """ Initiated Guessing ! @@ -473,7 +471,7 @@ def verify_client( authorization_token = None auth_info = {} - _context = endpoint.upstream_get('context') + _context = endpoint.upstream_get("context") methods = _context.client_authn_methods client_id = None allowed_methods = getattr(endpoint, "client_authn_method") @@ -487,7 +485,7 @@ def verify_client( try: logger.info(f"Verifying client authentication using {_method.tag}") auth_info = _method.verify( - keyjar=endpoint.upstream_get('attribute', 'keyjar'), + keyjar=endpoint.upstream_get("attribute", "keyjar"), request=request, authorization_token=authorization_token, endpoint=endpoint, @@ -510,10 +508,17 @@ def verify_client( client_id = also_known_as[client_id] auth_info["client_id"] = client_id - if client_id not in _context.cdb: - raise UnknownClient("Unknown Client ID") + _get_client_info = kwargs.get("get_client_info") + if _get_client_info: + _cinfo = _get_client_info(client_id, _context) + else: + try: + _cinfo = _context.cdb[client_id] + except KeyError: + raise UnknownClient("Unknown Client ID") - _cinfo = _context.cdb[client_id] + if not _cinfo: + raise UnknownClient("Unknown Client ID") if not valid_client_info(_cinfo): logger.warning("Client registration has timed out or " "client secret is expired.") diff --git a/src/idpyoidc/server/configure.py b/src/idpyoidc/server/configure.py index 3ba7449d..952d3929 100755 --- a/src/idpyoidc/server/configure.py +++ b/src/idpyoidc/server/configure.py @@ -7,16 +7,14 @@ from typing import List from typing import Optional +from idpyoidc.client.defaults import OAUTH2_SERVER_METADATA_URL from idpyoidc.configure import Base from idpyoidc.server.client_configure import verify_oidc_client_information from idpyoidc.server.scopes import SCOPE2CLAIMS logger = logging.getLogger(__name__) -OP_DEFAULT_CONFIG = { - "preference": { - "subject_types_supported": ["public", "pairwise"], - }, +_DEFAULT_CONFIG = { "cookie_handler": { "class": "idpyoidc.server.cookie_handler.CookieHandler", "kwargs": { @@ -39,7 +37,83 @@ }, }, }, - "claims_interface": {"class": "idpyoidc.server.session.claims.ClaimsInterface", "kwargs": {}}, + "claims_interface": { + "class": "idpyoidc.server.session.claims.ClaimsInterface", + "kwargs": {} + }, + "httpc_params": {"verify": False, "timeout": 4}, + "issuer": "https://{domain}:{port}", + "template_dir": "templates" +} + +AS_DEFAULT_CONFIG = copy.deepcopy(_DEFAULT_CONFIG) +_C = { + "authz": { + "class": "idpyoidc.server.authz.AuthzHandling", + "kwargs": { + "grant_config": { + "usage_rules": { + "authorization_code": { + "supports_minting": ["access_token", "refresh_token"], + "max_usage": 1, + "expires_in": 120 # 2 minutes + }, + "access_token": {"expires_in": 3600}, # An hour + "refresh_token": { + "supports_minting": ["access_token", "refresh_token"], + "expires_in": 86400, # One day + }, + }, + "expires_in": 2592000, # a month, 30 days + } + } + }, + "claims_interface": { + "class": "idpyoidc.server.session.claims.ClaimsInterface", + "kwargs": { + "claims_release_points": ["introspection", "access_token"] + } + }, + "endpoint": { + "provider_info": { + "path": OAUTH2_SERVER_METADATA_URL[3:], + "class": "idpyoidc.server.oauth2.server_metadata.ServerMetadata", + "kwargs": {"client_authn_method": None}, + }, + "authorization": { + "path": "authorization", + "class": "idpyoidc.server.oauth2.authorization.Authorization", + "kwargs": { + "client_authn_method": None, + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, + "response_types_supported": ["code"], + "response_modes_supported": ["query", "fragment", "form_post"], + }, + }, + "token": { + "path": "token", + "class": "idpyoidc.server.oauth2.token.Token", + "kwargs": { + "client_authn_method": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ] + } + } + } +} + +AS_DEFAULT_CONFIG.update(_C) + +OP_DEFAULT_CONFIG = copy.deepcopy(_DEFAULT_CONFIG) +OP_DEFAULT_CONFIG.update({ + "preference": { + "subject_types_supported": ["public", "pairwise"], + }, "authz": { "class": "idpyoidc.server.authz.AuthzHandling", "kwargs": { @@ -52,18 +126,22 @@ "id_token", ], "max_usage": 1, + 'expires_in': 120 # 2 minutes }, - "access_token": {}, + "access_token": {'expires_in': 3600}, # An hour "refresh_token": { - "supports_minting": ["access_token", "refresh_token"], - "expires_in": -1, + "supports_minting": ["access_token", "refresh_token", "id_token"], + "expires_in": 86400, # One day }, }, - "expires_in": 43200, + "expires_in": 2592000, # a month, 30 days } }, }, - "httpc_params": {"verify": False, "timeout": 4}, + "claims_interface": { + "class": "idpyoidc.server.session.claims.ClaimsInterface", + "kwargs": {} + }, "endpoint": { "provider_info": { "path": ".well-known/openid-configuration", @@ -80,12 +158,12 @@ "request_uri_parameter_supported": True, "response_types_supported": [ "code", - "token", + # "token", "id_token", - "code token", + # "code token", "code id_token", - "id_token token", - "code id_token token", + # "id_token token", + # "code id_token token", # "none" ], "response_modes_supported": ["query", "fragment", "form_post"], @@ -109,8 +187,6 @@ "kwargs": {"claim_types_supported": ["normal", "aggregated", "distributed"]}, }, }, - "issuer": "https://{domain}:{port}", - "template_dir": "templates", "token_handler_args": { "jwks_file": "private/token_jwks.json", "code": {"kwargs": {"lifetime": 600}}, @@ -125,13 +201,8 @@ "id_token": {"class": "idpyoidc.server.token.id_token.IDToken", "kwargs": {}}, }, "scopes_to_claims": SCOPE2CLAIMS, -} +}) -AS_DEFAULT_CONFIG = copy.deepcopy(OP_DEFAULT_CONFIG) -AS_DEFAULT_CONFIG["claims_interface"] = { - "class": "idpyoidc.server.session.claims.OAuth2ClaimsInterface", - "kwargs": {}, -} class EntityConfiguration(Base): @@ -151,24 +222,24 @@ class EntityConfiguration(Base): "httpc_params": {}, "issuer": "", "key_conf": None, - 'preference': {}, + "preference": {}, "session_params": None, "template_dir": None, "token_handler_args": {}, "userinfo": None, - "scopes_handler": None + "scopes_handler": None, } def __init__( - self, - conf: Dict, - base_path: Optional[str] = "", - entity_conf: Optional[List[dict]] = None, - domain: Optional[str] = "", - port: Optional[int] = 0, - file_attributes: Optional[List[str]] = None, - dir_attributes: Optional[List[str]] = None, - upstream_get: Optional[Callable] = None + self, + conf: Dict, + base_path: Optional[str] = "", + entity_conf: Optional[List[dict]] = None, + domain: Optional[str] = "", + port: Optional[int] = 0, + file_attributes: Optional[List[str]] = None, + dir_attributes: Optional[List[str]] = None, + upstream_get: Optional[Callable] = None, ): conf = copy.deepcopy(conf) @@ -232,14 +303,14 @@ class OPConfiguration(EntityConfiguration): ) def __init__( - self, - conf: Dict, - base_path: Optional[str] = "", - entity_conf: Optional[List[dict]] = None, - domain: Optional[str] = "", - port: Optional[int] = 0, - file_attributes: Optional[List[str]] = None, - dir_attributes: Optional[List[str]] = None, + self, + conf: Dict, + base_path: Optional[str] = "", + entity_conf: Optional[List[dict]] = None, + domain: Optional[str] = "", + port: Optional[int] = 0, + file_attributes: Optional[List[str]] = None, + dir_attributes: Optional[List[str]] = None, ): super().__init__( conf=conf, @@ -256,14 +327,14 @@ class ASConfiguration(EntityConfiguration): "Authorization server configuration" def __init__( - self, - conf: Dict, - base_path: Optional[str] = "", - entity_conf: Optional[List[dict]] = None, - domain: Optional[str] = "", - port: Optional[int] = 0, - file_attributes: Optional[List[str]] = None, - dir_attributes: Optional[List[str]] = None, + self, + conf: Dict, + base_path: Optional[str] = "", + entity_conf: Optional[List[dict]] = None, + domain: Optional[str] = "", + port: Optional[int] = 0, + file_attributes: Optional[List[str]] = None, + dir_attributes: Optional[List[str]] = None, ): EntityConfiguration.__init__( self, @@ -280,7 +351,7 @@ def __init__( DEFAULT_EXTENDED_CONF = { "add_on": { "pkce": { - "function": "idpyoidc.server.oidc.add_on.pkce.add_pkce_support", + "function": "idpyoidc.server.oauth2.add_on.pkce.add_support", "kwargs": {"essential": False, "code_challenge_method": "S256 S384 S512"}, }, "claims": { @@ -349,9 +420,7 @@ def __init__( "refresh_token", ], }, - "scopes_handler": { - "class": "idpyoidc.server.scopes.Scopes" - }, + "scopes_handler": {"class": "idpyoidc.server.scopes.Scopes"}, "claims_interface": {"class": "idpyoidc.server.session.claims.ClaimsInterface", "kwargs": {}}, "cookie_handler": { "class": "idpyoidc.server.cookie_handler.CookieHandler", @@ -417,12 +486,12 @@ def __init__( "request_uri_parameter_supported": True, "response_types_supported": [ "code", - "token", + # "token", "id_token", - "code token", + # "code token", "code id_token", - "id_token token", - "code id_token token", + # "id_token token", + # "code id_token token", # "none" ], "response_modes_supported": ["query", "fragment", "form_post"], diff --git a/src/idpyoidc/server/cookie_handler.py b/src/idpyoidc/server/cookie_handler.py index c9b87bae..1d3fc977 100755 --- a/src/idpyoidc/server/cookie_handler.py +++ b/src/idpyoidc/server/cookie_handler.py @@ -144,9 +144,7 @@ def _sign_enc_payload(self, payload: str, timestamp: Optional[Union[int, str]] = ] elif self.crypt: msg = lv_pack(timestamp, payload) - cookie_payload = [ - bytes_timestamp, - base64.b64encode(self.crypt.encrypt(msg.encode()))] + cookie_payload = [bytes_timestamp, base64.b64encode(self.crypt.encrypt(msg.encode()))] else: cookie_payload = [bytes_timestamp, bytes_load, base64.b64encode(mac)] @@ -169,7 +167,7 @@ def _ver_dec_content(self, parts): msg = self.crypt.decrypt(base64.b64decode(as_bytes(enc_payload))) t1, payload = lv_unpack(msg.decode("utf-8")) if t0 != t1: - raise VerificationError('Suspicious timestamp') + raise VerificationError("Suspicious timestamp") return payload, t1 elif len(parts) == 3: # verify the cookie signature diff --git a/src/idpyoidc/server/endpoint.py b/src/idpyoidc/server/endpoint.py index a0763ceb..849bb318 100755 --- a/src/idpyoidc/server/endpoint.py +++ b/src/idpyoidc/server/endpoint.py @@ -170,16 +170,21 @@ def verify_request(self, request, keyjar, client_id, verify_args, lap=0): return self.error_cls(error=err) else: # Fund a client ID I believe will work - self.verify_request(request=request, keyjar=keyjar, client_id=client_id, - verify_args=verify_args, lap=1) + self.verify_request( + request=request, + keyjar=keyjar, + client_id=client_id, + verify_args=verify_args, + lap=1, + ) return None def parse_request( - self, - request: Union[Message, dict, str], - http_info: Optional[dict] = None, - verify_args: Optional[dict] = None, - **kwargs + self, + request: Union[Message, dict, str], + http_info: Optional[dict] = None, + verify_args: Optional[dict] = None, + **kwargs ): """ @@ -193,7 +198,7 @@ def parse_request( LOGGER.info("Request: %s" % sanitize(request)) _context = self.upstream_get("context") - _keyjar = self.upstream_get('attribute', 'keyjar') + _keyjar = self.upstream_get("attribute", "keyjar") if http_info is None: http_info = {} @@ -226,17 +231,18 @@ def parse_request( if "client_id" in auth_info: req["client_id"] = auth_info["client_id"] - _auth_method = auth_info.get('method') - if _auth_method and _auth_method not in ['public', 'none']: - req['authenticated'] = True + _auth_method = auth_info.get("method") + if _auth_method and _auth_method not in ["public", "none"]: + req["authenticated"] = True _client_id = auth_info["client_id"] else: _client_id = req.get("client_id") # verify that the request message is correct, may have to do it twice - err_response = self.verify_request(request=req, keyjar=_keyjar, client_id=_client_id, - verify_args=verify_args) + err_response = self.verify_request( + request=req, keyjar=_keyjar, client_id=_client_id, verify_args=verify_args + ) if err_response: return err_response @@ -263,11 +269,7 @@ def client_authentication(self, request: Message, http_info: Optional[dict] = No if not get_client_id_from_token: kwargs["get_client_id_from_token"] = getattr(self, "get_client_id_from_token", None) - authn_info = verify_client( - request=request, - http_info=http_info, - **kwargs - ) + authn_info = verify_client(request=request, http_info=http_info, **kwargs) LOGGER.debug("authn_info: %s", authn_info) if authn_info == {} and self.client_authn_method and len(self.client_authn_method): @@ -278,7 +280,7 @@ def client_authentication(self, request: Message, http_info: Optional[dict] = No return authn_info def do_post_parse_request( - self, request: Message, client_id: Optional[str] = "", **kwargs + self, request: Message, client_id: Optional[str] = "", **kwargs ) -> Message: _context = self.upstream_get("context") for meth in self.post_parse_request: @@ -288,7 +290,7 @@ def do_post_parse_request( return request def do_pre_construct( - self, response_args: dict, request: Optional[Union[Message, dict]] = None, **kwargs + self, response_args: dict, request: Optional[Union[Message, dict]] = None, **kwargs ) -> dict: _context = self.upstream_get("context") for meth in self.pre_construct: @@ -297,10 +299,10 @@ def do_pre_construct( return response_args def do_post_construct( - self, - response_args: Union[Message, dict], - request: Optional[Union[Message, dict]] = None, - **kwargs + self, + response_args: Union[Message, dict], + request: Optional[Union[Message, dict]] = None, + **kwargs ) -> dict: _context = self.upstream_get("context") for meth in self.post_construct: @@ -309,10 +311,10 @@ def do_post_construct( return response_args def process_request( - self, - request: Optional[Union[Message, dict]] = None, - http_info: Optional[dict] = None, - **kwargs + self, + request: Optional[Union[Message, dict]] = None, + http_info: Optional[dict] = None, + **kwargs ) -> Union[Message, dict]: """ @@ -323,10 +325,10 @@ def process_request( return {} def construct( - self, - response_args: Optional[dict] = None, - request: Optional[Union[Message, dict]] = None, - **kwargs + self, + response_args: Optional[dict] = None, + request: Optional[Union[Message, dict]] = None, + **kwargs ): """ Construct the response @@ -344,19 +346,19 @@ def construct( return self.do_post_construct(response, request, **kwargs) def response_info( - self, - response_args: Optional[dict] = None, - request: Optional[Union[Message, dict]] = None, - **kwargs + self, + response_args: Optional[dict] = None, + request: Optional[Union[Message, dict]] = None, + **kwargs ) -> dict: return self.construct(response_args, request, **kwargs) def do_response( - self, - response_args: Optional[dict] = None, - request: Optional[Union[Message, dict]] = None, - error: Optional[str] = "", - **kwargs + self, + response_args: Optional[dict] = None, + request: Optional[Union[Message, dict]] = None, + error: Optional[str] = "", + **kwargs ) -> dict: """ :param response_args: Information to use when constructing the response @@ -391,6 +393,8 @@ def do_response( content_type = "application/json" elif self.response_format in ["jws", "jwe", "jose"]: content_type = "application/jose" + elif self.response_format == "text": + content_type = "text/plain" else: content_type = "application/x-www-form-urlencoded" else: diff --git a/src/idpyoidc/server/endpoint_context.py b/src/idpyoidc/server/endpoint_context.py index 742084c2..ad5acad6 100755 --- a/src/idpyoidc/server/endpoint_context.py +++ b/src/idpyoidc/server/endpoint_context.py @@ -102,19 +102,19 @@ class EndpointContext(OidcContext): "client_authn_method": {}, } - init_args = ['upstream_get', 'handler'] + init_args = ["upstream_get", "handler"] def __init__( - self, - conf: Union[dict, OPConfiguration], - upstream_get: Callable, - cwd: Optional[str] = "", - cookie_handler: Optional[Any] = None, - httpc: Optional[Any] = None, - server_type: Optional[str] = '', - entity_id: Optional[str] = "", - keyjar: Optional[KeyJar] = None, - claims_class: Optional[Claims] = None + self, + conf: Union[dict, OPConfiguration], + upstream_get: Callable, + cwd: Optional[str] = "", + cookie_handler: Optional[Any] = None, + httpc: Optional[Any] = None, + server_type: Optional[str] = "", + entity_id: Optional[str] = "", + keyjar: Optional[KeyJar] = None, + claims_class: Optional[Claims] = None, ): _id = entity_id or conf.get("issuer", "") OidcContext.__init__(self, conf, entity_id=_id) @@ -240,7 +240,7 @@ def __init__( _supports = self.supports() self.keyjar = self.claims.load_conf(conf, supports=_supports, keyjar=keyjar) self.provider_info = self.claims.provider_info(_supports) - self.provider_info['issuer'] = self.issuer + self.provider_info["issuer"] = self.issuer self.provider_info.update(self._get_endpoint_info()) # INTERFACES @@ -316,7 +316,10 @@ def set_scopes_handler(self): ) def do_add_on(self, endpoints): - _add_on_conf = self.conf.get("add_on") + _add_on_conf = self.conf.get("add_ons", self.conf.get("add_on")) + if not _add_on_conf: + _add_on_conf = self.conf.conf.get("add_ons") + if _add_on_conf: for spec in _add_on_conf.values(): if isinstance(spec["function"], str): @@ -399,21 +402,21 @@ def do_login_hint_lookup(self): def supports(self): res = {} if self.upstream_get: - for endpoint in self.upstream_get('endpoints').values(): + for endpoint in self.upstream_get("endpoints").values(): res.update(endpoint.supports()) res.update(self.claims.supports()) return res def set_provider_info(self): _info = self.claims.provider_info(self.supports()) - _info.update({'issuer': self.issuer, 'version': "3.0"}) + _info.update({"issuer": self.issuer, "version": "3.0"}) - for endp in self.upstream_get('endpoints').values(): + for endp in self.upstream_get("endpoints").values(): if endp.endpoint_name: _info[endp.endpoint_name] = endp.full_path # acr_values - if 'acr_values_supported' not in _info: + if "acr_values_supported" not in _info: if self.authn_broker: acr_values = self.authn_broker.get_acr_values() if acr_values is not None: @@ -484,7 +487,7 @@ def map_supported_to_preferred(self): def _get_endpoint_info(self): _res = {} - for name, endp in self.upstream_get('endpoints').items(): + for name, endp in self.upstream_get("endpoints").items(): if endp.endpoint_name: _res[endp.endpoint_name] = endp.full_path return _res diff --git a/src/idpyoidc/server/exception.py b/src/idpyoidc/server/exception.py index 3fbc552b..bfe5a30d 100755 --- a/src/idpyoidc/server/exception.py +++ b/src/idpyoidc/server/exception.py @@ -128,3 +128,7 @@ class MultipleCodeUsage(OidcEndpointError): class InvalidBranchID(OidcEndpointError): pass + + +class ClientGrantMismatch(OidcEndpointError): + pass diff --git a/src/idpyoidc/server/oauth2/add_on/dpop.py b/src/idpyoidc/server/oauth2/add_on/dpop.py index e426acd3..849d151a 100644 --- a/src/idpyoidc/server/oauth2/add_on/dpop.py +++ b/src/idpyoidc/server/oauth2/add_on/dpop.py @@ -1,17 +1,23 @@ +import logging +from hashlib import sha256 +from typing import Callable from typing import Optional +from typing import Union -from cryptojwt import JWS from cryptojwt import as_unicode +from cryptojwt import JWS from cryptojwt.jwk.jwk import key_from_jwk_dict from cryptojwt.jws.jws import factory +from idpyoidc.claims import get_signing_algs +from idpyoidc.message import Message +from idpyoidc.message import SINGLE_OPTIONAL_STRING from idpyoidc.message import SINGLE_REQUIRED_INT from idpyoidc.message import SINGLE_REQUIRED_JSON from idpyoidc.message import SINGLE_REQUIRED_STRING -from idpyoidc.message import Message -from idpyoidc.server.client_authn import ClientAuthnMethod -from idpyoidc.server.client_authn import basic_authn -from idpyoidc.server.exception import ClientAuthenticationError +from idpyoidc.server.client_authn import BearerHeader + +logger = logging.getLogger(__name__) class DPoPProof(Message): @@ -25,6 +31,7 @@ class DPoPProof(Message): "htm": SINGLE_REQUIRED_STRING, "htu": SINGLE_REQUIRED_STRING, "iat": SINGLE_REQUIRED_INT, + "ath": SINGLE_OPTIONAL_STRING, } header_params = {"typ", "alg", "jwk"} body_params = {"jti", "htm", "htu", "iat"} @@ -84,7 +91,7 @@ def verify_header(self, dpop_header) -> Optional["DPoPProof"]: return None -def post_parse_request(request, client_id, context, **kwargs): +def token_post_parse_request(request, client_id, context, **kwargs): """ Expect http_info attribute in kwargs. http_info should be a dictionary containing HTTP information. @@ -119,6 +126,47 @@ def post_parse_request(request, client_id, context, **kwargs): return request +def userinfo_post_parse_request(request, client_id, context, auth_info, **kwargs): + """ + Expect http_info attribute in kwargs. http_info should be a dictionary + containing HTTP information. + + :param request: + :param client_id: + :param context: + :param kwargs: + :return: + """ + + _http_info = kwargs.get("http_info") + if not _http_info: + return request + + _dpop = DPoPProof().verify_header(_http_info["headers"]["dpop"]) + + # The signature of the JWS is verified, now for checking the + # content + + if _dpop["htu"] != _http_info["url"]: + raise ValueError("htu in DPoP does not match the HTTP URI") + + if _dpop["htm"] != _http_info["method"]: + raise ValueError("htm in DPoP does not match the HTTP method") + + if not _dpop.key: + _dpop.key = key_from_jwk_dict(_dpop["jwk"]) + + ath = sha256(auth_info["token"].encode("utf8")).hexdigest() + + if _dpop["ath"] != ath: + raise ValueError("'ath' in DPoP does not match the token hash") + + # Need something I can add as a reference when minting tokens + request["dpop_jkt"] = as_unicode(_dpop.key.thumbprint("SHA-256")) + logger.debug("DPoP verified") + return request + + def token_args(context, client_id, token_args: Optional[dict] = None): dpop_jkt = context.cdb[client_id]["dpop_jkt"] _jkt = list(dpop_jkt.keys())[0] @@ -134,24 +182,31 @@ def token_args(context, client_id, token_args: Optional[dict] = None): def add_support(endpoint: dict, **kwargs): # _token_endp = endpoint["token"] - _token_endp.post_parse_request.append(post_parse_request) + _token_endp.post_parse_request.append(token_post_parse_request) _algs_supported = kwargs.get("dpop_signing_alg_values_supported") if not _algs_supported: _algs_supported = ["RS256"] + else: + _algs_supported = [alg for alg in _algs_supported if alg in get_signing_algs()] _token_endp.upstream_get("context").provider_info[ "dpop_signing_alg_values_supported" ] = _algs_supported _context = _token_endp.upstream_get("context") - _context.dpop_enabled = True + _context.add_on["dpop"] = {"algs_supported": _algs_supported} + _context.client_authn_methods["dpop"] = DPoPClientAuth + + _userinfo_endpoint = endpoint.get("userinfo") + if _userinfo_endpoint: + _userinfo_endpoint.post_parse_request.append(userinfo_post_parse_request) # DPoP-bound access token in the "Authorization" header and the DPoP proof in the "DPoP" header -class DPoPClientAuth(ClientAuthnMethod): +class DPoPClientAuth(BearerHeader): tag = "dpop_client_auth" def is_usable(self, request=None, authorization_info=None, http_headers=None): @@ -159,10 +214,21 @@ def is_usable(self, request=None, authorization_info=None, http_headers=None): return True return False - def verify(self, authorization_info, **kwargs): - client_info = basic_authn(authorization_info) + def verify( + self, + request: Optional[Union[dict, Message]] = None, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + get_client_id_from_token: Optional[Callable] = None, + **kwargs, + ): + # info contains token and client_id + info = BearerHeader._verify( + self, request, authorization_token, endpoint, get_client_id_from_token, **kwargs + ) _context = self.upstream_get("context") - if _context.cdb[client_info["id"]]["client_secret"] == client_info["secret"]: - return {"client_id": client_info["id"]} - else: - raise ClientAuthenticationError() + return {"client_id": ""} + # if _context.cdb[client_info["id"]]["client_secret"] == client_info["secret"]: + # return {"client_id": client_info["id"]} + # else: + # raise ClientAuthenticationError() diff --git a/src/idpyoidc/server/oidc/add_on/pkce.py b/src/idpyoidc/server/oauth2/add_on/pkce.py similarity index 78% rename from src/idpyoidc/server/oidc/add_on/pkce.py rename to src/idpyoidc/server/oauth2/add_on/pkce.py index ccd8d506..0b1e697a 100644 --- a/src/idpyoidc/server/oidc/add_on/pkce.py +++ b/src/idpyoidc/server/oauth2/add_on/pkce.py @@ -1,6 +1,7 @@ import hashlib import logging from typing import Dict +from typing import Optional from cryptojwt.utils import b64e @@ -44,7 +45,7 @@ def post_authn_parse(request, client_id, context, **kwargs): if "pkce_essential" in client: essential = client["pkce_essential"] else: - essential = context.args["pkce"].get("essential", False) + essential = context.add_on["pkce"].get("essential", False) if essential and "code_challenge" not in request: return AuthorizationErrorResponse( error="invalid_request", @@ -52,11 +53,10 @@ def post_authn_parse(request, client_id, context, **kwargs): ) if "code_challenge_method" not in request: - request["code_challenge_method"] = "S256" + request["code_challenge_method"] = "plain" if "code_challenge" in request and ( - request["code_challenge_method"] - not in context.args["pkce"]["code_challenge_methods"] + request["code_challenge_method"] not in context.add_on["pkce"]["code_challenge_methods"] ): return AuthorizationErrorResponse( error="invalid_request", @@ -126,7 +126,12 @@ def post_token_parse(request, client_id, context, **kwargs): return request -def add_pkce_support(endpoint: Dict[str, Endpoint], **kwargs): +def add_support( + endpoint: Dict[str, Endpoint], + code_challenge_methods: Optional[dict] = None, + essential: Optional[bool] = False, + **kwargs +): authn_endpoint = endpoint.get("authorization") if authn_endpoint is None: LOGGER.warning("No authorization endpoint found, skipping PKCE configuration") @@ -140,22 +145,16 @@ def add_pkce_support(endpoint: Dict[str, Endpoint], **kwargs): authn_endpoint.post_parse_request.append(post_authn_parse) token_endpoint.post_parse_request.append(post_token_parse) - code_challenge_methods = kwargs.get("code_challenge_methods", CC_METHOD.keys()) - code_challenge_methods = list( - set(code_challenge_methods).intersection( - authn_endpoint._supports["code_challenge_methods_supported"] - ) - ) - if not code_challenge_methods: - raise ValueError( - "Unsupported method: {}".format( - ", ".join(kwargs.get("code_challenge_methods", CC_METHOD.keys())) - ) - ) - kwargs["code_challenge_methods"] = {} - for method in code_challenge_methods: - if method not in CC_METHOD: - raise ValueError("Unsupported method: {}".format(method)) - kwargs["code_challenge_methods"][method] = CC_METHOD[method] - - authn_endpoint.upstream_get("context").args["pkce"] = kwargs + if code_challenge_methods is None: + code_challenge_methods = CC_METHOD + else: + for method in code_challenge_methods: + if method not in CC_METHOD: + raise ValueError("Unsupported method: {}".format(method)) + + _context = authn_endpoint.upstream_get("context") + _context.add_on["pkce"] = { + "code_challenge_methods": code_challenge_methods, + "essential": essential, + } + _context.set_preference("code_challenge_methods_supported", list(code_challenge_methods.keys())) diff --git a/src/idpyoidc/server/oauth2/authorization.py b/src/idpyoidc/server/oauth2/authorization.py index f6f60f99..46b12699 100755 --- a/src/idpyoidc/server/oauth2/authorization.py +++ b/src/idpyoidc/server/oauth2/authorization.py @@ -94,9 +94,9 @@ def max_age(request): def verify_uri( context: EndpointContext, - request: Union[dict, Message], - uri_type: str, - client_id: Optional[str] = None, + request: Union[dict, Message], + uri_type: str, + client_id: Optional[str] = None, ): """ A redirect URI @@ -226,10 +226,10 @@ def get_uri(context, request, uri_type): def authn_args_gather( - request: Union[AuthorizationRequest, dict], - authn_class_ref: str, - cinfo: dict, - **kwargs, + request: Union[AuthorizationRequest, dict], + authn_class_ref: str, + cinfo: dict, + **kwargs, ): """ Gather information to be used by the authentication method @@ -291,7 +291,10 @@ def validate_resource_indicators_policy(request, context, **kwargs): resource_servers_per_client = kwargs["resource_servers_per_client"] client_id = request["client_id"] - if isinstance(resource_servers_per_client, dict) and client_id not in resource_servers_per_client: + if ( + isinstance(resource_servers_per_client, dict) + and client_id not in resource_servers_per_client + ): return oauth2.AuthorizationErrorResponse( error="invalid_target", error_description=f"Resources for client {client_id} not found", @@ -342,7 +345,7 @@ class Authorization(Endpoint): "claims_parameter_supported": True, "request_parameter_supported": True, "request_uri_parameter_supported": True, - "response_types_supported": ["code", "token", "code token"], + "response_types_supported": ["code"], "response_modes_supported": ["query", "fragment", "form_post"], "request_object_signing_alg_values_supported": claims.get_signing_algs, "request_object_encryption_alg_values_supported": claims.get_encryption_algs, @@ -360,7 +363,7 @@ def __init__(self, upstream_get, **kwargs): self.post_parse_request.append(self._do_request_uri) self.post_parse_request.append(self._post_parse_request) self.allowed_request_algorithms = AllowedAlgorithms(ALG_PARAMS) - self.resource_indicators_config = kwargs.get('resource_indicators', None) + self.resource_indicators_config = kwargs.get("resource_indicators", None) def filter_request(self, context, req): return req @@ -377,7 +380,7 @@ def authentication_error_response(self, request, error, error_description, **kwa def verify_response_type(self, request: Union[Message, dict], cinfo: dict) -> bool: # Checking response types - _registered = [set(rt.split(" ")) for rt in cinfo.get("response_types", [])] + _registered = [set(rt.split(" ")) for rt in cinfo.get("response_types_supported", [])] if not _registered: # If no response_type is registered by the client then we'll use code. _registered = [{"code"}] @@ -437,12 +440,9 @@ def _do_request_uri(self, request, client_id, context, **kwargs): raise ValueError("A request_uri outside the registered") # Fetch the request - _resp = context.httpc('GET', _request_uri, **context.httpc_params) + _resp = context.httpc("GET", _request_uri, **context.httpc_params) if _resp.status_code == 200: - args = { - "keyjar": self.upstream_get('attribute', 'keyjar'), - "issuer": client_id - } + args = {"keyjar": self.upstream_get("attribute", "keyjar"), "issuer": client_id} _ver_request = self.request_cls().from_jwt(_resp.text, **args) self.allowed_request_algorithms( client_id, @@ -518,8 +518,10 @@ def _post_parse_request(self, request, client_id, context, **kwargs): else: request["redirect_uri"] = redirect_uri - if ("resource_indicators" in _cinfo - and "authorization_code" in _cinfo["resource_indicators"]): + if ( + "resource_indicators" in _cinfo + and "authorization_code" in _cinfo["resource_indicators"] + ): resource_indicators_config = _cinfo["resource_indicators"]["authorization_code"] else: resource_indicators_config = self.resource_indicators_config @@ -540,9 +542,7 @@ def _enforce_resource_indicators_policy(self, request, config): kwargs = policy.get("kwargs", {}) if kwargs.get("resource_servers_per_client", None) is None: - kwargs["resource_servers_per_client"] = { - request["client_id"]: request["client_id"] - } + kwargs["resource_servers_per_client"] = {request["client_id"]: request["client_id"]} if isinstance(function, str): try: @@ -618,7 +618,7 @@ def _unwrap_identity(self, identity): # identity is a dict or a json object # the value of 'uid' in the dictionary might be a base64 encoded (b64e) json object if isinstance(identity, dict): - _uid = as_unicode(identity['uid']) + _uid = as_unicode(identity["uid"]) try: _id = b64d(as_bytes(_uid)) except Exception: @@ -635,13 +635,13 @@ def _unwrap_identity(self, identity): return identity def setup_auth( - self, - request: Optional[Union[Message, dict]], - redirect_uri: str, - cinfo: dict, - cookie: List[dict] = None, - acr: str = None, - **kwargs, + self, + request: Optional[Union[Message, dict]], + redirect_uri: str, + cinfo: dict, + cookie: List[dict] = None, + acr: str = None, + **kwargs, ) -> dict: """ @@ -765,12 +765,12 @@ def aresp_check(self, aresp, request): return "" def response_mode( - self, - request: Union[dict, AuthorizationRequest], - response_args: Optional[Union[dict, AuthorizationResponse]] = None, - return_uri: Optional[str] = "", - fragment_enc: Optional[bool] = None, - **kwargs, + self, + request: Union[dict, AuthorizationRequest], + response_args: Optional[Union[dict, AuthorizationResponse]] = None, + return_uri: Optional[str] = "", + fragment_enc: Optional[bool] = None, + **kwargs, ) -> dict: resp_mode = request["response_mode"] if resp_mode == "form_post": @@ -849,11 +849,15 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict if request.get("scope"): scope = request.get("scope") if request.get("resource"): - resource_scopes = [_context.cdb[s]["scope"] for s in request.get("resource") if s in _context.cdb.keys() and _context.cdb[s].get("scope")] + resource_scopes = [ + _context.cdb[s]["scope"] + for s in request.get("resource") + if s in _context.cdb.keys() and _context.cdb[s].get("scope") + ] resource_scopes = [item for sublist in resource_scopes for item in sublist] aresp["scope"] = _context.scopes_handler.filter_scopes( - list(set(scope+resource_scopes)), _sinfo["client_id"] + list(set(scope + resource_scopes)), _sinfo["client_id"] ) rtype = set(request["response_type"][:]) @@ -865,7 +869,7 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict grant = _sinfo["grant"] - if "code" in request["response_type"]: + if "code" in rtype: _code = self.mint_token( token_class="authorization_code", grant=grant, @@ -890,7 +894,7 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict else: _access_token = None - if "id_token" in request["response_type"]: + if "id_token" in rtype: kwargs = {} if {"code", "id_token", "token"}.issubset(rtype): kwargs = {"code": _code.value, "access_token": _access_token.value} @@ -899,7 +903,7 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict elif {"id_token", "token"}.issubset(rtype): kwargs = {"access_token": _access_token.value} - if request["response_type"] == ["id_token"]: + if rtype == {"id_token"}: kwargs["as_if"] = "userinfo" try: @@ -1076,10 +1080,10 @@ def do_request_user(self, request_info, **kwargs): return kwargs def process_request( - self, - request: Optional[Union[Message, dict]] = None, - http_info: Optional[dict] = None, - **kwargs, + self, + request: Optional[Union[Message, dict]] = None, + http_info: Optional[dict] = None, + **kwargs, ): """The AuthorizationRequest endpoint diff --git a/src/idpyoidc/server/oauth2/introspection.py b/src/idpyoidc/server/oauth2/introspection.py index dbc3ccfb..5937d0d5 100644 --- a/src/idpyoidc/server/oauth2/introspection.py +++ b/src/idpyoidc/server/oauth2/introspection.py @@ -32,6 +32,7 @@ class Introspection(Endpoint): def __init__(self, upstream_get, **kwargs): Endpoint.__init__(self, upstream_get, **kwargs) self.offset = kwargs.get("offset", 0) + self.enforce_aud_restriction = kwargs.get("enforce_audience_restriction", True) def _introspect(self, token, client_id, grant): # Make sure that the token is an access_token or a refresh_token @@ -113,9 +114,18 @@ def process_request(self, request=None, release: Optional[list] = None, **kwargs aud = _token.resources if not aud: aud = grant.resources - - if request["client_id"] not in aud: - return {"response_args": _resp} + + client_id = request["client_id"] + try: + _cinfo = _context.cdb[client_id] + enforce_aud_restriction = _cinfo.get( + "enforce_audience_restriction", self.enforce_aud_restriction + ) + except: + enforce_aud_restriction = self.enforce_aud_restriction + if enforce_aud_restriction: + if request["client_id"] not in aud: + return {"response_args": _resp} _info = self._introspect(_token, _session_info["client_id"], _session_info["grant"]) if _info is None: diff --git a/src/idpyoidc/server/oauth2/pushed_authorization.py b/src/idpyoidc/server/oauth2/pushed_authorization.py index 40d319d8..d71c6b34 100644 --- a/src/idpyoidc/server/oauth2/pushed_authorization.py +++ b/src/idpyoidc/server/oauth2/pushed_authorization.py @@ -1,6 +1,10 @@ +from typing import Optional +from typing import Union import uuid +from idpyoidc.message import Message from idpyoidc.message import oauth2 +from idpyoidc.message.oauth2 import AuthorizationRequest from idpyoidc.server.oauth2.authorization import Authorization @@ -20,7 +24,7 @@ def __init__(self, upstream_get, **kwargs): self.post_parse_request.append(self._post_parse_request) self.ttl = kwargs.get("ttl", 3600) - def process_request(self, request=None, **kwargs): + def process_request(self, request: Optional[Union[Message, str]] = None, **kwargs): """ Store the request and return a URI. @@ -28,10 +32,18 @@ def process_request(self, request=None, **kwargs): """ # create URN + if isinstance(request, str): + _request = AuthorizationRequest().from_urlencoded(request) + else: + _request = AuthorizationRequest(**request) + + _request.verify(keyjar=self.upstream_get("attribute", "keyjar")) + _urn = "urn:uuid:{}".format(uuid.uuid4()) - self.upstream_get("context").par_db[_urn] = request + # Store the parsed and verified request + self.upstream_get("context").par_db[_urn] = _request return { "http_response": {"request_uri": _urn, "expires_in": self.ttl}, - "return_uri": request["redirect_uri"], + "return_uri": _request["redirect_uri"], } diff --git a/src/idpyoidc/server/oauth2/token.py b/src/idpyoidc/server/oauth2/token.py index 98bc9fa8..c6a53d1c 100755 --- a/src/idpyoidc/server/oauth2/token.py +++ b/src/idpyoidc/server/oauth2/token.py @@ -22,7 +22,6 @@ logger = logging.getLogger(__name__) - class Token(Endpoint): request_cls = Message response_cls = AccessTokenResponse @@ -44,21 +43,20 @@ class Token(Endpoint): "password": ResourceOwnerPasswordCredentials, } - _supports = { - "grant_types_supported": list(helper_by_grant_type.keys()) - } + _supports = {"grant_types_supported": list(helper_by_grant_type.keys())} def __init__(self, upstream_get, new_refresh_token=False, **kwargs): Endpoint.__init__(self, upstream_get, **kwargs) self.post_parse_request.append(self._post_parse_request) self.allow_refresh = False self.new_refresh_token = new_refresh_token - self.grant_type_helper = self.configure_types(kwargs.get("grant_types_helpers"), - self.helper_by_grant_type) + self.grant_type_helper = self.configure_types( + kwargs.get("grant_types_helpers"), self.helper_by_grant_type + ) # self.grant_types_supported = kwargs.get("grant_types_supported", # list(self.grant_type_helper.keys())) self.revoke_refresh_on_issue = kwargs.get("revoke_refresh_on_issue", False) - self.resource_indicators_config = kwargs.get('resource_indicators', None) + self.resource_indicators_config = kwargs.get("resource_indicators", None) def configure_types(self, helpers, default_helpers): if helpers is None: @@ -93,18 +91,18 @@ def configure_types(self, helpers, default_helpers): return _helper - def _get_helper(self, - request: Union[Message, dict], - client_id: Optional[str] = "") -> Optional[Union[Message, TokenEndpointHelper]]: - grant_type = request.get('grant_type') + def _get_helper( + self, request: Union[Message, dict], client_id: Optional[str] = "" + ) -> Optional[Union[Message, TokenEndpointHelper]]: + grant_type = request.get("grant_type") if grant_type: - _client_id = client_id or request.get('client_id') + _client_id = client_id or request.get("client_id") if client_id: - client = self.upstream_get('context').cdb[client_id] - _grant_types_supported = client.get("grant_types_supported", - self.upstream_get('context').claims.get_claim( - "grant_types_supported", []) - ) + client = self.upstream_get("context").cdb[client_id] + _grant_types_supported = client.get( + "grant_types_supported", + self.upstream_get("context").claims.get_claim("grant_types_supported", []), + ) if grant_type not in _grant_types_supported: return self.error_cls( error="invalid_request", @@ -119,7 +117,7 @@ def _get_helper(self, ) def _post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ): _resp = self._get_helper(request, client_id) if isinstance(_resp, TokenEndpointHelper): @@ -191,4 +189,4 @@ def process_request(self, request: Optional[Union[Message, dict]] = None, **kwar return resp def supports(self): - return {'grant_types_supported': list(self.grant_type_helper.keys())} + return {"grant_types_supported": list(self.grant_type_helper.keys())} diff --git a/src/idpyoidc/server/oauth2/token_helper/__init__.py b/src/idpyoidc/server/oauth2/token_helper/__init__.py index e9bbc96e..43c2a6ca 100644 --- a/src/idpyoidc/server/oauth2/token_helper/__init__.py +++ b/src/idpyoidc/server/oauth2/token_helper/__init__.py @@ -13,14 +13,13 @@ class TokenEndpointHelper(object): - def __init__(self, endpoint, config=None): self.endpoint = endpoint self.config = config self.error_cls = self.endpoint.error_cls def post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ): """Context specific parsing of the request. This is done after general request parsing and before processing @@ -33,15 +32,15 @@ def process_request(self, req: Union[Message, dict], **kwargs): raise NotImplementedError def _mint_token( - self, - token_class: str, - grant: Grant, - session_id: str, - client_id: str, - based_on: Optional[SessionToken] = None, - scope: Optional[list] = None, - token_args: Optional[dict] = None, - token_type: Optional[str] = "", + self, + token_class: str, + grant: Grant, + session_id: str, + client_id: str, + based_on: Optional[SessionToken] = None, + scope: Optional[list] = None, + token_args: Optional[dict] = None, + token_type: Optional[str] = "", ) -> SessionToken: _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager @@ -49,7 +48,8 @@ def _mint_token( if usage_rules: _exp_in = usage_rules.get("expires_in") else: - _exp_in = DEFAULT_TOKEN_LIFETIME + _token_handler = _mngr.token_handler[token_class] + _exp_in = _token_handler.lifetime token_args = token_args or {} for meth in _context.token_args_methods: @@ -95,8 +95,10 @@ def validate_resource_indicators_policy(request, context, **kwargs): resource_servers_per_client = kwargs.get("resource_servers_per_client", []) - if isinstance(resource_servers_per_client, - dict) and client_id not in resource_servers_per_client: + if ( + isinstance(resource_servers_per_client, dict) + and client_id not in resource_servers_per_client + ): return TokenErrorResponse( error="invalid_target", error_description=f"Resources for client {client_id} not found", @@ -154,14 +156,14 @@ def validate_token_exchange_policy(request, context, subject_token, **kwargs): ) if ( - "requested_token_type" in request - and request["requested_token_type"] == "urn:ietf:params:oauth:token-type:refresh_token" + "requested_token_type" in request + and request["requested_token_type"] == "urn:ietf:params:oauth:token-type:refresh_token" ): if "offline_access" not in subject_token.scope: return TokenErrorResponse( error="invalid_request", error_description=f"Exchange {request['subject_token_type']} to refresh token " - f"forbidden", + f"forbidden", ) scopes = request.get("scope", subject_token.scope) @@ -170,7 +172,7 @@ def validate_token_exchange_policy(request, context, subject_token, **kwargs): scopes = list(set(scopes).intersection(kwargs.get("scope"))) if scopes: request["scope"] = scopes - else: - request.pop("scope") + elif "scope" in request: + del request["scope"] return request diff --git a/src/idpyoidc/server/oauth2/token_helper/access_token.py b/src/idpyoidc/server/oauth2/token_helper/access_token.py index 96e64c1c..46dad9c7 100755 --- a/src/idpyoidc/server/oauth2/token_helper/access_token.py +++ b/src/idpyoidc/server/oauth2/token_helper/access_token.py @@ -19,7 +19,6 @@ class AccessTokenHelper(TokenEndpointHelper): - def process_request(self, req: Union[Message, dict], **kwargs): """ @@ -50,8 +49,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): _cinfo = self.endpoint.upstream_get("context").cdb.get(client_id) - if ("resource_indicators" in _cinfo - and "access_token" in _cinfo["resource_indicators"]): + if "resource_indicators" in _cinfo and "access_token" in _cinfo["resource_indicators"]: resource_indicators_config = _cinfo["resource_indicators"]["access_token"] else: resource_indicators_config = self.endpoint.kwargs.get("resource_indicators", None) @@ -66,12 +64,20 @@ def process_request(self, req: Union[Message, dict], **kwargs): if isinstance(req, TokenErrorResponse): return req - # if "grant_types_supported" in _context.cdb[client_id]: - # grant_types_supported = _context.cdb[client_id].get("grant_types_supported") - # else: - # grant_types_supported = _context.provider_info["grant_types_supported"] - grant = _session_info["grant"] + token_type = "Bearer" + + # Is DPOP supported + try: + _dpop_enabled = _context.add_on.get("dpop") + except AttributeError: + _dpop_enabled = False + + if _dpop_enabled: + _dpop_jkt = req.get("dpop_jkt") + if _dpop_jkt: + grant.extra["dpop_jkt"] = _dpop_jkt + token_type = "DPoP" _based_on = grant.get_token(_access_code) _supports_minting = _based_on.usage_rules.get("supports_minting", []) @@ -88,15 +94,18 @@ def process_request(self, req: Union[Message, dict], **kwargs): logger.debug("All checks OK") - issue_refresh = kwargs.get("issue_refresh", False) - if resource_indicators_config is not None: scope = req["scope"] else: scope = grant.scope + if "offline_access" in scope and "refresh_token" in _supports_minting: + issue_refresh = True + else: + issue_refresh = kwargs.get("issue_refresh", False) + _response = { - "token_type": "Bearer", + "token_type": token_type, "scope": scope, } @@ -115,7 +124,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): session_id=_session_info["branch_id"], client_id=_session_info["client_id"], based_on=_based_on, - token_args=token_args + token_args=token_args, ) except MintingNotAllowed as err: logger.warning(err) @@ -124,10 +133,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): if token.expires_at: _response["expires_in"] = token.expires_at - utc_time_sans_frac() - if ( - issue_refresh - and "refresh_token" in _supports_minting - ): + if issue_refresh and "refresh_token" in _supports_minting: try: refresh_token = self._mint_token( token_class="refresh_token", @@ -149,7 +155,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): return _response def _enforce_resource_indicators_policy(self, request, config): - _context = self.endpoint.upstream_get('context') + _context = self.endpoint.upstream_get("context") policy = config["policy"] function = policy["function"] @@ -169,7 +175,7 @@ def _enforce_resource_indicators_policy(self, request, config): return self.error_cls(error="server_error", error_description="Internal server error") def post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ): """ This is where clients come to get their access tokens diff --git a/src/idpyoidc/server/oauth2/token_helper/client_credentials.py b/src/idpyoidc/server/oauth2/token_helper/client_credentials.py index 2c37ba93..622efd3a 100755 --- a/src/idpyoidc/server/oauth2/token_helper/client_credentials.py +++ b/src/idpyoidc/server/oauth2/token_helper/client_credentials.py @@ -12,7 +12,6 @@ class ClientCredentials(TokenEndpointHelper): - def __init__(self, endpoint, config=None): TokenEndpointHelper.__init__(self, endpoint, config) @@ -23,32 +22,33 @@ def process_request(self, req: Union[Message, dict], **kwargs): # verify the client and the user - client_id = req['client_id'] + client_id = req["client_id"] _authenticated = req.get("authenticated", False) if not _authenticated: - if _context.cdb[client_id] != req['client_secret']: + if _context.cdb[client_id] != req["client_secret"]: logger.warning("Client authentication failed") return self.error_cls(error="invalid_request", error_description="Wrong client") - _grant_types_supported = _context.cdb[client_id].get('grant_types_supported') - if _grant_types_supported and 'client_credentials' not in _grant_types_supported: - return self.error_cls(error="invalid_request", - error_description="Unsupported grant type") + _grant_types_supported = _context.cdb[client_id].get("grant_types_supported") + if _grant_types_supported and "client_credentials" not in _grant_types_supported: + return self.error_cls( + error="invalid_request", error_description="Unsupported grant type" + ) # Is there a previous session ? try: - _session_info = _mngr.get(['client_credentials', client_id]) + _session_info = _mngr.get(["client_credentials", client_id]) _grant = _session_info["grant"] except KeyError: - logger.debug('No previous session') - branch_id = _mngr.add_grant(['client_credentials', client_id]) + logger.debug("No previous session") + branch_id = _mngr.add_grant(["client_credentials", client_id]) _session_info = _mngr.get_session_info(branch_id) _grant = _session_info["grant"] token_type = "Bearer" - _allowed = _context.cdb[client_id].get('allowed_scopes', []) + _allowed = _context.cdb[client_id].get("allowed_scopes", []) access_token = self._mint_token( token_class="access_token", grant=_grant, @@ -71,10 +71,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): return _resp def post_parse_request( - self, - request: Union[Message, dict], - client_id: Optional[str] = "", - **kwargs + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ): request = CCAccessTokenRequest(**request.to_dict()) logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) diff --git a/src/idpyoidc/server/oauth2/token_helper/refresh_token.py b/src/idpyoidc/server/oauth2/token_helper/refresh_token.py index 62341149..bb6150af 100755 --- a/src/idpyoidc/server/oauth2/token_helper/refresh_token.py +++ b/src/idpyoidc/server/oauth2/token_helper/refresh_token.py @@ -13,7 +13,6 @@ class RefreshTokenHelper(TokenEndpointHelper): - def process_request(self, req: Union[Message, dict], **kwargs): _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager @@ -84,9 +83,9 @@ def process_request(self, req: Union[Message, dict], **kwargs): token.register_usage() if ( - "client_id" in req - and req["client_id"] in _context.cdb - and "revoke_refresh_on_issue" in _context.cdb[req["client_id"]] + "client_id" in req + and req["client_id"] in _context.cdb + and "revoke_refresh_on_issue" in _context.cdb[req["client_id"]] ): revoke_refresh = _context.cdb[req["client_id"]].get("revoke_refresh_on_issue") else: @@ -98,7 +97,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): return _resp def post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ): """ This is where clients come to refresh their access tokens @@ -112,8 +111,8 @@ def post_parse_request( _context = self.endpoint.upstream_get("context") request.verify( - keyjar=self.endpoint.upstream_get('sttribute', 'keyjar'), - opponent_id=client_id) + keyjar=self.endpoint.upstream_get("sttribute", "keyjar"), opponent_id=client_id + ) _mngr = _context.session_manager try: diff --git a/src/idpyoidc/server/oauth2/token_helper/resource_owner_password_credentials.py b/src/idpyoidc/server/oauth2/token_helper/resource_owner_password_credentials.py index 75eee741..b6003f38 100755 --- a/src/idpyoidc/server/oauth2/token_helper/resource_owner_password_credentials.py +++ b/src/idpyoidc/server/oauth2/token_helper/resource_owner_password_credentials.py @@ -13,12 +13,11 @@ class ResourceOwnerPasswordCredentials(TokenEndpointHelper): - def __init__(self, endpoint, config=None): TokenEndpointHelper.__init__(self, endpoint, config) self.user_db = {} if config: - _db = config.get('db') + _db = config.get("db") if _db: _db_kwargs = _db.get("kwargs", {}) self.user_db = instantiate(_db["class"], **_db_kwargs) @@ -30,18 +29,18 @@ def process_request(self, req: Union[Message, dict], **kwargs): # verify the client and the user - client_id = req['client_id'] + client_id = req["client_id"] _cinfo = _context.cdb.get(client_id) if not _cinfo: - logger.error('Unknown client') + logger.error("Unknown client") return self.error_cls(error="invalid_grant", error_description="Unknown client") - if _cinfo['client_secret'] != req['client_secret']: + if _cinfo["client_secret"] != req["client_secret"]: logger.warning("Client secret mismatch") return self.error_cls(error="invalid_grant", error_description="Wrong client") _auth_method = None - _acr = kwargs.get('acr') + _acr = kwargs.get("acr") if _acr: _auth_method = _context.authn_broker.pick(_acr) else: @@ -51,14 +50,15 @@ def process_request(self, req: Union[Message, dict], **kwargs): logger.exception(f"An error occurred while picking the authN broker: {exc}") if not _auth_method: - return self.error_cls(error="invalid_request", - error_description="Can't authenticate user") + return self.error_cls( + error="invalid_request", error_description="Can't authenticate user" + ) authn = _auth_method["method"] # authn_class_ref = _auth_method["acr"] try: - _username = authn.verify(username=req['username'], password=req['password']) + _username = authn.verify(username=req["username"], password=req["password"]) except FailedAuthentication: logger.warning("User password did not match") return self.error_cls(error="invalid_grant", error_description="Wrong user") @@ -68,7 +68,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): _session_info = _mngr.get([_username, client_id]) _grant = _session_info["grant"] except KeyError: - logger.debug('No previous session') + logger.debug("No previous session") branch_id = _mngr.add_grant([_username, client_id]) _session_info = _mngr.get_session_info(branch_id) @@ -76,7 +76,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): token_type = "Bearer" - _allowed = _context.cdb[client_id].get('allowed_scopes', []) + _allowed = _context.cdb[client_id].get("allowed_scopes", []) access_token = self._mint_token( token_class="access_token", grant=_grant, @@ -90,7 +90,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): _resp = { "access_token": access_token.value, "token_type": access_token.token_class, - "scope": _allowed + "scope": _allowed, } if access_token.expires_at: @@ -99,9 +99,6 @@ def process_request(self, req: Union[Message, dict], **kwargs): return _resp def post_parse_request( - self, - request: Union[Message, dict], - client_id: Optional[str] = "", - **kwargs + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ): return request diff --git a/src/idpyoidc/server/oauth2/token_helper/token_exchange.py b/src/idpyoidc/server/oauth2/token_helper/token_exchange.py index 0b5a0524..9fdcb8e7 100755 --- a/src/idpyoidc/server/oauth2/token_helper/token_exchange.py +++ b/src/idpyoidc/server/oauth2/token_helper/token_exchange.py @@ -57,14 +57,13 @@ def post_parse_request(self, request, client_id="", **kwargs): try: request.verify( - keyjar=self.endpoint.upstream_get('attribute', 'keyjar'), - opponent_id=client_id + keyjar=self.endpoint.upstream_get("attribute", "keyjar"), opponent_id=client_id ) except ( - MissingRequiredAttribute, - ValueError, - MissingRequiredValue, - JWKESTException, + MissingRequiredAttribute, + ValueError, + MissingRequiredValue, + JWKESTException, ) as err: return self.endpoint.error_cls(error="invalid_request", error_description="%s" % err) @@ -133,8 +132,8 @@ def _enforce_policy(self, request, token, config): ) if ( - "requested_token_type" in request - and request["requested_token_type"] not in config["requested_token_types_supported"] + "requested_token_type" in request + and request["requested_token_type"] not in config["requested_token_types_supported"] ): return TokenErrorResponse( error="invalid_request", @@ -282,16 +281,15 @@ def _validate_configuration(self, config): if "policy" not in config: raise ImproperlyConfigured("Missing 'policy' from Token Exchange configuration") if "" not in config["policy"]: - raise ImproperlyConfigured( - "Default Token Exchange policy configuration is not defined" - ) + raise ImproperlyConfigured("Default Token Exchange policy configuration is not defined") if "function" not in config["policy"][""]: raise ImproperlyConfigured( "Missing 'function' from default Token Exchange policy configuration" ) - _default_requested_token_type = config.get("default_requested_token_type", - DEFAULT_REQUESTED_TOKEN_TYPE) + _default_requested_token_type = config.get( + "default_requested_token_type", DEFAULT_REQUESTED_TOKEN_TYPE + ) if _default_requested_token_type not in config["requested_token_types_supported"]: raise ImproperlyConfigured( f"Unsupported default requested_token_type {_default_requested_token_type}" @@ -300,11 +298,9 @@ def _validate_configuration(self, config): def get_handler_key(self, request, endpoint_context): client_info = endpoint_context.cdb.get(request["client_id"], {}) - default_requested_token_type = ( - client_info.get("token_exchange", {}).get("default_requested_token_type", None) - or - self.config.get("default_requested_token_type", DEFAULT_REQUESTED_TOKEN_TYPE) - ) + default_requested_token_type = client_info.get("token_exchange", {}).get( + "default_requested_token_type", None + ) or self.config.get("default_requested_token_type", DEFAULT_REQUESTED_TOKEN_TYPE) requested_token_type = request.get("requested_token_type", default_requested_token_type) return TOKEN_TYPES_MAPPING[requested_token_type] diff --git a/src/idpyoidc/server/oauth2/token_revocation.py b/src/idpyoidc/server/oauth2/token_revocation.py index 7db5e184..d36ed28b 100644 --- a/src/idpyoidc/server/oauth2/token_revocation.py +++ b/src/idpyoidc/server/oauth2/token_revocation.py @@ -19,7 +19,8 @@ class TokenRevocation(Endpoint): response_cls = oauth2.TokenRevocationResponse error_cls = oauth2.TokenRevocationErrorResponse request_format = "urlencoded" - response_format = "json" + response_format = "text" + response_body_type = "text" endpoint_name = "revocation_endpoint" name = "token_revocation" default_capabilities = { @@ -77,16 +78,19 @@ def process_request(self, request=None, **kwargs): try: self.token_types_supported = _context.cdb[client_id]["token_revocation"][ - "token_types_supported"] + "token_types_supported" + ] except Exception: - self.token_types_supported = self.token_revocation_kwargs.get("token_types_supported", - self.token_types_supported) + self.token_types_supported = self.token_revocation_kwargs.get( + "token_types_supported", self.token_types_supported + ) try: self.policy = _context.cdb[client_id]["token_revocation"]["policy"] except Exception: - self.policy = self.token_revocation_kwargs.get("policy", { - "": {"function": validate_token_revocation_policy}}) + self.policy = self.token_revocation_kwargs.get( + "policy", {"": {"function": validate_token_revocation_policy}} + ) if _token.token_class not in self.token_types_supported: desc = ( @@ -130,5 +134,5 @@ def validate_token_revocation_policy(token, session_info, **kwargs): _token = token _token.revoke() - response_args = {"response_args": {}} - return oauth2.TokenRevocationResponse(**response_args) + response_args = {"response_msg": "OK"} + return response_args diff --git a/src/idpyoidc/server/oidc/authorization.py b/src/idpyoidc/server/oidc/authorization.py index ac14a754..ca441f69 100644 --- a/src/idpyoidc/server/oidc/authorization.py +++ b/src/idpyoidc/server/oidc/authorization.py @@ -87,11 +87,10 @@ class Authorization(authorization.Authorization): "request_parameter_supported": True, "request_uri_parameter_supported": True, "require_request_uri_registration": False, - "response_types_supported": ["code", "token", "code token", 'id_token', 'id_token token', - 'code id_token', 'code id_token token'], - "response_modes_supported": ['query', 'fragment', 'form_post'], + "response_types_supported": ["code", "id_token", "code id_token"], + "response_modes_supported": ["query", "fragment", "form_post"], "subject_types_supported": ["public", "pairwise", "ephemeral"], - }, + }, } def __init__(self, upstream_get: Callable, **kwargs): diff --git a/src/idpyoidc/server/oidc/backchannel_authentication.py b/src/idpyoidc/server/oidc/backchannel_authentication.py index 50350590..b94dbdf0 100644 --- a/src/idpyoidc/server/oidc/backchannel_authentication.py +++ b/src/idpyoidc/server/oidc/backchannel_authentication.py @@ -66,7 +66,7 @@ def do_request_user(self, request): _context = self.upstream_get("context") _request_user = execute( self.parse_login_hint_token, - keyjar=self.upstream_get('attribute', 'keyjar'), + keyjar=self.upstream_get("attribute", "keyjar"), login_hint_token=request.get("login_hint_token"), context=_context, ) @@ -85,10 +85,10 @@ def allowed_target_uris(self): return set(res) def process_request( - self, - request: Optional[Union[Message, dict]] = None, - http_info: Optional[dict] = None, - **kwargs, + self, + request: Optional[Union[Message, dict]] = None, + http_info: Optional[dict] = None, + **kwargs, ): try: request_user = self.do_request_user(request) @@ -136,7 +136,7 @@ def _get_session_info(self, request, session_manager): return session_info, _grant def post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ) -> Union[Message, dict]: _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager @@ -302,10 +302,10 @@ def __init__(self, upstream_get: Callable, **kwargs): Endpoint.__init__(self, upstream_get, **kwargs) def process_request( - self, - request: Optional[Union[Message, dict]] = None, - http_info: Optional[dict] = None, - **kwargs, + self, + request: Optional[Union[Message, dict]] = None, + http_info: Optional[dict] = None, + **kwargs, ) -> Union[Message, dict]: return {} @@ -321,11 +321,11 @@ def is_usable(self, request=None, authorization_token=None): return False def _verify( - self, - authorization_token: Optional[str] = None, - endpoint=None, # Optional[Endpoint] - get_client_id_from_token: Optional[Callable] = None, - **kwargs, + self, + authorization_token: Optional[str] = None, + endpoint=None, # Optional[Endpoint] + get_client_id_from_token: Optional[Callable] = None, + **kwargs, ): ttype, token = authorization_token.split(" ", 1) if ttype != "Bearer": diff --git a/src/idpyoidc/server/oidc/registration.py b/src/idpyoidc/server/oidc/registration.py index 7b9d4a7f..5437f74e 100644 --- a/src/idpyoidc/server/oidc/registration.py +++ b/src/idpyoidc/server/oidc/registration.py @@ -153,8 +153,6 @@ def match_claim(self, claim, val): if isinstance(val, str): if val in _val: return val - else: - return None else: _ret = list(set(_val).intersection(set(val))) if len(_ret) > 0: @@ -162,12 +160,13 @@ def match_claim(self, claim, val): else: raise CapabilitiesMisMatch(_my_key) else: - if val == _val: + if isinstance(_val, list): + if val in _val: + return val + elif val == _val: return val - else: - return None - else: - return None + + return None def filter_client_request(self, request: dict) -> dict: _args = {} @@ -250,12 +249,18 @@ def do_client_registration(self, request, client_id, ignore=None): error_description="%s pointed to illegal URL" % item, ) - _keyjar = self.upstream_get('attribute', 'keyjar') + _keyjar = self.upstream_get("attribute", "keyjar") # Do I have the necessary keys for item in ["id_token_signed_response_alg", "userinfo_signed_response_alg"]: if item in request: - if request[item] in _context.provider_info[ - _context.claims.register2preferred[item]]: + _claim = _context.claims.register2preferred[item] + _support = _context.provider_info.get(_claim) + if _support is None: + logger.warning(f'Lacking support for "{item}"') + del _cinfo[item] + continue + + if request[item] in _support: ktyp = alg2keytype(request[item]) # do I have this ktyp and for EC type keys the curve if ktyp not in ["none", "oct"]: @@ -465,7 +470,7 @@ def client_registration_setup(self, request, new_id=True, set_secret=True): # Add the client_secret as a symmetric key to the key jar if client_secret: - self.upstream_get('attribute', 'keyjar').add_symmetric(client_id, str(client_secret)) + self.upstream_get("attribute", "keyjar").add_symmetric(client_id, str(client_secret)) logger.debug("Stored updated client info in CDB under cid={}".format(client_id)) logger.debug("ClientInfo: {}".format(_cinfo)) diff --git a/src/idpyoidc/server/oidc/session.py b/src/idpyoidc/server/oidc/session.py index 99e30b0c..ee1a8460 100644 --- a/src/idpyoidc/server/oidc/session.py +++ b/src/idpyoidc/server/oidc/session.py @@ -136,10 +136,12 @@ def do_back_channel_logout(self, cinfo, sid): except KeyError: alg = _context.provider_info["id_token_signing_alg_values_supported"][0] - _jws = JWT(self.upstream_get('attribute', 'keyjar'), - iss=_context.issuer, - lifetime=86400, - sign_alg=alg) + _jws = JWT( + self.upstream_get("attribute", "keyjar"), + iss=_context.issuer, + lifetime=86400, + sign_alg=alg, + ) _jws.with_jti = True _logout_token = _jws.pack(payload=payload, recv=cinfo["client_id"]) @@ -221,7 +223,7 @@ def unpack_signed_jwt(self, sjwt, sig_alg=""): else: alg = self.kwargs["signing_alg"] - sign_keys = self.upstream_get('attribute', 'keyjar').get_signing_key(alg2keytype(alg)) + sign_keys = self.upstream_get("attribute", "keyjar").get_signing_key(alg2keytype(alg)) _info = _jwt.verify_compact(keys=sign_keys, sigalg=alg) return _info else: @@ -342,7 +344,7 @@ def process_request( logger.debug("JWS payload: {}".format(payload)) # From me to me _jws = JWT( - self.upstream_get('attribute', 'keyjar'), + self.upstream_get("attribute", "keyjar"), iss=_context.issuer, lifetime=86400, sign_alg=self.kwargs["signing_alg"], @@ -377,7 +379,7 @@ def parse_request(self, request, http_info=None, **kwargs): if isinstance(request, dict): _context = self.upstream_get("context") request = self.request_cls(**request) - if not request.verify(keyjar=self.upstream_get('attribute', 'keyjar'), sigalg=""): + if not request.verify(keyjar=self.upstream_get("attribute", "keyjar"), sigalg=""): raise InvalidRequest("Request didn't verify") # id_token_signing_alg_values_supported try: diff --git a/src/idpyoidc/server/oidc/token_helper/access_token.py b/src/idpyoidc/server/oidc/token_helper/access_token.py index bad2873b..3431121b 100755 --- a/src/idpyoidc/server/oidc/token_helper/access_token.py +++ b/src/idpyoidc/server/oidc/token_helper/access_token.py @@ -17,7 +17,6 @@ class AccessTokenHelper(TokenEndpointHelper): - def _get_session_info(self, request, session_manager): if request["grant_type"] != "authorization_code": return self.error_cls(error="invalid_request", error_description="Unknown grant_type") @@ -63,10 +62,10 @@ def process_request(self, req: Union[Message, dict], **kwargs): token_type = "Bearer" # Is DPOP supported - try: - _dpop_enabled = _context.dpop_enabled - except AttributeError: - _dpop_enabled = False + _dpop_enabled = False + _dpop_args = _context.add_on.get("dpop") + if _dpop_args: + _dpop_enabled = True if _dpop_enabled: _dpop_jkt = req.get("dpop_jkt") @@ -116,10 +115,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): if token.expires_at: _response["expires_in"] = token.expires_at - utc_time_sans_frac() - if ( - issue_refresh - and "refresh_token" in _supports_minting - ): + if issue_refresh and "refresh_token" in _supports_minting: try: refresh_token = self._mint_token( token_class="refresh_token", @@ -161,7 +157,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): return _response def post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ) -> Union[Message, dict]: """ This is where clients come to get their access tokens diff --git a/src/idpyoidc/server/oidc/token_helper/refresh_token.py b/src/idpyoidc/server/oidc/token_helper/refresh_token.py index 534109a3..80792dcd 100755 --- a/src/idpyoidc/server/oidc/token_helper/refresh_token.py +++ b/src/idpyoidc/server/oidc/token_helper/refresh_token.py @@ -18,8 +18,8 @@ logger = logging.getLogger(__name__) -class RefreshTokenHelper(TokenEndpointHelper): +class RefreshTokenHelper(TokenEndpointHelper): def process_request(self, req: Union[Message, dict], **kwargs): _context = self.endpoint.upstream_get("context") _mngr = _context.session_manager @@ -113,9 +113,9 @@ def process_request(self, req: Union[Message, dict], **kwargs): token.register_usage() if ( - "client_id" in req - and req["client_id"] in _context.cdb - and "revoke_refresh_on_issue" in _context.cdb[req["client_id"]] + "client_id" in req + and req["client_id"] in _context.cdb + and "revoke_refresh_on_issue" in _context.cdb[req["client_id"]] ): revoke_refresh = _context.cdb[req["client_id"]].get("revoke_refresh_on_issue") else: @@ -127,10 +127,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): return _resp def post_parse_request( - self, - request: Union[Message, dict], - client_id: Optional[str] = "", - **kwargs + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs ): """ This is where clients come to refresh their access tokens @@ -143,8 +140,9 @@ def post_parse_request( request = RefreshAccessTokenRequest(**request.to_dict()) _context = self.endpoint.upstream_get("context") - request.verify(keyjar=self.endpoint.upstream_get('attribute', 'keyjar'), - opponent_id=client_id) + request.verify( + keyjar=self.endpoint.upstream_get("attribute", "keyjar"), opponent_id=client_id + ) _mngr = _context.session_manager try: @@ -176,4 +174,3 @@ def post_parse_request( ) return request - diff --git a/src/idpyoidc/server/oidc/token_helper/token_exchange.py b/src/idpyoidc/server/oidc/token_helper/token_exchange.py index 39025a56..2246f71a 100755 --- a/src/idpyoidc/server/oidc/token_helper/token_exchange.py +++ b/src/idpyoidc/server/oidc/token_helper/token_exchange.py @@ -1,7 +1,8 @@ import logging -from idpyoidc.server.oauth2.token_helper.token_exchange import TokenExchangeHelper as \ - OAuth2TokenExchangeHelper +from idpyoidc.server.oauth2.token_helper.token_exchange import ( + TokenExchangeHelper as OAuth2TokenExchangeHelper, +) logger = logging.getLogger(__name__) diff --git a/src/idpyoidc/server/oidc/userinfo.py b/src/idpyoidc/server/oidc/userinfo.py index 58ffb107..962c0326 100755 --- a/src/idpyoidc/server/oidc/userinfo.py +++ b/src/idpyoidc/server/oidc/userinfo.py @@ -26,7 +26,7 @@ class UserInfo(Endpoint): request_cls = Message response_cls = oidc.OpenIDSchema request_format = "json" - response_format = "json" + response_format = "jose" response_placement = "body" endpoint_name = "userinfo_endpoint" name = "userinfo" @@ -38,7 +38,9 @@ class UserInfo(Endpoint): "userinfo_encryption_enc_values_supported": claims.get_encryption_encs, } - def __init__(self, upstream_get: Callable, add_claims_by_scope: Optional[bool] = True, **kwargs): + def __init__( + self, upstream_get: Callable, add_claims_by_scope: Optional[bool] = True, **kwargs + ): Endpoint.__init__( self, upstream_get, @@ -47,29 +49,18 @@ def __init__(self, upstream_get: Callable, add_claims_by_scope: Optional[bool] = ) # Add the issuer ID as an allowed JWT target self.allowed_targets.append("") + self.config = kwargs or {} - if kwargs is None: - self.config = { - "policy": { - "function": "/path/to/callable", - "kwargs": {} - }, - } - else: - self.config = kwargs - - def get_client_id_from_token(self, endpoint_context, token, request=None): - _info = endpoint_context.session_manager.get_session_info_by_token( - token, handler_key="access_token" - ) + def get_client_id_from_token(self, context, token, request=None): + _info = context.session_manager.get_session_info_by_token(token, handler_key="access_token") return _info["client_id"] def do_response( - self, - response_args: Optional[Union[Message, dict]] = None, - request: Optional[Union[Message, dict]] = None, - client_id: Optional[str] = "", - **kwargs + self, + response_args: Optional[Union[Message, dict]] = None, + request: Optional[Union[Message, dict]] = None, + client_id: Optional[str] = "", + **kwargs, ) -> dict: if "error" in kwargs and kwargs["error"]: @@ -100,7 +91,7 @@ def do_response( if encrypt or sign: _jwt = JWT( - self.upstream_get('attribute', 'keyjar'), + self.upstream_get("attribute", "keyjar"), iss=_context.issuer, sign=sign, sign_alg=sign_alg, @@ -144,7 +135,7 @@ def process_request(self, request=None, **kwargs): allowed = True _auth_event = _grant.authentication_event - # if the authenticate is still active or offline_access is granted. + # if the authentication is still active or offline_access is granted. if not _auth_event["valid_until"] >= utc_time_sans_frac(): logger.debug( "authentication not valid: {} > {}".format( @@ -207,7 +198,14 @@ def parse_request(self, request, http_info=None, **kwargs): request["client_id"] = auth_info["client_id"] request["access_token"] = auth_info["token"] - return request + # Do any endpoint specific parsing + return self.do_post_parse_request( + request=request, + client_id=auth_info["client_id"], + http_info=http_info, + auth_info=auth_info, + **kwargs, + ) def _enforce_policy(self, request, response_info, token, config): policy = config["policy"] diff --git a/src/idpyoidc/server/session/claims.py b/src/idpyoidc/server/session/claims.py index 179ce4ca..fef2c953 100755 --- a/src/idpyoidc/server/session/claims.py +++ b/src/idpyoidc/server/session/claims.py @@ -1,4 +1,5 @@ import logging +from typing import List from typing import Optional from typing import Union @@ -26,8 +27,10 @@ class ClaimsInterface: init_args = {"add_claims_by_scope": False, "enable_claims_per_client": False} claims_release_points = ["userinfo", "introspection", "id_token", "access_token"] - def __init__(self, upstream_get): + def __init__(self, upstream_get, claims_release_points:List[str] = None): self.upstream_get = upstream_get + if claims_release_points: + self.claims_release_points = claims_release_points def authorization_request_claims( self, @@ -168,6 +171,7 @@ def get_claims( auth_req = grant.authorization_request else: auth_req = {} + claims = self.get_claims_from_request( auth_req=auth_req, claims_release_point=claims_release_point, diff --git a/src/idpyoidc/server/session/database.py b/src/idpyoidc/server/session/database.py index 2985a083..1a8191ff 100644 --- a/src/idpyoidc/server/session/database.py +++ b/src/idpyoidc/server/session/database.py @@ -44,16 +44,16 @@ def __init__(self, crypt_config: Optional[dict] = None, **kwargs): @staticmethod def branch_key(*args): - """ Construct a key using a list of names """ + """Construct a key using a list of names""" return DIVIDER.join(args) @staticmethod def unpack_branch_key(key): - """ Translate a key into an ordered list of names """ + """Translate a key into an ordered list of names""" return key.split(DIVIDER) def encrypted_branch_id(self, *args) -> str: - """ Provided an ordered list of names construct a key and then encrypt it. """ + """Provided an ordered list of names construct a key and then encrypt it.""" rnd = rndstr(32) return base64.b64encode( self.crypt.encrypt(lv_pack(rnd, self.branch_key(*args)).encode()) @@ -61,8 +61,8 @@ def encrypted_branch_id(self, *args) -> str: def decrypt_branch_id(self, key: str) -> List[str]: """ - Given an encrypted key, decrypt it and then unpack the key to return an ordered list - of names. + Given an encrypted key, decrypt it and then unpack the key to return an ordered list + of names. """ try: plain = self.crypt.decrypt(base64.b64decode(key)) @@ -88,7 +88,7 @@ def set(self, path: List[str], value: Union[NodeInfo, Grant]): _superior = None for i in range(_len): - _key = self.branch_key(*path[0:i + 1]) + _key = self.branch_key(*path[0 : i + 1]) # _key = path[i] _info = self.db.get(_key) if _info is None: @@ -115,7 +115,7 @@ def set(self, path: List[str], value: Union[NodeInfo, Grant]): _superior = _info def get(self, path: List[str]) -> Union[NodeInfo, Grant]: - """ Given a path return the node that matches the path. """ + """Given a path return the node that matches the path.""" _key = self.branch_key(*path) return self.db[_key] @@ -156,7 +156,7 @@ def delete(self, path: List[str]): _sub = None for i in range(0, len(path)): - _key = self.branch_key(*path[0:_len - i]) + _key = self.branch_key(*path[0 : _len - i]) if _key in self.db: _node = self.db[_key] if _sub: diff --git a/src/idpyoidc/server/session/grant.py b/src/idpyoidc/server/session/grant.py index 6f193adb..e59c54d2 100644 --- a/src/idpyoidc/server/session/grant.py +++ b/src/idpyoidc/server/session/grant.py @@ -31,11 +31,11 @@ class GrantMessage(ImpExp): } def __init__( - self, - scope: Optional[str] = "", - authorization_details: Optional[dict] = None, - claims: Optional[list] = None, - resources: Optional[list] = None, + self, + scope: Optional[str] = "", + authorization_details: Optional[dict] = None, + claims: Optional[list] = None, + resources: Optional[list] = None, ): ImpExp.__init__(self) self.scope = scope @@ -104,24 +104,24 @@ class Grant(Item): } def __init__( - self, - scope: Optional[list] = None, - claims: Optional[dict] = None, - resources: Optional[list] = None, - authorization_details: Optional[dict] = None, - authorization_request: Optional[Message] = None, - authentication_event: Optional[AuthnEvent] = None, - issued_token: Optional[list] = None, - usage_rules: Optional[dict] = None, - issued_at: int = 0, - expires_in: int = 0, - expires_at: int = 0, - revoked: bool = False, - token_map: Optional[dict] = None, - sub: Optional[str] = "", - extra: Optional[Dict[str, str]] = None, - remember_token: Optional[Callable] = None, - remove_inactive_token: Optional[bool] = False, + self, + scope: Optional[list] = None, + claims: Optional[dict] = None, + resources: Optional[list] = None, + authorization_details: Optional[dict] = None, + authorization_request: Optional[Message] = None, + authentication_event: Optional[AuthnEvent] = None, + issued_token: Optional[list] = None, + usage_rules: Optional[dict] = None, + issued_at: int = 0, + expires_in: int = 0, + expires_at: int = 0, + revoked: bool = False, + token_map: Optional[dict] = None, + sub: Optional[str] = "", + extra: Optional[Dict[str, str]] = None, + remember_token: Optional[Callable] = None, + remove_inactive_token: Optional[bool] = False, ): Item.__init__( self, @@ -171,6 +171,9 @@ def find_scope(self, based_on): return self.scope def add_acr_value(self, claims_release_point): + # if claims_release_point == "userinfo": + # return False + _release = self.claims.get(claims_release_point) if _release: _acr_request = _release.get("acr") @@ -179,14 +182,14 @@ def add_acr_value(self, claims_release_point): return False def payload_arguments( - self, - session_id: str, - context: object, - item: SessionToken, - claims_release_point: str, - scope: Optional[dict] = None, - extra_payload: Optional[dict] = None, - secondary_identifier: str = "", + self, + session_id: str, + context: object, + item: SessionToken, + claims_release_point: str, + scope: Optional[dict] = None, + extra_payload: Optional[dict] = None, + secondary_identifier: str = "", ) -> dict: """ @@ -238,10 +241,9 @@ def payload_arguments( secondary_identifier=secondary_identifier, ) - if context.session_manager.node_type[0] == "user": + if _claims_restriction and context.session_manager.node_type[0] == "user": user_id, _, _ = context.session_manager.decrypt_branch_id(session_id) - user_info = context.claims_interface.get_user_claims(user_id, - _claims_restriction) + user_info = context.claims_interface.get_user_claims(user_id, _claims_restriction) payload.update(user_info) # Should I add the acr value @@ -253,19 +255,19 @@ def payload_arguments( return payload def mint_token( - self, - session_id: str, - context: object, - token_class: str, - token_handler: TokenHandler = None, - based_on: Optional[SessionToken] = None, - usage_rules: Optional[dict] = None, - scope: Optional[list] = None, - token_type: Optional[str] = "", - expires_in: Optional[int] = 0, - not_before: Optional[int] = 0, - claims: Optional[List[str]] = None, - **kwargs, + self, + session_id: str, + context: object, + token_class: str, + token_handler: TokenHandler = None, + based_on: Optional[SessionToken] = None, + usage_rules: Optional[dict] = None, + scope: Optional[list] = None, + token_type: Optional[str] = "", + expires_in: Optional[int] = 0, + not_before: Optional[int] = 0, + claims: Optional[List[str]] = None, + **kwargs, ) -> Optional[SessionToken]: """ @@ -389,7 +391,7 @@ def get_token(self, value: str) -> Optional[SessionToken]: return None def revoke_token( - self, value: Optional[str] = "", based_on: Optional[str] = "", recursive: bool = True + self, value: Optional[str] = "", based_on: Optional[str] = "", recursive: bool = True ): remain = [] for t in self.issued_token: @@ -484,24 +486,24 @@ class ExchangeGrant(Grant): type = "exchange_grant" def __init__( - self, - scope: Optional[list] = None, - claims: Optional[dict] = None, - resources: Optional[list] = None, - authorization_details: Optional[dict] = None, - authorization_request: Optional[Message] = None, - authentication_event: Optional[AuthnEvent] = None, - issued_token: Optional[list] = None, - usage_rules: Optional[dict] = None, - exchange_request: Optional[TokenExchangeRequest] = None, - original_branch_id: str = "", - issued_at: int = 0, - expires_in: int = 0, - expires_at: int = 0, - revoked: bool = False, - token_map: Optional[dict] = None, - users: list = None, - sub: Optional[str] = "", + self, + scope: Optional[list] = None, + claims: Optional[dict] = None, + resources: Optional[list] = None, + authorization_details: Optional[dict] = None, + authorization_request: Optional[Message] = None, + authentication_event: Optional[AuthnEvent] = None, + issued_token: Optional[list] = None, + usage_rules: Optional[dict] = None, + exchange_request: Optional[TokenExchangeRequest] = None, + original_branch_id: str = "", + issued_at: int = 0, + expires_in: int = 0, + expires_at: int = 0, + revoked: bool = False, + token_map: Optional[dict] = None, + users: list = None, + sub: Optional[str] = "", ): Grant.__init__( self, @@ -529,14 +531,14 @@ def __init__( self.original_branch_id = original_branch_id def payload_arguments( - self, - session_id: str, - endpoint_context, - item: SessionToken, - claims_release_point: str, - scope: Optional[dict] = None, - extra_payload: Optional[dict] = None, - secondary_identifier: str = "", + self, + session_id: str, + endpoint_context, + item: SessionToken, + claims_release_point: str, + scope: Optional[dict] = None, + extra_payload: Optional[dict] = None, + secondary_identifier: str = "", ) -> dict: """ :param session_id: Session ID diff --git a/src/idpyoidc/server/session/grant_manager.py b/src/idpyoidc/server/session/grant_manager.py index ec99134c..3d357aaa 100644 --- a/src/idpyoidc/server/session/grant_manager.py +++ b/src/idpyoidc/server/session/grant_manager.py @@ -28,17 +28,17 @@ class GrantManager(Database): init_args = ["handler"] def __init__( - self, - handler: TokenHandler, - conf: Optional[dict] = None, - remember_token: Optional[Callable] = None, - remove_inactive_token: Optional[bool] = False, + self, + handler: TokenHandler, + conf: Optional[dict] = None, + remember_token: Optional[Callable] = None, + remove_inactive_token: Optional[bool] = False, ): self.conf = conf or { "session_params": { "encrypter": default_crypt_config(), "node_type": ["client", "grant"], - "node_info_class": {"client": ClientSessionInfo, "grant": Grant} + "node_info_class": {"client": ClientSessionInfo, "grant": Grant}, } } @@ -69,7 +69,7 @@ def __setitem__(self, branch_id: str, value): def _setup_branch(self, path): for i in range(len(path)): - _id = path[0:i + 1] + _id = path[0 : i + 1] try: _si = self.get(_id) @@ -81,16 +81,16 @@ def _setup_branch(self, path): def _get_nodes(self, path): res = [] for i in range(len(path)): - _id = path[0:i + 1] + _id = path[0 : i + 1] res.append(self.get(_id)) return res def add_grant( - self, - path: List[str], - token_usage_rules: Optional[dict] = None, - scope: Optional[list] = None, - **kwargs + self, + path: List[str], + token_usage_rules: Optional[dict] = None, + scope: Optional[list] = None, + **kwargs, ) -> str: """ Creates a Grant instance and adds it as a leaf to a branch @@ -111,7 +111,7 @@ def add_grant( remember_token=self.remember_token, remove_inactive_token=self.remove_inactive_token, scope=scope, - **grant_args + **grant_args, ) _id = path[:] @@ -121,12 +121,12 @@ def add_grant( return self.encrypted_branch_id(*_id) def add_exchange_grant( - self, - exchange_request: TokenExchangeRequest, - original_branch_id: str, - path: List[str], - token_usage_rules: Optional[dict] = None, - **grant_args + self, + exchange_request: TokenExchangeRequest, + original_branch_id: str, + path: List[str], + token_usage_rules: Optional[dict] = None, + **grant_args, ) -> str: """ @@ -156,8 +156,9 @@ def add_exchange_grant( return self.encrypted_branch_id(*_id) - def get_node_info(self, branch_id: str, level: Optional[int] = None, - node_type: Optional[str] = None) -> (str, NodeInfo): + def get_node_info( + self, branch_id: str, level: Optional[int] = None, node_type: Optional[str] = None + ) -> (str, NodeInfo): """ Return session information for a specific node in the grant path. @@ -173,7 +174,7 @@ def get_node_info(self, branch_id: str, level: Optional[int] = None, else: raise ValueError("One of level or node_type MUST be defined") - return _path[level], self.get(_path[0:level + 1]) + return _path[level], self.get(_path[0 : level + 1]) def branch_info(self, branch_id: str, *args) -> dict: """ @@ -230,7 +231,7 @@ def revoke_sub_tree(self, branch_id: str, level: Optional[int] = None): else: if level > len(_path): raise ValueError("Looking for level beyond what is available") - _node = self.get(_path[0:level + 1]) + _node = self.get(_path[0 : level + 1]) self._revoke_tree(_node) def _grants(self, path): @@ -243,9 +244,9 @@ def _grants(self, path): return _res def grants( - self, - branch_id: Optional[str] = "", - path: Optional[List[str]] = "", + self, + branch_id: Optional[str] = "", + path: Optional[List[str]] = "", ) -> List[Grant]: """ Find all grants connected to a branch diff --git a/src/idpyoidc/server/session/manager.py b/src/idpyoidc/server/session/manager.py index 6a33f8ac..8c017f27 100644 --- a/src/idpyoidc/server/session/manager.py +++ b/src/idpyoidc/server/session/manager.py @@ -48,7 +48,7 @@ def __init__(self, salt: Optional[str] = "", filename: Optional[str] = ""): if os.path.isfile(filename): self.salt = open(filename).read() elif not os.path.isfile(filename) and os.path.exists( - filename + filename ): # Not a file, Something else raise ConfigurationError("Salt filename points to something that is not a file") else: @@ -83,12 +83,12 @@ class SessionManager(GrantManager): init_args = ["handler"] def __init__( - self, - handler: TokenHandler, - conf: Optional[dict] = None, - sub_func: Optional[dict] = None, - remember_token: Optional[Callable] = None, - remove_inactive_token: Optional[bool] = False, + self, + handler: TokenHandler, + conf: Optional[dict] = None, + sub_func: Optional[dict] = None, + remember_token: Optional[Callable] = None, + remove_inactive_token: Optional[bool] = False, ): self.conf = conf or {"session_params": {"encrypter": default_crypt_config()}} @@ -102,12 +102,10 @@ def __init__( if len(self.node_type) == 0: raise ValueError("SessionManager node_type must at least contain one value") - self.node_info_class = session_params.get("node_info_class", - { - "user": UserSessionInfo, - "client": ClientSessionInfo, - "grant": Grant - }) + self.node_info_class = session_params.get( + "node_info_class", + {"user": UserSessionInfo, "client": ClientSessionInfo, "grant": Grant}, + ) self.token_handler = handler self.remember_token = remember_token @@ -161,14 +159,14 @@ def make_path(self, **kwargs): return _path def create_grant( - self, - authn_event: AuthnEvent, - auth_req: AuthorizationRequest, - user_id: Optional[str] = "", - client_id: Optional[str] = "", - sub_type: Optional[str] = "public", - token_usage_rules: Optional[dict] = None, - scopes: Optional[list] = None, + self, + authn_event: AuthnEvent, + auth_req: AuthorizationRequest, + user_id: Optional[str] = "", + client_id: Optional[str] = "", + sub_type: Optional[str] = "public", + token_usage_rules: Optional[dict] = None, + scopes: Optional[list] = None, ) -> str: """ @@ -211,15 +209,15 @@ def create_grant( ) def create_exchange_grant( - self, - exchange_request: TokenExchangeRequest, - original_grant: Grant, - original_session_id: str, - user_id: str, - client_id: Optional[str] = "", - sub_type: Optional[str] = "public", - token_usage_rules: Optional[dict] = None, - scopes: Optional[list] = None, + self, + exchange_request: TokenExchangeRequest, + original_grant: Grant, + original_session_id: str, + user_id: str, + client_id: Optional[str] = "", + sub_type: Optional[str] = "public", + token_usage_rules: Optional[dict] = None, + scopes: Optional[list] = None, ) -> str: """ @@ -239,18 +237,18 @@ def create_exchange_grant( path=self.make_path(user_id=user_id, client_id=client_id), sub=original_grant.sub, token_usage_rules=token_usage_rules, - scope=scopes + scope=scopes, ) def create_session( - self, - authn_event: AuthnEvent, - auth_req: AuthorizationRequest, - user_id: Optional[str] = "", - client_id: Optional[str] = "", - sub_type: Optional[str] = "public", - token_usage_rules: Optional[dict] = None, - scopes: Optional[list] = None, + self, + authn_event: AuthnEvent, + auth_req: AuthorizationRequest, + user_id: Optional[str] = "", + client_id: Optional[str] = "", + sub_type: Optional[str] = "public", + token_usage_rules: Optional[dict] = None, + scopes: Optional[list] = None, ) -> str: """ Create part of a user session. The parts added are user- and client @@ -277,15 +275,15 @@ def create_session( ) def create_exchange_session( - self, - exchange_request: TokenExchangeRequest, - original_grant: Grant, - original_session_id: str, - user_id: str, - client_id: Optional[str] = "", - sub_type: Optional[str] = "public", - token_usage_rules: Optional[dict] = None, - scopes: Optional[list] = None, + self, + exchange_request: TokenExchangeRequest, + original_grant: Grant, + original_session_id: str, + user_id: str, + client_id: Optional[str] = "", + sub_type: Optional[str] = "public", + token_usage_rules: Optional[dict] = None, + scopes: Optional[list] = None, ) -> str: """ Create part of a user session. The parts added are user- and client @@ -319,7 +317,7 @@ def get_client_session_info(self, session_id: str) -> ClientSessionInfo: :param session_id: Session identifier :return: ClientSessionInfo instance """ - _id, csi = self.get_node_info(session_id, node_type='client') + _id, csi = self.get_node_info(session_id, node_type="client") if isinstance(csi, ClientSessionInfo): return csi @@ -333,7 +331,7 @@ def get_user_session_info(self, session_id: str) -> UserSessionInfo: :param session_id: Session identifier :return: ClientSessionInfo instance """ - _id, usi = self.get_node_info(session_id, node_type='user') + _id, usi = self.get_node_info(session_id, node_type="user") if isinstance(usi, UserSessionInfo): return usi @@ -347,7 +345,7 @@ def get_grant(self, session_id: str) -> Grant: :param session_id: Session identifier :return: ClientSessionInfo instance """ - _id, grant = self.get_node_info(session_id, node_type='grant') + _id, grant = self.get_node_info(session_id, node_type="grant") if isinstance(grant, Grant): return grant @@ -373,10 +371,10 @@ def revoke_token(self, session_id: str, token_value: str, recursive: bool = Fals grant.revoke_token(value=token.value) def get_authentication_events( - self, - session_id: Optional[str] = "", - user_id: Optional[str] = "", - client_id: Optional[str] = "", + self, + session_id: Optional[str] = "", + user_id: Optional[str] = "", + client_id: Optional[str] = "", ) -> List[AuthnEvent]: """ Return the authentication events that exists for a user/client combination. @@ -387,7 +385,7 @@ def get_authentication_events( :return: None if no authentication event could be found or an AuthnEvent instance. """ if session_id: - cid, c_info = self.get_node_info(session_id, node_type='client') + cid, c_info = self.get_node_info(session_id, node_type="client") elif user_id and client_id: c_info = self.get([user_id, client_id]) else: @@ -450,13 +448,13 @@ def revoke_grant(self, session_id: str): # return [self.get([user_id, client_id, gid]) for gid in _csi.subordinate] def get_session_info( - self, - session_id: str, - user_session_info: bool = False, - client_session_info: bool = False, - grant: bool = False, - authentication_event: bool = False, - authorization_request: bool = False, + self, + session_id: str, + user_session_info: bool = False, + client_session_info: bool = False, + grant: bool = False, + authentication_event: bool = False, + authorization_request: bool = False, ) -> dict: """ Returns information connected to a session. @@ -481,14 +479,14 @@ def get_session_info( return res def get_session_info_by_token( - self, - token_value: str, - user_session_info: Optional[bool] = False, - client_session_info: Optional[bool] = False, - grant: Optional[bool] = False, - authentication_event: Optional[bool] = False, - authorization_request: Optional[bool] = False, - handler_key: Optional[str] = "", + self, + token_value: str, + user_session_info: Optional[bool] = False, + client_session_info: Optional[bool] = False, + grant: Optional[bool] = False, + authentication_event: Optional[bool] = False, + authorization_request: Optional[bool] = False, + handler_key: Optional[str] = "", ) -> dict: if handler_key: diff --git a/src/idpyoidc/server/token/handler.py b/src/idpyoidc/server/token/handler.py index 20b36fa5..ad2ae4e9 100755 --- a/src/idpyoidc/server/token/handler.py +++ b/src/idpyoidc/server/token/handler.py @@ -24,11 +24,11 @@ class TokenHandler(ImpExp): parameter = {"handler": DLDict, "handler_order": [""]} def __init__( - self, - access_token: Optional[Token] = None, - authorization_code: Optional[Token] = None, - refresh_token: Optional[Token] = None, - id_token: Optional[Token] = None, + self, + access_token: Optional[Token] = None, + authorization_code: Optional[Token] = None, + refresh_token: Optional[Token] = None, + id_token: Optional[Token] = None, ): ImpExp.__init__(self) self.handler = {"authorization_code": authorization_code, "access_token": access_token} @@ -141,13 +141,13 @@ def default_token(spec): def factory( - upstream_get, - code: Optional[dict] = None, - token: Optional[dict] = None, - refresh: Optional[dict] = None, - id_token: Optional[dict] = None, - jwks_file: Optional[str] = "", - **kwargs + upstream_get, + code: Optional[dict] = None, + token: Optional[dict] = None, + refresh: Optional[dict] = None, + id_token: Optional[dict] = None, + jwks_file: Optional[str] = "", + **kwargs ) -> TokenHandler: """ Create a token handler diff --git a/src/idpyoidc/server/token/id_token.py b/src/idpyoidc/server/token/id_token.py index 0840ef5f..181a000c 100755 --- a/src/idpyoidc/server/token/id_token.py +++ b/src/idpyoidc/server/token/id_token.py @@ -57,9 +57,7 @@ def include_session_id(context, client_id, where): return True -def get_sign_and_encrypt_algorithms( - context, client_info, payload_type, sign=False, encrypt=False -): +def get_sign_and_encrypt_algorithms(context, client_info, payload_type, sign=False, encrypt=False): args = {"sign": sign, "encrypt": encrypt} if sign: try: @@ -257,10 +255,11 @@ def sign_encrypt( lifetime = self.lifetime _jwt = JWT( - self.upstream_get('attribute', 'keyjar'), + self.upstream_get("attribute", "keyjar"), iss=_context.issuer, lifetime=lifetime, - **alg_dict) + **alg_dict, + ) return _jwt.pack(_payload, recv=client_id) @@ -324,8 +323,8 @@ def info(self, token): alg_dict = get_sign_and_encrypt_algorithms(_context, client_info, "id_token", sign=True) verifier = JWT( - key_jar=self.upstream_get('attribute', 'keyjar'), - allowed_sign_algs=alg_dict["sign_alg"]) + key_jar=self.upstream_get("attribute", "keyjar"), allowed_sign_algs=alg_dict["sign_alg"] + ) try: _payload = verifier.unpack(token) except JWSException: diff --git a/src/idpyoidc/server/token/jwt_token.py b/src/idpyoidc/server/token/jwt_token.py index ec125921..698d3ea1 100644 --- a/src/idpyoidc/server/token/jwt_token.py +++ b/src/idpyoidc/server/token/jwt_token.py @@ -17,20 +17,19 @@ class JWTToken(Token): - def __init__( - self, - token_class, - # keyjar: KeyJar = None, - issuer: str = None, - aud: Optional[list] = None, - alg: str = "ES256", - lifetime: int = DEFAULT_TOKEN_LIFETIME, - upstream_get: Callable = None, - token_type: str = "Bearer", - profile: Optional[Union[Message, str]] = JWTAccessToken, - with_jti: Optional[bool] = False, - **kwargs + self, + token_class, + # keyjar: KeyJar = None, + issuer: str = None, + aud: Optional[list] = None, + alg: str = "ES256", + lifetime: int = DEFAULT_TOKEN_LIFETIME, + upstream_get: Callable = None, + token_type: str = "Bearer", + profile: Optional[Union[Message, str]] = JWTAccessToken, + with_jti: Optional[bool] = False, + **kwargs ): Token.__init__(self, token_class, **kwargs) self.token_type = token_type @@ -59,13 +58,13 @@ def load_custom_claims(self, payload: dict = None): return payload def __call__( - self, - session_id: Optional[str] = "", - token_class: Optional[str] = "", - usage_rules: Optional[dict] = None, - profile: Optional[Message] = None, - with_jti: Optional[bool] = None, - **payload + self, + session_id: Optional[str] = "", + token_class: Optional[str] = "", + usage_rules: Optional[dict] = None, + profile: Optional[Message] = None, + with_jti: Optional[bool] = None, + **payload ) -> str: """ Return a token. @@ -90,7 +89,7 @@ def __call__( else: lifetime = self.lifetime signer = JWT( - key_jar=self.upstream_get('attribute', 'keyjar'), + key_jar=self.upstream_get("attribute", "keyjar"), iss=self.issuer, lifetime=lifetime, sign_alg=self.alg, @@ -111,8 +110,9 @@ def __call__( return signer.pack(payload) def get_payload(self, token): - verifier = JWT(key_jar=self.upstream_get('attribute', 'keyjar'), - allowed_sign_algs=[self.alg]) + verifier = JWT( + key_jar=self.upstream_get("attribute", "keyjar"), allowed_sign_algs=[self.alg] + ) try: _payload = verifier.unpack(token) except JWSException: diff --git a/src/idpyoidc/server/user_authn/user.py b/src/idpyoidc/server/user_authn/user.py index c0307dc9..b623d6a0 100755 --- a/src/idpyoidc/server/user_authn/user.py +++ b/src/idpyoidc/server/user_authn/user.py @@ -119,8 +119,13 @@ def cookie_info(self, cookie: List[dict], client_id: str) -> dict: # verify session ID try: _context.session_manager[_info["sid"]] - except (KeyError, ValueError, InconsistentDatabase, - NoSuchClientSession, NoSuchGrant) as err: + except ( + KeyError, + ValueError, + InconsistentDatabase, + NoSuchClientSession, + NoSuchGrant, + ) as err: logger.info(f"Verifying session ID fail due to {err}") return {} @@ -153,13 +158,13 @@ class UserPassJinja2(UserAuthnMethod): url_endpoint = "/verify/user_pass_jinja" def __init__( - self, - db, - template_handler, - template="user_pass.jinja2", - upstream_get=None, - verify_endpoint="", - **kwargs, + self, + db, + template_handler, + template="user_pass.jinja2", + upstream_get=None, + verify_endpoint="", + **kwargs, ): super(UserPassJinja2, self).__init__(upstream_get=upstream_get) @@ -193,7 +198,7 @@ def __call__(self, **kwargs): if not self.upstream_get: raise Exception(f"{self.__class__.__name__} doesn't have a working upstream_get") _context = self.upstream_get("context") - _keyjar = self.upstream_get("attribute", 'keyjar') + _keyjar = self.upstream_get("attribute", "keyjar") # Stores information need afterwards in a signed JWT that then # appears as a hidden input in the form jws = create_signed_jwt(_context.issuer, _keyjar, **kwargs) @@ -219,12 +224,11 @@ def verify(self, *args, **kwargs): class UserPass(UserAuthnMethod): - def __init__( - self, - db_conf, - upstream_get=None, - **kwargs, + self, + db_conf, + upstream_get=None, + **kwargs, ): super(UserPass, self).__init__(upstream_get=upstream_get) @@ -242,7 +246,6 @@ def verify(self, *args, **kwargs): class BasicAuthn(UserAuthnMethod): - def __init__(self, pwd, ttl=5, upstream_get=None): UserAuthnMethod.__init__(self, upstream_get=upstream_get) self.passwd = pwd diff --git a/src/idpyoidc/server/util.py b/src/idpyoidc/server/util.py index 4ec0eaa9..1105da88 100755 --- a/src/idpyoidc/server/util.py +++ b/src/idpyoidc/server/util.py @@ -57,7 +57,6 @@ def build_endpoints(conf, upstream_get, issuer): class JSONDictDB(object): - def __init__(self, filename): with open(filename, "r") as f: self._db = json.load(f) @@ -94,7 +93,7 @@ def lv_unpack(txt): while txt: l, v = txt.split(":", 1) res.append(v[: int(l)]) - txt = v[int(l):] + txt = v[int(l) :] return res @@ -123,9 +122,9 @@ def get_http_params(config): def allow_refresh_token(context): # Are there a refresh_token handler - refresh_token_handler = context.session_manager.token_handler.handler.get( - "refresh_token" - ) + refresh_token_handler = context.session_manager.token_handler.handler.get("refresh_token") + if refresh_token_handler is None: + return False # Is refresh_token grant type supported _token_supported = False diff --git a/src/idpyoidc/storage/abfile.py b/src/idpyoidc/storage/abfile.py index e6f980c0..cb80182d 100644 --- a/src/idpyoidc/storage/abfile.py +++ b/src/idpyoidc/storage/abfile.py @@ -24,10 +24,11 @@ class AbstractFileSystem(DictType): """ def __init__( - self, fdir: Optional[str] = "", - key_conv: Optional[str] = "", - value_conv: Optional[str] = "", - **kwargs + self, + fdir: Optional[str] = "", + key_conv: Optional[str] = "", + value_conv: Optional[str] = "", + **kwargs ): """ items = FileSystem( diff --git a/src/idpyoidc/time_util.py b/src/idpyoidc/time_util.py index 3ff0838b..9e0d7e18 100644 --- a/src/idpyoidc/time_util.py +++ b/src/idpyoidc/time_util.py @@ -104,11 +104,11 @@ def parse_duration(duration): try: mod = duration[index:].index(code) try: - dic[typ] = int(duration[index: index + mod]) + dic[typ] = int(duration[index : index + mod]) except ValueError: if code == "S": try: - dic[typ] = float(duration[index: index + mod]) + dic[typ] = float(duration[index : index + mod]) except ValueError: raise TimeUtilError("Not a float") else: @@ -185,7 +185,7 @@ def time_in_a_while(days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0 def time_a_while_ago( - days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0 + days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0 ): """ Will return a time specification for a time sometime in the past. @@ -205,14 +205,14 @@ def time_a_while_ago( def in_a_while( - days=0, - seconds=0, - microseconds=0, - milliseconds=0, - minutes=0, - hours=0, - weeks=0, - time_format=TIME_FORMAT, + days=0, + seconds=0, + microseconds=0, + milliseconds=0, + minutes=0, + hours=0, + weeks=0, + time_format=TIME_FORMAT, ): """ :param days: @@ -234,14 +234,14 @@ def in_a_while( def a_while_ago( - days=0, - seconds=0, - microseconds=0, - milliseconds=0, - minutes=0, - hours=0, - weeks=0, - time_format=TIME_FORMAT, + days=0, + seconds=0, + microseconds=0, + milliseconds=0, + minutes=0, + hours=0, + weeks=0, + time_format=TIME_FORMAT, ): """ @@ -361,7 +361,7 @@ def time_sans_frac(): def epoch_in_a_while( - days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0 + days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0 ): """ Return the number of seconds since epoch a while from now. diff --git a/src/idpyoidc/util.py b/src/idpyoidc/util.py index ab515d00..21c66afa 100644 --- a/src/idpyoidc/util.py +++ b/src/idpyoidc/util.py @@ -84,7 +84,6 @@ def split_uri(uri: str) -> [str, Union[dict, None]]: class QPKey: - def serialize(self, str): return quote_plus(str) @@ -93,7 +92,6 @@ def deserialize(self, str): class JSON: - def serialize(self, str): return json.dumps(str) @@ -102,7 +100,6 @@ def deserialize(self, str): class PassThru: - def serialize(self, str): return str diff --git a/tests/private/token_jwks.json b/tests/private/token_jwks.json index 9e18a977..c8197bd2 100644 --- a/tests/private/token_jwks.json +++ b/tests/private/token_jwks.json @@ -1 +1 @@ -{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "XeeoaV1P5eINXBFEDU2U_YBXqsjJE0uD"}]} \ No newline at end of file +{"keys": [{"kty": "oct", "use": "enc", "kid": "code", "k": "vSHDkLBHhDStkR0NWu8519rmV5zmnm5_"}, {"kty": "oct", "use": "enc", "kid": "refresh", "k": "ASSv_faXjb13FYcRMP7ht8f641jIWw3W"}]} \ No newline at end of file diff --git a/tests/request123456.jwt b/tests/request123456.jwt index 1d5c9d1d..391614ae 100644 --- a/tests/request123456.jwt +++ b/tests/request123456.jwt @@ -1 +1 @@ -eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJBWGV0Wm1SVXFWT2NPX0NTMFZrNF9oM05vRjlJRHpzYUEwZHBWRFpZVS1BIiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjc4OTU2Mzg1LCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.axJ7C32rBbu5jWwnZAa04_3QSPwytuRtUjRTOpcHnSa1D_XsnPjVuVmRbYWFPepcaPeMN6GYuOn22_6quVSRktnMvVPfh-C1YttosfWOYavq60H3Hav3mLa357gGgCSRJJG1RGXQlSf5PU7P1hdiJoCaiejpVaA7efkBcQagTndlxFoE3oRoeKr9RqLKPRvRnlB-qv6FpanLwm4gY4NnAOjHo_1BOP6tvJTfad6aQwW5sRL-NaKLLrfkHgKnsTpyEUrBtl6-63O8_w9ckBsT1B9JBH1T6vhkjY-vGBptTnrAf_0giDi_Lw7jZMrETqJjnyMlQIDd88AOlnHV0IDvew \ No newline at end of file +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNIRXlZV2N3TlZrMExUZFJPVFp6WjJGVVduZElWWGRhY2sweFdVTTVTRXB3Y1MwM2RWVXhXVTR6UlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICI1ams1WkdMLTE3NXpHT0FPQW5yTlZra2paZXltS0JwOVFoek81QV90eDQwIiwgImNsaWVudF9pZCI6ICJjbGllbnRfaWQiLCAiaXNzIjogImNsaWVudF9pZCIsICJpYXQiOiAxNjg3MTgwNzk0LCAiYXVkIjogWyJodHRwczovL2V4YW1wbGUuY29tIl19.HUbiyiC0pypd8hamG9JJ-xQaJ7FEAVjDoy4jH00hJ5FtLqm87PAKIvD5aptYv8VzdpA5X8hCDUW4g0noNBbEsmvXeJpoXHSeVz_A4Ue8Ziz7z6dnrYf7BNFt3NyTibKVlkcWNGPBhEjyw0k4r6O86lQ2mSQjINJuqpR7VeEQyK7CBhDl5bicPctB4yGm4VksvC39695hhyGtUrUyrGW539g54VkG-x0kKv2HMc_ZGsnsEgFrT0fHKWuc1hPRkGi2XuSyhhD20zhnZhMGyTovwoZxmbx2seiIinjd0_wZVMZS277yUvMQTCvjOHJyu80XLLZqI71GguonCWdxIIrblQ \ No newline at end of file diff --git a/tests/test_05_oauth2.py b/tests/test_05_oauth2.py index fac4d7ad..fc187db3 100644 --- a/tests/test_05_oauth2.py +++ b/tests/test_05_oauth2.py @@ -578,7 +578,7 @@ def test_init(self): class TestCCAccessTokenRequest(object): def test_init(self): - cc = CCAccessTokenRequest(scope="/foo", grant_type='client_credentials') + cc = CCAccessTokenRequest(scope="/foo", grant_type="client_credentials") cc.verify() assert cc["scope"] == ["/foo"] diff --git a/tests/test_08_transform.py b/tests/test_08_transform.py index 52020451..27eed5ca 100644 --- a/tests/test_08_transform.py +++ b/tests/test_08_transform.py @@ -13,24 +13,23 @@ class TestTransform: - @pytest.fixture(autouse=True) def setup(self): supported = OIDC_Claims._supports.copy() for service in [ - 'idpyoidc.client.oidc.access_token.AccessToken', - 'idpyoidc.client.oidc.authorization.Authorization', - 'idpyoidc.client.oidc.backchannel_authentication.BackChannelAuthentication', - 'idpyoidc.client.oidc.backchannel_authentication.ClientNotification', - 'idpyoidc.client.oidc.check_id.CheckID', - 'idpyoidc.client.oidc.check_session.CheckSession', - 'idpyoidc.client.oidc.end_session.EndSession', - 'idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery', - 'idpyoidc.client.oidc.read_registration.RegistrationRead', - 'idpyoidc.client.oidc.refresh_access_token.RefreshAccessToken', - 'idpyoidc.client.oidc.registration.Registration', - 'idpyoidc.client.oidc.userinfo.UserInfo', - 'idpyoidc.client.oidc.webfinger.WebFinger' + "idpyoidc.client.oidc.access_token.AccessToken", + "idpyoidc.client.oidc.authorization.Authorization", + "idpyoidc.client.oidc.backchannel_authentication.BackChannelAuthentication", + "idpyoidc.client.oidc.backchannel_authentication.ClientNotification", + "idpyoidc.client.oidc.check_id.CheckID", + "idpyoidc.client.oidc.check_session.CheckSession", + "idpyoidc.client.oidc.end_session.EndSession", + "idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery", + "idpyoidc.client.oidc.read_registration.RegistrationRead", + "idpyoidc.client.oidc.refresh_access_token.RefreshAccessToken", + "idpyoidc.client.oidc.registration.Registration", + "idpyoidc.client.oidc.userinfo.UserInfo", + "idpyoidc.client.oidc.webfinger.WebFinger", ]: cls = importer(service) supported.update(cls._supports) @@ -44,137 +43,146 @@ def setup(self): def test_supported(self): # These are all the available configuration parameters assert set(self.supported.keys()) == { - 'acr_values_supported', - 'application_type', - 'backchannel_logout_session_required', - 'backchannel_logout_supported', - 'backchannel_logout_uri', - 'callback_uris', - 'client_id', - 'client_name', - 'client_secret', - 'client_uri', - 'contacts', - 'default_max_age', - 'encrypt_id_token_supported', - 'encrypt_request_object_supported', - 'encrypt_userinfo_supported', - 'frontchannel_logout_session_required', - 'frontchannel_logout_supported', - 'frontchannel_logout_uri', - 'id_token_encryption_alg_values_supported', - 'id_token_encryption_enc_values_supported', - 'id_token_signing_alg_values_supported', - 'initiate_login_uri', - 'jwks', - 'jwks_uri', - 'logo_uri', - 'policy_uri', - 'post_logout_redirect_uris', - 'redirect_uris', - 'request_object_encryption_alg_values_supported', - 'request_object_encryption_enc_values_supported', - 'request_object_signing_alg_values_supported', - 'request_parameter', - 'request_parameter_supported', - 'request_uri_parameter_supported', - 'request_uris', - 'requests_dir', - 'require_auth_time', - 'response_modes_supported', - 'response_types_supported', - 'scopes_supported', - 'sector_identifier_uri', - 'subject_types_supported', + "acr_values_supported", + "application_type", + "backchannel_logout_session_required", + "backchannel_logout_supported", + "backchannel_logout_uri", + "callback_uris", + "client_id", + "client_name", + "client_secret", + "client_uri", + "contacts", + "default_max_age", + "encrypt_id_token_supported", + "encrypt_request_object_supported", + "encrypt_userinfo_supported", + "frontchannel_logout_session_required", + "frontchannel_logout_supported", + "frontchannel_logout_uri", + "id_token_encryption_alg_values_supported", + "id_token_encryption_enc_values_supported", + "id_token_signing_alg_values_supported", + "initiate_login_uri", + "jwks", + "jwks_uri", + "logo_uri", + "policy_uri", + "post_logout_redirect_uris", + "redirect_uris", + "request_object_encryption_alg_values_supported", + "request_object_encryption_enc_values_supported", + "request_object_signing_alg_values_supported", + "request_parameter", + "request_parameter_supported", + "request_uri_parameter_supported", + "request_uris", + "requests_dir", + "require_auth_time", + "response_modes_supported", + "response_types_supported", + "scopes_supported", + "sector_identifier_uri", + "subject_types_supported", # 'token_endpoint_auth_method', - 'token_endpoint_auth_methods_supported', - 'token_endpoint_auth_signing_alg_values_supported', - 'tos_uri', - 'userinfo_encryption_alg_values_supported', - 'userinfo_encryption_enc_values_supported', - 'userinfo_signing_alg_values_supported'} + "token_endpoint_auth_methods_supported", + "token_endpoint_auth_signing_alg_values_supported", + "tos_uri", + "userinfo_encryption_alg_values_supported", + "userinfo_encryption_enc_values_supported", + "userinfo_signing_alg_values_supported", + } def test_oidc_setup(self): # This is OP specified stuff assert set(ProviderConfigurationResponse.c_param.keys()).difference( - set(self.supported)) == {'authorization_endpoint', - 'check_session_iframe', - 'claim_types_supported', - 'claims_locales_supported', - 'claims_parameter_supported', - 'claims_supported', - 'display_values_supported', - 'end_session_endpoint', - 'error', - 'error_description', - 'error_uri', - 'grant_types_supported', - 'issuer', - 'op_policy_uri', - 'op_tos_uri', - 'registration_endpoint', - 'require_request_uri_registration', - 'service_documentation', - 'token_endpoint', - 'ui_locales_supported', - 'userinfo_endpoint', - 'code_challenge_methods_supported'} + set(self.supported) + ) == { + "authorization_endpoint", + "check_session_iframe", + "claim_types_supported", + "claims_locales_supported", + "claims_parameter_supported", + "claims_supported", + "display_values_supported", + "end_session_endpoint", + "error", + "error_description", + "error_uri", + "grant_types_supported", + "issuer", + "op_policy_uri", + "op_tos_uri", + "registration_endpoint", + "require_request_uri_registration", + "service_documentation", + "token_endpoint", + "ui_locales_supported", + "userinfo_endpoint", + "code_challenge_methods_supported", + } # parameters that are not mapped against what the OP's provider info says assert set(self.supported).difference( - set(ProviderConfigurationResponse.c_param.keys())) == {'application_type', - 'backchannel_logout_uri', - 'callback_uris', - 'client_id', - 'client_name', - 'client_secret', - 'client_uri', - 'contacts', - 'default_max_age', - 'encrypt_id_token_supported', - 'encrypt_request_object_supported', - 'encrypt_userinfo_supported', - 'frontchannel_logout_uri', - 'initiate_login_uri', - 'jwks', - 'logo_uri', - 'policy_uri', - 'post_logout_redirect_uris', - 'redirect_uris', - 'request_parameter', - 'request_uris', - 'requests_dir', - 'require_auth_time', - 'sector_identifier_uri', - 'tos_uri'} + set(ProviderConfigurationResponse.c_param.keys()) + ) == { + "application_type", + "backchannel_logout_uri", + "callback_uris", + "client_id", + "client_name", + "client_secret", + "client_uri", + "contacts", + "default_max_age", + "encrypt_id_token_supported", + "encrypt_request_object_supported", + "encrypt_userinfo_supported", + "frontchannel_logout_uri", + "initiate_login_uri", + "jwks", + "logo_uri", + "policy_uri", + "post_logout_redirect_uris", + "redirect_uris", + "request_parameter", + "request_uris", + "requests_dir", + "require_auth_time", + "sector_identifier_uri", + "tos_uri", + } claims = OIDC_Claims() # No input from the IDP so info is absent - claims.prefer = supported_to_preferred(supported=self.supported, - preference=claims.prefer, - base_url='https://example.com') + claims.prefer = supported_to_preferred( + supported=self.supported, preference=claims.prefer, base_url="https://example.com" + ) # These are the claims that has default values. A default value may be an empty list. # This is the case for claims like id_token_encryption_enc_values_supported. - assert set(claims.prefer.keys()) == {'application_type', - 'default_max_age', - 'encrypt_request_object_supported', - 'encrypt_userinfo_supported', - 'id_token_encryption_alg_values_supported', - 'id_token_encryption_enc_values_supported', - 'id_token_signing_alg_values_supported', - 'request_object_encryption_alg_values_supported', - 'request_object_encryption_enc_values_supported', - 'request_object_signing_alg_values_supported', - 'response_modes_supported', - 'response_types_supported', - 'scopes_supported', - 'subject_types_supported', - 'token_endpoint_auth_methods_supported', - 'token_endpoint_auth_signing_alg_values_supported', - 'userinfo_encryption_alg_values_supported', - 'userinfo_encryption_enc_values_supported', - 'userinfo_signing_alg_values_supported'} + assert set(claims.prefer.keys()) == { + "application_type", + "default_max_age", + "encrypt_request_object_supported", + "encrypt_userinfo_supported", + "id_token_encryption_alg_values_supported", + "id_token_encryption_enc_values_supported", + "id_token_signing_alg_values_supported", + "request_object_encryption_alg_values_supported", + "request_object_encryption_enc_values_supported", + "request_object_signing_alg_values_supported", + "response_modes_supported", + "response_types_supported", + "scopes_supported", + "subject_types_supported", + "token_endpoint_auth_methods_supported", + "token_endpoint_auth_signing_alg_values_supported", + "userinfo_encryption_alg_values_supported", + "userinfo_encryption_enc_values_supported", + "userinfo_signing_alg_values_supported", + } # To verify that I have all the necessary claims to do client registration reg_claim = [] @@ -184,7 +192,10 @@ def test_oidc_setup(self): reg_claim.append(key) assert set(RegistrationRequest.c_param.keys()).difference(set(reg_claim)) == { - 'post_logout_redirect_uri', 'grant_types'} + "post_logout_redirect_uri", + "grant_types", + "response_modes" # Extra item + } # Which ones are list -> singletons @@ -197,11 +208,11 @@ def test_oidc_setup(self): elif isinstance(spec[0], list): l_to_s.append(key) - assert set(non_oidc) == {'scopes_supported'} - assert set(l_to_s) == {'response_types', 'grant_types', 'default_acr_values'} + assert set(non_oidc) == {"scopes_supported"} + assert set(l_to_s) == {"response_types", "grant_types", "default_acr_values"} def test_provider_info(self): - OP_BASEURL = 'https://example.com' + OP_BASEURL = "https://example.com" provider_info_response = { "version": "3.0", "token_endpoint_auth_methods_supported": [ @@ -218,69 +229,75 @@ def test_provider_info(self): "registration_endpoint": f"{OP_BASEURL}/registration", "end_session_endpoint": f"{OP_BASEURL}/end_session", # below are a set which the RP has default values but the OP overwrites - "scopes_supported": ['openid', 'fee', 'faa', 'foo', 'fum'], - "response_types_supported": ['code', 'id_token', 'code id_token'], - "response_modes_supported": ['query', 'form_post', 'new_fangled'], + "scopes_supported": ["openid", "fee", "faa", "foo", "fum"], + "response_types_supported": ["code", "id_token", "code id_token"], + "response_modes_supported": ["query", "form_post", "new_fangled"], # this does not have a default value - "acr_values_supported": ['mfa'], + "acr_values_supported": ["mfa"], } claims = OIDC_Claims() - claims.prefer = supported_to_preferred(supported=self.supported, - preference=claims.prefer, - base_url='https://example.com', - info=provider_info_response) + claims.prefer = supported_to_preferred( + supported=self.supported, + preference=claims.prefer, + base_url="https://example.com", + info=provider_info_response, + ) # These are the claims that has default values - assert set(claims.prefer.keys()) == {'application_type', - 'default_max_age', - 'encrypt_request_object_supported', - 'encrypt_userinfo_supported', - 'id_token_encryption_alg_values_supported', - 'id_token_encryption_enc_values_supported', - 'id_token_signing_alg_values_supported', - 'request_object_encryption_alg_values_supported', - 'request_object_encryption_enc_values_supported', - 'request_object_signing_alg_values_supported', - 'response_modes_supported', - 'response_types_supported', - 'scopes_supported', - 'subject_types_supported', - 'token_endpoint_auth_methods_supported', - 'token_endpoint_auth_signing_alg_values_supported', - 'userinfo_encryption_alg_values_supported', - 'userinfo_encryption_enc_values_supported', - 'userinfo_signing_alg_values_supported'} + assert set(claims.prefer.keys()) == { + "application_type", + "default_max_age", + "encrypt_request_object_supported", + "encrypt_userinfo_supported", + "id_token_encryption_alg_values_supported", + "id_token_encryption_enc_values_supported", + "id_token_signing_alg_values_supported", + "request_object_encryption_alg_values_supported", + "request_object_encryption_enc_values_supported", + "request_object_signing_alg_values_supported", + "response_modes_supported", + "response_types_supported", + "scopes_supported", + "subject_types_supported", + "token_endpoint_auth_methods_supported", + "token_endpoint_auth_signing_alg_values_supported", + "userinfo_encryption_alg_values_supported", + "userinfo_encryption_enc_values_supported", + "userinfo_signing_alg_values_supported", + } # least common denominator # The RP supports less than the OP - assert claims.get_preference('scopes_supported') == ['openid'] - assert claims.get_preference("response_modes_supported") == ['query', 'form_post'] + assert claims.get_preference("scopes_supported") == ["openid"] + assert claims.get_preference("response_modes_supported") == ["query", "form_post"] # The OP supports less than the RP - assert claims.get_preference("response_types_supported") == ['code', 'id_token', - 'code id_token'] + assert claims.get_preference("response_types_supported") == [ + "code", + "id_token", + "code id_token", + ] class TestTransform2: - @pytest.fixture(autouse=True) def setup(self): self.claims = OIDC_Claims() supported = self.claims._supports.copy() for service in [ - 'idpyoidc.client.oidc.access_token.AccessToken', - 'idpyoidc.client.oidc.authorization.Authorization', - 'idpyoidc.client.oidc.backchannel_authentication.BackChannelAuthentication', - 'idpyoidc.client.oidc.backchannel_authentication.ClientNotification', - 'idpyoidc.client.oidc.check_id.CheckID', - 'idpyoidc.client.oidc.check_session.CheckSession', - 'idpyoidc.client.oidc.end_session.EndSession', - 'idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery', - 'idpyoidc.client.oidc.read_registration.RegistrationRead', - 'idpyoidc.client.oidc.refresh_access_token.RefreshAccessToken', - 'idpyoidc.client.oidc.registration.Registration', - 'idpyoidc.client.oidc.userinfo.UserInfo', - 'idpyoidc.client.oidc.webfinger.WebFinger' + "idpyoidc.client.oidc.access_token.AccessToken", + "idpyoidc.client.oidc.authorization.Authorization", + "idpyoidc.client.oidc.backchannel_authentication.BackChannelAuthentication", + "idpyoidc.client.oidc.backchannel_authentication.ClientNotification", + "idpyoidc.client.oidc.check_id.CheckID", + "idpyoidc.client.oidc.check_session.CheckSession", + "idpyoidc.client.oidc.end_session.EndSession", + "idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery", + "idpyoidc.client.oidc.read_registration.RegistrationRead", + "idpyoidc.client.oidc.refresh_access_token.RefreshAccessToken", + "idpyoidc.client.oidc.registration.Registration", + "idpyoidc.client.oidc.userinfo.UserInfo", + "idpyoidc.client.oidc.webfinger.WebFinger", ]: cls = importer(service) supported.update(cls._supports) @@ -292,18 +309,20 @@ def setup(self): self.supported = supported preference = { "application_type": "web", - "redirect_uris": ["https://client.example.org/callback", - "https://client.example.org/callback2"], + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2", + ], "client_name": "My Example", # "client_name#ja-Jpan-JP": "クライアント名", "logo_uri": "https://client.example.org/logo.png", - 'contacts': ["ve7jtb@example.org", "mary@example.org"] + "contacts": ["ve7jtb@example.org", "mary@example.org"], } self.claims.load_conf(preference, self.supported) def test_registration_response(self): - OP_BASEURL = 'https://example.com' + OP_BASEURL = "https://example.com" provider_info_response = { "version": "3.0", "token_endpoint_auth_methods_supported": [ @@ -320,81 +339,92 @@ def test_registration_response(self): "registration_endpoint": f"{OP_BASEURL}/registration", "end_session_endpoint": f"{OP_BASEURL}/end_session", # below are a set which the RP has default values but the OP overwrites - "scopes_supported": ['openid', 'fee', 'faa', 'foo', 'fum'], - "response_types_supported": ['code', 'id_token', 'code id_token'], - "response_modes_supported": ['query', 'form_post', 'new_fangled'], + "scopes_supported": ["openid", "fee", "faa", "foo", "fum"], + "response_types_supported": ["code", "id_token", "code id_token"], + "response_modes_supported": ["query", "form_post", "new_fangled"], # this does not have a default value - "acr_values_supported": ['mfa'], + "acr_values_supported": ["mfa"], + } + + self.claims.prefer = supported_to_preferred( + supported=self.supported, + preference=self.claims.prefer, + base_url="https://example.com", + info=provider_info_response, + ) + + registration_request = create_registration_request( + prefers=self.claims.prefer, supported=self.supported + ) + + assert set(registration_request.keys()) == { + "application_type", + "client_name", + "contacts", + "default_max_age", + "id_token_signed_response_alg", + "logo_uri", + "redirect_uris", + "request_object_signing_alg", + "response_types", + "response_modes", # non-standard + "subject_type", + "token_endpoint_auth_method", + "token_endpoint_auth_signing_alg", + "userinfo_signed_response_alg", } - self.claims.prefer = supported_to_preferred(supported=self.supported, - preference=self.claims.prefer, - base_url='https://example.com', - info=provider_info_response) - - registration_request = create_registration_request(prefers=self.claims.prefer, - supported=self.supported) - - assert set(registration_request.keys()) == {'application_type', - 'client_name', - 'contacts', - 'default_max_age', - 'id_token_signed_response_alg', - 'logo_uri', - 'redirect_uris', - 'request_object_signing_alg', - 'response_types', - 'subject_type', - 'token_endpoint_auth_method', - 'token_endpoint_auth_signing_alg', - 'userinfo_signed_response_alg'} - - assert registration_request["subject_type"] == 'public' + assert registration_request["subject_type"] == "public" registration_response = { "application_type": "web", - "redirect_uris": - ["https://client.example.org/callback", - "https://client.example.org/callback2"], + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2", + ], "client_name": "My Example", "logo_uri": "https://client.example.org/logo.png", "subject_type": "pairwise", - "sector_identifier_uri": - "https://other.example.net/file_of_redirect_uris.json", + "sector_identifier_uri": "https://other.example.net/file_of_redirect_uris.json", "token_endpoint_auth_method": "client_secret_basic", "jwks_uri": "https://client.example.org/my_public_keys.jwks", "userinfo_encrypted_response_alg": "RSA1_5", "userinfo_encrypted_response_enc": "A128CBC-HS256", "contacts": ["ve7jtb@example.org", "mary@example.org"], "request_uris": [ - "https://client.example.org/rf.txt#qpXaRLh_n93TTR9F252ValdatUQvQiJi5BDub2BeznA"] + "https://client.example.org/rf.txt#qpXaRLh_n93TTR9F252ValdatUQvQiJi5BDub2BeznA" + ], + } + + to_use = preferred_to_registered( + supported=self.supported, + prefers=self.claims.prefer, + registration_response=registration_response, + ) + + assert set(to_use.keys()) == { + "application_type", + "client_name", + "contacts", + "default_max_age", + "encrypt_request_object_supported", + "encrypt_userinfo_supported", + "id_token_signed_response_alg", + "jwks_uri", + "logo_uri", + "redirect_uris", + "request_object_signing_alg", + "request_uris", + "response_types", + "response_modes", # non-standard + "scope", + "sector_identifier_uri", + "subject_type", + "token_endpoint_auth_method", + "token_endpoint_auth_signing_alg", + "userinfo_encrypted_response_alg", + "userinfo_encrypted_response_enc", + "userinfo_signed_response_alg", } - to_use = preferred_to_registered(supported=self.supported, - prefers=self.claims.prefer, - registration_response=registration_response) - - assert set(to_use.keys()) == {'application_type', - 'client_name', - 'contacts', - 'default_max_age', - 'encrypt_request_object_supported', - 'encrypt_userinfo_supported', - 'id_token_signed_response_alg', - 'jwks_uri', - 'logo_uri', - 'redirect_uris', - 'request_object_signing_alg', - 'request_uris', - 'response_modes_supported', - 'response_types', - 'scope', - 'sector_identifier_uri', - 'subject_type', - 'token_endpoint_auth_method', - 'token_endpoint_auth_signing_alg', - 'userinfo_encrypted_response_alg', - 'userinfo_encrypted_response_enc', - 'userinfo_signed_response_alg'} - - assert to_use["subject_type"] == 'pairwise' + assert to_use["subject_type"] == "pairwise" diff --git a/tests/test_09_work_condition.py b/tests/test_09_work_condition.py index 6ebeb5e3..44f918a8 100644 --- a/tests/test_09_work_condition.py +++ b/tests/test_09_work_condition.py @@ -15,25 +15,24 @@ class TestWorkEnvironment: - @pytest.fixture(autouse=True) def setup(self): self.claims = Claims() supported = self.claims._supports.copy() for service in [ - 'idpyoidc.client.oidc.access_token.AccessToken', - 'idpyoidc.client.oidc.authorization.Authorization', - 'idpyoidc.client.oidc.backchannel_authentication.BackChannelAuthentication', - 'idpyoidc.client.oidc.backchannel_authentication.ClientNotification', - 'idpyoidc.client.oidc.check_id.CheckID', - 'idpyoidc.client.oidc.check_session.CheckSession', - 'idpyoidc.client.oidc.end_session.EndSession', - 'idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery', - 'idpyoidc.client.oidc.read_registration.RegistrationRead', - 'idpyoidc.client.oidc.refresh_access_token.RefreshAccessToken', - 'idpyoidc.client.oidc.registration.Registration', - 'idpyoidc.client.oidc.userinfo.UserInfo', - 'idpyoidc.client.oidc.webfinger.WebFinger' + "idpyoidc.client.oidc.access_token.AccessToken", + "idpyoidc.client.oidc.authorization.Authorization", + "idpyoidc.client.oidc.backchannel_authentication.BackChannelAuthentication", + "idpyoidc.client.oidc.backchannel_authentication.ClientNotification", + "idpyoidc.client.oidc.check_id.CheckID", + "idpyoidc.client.oidc.check_session.CheckSession", + "idpyoidc.client.oidc.end_session.EndSession", + "idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery", + "idpyoidc.client.oidc.read_registration.RegistrationRead", + "idpyoidc.client.oidc.refresh_access_token.RefreshAccessToken", + "idpyoidc.client.oidc.registration.Registration", + "idpyoidc.client.oidc.userinfo.UserInfo", + "idpyoidc.client.oidc.webfinger.WebFinger", ]: cls = importer(service) supported.update(cls._supports) @@ -48,91 +47,103 @@ def test_load_conf(self): # Only symmetric key client_conf = { "application_type": "web", - "redirect_uris": ["https://client.example.org/callback", - "https://client.example.org/callback2"], + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2", + ], "client_name": "My Example", "client_id": "client_id", "client_secret": "a longesh password", "logo_uri": "https://client.example.org/logo.png", - 'contacts': ["ve7jtb@example.org", "mary@example.org"] + "contacts": ["ve7jtb@example.org", "mary@example.org"], } self.claims.load_conf(client_conf, self.supported) - assert self.claims.get_preference('jwks') is None - assert self.claims.get_preference('jwks_uri') is None + assert self.claims.get_preference("jwks") is None + assert self.claims.get_preference("jwks_uri") is None def test_load_jwks(self): # Symmetric and asymmetric keys published as JWKS client_conf = { "application_type": "web", - 'base_url': "https://client.example.org/", - "redirect_uris": ["https://client.example.org/callback", - "https://client.example.org/callback2"], + "base_url": "https://client.example.org/", + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2", + ], "client_name": "My Example", "client_id": "client_id", "keys": {"key_defs": KEYSPEC, "read_only": True}, "client_secret": "a longesh password", "logo_uri": "https://client.example.org/logo.png", - 'contacts': ["ve7jtb@example.org", "mary@example.org"] + "contacts": ["ve7jtb@example.org", "mary@example.org"], } self.claims.load_conf(client_conf, self.supported) - assert self.claims.get_preference('jwks') is not None - assert self.claims.get_preference('jwks_uri') is None + assert self.claims.get_preference("jwks") is not None + assert self.claims.get_preference("jwks_uri") is None def test_load_jwks_uri1(self): # Symmetric and asymmetric keys published through a jwks_uri client_conf = { "application_type": "web", - 'base_url': "https://client.example.org/", - "redirect_uris": ["https://client.example.org/callback", - "https://client.example.org/callback2"], + "base_url": "https://client.example.org/", + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2", + ], "client_name": "My Example", "keys": {"uri_path": "static/jwks.json", "key_defs": KEYSPEC, "read_only": True}, "logo_uri": "https://client.example.org/logo.png", - 'contacts': ["ve7jtb@example.org", "mary@example.org"] + "contacts": ["ve7jtb@example.org", "mary@example.org"], } self.claims.load_conf(client_conf, self.supported) - assert self.claims.get_preference('jwks') is None - assert self.claims.get_preference( - 'jwks_uri') == f"{client_conf['base_url']}{client_conf['keys']['uri_path']}" + assert self.claims.get_preference("jwks") is None + assert ( + self.claims.get_preference("jwks_uri") + == f"{client_conf['base_url']}{client_conf['keys']['uri_path']}" + ) def test_load_jwks_uri2(self): # Symmetric and asymmetric keys published through a jwks_uri client_conf = { "application_type": "web", - 'base_url': "https://client.example.org/", - "redirect_uris": ["https://client.example.org/callback", - "https://client.example.org/callback2"], + "base_url": "https://client.example.org/", + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2", + ], "client_name": "My Example", "keys": {"key_defs": KEYSPEC, "read_only": True}, - "jwks_uri": 'https://client.example.org/keys/jwks.json', + "jwks_uri": "https://client.example.org/keys/jwks.json", "logo_uri": "https://client.example.org/logo.png", - 'contacts': ["ve7jtb@example.org", "mary@example.org"] + "contacts": ["ve7jtb@example.org", "mary@example.org"], } self.claims.load_conf(client_conf, self.supported) - assert self.claims.get_preference('jwks') is None - assert self.claims.get_preference('jwks_uri') == client_conf['jwks_uri'] + assert self.claims.get_preference("jwks") is None + assert self.claims.get_preference("jwks_uri") == client_conf["jwks_uri"] def test_registration_response(self): client_conf = { "application_type": "web", - 'base_url': "https://client.example.org/", - "redirect_uris": ["https://client.example.org/callback", - "https://client.example.org/callback2"], + "base_url": "https://client.example.org/", + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2", + ], "client_name": "My Example", "client_id": "client_id", "keys": {"key_defs": KEYSPEC, "read_only": True}, "client_secret": "a longesh password", "logo_uri": "https://client.example.org/logo.png", - 'contacts': ["ve7jtb@example.org", "mary@example.org"] + "contacts": ["ve7jtb@example.org", "mary@example.org"], } self.claims.load_conf(client_conf, self.supported) - OP_BASEURL = 'https://example.com' + OP_BASEURL = "https://example.com" provider_info_response = { "version": "3.0", "token_endpoint_auth_methods_supported": [ @@ -149,85 +160,95 @@ def test_registration_response(self): "registration_endpoint": f"{OP_BASEURL}/registration", "end_session_endpoint": f"{OP_BASEURL}/end_session", # below are a set which the RP has default values but the OP overwrites - "scopes_supported": ['openid', 'fee', 'faa', 'foo', 'fum'], - "response_types_supported": ['code', 'id_token', 'code id_token'], - "response_modes_supported": ['query', 'form_post', 'new_fangled'], + "scopes_supported": ["openid", "fee", "faa", "foo", "fum"], + "response_types_supported": ["code", "id_token", "code id_token"], + "response_modes_supported": ["query", "form_post", "new_fangled"], # this does not have a default value - "acr_values_supported": ['mfa'], + "acr_values_supported": ["mfa"], } - pref = self.claims.prefer = supported_to_preferred(supported=self.supported, - preference=self.claims.prefer, - base_url='https://example.com', - info=provider_info_response) + pref = self.claims.prefer = supported_to_preferred( + supported=self.supported, + preference=self.claims.prefer, + base_url="https://example.com", + info=provider_info_response, + ) registration_request = create_registration_request(self.claims.prefer, self.supported) - assert set(registration_request.keys()) == {'application_type', - 'client_name', - 'contacts', - 'default_max_age', - 'id_token_signed_response_alg', - 'jwks', - 'logo_uri', - 'redirect_uris', - 'request_object_signing_alg', - 'response_types', - 'subject_type', - 'token_endpoint_auth_method', - 'token_endpoint_auth_signing_alg', - 'userinfo_signed_response_alg'} - - assert registration_request["subject_type"] == 'public' + assert set(registration_request.keys()) == { + "application_type", + "client_name", + "contacts", + "default_max_age", + "id_token_signed_response_alg", + "jwks", + "logo_uri", + "redirect_uris", + "request_object_signing_alg", + "response_modes", # non-standard + "response_types", + "subject_type", + "token_endpoint_auth_method", + "token_endpoint_auth_signing_alg", + "userinfo_signed_response_alg", + } + + assert registration_request["subject_type"] == "public" registration_response = { "application_type": "web", - "redirect_uris": - ["https://client.example.org/callback", - "https://client.example.org/callback2"], + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2", + ], "client_name": "My Example", "logo_uri": "https://client.example.org/logo.png", "subject_type": "pairwise", - "sector_identifier_uri": - "https://other.example.net/file_of_redirect_uris.json", + "sector_identifier_uri": "https://other.example.net/file_of_redirect_uris.json", "token_endpoint_auth_method": "client_secret_basic", "jwks_uri": "https://client.example.org/my_public_keys.jwks", "userinfo_encrypted_response_alg": "RSA1_5", "userinfo_encrypted_response_enc": "A128CBC-HS256", "contacts": ["ve7jtb@example.org", "mary@example.org"], "request_uris": [ - "https://client.example.org/rf.txt#qpXaRLh_n93TTR9F252ValdatUQvQiJi5BDub2BeznA"] + "https://client.example.org/rf.txt#qpXaRLh_n93TTR9F252ValdatUQvQiJi5BDub2BeznA" + ], } - to_use = preferred_to_registered(prefers=self.claims.prefer, - supported=self.supported, - registration_response=registration_response) - - assert set(to_use.keys()) == {'application_type', - 'client_id', - 'client_name', - 'client_secret', - 'contacts', - 'default_max_age', - 'encrypt_request_object_supported', - 'encrypt_userinfo_supported', - 'id_token_signed_response_alg', - 'jwks', - 'jwks_uri', - 'logo_uri', - 'redirect_uris', - 'request_object_signing_alg', - 'request_uris', - 'response_modes_supported', - 'response_types', - 'scope', - 'sector_identifier_uri', - 'subject_type', - 'token_endpoint_auth_method', - 'token_endpoint_auth_signing_alg', - 'userinfo_encrypted_response_alg', - 'userinfo_encrypted_response_enc', - 'userinfo_signed_response_alg'} + to_use = preferred_to_registered( + prefers=self.claims.prefer, + supported=self.supported, + registration_response=registration_response, + ) + + assert set(to_use.keys()) == { + "application_type", + "client_id", + "client_name", + "client_secret", + "contacts", + "default_max_age", + "encrypt_request_object_supported", + "encrypt_userinfo_supported", + "id_token_signed_response_alg", + "jwks", + "jwks_uri", + "logo_uri", + "redirect_uris", + "request_object_signing_alg", + "request_uris", + "response_modes", + "response_types", + "scope", + "sector_identifier_uri", + "subject_type", + "token_endpoint_auth_method", + "token_endpoint_auth_signing_alg", + "userinfo_encrypted_response_alg", + "userinfo_encrypted_response_enc", + "userinfo_signed_response_alg", + } # Not what I asked for but something I can handle - assert to_use["subject_type"] == 'pairwise' + assert to_use["subject_type"] == "pairwise" diff --git a/tests/test_12_context.py b/tests/test_12_context.py index 2448a86a..7c0c30a4 100644 --- a/tests/test_12_context.py +++ b/tests/test_12_context.py @@ -1,7 +1,7 @@ from idpyoidc.context import OidcContext -ENTITY_ID = 'https://example.com' +ENTITY_ID = "https://example.com" class TestDumpLoad(object): diff --git a/tests/test_client_00_current.py b/tests/test_client_00_current.py index b701d6dc..414f71f7 100644 --- a/tests/test_client_00_current.py +++ b/tests/test_client_00_current.py @@ -13,11 +13,11 @@ def test_setup(self): def test_create_key_no_key(self): state_key = self.current.create_key() - self.current.set(state_key, {'iss': ISSUER}) - _iss = self.current.get(state_key)['iss'] + self.current.set(state_key, {"iss": ISSUER}) + _iss = self.current.get(state_key)["iss"] assert _iss == ISSUER - _item = self.current.get_set(state_key, claim=['iss']) - assert _item['iss'] == ISSUER + _item = self.current.get_set(state_key, claim=["iss"]) + assert _item["iss"] == ISSUER def test_store_and_retrieve_state_item(self): state_key = self.current.create_key() @@ -51,7 +51,7 @@ def test_other_id(self): assert _state_key == state_key def test_remove(self): - state_key = self.current.create_state(iss='foo') + state_key = self.current.create_state(iss="foo") self.current.bind_key("subject_id", state_key) self.current.bind_key("nonce", state_key) self.current.bind_key("session_id", state_key) diff --git a/tests/test_client_01_service_context.py b/tests/test_client_01_service_context.py index 0143be2f..bba666a3 100644 --- a/tests/test_client_01_service_context.py +++ b/tests/test_client_01_service_context.py @@ -15,26 +15,24 @@ "base_url": "https://example.com/cli", "key_conf": {"key_defs": KEYDEFS}, "issuer": "https://op.example.com", - "preference": { - "response_types": ["code"] - } + "preference": {"response_types": ["code"]}, } class TestServiceContext: - @pytest.fixture(autouse=True) def setup(self): self.unit = Unit() - self.service_context = ServiceContext(config=MINI_CONFIG, upstream_get=self.unit.unit_get, - base_url="https://example.com/cli") + self.service_context = ServiceContext( + config=MINI_CONFIG, upstream_get=self.unit.unit_get, base_url="https://example.com/cli" + ) def test_init(self): assert self.service_context def test_filename_from_webname(self): _filename = self.service_context.filename_from_webname("https://example.com/cli/jwks.json") - assert _filename == 'jwks.json' + assert _filename == "jwks.json" def test_get_sign_alg(self): _alg = self.service_context.get_sign_alg("id_token") @@ -57,8 +55,9 @@ def test_get_enc_alg_enc(self): assert _alg_enc == {"alg": None, "enc": None} self.service_context.claims.set_preference("userinfo_encrypted_response_alg", "RSA1_5") - self.service_context.claims.set_preference("userinfo_encrypted_response_enc", - "A128CBC+HS256") + self.service_context.claims.set_preference( + "userinfo_encrypted_response_enc", "A128CBC+HS256" + ) _alg_enc = self.service_context.get_enc_alg_enc("userinfo") assert _alg_enc == {"alg": "RSA1_5", "enc": "A128CBC+HS256"} diff --git a/tests/test_client_02_entity.py b/tests/test_client_02_entity.py index 2492dc53..f8929b5c 100644 --- a/tests/test_client_02_entity.py +++ b/tests/test_client_02_entity.py @@ -20,7 +20,8 @@ class TestEntity: @pytest.fixture(autouse=True) def setup(self): self.entity = Entity( - config=MINI_CONFIG.copy(), services={"xyz": {"class": "idpyoidc.client.service.Service"}} + config=MINI_CONFIG.copy(), + services={"xyz": {"class": "idpyoidc.client.service.Service"}}, ) def test_1(self): @@ -66,7 +67,7 @@ def test_client_authn_default(): "keys": {"key_defs": KEYSPEC, "read_only": True}, } - entity = Entity(config=config, client_type='oidc') + entity = Entity(config=config, client_type="oidc") assert entity.get_context().client_authn_methods == {} @@ -77,13 +78,15 @@ def test_client_authn_by_names(): "contacts": ["ops@example.org"], "redirect_uris": [f"{RP_BASEURL}/authz_cb"], "keys": {"key_defs": KEYSPEC, "read_only": True}, - "client_authn_methods": ['client_secret_basic', 'client_secret_post'] + "client_authn_methods": ["client_secret_basic", "client_secret_post"], } - entity = Entity(config=config, client_type='oidc') + entity = Entity(config=config, client_type="oidc") - assert set(entity.get_context().client_authn_methods.keys()) == {'client_secret_basic', - 'client_secret_post'} + assert set(entity.get_context().client_authn_methods.keys()) == { + "client_secret_basic", + "client_secret_post", + } class FooBar(ClientAuthnMethod): @@ -101,20 +104,19 @@ def test_client_authn_full(): "redirect_uris": [f"{RP_BASEURL}/authz_cb"], "keys": {"key_defs": KEYSPEC, "read_only": True}, "client_authn_methods": { - 'client_secret_basic': {}, - 'client_secret_post': None, - 'home_brew': { - 'class': FooBar, - 'kwargs': {'one': 'bar'} - } - } + "client_secret_basic": {}, + "client_secret_post": None, + "home_brew": {"class": FooBar, "kwargs": {"one": "bar"}}, + }, } - entity = Entity(config=config, client_type='oidc') + entity = Entity(config=config, client_type="oidc") - assert set(entity.get_context().client_authn_methods.keys()) == {'client_secret_basic', - 'client_secret_post', - 'home_brew'} + assert set(entity.get_context().client_authn_methods.keys()) == { + "client_secret_basic", + "client_secret_post", + "home_brew", + } def test_service_specific(): @@ -123,24 +125,27 @@ def test_service_specific(): "contacts": ["ops@example.org"], "redirect_uris": [f"{RP_BASEURL}/authz_cb"], "keys": {"key_defs": KEYSPEC, "read_only": True}, - "client_authn_methods": ['client_secret_basic', 'client_secret_post'] + "client_authn_methods": ["client_secret_basic", "client_secret_post"], } - entity = Entity(config=config, client_type='oidc', - services={ - "xyz": { - "class": "idpyoidc.client.service.Service", - "kwargs": { - "client_authn_methods": ['private_key_jwt'] - } - } - }) + entity = Entity( + config=config, + client_type="oidc", + services={ + "xyz": { + "class": "idpyoidc.client.service.Service", + "kwargs": {"client_authn_methods": ["private_key_jwt"]}, + } + }, + ) # A specific does not change the general - assert set(entity.get_context().client_authn_methods.keys()) == {'client_secret_basic', - 'client_secret_post'} + assert set(entity.get_context().client_authn_methods.keys()) == { + "client_secret_basic", + "client_secret_post", + } - assert set(entity.get_service('').client_authn_methods.keys()) == {'private_key_jwt'} + assert set(entity.get_service("").client_authn_methods.keys()) == {"private_key_jwt"} def test_service_specific2(): @@ -149,26 +154,28 @@ def test_service_specific2(): "contacts": ["ops@example.org"], "redirect_uris": [f"{RP_BASEURL}/authz_cb"], "keys": {"key_defs": KEYSPEC, "read_only": True}, - "client_authn_methods": ['client_secret_basic', 'client_secret_post'] + "client_authn_methods": ["client_secret_basic", "client_secret_post"], } - entity = Entity(config=config, client_type='oidc', - services={ - "xyz": { - "class": "idpyoidc.client.service.Service", - "kwargs": { - "client_authn_methods": { - 'home_brew': { - 'class': FooBar, - 'kwargs': {'one': 'bar'} - } - } - } - } - }) + entity = Entity( + config=config, + client_type="oidc", + services={ + "xyz": { + "class": "idpyoidc.client.service.Service", + "kwargs": { + "client_authn_methods": { + "home_brew": {"class": FooBar, "kwargs": {"one": "bar"}} + } + }, + } + }, + ) # A specific does not change the general - assert set(entity.get_context().client_authn_methods.keys()) == {'client_secret_basic', - 'client_secret_post'} + assert set(entity.get_context().client_authn_methods.keys()) == { + "client_secret_basic", + "client_secret_post", + } - assert set(entity.get_service('').client_authn_methods.keys()) == {'home_brew'} + assert set(entity.get_service("").client_authn_methods.keys()) == {"home_brew"} diff --git a/tests/test_client_02b_entity_metadata.py b/tests/test_client_02b_entity_metadata.py index fbc40ef8..fd542125 100644 --- a/tests/test_client_02b_entity_metadata.py +++ b/tests/test_client_02b_entity_metadata.py @@ -15,7 +15,7 @@ "application_type": "web", "contacts": "support@example.com", "response_types_supported": ["code"], - 'request_parameter': "request_uri", + "request_parameter": "request_uri", "request_object_signing_alg_values_supported": ["ES256"], "scope": ["openid", "profile", "email", "address", "phone"], "token_endpoint_auth_methods_supported": ["private_key_jwt"], @@ -24,78 +24,69 @@ "post_logout_redirect_uris": ["https://rp.example.com/post"], "backchannel_logout_uri": "https://rp.example.com/back", "backchannel_logout_session_required": True, - "client_authn_methods": ['bearer_header'] + "client_authn_methods": ["bearer_header"], }, - "services": { "discovery": { "class": "idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery", - "kwargs": {} - }, - "registration": { - "class": "idpyoidc.client.oidc.registration.Registration", - "kwargs": {} + "kwargs": {}, }, + "registration": {"class": "idpyoidc.client.oidc.registration.Registration", "kwargs": {}}, "authorization": { "class": "idpyoidc.client.oidc.authorization.Authorization", - "kwargs": {} - }, - "accesstoken": { - "class": "idpyoidc.client.oidc.access_token.AccessToken", - "kwargs": {} - }, - "userinfo": { - "class": "idpyoidc.client.oidc.userinfo.UserInfo", - "kwargs": {} + "kwargs": {}, }, - "end_session": { - "class": "idpyoidc.client.oidc.end_session.EndSession", - "kwargs": {} - } - } + "accesstoken": {"class": "idpyoidc.client.oidc.access_token.AccessToken", "kwargs": {}}, + "userinfo": {"class": "idpyoidc.client.oidc.userinfo.UserInfo", "kwargs": {}}, + "end_session": {"class": "idpyoidc.client.oidc.end_session.EndSession", "kwargs": {}}, + }, } KEY_CONF = { "private_path": "private/jwks.json", - "key_defs": [{"type": "RSA", "key": "", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["sig"]}], - "read_only": False + "key_defs": [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, + ], + "read_only": False, } def test_create_client(): - client = Entity(config=CLIENT_CONFIG, client_type='oidc') + client = Entity(config=CLIENT_CONFIG, client_type="oidc") _context = client.get_context() _context.map_supported_to_preferred() _pref = _context.prefers() _pref_with_values = [k for k, v in _pref.items() if v] - assert set(_pref_with_values) == {'application_type', - 'backchannel_logout_session_required', - 'backchannel_logout_uri', - 'callback_uris', - 'client_id', - 'client_secret', - 'contacts', - 'default_max_age', - 'grant_types_supported', - 'id_token_signing_alg_values_supported', - 'post_logout_redirect_uris', - 'redirect_uris', - 'request_object_signing_alg_values_supported', - 'request_parameter', - 'response_modes_supported', - 'response_types_supported', - 'scopes_supported', - 'subject_types_supported', - 'token_endpoint_auth_methods_supported', - 'token_endpoint_auth_signing_alg_values_supported', - 'userinfo_signing_alg_values_supported'} + assert set(_pref_with_values) == { + "application_type", + "backchannel_logout_session_required", + "backchannel_logout_uri", + "callback_uris", + "client_id", + "client_secret", + "contacts", + "default_max_age", + "grant_types_supported", + "id_token_signing_alg_values_supported", + "post_logout_redirect_uris", + "redirect_uris", + "request_object_signing_alg_values_supported", + "request_parameter", + "response_modes_supported", + "response_types_supported", + "scopes_supported", + "subject_types_supported", + "token_endpoint_auth_methods_supported", + "token_endpoint_auth_signing_alg_values_supported", + "userinfo_signing_alg_values_supported", + } # What's in service configuration has higher priority then what's just supported. _context = client.get_service_context() - assert _context.get_preference("contacts") == 'support@example.com' + assert _context.get_preference("contacts") == "support@example.com" # - assert _context.get_preference("userinfo_signing_alg_values_supported") == ['ES256'] + assert _context.get_preference("userinfo_signing_alg_values_supported") == ["ES256"] # How to act _context.map_preferred_to_registered() @@ -107,37 +98,36 @@ def test_create_client(): rr = set(RegistrationRequest.c_param.keys()) # The ones that are not defined and will therefore not appear in a registration request d = rr.difference(set(_conf_args)) - assert d == {'client_name', - 'client_uri', - 'default_acr_values', - 'frontchannel_logout_session_required', - 'frontchannel_logout_uri', - 'id_token_encrypted_response_alg', - 'id_token_encrypted_response_enc', - 'initiate_login_uri', - 'logo_uri', - 'jwks', - 'jwks_uri', - 'policy_uri', - 'post_logout_redirect_uri', - 'request_object_encryption_alg', - 'request_object_encryption_enc', - 'request_uris', - 'require_auth_time', - 'sector_identifier_uri', - 'tos_uri', - 'userinfo_encrypted_response_alg', - 'userinfo_encrypted_response_enc'} + assert d == { + "client_name", + "client_uri", + "default_acr_values", + "frontchannel_logout_session_required", + "frontchannel_logout_uri", + "id_token_encrypted_response_alg", + "id_token_encrypted_response_enc", + "initiate_login_uri", + "logo_uri", + "jwks", + "jwks_uri", + "policy_uri", + "post_logout_redirect_uri", + "request_object_encryption_alg", + "request_object_encryption_enc", + "request_uris", + "require_auth_time", + "sector_identifier_uri", + "tos_uri", + "userinfo_encrypted_response_alg", + "userinfo_encrypted_response_enc", + } def test_create_client_key_conf(): client_config = CLIENT_CONFIG.copy() - client_config.update({ - "key_conf": KEY_CONF, - "jwks_uri": "https://example.com/keys/jwks.json" - }) + client_config.update({"key_conf": KEY_CONF, "jwks_uri": "https://example.com/keys/jwks.json"}) - client = Entity(config=client_config, client_type='oidc') + client = Entity(config=client_config, client_type="oidc") assert client.get_service_context().get_preference("jwks_uri") @@ -145,13 +135,13 @@ def test_create_client_keyjar(): _keyjar = init_key_jar(**KEY_CONF) client_config = CLIENT_CONFIG.copy() - client = Entity(config=client_config, keyjar=_keyjar, client_type='oidc') + client = Entity(config=client_config, keyjar=_keyjar, client_type="oidc") _jwks = client.get_service_context().get_preference("jwks") assert _jwks def test_create_client_jwks_uri(): client_config = CLIENT_CONFIG.copy() - client_config['jwks_uri'] = "https://rp.example.com/jwks_uri.json" + client_config["jwks_uri"] = "https://rp.example.com/jwks_uri.json" client = Entity(config=client_config) assert client.get_service_context().get_preference("jwks_uri") diff --git a/tests/test_client_04_service.py b/tests/test_client_04_service.py index d0ded3a6..95e09348 100644 --- a/tests/test_client_04_service.py +++ b/tests/test_client_04_service.py @@ -6,7 +6,6 @@ class Response(object): - def __init__(self, status_code, text, headers=None): self.status_code = status_code self.text = text @@ -22,20 +21,19 @@ def __init__(self, status_code, text, headers=None): "redirect_uris": ["https://example.com/cli/authz_cb"], "preference": {"response_types_supported": ["code"]}, "key_conf": {"key_defs": KEYDEFS}, - "client_id": 'CLIENT', - 'base_url': "https://example.com/cli" + "client_id": "CLIENT", + "base_url": "https://example.com/cli", } class TestService: - @pytest.fixture(autouse=True) def create_service(self): self.entity = Entity( config=CLIENT_CONF.copy(), services={"authz": {"class": "idpyoidc.client.oidc.authorization.Authorization"}}, - client_type='oidc', - jwks_uri='https://example.com/cli/jwks.json' + client_type="oidc", + jwks_uri="https://example.com/cli/jwks.json", ) self.service = self.entity.get_service("authorization") @@ -45,8 +43,8 @@ def create_service(self): def upstream_get(self, *args): if args[0] == "context": return self.service_context - elif args[0] == 'attribute' and args[1] == 'keyjar': - return self.upstream_get('attribute', 'keyjar') + elif args[0] == "attribute" and args[1] == "keyjar": + return self.upstream_get("attribute", "keyjar") def test_1(self): assert self.service @@ -54,30 +52,42 @@ def test_1(self): def test_use(self): use = self.service_context.map_preferred_to_registered() - assert set(use.keys()) == {'application_type', - 'callback_uris', - 'client_id', - 'default_max_age', - 'encrypt_request_object_supported', - 'id_token_signed_response_alg', - 'jwks', - 'redirect_uris', - 'request_object_signing_alg', - 'response_modes_supported', - 'response_types', - 'scope', - 'subject_type'} + assert set(use.keys()) == { + "application_type", + "callback_uris", + "client_id", + "default_max_age", + "encrypt_request_object_supported", + "id_token_signed_response_alg", + "jwks", + "redirect_uris", + "request_object_signing_alg", + "response_modes", + "response_types", + "scope", + "subject_type", + } def test_gather_request_args(self): self.service.conf["request_args"] = {"response_type": "code"} args = self.service.gather_request_args(state="state") - assert args == {"response_type": "code", "state": "state", 'client_id': 'CLIENT', - 'redirect_uri': 'https://example.com/cli/authz_cb', 'scope': ['openid']} + assert args == { + "response_type": "code", + "state": "state", + "client_id": "CLIENT", + "redirect_uri": "https://example.com/cli/authz_cb", + "scope": ["openid"], + } self.service_context.set_usage("client_id", "client") args = self.service.gather_request_args(state="state") - assert args == {"client_id": "client", "response_type": "code", "state": "state", - 'redirect_uri': 'https://example.com/cli/authz_cb', 'scope': ['openid']} + assert args == { + "client_id": "client", + "response_type": "code", + "state": "state", + "redirect_uri": "https://example.com/cli/authz_cb", + "scope": ["openid"], + } self.service_context.set_usage("scope", ["openid", "foo"]) args = self.service.gather_request_args(state="state") @@ -86,7 +96,7 @@ def test_gather_request_args(self): "response_type": "code", "scope": ["openid", "foo"], "state": "state", - 'redirect_uri': 'https://example.com/cli/authz_cb', + "redirect_uri": "https://example.com/cli/authz_cb", } self.service_context.set_usage("redirect_uri", "https://rp.example.com") @@ -115,7 +125,7 @@ def test_parse_response_json(self): self.service_context.issuer = "https://op.example.com/" self.service_context.client_id = "client" - _sign_key = self.service.upstream_get('attribute', 'keyjar').get_signing_key() + _sign_key = self.service.upstream_get("attribute", "keyjar").get_signing_key() resp1 = AuthorizationResponse(code="auth_grant", state="state").to_json() arg = self.service.parse_response(resp1) assert isinstance(arg, AuthorizationResponse) @@ -127,7 +137,7 @@ def test_parse_response_jwt(self): self.service_context.issuer = "https://op.example.com/" self.service_context.client_id = "client" - _sign_key = self.service.upstream_get('attribute', 'keyjar').get_signing_key() + _sign_key = self.service.upstream_get("attribute", "keyjar").get_signing_key() resp1 = AuthorizationResponse(code="auth_grant", state="state").to_jwt( key=_sign_key, algorithm="RS256" ) @@ -141,7 +151,7 @@ def test_parse_response_err(self): self.service_context.issuer = "https://op.example.com/" self.service_context.client_id = "client" - _sign_key = self.service.upstream_get('attribute', 'keyjar').get_signing_key() + _sign_key = self.service.upstream_get("attribute", "keyjar").get_signing_key() resp1 = AuthorizationResponse(code="auth_grant", state="state").to_jwt( key=_sign_key, algorithm="RS256" ) @@ -150,11 +160,11 @@ def test_parse_response_err(self): class TestAuthorization(object): - @pytest.fixture(autouse=True) def create_service(self): self.entity = Entity( - config=CLIENT_CONF.copy(), services={"base": {"class": "idpyoidc.client.service.Service"}} + config=CLIENT_CONF.copy(), + services={"base": {"class": "idpyoidc.client.service.Service"}}, ) self.service = self.entity.get_service("") diff --git a/tests/test_client_06_client_authn.py b/tests/test_client_06_client_authn.py index 69eb264a..04d098c0 100644 --- a/tests/test_client_06_client_authn.py +++ b/tests/test_client_06_client_authn.py @@ -44,13 +44,15 @@ # "redirect_uris": ["https://example.com/cli/authz_cb"], "client_secret": "white boarding pass", "client_id": CLIENT_ID, - "key_conf": {'key_defs': KEYSPEC} + "key_conf": {"key_defs": KEYSPEC}, } KEY_CONF = { - "key_defs": [{"type": "RSA", "key": "", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["sig"]}], - "read_only": False + "key_defs": [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, + ], + "read_only": False, } @@ -65,14 +67,10 @@ def entity(): config=CLIENT_CONF, services={ "base": {"class": "idpyoidc.client.service.Service"}, - "accesstoken": { - "class": "idpyoidc.client.oidc.access_token.AccessToken", - "kwargs": { - } - } + "accesstoken": {"class": "idpyoidc.client.oidc.access_token.AccessToken", "kwargs": {}}, }, keyjar=keyjar, - client_type='oidc' + client_type="oidc", ) # The following two lines is necessary since they replace provider info collection and # client registration. @@ -90,17 +88,17 @@ def test_quote(): ) assert ( - http_args["headers"]["Authorization"] == "Basic " - 'Nzk2ZDhmYWUtYTQyZi00ZTRmLWFiMjUtZDYyMDViNmQ0ZmEyOk1LRU0vQTdQa243SnVVMExBY3h5SFZLdndkY3pzdWdhUFUwQmllTGI0Q2JRQWdRait5cGNhbkZPQ2IwL0ZBNWg=' + http_args["headers"]["Authorization"] == "Basic " + "Nzk2ZDhmYWUtYTQyZi00ZTRmLWFiMjUtZDYyMDViNmQ0ZmEyOk1LRU0vQTdQa243SnVVMExBY3h5SFZLdndkY3pzdWdhUFUwQmllTGI0Q2JRQWdRait5cGNhbkZPQ2IwL0ZBNWg=" ) class TestClientSecretBasic(object): - def test_construct(self, entity): _service = entity.get_service("") request = _service.construct( - request_args={'redirect_uri': "http://example.com", 'state': "ABCDE"}) + request_args={"redirect_uri": "http://example.com", "state": "ABCDE"} + ) csb = ClientSecretBasic() http_args = csb.construct(request, _service) @@ -108,7 +106,7 @@ def test_construct(self, entity): _authz = http_args["headers"]["Authorization"] assert _authz.startswith("Basic ") _token = _authz.split(" ", 1)[1] - assert base64.urlsafe_b64decode(_token) == b'A:white boarding pass' + assert base64.urlsafe_b64decode(_token) == b"A:white boarding pass" def test_does_not_remove_padding(self): request = AccessTokenRequest(code="foo", redirect_uri="http://example.com") @@ -129,7 +127,6 @@ def test_construct_cc(self): class TestBearerHeader(object): - def test_construct(self, entity): request = ResourceRequest(access_token="Sesame") bh = BearerHeader() @@ -141,9 +138,7 @@ def test_construct_with_http_args(self, entity): request = ResourceRequest(access_token="Sesame") bh = BearerHeader() # Any HTTP args should just be passed on - http_args = bh.construct( - request, service=entity.get_service(""), http_args={"foo": "bar"} - ) + http_args = bh.construct(request, service=entity.get_service(""), http_args={"foo": "bar"}) assert _eq(http_args.keys(), ["foo", "headers"]) assert http_args["headers"] == {"Authorization": "Bearer Sesame"} @@ -175,7 +170,7 @@ def test_construct_with_token(self, entity): _service = entity.get_service("") srv_cntx = _service.upstream_get("context") _state = srv_cntx.cstate.create_key() - srv_cntx.cstate.set(_state, {'iss': "Issuer"}) + srv_cntx.cstate.set(_state, {"iss": "Issuer"}) req = AuthorizationRequest( state=_state, response_type="code", redirect_uri="https://example.com", scope=["openid"] ) @@ -201,7 +196,6 @@ def test_construct_with_token(self, entity): class TestBearerBody(object): - def test_construct(self, entity): _token_service = entity.get_service("") request = ResourceRequest(access_token="Sesame") @@ -214,7 +208,7 @@ def test_construct_with_state(self, entity): _auth_service = entity.get_service("accesstoken") _cntx = _auth_service.upstream_get("context") _key = _cntx.cstate.create_key() - _cntx.cstate.set(_key, {'iss': "Issuer"}) + _cntx.cstate.set(_key, {"iss": "Issuer"}) resp = AuthorizationResponse(code="code", state=_key) _cntx.cstate.update(_key, resp) @@ -235,10 +229,10 @@ def test_construct_with_state(self, entity): def test_construct_with_request(self, entity): authz_service = entity.get_service("") - _cntx = authz_service.upstream_get('context') + _cntx = authz_service.upstream_get("context") _key = _cntx.cstate.create_key() - _cntx.cstate.set(_key, {'iss': "Issuer"}) + _cntx.cstate.set(_key, {"iss": "Issuer"}) resp1 = AuthorizationResponse(code="auth_grant", state=_key) response = authz_service.parse_response(resp1.to_urlencoded(), "urlencoded") authz_service.update_service_context(response, key=_key) @@ -258,11 +252,11 @@ def test_construct_with_request(self, entity): class TestClientSecretPost(object): - def test_construct(self, entity): _token_service = entity.get_service("") - request = _token_service.construct(request_args={'redirect_uri': "http://example.com", - 'state': "ABCDE"}) + request = _token_service.construct( + request_args={"redirect_uri": "http://example.com", "state": "ABCDE"} + ) csp = ClientSecretPost() http_args = csp.construct(request, service=_token_service) @@ -278,25 +272,26 @@ def test_construct(self, entity): def test_modify_1(self, entity): token_service = entity.get_service("") - request = token_service.construct(request_args={'redirect_uri': "http://example.com", - 'state': "ABCDE"}) + request = token_service.construct( + request_args={"redirect_uri": "http://example.com", "state": "ABCDE"} + ) csp = ClientSecretPost() http_args = csp.construct(request, service=token_service) assert "client_secret" in request def test_modify_2(self, entity): _service = entity.get_service("") - request = _service.construct(request_args={'redirect_uri': "http://example.com", - 'state': "ABCDE"}) + request = _service.construct( + request_args={"redirect_uri": "http://example.com", "state": "ABCDE"} + ) csp = ClientSecretPost() - _service.upstream_get("context").set_usage('client_secret', "") + _service.upstream_get("context").set_usage("client_secret", "") # this will fail with pytest.raises(AuthnFailure): http_args = csp.construct(request, service=_service) class TestPrivateKeyJWT(object): - def test_construct(self, entity): token_service = entity.get_service("") kb_rsa = KeyBundle( @@ -307,8 +302,8 @@ def test_construct(self, entity): for key in kb_rsa: key.add_kid() - _context = token_service.upstream_get('context') - _keyjar = token_service.upstream_get('attribute', 'keyjar') + _context = token_service.upstream_get("context") + _keyjar = token_service.upstream_get("attribute", "keyjar") _keyjar.add_kb("", kb_rsa) _context.provider_info = { "issuer": "https://example.com/", @@ -343,7 +338,7 @@ def test_construct_client_assertion(self, entity): request = AccessTokenRequest() pkj = PrivateKeyJWT() _ca = assertion_jwt( - token_service.upstream_get('context').get_client_id(), + token_service.upstream_get("context").get_client_id(), kb_rsa.get("RSA"), "https://example.com/token", "RS256", @@ -355,7 +350,6 @@ def test_construct_client_assertion(self, entity): class TestClientSecretJWT_TE(object): - def test_client_secret_jwt(self, entity): _service_context = entity.get_context() _service_context.token_endpoint = "https://example.com/token" @@ -371,22 +365,22 @@ def test_client_secret_jwt(self, entity): csj = ClientSecretJWT() request = AccessTokenRequest() - csj.construct( - request, service=entity.get_service(""), authn_endpoint="token_endpoint" - ) + csj.construct(request, service=entity.get_service(""), authn_endpoint="token_endpoint") assert request["client_assertion_type"] == JWT_BEARER assert "client_assertion" in request cas = request["client_assertion"] _kj = KeyJar() - _kj.add_symmetric(_service_context.get_client_id(), - _service_context.get_usage('client_secret'), ["sig"]) + _kj.add_symmetric( + _service_context.get_client_id(), _service_context.get_usage("client_secret"), ["sig"] + ) jso = JWT(key_jar=_kj, sign_alg="HS256").unpack(cas) assert _eq(jso.keys(), ["aud", "iss", "sub", "exp", "iat", "jti"]) _rj = JWS(alg="HS256") - info = _rj.verify_compact(cas, _kj.get_signing_key( - issuer_id=_service_context.get_client_id())) + info = _rj.verify_compact( + cas, _kj.get_signing_key(issuer_id=_service_context.get_client_id()) + ) assert _eq(info.keys(), ["aud", "iss", "sub", "jti", "exp", "iat"]) assert info["aud"] == [_service_context.provider_info["token_endpoint"]] @@ -406,9 +400,9 @@ def test_get_key_by_kid(self, entity): request = AccessTokenRequest() # get a kid for a symmetric key - kid = '' - for _key in entity.get_attribute('keyjar').get_issuer_keys(""): - if _key.kty == 'oct': + kid = "" + for _key in entity.get_attribute("keyjar").get_issuer_keys(""): + if _key.kty == "oct": kid = _key.kid break @@ -419,7 +413,7 @@ def test_get_key_by_kid(self, entity): def test_get_key_by_kid_fail(self, entity): token_service = entity.get_service("") - _service_context = token_service.upstream_get('context') + _service_context = token_service.upstream_get("context") _service_context.token_endpoint = "https://example.com/token" _service_context.provider_info = { @@ -460,7 +454,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): # Since I have an RSA key this doesn't fail csj.construct(request, service=token_service, authn_endpoint="token_endpoint") - _rsa_key = entity.keyjar.get(key_use='sig', key_type='rsa', issuer_id='')[0] + _rsa_key = entity.keyjar.get(key_use="sig", key_type="rsa", issuer_id="")[0] _jws = factory(request["client_assertion"]) assert _jws.jwt.headers["alg"] == "RS256" _rsa_key = entity.keyjar.get_signing_key(key_type="RSA")[0] @@ -484,7 +478,7 @@ def test_get_audience_and_algorithm_default_alg(self, entity): ] csj.construct(request, service=token_service, authn_endpoint="token_endpoint") - _ec_key = entity.keyjar.get(key_use='sig', key_type='ec', issuer_id='')[0] + _ec_key = entity.keyjar.get(key_use="sig", key_type="ec", issuer_id="")[0] _jws = factory(request["client_assertion"]) # Should be ES256 since I have a key for ES256 assert _jws.jwt.headers["alg"] == "ES256" @@ -493,11 +487,10 @@ def test_get_audience_and_algorithm_default_alg(self, entity): class TestClientSecretJWT_UI(object): - def test_client_secret_jwt(self, entity): access_token_service = entity.get_service("") - _service_context = access_token_service.upstream_get('context') + _service_context = access_token_service.upstream_get("context") _service_context.token_endpoint = "https://example.com/token" _service_context.provider_info = { "issuer": "https://example.com/", @@ -515,21 +508,24 @@ def test_client_secret_jwt(self, entity): cas = request["client_assertion"] _kj = KeyJar() - _kj.add_symmetric(_service_context.get_client_id(), - _service_context.get_usage('client_secret'), usage=["sig"]) + _kj.add_symmetric( + _service_context.get_client_id(), + _service_context.get_usage("client_secret"), + usage=["sig"], + ) jso = JWT(key_jar=_kj, sign_alg="HS256").unpack(cas) assert _eq(jso.keys(), ["aud", "iss", "sub", "jti", "exp", "iat"]) _rj = JWS(alg="HS256") - info = _rj.verify_compact(cas, _kj.get_signing_key( - issuer_id=_service_context.get_client_id())) + info = _rj.verify_compact( + cas, _kj.get_signing_key(issuer_id=_service_context.get_client_id()) + ) assert _eq(info.keys(), ["aud", "iss", "sub", "jti", "exp", "iat"]) assert info["aud"] == [_service_context.provider_info["issuer"]] class TestValidClientInfo(object): - def test_valid_service_context(self, entity): _service_context = entity.get_context() diff --git a/tests/test_client_10_entity.py b/tests/test_client_10_entity.py index 3a3f3a7f..6a1fd603 100644 --- a/tests/test_client_10_entity.py +++ b/tests/test_client_10_entity.py @@ -42,7 +42,7 @@ def test_import_keys_url(self): rsps.add( "GET", _jwks_url, - body=self.entity.get_attribute('keyjar').export_jwks_as_json(), + body=self.entity.get_attribute("keyjar").export_jwks_as_json(), status=200, adding_headers={"Content-Type": "application/json"}, ) @@ -50,8 +50,9 @@ def test_import_keys_url(self): self.entity.import_keys(keyspec) # Now there should be one belonging to https://example.com - assert len(self.entity.get_attribute('keyjar').get_issuer_keys( - "https://foobar.com")) == 1 + assert ( + len(self.entity.get_attribute("keyjar").get_issuer_keys("https://foobar.com")) == 1 + ) def test_import_keys_file_json(self): # Should only be one and that a symmetric key (client_secret) usable diff --git a/tests/test_client_12_client_auth.py b/tests/test_client_12_client_auth.py index 9149ffd0..bb81133e 100755 --- a/tests/test_client_12_client_auth.py +++ b/tests/test_client_12_client_auth.py @@ -47,7 +47,7 @@ def _eq(l1, l2): @pytest.fixture def entity(): - entity = Entity(config=CLIENT_CONF, client_type='oidc') + entity = Entity(config=CLIENT_CONF, client_type="oidc") # The following two lines is necessary since they replace provider info collection and # client registration. entity.get_service_context().map_supported_to_preferred() @@ -64,17 +64,19 @@ def test_quote(): ) assert ( - http_args["headers"]["Authorization"] == "Basic " - 'Nzk2ZDhmYWUtYTQyZi00ZTRmLWFiMjUtZDYyMDViNmQ0ZmEyOk1LRU0vQTdQa243SnVVMExBY3h5SFZLdndkY3pzdWdhUFUwQmllTGI0Q2JRQWdRait5cGNhbkZPQ2IwL0ZBNWg=' + http_args["headers"]["Authorization"] == "Basic " + "Nzk2ZDhmYWUtYTQyZi00ZTRmLWFiMjUtZDYyMDViNmQ0ZmEyOk1LRU0vQTdQa243SnVVMExBY3h5SFZLdndkY3pzdWdhUFUwQmllTGI0Q2JRQWdRait5cGNhbkZPQ2IwL0ZBNWg=" ) class TestClientSecretBasic(object): - def test_construct(self, entity): + entity.context.cstate.update("ABCDE", {'code': 'abcdefghijklmnopqrst'}) + _token_service = entity.get_service("accesstoken") - request = _token_service.construct(request_args={'redirect_uri': "http://example.com", - 'state': "ABCDE"}) + request = _token_service.construct( + request_args={"redirect_uri": "http://example.com", "state": "ABCDE"} + ) csb = ClientSecretBasic() http_args = csb.construct(request, _token_service) @@ -108,7 +110,6 @@ def test_construct_cc(self): class TestBearerHeader(object): - def test_construct(self, entity): request = ResourceRequest(access_token="Sesame") bh = BearerHeader() @@ -180,7 +181,6 @@ def test_construct_with_token(self, entity): class TestBearerBody(object): - def test_construct(self, entity): _token_service = entity.get_service("accesstoken") request = ResourceRequest(access_token="Sesame") @@ -235,8 +235,9 @@ def test_construct_with_request(self, entity): class TestClientSecretPost(object): - def test_construct(self, entity): + entity.context.cstate.update("ABCDE", {'code': 'abcdefghijklmnopqrst'}) + _token_service = entity.get_service("accesstoken") request = _token_service.construct(redirect_uri="http://example.com", state="ABCDE") csp = ClientSecretPost() @@ -253,6 +254,8 @@ def test_construct(self, entity): assert http_args is None def test_modify_1(self, entity): + entity.context.cstate.update("ABCDE", {'code': 'abcdefghijklmnopqrst'}) + token_service = entity.get_service("accesstoken") request = token_service.construct(redirect_uri="http://example.com", state="ABCDE") csp = ClientSecretPost() @@ -262,19 +265,20 @@ def test_modify_1(self, entity): assert "client_secret" in request def test_modify_2(self, entity): + entity.context.cstate.update("ABCDE", {'code': 'abcdefghijklmnopqrst'}) + token_service = entity.get_service("accesstoken") request = token_service.construct(redirect_uri="http://example.com", state="ABCDE") csp = ClientSecretPost() # client secret not in request or kwargs del request["client_secret"] - token_service.upstream_get("context").set_usage('client_secret', "") + token_service.upstream_get("context").set_usage("client_secret", "") # this will fail with pytest.raises(AuthnFailure): csp.construct(request, service=token_service) class TestPrivateKeyJWT(object): - def test_construct(self, entity): token_service = entity.get_service("accesstoken") kb_rsa = KeyBundle( @@ -332,7 +336,6 @@ def test_construct_client_assertion(self, entity): class TestClientSecretJWT_TE(object): - def test_client_secret_jwt(self, entity): _service_context = entity.get_context() _service_context.token_endpoint = "https://example.com/token" @@ -357,14 +360,16 @@ def test_client_secret_jwt(self, entity): cas = request["client_assertion"] _kj = KeyJar() - _kj.add_symmetric(_service_context.get_client_id(), - _service_context.get_usage('client_secret'), ["sig"]) + _kj.add_symmetric( + _service_context.get_client_id(), _service_context.get_usage("client_secret"), ["sig"] + ) jso = JWT(key_jar=_kj, sign_alg="HS256").unpack(cas) assert _eq(jso.keys(), ["aud", "iss", "sub", "exp", "iat", "jti"]) _rj = JWS(alg="HS256") info = _rj.verify_compact( - cas, _kj.get_signing_key(issuer_id=_service_context.get_client_id())) + cas, _kj.get_signing_key(issuer_id=_service_context.get_client_id()) + ) assert _eq(info.keys(), ["aud", "iss", "sub", "jti", "exp", "iat"]) assert info["aud"] == [_service_context.provider_info["token_endpoint"]] @@ -466,7 +471,6 @@ def test_get_audience_and_algorithm_default_alg(self, entity): class TestClientSecretJWT_UI(object): - def test_client_secret_jwt(self, entity): access_token_service = entity.get_service("accesstoken") @@ -488,21 +492,24 @@ def test_client_secret_jwt(self, entity): cas = request["client_assertion"] _kj = KeyJar() - _kj.add_symmetric(_service_context.get_client_id(), - _service_context.get_usage('client_secret'), usage=["sig"]) + _kj.add_symmetric( + _service_context.get_client_id(), + _service_context.get_usage("client_secret"), + usage=["sig"], + ) jso = JWT(key_jar=_kj, sign_alg="HS256").unpack(cas) assert _eq(jso.keys(), ["aud", "iss", "sub", "jti", "exp", "iat"]) _rj = JWS(alg="HS256") info = _rj.verify_compact( - cas, _kj.get_signing_key(issuer_id=_service_context.get_client_id())) + cas, _kj.get_signing_key(issuer_id=_service_context.get_client_id()) + ) assert _eq(info.keys(), ["aud", "iss", "sub", "jti", "exp", "iat"]) assert info["aud"] == [_service_context.provider_info["issuer"]] class TestValidClientInfo(object): - def test_valid_service_context(self, entity): _service_context = entity.get_context() diff --git a/tests/test_client_14_service_context_impexp.py b/tests/test_client_14_service_context_impexp.py index f0ec76e3..1a994cd9 100644 --- a/tests/test_client_14_service_context_impexp.py +++ b/tests/test_client_14_service_context_impexp.py @@ -19,7 +19,7 @@ def test_client_info_init(): "base_url": BASE_URL, "requests_dir": "requests", } - ci = ServiceContext(config=config, client_type='oidc', base_url=BASE_URL) + ci = ServiceContext(config=config, client_type="oidc", base_url=BASE_URL) ci.claims.load_conf(config, supports=ci.supports()) ci.map_supported_to_preferred() ci.map_preferred_to_registered() @@ -40,11 +40,11 @@ def test_client_info_init(): def test_set_and_get_client_secret(): service_context = ServiceContext(base_url=BASE_URL) - service_context.set_usage('client_secret', "longenoughsupersecret") + service_context.set_usage("client_secret", "longenoughsupersecret") srvcnx2 = ServiceContext(base_url=BASE_URL).load(service_context.dump()) - assert srvcnx2.get_usage('client_secret') == "longenoughsupersecret" + assert srvcnx2.get_usage("client_secret") == "longenoughsupersecret" def test_set_and_get_client_id(): @@ -97,7 +97,6 @@ def verify_alg_support(service_context, alg, usage, typ): class TestClientInfo(object): - @pytest.fixture(autouse=True) def create_client_info_instance(self): config = { @@ -276,17 +275,18 @@ def test_import_keys_file_json(self): keyspec = {"file": {"rsa": [file_path]}} self.service_context.import_keys(keyspec) - _sc_state = self.service_context.dump(exclude_attributes=["context", 'upstream_get']) + _sc_state = self.service_context.dump(exclude_attributes=["context", "upstream_get"]) _jsc_state = json.dumps(_sc_state) _o_state = json.loads(_jsc_state) - srvcntx = ServiceContext(base_url=BASE_URL).load(_o_state, init_args={ - 'upstream_get': self.service_context.upstream_get}) + srvcntx = ServiceContext(base_url=BASE_URL).load( + _o_state, init_args={"upstream_get": self.service_context.upstream_get} + ) # Now there should be 2, the second a RSA key for signing - assert len(srvcntx.upstream_get('attribute', 'keyjar').get_issuer_keys("")) == 2 + assert len(srvcntx.upstream_get("attribute", "keyjar").get_issuer_keys("")) == 2 def test_import_keys_url(self): - _keyjar = self.service_context.upstream_get('attribute', 'keyjar') + _keyjar = self.service_context.upstream_get("attribute", "keyjar") assert len(_keyjar.get_issuer_keys("")) == 1 # One EC key for signing @@ -309,9 +309,15 @@ def test_import_keys_url(self): srvcntx = ServiceContext(base_url=BASE_URL).load( self.service_context.dump(exclude_attributes=["context"]), - init_args={'upstream_get': self.service_context.upstream_get} + init_args={"upstream_get": self.service_context.upstream_get}, ) # Now there should be one belonging to https://example.com - assert len(srvcntx.upstream_get('attribute', 'keyjar').get_issuer_keys( - "https://foobar.com")) == 1 + assert ( + len( + srvcntx.upstream_get("attribute", "keyjar").get_issuer_keys( + "https://foobar.com" + ) + ) + == 1 + ) diff --git a/tests/test_client_20_oauth2.py b/tests/test_client_20_oauth2.py index 81defa1b..cb227a68 100644 --- a/tests/test_client_20_oauth2.py +++ b/tests/test_client_20_oauth2.py @@ -65,7 +65,7 @@ def test_construct_authorization_request(self): "response_type": ["code"], } - self.client.get_context().cstate.set("ABCDE", {"iss": 'issuer'}) + self.client.get_context().cstate.set("ABCDE", {"iss": "issuer"}) msg = self.client.get_service("authorization").construct(request_args=req_args) assert isinstance(msg, AuthorizationRequest) assert msg["client_id"] == "client_1" @@ -87,9 +87,7 @@ def test_construct_accesstoken_request(self): self.client.get_context().cstate.update("ABCDE", auth_response) - msg = self.client.get_service("accesstoken").construct( - request_args=req_args, state="ABCDE" - ) + msg = self.client.get_service("accesstoken").construct(request_args=req_args, state="ABCDE") assert isinstance(msg, AccessTokenRequest) assert msg.to_dict() == { @@ -104,7 +102,7 @@ def test_construct_accesstoken_request(self): def test_construct_refresh_token_request(self): _context = self.client.get_context() _state = "ABCDE" - _context.cstate.set(_state, {'iss': "issuer"}) + _context.cstate.set(_state, {"iss": "issuer"}) auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state="state" @@ -146,9 +144,7 @@ def test_error_response_500(self): err = ResponseMessage(error="Illegal") http_resp = MockResponse(500, err.to_urlencoded()) with pytest.raises(ParseError): - self.client.parse_request_response( - self.client.get_service("authorization"), http_resp - ) + self.client.parse_request_response(self.client.get_service("authorization"), http_resp) def test_error_response_2(self): err = ResponseMessage(error="Illegal") @@ -157,9 +153,7 @@ def test_error_response_2(self): ) with pytest.raises(OidcServiceError): - self.client.parse_request_response( - self.client.get_service("authorization"), http_resp - ) + self.client.parse_request_response(self.client.get_service("authorization"), http_resp) BASE_URL = "https://example.com" @@ -196,7 +190,7 @@ def create_client(self): assert self.client def test_keyjar(self): - _keyjar = self.client.get_attribute('keyjar') + _keyjar = self.client.get_attribute("keyjar") assert len(_keyjar) == 2 # one issuer assert len(_keyjar[""]) == 3 assert len(_keyjar.get("sig")) == 3 diff --git a/tests/test_client_21_oidc_service.py b/tests/test_client_21_oidc_service.py index fb3ac1b2..18eb38c5 100644 --- a/tests/test_client_21_oidc_service.py +++ b/tests/test_client_21_oidc_service.py @@ -29,7 +29,6 @@ class Response(object): - def __init__(self, status_code, text, headers=None): self.status_code = status_code self.text = text @@ -70,7 +69,6 @@ def make_keyjar(): class TestAuthorization(object): - @pytest.fixture(autouse=True) def create_request(self): client_config = { @@ -78,15 +76,19 @@ def create_request(self): "client_secret": "a longesh password", "callback_uris": { "redirect_uris": { # different flows - "code": ["https://example.com/cli/authz_cb"], - "implicit": ["https://example.com/cli/imp_cb"], - "form_post": ["https://example.com/cli/form"] + "query": ["https://example.com/cli/authz_cb"], + "fragment": ["https://example.com/cli/imp_cb"], + "form_post": ["https://example.com/cli/form"], } }, - "response_types_supported": ['code', 'token'] + "response_types_supported": ["code", "token"], } - entity = Entity(services=DEFAULT_OIDC_SERVICES, keyjar=make_keyjar(), config=client_config, - client_type='oidc') + entity = Entity( + services=DEFAULT_OIDC_SERVICES, + keyjar=make_keyjar(), + config=client_config, + client_type="oidc", + ) _context = entity.get_context() _context.issuer = "https://example.com" _context.map_supported_to_preferred() @@ -184,7 +186,7 @@ def test_request_init(self): def test_request_init_request_method(self): req_args = {"response_type": "code", "state": "state"} self.service.endpoint = "https://example.com/authorize" - self.context.set_usage('request_object_encryption_alg', None) + self.context.set_usage("request_object_encryption_alg", None) _info = self.service.get_request_parameters(request_args=req_args, request_method="value") assert set(_info.keys()) == {"url", "method", "request"} msg = AuthorizationRequest().from_urlencoded(self.service.get_urlinfo(_info["url"])) @@ -297,9 +299,9 @@ def test_allow_unsigned_idtoken(self, allow_sign_alg_none): idt = JWT(ISS_KEY, iss=ISS, lifetime=3600, sign_alg="none") payload = {"sub": "123456789", "aud": ["client_id"], "nonce": req_args["nonce"]} _idt = idt.pack(payload) - self.service.upstream_get("context").claims.set_usage("verify_args", { - "allow_sign_alg_none": allow_sign_alg_none - }) + self.service.upstream_get("context").claims.set_usage( + "verify_args", {"allow_sign_alg_none": allow_sign_alg_none} + ) resp = AuthorizationResponse(state="state", code="code", id_token=_idt) if allow_sign_alg_none: self.service.parse_response(resp.to_urlencoded()) @@ -309,7 +311,6 @@ def test_allow_unsigned_idtoken(self, allow_sign_alg_none): class TestAuthorizationCallback(object): - @pytest.fixture(autouse=True) def create_request(self): client_config = { @@ -317,14 +318,18 @@ def create_request(self): "client_secret": "a longesh password", "callback_uris": { "redirect_uris": { - "code": ["https://example.com/cli/authz_cb"], - "implicit": ["https://example.com/cli/authz_im_cb"], - "form_post": ["https://example.com/cli/authz_fp_cb"] + "query": ["https://example.com/cli/authz_cb"], + "fragment": ["https://example.com/cli/authz_im_cb"], + "form_post": ["https://example.com/cli/authz_fp_cb"], }, }, } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, - client_type='oidc') + entity = Entity( + keyjar=make_keyjar(), + config=client_config, + services=DEFAULT_OIDC_SERVICES, + client_type="oidc", + ) _context = entity.get_context() _context.issuer = "https://example.com" _context.map_supported_to_preferred() @@ -391,19 +396,18 @@ def test_construct_form_post(self): class TestAccessTokenRequest(object): - @pytest.fixture(autouse=True) def create_request(self): client_config = { "client_id": "client_id", "client_secret": "a longesh password", "redirect_uris": ["https://example.com/cli/authz_cb"], - 'client_authn_methods': ['client_secret_basic'] + "client_authn_methods": ["client_secret_basic"], } entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES) _context = entity.get_context() _context.issuer = "https://example.com" - _context.provider_info = {'token_endpoint': f'{_context.issuer}/token'} + _context.provider_info = {"token_endpoint": f"{_context.issuer}/token"} self.service = entity.get_service("accesstoken") # add some history @@ -457,7 +461,7 @@ def test_request_init(self): assert set(_info.keys()) == {"body", "url", "headers", "method", "request"} assert _info["url"] == "https://example.com/authorize" msg = AccessTokenRequest().from_urlencoded(self.service.get_urlinfo(_info["body"])) - assert set(msg.keys()) == {'redirect_uri', 'grant_type', 'state', 'code', 'client_id'} + assert set(msg.keys()) == {"redirect_uri", "grant_type", "state", "code", "client_id"} def test_id_token_nonce_match(self): _cstate = self.service.upstream_get("context").cstate @@ -470,7 +474,6 @@ def test_id_token_nonce_match(self): class TestProviderInfo(object): - @pytest.fixture(autouse=True) def create_service(self): self._iss = ISS @@ -492,7 +495,7 @@ def create_service(self): "userinfo_signing_alg_values_supported": ["ES256"], "post_logout_redirect_uris": ["https://rp.example.com/post"], "backchannel_logout_uri": "https://rp.example.com/back", - "backchannel_logout_session_required": True + "backchannel_logout_session_required": True, }, "services": { "web_finger": {"class": "idpyoidc.client.oidc.webfinger.WebFinger"}, @@ -501,27 +504,24 @@ def create_service(self): }, "registration": { "class": "idpyoidc.client.oidc.registration.Registration", - "kwargs": {} + "kwargs": {}, }, "authorization": { "class": "idpyoidc.client.oidc.authorization.Authorization", - "kwargs": {} + "kwargs": {}, }, "accesstoken": { "class": "idpyoidc.client.oidc.access_token.AccessToken", - "kwargs": {} - }, - "userinfo": { - "class": "idpyoidc.client.oidc.userinfo.UserInfo", - "kwargs": {} + "kwargs": {}, }, + "userinfo": {"class": "idpyoidc.client.oidc.userinfo.UserInfo", "kwargs": {}}, "end_session": { "class": "idpyoidc.client.oidc.end_session.EndSession", - "kwargs": {} - } - } + "kwargs": {}, + }, + }, } - entity = Entity(keyjar=make_keyjar(), config=client_config, client_type='oidc') + entity = Entity(keyjar=make_keyjar(), config=client_config, client_type="oidc") entity.get_context().issuer = "https://example.com" self.service = entity.get_service("provider_info") @@ -732,39 +732,41 @@ def test_post_parse(self): with responses.RequestsMock() as rsps: rsps.add("GET", resp["jwks_uri"], body=iss_jwks, status=200) - self.service.update_service_context(resp, '') + self.service.update_service_context(resp, "") # static client registration _context.map_preferred_to_registered() use_copy = self.service.upstream_get("context").claims.use.copy() # jwks content will change dynamically between runs - assert 'jwks' in use_copy - del use_copy['jwks'] - del use_copy['callback_uris'] - - assert use_copy == {'application_type': 'web', - 'backchannel_logout_session_required': True, - 'backchannel_logout_uri': 'https://rp.example.com/back', - 'client_id': 'client_id', - 'client_secret': 'a longesh password', - 'contacts': ['ops@example.org'], - 'default_max_age': 86400, - 'encrypt_id_token_supported': False, - 'encrypt_request_object_supported': False, - 'encrypt_userinfo_supported': False, - 'grant_types': ['authorization_code'], - 'id_token_signed_response_alg': 'RS256', - 'post_logout_redirect_uris': ['https://rp.example.com/post'], - 'redirect_uris': ['https://example.com/cli/authz_cb'], - 'request_object_signing_alg': 'ES256', - 'response_modes_supported': ['query', 'fragment', 'form_post'], - 'response_types': ['code'], - 'scope': ['openid'], - 'subject_type': 'public', - 'token_endpoint_auth_method': 'private_key_jwt', - 'token_endpoint_auth_signing_alg': 'ES256', - 'userinfo_signed_response_alg': 'ES256'} + assert "jwks" in use_copy + del use_copy["jwks"] + del use_copy["callback_uris"] + + assert use_copy == { + "application_type": "web", + "backchannel_logout_session_required": True, + "backchannel_logout_uri": "https://rp.example.com/back", + "client_id": "client_id", + "client_secret": "a longesh password", + "contacts": ["ops@example.org"], + "default_max_age": 86400, + "encrypt_id_token_supported": False, + "encrypt_request_object_supported": False, + "encrypt_userinfo_supported": False, + "grant_types": ["authorization_code"], + "id_token_signed_response_alg": "RS256", + "post_logout_redirect_uris": ["https://rp.example.com/post"], + "redirect_uris": ["https://example.com/cli/authz_cb"], + "request_object_signing_alg": "ES256", + "response_modes": ["query", "fragment", "form_post"], + "response_types": ["code"], + "scope": ["openid"], + "subject_type": "public", + "token_endpoint_auth_method": "private_key_jwt", + "token_endpoint_auth_signing_alg": "ES256", + "userinfo_signed_response_alg": "ES256", + } def test_post_parse_2(self): OP_BASEURL = ISS @@ -793,39 +795,41 @@ def test_post_parse_2(self): with responses.RequestsMock() as rsps: rsps.add("GET", resp["jwks_uri"], body=iss_jwks, status=200) - self.service.update_service_context(resp, '') + self.service.update_service_context(resp, "") # static client registration _context.map_preferred_to_registered() use_copy = self.service.upstream_get("context").claims.use.copy() # jwks content will change dynamically between runs - assert 'jwks' in use_copy - del use_copy['jwks'] - del use_copy['callback_uris'] - - assert use_copy == {'application_type': 'web', - 'backchannel_logout_session_required': True, - 'backchannel_logout_uri': 'https://rp.example.com/back', - 'client_id': 'client_id', - 'client_secret': 'a longesh password', - 'contacts': ['ops@example.org'], - 'default_max_age': 86400, - 'encrypt_id_token_supported': False, - 'encrypt_request_object_supported': False, - 'encrypt_userinfo_supported': False, - 'grant_types': ['authorization_code'], - 'id_token_signed_response_alg': 'RS256', - 'post_logout_redirect_uris': ['https://rp.example.com/post'], - 'redirect_uris': ['https://example.com/cli/authz_cb'], - 'request_object_signing_alg': 'ES256', - 'response_modes_supported': ['query', 'fragment', 'form_post'], - 'response_types': ['code'], - 'scope': ['openid'], - 'subject_type': 'public', - 'token_endpoint_auth_method': 'private_key_jwt', - 'token_endpoint_auth_signing_alg': 'ES256', - 'userinfo_signed_response_alg': 'ES256'} + assert "jwks" in use_copy + del use_copy["jwks"] + del use_copy["callback_uris"] + + assert use_copy == { + "application_type": "web", + "backchannel_logout_session_required": True, + "backchannel_logout_uri": "https://rp.example.com/back", + "client_id": "client_id", + "client_secret": "a longesh password", + "contacts": ["ops@example.org"], + "default_max_age": 86400, + "encrypt_id_token_supported": False, + "encrypt_request_object_supported": False, + "encrypt_userinfo_supported": False, + "grant_types": ["authorization_code"], + "id_token_signed_response_alg": "RS256", + "post_logout_redirect_uris": ["https://rp.example.com/post"], + "redirect_uris": ["https://example.com/cli/authz_cb"], + "request_object_signing_alg": "ES256", + "response_modes": ["query", "fragment", "form_post"], + "response_types": ["code"], + "scope": ["openid"], + "subject_type": "public", + "token_endpoint_auth_method": "private_key_jwt", + "token_endpoint_auth_signing_alg": "ES256", + "userinfo_signed_response_alg": "ES256", + } def test_response_types_to_grant_types(): @@ -848,7 +852,6 @@ def create_jws(val): class TestRegistration(object): - @pytest.fixture(autouse=True) def create_request(self): self._iss = ISS @@ -858,10 +861,12 @@ def create_request(self): "requests_dir": "requests", "base_url": "https://example.com/cli/", } - entity = Entity(keyjar=make_keyjar(), - config=client_config, - services=DEFAULT_OIDC_SERVICES, - client_type='oidc') + entity = Entity( + keyjar=make_keyjar(), + config=client_config, + services=DEFAULT_OIDC_SERVICES, + client_type="oidc", + ) entity.get_context().issuer = "https://example.com" entity.get_context().map_supported_to_preferred() self.service = entity.get_service("registration") @@ -869,38 +874,45 @@ def create_request(self): def test_construct(self): _req = self.service.construct() assert isinstance(_req, RegistrationRequest) - assert set(_req.keys()) == {'application_type', - 'default_max_age', - 'grant_types', - 'id_token_signed_response_alg', - 'jwks', - 'redirect_uris', - 'request_object_signing_alg', - 'response_types', - 'subject_type', - 'token_endpoint_auth_method', - 'token_endpoint_auth_signing_alg', - 'userinfo_signed_response_alg'} + assert set(_req.keys()) == { + "application_type", + "default_max_age", + "grant_types", + "id_token_signed_response_alg", + "jwks", + "redirect_uris", + "request_object_signing_alg", + "response_modes", + "response_types", + "subject_type", + "token_endpoint_auth_method", + "token_endpoint_auth_signing_alg", + "userinfo_signed_response_alg", + } def test_config_with_post_logout(self): self.service.upstream_get("context").claims.set_preference( - "post_logout_redirect_uri", "https://example.com/post_logout") + "post_logout_redirect_uri", "https://example.com/post_logout" + ) _req = self.service.construct() assert isinstance(_req, RegistrationRequest) - assert set(_req.keys()) == {'application_type', - 'default_max_age', - 'grant_types', - 'id_token_signed_response_alg', - 'jwks', - 'post_logout_redirect_uri', - 'redirect_uris', - 'request_object_signing_alg', - 'response_types', - 'subject_type', - 'token_endpoint_auth_method', - 'token_endpoint_auth_signing_alg', - 'userinfo_signed_response_alg'} + assert set(_req.keys()) == { + "application_type", + "default_max_age", + "grant_types", + "id_token_signed_response_alg", + "jwks", + "post_logout_redirect_uri", + "redirect_uris", + "request_object_signing_alg", + "response_modes", + "response_types", + "subject_type", + "token_endpoint_auth_method", + "token_endpoint_auth_signing_alg", + "userinfo_signed_response_alg", + } assert "post_logout_redirect_uri" in _req.keys() @@ -911,12 +923,16 @@ def test_config_with_required_request_uri(): "redirect_uris": ["https://example.com/cli/authz_cb"], "issuer": ISS, "requests_dir": "requests", - "request_parameter": 'request_uri', - 'request_uris': ["https://example.com/cli/requests"], + "request_parameter": "request_uri", + "request_uris": ["https://example.com/cli/requests"], "base_url": "https://example.com/cli", } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, - client_type='oidc') + entity = Entity( + keyjar=make_keyjar(), + config=client_config, + services=DEFAULT_OIDC_SERVICES, + client_type="oidc", + ) entity.get_context().issuer = "https://example.com" pi_service = entity.get_service("provider_info") @@ -925,11 +941,22 @@ def test_config_with_required_request_uri(): reg_service = entity.get_service("registration") _req = reg_service.construct() assert isinstance(_req, RegistrationRequest) - assert set(_req.keys()) == {"application_type", "response_types", "jwks", - "redirect_uris", "grant_types", "id_token_signed_response_alg", - "request_uris", 'default_max_age', 'request_object_signing_alg', - 'subject_type', 'token_endpoint_auth_method', - 'token_endpoint_auth_signing_alg', 'userinfo_signed_response_alg'} + assert set(_req.keys()) == { + "application_type", + "response_modes", + "response_types", + "jwks", + "redirect_uris", + "grant_types", + "id_token_signed_response_alg", + "request_uris", + "default_max_age", + "request_object_signing_alg", + "subject_type", + "token_endpoint_auth_method", + "token_endpoint_auth_signing_alg", + "userinfo_signed_response_alg", + } def test_config_logout_uri(): @@ -950,11 +977,15 @@ def test_config_logout_uri(): "post_logout_redirect_uri": "https://rp.example.com/post", "backchannel_logout_uri": "https://rp.example.com/back", "backchannel_logout_session_required": True, - 'backchannel_logout_supported': True - } + "backchannel_logout_supported": True, + }, } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, - client_type='oidc') + entity = Entity( + keyjar=make_keyjar(), + config=client_config, + services=DEFAULT_OIDC_SERVICES, + client_type="oidc", + ) _context = entity.get_context() _context.issuer = "https://example.com" @@ -965,23 +996,25 @@ def test_config_logout_uri(): reg_service = entity.get_service("registration") _req = reg_service.construct() assert isinstance(_req, RegistrationRequest) - assert set(_req.keys()) == {'application_type', - 'default_max_age', - 'grant_types', - 'id_token_signed_response_alg', - 'jwks', - 'redirect_uris', - 'request_object_signing_alg', - 'request_uris', - 'response_types', - 'subject_type', - 'token_endpoint_auth_method', - 'token_endpoint_auth_signing_alg', - 'userinfo_signed_response_alg'} + assert set(_req.keys()) == { + "application_type", + "default_max_age", + "grant_types", + "id_token_signed_response_alg", + "jwks", + "redirect_uris", + "request_object_signing_alg", + "request_uris", + "response_modes", + "response_types", + "subject_type", + "token_endpoint_auth_method", + "token_endpoint_auth_signing_alg", + "userinfo_signed_response_alg", + } class TestUserInfo(object): - @pytest.fixture(autouse=True) def create_request(self): self._iss = ISS @@ -993,8 +1026,12 @@ def create_request(self): "requests_dir": "requests", "base_url": "https://example.com/cli/", } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, - client_type='oidc') + entity = Entity( + keyjar=make_keyjar(), + config=client_config, + services=DEFAULT_OIDC_SERVICES, + client_type="oidc", + ) entity.get_context().issuer = "https://example.com" self.service = entity.get_service("userinfo") @@ -1106,7 +1143,7 @@ def test_unpack_encrypted_response(self): # Add encryption key _kj = build_keyjar([{"type": "RSA", "use": ["enc"]}], issuer_id="") # Own key jar gets the private key - self.service.upstream_get("attribute", 'keyjar').import_jwks( + self.service.upstream_get("attribute", "keyjar").import_jwks( _kj.export_jwks(private=True), issuer_id="client_id" ) # opponent gets the public key @@ -1116,9 +1153,7 @@ def test_unpack_encrypted_response(self): sub="diana", given_name="Diana", family_name="krall", iss=ISS, aud="client_id" ) enckey = ISS_KEY.get_encrypt_key("rsa", issuer_id="client_id") - algspec = self.service.upstream_get("context").get_enc_alg_enc( - self.service.service_name - ) + algspec = self.service.upstream_get("context").get_enc_alg_enc(self.service.service_name) enc_resp = resp.to_jwe(enckey, **algspec) _resp = self.service.parse_response(enc_resp, state="abcde", sformat="jwe") @@ -1126,7 +1161,6 @@ def test_unpack_encrypted_response(self): class TestCheckSession(object): - @pytest.fixture(autouse=True) def create_request(self): self._iss = ISS @@ -1154,7 +1188,6 @@ def test_construct(self): class TestCheckID(object): - @pytest.fixture(autouse=True) def create_request(self): self._iss = ISS @@ -1182,7 +1215,6 @@ def test_construct(self): class TestEndSession(object): - @pytest.fixture(autouse=True) def create_request(self): self._iss = ISS @@ -1193,7 +1225,7 @@ def create_request(self): "issuer": self._iss, "requests_dir": "requests", "base_url": "https://example.com/cli/", - "post_logout_redirect_uris": ["https://example.com/post_logout"] + "post_logout_redirect_uris": ["https://example.com/post_logout"], } services = {"checksession": {"class": "idpyoidc.client.oidc.end_session.EndSession"}} entity = Entity(keyjar=make_keyjar(), config=client_config, services=services) @@ -1205,7 +1237,8 @@ def create_request(self): def test_construct(self): self.service.upstream_get("service_context").cstate.update( - "abcde", {"id_token": "a.signed.jwt"}) + "abcde", {"id_token": "a.signed.jwt"} + ) _req = self.service.construct(state="abcde") assert isinstance(_req, EndSessionRequest) assert len(_req) == 3 @@ -1217,7 +1250,7 @@ def test_authz_service_conf(): "client_id": "client_id", "client_secret": "a longesh password", "redirect_uris": ["https://example.com/cli/authz_cb"], - "response_types": ["code"], + "response_types": ["code", "id_token"], } services = { @@ -1235,8 +1268,9 @@ def test_authz_service_conf(): }, } } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=services, - client_type='oidc') + entity = Entity( + keyjar=make_keyjar(), config=client_config, services=services, client_type="oidc" + ) _context = entity.get_context() _context.issuer = "https://example.com" _context.map_supported_to_preferred() @@ -1244,13 +1278,15 @@ def test_authz_service_conf(): service = entity.get_service("authorization") req = service.construct() - assert set(req.keys()) == {'claims', - 'client_id', - 'nonce', - 'redirect_uri', - 'response_type', - 'scope', - 'state'} + assert set(req.keys()) == { + "claims", + "client_id", + "nonce", + "redirect_uri", + "response_type", + "scope", + "state", + } assert set(req["claims"].keys()) == {"id_token"} @@ -1265,8 +1301,12 @@ def test_jwks_uri_conf(): "id_token_signed_response_alg": "RS384", "userinfo_signed_response_alg": "RS384", } - entity = Entity(keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, - client_type='oidc') + entity = Entity( + keyjar=make_keyjar(), + config=client_config, + services=DEFAULT_OIDC_SERVICES, + client_type="oidc", + ) _context = entity.get_context() _context.issuer = "https://example.com" _context.map_supported_to_preferred() @@ -1291,7 +1331,7 @@ def test_jwks_uri_arg(): keyjar=make_keyjar(), config=client_config, services=DEFAULT_OIDC_SERVICES, - client_type='oidc' + client_type="oidc", ) _context = entity.get_context() _context.issuer = "https://example.com" diff --git a/tests/test_client_22_oidc.py b/tests/test_client_22_oidc.py index 9bbdd65e..94b06dbf 100755 --- a/tests/test_client_22_oidc.py +++ b/tests/test_client_22_oidc.py @@ -50,7 +50,7 @@ def create_client(self): "redirect_uris": ["https://example.com/cli/authz_cb"], "client_id": "client_1", "client_secret": "abcdefghijklmnop", - 'client_authn_methods': ['bearer_header'] + "client_authn_methods": ["bearer_header"], } self.client = RP(config=conf) @@ -62,7 +62,7 @@ def test_construct_authorization_request(self): "nonce": "nonce", } - self.client.get_context().cstate.set("ABCDE", {'iss': "issuer"}) + self.client.get_context().cstate.set("ABCDE", {"iss": "issuer"}) msg = self.client.get_service("authorization").construct(request_args=req_args) assert isinstance(msg, AuthorizationRequest) @@ -73,7 +73,7 @@ def test_construct_accesstoken_request(self): auth_request = AuthorizationRequest(redirect_uri="https://example.com/cli/authz_cb") _state = _context.cstate.create_key() - _context.cstate.set(_state, {'iss': "issuer"}) + _context.cstate.set(_state, {"iss": "issuer"}) auth_request["state"] = _state _context.cstate.update(_state, auth_request) @@ -84,9 +84,7 @@ def test_construct_accesstoken_request(self): # Bind access code to state req_args = {} - msg = self.client.get_service("accesstoken").construct( - request_args=req_args, state=_state - ) + msg = self.client.get_service("accesstoken").construct(request_args=req_args, state=_state) assert isinstance(msg, AccessTokenRequest) assert msg.to_dict() == { "client_id": "client_1", @@ -99,7 +97,7 @@ def test_construct_accesstoken_request(self): def test_construct_refresh_token_request(self): _context = self.client.get_context() - _context.cstate.set("ABCDE", {'iss':"issuer"}) + _context.cstate.set("ABCDE", {"iss": "issuer"}) auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state="state" @@ -128,7 +126,7 @@ def test_construct_refresh_token_request(self): def test_do_userinfo_request_init(self): _context = self.client.get_context() _state = _context.cstate.create_key() - _context.cstate.set(_state, {'iss': "issuer"}) + _context.cstate.set(_state, {"iss": "issuer"}) auth_request = AuthorizationRequest( redirect_uri="https://example.com/cli/authz_cb", state="state" diff --git a/tests/test_client_23_pkce.py b/tests/test_client_23_pkce.py index e7882822..c26fd4c3 100644 --- a/tests/test_client_23_pkce.py +++ b/tests/test_client_23_pkce.py @@ -48,23 +48,17 @@ def create_client(self): "client_id": "client_id", "client_secret": "a longesh password", "redirect_uris": ["https://example.com/cli/authz_cb"], - "preference": { - "response_types": ["code"] - }, + "preference": {"response_types": ["code"]}, "add_ons": { "pkce": { "function": "idpyoidc.client.oauth2.add_on.pkce.add_support", - "kwargs": { - "code_challenge_length": 64, - "code_challenge_method": "S256" - }, + "kwargs": {"code_challenge_length": 64, "code_challenge_method": "S256"}, } }, } - self.entity = Entity(keyjar=CLI_KEY, - config=config, - services=DEFAULT_OAUTH2_SERVICES, - client_type='oauth2') + self.entity = Entity( + keyjar=CLI_KEY, config=config, services=DEFAULT_OAUTH2_SERVICES, client_type="oauth2" + ) if "add_ons" in config: do_add_ons(config["add_ons"], self.entity.get_services()) @@ -105,8 +99,8 @@ def test_access_token_and_pkce(self): auth_response = AuthorizationResponse(code="access code") _context = self.entity.get_context() _context.cstate.update(_state, auth_response) - #auth_serv = self.entity.get_service("authorization") - #_state = _context.cstate.create_state(iss="Issuer") + # auth_serv = self.entity.get_service("authorization") + # _state = _context.cstate.create_state(iss="Issuer") token_service = self.entity.get_service("accesstoken") request = token_service.construct_request(state=_state) diff --git a/tests/test_client_25_oauth2_cc_ropc.py b/tests/test_client_25_oauth2_cc_ropc.py index e03382fb..4b43ffd5 100644 --- a/tests/test_client_25_oauth2_cc_ropc.py +++ b/tests/test_client_25_oauth2_cc_ropc.py @@ -11,13 +11,12 @@ class TestCC: - @pytest.fixture(autouse=True) def create_service(self): client_config = { "client_id": "client_id", "client_secret": "another password", - "base_url": BASE_URL + "base_url": BASE_URL, } services = { "client_credentials": { @@ -34,8 +33,10 @@ def test_token_get_request(self): _info = _srv.get_request_parameters() assert _info["method"] == "POST" assert _info["url"] == "https://example.com/token" - assert _info[ - "body"] == "grant_type=client_credentials&client_id=client_id&client_secret=another+password" + assert ( + _info["body"] + == "grant_type=client_credentials&client_id=client_id&client_secret=another+password" + ) assert _info["headers"] == { "Content-Type": "application/x-www-form-urlencoded", @@ -63,30 +64,29 @@ def test_token_parse_response(self): class TestROPC: - @pytest.fixture(autouse=True) def create_service(self): client_config = { "client_id": "client_id", "client_secret": "another password", - "base_url": BASE_URL + "base_url": BASE_URL, } services = { "resource_owner_password_credentials": { - "class": - "idpyoidc.client.oauth2.resource_owner_password_credentials" - ".ROPCAccessTokenRequest" + "class": "idpyoidc.client.oauth2.resource_owner_password_credentials" + ".ROPCAccessTokenRequest" } } self.entity = Entity(config=client_config, services=services) self.entity.get_service( - "resource_owner_password_credentials").endpoint = "https://example.com/token" + "resource_owner_password_credentials" + ).endpoint = "https://example.com/token" def test_token_get_request(self): _srv = self.entity.get_service("resource_owner_password_credentials") - _info = _srv.get_request_parameters({'username': 'diana', 'password': 'krall'}) + _info = _srv.get_request_parameters({"username": "diana", "password": "krall"}) assert _info["method"] == "POST" assert _info["url"] == "https://example.com/token" assert _info["body"] == ( @@ -94,7 +94,8 @@ def test_token_get_request(self): "password=krall&" "grant_type=password&" "client_id=client_id&" - "client_secret=another+password") + "client_secret=another+password" + ) assert _info["headers"] == { "Content-Type": "application/x-www-form-urlencoded", diff --git a/tests/test_client_26_read_registration.py b/tests/test_client_26_read_registration.py index cb8026f9..32eacbd6 100644 --- a/tests/test_client_26_read_registration.py +++ b/tests/test_client_26_read_registration.py @@ -34,8 +34,8 @@ def create_request(self): "read_registration": { "class": "idpyoidc.client.oidc.read_registration.RegistrationRead" }, - 'authorization': {'class': 'idpyoidc.client.oidc.authorization.Authorization'}, - 'accesstoken': {'class': 'idpyoidc.client.oidc.access_token.AccessToken'} + "authorization": {"class": "idpyoidc.client.oidc.authorization.Authorization"}, + "accesstoken": {"class": "idpyoidc.client.oidc.access_token.AccessToken"}, } self.entity = Entity(config=client_config, services=services) diff --git a/tests/test_client_27_conversation.py b/tests/test_client_27_conversation.py index f1117ca9..fa0d24cc 100644 --- a/tests/test_client_27_conversation.py +++ b/tests/test_client_27_conversation.py @@ -27,22 +27,21 @@ "keys": [ { "d": "mcAW1xeNsjzyV1M7F7_cUHz0MIR" - "-tcnKFJnbbo5UXxMRUPu17qwRHr8ttep1Ie64r2L9QlphcT9BjYd0KQ8ll3flIzLtiJv__MNPQVjk5bsYzb_erQRzSwLJU-aCcNFB8dIyQECzu-p44UVEPQUGzykImsSShvMQhcvrKiqqg7NlijJuEKHaKynV9voPsjwKYSqk6lH8kMloCaVS-dOkK-r7bZtbODUxx9GJWnxhX0JWXcdrPZRb29y9cdthrMcEaCXG23AxnMEfp-enDqarLHYTQrCBJXs_b-9k2d8v9zLm7E-Pf-0YGmaoJtX89lwQkO_SmFF3sXsnI2cFreqU3Q", + "-tcnKFJnbbo5UXxMRUPu17qwRHr8ttep1Ie64r2L9QlphcT9BjYd0KQ8ll3flIzLtiJv__MNPQVjk5bsYzb_erQRzSwLJU-aCcNFB8dIyQECzu-p44UVEPQUGzykImsSShvMQhcvrKiqqg7NlijJuEKHaKynV9voPsjwKYSqk6lH8kMloCaVS-dOkK-r7bZtbODUxx9GJWnxhX0JWXcdrPZRb29y9cdthrMcEaCXG23AxnMEfp-enDqarLHYTQrCBJXs_b-9k2d8v9zLm7E-Pf-0YGmaoJtX89lwQkO_SmFF3sXsnI2cFreqU3Q", "e": "AQAB", "kid": "c19uYlBJXzVfNjNZeGVnYmxncHZwUzZTZDVwUFdxdVJLU3AxQXdwaFdfbw", "kty": "RSA", "n": "3ZblhNL2CjRktLM9vyDn8jnA4G1B1HCpPh" - "-gv2AK4m9qDBZPYZGOGqzeW3vanvLTBlqnPm0GHg4rOrfMEwwLrfMcgmg1y4GD0vVU8G9HP1" - "-oUPtKUqaKOp313tFKzFh9_OHGQ6EmhxG7gegPR9kQXduTDXqBFi81MzRplIQ8DHLM3-n2CyDW1V" - "-dhRVh" - "-AM0ZcJyzR_DvZ3mhG44DysPdHQOSeWnpdn1d81" - "-PriqZfhAF9tn1ihgtjXd5swf1HTSjLd7xv1hitGf2245Xmr" - "-V2pQFzeMukLM3JKbTYbElsB7Zm0wZx49hZMtgx35XMoO04bifdbO3yLtTA5ovXN3fQ", + "-gv2AK4m9qDBZPYZGOGqzeW3vanvLTBlqnPm0GHg4rOrfMEwwLrfMcgmg1y4GD0vVU8G9HP1" + "-oUPtKUqaKOp313tFKzFh9_OHGQ6EmhxG7gegPR9kQXduTDXqBFi81MzRplIQ8DHLM3-n2CyDW1V" + "-dhRVh" + "-AM0ZcJyzR_DvZ3mhG44DysPdHQOSeWnpdn1d81" + "-PriqZfhAF9tn1ihgtjXd5swf1HTSjLd7xv1hitGf2245Xmr" + "-V2pQFzeMukLM3JKbTYbElsB7Zm0wZx49hZMtgx35XMoO04bifdbO3yLtTA5ovXN3fQ", "p": "88aNu59aBn0elksaVznzoVKkdbT5B4euhOIEqJoFvFbEocw9mC4k" - "-yozIAQSV5FEakoSPOl8lrymCoM3Q1fVHfaM9Rbb9RCRlsV1JOeVVZOE05HUdz8zOIqLBDEGM_oQqDwF_kp" - "-4nDTZ1-dtnGdTo4Cf7QRuApzE_dwVabUCTc", - "q": - "6LOHuM7H_0kDrMTwUEX7Aubzr792GoJ6EgTKIQY25SAFTZpYwuC3NnqlAdy8foIa3d7eGU2yICRbBG0S_ITcooDFrOa7nZ6enMUclMTxW8FwwvBXeIHo9cIsrKYtOThGplz43Cvl73MK5M58ZRmuhaNYa6Mk4PL4UokARfEiDus", + "-yozIAQSV5FEakoSPOl8lrymCoM3Q1fVHfaM9Rbb9RCRlsV1JOeVVZOE05HUdz8zOIqLBDEGM_oQqDwF_kp" + "-4nDTZ1-dtnGdTo4Cf7QRuApzE_dwVabUCTc", + "q": "6LOHuM7H_0kDrMTwUEX7Aubzr792GoJ6EgTKIQY25SAFTZpYwuC3NnqlAdy8foIa3d7eGU2yICRbBG0S_ITcooDFrOa7nZ6enMUclMTxW8FwwvBXeIHo9cIsrKYtOThGplz43Cvl73MK5M58ZRmuhaNYa6Mk4PL4UokARfEiDus", "use": "sig", }, { @@ -70,15 +69,13 @@ "kid": "Mk0yN2w0N3BZLWtyOEpQWGFmNDZvQi1hbDl2azR3ai1WNElGdGZQSFd6MA", "e": "AQAB", "n": "yPrOADZtGoa9jxFCmDsJ1nAYmzgznUxCtUlb_ty33" - "-AFNEqzW_pSLr5g6RQAPGsvVQqbsb9AB18QNgz" - "-eG7cnvKIIR7JXWCuGv_Q9MwoRD0-zaYGRbRvFoTZokZMB6euBfMo6kijJ" - "-gdKuSaxIE84X_Fcf1ESAKJ0EX6Cxdm8hKkBelGIDPMW5z7EHQ8OuLCQtTJnDvbjEOk9sKzkKqVj53XFs5vjd4WUhxS6xIDcWE-lTafUpm0BsobklLePidHxyAMGOunL_Pt3RCLZGlWeWOO9fZhLtydiDWiZlcNR0FQEX_mfV1kCOHHBFN1VKOY2pyJpjp9djdtHxPZ9fP35w", - "d": - "aRBTqGDLYFaXuba4LYSPe_5Vnq8erFg1dzfGU9Fmfi5KCjAS2z5cv_reBnpiNTODJt3Izn7AJhpYCyl3zdWGl8EJ0OabNalY2txoi9A-LI4nyrHEDaRpfkgszVwaWtYZbxrShMc8I5x_wvCGx7sX7Hoy6YgQreRFzw8Fy86MDncpmcUwQTnXVUMLgioeYz5gW6rwXkqj_NVyuHPiheykJG026cXFNBWplCk4ET1bvf_6ZB9QmLwO16Pu2O-dtu1HHDOqI7y6-YgKIC6mcLrQrF9-FO7NkilcOB7zODNiYzhDBQ2YJAbcdn_3M_lkhaFwR-n4WB7vCM0vNqz7lEg6QQ", - "p": - "_STNoJFkX9_uw8whytVmTrHP5K7vcZBIH9nuCTvj137lC48ZpR1UARx4qShxHLfK7DrufHd7TYnJkEMNUHFmdKvkaVQMY0_BsBSvCrUl10gzxsI08hg53L17E1Pe73iZp3f5nA4eB-1YB-km1Cc-Xs10OPWedJHf9brlCPDLAb8", + "-AFNEqzW_pSLr5g6RQAPGsvVQqbsb9AB18QNgz" + "-eG7cnvKIIR7JXWCuGv_Q9MwoRD0-zaYGRbRvFoTZokZMB6euBfMo6kijJ" + "-gdKuSaxIE84X_Fcf1ESAKJ0EX6Cxdm8hKkBelGIDPMW5z7EHQ8OuLCQtTJnDvbjEOk9sKzkKqVj53XFs5vjd4WUhxS6xIDcWE-lTafUpm0BsobklLePidHxyAMGOunL_Pt3RCLZGlWeWOO9fZhLtydiDWiZlcNR0FQEX_mfV1kCOHHBFN1VKOY2pyJpjp9djdtHxPZ9fP35w", + "d": "aRBTqGDLYFaXuba4LYSPe_5Vnq8erFg1dzfGU9Fmfi5KCjAS2z5cv_reBnpiNTODJt3Izn7AJhpYCyl3zdWGl8EJ0OabNalY2txoi9A-LI4nyrHEDaRpfkgszVwaWtYZbxrShMc8I5x_wvCGx7sX7Hoy6YgQreRFzw8Fy86MDncpmcUwQTnXVUMLgioeYz5gW6rwXkqj_NVyuHPiheykJG026cXFNBWplCk4ET1bvf_6ZB9QmLwO16Pu2O-dtu1HHDOqI7y6-YgKIC6mcLrQrF9-FO7NkilcOB7zODNiYzhDBQ2YJAbcdn_3M_lkhaFwR-n4WB7vCM0vNqz7lEg6QQ", + "p": "_STNoJFkX9_uw8whytVmTrHP5K7vcZBIH9nuCTvj137lC48ZpR1UARx4qShxHLfK7DrufHd7TYnJkEMNUHFmdKvkaVQMY0_BsBSvCrUl10gzxsI08hg53L17E1Pe73iZp3f5nA4eB-1YB-km1Cc-Xs10OPWedJHf9brlCPDLAb8", "q": "yz9T0rPEc0ZPjSi45gsYiQL2KJ3UsPHmLrgOHq0D4UvsB6UFtUtOWh7A1UpQdmBuHjIJz" - "-Iq7VH4kzlI6VxoXhwE69oxBXr4I7fBudZRvlLuIJS9M2wvsTVouj0DBYSR6ZlAQHCCou89P2P6zQCEaqu7bWXNcpyTixbbvOU1w9k", + "-Iq7VH4kzlI6VxoXhwE69oxBXr4I7fBudZRvlLuIJS9M2wvsTVouj0DBYSR6ZlAQHCCou89P2P6zQCEaqu7bWXNcpyTixbbvOU1w9k", }, { "kty": "EC", @@ -106,31 +103,14 @@ "WebFinger": {"class": WebFinger}, "discovery": { "class": "idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery", - "kwargs": {} + "kwargs": {}, }, - "registration": { - "class": "idpyoidc.client.oidc.registration.Registration", - "kwargs": {} - }, - "authorization": { - "class": "idpyoidc.client.oidc.authorization.Authorization", - "kwargs": {} - }, - "accesstoken": { - "class": "idpyoidc.client.oidc.access_token.AccessToken", - "kwargs": {} - }, - "refresh_token": { - "class": "idpyoidc.client.oidc.refresh_access_token.RefreshAccessToken" - }, - "userinfo": { - "class": "idpyoidc.client.oidc.userinfo.UserInfo", - "kwargs": {} - }, - "end_session": { - "class": "idpyoidc.client.oidc.end_session.EndSession", - "kwargs": {} - } + "registration": {"class": "idpyoidc.client.oidc.registration.Registration", "kwargs": {}}, + "authorization": {"class": "idpyoidc.client.oidc.authorization.Authorization", "kwargs": {}}, + "accesstoken": {"class": "idpyoidc.client.oidc.access_token.AccessToken", "kwargs": {}}, + "refresh_token": {"class": "idpyoidc.client.oidc.refresh_access_token.RefreshAccessToken"}, + "userinfo": {"class": "idpyoidc.client.oidc.userinfo.UserInfo", "kwargs": {}}, + "end_session": {"class": "idpyoidc.client.oidc.end_session.EndSession", "kwargs": {}}, } @@ -149,12 +129,12 @@ def test_conversation(): "post_logout_redirect_uri": "https://rp.example.com/post", "backchannel_logout_uri": "https://rp.example.com/back", "backchannel_logout_session_required": True, - 'allow': {'missing_kid': True}, - "client_authn_methods": ['bearer_header'], - "services": SERVICES + "allow": {"missing_kid": True}, + "client_authn_methods": ["bearer_header"], + "services": SERVICES, } - entity = Entity(config=config, keyjar=RP_KEYJAR, client_type='oidc') + entity = Entity(config=config, keyjar=RP_KEYJAR, client_type="oidc") assert set(entity.get_services().keys()) == { "accesstoken", @@ -164,7 +144,7 @@ def test_conversation(): "refresh_token", "userinfo", "provider_info", - 'end_session', + "end_session", } service_context = entity.get_context() @@ -174,11 +154,11 @@ def test_conversation(): info = webfinger_service.get_request_parameters(request_args={"resource": "foobar@example.org"}) assert ( - info["url"] == "https://example.org/.well-known/webfinger?rel=http" - "%3A%2F" - "%2Fopenid.net%2Fspecs%2Fconnect%2F1.0%2Fissuer" - "&resource" - "=acct%3Afoobar%40example.org" + info["url"] == "https://example.org/.well-known/webfinger?rel=http" + "%3A%2F" + "%2Fopenid.net%2Fspecs%2Fconnect%2F1.0%2Fissuer" + "&resource" + "=acct%3Afoobar%40example.org" ) webfinger_response = json.dumps( @@ -405,7 +385,7 @@ def test_conversation(): resp = provider_info_service.parse_response(provider_info_response) assert isinstance(resp, ProviderConfigurationResponse) - provider_info_service.update_service_context(resp, '') + provider_info_service.update_service_context(resp, "") _pi = entity.get_context().provider_info assert _pi["issuer"] == OP_BASEURL @@ -418,22 +398,25 @@ def test_conversation(): assert info["url"] == "https://example.org/op/registration" _body = json.loads(info["body"]) - assert set(_body.keys()) == {'application_type', - 'backchannel_logout_session_required', - 'backchannel_logout_uri', - 'contacts', - 'default_max_age', - 'grant_types', - 'id_token_signed_response_alg', - 'jwks', - 'redirect_uris', - 'request_object_signing_alg', - 'request_uris', - 'response_types', - 'subject_type', - 'token_endpoint_auth_method', - 'token_endpoint_auth_signing_alg', - 'userinfo_signed_response_alg'} + assert set(_body.keys()) == { + "application_type", + "backchannel_logout_session_required", + "backchannel_logout_uri", + "contacts", + "default_max_age", + "grant_types", + "id_token_signed_response_alg", + "jwks", + "redirect_uris", + "request_object_signing_alg", + "request_uris", + "response_modes", + "response_types", + "subject_type", + "token_endpoint_auth_method", + "token_endpoint_auth_signing_alg", + "userinfo_signed_response_alg", + } assert info["headers"] == {"Content-Type": "application/json"} now = int(time.time()) @@ -460,7 +443,7 @@ def test_conversation(): registration_service.update_service_context(response) assert service_context.get_client_id() == "zls2qhN1jO6A" - assert service_context.get_usage('client_secret') == "c8434f28cf9375d9a7" + assert service_context.get_usage("client_secret") == "c8434f28cf9375d9a7" assert set(service_context.registration_response.keys()) == { "client_secret_expires_at", "contacts", @@ -525,11 +508,7 @@ def test_conversation(): assert info["url"] == "https://example.org/op/token" _qp = parse_qs(info["body"]) # since the default is private_key_jwt !!! - assert set(_qp.keys()) == {'client_id', - 'code', - 'grant_type', - 'redirect_uri', - 'state'} + assert set(_qp.keys()) == {"client_id", "code", "grant_type", "redirect_uri", "state"} assert info["headers"]["Content-Type"] == "application/x-www-form-urlencoded" # create the IdToken @@ -570,20 +549,22 @@ def test_conversation(): _item = _cstate.get(STATE) - assert set(_item.keys()) == {'__expires_at', - '__verified_id_token', - 'access_token', - 'client_id', - 'code', - 'expires_in', - 'id_token', - 'iss', - 'nonce', - 'redirect_uri', - 'response_type', - 'scope', - 'state', - 'token_type'} + assert set(_item.keys()) == { + "__expires_at", + "__verified_id_token", + "access_token", + "client_id", + "code", + "expires_in", + "id_token", + "iss", + "nonce", + "redirect_uri", + "response_type", + "scope", + "state", + "token_type", + } assert _item["token_type"] == "Bearer" assert _item["access_token"] == "Z0FBQUFBQmFkdFF" diff --git a/tests/test_client_28_stand_alone.py b/tests/test_client_28_stand_alone.py new file mode 100644 index 00000000..05cce795 --- /dev/null +++ b/tests/test_client_28_stand_alone.py @@ -0,0 +1,598 @@ +from urllib.parse import parse_qs +from urllib.parse import urlsplit + +import pytest +import responses +from cryptojwt.key_jar import build_keyjar + +from idpyoidc.client.defaults import DEFAULT_KEY_DEFS +from idpyoidc.client.defaults import DEFAULT_OIDC_SERVICES +from idpyoidc.client.defaults import OIDCONF_PATTERN +from idpyoidc.client.exception import Unsupported +from idpyoidc.client.oauth2.stand_alone_client import StandAloneClient +from idpyoidc.exception import VerificationError +from idpyoidc.message.oidc import AccessTokenResponse +from idpyoidc.message.oidc import AuthorizationResponse +from idpyoidc.message.oidc import IdToken +from idpyoidc.message.oidc import OpenIDSchema +from idpyoidc.message.oidc import ProviderConfigurationResponse +from idpyoidc.message.oidc import RegistrationResponse + +ISSUER = "https://op.example.com" + +STATIC_CONFIG = { + "base_url": "https://example.com/cli/", + "client_id": "Number5", + "client_type": "oidc", + "client_secret": "asdflkjh0987654321", + "provider_info": { + "issuer": ISSUER, + "authorization_endpoint": "https://op.example.com/authn", + "token_endpoint": "https://op.example.com/token", + "userinfo_endpoint": "https://op.example.com/user", + } +} + + +def get_state_from_url(url): + p = urlsplit(url) + qs = parse_qs(p.query) + return qs['state'][0] + + +class TestStandAloneClientOIDCStatic(object): + + @pytest.fixture(autouse=True) + def client_setup(self): + self.client = StandAloneClient(config=STATIC_CONFIG) + + def test_get_services(self): + assert set(self.client.get_services().keys()) == {'provider_info', 'registration', + 'authorization', 'accesstoken', + 'refresh_token', 'userinfo'} + + def test_do_provider_info(self): + issuer = self.client.do_provider_info() + assert issuer == STATIC_CONFIG['provider_info']['issuer'] + assert self.client.context.get('issuer') == issuer + + def test_client_registration(self): + self.client.do_provider_info() + self.client.do_client_registration() + assert self.client.context.get_usage('client_id') == STATIC_CONFIG['client_id'] + + def test_init_authorization(self): + self.client.do_provider_info() + self.client.do_client_registration() + url = self.client.init_authorization() + assert url + p = urlsplit(url) + qs = parse_qs(p.query) + assert qs['client_id'][0] == STATIC_CONFIG['client_id'] + assert qs['response_type'][0] == 'code' + + def test_response_type_id_token(self): + self.client.do_provider_info() + self.client.do_client_registration() + + # Explicitly set + url = self.client.init_authorization(req_args={'response_type': 'id_token'}) + + assert url + p = urlsplit(url) + qs = parse_qs(p.query) + assert qs['client_id'][0] == STATIC_CONFIG['client_id'] + assert qs['response_type'][0] == 'id_token' + + +def test_response_mode(): + conf = STATIC_CONFIG.copy() + conf.update({ + "response_modes_supported": ['query', 'form_post'], + 'separate_form_post_cb': True + }) + client = StandAloneClient(config=conf) + client.do_provider_info() + client.do_client_registration() + + # Explicitly set + url = client.init_authorization(req_args={'response_mode': 'form_post'}) + + assert url + p = urlsplit(url) + qs = parse_qs(p.query) + assert 'authz_cb_form' in qs['redirect_uri'][0] + assert qs['client_id'][0] == STATIC_CONFIG['client_id'] + assert qs['response_type'][0] == 'code' + assert qs['response_mode'][0] == 'form_post' + + +def test_response_mode_not_separate_endpoint(): + conf = STATIC_CONFIG.copy() + conf.update({ + "response_modes_supported": ['query', 'form_post'], + 'separate_form_post_cb': False + }) + client = StandAloneClient(config=conf) + client.do_provider_info() + client.do_client_registration() + + # Explicitly set + url = client.init_authorization(req_args={'response_mode': 'form_post'}) + + assert url + p = urlsplit(url) + qs = parse_qs(p.query) + assert 'authz_cb_form' not in qs['redirect_uri'][0] + assert 'authz_cb' in qs['redirect_uri'][0] + assert qs['client_id'][0] == STATIC_CONFIG['client_id'] + assert qs['response_type'][0] == 'code' + assert qs['response_mode'][0] == 'form_post' + + +SEMI_DYN_CONFIG = { + "base_url": "https://example.com/cli/", + "client_id": "Number5", + "client_secret": "asdflkjh0987654321", + "client_type": "oidc", + "provider_info": { + "issuer": "https://op.example.com" + } +} + +PROVIDER_INFO = ProviderConfigurationResponse( + issuer=ISSUER, + authorization_endpoint="https://op.example.com/authn", + token_endpoint="https://op.example.com/token", + userinfo_endpoint="https://op.example.com/user", + registration_endpoint="https://op.example.com/register", + jwks_uri="https://op.example.com/keys/jwks.json", + response_types_supported=["code"], + subject_types_supported=['public'], + id_token_signing_alg_values_supported=['RS256'] +) + +OP_KEYS = build_keyjar(DEFAULT_KEY_DEFS) + + +class TestStandAloneClientOIDCDynProviderInfo(object): + + @pytest.fixture(autouse=True) + def client_setup(self): + self.client = StandAloneClient(config=SEMI_DYN_CONFIG) + + def test_do_provider_info(self): + with responses.RequestsMock() as rsps: + rsps.add( + "GET", + OIDCONF_PATTERN.format(SEMI_DYN_CONFIG['provider_info']['issuer']), + body=PROVIDER_INFO.to_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) + rsps.add( + "GET", + PROVIDER_INFO['jwks_uri'], + body=OP_KEYS.export_jwks_as_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) + issuer = self.client.do_provider_info() + + assert issuer == SEMI_DYN_CONFIG['provider_info']['issuer'] + assert self.client.context.get('issuer') == issuer + + +DYN_CONFIG = { + "base_url": "https://rp.example.com", + "redirect_uris": ["https://rp.example.com/cb"], + "key_conf": {"key_defs": DEFAULT_KEY_DEFS}, + "client_type": "oidc", + "provider_info": { + "issuer": "https://op.example.com" + } +} + + +class TestStandAloneClientOIDCDyn(object): + + @pytest.fixture(autouse=True) + def client_setup(self): + self.client = StandAloneClient(config=DYN_CONFIG) + + def test_do_provider_info(self): + with responses.RequestsMock() as rsps: + rsps.add( + "GET", + OIDCONF_PATTERN.format(SEMI_DYN_CONFIG['provider_info']['issuer']), + body=PROVIDER_INFO.to_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) + rsps.add( + "GET", + PROVIDER_INFO['jwks_uri'], + body=OP_KEYS.export_jwks_as_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) + issuer = self.client.do_provider_info() + + assert issuer == DYN_CONFIG['provider_info']['issuer'] + assert self.client.context.get('issuer') == issuer + + registration_response = RegistrationResponse( + client_id="client_1", + client_secret="a0b1c2d3e4f5g6h7i8j9", + redirect_uris=["https://rp.example.com/cb"] + ) + with responses.RequestsMock() as rsps: + # registration response + rsps.add( + "POST", + PROVIDER_INFO['registration_endpoint'], + body=registration_response.to_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) + + self.client.do_client_registration() + + assert self.client.context.get_usage('client_id') == 'client_1' + + +def test_request_type_mode_1(): + config = STATIC_CONFIG.copy() + config.update({ + "response_modes_supported": ['query', 'form_post'], + "response_types_supported": ['code', 'code idtoken'] + }) + client = StandAloneClient(config=config) + client.do_provider_info() + client.do_client_registration() + + # Explicitly set + url = client.init_authorization() + + assert url + p = urlsplit(url) + qs = parse_qs(p.query) + assert 'authz_cb' in qs['redirect_uri'][0] + assert qs['client_id'][0] == STATIC_CONFIG['client_id'] + assert qs['response_type'][0] == 'code' + + assert 'response_mode' not in qs + + +def test_request_type_mode_2(): + config = STATIC_CONFIG.copy() + config.update({ + "response_modes_supported": ['form_post'], + "response_types_supported": ['code', 'code id_token'] + }) + client = StandAloneClient(config=config) + client.do_provider_info() + client.do_client_registration() + + # Explicitly set + url = client.init_authorization() + + assert url + p = urlsplit(url) + qs = parse_qs(p.query) + assert 'authz_cb' in qs['redirect_uri'][0] + assert qs['client_id'][0] == STATIC_CONFIG['client_id'] + assert qs['response_type'][0] == 'code' + assert qs['response_mode'][0] == 'form_post' + + +def test_request_type_mode_3(): + config = STATIC_CONFIG.copy() + config.update({ + "response_modes_supported": ['form_post'], + "response_types_supported": ['id_token code'] + }) + client = StandAloneClient(config=config) + client.do_provider_info() + client.do_client_registration() + + # Explicitly set + url = client.init_authorization() + + assert url + p = urlsplit(url) + qs = parse_qs(p.query) + assert 'authz_cb' in qs['redirect_uri'][0] + assert qs['client_id'][0] == STATIC_CONFIG['client_id'] + assert qs['response_type'][0] == 'id_token code' + assert qs['response_mode'][0] == 'form_post' + + +def test_request_type_mode_4(): + config = STATIC_CONFIG.copy() + config.update({ + "response_modes_supported": ['query'], + "response_types_supported": ['id_token code'] + }) + client = StandAloneClient(config=config) + client.do_provider_info() + client.do_client_registration() + + # Explicitly set + with pytest.raises(Unsupported): + client.init_authorization() + + +class TestFinalizeAuth(object): + + @pytest.fixture(autouse=True) + def client_setup(self): + self.client = StandAloneClient(config=STATIC_CONFIG) + self.client.do_provider_info() + self.client.do_client_registration() + + def test_one(self): + url = self.client.init_authorization() + + _state = get_state_from_url(url) + _response = AuthorizationResponse( + code=24 * 'x', + state=_state, + iss=self.client.context.issuer, + client_id=self.client.context.get_client_id() + ) + _auth_response = self.client.finalize_auth(_response.to_dict()) + assert _auth_response + + def test_imposter(self): + url = self.client.init_authorization() + + _state = get_state_from_url(url) + _response = AuthorizationResponse( + code=24 * 'x', + state=_state, + iss="https://fake.example.com", + client_id=self.client.context.get_client_id() + ) + + with pytest.raises(VerificationError): + self.client.finalize_auth(_response.to_dict()) + + def test_wrong_state(self): + url = self.client.init_authorization() + + _state = get_state_from_url(url) + _response = AuthorizationResponse( + code=24 * 'x', + state="_state", + iss=self.client.context.issuer, + client_id=self.client.context.get_client_id() + ) + + with pytest.raises(KeyError): + self.client.finalize_auth(_response.to_dict()) + + +ISSUER_KEYS = build_keyjar(DEFAULT_KEY_DEFS, issuer_id=ISSUER) +SUBJECT_NAME = "Subject" +_services = DEFAULT_OIDC_SERVICES.copy() +_services["end_session"] = {'class': "idpyoidc.client.oidc.end_session.EndSession"} + +EXTENDED_STATIC_CONFIG = { + "base_url": "https://example.com/cli/", + "client_id": "Number5", + "client_type": "oidc", + "client_secret": "asdflkjh0987654321", + "post_logout_redirect_uri": "https://example.com/cli/logout", + "services": _services, + "provider_info": { + "issuer": ISSUER, + "authorization_endpoint": "https://op.example.com/authn", + "token_endpoint": "https://op.example.com/token", + "userinfo_endpoint": "https://op.example.com/user", + "end_session_endpoint": "https://op.example.com/end_session" + } +} + + +class TestPostAuthn(object): + + @pytest.fixture(autouse=True) + def client_setup(self): + self.client = StandAloneClient(config=EXTENDED_STATIC_CONFIG) + self.client.do_provider_info() + self.client.do_client_registration() + url = self.client.init_authorization() + + self.state = get_state_from_url(url) + _response = AuthorizationResponse( + code=24 * 'x', + state=self.state, + iss=self.client.context.issuer, + client_id=self.client.context.get_client_id() + ) + self.client.finalize_auth(_response.to_dict()) + + def _create_id_token(self, subject): + _context = self.client.get_context() + _session = self.client.get_session_information(self.state) + _nonce = _session["nonce"] + _iss = _session["iss"] + _aud = _context.get_client_id() + idval = {"nonce": _nonce, "sub": subject, "iss": _iss, "aud": _aud} + + _keyjar = _context.upstream_get("attribute", "keyjar") + _keyjar.import_jwks(ISSUER_KEYS.export_jwks(issuer_id=ISSUER), ISSUER) + + idts = IdToken(**idval) + return idts.to_jwt( + key=ISSUER_KEYS.get_signing_key("rsa", issuer_id=ISSUER), + algorithm="RS256", + lifetime=300, + ) + + def test_get_access_token(self): + with responses.RequestsMock() as rsps: + token_response = AccessTokenResponse( + access_token='access_token', + token_type='Bearer' + ) + rsps.add( + "POST", + STATIC_CONFIG['provider_info']['token_endpoint'], + body=token_response.to_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) + + response = self.client.get_tokens(self.state) + assert isinstance(response, AccessTokenResponse) + assert 'access_token' in response + + def test_get_access_and_id_token(self): + with responses.RequestsMock() as rsps: + token_response = AccessTokenResponse( + access_token='access_token', + token_type='Bearer', + id_token=self._create_id_token('Subject') + ) + rsps.add( + "POST", + STATIC_CONFIG['provider_info']['token_endpoint'], + body=token_response.to_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) + + response = self.client.get_access_and_id_token(state=self.state) + + assert response + assert set(response.keys()) == {'access_token', 'id_token'} + assert response['access_token'] == "access_token" + assert response['id_token']['iss'] == ISSUER + + def test_userinfo(self): + with responses.RequestsMock() as rsps: + token_response = AccessTokenResponse( + access_token='access_token', + token_type='Bearer' + ) + rsps.add( + "POST", + STATIC_CONFIG['provider_info']['token_endpoint'], + body=token_response.to_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) + + self.client.get_tokens(self.state) + + with responses.RequestsMock() as rsps: + _response = OpenIDSchema( + sub=SUBJECT_NAME, + email='subject@example.com' + ) + rsps.add( + "GET", + STATIC_CONFIG['provider_info']['userinfo_endpoint'], + body=_response.to_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) + + response = self.client.get_user_info(self.state) + assert response + + def test_finalize_1(self): + _auth_response = AuthorizationResponse( + code=24 * 'x', + state=self.state, + iss=self.client.context.issuer, + client_id=self.client.context.get_client_id() + ) + + with responses.RequestsMock() as rsps: + token_response = AccessTokenResponse( + access_token='access_token', + token_type='Bearer' + ) + rsps.add( + "POST", + STATIC_CONFIG['provider_info']['token_endpoint'], + body=token_response.to_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) + _response = OpenIDSchema( + sub=SUBJECT_NAME, + email='subject@example.com' + ) + rsps.add( + "GET", + STATIC_CONFIG['provider_info']['userinfo_endpoint'], + body=_response.to_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) + + response = self.client.finalize(_auth_response.to_dict()) + assert response + assert set(response.keys()) == {'userinfo', 'state', 'token', 'id_token', + 'session_state', 'issuer'} + assert response['token'] == 'access_token' + assert response['id_token'] is None + assert response['userinfo']['sub'] == SUBJECT_NAME + assert response['issuer'] == ISSUER + + def test_finalize_2(self): + _auth_response = AuthorizationResponse( + code=24 * 'x', + state=self.state, + iss=self.client.context.issuer, + client_id=self.client.context.get_client_id() + ) + + with responses.RequestsMock() as rsps: + token_response = AccessTokenResponse( + access_token='access_token', + expires_in=300, + token_type='Bearer', + id_token=self._create_id_token(SUBJECT_NAME) + ) + rsps.add( + "POST", + STATIC_CONFIG['provider_info']['token_endpoint'], + body=token_response.to_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) + _response = OpenIDSchema( + sub=SUBJECT_NAME, + email='subject@example.com' + ) + rsps.add( + "GET", + STATIC_CONFIG['provider_info']['userinfo_endpoint'], + body=_response.to_json(), + adding_headers={"Content-Type": "application/json"}, + status=200, + ) + + response = self.client.finalize(_auth_response.to_dict()) + assert response + assert set(response.keys()) == {'userinfo', 'state', 'token', 'id_token', + 'session_state', 'issuer'} + assert response['token'] == 'access_token' + assert response['id_token'] is not None + assert response['userinfo']['sub'] == SUBJECT_NAME + assert response['issuer'] == ISSUER + + assert self.client.has_active_authentication(self.state) + + token, eat = self.client.get_valid_access_token(self.state) + assert token == "access_token" + assert eat > 0 + logout_info = self.client.logout(self.state, "https://example.com/cli/logout") + assert set(logout_info.keys()) == {'method', 'request', 'url'} + assert set(logout_info['request'].keys()) == {'post_logout_redirect_uri', 'id_token_hint', + 'state'} diff --git a/tests/test_client_29_pushed_auth.py b/tests/test_client_29_pushed_auth.py index 1f8901a3..995fa17c 100644 --- a/tests/test_client_29_pushed_auth.py +++ b/tests/test_client_29_pushed_auth.py @@ -35,7 +35,7 @@ def create_client(self): "preference": {"response_types": ["code"]}, "add_ons": { "pushed_authorization": { - "function": "idpyoidc.client.oauth2.add_on.pushed_authorization.add_support", + "function": "idpyoidc.client.oauth2.add_on.par.add_support", "kwargs": { "body_format": "jws", "signing_algorithm": "RS256", diff --git a/tests/test_client_28_rp_handler_oidc.py b/tests/test_client_30_rp_handler_oidc.py similarity index 79% rename from tests/test_client_28_rp_handler_oidc.py rename to tests/test_client_30_rp_handler_oidc.py index e6d14553..b992a55d 100644 --- a/tests/test_client_28_rp_handler_oidc.py +++ b/tests/test_client_30_rp_handler_oidc.py @@ -4,9 +4,9 @@ from urllib.parse import urlparse from urllib.parse import urlsplit -from cryptojwt.key_jar import init_key_jar import pytest import responses +from cryptojwt.key_jar import init_key_jar from idpyoidc.client.entity import Entity from idpyoidc.client.rp_handler import RPHandler @@ -28,10 +28,7 @@ "response_types_supported": [ "code", "id_token", - "id_token token", "code id_token", - "code id_token token", - "code token", ], "token_endpoint_auth_methods_supported": ["client_secret_basic"], "scopes_supported": ["openid", "profile", "email", "address", "phone"], @@ -44,6 +41,7 @@ "redirect_uris": None, "base_url": BASE_URL, "request_parameter": "request_uris", + "client_type": "oidc", "services": { "web_finger": {"class": "idpyoidc.client.oidc.webfinger.WebFinger"}, "discovery": { @@ -111,10 +109,11 @@ "issuer": "https://github.com/login/oauth/authorize", "client_id": "eeeeeeeee", "client_secret": "aaaaaaaaaaaaaaaaaaaa", + "client_type": "oidc", "redirect_uris": ["{}/authz_cb/github".format(BASE_URL)], "preference": { "response_types_supported": ["code"], - "scopes_supported": ["user", "public_repo", 'openid'], + "scopes_supported": ["user", "public_repo", "openid"], "token_endpoint_auth_methods_supported": [], "verify_args": {"allow_sign_alg_none": True}, }, @@ -138,13 +137,14 @@ "issuer": "https://github.com/login/oauth/authorize", "client_id": "eeeeeeeee", "client_secret": "aaaaaaaaaaaaaaaaaaaa", + "client_type": "oidc", "redirect_uris": ["{}/authz_cb/github".format(BASE_URL)], "preference": { "response_types_supported": ["code"], "scopes_supported": ["user", "public_repo"], "token_endpoint_auth_methods_supported": [], "verify_args": {"allow_sign_alg_none": True}, - 'encrypt_request_object': False + "encrypt_request_object": False, }, "provider_info": { "authorization_endpoint": "https://github.com/login/oauth/authorize", @@ -205,11 +205,18 @@ ) +def get_state_from_url(url): + p = urlsplit(url) + qp = parse_qs(p.query) + return qp["state"][0] + + def iss_id(iss): return CLIENT_CONFIG[iss]["issuer"] class TestRPHandler(object): + @pytest.fixture(autouse=True) def rphandler_setup(self): self.rph = RPHandler( @@ -242,7 +249,7 @@ def test_init_client(self): # Neither provider info discovery not client registration has been done # So only preferences so far. - assert _context.get_preference('client_id') == "eeeeeeeee" + assert _context.get_preference("client_id") == "eeeeeeeee" assert _context.get_preference("client_secret") == "aaaaaaaaaaaaaaaaaaaa" assert _context.issuer == "https://github.com/login/oauth/authorize" @@ -254,18 +261,24 @@ def test_init_client(self): } _pref = [k for k, v in _context.prefers().items() if v] - assert set(_pref) == {'client_id', 'client_secret', 'redirect_uris', - 'response_types_supported', 'callback_uris', 'scopes_supported'} + assert set(_pref) == { + "client_id", + "client_secret", + "redirect_uris", + "response_types_supported", + "callback_uris", + "scopes_supported", + } _github_id = iss_id("github") - _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar = _context.upstream_get("attribute", "keyjar") _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) # The key jar should only contain a symmetric key that is the clients # secret. 2 because one is marked for encryption and the other signing # usage. - assert set(_keyjar.owners()) == {"", 'eeeeeeeee', _github_id} + assert set(_keyjar.owners()) == {"", "eeeeeeeee", _github_id} keys = _keyjar.get_issuer_keys("") assert len(keys) == 3 @@ -292,8 +305,9 @@ def test_do_client_registration(self): assert self.rph.hash2issuer["github"] == issuer assert ( - client.get_context().get_preference('callback_uris').get( - "post_logout_redirect_uris") is None + client.get_context().get_preference("callback_uris").get( + "post_logout_redirect_uris") + is None ) def test_do_client_setup(self): @@ -303,11 +317,11 @@ def test_do_client_setup(self): # Neither provider info discovery not client registration has been done # So only preferences so far. - assert _context.get_preference('client_id') == "eeeeeeeee" + assert _context.get_preference("client_id") == "eeeeeeeee" assert _context.get_preference("client_secret") == "aaaaaaaaaaaaaaaaaaaa" assert _context.issuer == _github_id - _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar = _context.upstream_get("attribute", "keyjar") _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) assert set(_keyjar.owners()) == {"", "eeeeeeeee", _github_id} @@ -319,34 +333,31 @@ def test_do_client_setup(self): _endp = _srv.upstream_get("context").get("provider_info")[_srv.endpoint_name] assert _srv.endpoint == _endp - assert self.rph.hash2issuer["github"] == _context.get("issuer") - def test_create_callbacks(self): client = self.rph.init_client("https://op.example.com/") _srv = client.get_service("registration") _context = _srv.upstream_get("context") - cb = _context.get_preference('callback_uris') + cb = _context.get_preference("callback_uris") assert set(cb.keys()) == {"request_uris", "redirect_uris"} - assert set(cb['redirect_uris'].keys()) == {'code'} + assert set(cb["redirect_uris"].keys()) == {"query", "fragment"} _hash = _context.iss_hash - assert cb['redirect_uris']["code"] == [f"https://example.com/rp/authz_cb/{_hash}"] + assert cb["redirect_uris"]["query"] == [f"https://example.com/rp/authz_cb/{_hash}"] assert list(self.rph.hash2issuer.keys()) == [_hash] assert self.rph.hash2issuer[_hash] == "https://op.example.com/" def test_begin(self): - res = self.rph.begin(issuer_id="github") - assert set(res.keys()) == {"url", "state"} + url = self.rph.begin(issuer_id="github") _github_id = iss_id("github") client = self.rph.issuer2rp[_github_id] assert client.get_context().issuer == _github_id - part = urlsplit(res["url"]) + part = urlsplit(url) assert part.scheme == "https" assert part.netloc == "github.com" assert part.path == "/login/oauth/authorize" @@ -366,70 +377,77 @@ def test_begin(self): assert query["client_id"] == ["eeeeeeeee"] assert query["redirect_uri"] == ["https://example.com/rp/authz_cb/github"] assert query["response_type"] == ["code"] - assert set(query["scope"][0].split(' ')) == {"openid", "user", "public_repo"} + assert set(query["scope"][0].split(" ")) == {"openid", "user", "public_repo"} def test_get_session_information(self): - res = self.rph.begin(issuer_id="github") - _session = self.rph.get_session_information(res["state"]) + url = self.rph.begin(issuer_id="github") + _session = self.rph.get_session_information(get_state_from_url(url)) assert self.rph.client_configs["github"]["issuer"] == _session["iss"] def test_get_client_from_session_key(self): - res = self.rph.begin(issuer_id="linkedin") - cli1 = self.rph.get_client_from_session_key(state=res["state"]) - _session = self.rph.get_session_information(res["state"]) - cli2 = self.rph.issuer2rp[_session['iss']] + url = self.rph.begin(issuer_id="linkedin") + _state = get_state_from_url(url) + cli1 = self.rph.get_client_from_session_key(state=_state) + _session = self.rph.get_session_information(_state) + cli2 = self.rph.issuer2rp[_session["iss"]] assert cli1 == cli2 # redo - self.rph.do_provider_info(state=res["state"]) + self.rph.do_provider_info(state=_state) # get new redirect_uris cli2.get_context().set_preference("redirect_uris", []) - self.rph.do_client_registration(state=res["state"]) + self.rph.do_client_registration(state=_state) def test_finalize_auth(self): - res = self.rph.begin(issuer_id="linkedin") - _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session['iss']] + url = self.rph.begin(issuer_id="linkedin") + _state = get_state_from_url(url) + _session = self.rph.get_session_information(_state) + client = self.rph.issuer2rp[_session["iss"]] - auth_response = AuthorizationResponse(code="access_code", state=res["state"]) - resp = self.rph.finalize_auth(client, _session['iss'], auth_response.to_dict()) + auth_response = AuthorizationResponse(code="access_code", state=_state) + resp = self.rph.finalize_auth(client, _session["iss"], auth_response.to_dict()) assert set(resp.keys()) == {"state", "code"} - _state = client.get_context().cstate.get(res["state"]) - assert set(_state.keys()) == {'client_id', - 'code', - 'iss', - 'nonce', - 'redirect_uri', - 'response_type', - 'scope', - 'state'} + _state = client.get_context().cstate.get(_state) + assert set(_state.keys()) == { + "client_id", + "code", + "iss", + "nonce", + "redirect_uri", + "response_type", + "scope", + "state", + } def test_get_client_authn_method(self): - res = self.rph.begin(issuer_id="github") - _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session['iss']] + url = self.rph.begin(issuer_id="github") + _state = get_state_from_url(url) + _session = self.rph.get_session_information(_state) + client = self.rph.issuer2rp[_session["iss"]] authn_method = self.rph.get_client_authn_method(client, "token_endpoint") - assert authn_method == '' + assert authn_method == "" - res = self.rph.begin(issuer_id="linkedin") - _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session['iss']] + url = self.rph.begin(issuer_id="linkedin") + _state = get_state_from_url(url) + _session = self.rph.get_session_information(_state) + client = self.rph.issuer2rp[_session["iss"]] authn_method = self.rph.get_client_authn_method(client, "token_endpoint") assert authn_method == "client_secret_post" def test_get_tokens(self): - res = self.rph.begin(issuer_id="github") - _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session['iss']] + url = self.rph.begin(issuer_id="github") + _state = get_state_from_url(url) + _session = self.rph.get_session_information(_state) + client = self.rph.issuer2rp[_session["iss"]] _github_id = iss_id("github") _context = client.get_context() - _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar = _context.upstream_get("attribute", "keyjar") _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) _nonce = _session["nonce"] - _iss = _session['iss'] + _iss = _session["iss"] _aud = _context.get_client_id() - idval = {"nonce": _nonce, "sub": "EndUserSubject", 'iss': _iss, "aud": _aud} + idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} idts = IdToken(**idval) _signed_jwt = idts.to_jwt( @@ -455,10 +473,10 @@ def test_get_tokens(self): ) client.get_service("accesstoken").endpoint = _url - auth_response = AuthorizationResponse(code="access_code", state=res["state"]) - resp = self.rph.finalize_auth(client, _session['iss'], auth_response.to_dict()) + auth_response = AuthorizationResponse(code="access_code", state=_state) + resp = self.rph.finalize_auth(client, _session["iss"], auth_response.to_dict()) - resp = self.rph.get_tokens(res["state"], client) + resp = self.rph.get_tokens(_state, client) assert set(resp.keys()) == { "access_token", "expires_in", @@ -468,34 +486,37 @@ def test_get_tokens(self): "__expires_at", } - _curr = client.get_context().cstate.get(res["state"]) - assert set(_curr.keys()) == {'__expires_at', - '__verified_id_token', - 'access_token', - 'client_id', - 'code', - 'expires_in', - 'id_token', - 'iss', - 'nonce', - 'redirect_uri', - 'response_type', - 'scope', - 'state', - 'token_type'} + _curr = client.get_context().cstate.get(_state) + assert set(_curr.keys()) == { + "__expires_at", + "__verified_id_token", + "access_token", + "client_id", + "code", + "expires_in", + "id_token", + "iss", + "nonce", + "redirect_uri", + "response_type", + "scope", + "state", + "token_type", + } def test_access_and_id_token(self): - res = self.rph.begin(issuer_id="github") - _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session['iss']] + url = self.rph.begin(issuer_id="github") + _state = get_state_from_url(url) + _session = self.rph.get_session_information(_state) + client = self.rph.issuer2rp[_session["iss"]] _context = client.get_context() _nonce = _session["nonce"] - _iss = _session['iss'] + _iss = _session["iss"] _aud = _context.get_client_id() - idval = {"nonce": _nonce, "sub": "EndUserSubject", 'iss': _iss, "aud": _aud} + idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} _github_id = iss_id("github") - _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar = _context.upstream_get("attribute", "keyjar") _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) @@ -524,24 +545,25 @@ def test_access_and_id_token(self): ) client.get_service("accesstoken").endpoint = _url - _response = AuthorizationResponse(code="access_code", state=res["state"]) - auth_response = self.rph.finalize_auth(client, _session['iss'], _response.to_dict()) + _response = AuthorizationResponse(code="access_code", state=_state) + auth_response = self.rph.finalize_auth(client, _session["iss"], _response.to_dict()) resp = self.rph.get_access_and_id_token(auth_response, client=client) assert resp["access_token"] == "accessTok" assert isinstance(resp["id_token"], IdToken) def test_access_and_id_token_by_reference(self): - res = self.rph.begin(issuer_id="github") - _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session['iss']] + url = self.rph.begin(issuer_id="github") + _state = get_state_from_url(url) + _session = self.rph.get_session_information(_state) + client = self.rph.issuer2rp[_session["iss"]] _context = client.get_context() _nonce = _session["nonce"] - _iss = _session['iss'] + _iss = _session["iss"] _aud = _context.get_client_id() - idval = {"nonce": _nonce, "sub": "EndUserSubject", 'iss': _iss, "aud": _aud} + idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} _github_id = iss_id("github") - _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar = _context.upstream_get("attribute", "keyjar") _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) @@ -570,24 +592,25 @@ def test_access_and_id_token_by_reference(self): ) client.get_service("accesstoken").endpoint = _url - _response = AuthorizationResponse(code="access_code", state=res["state"]) - _ = self.rph.finalize_auth(client, _session['iss'], _response.to_dict()) - resp = self.rph.get_access_and_id_token(state=res["state"]) + _response = AuthorizationResponse(code="access_code", state=_state) + _ = self.rph.finalize_auth(client, _session["iss"], _response.to_dict()) + resp = self.rph.get_access_and_id_token(state=_state) assert resp["access_token"] == "accessTok" assert isinstance(resp["id_token"], IdToken) def test_get_user_info(self): - res = self.rph.begin(issuer_id="github") - _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session['iss']] + url = self.rph.begin(issuer_id="github") + _state = get_state_from_url(url) + _session = self.rph.get_session_information(_state) + client = self.rph.issuer2rp[_session["iss"]] _context = client.get_context() _nonce = _session["nonce"] - _iss = _session['iss'] + _iss = _session["iss"] _aud = _context.get_client_id() - idval = {"nonce": _nonce, "sub": "EndUserSubject", 'iss': _iss, "aud": _aud} + idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} _github_id = iss_id("github") - _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar = _context.upstream_get("attribute", "keyjar") _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) @@ -616,8 +639,8 @@ def test_get_user_info(self): ) client.get_service("accesstoken").endpoint = _url - _response = AuthorizationResponse(code="access_code", state=res["state"]) - auth_response = self.rph.finalize_auth(client, _session['iss'], _response.to_dict()) + _response = AuthorizationResponse(code="access_code", state=_state) + auth_response = self.rph.finalize_auth(client, _session["iss"], _response.to_dict()) token_resp = self.rph.get_access_and_id_token(auth_response, client=client) @@ -632,21 +655,22 @@ def test_get_user_info(self): ) client.get_service("userinfo").endpoint = _url - userinfo_resp = self.rph.get_user_info(res["state"], client, token_resp["access_token"]) + userinfo_resp = self.rph.get_user_info(_state, client, token_resp["access_token"]) assert userinfo_resp def test_userinfo_in_id_token(self): - res = self.rph.begin(issuer_id="github") - _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session['iss']] + url = self.rph.begin(issuer_id="github") + _state = get_state_from_url(url) + _session = self.rph.get_session_information(_state) + client = self.rph.issuer2rp[_session["iss"]] _context = client.get_context() _nonce = _session["nonce"] - _iss = _session['iss'] + _iss = _session["iss"] _aud = _context.get_client_id() idval = { "nonce": _nonce, "sub": "EndUserSubject", - 'iss': _iss, + "iss": _iss, "aud": _aud, "given_name": "Diana", "family_name": "Krall", @@ -666,20 +690,22 @@ def test_get_provider_specific_service(): class TestRPHandlerTier2(object): + @pytest.fixture(autouse=True) def rphandler_setup(self): self.rph = RPHandler(BASE_URL, CLIENT_CONFIG, keyjar=CLI_KEY) - res = self.rph.begin(issuer_id="github") - _session = self.rph.get_session_information(res["state"]) - client = self.rph.issuer2rp[_session['iss']] + url = self.rph.begin(issuer_id="github") + _state = get_state_from_url(url) + _session = self.rph.get_session_information(_state) + client = self.rph.issuer2rp[_session["iss"]] _context = client.get_context() _nonce = _session["nonce"] - _iss = _session['iss'] + _iss = _session["iss"] _aud = _context.get_client_id() - idval = {"nonce": _nonce, "sub": "EndUserSubject", 'iss': _iss, "aud": _aud} + idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} _github_id = iss_id("github") - _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar = _context.upstream_get("attribute", "keyjar") _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) @@ -710,8 +736,8 @@ def rphandler_setup(self): client.get_service("accesstoken").endpoint = _url - _response = AuthorizationResponse(code="access_code", state=res["state"]) - auth_response = self.rph.finalize_auth(client, _session['iss'], _response.to_dict()) + _response = AuthorizationResponse(code="access_code", state=_state) + auth_response = self.rph.finalize_auth(client, _session["iss"], _response.to_dict()) token_resp = self.rph.get_access_and_id_token(auth_response, client=client) @@ -726,20 +752,20 @@ def rphandler_setup(self): ) client.get_service("userinfo").endpoint = _url - self.rph.get_user_info(res["state"], client, token_resp["access_token"]) - self.state = res["state"] + self.rph.get_user_info(_state, client, token_resp["access_token"]) + self.state = _state def test_init_authorization(self): _session = self.rph.get_session_information(self.state) - client = self.rph.issuer2rp[_session['iss']] - res = self.rph.init_authorization(client, req_args={"scope": ["openid", "email"]}) - part = urlsplit(res["url"]) + client = self.rph.issuer2rp[_session["iss"]] + _url = self.rph.init_authorization(client, req_args={"scope": ["openid", "email"]}) + part = urlsplit(_url) _qp = parse_qs(part.query) assert _qp["scope"] == ["openid email"] def test_refresh_access_token(self): _session = self.rph.get_session_information(self.state) - client = self.rph.issuer2rp[_session['iss']] + client = self.rph.issuer2rp[_session["iss"]] _info = {"access_token": "2nd_accessTok", "token_type": "Bearer", "expires_in": 3600} at = AccessTokenResponse(**_info) @@ -759,7 +785,7 @@ def test_refresh_access_token(self): def test_get_user_info(self): _session = self.rph.get_session_information(self.state) - client = self.rph.issuer2rp[_session['iss']] + client = self.rph.issuer2rp[_session["iss"]] _url = "https://github.com/userinfo" with responses.RequestsMock() as rsps: @@ -786,6 +812,7 @@ def test_get_valid_access_token(self): class MockResponse: + def __init__(self, status_code, text, headers=None): self.status_code = status_code self.text = text @@ -793,6 +820,7 @@ def __init__(self, status_code, text, headers=None): class MockOP(object): + def __init__(self, issuer, keyjar=None): self.keyjar = keyjar self.issuer = issuer @@ -841,7 +869,7 @@ def __call__(self, url, method="GET", data=None, headers=None, **kwargs): def construct_access_token_response(nonce, issuer, client_id, key_jar): _aud = client_id - idval = {"nonce": nonce, "sub": "EndUserSubject", 'iss': issuer, "aud": _aud} + idval = {"nonce": nonce, "sub": "EndUserSubject", "iss": issuer, "aud": _aud} idts = IdToken(**idval) _signed_jwt = idts.to_jwt( @@ -868,23 +896,20 @@ def registration_callback(data): def test_rphandler_request_uri(): rph = RPHandler(BASE_URL, CLIENT_CONFIG, keyjar=CLI_KEY) - res = rph.begin(issuer_id="github2", behaviour_args={"request_param": "request_uri"}) - _session = rph.get_session_information(res["state"]) - _url = res["url"] + _url = rph.begin(issuer_id="github2", behaviour_args={"request_param": "request_uri"}) _qp = parse_qs(urlparse(_url).query) assert "request_uri" in _qp def test_rphandler_request(): rph = RPHandler(BASE_URL, CLIENT_CONFIG, keyjar=CLI_KEY) - res = rph.begin(issuer_id="github2", behaviour_args={"request_param": "request"}) - _session = rph.get_session_information(res["state"]) - _url = res["url"] + _url = rph.begin(issuer_id="github2", behaviour_args={"request_param": "request"}) _qp = parse_qs(urlparse(_url).query) assert "request" in _qp class TestRPHandlerWithMockOP(object): + @pytest.fixture(autouse=True) def rphandler_setup(self): self.issuer = "https://github.com/login/oauth/authorize" @@ -892,9 +917,10 @@ def rphandler_setup(self): self.rph = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY) def test_finalize(self): - auth_query = self.rph.begin(issuer_id="github") + url = self.rph.begin(issuer_id="github") + _state = get_state_from_url(url) # The authorization query is sent and after successful authentication - client = self.rph.get_client_from_session_key(state=auth_query["state"]) + client = self.rph.get_client_from_session_key(state=_state) # register a response _url = CLIENT_CONFIG["github"]["provider_info"]["authorization_endpoint"] with responses.RequestsMock() as rsps: @@ -903,10 +929,10 @@ def test_finalize(self): _url, status=302, ) - _ = client.httpc("GET", auth_query["url"]) + _ = client.httpc("GET", url) # the user is redirected back to the RP with a positive response - auth_response = AuthorizationResponse(code="access_code", state=auth_query["state"]) + auth_response = AuthorizationResponse(code="access_code", state=_state) # need session information and the client instance _session = self.rph.get_session_information(auth_response["state"]) @@ -926,7 +952,7 @@ def test_finalize(self): sub="EndUserSubject", given_name="Diana", family_name="Krall", occupation="Jazz pianist" ) _github_id = iss_id("github") - _keyjar = client.get_attribute('keyjar') + _keyjar = client.get_attribute("keyjar") _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) with responses.RequestsMock() as rsps: rsps.add( @@ -946,9 +972,10 @@ def test_finalize(self): # do the rest (= get access token and user info) # assume code flow - resp = self.rph.finalize(_session['iss'], auth_response.to_dict()) + resp = self.rph.finalize(_session["iss"], auth_response.to_dict()) - assert set(resp.keys()) == {"userinfo", "state", "token", "id_token", "session_state"} + assert set(resp.keys()) == {'token', 'session_state', 'userinfo', 'state', 'issuer', + 'id_token'} def test_dynamic_setup(self): user_id = "acct:foobar@example.com" @@ -983,16 +1010,24 @@ def test_dynamic_setup(self): "request_object_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"], } pcr = ProviderConfigurationResponse(**resp) - _crr = {"application_type": "web", "response_types": ["code", "code id_token"], - "redirect_uris": [ - "https://example.com/rp/authz_cb" - "/7b7308fecf10c90b29303b6ae35ad1ef0f1914e49187f163335ae0b26a769e4f"], - "grant_types": ["authorization_code", "implicit"], "contacts": ["ops@example.com"], - "subject_type": "public", "id_token_signed_response_alg": "RS256", - "userinfo_signed_response_alg": "RS256", "request_object_signing_alg": "RS256", - "token_endpoint_auth_signing_alg": "RS256", "default_max_age": 86400, - "token_endpoint_auth_method": "client_secret_basic"} - _crr.update({'client_id':'abcdefghijkl', 'client_secret':rndstr(32)}) + _crr = { + "application_type": "web", + "response_types": ["code", "code id_token"], + "redirect_uris": [ + "https://example.com/rp/authz_cb" + "/7b7308fecf10c90b29303b6ae35ad1ef0f1914e49187f163335ae0b26a769e4f" + ], + "grant_types": ["authorization_code", "implicit"], + "contacts": ["ops@example.com"], + "subject_type": "public", + "id_token_signed_response_alg": "RS256", + "userinfo_signed_response_alg": "RS256", + "request_object_signing_alg": "RS256", + "token_endpoint_auth_signing_alg": "RS256", + "default_max_age": 86400, + "token_endpoint_auth_method": "client_secret_basic", + } + _crr.update({"client_id": "abcdefghijkl", "client_secret": rndstr(32)}) cli_reg_resp = RegistrationResponse(**_crr) with responses.RequestsMock() as rsps: rsps.add( diff --git a/tests/test_client_30_rph_defaults.py b/tests/test_client_30_rph_defaults.py index dbef6550..30dadd6e 100644 --- a/tests/test_client_30_rph_defaults.py +++ b/tests/test_client_30_rph_defaults.py @@ -49,10 +49,10 @@ def test_init_client(self): 'userinfo_encryption_alg_values_supported', 'userinfo_encryption_enc_values_supported'} - _keyjar = client.get_attribute('keyjar') + _keyjar = client.get_attribute("keyjar") assert list(_keyjar.owners()) == ["", BASE_URL] keys = _keyjar.get_issuer_keys("") - assert len(keys) == 2 + assert len(keys) == 4 assert _context.base_url == BASE_URL @@ -96,36 +96,34 @@ def test_begin(self): self.rph.issuer2rp[issuer] = client - assert set(_context.claims.use.keys()) == {'application_type', - 'callback_uris', - 'client_id', - 'client_secret', - 'default_max_age', - 'encrypt_request_object_supported', - 'encrypt_userinfo_supported', - 'grant_types', - 'id_token_signed_response_alg', - 'jwks_uri', - 'redirect_uris', - 'request_object_signing_alg', - 'response_modes_supported', - 'response_types', - 'scope', - 'subject_type', - 'token_endpoint_auth_method', - 'token_endpoint_auth_signing_alg', - 'userinfo_signed_response_alg'} + assert set(_context.claims.use.keys()) == { + "application_type", + "callback_uris", + "client_id", + "client_secret", + "default_max_age", + "encrypt_request_object_supported", + "grant_types", + "id_token_signed_response_alg", + "jwks_uri", + "redirect_uris", + "request_object_signing_alg", + "response_modes", + "response_types", + "scope", + "subject_type", + "token_endpoint_auth_method", + "token_endpoint_auth_signing_alg", + } assert _context.get_client_id() == "client uno" assert _context.get_usage("client_secret") == "VerySecretAndLongEnough" assert _context.get("issuer") == ISS_ID - res = self.rph.init_authorization(client) - assert set(res.keys()) == {"url", "state"} - p = urlparse(res["url"]) + url = self.rph.init_authorization(client) + p = urlparse(url) assert p.hostname == "op.example.org" assert p.path == "/authorization" qs = parse_qs(p.query) - assert qs["state"] == [res["state"]] # PKCE stuff assert "code_challenge" in qs assert qs["code_challenge_method"] == ["S256"] diff --git a/tests/test_client_31_oauth2_persistent.py b/tests/test_client_31_oauth2_persistent.py index 16b275bf..8469f9e5 100644 --- a/tests/test_client_31_oauth2_persistent.py +++ b/tests/test_client_31_oauth2_persistent.py @@ -111,9 +111,7 @@ def test_construct_refresh_token_request(self): client_1.get_context().load(_state_dump) req_args = {} - msg = client_1.get_service("refresh_token").construct( - request_args=req_args, state=_state - ) + msg = client_1.get_service("refresh_token").construct(request_args=req_args, state=_state) assert isinstance(msg, RefreshAccessTokenRequest) assert msg.to_dict() == { "client_id": "client_1", diff --git a/tests/test_client_32_oidc_persistent.py b/tests/test_client_32_oidc_persistent.py index 0f5c34ae..3744fc60 100755 --- a/tests/test_client_32_oidc_persistent.py +++ b/tests/test_client_32_oidc_persistent.py @@ -67,9 +67,7 @@ def test_construct_accesstoken_request(self): # Bind access code to state req_args = {} - msg = client_2.get_service("accesstoken").construct( - request_args=req_args, state=_state - ) + msg = client_2.get_service("accesstoken").construct(request_args=req_args, state=_state) assert isinstance(msg, AccessTokenRequest) assert msg.to_dict() == { "client_id": "client_1", @@ -89,7 +87,7 @@ def test_construct_refresh_token_request(self): redirect_uri="https://example.com/cli/authz_cb", state=_state ) - client_1.get_context().cstate.update(_state,auth_request) + client_1.get_context().cstate.update(_state, auth_request) # Client 2 carries on client_2 = RP(config=CONF) @@ -100,16 +98,14 @@ def test_construct_refresh_token_request(self): client_2.get_context().cstate.update(_state, auth_response) token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - client_2.get_context().cstate.update(_state,token_response ) + client_2.get_context().cstate.update(_state, token_response) # Back to Client 1 _state_dump = client_2.get_context().dump() client_1.get_context().load(_state_dump) req_args = {} - msg = client_1.get_service("refresh_token").construct( - request_args=req_args, state=_state - ) + msg = client_1.get_service("refresh_token").construct(request_args=req_args, state=_state) assert isinstance(msg, RefreshAccessTokenRequest) assert msg.to_dict() == { "client_id": "client_1", @@ -133,10 +129,10 @@ def test_do_userinfo_request_init(self): client_2.get_context().load(_state_dump) auth_response = AuthorizationResponse(code="access_code") - client_2.get_context().cstate.update(_state,auth_response) + client_2.get_context().cstate.update(_state, auth_response) token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - client_2.get_context().cstate.update(_state,token_response) + client_2.get_context().cstate.update(_state, token_response) # Back to Client 1 _state_dump = client_2.get_context().dump() diff --git a/tests/test_client_40_dpop.py b/tests/test_client_40_dpop.py index f96661e5..8a29167d 100644 --- a/tests/test_client_40_dpop.py +++ b/tests/test_client_40_dpop.py @@ -87,9 +87,7 @@ def create_client(self): } services = { - "discovery": { - "class": "idpyoidc.client.oauth2.server_metadata.ServerMetadata" - }, + "discovery": {"class": "idpyoidc.client.oauth2.server_metadata.ServerMetadata"}, "authorization": {"class": "idpyoidc.client.oauth2.authorization.Authorization"}, "access_token": {"class": "idpyoidc.client.oauth2.access_token.AccessToken"}, "refresh_access_token": { diff --git a/tests/test_client_41_rp_handler_persistent.py b/tests/test_client_41_rp_handler_persistent.py index db08eafe..be07ad4a 100644 --- a/tests/test_client_41_rp_handler_persistent.py +++ b/tests/test_client_41_rp_handler_persistent.py @@ -2,8 +2,8 @@ from urllib.parse import parse_qs from urllib.parse import urlsplit -from cryptojwt.key_jar import init_key_jar import responses +from cryptojwt.key_jar import init_key_jar from idpyoidc.client.rp_handler import RPHandler from idpyoidc.message.oidc import AccessTokenResponse @@ -51,7 +51,7 @@ "client_id": "xxxxxxx", "client_secret": "yyyyyyyyyyyyyyyyyyyy", "redirect_uris": ["{}/authz_cb/linkedin".format(BASE_URL)], - 'client_type': 'oauth2', + "client_type": "oauth2", "preference": { "response_types": ["code"], "scope": ["r_basicprofile", "r_emailaddress"], @@ -168,7 +168,14 @@ def iss_id(iss): return CLIENT_CONFIG[iss]["issuer"] +def get_state_from_url(url): + p = urlsplit(url) + qp = parse_qs(p.query) + return qp["state"][0] + + class TestRPHandler(object): + def test_pick_config(self): rph_1 = RPHandler( BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY, module_dirs=["oidc"] @@ -241,10 +248,10 @@ def test_do_client_setup(self): assert _context.get_usage("client_secret") == "aaaaaaaaaaaaaaaaaaaa" assert _context.get("issuer") == _github_id - _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar = _context.upstream_get("attribute", "keyjar") _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) - assert set(_keyjar.owners()) == {"", 'eeeeeeeee', _github_id} + assert set(_keyjar.owners()) == {"", "eeeeeeeee", _github_id} keys = _keyjar.get_issuer_keys("") assert len(keys) == 3 # one symmetric, one RSA and one EC @@ -253,22 +260,19 @@ def test_do_client_setup(self): _endp = client.get_context().get("provider_info")[_srv.endpoint_name] assert _srv.endpoint == _endp - assert rph_1.hash2issuer["github"] == _context.get("issuer") - def test_begin(self): rph_1 = RPHandler( BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY, module_dirs=["oidc"] ) - res = rph_1.begin(issuer_id="github") - assert set(res.keys()) == {"url", "state"} + url = rph_1.begin(issuer_id="github") _github_id = iss_id("github") client = rph_1.issuer2rp[_github_id] assert client.get_context().get("issuer") == _github_id - part = urlsplit(res["url"]) + part = urlsplit(url) assert part.scheme == "https" assert part.netloc == "github.com" assert part.path == "/login/oauth/authorize" @@ -294,8 +298,8 @@ def test_get_session_information(self): BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY, module_dirs=["oidc"] ) - res = rph_1.begin(issuer_id="github") - _session = rph_1.get_session_information(res["state"]) + url = rph_1.begin(issuer_id="github") + _session = rph_1.get_session_information(get_state_from_url(url)) assert rph_1.client_configs["github"]["issuer"] == _session["iss"] def test_get_client_from_session_key(self): @@ -303,49 +307,58 @@ def test_get_client_from_session_key(self): BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY, module_dirs=["oidc"] ) - res = rph_1.begin(issuer_id="linkedin") - cli1 = rph_1.get_client_from_session_key(state=res["state"]) - _session = rph_1.get_session_information(res["state"]) + url = rph_1.begin(issuer_id="linkedin") + _state = get_state_from_url(url) + cli1 = rph_1.get_client_from_session_key(state=_state) + _session = rph_1.get_session_information(_state) cli2 = rph_1.issuer2rp[_session["iss"]] assert cli1 == cli2 # redo - rph_1.do_provider_info(state=res["state"]) + rph_1.do_provider_info(state=_state) # get new redirect_uris cli2.get_context().set_usage("redirect_uris", []) - rph_1.do_client_registration(state=res["state"]) + rph_1.do_client_registration(state=_state) def test_finalize_auth(self): rph_1 = RPHandler( BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY, module_dirs=["oidc"] ) - res = rph_1.begin(issuer_id="linkedin") - _session = rph_1.get_session_information(res["state"]) + url = rph_1.begin(issuer_id="linkedin") + _state = get_state_from_url(url) + _session = rph_1.get_session_information(_state) client = rph_1.issuer2rp[_session["iss"]] - auth_response = AuthorizationResponse(code="access_code", state=res["state"]) + auth_response = AuthorizationResponse(code="access_code", state=_state) resp = rph_1.finalize_auth(client, _session["iss"], auth_response.to_dict()) assert set(resp.keys()) == {"state", "code"} - aresp = ( - client.get_service("authorization").upstream_get("context").cstate.get(res["state"]) - ) + aresp = client.get_service("authorization").upstream_get("context").cstate.get(_state) assert set(aresp.keys()) == { - "state", "code", 'iss', 'client_id', - 'scope', 'nonce', 'response_type', 'redirect_uri'} + "state", + "code", + "iss", + "client_id", + "scope", + "nonce", + "response_type", + "redirect_uri", + } def test_get_client_authn_method(self): rph_1 = RPHandler( BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY, module_dirs=["oidc"] ) - res = rph_1.begin(issuer_id="github") - _session = rph_1.get_session_information(res["state"]) + url = rph_1.begin(issuer_id="github") + _state = get_state_from_url(url) + _session = rph_1.get_session_information(_state) client = rph_1.issuer2rp[_session["iss"]] authn_method = rph_1.get_client_authn_method(client, "token_endpoint") assert authn_method == "" - res = rph_1.begin(issuer_id="linkedin") - _session = rph_1.get_session_information(res["state"]) + url = rph_1.begin(issuer_id="linkedin") + _state = get_state_from_url(url) + _session = rph_1.get_session_information(_state) client = rph_1.issuer2rp[_session["iss"]] authn_method = rph_1.get_client_authn_method(client, "token_endpoint") assert authn_method == "client_secret_post" @@ -355,13 +368,14 @@ def test_get_tokens(self): BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY, module_dirs=["oidc"] ) - res = rph_1.begin(issuer_id="github") - _session = rph_1.get_session_information(res["state"]) + url = rph_1.begin(issuer_id="github") + _state = get_state_from_url(url) + _session = rph_1.get_session_information(_state) client = rph_1.issuer2rp[_session["iss"]] _github_id = iss_id("github") _context = client.get_context() - _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar = _context.upstream_get("attribute", "keyjar") _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) _nonce = _session["nonce"] @@ -393,10 +407,10 @@ def test_get_tokens(self): ) client.get_service("accesstoken").endpoint = _url - auth_response = AuthorizationResponse(code="access_code", state=res["state"]) + auth_response = AuthorizationResponse(code="access_code", state=_state) resp = rph_1.finalize_auth(client, _session["iss"], auth_response.to_dict()) - resp = rph_1.get_tokens(res["state"], client) + resp = rph_1.get_tokens(_state, client) assert set(resp.keys()) == { "access_token", "expires_in", @@ -409,23 +423,23 @@ def test_get_tokens(self): atresp = ( client.get_service("accesstoken") .upstream_get("service_context") - .cstate.get(res["state"]) + .cstate.get(_state) ) assert set(atresp.keys()) == { "__expires_at", "__verified_id_token", "access_token", - 'client_id', - 'code', + "client_id", + "code", "expires_in", "id_token", - 'iss', - 'nonce', - 'redirect_uri', - 'response_type', - 'scope', - 'state', - "token_type" + "iss", + "nonce", + "redirect_uri", + "response_type", + "scope", + "state", + "token_type", } def test_access_and_id_token(self): @@ -433,8 +447,9 @@ def test_access_and_id_token(self): BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY, module_dirs=["oidc"] ) - res = rph_1.begin(issuer_id="github") - _session = rph_1.get_session_information(res["state"]) + url = rph_1.begin(issuer_id="github") + _state = get_state_from_url(url) + _session = rph_1.get_session_information(_state) client = rph_1.issuer2rp[_session["iss"]] _context = client.get_context() _nonce = _session["nonce"] @@ -443,7 +458,7 @@ def test_access_and_id_token(self): idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} _github_id = iss_id("github") - _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar = _context.upstream_get("attribute", "keyjar") _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) @@ -472,7 +487,7 @@ def test_access_and_id_token(self): ) client.get_service("accesstoken").endpoint = _url - _response = AuthorizationResponse(code="access_code", state=res["state"]) + _response = AuthorizationResponse(code="access_code", state=_state) auth_response = rph_1.finalize_auth(client, _session["iss"], _response.to_dict()) resp = rph_1.get_access_and_id_token(auth_response, client=client) assert resp["access_token"] == "accessTok" @@ -483,8 +498,9 @@ def test_access_and_id_token_by_reference(self): BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY, module_dirs=["oidc"] ) - res = rph_1.begin(issuer_id="github") - _session = rph_1.get_session_information(res["state"]) + url = rph_1.begin(issuer_id="github") + _state = get_state_from_url(url) + _session = rph_1.get_session_information(_state) client = rph_1.issuer2rp[_session["iss"]] _context = client.get_context() _nonce = _session["nonce"] @@ -493,7 +509,7 @@ def test_access_and_id_token_by_reference(self): idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} _github_id = iss_id("github") - _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar = _context.upstream_get("attribute", "keyjar") _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) @@ -522,9 +538,9 @@ def test_access_and_id_token_by_reference(self): ) client.get_service("accesstoken").endpoint = _url - _response = AuthorizationResponse(code="access_code", state=res["state"]) + _response = AuthorizationResponse(code="access_code", state=_state) _ = rph_1.finalize_auth(client, _session["iss"], _response.to_dict()) - resp = rph_1.get_access_and_id_token(state=res["state"]) + resp = rph_1.get_access_and_id_token(state=_state) assert resp["access_token"] == "accessTok" assert isinstance(resp["id_token"], IdToken) @@ -533,8 +549,9 @@ def test_get_user_info(self): BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY, module_dirs=["oidc"] ) - res = rph_1.begin(issuer_id="github") - _session = rph_1.get_session_information(res["state"]) + url = rph_1.begin(issuer_id="github") + _state = get_state_from_url(url) + _session = rph_1.get_session_information(_state) client = rph_1.issuer2rp[_session["iss"]] _context = client.get_context() _nonce = _session["nonce"] @@ -543,7 +560,7 @@ def test_get_user_info(self): idval = {"nonce": _nonce, "sub": "EndUserSubject", "iss": _iss, "aud": _aud} _github_id = iss_id("github") - _keyjar = _context.upstream_get('attribute', 'keyjar') + _keyjar = _context.upstream_get("attribute", "keyjar") _keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) @@ -572,7 +589,7 @@ def test_get_user_info(self): ) client.get_service("accesstoken").endpoint = _url - _response = AuthorizationResponse(code="access_code", state=res["state"]) + _response = AuthorizationResponse(code="access_code", state=_state) auth_response = rph_1.finalize_auth(client, _session["iss"], _response.to_dict()) token_resp = rph_1.get_access_and_id_token(auth_response, client=client) @@ -588,7 +605,7 @@ def test_get_user_info(self): ) client.get_service("userinfo").endpoint = _url - userinfo_resp = rph_1.get_user_info(res["state"], client, token_resp["access_token"]) + userinfo_resp = rph_1.get_user_info(_state, client, token_resp["access_token"]) assert userinfo_resp def test_userinfo_in_id_token(self): @@ -596,8 +613,9 @@ def test_userinfo_in_id_token(self): BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY, module_dirs=["oidc"] ) - res = rph_1.begin(issuer_id="github") - _session = rph_1.get_session_information(res["state"]) + url = rph_1.begin(issuer_id="github") + _state = get_state_from_url(url) + _session = rph_1.get_session_information(_state) client = rph_1.issuer2rp[_session["iss"]] # _context = client.client_get("service_context") _nonce = _session["nonce"] diff --git a/tests/test_client_51_identity_assurance.py b/tests/test_client_51_identity_assurance.py index f7fbf39f..671cc9ae 100644 --- a/tests/test_client_51_identity_assurance.py +++ b/tests/test_client_51_identity_assurance.py @@ -72,7 +72,7 @@ def test_unpack_aggregated_response(self): }, } - _jwt = JWT(key_jar=self.service.upstream_get("attribute",'keyjar')) + _jwt = JWT(key_jar=self.service.upstream_get("attribute", "keyjar")) _jws = _jwt.pack(payload=_distributed_respone) resp = { diff --git a/tests/test_client_55_token_exchange.py b/tests/test_client_55_token_exchange.py index 976d3b6a..d2951070 100644 --- a/tests/test_client_55_token_exchange.py +++ b/tests/test_client_55_token_exchange.py @@ -52,21 +52,16 @@ def create_request(self): "requests_dir": "requests", "base_url": "https://example.com/cli/", } - entity = Entity(keyjar=make_keyjar(), config=client_config, - services={ - "discovery": { - "class": - "idpyoidc.client.oauth2.server_metadata.ServerMetadata"}, - "authorization": { - "class": "idpyoidc.client.oauth2.authorization.Authorization"}, - "access_token": { - "class": "idpyoidc.client.oauth2.access_token.AccessToken"}, - "token_exchange": { - "class": - "idpyoidc.client.oauth2.token_exchange.TokenExchange" - }, - } - ) + entity = Entity( + keyjar=make_keyjar(), + config=client_config, + services={ + "discovery": {"class": "idpyoidc.client.oauth2.server_metadata.ServerMetadata"}, + "authorization": {"class": "idpyoidc.client.oauth2.authorization.Authorization"}, + "access_token": {"class": "idpyoidc.client.oauth2.access_token.AccessToken"}, + "token_exchange": {"class": "idpyoidc.client.oauth2.token_exchange.TokenExchange"}, + }, + ) entity.get_context().issuer = "https://example.com" self.service = entity.get_service("token_exchange") _cstate = self.service.upstream_get("context").cstate @@ -79,8 +74,9 @@ def create_request(self): ver_idt = IdToken().from_jwt(idt, make_keyjar()) - token_response = AccessTokenResponse(access_token="access_token", id_token=idt, - __verified_id_token=ver_idt) + token_response = AccessTokenResponse( + access_token="access_token", id_token=idt, __verified_id_token=ver_idt + ) _cstate.update("abcde", token_response) def test_construct(self): diff --git a/tests/test_server_01_claims.py b/tests/test_server_01_claims.py index 9162e329..9ca4ba6e 100644 --- a/tests/test_server_01_claims.py +++ b/tests/test_server_01_claims.py @@ -139,10 +139,11 @@ def create_idtoken(self): "add_claims": { "always": {}, }, - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } - self.server.get_attribute('keyjar').add_symmetric("client_1", "hemligtochintekort", - ["sig", "enc"]) + self.server.get_attribute("keyjar").add_symmetric( + "client_1", "hemligtochintekort", ["sig", "enc"] + ) self.claims_interface = self.context.claims_interface self.user_id = USER_ID diff --git a/tests/test_server_03_authz_handling.py b/tests/test_server_03_authz_handling.py index 4edc5c92..d5effeef 100644 --- a/tests/test_server_03_authz_handling.py +++ b/tests/test_server_03_authz_handling.py @@ -132,9 +132,9 @@ def create_idtoken(self): "client_salt": "salted", "token_endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } - server.get_attribute('keyjar').add_symmetric( + server.get_attribute("keyjar").add_symmetric( "client_1", "hemligtochintekort", ["sig", "enc"] ) server.endpoint = do_endpoints(conf, server.upstream_get) diff --git a/tests/test_server_06_grant.py b/tests/test_server_06_grant.py index f56a80e0..2f163811 100644 --- a/tests/test_server_06_grant.py +++ b/tests/test_server_06_grant.py @@ -439,7 +439,7 @@ def test_get_usage_rules(self): # Default usage rules self.context.cdb["client_id"] = {} rules = get_usage_rules("access_token", self.context, grant, "client_id") - assert rules == {"supports_minting": [], "expires_in": 3600} + assert rules == {"expires_in": 3600} # client specific usage rules self.context.cdb["client_id"] = {"access_token": {"expires_in": 600}} diff --git a/tests/test_server_07_sess_mngm_db.py b/tests/test_server_07_sess_mngm_db.py index 80aaaff3..746889a4 100644 --- a/tests/test_server_07_sess_mngm_db.py +++ b/tests/test_server_07_sess_mngm_db.py @@ -25,13 +25,17 @@ class TestDB: @pytest.fixture(autouse=True) def setup_environment(self): - self.db = Database(crypt_config=CRYPT_CONFIG, - session_params={"node_type": ["user", "client", "grant"], - "node_info_class": { - "user": UserSessionInfo, - "client": ClientSessionInfo, - "grant": Grant} - }) + self.db = Database( + crypt_config=CRYPT_CONFIG, + session_params={ + "node_type": ["user", "client", "grant"], + "node_info_class": { + "user": UserSessionInfo, + "client": ClientSessionInfo, + "grant": Grant, + }, + }, + ) def test_user_info(self): with pytest.raises(KeyError): @@ -104,13 +108,13 @@ def test_client_info_add2(self): self.db.set(["diana", "client_1", "G1"], grant) stored_client_info = self.db.get(["diana", "client_1"]) - assert isinstance(stored_client_info,ClientSessionInfo) + assert isinstance(stored_client_info, ClientSessionInfo) assert set(stored_client_info.keys()) == { "subordinate", "revoked", "type", "extra_args", - "id" + "id", } stored_grant_info = self.db.get(["diana", "client_1", "G1"]) @@ -126,7 +130,7 @@ def test_jump_ahead(self): user_info = self.db.get(["diana"]) assert user_info.subordinate == ["diana;;client_1"] client_info = self.db.get(["diana", "client_1"]) - assert client_info.subordinate == ['diana;;client_1;;G1'] + assert client_info.subordinate == ["diana;;client_1;;G1"] grant_info = self.db.get(["diana", "client_1", "G1"]) assert grant_info.issued_at assert len(grant_info.issued_token) == 1 diff --git a/tests/test_server_08_id_token.py b/tests/test_server_08_id_token.py index fddf289f..0229b26d 100644 --- a/tests/test_server_08_id_token.py +++ b/tests/test_server_08_id_token.py @@ -173,7 +173,7 @@ def create_session_manager(self): "always": {}, "by_scope": {}, }, - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } self.server.keyjar.add_symmetric("client_1", "hemligtochintekort", ["sig", "enc"]) self.session_manager = self.context.session_manager @@ -435,11 +435,7 @@ def test_get_sign_algorithm(self): self.context, client_info, "id_token", sign=True, encrypt=True ) # default signing alg - assert algs == { - "sign": True, - "encrypt": True, - "sign_alg": "RS256" - } + assert algs == {"sign": True, "encrypt": True, "sign_alg": "RS256"} def test_available_claims(self): req = dict(AREQ) @@ -506,9 +502,7 @@ def test_client_claims(self): grant = self.session_manager[session_id] self.session_manager.token_handler["id_token"].kwargs["enable_claims_per_client"] = True - self.context.cdb["client_1"]["add_claims"]["always"]["id_token"] = { - "address": None - } + self.context.cdb["client_1"]["add_claims"]["always"]["id_token"] = {"address": None} _claims = self.context.claims_interface.get_claims( session_id=session_id, scopes=AREQ["scope"], claims_release_point="id_token" diff --git a/tests/test_server_09_authn_context.py b/tests/test_server_09_authn_context.py index 0d77920b..4b9a72ea 100644 --- a/tests/test_server_09_authn_context.py +++ b/tests/test_server_09_authn_context.py @@ -65,7 +65,7 @@ "claim_types_supported": ["normal", "aggregated", "distributed"], "claims_parameter_supported": True, "request_parameter_supported": True, - "request_uri_parameter_supported": True, + # "request_uri_parameter_supported": True, } BASEDIR = os.path.abspath(os.path.dirname(__file__)) diff --git a/tests/test_server_10_session_manager.py b/tests/test_server_10_session_manager.py index 1518e7bb..8812a68f 100644 --- a/tests/test_server_10_session_manager.py +++ b/tests/test_server_10_session_manager.py @@ -107,7 +107,14 @@ def create_session_manager(self): }, "refresh_token": {"supports_minting": ["id_token"]}, }, - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": [ + "openid", + "profile", + "email", + "address", + "phone", + "offline_access", + ], } } @@ -151,7 +158,7 @@ def test_create_session_sub_type(self, sub_type, sector_identifier): ) _user_info_1 = self.session_manager.get_user_session_info(session_key_1) - assert _user_info_1.subordinate == ['diana;;client_1'] + assert _user_info_1.subordinate == ["diana;;client_1"] _client_info_1 = self.session_manager.get_client_session_info(session_key_1) assert len(_client_info_1.subordinate) == 1 # grant = self.session_manager.get_grant(session_key_1) @@ -353,7 +360,7 @@ def test_get_general_session_info(self): "user", "client", "grant", - "branch_id" + "branch_id", } assert _session_info["user_id"] == "diana" assert _session_info["client_id"] == "client_1" @@ -379,7 +386,7 @@ def test_get_session_info_by_token(self): "user_id", "user", "client", - "grant" + "grant", } assert _session_info["user_id"] == "diana" assert _session_info["client_id"] == "client_1" @@ -515,7 +522,7 @@ def test_token_usage_client_config(self): }, "refresh_token": {"supports_minting": ["access_token"]}, }, - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } token_usage_rules = self.endpoint_context.authz.usage_rules("client_1") @@ -671,7 +678,8 @@ def test_grants(self): grant = self.session_manager[_session_id] grant_kwargs = grant.parameter for i in ("not_before", "used"): - grant_kwargs.pop(i) + if i in grant_kwargs: + del grant_kwargs[i] self.session_manager.add_grant(["diana", "client_1"], **grant_kwargs) def test_find_latest_idtoken(self): diff --git a/tests/test_server_12_session_life.py b/tests/test_server_12_session_life.py index 2bbd3856..d36ecdb6 100644 --- a/tests/test_server_12_session_life.py +++ b/tests/test_server_12_session_life.py @@ -229,7 +229,7 @@ def test_code_flow(self): "claim_types_supported": ["normal", "aggregated", "distributed"], "claims_parameter_supported": True, "request_parameter_supported": True, - "request_uri_parameter_supported": True, + # "request_uri_parameter_supported": True, } BASEDIR = os.path.abspath(os.path.dirname(__file__)) diff --git a/tests/test_server_16_endpoint.py b/tests/test_server_16_endpoint.py index 9ebf8173..5a3b59de 100755 --- a/tests/test_server_16_endpoint.py +++ b/tests/test_server_16_endpoint.py @@ -76,7 +76,7 @@ def create_endpoint(self): } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - server.context.cdb["client_id"] = {} + server.context.cdb["client_id"] = {"redirect_uris": [("https://example.com/cb", None)]} self.context = server.context _endpoints = do_endpoints(conf, server.unit_get) self.endpoint = _endpoints[""] @@ -108,7 +108,7 @@ def test_parse_dict(self): def test_parse_jwt(self): self.endpoint.request_format = "jwt" - kj = self.endpoint.upstream_get('attribute','keyjar') + kj = self.endpoint.upstream_get("attribute", "keyjar") request = REQ.to_jwt(kj.get_signing_key("RSA"), "RS256") req = self.endpoint.parse_request(request) assert req == REQ diff --git a/tests/test_server_16_endpoint_context.py b/tests/test_server_16_endpoint_context.py index 38d3dc2c..39c1113b 100644 --- a/tests/test_server_16_endpoint_context.py +++ b/tests/test_server_16_endpoint_context.py @@ -46,14 +46,7 @@ class Endpoint_1(Endpoint): "client_secret_basic", ], "subject_types_supported": ["public", "pairwise"], - "endpoint": { - - "userinfo": { - "path": "userinfo", - "class": Endpoint_1, - "kwargs": {} - } - }, + "endpoint": {"userinfo": {"path": "userinfo", "class": Endpoint_1, "kwargs": {}}}, "token_handler_args": { "jwks_def": { "private_path": "private/token_jwks.json", @@ -88,7 +81,6 @@ class Endpoint_1(Endpoint): class TestEndpointContext: - @pytest.fixture(autouse=True) def create_endpoint_context(self): server = Server(conf) @@ -98,13 +90,14 @@ def create_endpoint_context(self): def test(self): self.context.set_provider_info() assert set(self.context.provider_info.keys()) == { - 'id_token_signing_alg_values_supported', - 'issuer', - 'jwks_uri', - 'scopes_supported', - 'subject_types_supported', - 'userinfo_signing_alg_values_supported', - 'version'} + "id_token_signing_alg_values_supported", + "issuer", + "jwks_uri", + "scopes_supported", + "subject_types_supported", + "userinfo_signing_alg_values_supported", + "version", + } class Tokenish(Endpoint): @@ -171,27 +164,35 @@ def test_provider_configuration(kwargs): server.context.cdb["client_id"] = {} server.context.set_provider_info() pi = server.context.provider_info - assert set(pi.keys()) == {'acr_values_supported', - 'id_token_signing_alg_values_supported', - 'issuer', - 'jwks_uri', - 'scopes_supported', - 'subject_types_supported', - 'token_endpoint_auth_methods_supported', - 'version'} + assert set(pi.keys()) == { + "acr_values_supported", + "id_token_signing_alg_values_supported", + "issuer", + "jwks_uri", + "scopes_supported", + "subject_types_supported", + "token_endpoint_auth_methods_supported", + "version", + } if kwargs: - if 'token_endpoint_auth_methods_supported' in kwargs: - assert pi["token_endpoint_auth_methods_supported"] == ['client_secret_jwt', - 'private_key_jwt'] + if "token_endpoint_auth_methods_supported" in kwargs: + assert pi["token_endpoint_auth_methods_supported"] == [ + "client_secret_jwt", + "private_key_jwt", + ] else: - assert pi["token_endpoint_auth_methods_supported"] == ['client_secret_post', - 'client_secret_basic', - 'client_secret_jwt', - 'private_key_jwt'] + assert pi["token_endpoint_auth_methods_supported"] == [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ] else: - assert pi["token_endpoint_auth_methods_supported"] == ['client_secret_post', - 'client_secret_basic', - 'client_secret_jwt', - 'private_key_jwt'] + assert pi["token_endpoint_auth_methods_supported"] == [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ] diff --git a/tests/test_server_17_client_authn.py b/tests/test_server_17_client_authn.py index 0fe2d533..b329644e 100644 --- a/tests/test_server_17_client_authn.py +++ b/tests/test_server_17_client_authn.py @@ -48,7 +48,9 @@ class Endpoint_2(Endpoint): class Endpoint_3(Endpoint): name = "endpoint_3" - def __init__(self, upstream_get: Callable, add_claims_by_scope: Optional[bool] = True, **kwargs): + def __init__( + self, upstream_get: Callable, add_claims_by_scope: Optional[bool] = True, **kwargs + ): Endpoint.__init__( self, upstream_get, @@ -251,29 +253,21 @@ def test_private_key_jwt_reusage_other_endpoint(self): _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True - _assertion = _jwt.pack( - {"aud": [self.server.get_endpoint("endpoint_1").full_path]} - ) + _assertion = _jwt.pack({"aud": [self.server.get_endpoint("endpoint_1").full_path]}) request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} # This should be OK assert self.method.is_usable(request=request) - self.method.verify( - request=request, endpoint=self.server.get_endpoint("endpoint_1") - ) + self.method.verify(request=request, endpoint=self.server.get_endpoint("endpoint_1")) # This should NOT be OK with pytest.raises(InvalidToken): - self.method.verify( - request=request, endpoint=self.server.get_endpoint("authorization") - ) + self.method.verify(request=request, endpoint=self.server.get_endpoint("authorization")) # This should NOT be OK because this is the second time the token appears with pytest.raises(InvalidToken): - self.method.verify( - request=request, endpoint=self.server.get_endpoint("endpoint_1") - ) + self.method.verify(request=request, endpoint=self.server.get_endpoint("endpoint_1")) def test_private_key_jwt_auth_endpoint(self): # Own dynamic keys @@ -286,9 +280,7 @@ def test_private_key_jwt_auth_endpoint(self): _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True - _assertion = _jwt.pack( - {"aud": [self.server.get_endpoint("endpoint_2").full_path]} - ) + _assertion = _jwt.pack({"aud": [self.server.get_endpoint("endpoint_2").full_path]}) request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} @@ -337,7 +329,10 @@ def create_method(self): def test_bearer_body(self): request = {"access_token": "1234567890"} - assert self.method.verify(request, get_client_id_from_token=get_client_id_from_token) == {"token": "1234567890", "method": "bearer_body"} + assert self.method.verify(request, get_client_id_from_token=get_client_id_from_token) == { + "token": "1234567890", + "method": "bearer_body", + } def test_bearer_body_no_token(self): request = {} @@ -488,9 +483,7 @@ def test_verify_per_client(self): assert res == {"method": "public", "client_id": client_id} def test_verify_per_client_per_endpoint(self): - self.server.context.cdb[client_id]["registration_endpoint_client_authn_method"] = [ - "public" - ] + self.server.context.cdb[client_id]["registration_endpoint_client_authn_method"] = ["public"] self.server.context.cdb[client_id]["token_endpoint_client_authn_method"] = [ "client_secret_post" ] @@ -710,10 +703,7 @@ class Mock: server.endpoint = do_endpoints(CONF, server.unit_get) request = {"redirect_uris": ["https://example.com/cb"]} - res = verify_client( - request=request, - endpoint=server.get_endpoint("endpoint_4") - ) + res = verify_client(request=request, endpoint=server.get_endpoint("endpoint_4")) assert res == {"client_id": "client_id", "method": "custom"} mock.is_usable.assert_called_once() diff --git a/tests/test_server_20a_server.py b/tests/test_server_20a_server.py index 3f41200f..11e6d6fb 100755 --- a/tests/test_server_20a_server.py +++ b/tests/test_server_20a_server.py @@ -1,16 +1,16 @@ +from copy import copy +from copy import deepcopy import io import json import os -from copy import copy -from copy import deepcopy -import yaml from cryptojwt.key_jar import build_keyjar +import yaml from idpyoidc.server import Server from idpyoidc.server.configure import OPConfiguration from idpyoidc.server.login_hint import LoginHintLookup -from idpyoidc.server.oidc.add_on.pkce import add_pkce_support +from idpyoidc.server.oauth2.add_on.pkce import add_support from idpyoidc.server.oidc.authorization import Authorization from idpyoidc.server.oidc.provider_config import ProviderConfiguration from idpyoidc.server.oidc.registration import Registration @@ -73,7 +73,7 @@ def full_path(local_file): } }, "claims_interface": {"class": "idpyoidc.server.session.claims.ClaimsInterface", "kwargs": {}}, - "add_on": {"pkce": {"function": add_pkce_support, "kwargs": {"essential": True}}}, + "add_on": {"pkce": {"function": add_support, "kwargs": {"essential": True}}}, "template_dir": "template", "login_hint_lookup": {"class": LoginHintLookup, "kwargs": {}}, "session_params": SESSION_PARAMS, @@ -119,16 +119,11 @@ def test_capabilities_default(): server = Server(configuration) assert set(server.context.provider_info["response_types_supported"]) == { "code", - "token", "id_token", - "code token", "code id_token", - "id_token token", - "code id_token token", } assert server.context.provider_info["request_uri_parameter_supported"] is True - assert server.context.get_preference('jwks_uri') == \ - "https://127.0.0.1:443/static/jwks.json" + assert server.context.get_preference("jwks_uri") == "https://127.0.0.1:443/static/jwks.json" def test_capabilities_subset1(): diff --git a/tests/test_server_20b_claims.py b/tests/test_server_20b_claims.py index 1d95fece..f84572ab 100644 --- a/tests/test_server_20b_claims.py +++ b/tests/test_server_20b_claims.py @@ -125,11 +125,9 @@ def create_idtoken(self): "add_claims": { "always": {}, }, - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } - server.keyjar.add_symmetric( - "client_1", "hemligtochintekort", ["sig", "enc"] - ) + server.keyjar.add_symmetric("client_1", "hemligtochintekort", ["sig", "enc"]) self.claims_interface = server.context.claims_interface self.context = server.context self.session_manager = self.context.session_manager diff --git a/tests/test_server_20c_authz_handling.py b/tests/test_server_20c_authz_handling.py index 797e5450..f4c8ba8f 100644 --- a/tests/test_server_20c_authz_handling.py +++ b/tests/test_server_20c_authz_handling.py @@ -108,11 +108,9 @@ def create_idtoken(self): "client_salt": "salted", "token_endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } - server.keyjar.add_symmetric( - "client_1", "hemligtochintekort", ["sig", "enc"] - ) + server.keyjar.add_symmetric("client_1", "hemligtochintekort", ["sig", "enc"]) self.session_manager = server.context.session_manager self.user_id = USER_ID self.server = server diff --git a/tests/test_server_20d_client_authn.py b/tests/test_server_20d_client_authn.py index 0c392c5b..21beb359 100755 --- a/tests/test_server_20d_client_authn.py +++ b/tests/test_server_20d_client_authn.py @@ -222,15 +222,11 @@ def test_private_key_jwt_reusage_other_endpoint(self): # This should NOT be OK with pytest.raises(InvalidToken): - self.method.verify( - request=request, endpoint=self.server.get_endpoint("authorization") - ) + self.method.verify(request=request, endpoint=self.server.get_endpoint("authorization")) # This should NOT be OK because this is the second time the token appears with pytest.raises(InvalidToken): - self.method.verify( - request=request, endpoint=self.server.get_endpoint("token") - ) + self.method.verify(request=request, endpoint=self.server.get_endpoint("token")) def test_private_key_jwt_auth_endpoint(self): # Own dynamic keys @@ -243,9 +239,7 @@ def test_private_key_jwt_auth_endpoint(self): _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True - _assertion = _jwt.pack( - {"aud": [self.server.get_endpoint("authorization").full_path]} - ) + _assertion = _jwt.pack({"aud": [self.server.get_endpoint("authorization").full_path]}) request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} @@ -293,7 +287,9 @@ def create_method(self): def test_bearer_body(self): request = {"access_token": "1234567890"} assert self.method.verify(request, get_client_id_from_token=get_client_id_from_token) == { - "token": "1234567890", "method": "bearer_body"} + "token": "1234567890", + "method": "bearer_body", + } def test_bearer_body_no_token(self): request = {} @@ -442,9 +438,7 @@ def test_verify_per_client(self): assert res == {"method": "public", "client_id": client_id} def test_verify_per_client_per_endpoint(self): - self.server.context.cdb[client_id]["registration_endpoint_client_authn_method"] = [ - "public" - ] + self.server.context.cdb[client_id]["registration_endpoint_client_authn_method"] = ["public"] self.server.context.cdb[client_id]["token_endpoint_client_authn_method"] = [ "client_secret_post" ] @@ -662,10 +656,7 @@ class Mock: server.context.cdb[client_id] = {"client_secret": client_secret} request = {"redirect_uris": ["https://example.com/cb"]} - res = verify_client( - request=request, - endpoint=server.get_endpoint("registration") - ) + res = verify_client(request=request, endpoint=server.get_endpoint("registration")) assert res == {"client_id": "client_id", "method": "custom"} mock.is_usable.assert_called_once() diff --git a/tests/test_server_20e_jwt_token.py b/tests/test_server_20e_jwt_token.py index 87ea8209..d824bd15 100644 --- a/tests/test_server_20e_jwt_token.py +++ b/tests/test_server_20e_jwt_token.py @@ -62,7 +62,7 @@ "claim_types_supported": ["normal", "aggregated", "distributed"], "claims_parameter_supported": True, "request_parameter_supported": True, - "request_uri_parameter_supported": True, + # "request_uri_parameter_supported": True, } AUTH_REQ = AuthorizationRequest( @@ -207,7 +207,7 @@ def create_endpoint(self): "always": {}, "by_scope": {}, }, - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } self.session_manager = self.context.session_manager self.user_id = "diana" @@ -269,9 +269,7 @@ def test_info(self): @pytest.mark.parametrize("enable_claims_per_client", [True, False]) def test_enable_claims_per_client(self, enable_claims_per_client): # Set up configuration - self.context.cdb["client_1"]["add_claims"]["always"]["access_token"] = { - "address": None - } + self.context.cdb["client_1"]["add_claims"]["always"]["access_token"] = {"address": None} self.context.session_manager.token_handler.handler["access_token"].kwargs[ "enable_claims_per_client" ] = enable_claims_per_client @@ -411,7 +409,15 @@ def create_endpoint(self): "always": {}, "by_scope": {}, }, - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access", "webid"] + "allowed_scopes": [ + "openid", + "profile", + "email", + "address", + "phone", + "offline_access", + "webid", + ], } self.session_manager = self.context.session_manager self.user_id = "diana" @@ -517,7 +523,7 @@ def test_mint_with_scope(self): grant, session_id, code, - scope=["openid", 'foobar'], + scope=["openid", "foobar"], aud=["https://audience.example.com"], ) diff --git a/tests/test_server_20f_userinfo.py b/tests/test_server_20f_userinfo.py index 8a5a8a64..544a059d 100644 --- a/tests/test_server_20f_userinfo.py +++ b/tests/test_server_20f_userinfo.py @@ -199,7 +199,7 @@ def create_endpoint_context(self): "always": {}, "by_scope": {}, }, - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } self.session_manager = self.endpoint_context.session_manager self.claims_interface = ClaimsInterface(server.unit_get) @@ -424,7 +424,15 @@ def create_endpoint_context(self, conf): self.server = Server(conf) self.endpoint_context = self.server.context self.endpoint_context.cdb["client1"] = { - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access", "research_and_scholarship"] + "allowed_scopes": [ + "openid", + "profile", + "email", + "address", + "phone", + "offline_access", + "research_and_scholarship", + ] } self.session_manager = self.endpoint_context.session_manager self.claims_interface = ClaimsInterface(self.server.unit_get) diff --git a/tests/test_server_22_oidc_provider_config_endpoint.py b/tests/test_server_22_oidc_provider_config_endpoint.py index bd5f20a4..7000d724 100755 --- a/tests/test_server_22_oidc_provider_config_endpoint.py +++ b/tests/test_server_22_oidc_provider_config_endpoint.py @@ -57,7 +57,7 @@ def conf(self): return { "issuer": "https://example.com/", "httpc_params": {"verify": False}, - "capabilities": CAPABILITIES, + "preference": CAPABILITIES, "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS}, "endpoint": { "provider_config": { diff --git a/tests/test_server_23_oidc_registration_endpoint.py b/tests/test_server_23_oidc_registration_endpoint.py index 8550d6a1..44f1f2ce 100755 --- a/tests/test_server_23_oidc_registration_endpoint.py +++ b/tests/test_server_23_oidc_registration_endpoint.py @@ -163,7 +163,9 @@ def create_endpoint(self): "session_params": SESSION_PARAMS, } server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) - server.context.cdb["client_id"] = {} + server.context.cdb["client_id"] = { + "redirect_uris": [("https://example.com/cb", None)], + } self.endpoint = server.get_endpoint("registration") def test_parse(self): @@ -339,7 +341,10 @@ def test_register_initiate_login_uri_wrong_scheme(self): assert _resp["error"] == "invalid_configuration_request" def test_register_unsupported_response_type(self): - self.endpoint.upstream_get("context").provider_info["response_types_supported"] = ["token", "id_token"] + self.endpoint.upstream_get("context").provider_info["response_types_supported"] = [ + "token", + "id_token", + ] _msg = MSG.copy() _msg["response_types"] = ["id_token token"] _req = self.endpoint.parse_request(RegistrationRequest(**_msg).to_json()) diff --git a/tests/test_server_24_oauth2_authorization_endpoint.py b/tests/test_server_24_oauth2_authorization_endpoint.py index e5f0a74d..efbf6942 100755 --- a/tests/test_server_24_oauth2_authorization_endpoint.py +++ b/tests/test_server_24_oauth2_authorization_endpoint.py @@ -270,7 +270,7 @@ def create_endpoint(self): self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") - self.endpoint.upstream_get("attribute",'keyjar').add_symmetric( + self.endpoint.upstream_get("attribute", "keyjar").add_symmetric( "client_1", "hemligtkodord1234567890" ) @@ -491,7 +491,7 @@ def test_create_authn_response(self): "client_id": "client_id", "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "ES256", - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } session_id = self._create_session(request) @@ -570,7 +570,7 @@ def test_setup_auth_invalid_scope(self): "client_id": "client_id", "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "RS256", - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } _context = self.endpoint.upstream_get("context") @@ -708,14 +708,14 @@ def test_req_user_no_prompt(self): def test_unwrap_identity(self): identity = { - 'sid': - 'Z0FBQUFBQmlZQXFBeDlvSjRENVVYSDBFeTZ6YzVQWTRGVy1laFk2ZmJIbWdPeUhzbVJYbWo5clVPQ045MXpiSVYwS0pfZkREaVUwX2VaVU9HMk9hUktxaGR0R0dQMlRLOXVWQWVTYWJMdDFsVWZJUEItWS1NVi1WQXllNEVlYm9KMDJsSmFYU0pLYWVJeVRKZkJCYmE1T2RpWXRPM3ZmanRlMThfLUNvcnd4ZXVxcFBWdDY0M18tbXNzbjFvbGl4OFdJRTF6YTcwQ3dqNjdsRHdUa1V4ZTlZMjU3SVlXaXdSSTVJSFJJNENwand3a2pOdmV2WGFPRGZhSnZma2NkZ01ZZk1iS3hma1phcQ==', - 'state': '80ec120d9a322e70e02503e9a99e734174c1e6cb', - 'timestamp': 1650461312, - 'uid': '6260077f56d8970e543aa380', - 'grant_id': 'c636b820c0ad11ecbdd1acde48001122'} + "sid": "Z0FBQUFBQmlZQXFBeDlvSjRENVVYSDBFeTZ6YzVQWTRGVy1laFk2ZmJIbWdPeUhzbVJYbWo5clVPQ045MXpiSVYwS0pfZkREaVUwX2VaVU9HMk9hUktxaGR0R0dQMlRLOXVWQWVTYWJMdDFsVWZJUEItWS1NVi1WQXllNEVlYm9KMDJsSmFYU0pLYWVJeVRKZkJCYmE1T2RpWXRPM3ZmanRlMThfLUNvcnd4ZXVxcFBWdDY0M18tbXNzbjFvbGl4OFdJRTF6YTcwQ3dqNjdsRHdUa1V4ZTlZMjU3SVlXaXdSSTVJSFJJNENwand3a2pOdmV2WGFPRGZhSnZma2NkZ01ZZk1iS3hma1phcQ==", + "state": "80ec120d9a322e70e02503e9a99e734174c1e6cb", + "timestamp": 1650461312, + "uid": "6260077f56d8970e543aa380", + "grant_id": "c636b820c0ad11ecbdd1acde48001122", + } _id = self.endpoint._unwrap_identity(identity) - assert _id["uid"] == '6260077f56d8970e543aa380' + assert _id["uid"] == "6260077f56d8970e543aa380" # def test_sso(self): # _pr_resp = self.endpoint.parse_request(AUTH_REQ_DICT) diff --git a/tests/test_server_24_oauth2_resource_indicators.py b/tests/test_server_24_oauth2_resource_indicators.py index b991bfd2..a98638ed 100644 --- a/tests/test_server_24_oauth2_resource_indicators.py +++ b/tests/test_server_24_oauth2_resource_indicators.py @@ -38,8 +38,12 @@ from idpyoidc.server.oauth2.authorization import inputs from idpyoidc.server.oauth2.authorization import join_query from idpyoidc.server.oauth2.authorization import verify_uri -from idpyoidc.server.oauth2.authorization import validate_resource_indicators_policy as validate_authorization_resource_indicators_policy -from idpyoidc.server.oauth2.token_helper import validate_resource_indicators_policy as validate_token_resource_indicators_policy +from idpyoidc.server.oauth2.authorization import ( + validate_resource_indicators_policy as validate_authorization_resource_indicators_policy, +) +from idpyoidc.server.oauth2.token_helper import ( + validate_resource_indicators_policy as validate_token_resource_indicators_policy, +) from idpyoidc.server.user_info import UserInfo from idpyoidc.time_util import in_a_while from tests import CRYPT_CONFIG @@ -47,7 +51,7 @@ KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["sig"]} + {"type": "EC", "crv": "P-256", "use": ["sig"]}, ] COOKIE_KEYDEFS = [ @@ -352,9 +356,7 @@ def get_cookie_value(cookie=None, name=None): "policy": { "function": validate_token_resource_indicators_policy, "kwargs": { - "resource_servers_per_client": { - "client_1": ["client_2", "client_3"] - }, + "resource_servers_per_client": {"client_1": ["client_2", "client_3"]}, }, } }, @@ -410,6 +412,7 @@ def get_cookie_value(cookie=None, name=None): "session_params": SESSION_PARAMS, } + class TestEndpoint(object): @pytest.fixture(autouse=False) def create_endpoint_ri_disabled(self): @@ -482,7 +485,7 @@ def _mint_code(self, grant, client_id): token_class="authorization_code", token_handler=self.session_manager.token_handler["authorization_code"], usage_rules=usage_rules, - resources=grant.resources + resources=grant.resources, ) if _exp_in: @@ -521,7 +524,9 @@ def test_authorization_code_req_no_resource(self, create_endpoint_ri_enabled): assert "error" in msg assert msg["error_description"] == "Missing resource parameter" - def test_authorization_code_req_no_resource_indicators_disabled(self, create_endpoint_ri_disabled): + def test_authorization_code_req_no_resource_indicators_disabled( + self, create_endpoint_ri_disabled + ): """ Test successful authorization request when resource indicators is disabled. """ @@ -552,9 +557,7 @@ def test_authorization_code_req_per_client(self, create_endpoint_ri_disabled): "authorization_code": { "policy": { "function": validate_authorization_resource_indicators_policy, - "kwargs": { - "resource_servers_per_client":["client_3"] - }, + "kwargs": {"resource_servers_per_client": ["client_3"]}, }, }, } @@ -589,7 +592,7 @@ def test_authorization_code_req_invalid_resource_client(self, create_endpoint_ri for the authorization endpoint and requested resource is not permitted for client. """ request = AUTH_REQ.copy() - request["resource"] = "client_2" + request["resource"] = "client_3" client_id = request["client_id"] endpoint_context = self.endpoint.upstream_get("context") @@ -607,7 +610,7 @@ def test_access_token_req(self, create_endpoint_ri_enabled): "client_id": "client_3", "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "ES256", - "allowed_scopes": ["openid"] + "allowed_scopes": ["openid"], } session_id = self._create_session(AUTH_REQ) grant = self.session_manager[session_id] @@ -661,7 +664,7 @@ def test_create_authn_response(self, create_endpoint_ri_enabled): "client_id": "client_3", "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "ES256", - "allowed_scopes": ["openid"] + "allowed_scopes": ["openid"], } session_id = self._create_session(AUTH_REQ) diff --git a/tests/test_server_24_oauth2_token_endpoint.py b/tests/test_server_24_oauth2_token_endpoint.py index 262f7331..55dcebf3 100644 --- a/tests/test_server_24_oauth2_token_endpoint.py +++ b/tests/test_server_24_oauth2_token_endpoint.py @@ -172,7 +172,6 @@ def conf(): class TestEndpoint(object): - @pytest.fixture(autouse=True) def create_endpoint(self, conf): server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) @@ -183,7 +182,7 @@ def create_endpoint(self, conf): "client_salt": "salted", "endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } server.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") self.session_manager = context.session_manager @@ -256,7 +255,7 @@ def test_parse(self): _token_request["code"] = code.value _req = self.token_endpoint.parse_request(_token_request) - assert set(_req.keys()).difference(set(_token_request.keys())) == {'authenticated'} + assert set(_req.keys()).difference(set(_token_request.keys())) == {"authenticated"} def test_auth_code_grant_disallowed_per_client(self): areq = AUTH_REQ.copy() @@ -457,7 +456,7 @@ def test_new_refresh_token(self, conf): "client_salt": "salted", "endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } areq = AUTH_REQ.copy() @@ -497,7 +496,7 @@ def test_revoke_on_issue_refresh_token(self, conf): "client_salt": "salted", "endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } self.token_endpoint.revoke_refresh_on_issue = True @@ -535,7 +534,7 @@ def test_revoke_on_issue_refresh_token_per_client(self, conf): "client_salt": "salted", "endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } self.context.cdb[AUTH_REQ["client_id"]]["revoke_refresh_on_issue"] = True areq = AUTH_REQ.copy() @@ -706,7 +705,7 @@ def test_do_refresh_access_token_not_allowed(self): _req = self.token_endpoint.parse_request(_request.to_json()) res = self.token_endpoint.process_request(_req) assert "error" in res - assert res["error_description"] == 'Minting of access_token not supported' + assert res["error_description"] == "Minting of access_token not supported" def test_do_refresh_access_token_revoked(self): areq = AUTH_REQ.copy() @@ -735,8 +734,9 @@ def test_do_refresh_access_token_revoked(self): def test_configure_grant_types(self): conf = {"access_token": {"class": "idpyoidc.server.oidc.token.AccessTokenHelper"}} - _helper = self.token_endpoint.configure_types(conf, - self.token_endpoint.helper_by_grant_type) + _helper = self.token_endpoint.configure_types( + conf, self.token_endpoint.helper_by_grant_type + ) assert len(_helper) == 1 assert "access_token" in _helper @@ -799,7 +799,7 @@ def test_refresh_token_request_other_client(self): "kwargs": { "lifetime": 3600, "add_claims_by_scope": True, - "aud": ["https://example.org/appl"] + "aud": ["https://example.org/appl"], }, }, "refresh": { @@ -819,8 +819,8 @@ def test_refresh_token_request_other_client(self): "lifetime": 3600, "add_claims_by_scope": True, "aud": ["https://example.org/appl"], - "profile": 'idpyoidc.message.oauth2.JWTAccessToken', - "with_jti": True + "profile": "idpyoidc.message.oauth2.JWTAccessToken", + "with_jti": True, }, }, "refresh": { @@ -835,9 +835,7 @@ def test_refresh_token_request_other_client(self): CONTEXT = OidcContext() CONTEXT.cwd = BASEDIR CONTEXT.issuer = "https://op.example.com" -CONTEXT.cdb = { - "client_1": {} -} +CONTEXT.cdb = {"client_1": {}} KEYJAR = KeyJar() KEYJAR.import_jwks(CLIENT_KEYJAR.export_jwks(private=True), "client_1") KEYJAR.import_jwks(CLIENT_KEYJAR.export_jwks(private=True), "") @@ -847,20 +845,16 @@ def upstream_get(what, *args): if what == "context": if not args: return CONTEXT - elif what == 'attribute': - if args[0] == 'keyjar': + elif what == "attribute": + if args[0] == "keyjar": return KEYJAR def test_def_jwttoken(): _handler = handler.factory(upstream_get=upstream_get, **DEFAULT_TOKEN_HANDLER_ARGS) - token_handler = _handler['access_token'] - token_payload = { - 'sub': 'subject_id', - 'aud': 'resource_1', - 'client_id': 'client_1' - } - value = token_handler(session_id='session_id', **token_payload) + token_handler = _handler["access_token"] + token_payload = {"sub": "subject_id", "aud": "resource_1", "client_id": "client_1"} + value = token_handler(session_id="session_id", **token_payload) _jws = factory(value) msg = JWTAccessToken(**_jws.jwt.payload()) @@ -871,13 +865,9 @@ def test_def_jwttoken(): def test_jwttoken(): _handler = handler.factory(upstream_get=upstream_get, **TOKEN_HANDLER_ARGS) - token_handler = _handler['access_token'] - token_payload = { - 'sub': 'subject_id', - 'aud': 'resource_1', - 'client_id': 'client_1' - } - value = token_handler(session_id='session_id', **token_payload) + token_handler = _handler["access_token"] + token_payload = {"sub": "subject_id", "aud": "resource_1", "client_id": "client_1"} + value = token_handler(session_id="session_id", **token_payload) _jws = factory(value) msg = JWTAccessToken(**_jws.jwt.payload()) @@ -893,19 +883,15 @@ class MyAccessToken(Message): "aud": REQUIRED_LIST_OF_STRINGS, "sub": SINGLE_REQUIRED_STRING, "iat": SINGLE_REQUIRED_INT, - 'usage': SINGLE_REQUIRED_STRING + "usage": SINGLE_REQUIRED_STRING, } def test_jwttoken_2(): _handler = handler.factory(upstream_get=upstream_get, **TOKEN_HANDLER_ARGS) - token_handler = _handler['access_token'] - token_payload = { - 'sub': 'subject_id', - 'aud': 'Skiresort', - 'usage': 'skilift' - } - value = token_handler(session_id='session_id', profile=MyAccessToken, **token_payload) + token_handler = _handler["access_token"] + token_payload = {"sub": "subject_id", "aud": "Skiresort", "usage": "skilift"} + value = token_handler(session_id="session_id", profile=MyAccessToken, **token_payload) _jws = factory(value) msg = MyAccessToken(**_jws.jwt.payload()) @@ -915,7 +901,6 @@ def test_jwttoken_2(): class TestClientCredentialsFlow(object): - @pytest.fixture(autouse=True) def create_endpoint(self, conf): server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) @@ -927,7 +912,7 @@ def create_endpoint(self, conf): "endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], - "grant_types_supported": ['client_credentials', 'password'] + "grant_types_supported": ["client_credentials", "password"], } self.session_manager = context.session_manager self.token_endpoint = server.get_endpoint("token") @@ -935,17 +920,24 @@ def create_endpoint(self, conf): self.context = context def test_client_credentials(self): - request = CCAccessTokenRequest(client_id="client_1", client_secret='hemligt', - grant_type='client_credentials', scope="whatever") + request = CCAccessTokenRequest( + client_id="client_1", + client_secret="hemligt", + grant_type="client_credentials", + scope="whatever", + ) request = self.token_endpoint.parse_request(request) response = self.token_endpoint.process_request(request) - assert set(response.keys()) == {'response_args', 'cookie', 'http_headers'} - assert set(response["response_args"].keys()) == {'access_token', 'token_type', 'scope', - 'expires_in'} + assert set(response.keys()) == {"response_args", "cookie", "http_headers"} + assert set(response["response_args"].keys()) == { + "access_token", + "token_type", + "scope", + "expires_in", + } class TestResourceOwnerPasswordCredentialsFlow(object): - @pytest.fixture(autouse=True) def create_endpoint(self, conf): conf["authentication"] = { @@ -955,9 +947,9 @@ def create_endpoint(self, conf): "kwargs": { "db_conf": { "class": "idpyoidc.server.util.JSONDictDB", - "kwargs": {"filename": "passwd.json"} + "kwargs": {"filename": "passwd.json"}, } - } + }, } } @@ -970,21 +962,27 @@ def create_endpoint(self, conf): "endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], - "grant_types_supported": ['client_credentials', 'password'], + "grant_types_supported": ["client_credentials", "password"], } self.session_manager = context.session_manager self.token_endpoint = server.get_endpoint("token") self.context = context def test_resource_owner_password_credentials(self): - request = ROPCAccessTokenRequest(client_id="client_1", - client_secret='hemligt', - grant_type='password', - username='diana', - password='krall', - scope="whatever") + request = ROPCAccessTokenRequest( + client_id="client_1", + client_secret="hemligt", + grant_type="password", + username="diana", + password="krall", + scope="whatever", + ) request = self.token_endpoint.parse_request(request) response = self.token_endpoint.process_request(request) - assert set(response.keys()) == {'response_args', 'cookie', 'http_headers'} - assert set(response["response_args"].keys()) == {'access_token', 'token_type', 'scope', - 'expires_in'} + assert set(response.keys()) == {"response_args", "cookie", "http_headers"} + assert set(response["response_args"].keys()) == { + "access_token", + "token_type", + "scope", + "expires_in", + } diff --git a/tests/test_server_24_oauth2_token_endpoint_def_conf.py b/tests/test_server_24_oauth2_token_endpoint_def_conf.py new file mode 100644 index 00000000..a367d4dc --- /dev/null +++ b/tests/test_server_24_oauth2_token_endpoint_def_conf.py @@ -0,0 +1,904 @@ +import os + +import pytest +from cryptojwt import JWT +from cryptojwt import KeyJar +from cryptojwt.jws.jws import factory +from cryptojwt.key_jar import build_keyjar + +from idpyoidc.context import OidcContext +from idpyoidc.defaults import JWT_BEARER +from idpyoidc.message import Message +from idpyoidc.message import REQUIRED_LIST_OF_STRINGS +from idpyoidc.message import SINGLE_REQUIRED_INT +from idpyoidc.message import SINGLE_REQUIRED_STRING +from idpyoidc.message.oauth2 import AccessTokenRequest +from idpyoidc.message.oauth2 import AuthorizationRequest +from idpyoidc.message.oauth2 import CCAccessTokenRequest +from idpyoidc.message.oauth2 import JWTAccessToken +from idpyoidc.message.oauth2 import RefreshAccessTokenRequest +from idpyoidc.message.oauth2 import ROPCAccessTokenRequest +from idpyoidc.message.oauth2 import TokenErrorResponse +from idpyoidc.server import Server +from idpyoidc.server.authn_event import create_authn_event +from idpyoidc.server.configure import ASConfiguration +from idpyoidc.server.exception import InvalidToken +from idpyoidc.server.token import handler +from idpyoidc.time_util import utc_time_sans_frac +from tests import CRYPT_CONFIG + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +CLIENT_KEYJAR = build_keyjar(KEYDEFS) + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["email"], + state="STATE", + response_type="code", +) + +TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", +) + +REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( + grant_type="refresh_token", client_id="client_1", client_secret="hemligt" +) + +TOKEN_REQ_DICT = TOKEN_REQ.to_dict() + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +class TestEndpoint(object): + + @pytest.fixture(autouse=True) + def create_endpoint(self): + conf = { + "issuer": "https://example.com/", + 'userinfo': { + "class": "idpyoidc.server.user_info.UserInfo", + "kwargs": {"db_file": full_path("users.json")}, + } + } + server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + context = server.context + context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], + } + server.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") + self.session_manager = context.session_manager + self.token_endpoint = server.get_endpoint("token") + self.user_id = "diana" + self.context = context + + def test_init(self): + assert self.token_endpoint + + def _create_session(self, auth_req, sub_type="public", sector_identifier=""): + if sector_identifier: + authz_req = auth_req.copy() + authz_req["sector_identifier_uri"] = sector_identifier + else: + authz_req = auth_req + client_id = authz_req["client_id"] + ae = create_authn_event(self.user_id) + return self.session_manager.create_session( + ae, authz_req, self.user_id, client_id=client_id, sub_type=sub_type + ) + + def _mint_code(self, grant, client_id): + session_id = self.session_manager.encrypted_session_id(self.user_id, client_id, grant.id) + usage_rules = grant.usage_rules.get("authorization_code", {}) + _exp_in = usage_rules.get("expires_in") + + # Constructing an authorization code is now done + _code = grant.mint_token( + session_id=session_id, + context=self.context, + token_class="authorization_code", + token_handler=self.session_manager.token_handler["authorization_code"], + usage_rules=usage_rules, + ) + + if _exp_in: + if isinstance(_exp_in, str): + _exp_in = int(_exp_in) + if _exp_in: + _code.expires_at = utc_time_sans_frac() + _exp_in + return _code + + def _mint_access_token(self, grant, session_id, token_ref=None): + _session_info = self.session_manager.get_session_info(session_id) + usage_rules = grant.usage_rules.get("access_token", {}) + _exp_in = usage_rules.get("expires_in", 0) + + _token = grant.mint_token( + _session_info, + context=self.context, + token_class="access_token", + token_handler=self.session_manager.token_handler["access_token"], + based_on=token_ref, # Means the token (tok) was used to mint this token + usage_rules=usage_rules, + ) + if isinstance(_exp_in, str): + _exp_in = int(_exp_in) + if _exp_in: + _token.expires_at = utc_time_sans_frac() + _exp_in + + return _token + + def test_parse(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + + assert set(_req.keys()).difference(set(_token_request.keys())) == {"authenticated"} + + def test_auth_code_grant_disallowed_per_client(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["email"] + self.context.cdb["client_1"]["grant_types_supported"] = [] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _cntx = self.context + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + assert isinstance(_req, TokenErrorResponse) + assert _req.to_dict() == { + "error": "invalid_request", + "error_description": "Unsupported grant_type: authorization_code", + } + + def test_process_request(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _context = self.context + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + assert _resp + assert set(_resp.keys()) == {"cookie", "http_headers", "response_args"} + + def test_process_request_using_code_twice(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _context = self.context + _token_request["code"] = code.value + + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + # 2nd time used + _2nd_response = self.token_endpoint.parse_request(_token_request) + assert "error" in _2nd_response + + def test_do_response(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + + _resp = self.token_endpoint.process_request(request=_req) + msg = self.token_endpoint.do_response(request=_req, **_resp) + assert isinstance(msg, dict) + + def test_process_request_using_private_key_jwt(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + del _token_request["client_id"] + del _token_request["client_secret"] + _context = self.context + + _jwt = JWT(CLIENT_KEYJAR, iss=AUTH_REQ["client_id"], sign_alg="RS256") + _jwt.with_jti = True + _assertion = _jwt.pack({"aud": [self.token_endpoint.full_path]}) + _token_request.update({"client_assertion": _assertion, "client_assertion_type": JWT_BEARER}) + _token_request["code"] = code.value + + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + # 2nd time used + with pytest.raises(InvalidToken): + self.token_endpoint.parse_request(_token_request) + + def test_do_refresh_access_token(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["email", "foobar"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _cntx = self.context + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + + _token_value = _resp["response_args"]["refresh_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="refresh_token" + ) + _token = self.session_manager.find_token(_session_info["branch_id"], _token_value) + _token.usage_rules["supports_minting"] = ["access_token", "refresh_token"] + + _req = self.token_endpoint.parse_request(_request.to_json()) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + assert set(_resp.keys()) == {"cookie", "response_args", "http_headers"} + assert set(_resp["response_args"].keys()) == { + "access_token", + "token_type", + "expires_in", + "refresh_token", + "scope", + } + msg = self.token_endpoint.do_response(request=_req, **_resp) + assert isinstance(msg, dict) + + def test_refresh_grant_disallowed_per_client(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["email"] + self.context.cdb["client_1"]["grant_types_supported"] = ["authorization_code"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _cntx = self.context + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + assert "refresh_token" not in _resp + + def test_do_2nd_refresh_access_token(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["email"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + self.token_endpoint.revoke_refresh_on_issue = False + _cntx = self.context + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + + # Make sure ID Tokens can also be used by this refesh token + _token_value = _resp["response_args"]["refresh_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="refresh_token" + ) + _token = self.session_manager.find_token(_session_info["branch_id"], _token_value) + _token.usage_rules["supports_minting"] = [ + "access_token", + "refresh_token", + ] + + _req = self.token_endpoint.parse_request(_request.to_json()) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + _2nd_request = REFRESH_TOKEN_REQ.copy() + _2nd_request["refresh_token"] = _resp["response_args"]["refresh_token"] + _2nd_req = self.token_endpoint.parse_request(_request.to_json()) + _2nd_resp = self.token_endpoint.process_request(request=_2nd_req, issue_refresh=True) + assert set(_2nd_resp.keys()) == {"cookie", "response_args", "http_headers"} + assert set(_2nd_resp["response_args"].keys()) == { + "access_token", + "token_type", + "expires_in", + "refresh_token", + "scope", + } + msg = self.token_endpoint.do_response(request=_req, **_resp) + assert isinstance(msg, dict) + + def test_new_refresh_token(self): + self.context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], + } + + areq = AUTH_REQ.copy() + areq["scope"] = ["email"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + assert "refresh_token" in _resp["response_args"] + first_refresh_token = _resp["response_args"]["refresh_token"] + + _refresh_request = REFRESH_TOKEN_REQ.copy() + _refresh_request["refresh_token"] = first_refresh_token + _2nd_req = self.token_endpoint.parse_request(_refresh_request.to_json()) + _2nd_resp = self.token_endpoint.process_request(request=_2nd_req, issue_refresh=True) + assert "refresh_token" in _2nd_resp["response_args"] + second_refresh_token = _2nd_resp["response_args"]["refresh_token"] + + _2d_refresh_request = REFRESH_TOKEN_REQ.copy() + _2d_refresh_request["refresh_token"] = second_refresh_token + _3rd_req = self.token_endpoint.parse_request(_2d_refresh_request.to_json()) + _3rd_resp = self.token_endpoint.process_request(request=_3rd_req, issue_refresh=True) + assert "access_token" in _3rd_resp["response_args"] + assert "refresh_token" in _3rd_resp["response_args"] + + assert first_refresh_token != second_refresh_token + + def test_revoke_on_issue_refresh_token(self): + self.context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], + } + + self.token_endpoint.revoke_refresh_on_issue = True + areq = AUTH_REQ.copy() + areq["scope"] = ["email"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + assert "refresh_token" in _resp["response_args"] + first_refresh_token = _resp["response_args"]["refresh_token"] + + _refresh_request = REFRESH_TOKEN_REQ.copy() + _refresh_request["refresh_token"] = first_refresh_token + _2nd_req = self.token_endpoint.parse_request(_refresh_request.to_json()) + _2nd_resp = self.token_endpoint.process_request(request=_2nd_req, issue_refresh=True) + assert "refresh_token" in _2nd_resp["response_args"] + second_refresh_token = _2nd_resp["response_args"]["refresh_token"] + + assert first_refresh_token != second_refresh_token + first_refresh_token = grant.get_token(first_refresh_token) + second_refresh_token = grant.get_token(second_refresh_token) + assert first_refresh_token.revoked is True + assert second_refresh_token.revoked is False + + def test_revoke_on_issue_refresh_token_per_client(self): + self.context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], + } + self.context.cdb[AUTH_REQ["client_id"]]["revoke_refresh_on_issue"] = True + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + assert "refresh_token" in _resp["response_args"] + first_refresh_token = _resp["response_args"]["refresh_token"] + + _refresh_request = REFRESH_TOKEN_REQ.copy() + _refresh_request["refresh_token"] = first_refresh_token + _2nd_req = self.token_endpoint.parse_request(_refresh_request.to_json()) + _2nd_resp = self.token_endpoint.process_request(request=_2nd_req, issue_refresh=True) + assert "refresh_token" in _2nd_resp["response_args"] + second_refresh_token = _2nd_resp["response_args"]["refresh_token"] + + _2d_refresh_request = REFRESH_TOKEN_REQ.copy() + _2d_refresh_request["refresh_token"] = second_refresh_token + + assert first_refresh_token != second_refresh_token + first_refresh_token = grant.get_token(first_refresh_token) + second_refresh_token = grant.get_token(second_refresh_token) + assert first_refresh_token.revoked is True + assert second_refresh_token.revoked is False + + def test_refresh_scopes(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["email", "profile"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + _request["scope"] = ["email"] + + _req = self.token_endpoint.parse_request(_request.to_json()) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + assert set(_resp.keys()) == {"cookie", "response_args", "http_headers"} + assert set(_resp["response_args"].keys()) == { + "access_token", + "token_type", + "expires_in", + "refresh_token", + "scope", + } + + _token_value = _resp["response_args"]["access_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="access_token" + ) + at = self.session_manager.find_token(_session_info["branch_id"], _token_value) + rt = self.session_manager.find_token( + _session_info["branch_id"], _resp["response_args"]["refresh_token"] + ) + + assert at.scope == rt.scope == _request["scope"] == _resp["response_args"]["scope"] + + def test_refresh_more_scopes(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["email"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + _request["scope"] = ["ema"] + + _req = self.token_endpoint.parse_request(_request.to_json()) + assert isinstance(_req, TokenErrorResponse) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + assert _resp.to_dict() == { + "error": "invalid_request", + "error_description": "Invalid refresh scopes", + } + + def test_refresh_more_scopes_2(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["email", "profile"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + _request["scope"] = ["email"] + + _token_value = _resp["response_args"]["refresh_token"] + + _req = self.token_endpoint.parse_request(_request.to_json()) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + _token_value = _resp["response_args"]["refresh_token"] + _request["refresh_token"] = _token_value + # We should be able to request the original requests scopes + _request["scope"] = ["email", "profile"] + + _req = self.token_endpoint.parse_request(_request.to_json()) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + assert set(_resp.keys()) == {"cookie", "response_args", "http_headers"} + assert set(_resp["response_args"].keys()) == { + "access_token", + "token_type", + "expires_in", + "refresh_token", + "scope", + } + + _token_value = _resp["response_args"]["access_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="access_token" + ) + at = self.session_manager.find_token(_session_info["branch_id"], _token_value) + rt = self.session_manager.find_token( + _session_info["branch_id"], _resp["response_args"]["refresh_token"] + ) + + assert at.scope == rt.scope == _request["scope"] == _resp["response_args"]["scope"] + + def test_do_refresh_access_token_not_allowed(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["email"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _cntx = self.token_endpoint.upstream_get("context") + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + # This is weird, issuing a refresh token that can't be used to mint anything + # but it's testing so anything goes. + grant.usage_rules["refresh_token"] = {"supports_minting": []} + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + _req = self.token_endpoint.parse_request(_request.to_json()) + res = self.token_endpoint.process_request(_req) + assert "error" in res + assert res["error_description"] == "Minting of access_token not supported" + + def test_do_refresh_access_token_revoked(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["email"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _cntx = self.token_endpoint.upstream_get("context") + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + _refresh_token = _resp["response_args"]["refresh_token"] + _cntx.session_manager.revoke_token(session_id, _refresh_token) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _refresh_token + _req = self.token_endpoint.parse_request(_request.to_json()) + # A revoked token is caught already when parsing the query. + assert isinstance(_req, TokenErrorResponse) + + def test_configure_grant_types(self): + conf = {"access_token": {"class": "idpyoidc.server.oidc.token.AccessTokenHelper"}} + + _helper = self.token_endpoint.configure_types( + conf, self.token_endpoint.helper_by_grant_type + ) + + assert len(_helper) == 1 + assert "access_token" in _helper + assert "refresh_token" not in _helper + + def test_token_request_other_client(self): + _context = self.context + _context.cdb["client_2"] = _context.cdb["client_1"] + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["client_id"] = "client_2" + _token_request["code"] = code.value + + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + assert isinstance(_resp, TokenErrorResponse) + assert _resp.to_dict() == {"error": "invalid_grant", "error_description": "Wrong client"} + + def test_refresh_token_request_other_client(self): + _context = self.context + _context.cdb["client_2"] = _context.cdb["client_1"] + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + _request = REFRESH_TOKEN_REQ.copy() + _request["client_id"] = "client_2" + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + + _token_value = _resp["response_args"]["refresh_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="refresh_token" + ) + _token = self.session_manager.find_token(_session_info["branch_id"], _token_value) + _token.usage_rules["supports_minting"] = ["access_token", "refresh_token"] + + _req = self.token_endpoint.parse_request(_request.to_json()) + _resp = self.token_endpoint.process_request( + request=_req, + ) + assert isinstance(_resp, TokenErrorResponse) + assert _resp.to_dict() == {"error": "invalid_grant", "error_description": "Wrong client"} + + +DEFAULT_TOKEN_HANDLER_ARGS = { + "jwks_file": "private/token_jwks.json", + "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + }, + }, +} +TOKEN_HANDLER_ARGS = { + "jwks_file": "private/token_jwks.json", + "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + "profile": "idpyoidc.message.oauth2.JWTAccessToken", + "with_jti": True, + }, + }, + "refresh": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + }, + }, +} + +CONTEXT = OidcContext() +CONTEXT.cwd = BASEDIR +CONTEXT.issuer = "https://op.example.com" +CONTEXT.cdb = {"client_1": {}} +KEYJAR = KeyJar() +KEYJAR.import_jwks(CLIENT_KEYJAR.export_jwks(private=True), "client_1") +KEYJAR.import_jwks(CLIENT_KEYJAR.export_jwks(private=True), "") + + +def upstream_get(what, *args): + if what == "context": + if not args: + return CONTEXT + elif what == "attribute": + if args[0] == "keyjar": + return KEYJAR + + +def test_def_jwttoken(): + _handler = handler.factory(upstream_get=upstream_get, **DEFAULT_TOKEN_HANDLER_ARGS) + token_handler = _handler["access_token"] + token_payload = {"sub": "subject_id", "aud": "resource_1", "client_id": "client_1"} + value = token_handler(session_id="session_id", **token_payload) + + _jws = factory(value) + msg = JWTAccessToken(**_jws.jwt.payload()) + # test if all required claims are there + msg.verify() + assert True + + +def test_jwttoken(): + _handler = handler.factory(upstream_get=upstream_get, **TOKEN_HANDLER_ARGS) + token_handler = _handler["access_token"] + token_payload = {"sub": "subject_id", "aud": "resource_1", "client_id": "client_1"} + value = token_handler(session_id="session_id", **token_payload) + + _jws = factory(value) + msg = JWTAccessToken(**_jws.jwt.payload()) + # test if all required claims are there + msg.verify() + assert True + + +class MyAccessToken(Message): + c_param = { + "iss": SINGLE_REQUIRED_STRING, + "exp": SINGLE_REQUIRED_INT, + "aud": REQUIRED_LIST_OF_STRINGS, + "sub": SINGLE_REQUIRED_STRING, + "iat": SINGLE_REQUIRED_INT, + "usage": SINGLE_REQUIRED_STRING, + } + + +def test_jwttoken_2(): + _handler = handler.factory(upstream_get=upstream_get, **TOKEN_HANDLER_ARGS) + token_handler = _handler["access_token"] + token_payload = {"sub": "subject_id", "aud": "Skiresort", "usage": "skilift"} + value = token_handler(session_id="session_id", profile=MyAccessToken, **token_payload) + + _jws = factory(value) + msg = MyAccessToken(**_jws.jwt.payload()) + # test if all required claims are there + msg.verify() + assert True + + +class TestClientCredentialsFlow(object): + + @pytest.fixture(autouse=True) + def create_endpoint(self): + conf = { + "issuer": "https://example.com/", + 'userinfo': { + "class": "idpyoidc.server.user_info.UserInfo", + "kwargs": {"db_file": full_path("users.json")}, + } + } + + server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + context = server.context + context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], + "grant_types_supported": ["client_credentials", "password"], + } + self.session_manager = context.session_manager + self.token_endpoint = server.get_endpoint("token") + self.user_id = "diana" + self.context = context + + def test_client_credentials(self): + request = CCAccessTokenRequest( + client_id="client_1", + client_secret="hemligt", + grant_type="client_credentials", + scope="whatever", + ) + request = self.token_endpoint.parse_request(request) + response = self.token_endpoint.process_request(request) + assert set(response.keys()) == {"response_args", "cookie", "http_headers"} + assert set(response["response_args"].keys()) == { + "access_token", + "token_type", + "scope", + "expires_in", + } + + +class TestResourceOwnerPasswordCredentialsFlow(object): + + @pytest.fixture(autouse=True) + def create_endpoint(self): + conf = { + "issuer": "https://example.com/", + 'userinfo': { + "class": "idpyoidc.server.user_info.UserInfo", + "kwargs": {"db_file": full_path("users.json")}, + }, + "authentication": { + "user": { + "acr": "urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword", + "class": "idpyoidc.server.user_authn.user.UserPass", + "kwargs": { + "db_conf": { + "class": "idpyoidc.server.util.JSONDictDB", + "kwargs": {"filename": "passwd.json"}, + } + }, + } + }} + + server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + context = server.context + context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], + "grant_types_supported": ["client_credentials", "password"], + } + self.session_manager = context.session_manager + self.token_endpoint = server.get_endpoint("token") + self.context = context + + def test_resource_owner_password_credentials(self): + request = ROPCAccessTokenRequest( + client_id="client_1", + client_secret="hemligt", + grant_type="password", + username="diana", + password="krall", + scope="whatever", + ) + request = self.token_endpoint.parse_request(request) + response = self.token_endpoint.process_request(request) + assert set(response.keys()) == {"response_args", "cookie", "http_headers"} + assert set(response["response_args"].keys()) == { + "access_token", + "token_type", + "scope", + "expires_in", + } diff --git a/tests/test_server_24_oidc_authorization_endpoint.py b/tests/test_server_24_oidc_authorization_endpoint.py index fc0bcca8..dd53ac8a 100755 --- a/tests/test_server_24_oidc_authorization_endpoint.py +++ b/tests/test_server_24_oidc_authorization_endpoint.py @@ -64,13 +64,8 @@ RESPONSE_TYPES_SUPPORTED = [ ["code"], - ["token"], ["id_token"], - ["code", "token"], ["code", "id_token"], - ["id_token", "token"], - ["code", "token", "id_token"], - ["none"], ] CAPABILITIES = { @@ -120,12 +115,10 @@ def full_path(local_file): - ['https://example.com/cb', ''] "client_salt": "salted" 'token_endpoint_auth_method': 'client_secret_post' - 'response_types': + response_types_supported: - 'code' - - 'token' - 'code id_token' - 'id_token' - - 'code id_token token' allowed_scopes: - 'openid' - 'profile' @@ -138,14 +131,14 @@ def full_path(local_file): redirect_uris: - ['https://app1.example.net/foo', ''] - ['https://app2.example.net/bar', ''] - response_types: + response_types_supported: - code client3: client_secret: '2222222222222222222222222222222222222222' redirect_uris: - ['https://127.0.0.1:8090/authz_cb/bobcat', ''] post_logout_redirect_uri: ['https://openidconnect.net/', ''] - response_types: + response_types_supported: - code allowed_scopes: - 'openid' @@ -294,9 +287,7 @@ def create_endpoint(self): _clients = yaml.safe_load(io.StringIO(client_yaml)) context.cdb = _clients["oidc_clients"] - server.keyjar.import_jwks( - server.keyjar.export_jwks(True, ""), conf["issuer"] - ) + server.keyjar.import_jwks(server.keyjar.export_jwks(True, ""), conf["issuer"]) self.context = context self.endpoint = server.get_endpoint("authorization") self.session_manager = context.session_manager @@ -377,14 +368,6 @@ def test_do_response_id_token(self): assert "code" not in _frag_msg assert "token" not in _frag_msg - def test_do_response_id_token_token(self): - _orig_req = AUTH_REQ_DICT.copy() - _orig_req["response_type"] = "id_token token" - _orig_req["nonce"] = "rnd_nonce" - _pr_resp = self.endpoint.parse_request(_orig_req) - assert isinstance(_pr_resp, AuthorizationErrorResponse) - assert _pr_resp["error"] == "invalid_request" - def test_do_response_code_token(self): _orig_req = AUTH_REQ_DICT.copy() _orig_req["response_type"] = "code token" @@ -409,23 +392,6 @@ def test_do_response_code_id_token(self): assert "code" in _frag_msg assert "access_token" not in _frag_msg - def test_do_response_code_id_token_token(self): - _orig_req = AUTH_REQ_DICT.copy() - _orig_req["response_type"] = "code id_token token" - _orig_req["nonce"] = "rnd_nonce" - _pr_resp = self.endpoint.parse_request(_orig_req) - _resp = self.endpoint.process_request(_pr_resp) - msg = self.endpoint.do_response(**_resp) - assert isinstance(msg, dict) - part = urlparse(msg["response"]) - assert part.query == "" - assert part.fragment - _frag_msg = parse_qs(part.fragment) - assert _frag_msg - assert "id_token" in _frag_msg - assert "code" in _frag_msg - assert "access_token" in _frag_msg - def test_id_token_claims(self): _req = AUTH_REQ_DICT.copy() _req["claims"] = CLAIMS @@ -434,8 +400,7 @@ def test_id_token_claims(self): _pr_resp = self.endpoint.parse_request(_req) _resp = self.endpoint.process_request(_pr_resp) idt = verify_id_token( - _resp["response_args"], - keyjar=self.endpoint.upstream_get("attribute","keyjar") + _resp["response_args"], keyjar=self.endpoint.upstream_get("attribute", "keyjar") ) assert idt # from config @@ -454,13 +419,13 @@ def test_id_token_acr(self): _req["claims"] = { "id_token": {"acr": {"value": "http://www.swamid.se/policy/assurance/al1"}} } - _req["response_type"] = "code id_token token" + _req["response_type"] = "code id_token" _req["nonce"] = "rnd_nonce" _pr_resp = self.endpoint.parse_request(_req) _resp = self.endpoint.process_request(_pr_resp) res = verify_id_token( _resp["response_args"], - keyjar=self.endpoint.upstream_get("attribute","keyjar"), + keyjar=self.endpoint.upstream_get("attribute", "keyjar"), ) assert res res = _resp["response_args"][verified_claim_name("id_token")] @@ -607,7 +572,7 @@ def test_create_authn_response_id_token(self): "client_id": "client_id", "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "ES256", - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } session_id = self._create_session(request) @@ -635,7 +600,7 @@ def test_create_authn_response_id_token_request_claims(self): "client_id": "client_id", "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "ES256", - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } session_id = self._create_session(request) @@ -830,9 +795,7 @@ def test_setup_auth_login_hint2acrs(self): "kwargs": {"user": "knoll"}, "class": NoAuthn, } - self.endpoint.upstream_get("context").authn_broker["foo"] = init_method( - method_spec, None - ) + self.endpoint.upstream_get("context").authn_broker["foo"] = init_method(method_spec, None) item = self.endpoint.upstream_get("context").authn_broker.db["anon"] item["method"].fail = NoSuchAuthentication @@ -864,15 +827,17 @@ def test_parse_request(self): "scope": AUTH_REQ.get("scope"), } ) - assert set(_req.keys()) == {'__verified_request', - 'aud', - 'client_id', - 'iat', - 'iss', - 'redirect_uri', - 'response_type', - 'scope', - 'state'} + assert set(_req.keys()) == { + "__verified_request", + "aud", + "client_id", + "iat", + "iss", + "redirect_uri", + "response_type", + "scope", + "state", + } def test_parse_request_uri(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") @@ -901,7 +866,7 @@ def test_verify_response_type(self): request = AuthorizationRequest( client_id="client_id", redirect_uri="https://rp.example.com/cb", - response_type=["id_token token"], + response_type=["id_token"], state="state", nonce="nonce", scope="openid", @@ -915,32 +880,14 @@ def test_verify_response_type(self): assert self.endpoint.verify_response_type(request, client_info) is False - client_info["response_types"] = [ + client_info["response_types_supported"] = [ "code", "code id_token", "id_token", - "id_token token", ] assert self.endpoint.verify_response_type(request, client_info) is True - # @pytest.mark.parametrize("exp_in", [360, "360", 0]) - # def test_mint_token_exp_at(self, exp_in): - # request = AuthorizationRequest( - # client_id="client_1", - # response_type=["code"], - # redirect_uri="https://example.com/cb", - # state="state", - # scope="openid", - # ) - # self.session_manager.set(["user_id", "client_id", "grant.id"], grant) - # - # code = self.endpoint.mint_token("authorization_code", grant, sid) - # if exp_in in [360, "360"]: - # assert code.expires_at - # else: - # assert code.expires_at == 0 - def test_do_request_uri(self): request = AuthorizationRequest( redirect_uri="https://rp.example.com/cb", @@ -950,7 +897,7 @@ def test_do_request_uri(self): orig_request = AuthorizationRequest( client_id="client_id", redirect_uri="https://rp.example.com/cb", - response_type=["id_token token"], + response_type=["id_token"], state="state", nonce="nonce", scope="openid", @@ -1244,9 +1191,7 @@ def create_endpoint(self): _clients = yaml.safe_load(io.StringIO(client_yaml)) context.cdb = _clients["oidc_clients"] - server.keyjar.import_jwks( - server.keyjar.export_jwks(True, ""), conf["issuer"] - ) + server.keyjar.import_jwks(server.keyjar.export_jwks(True, ""), conf["issuer"]) self.endpoint = server.get_endpoint("authorization") self.session_manager = context.session_manager self.user_id = "diana" @@ -1472,11 +1417,14 @@ def test_authenticated_as_with_goobledigook(self): client_id=authn_req["client_id"], ) - kakor = [{ - 'value': '{"sub": "adam", "sid": "Z0FBQUFBQmlhVl", "state": "state_identifier", ' - '"client_id": "client 12345"}', - 'type': '', - 'timestamp': '1651070251'}] + kakor = [ + { + "value": '{"sub": "adam", "sid": "Z0FBQUFBQmlhVl", "state": "state_identifier", ' + '"client_id": "client 12345"}', + "type": "", + "timestamp": "1651070251", + } + ] _info, _time_stamp = method.authenticated_as(client_id="client 12345", cookie=kakor) assert _info == {} diff --git a/tests/test_server_26_oidc_userinfo_endpoint.py b/tests/test_server_26_oidc_userinfo_endpoint.py index bef8921b..50313ca4 100755 --- a/tests/test_server_26_oidc_userinfo_endpoint.py +++ b/tests/test_server_26_oidc_userinfo_endpoint.py @@ -21,30 +21,22 @@ from idpyoidc.server.scopes import SCOPE2CLAIMS from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from idpyoidc.server.user_info import UserInfo -from idpyoidc.server.oidc.userinfo import validate_userinfo_policy from idpyoidc.time_util import utc_time_sans_frac from tests import CRYPT_CONFIG from tests import SESSION_PARAMS - KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, ] -RESPONSE_TYPES_SUPPORTED = [ - ["code"], - ["token"], - ["id_token"], - ["code", "token"], - ["code", "id_token"], - ["id_token", "token"], - ["code", "token", "id_token"], - ["none"], -] +# RESPONSE_TYPES_SUPPORTED = [ +# ["code"], +# ["id_token"], +# ["code", "id_token"], +# ] -CAPABILITIES = { -} +CAPABILITIES = {} AUTH_REQ = AuthorizationRequest( client_id="client_1", @@ -81,7 +73,7 @@ def create_endpoint(self): "issuer": "https://example.com/", "httpc_params": {"verify": False, "timeout": 1}, "subject_types_supported": ["public", "pairwise", "ephemeral"], - 'claims_supported': [ + "claims_supported": [ "address", "birthdate", "email", @@ -102,7 +94,8 @@ def create_endpoint(self): "sub", "updated_at", "website", - "zoneinfo"], + "zoneinfo", + ], "grant_types_supported": [ "authorization_code", "implicit", @@ -214,8 +207,16 @@ def create_endpoint(self): "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", "token_endpoint_auth_method": "client_secret_post", - "response_types": ["code", "token", "code id_token", "id_token"], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access", "research_and_scholarship"] + "response_types_supported": ["code", "code id_token", "id_token"], + "allowed_scopes": [ + "openid", + "profile", + "email", + "address", + "phone", + "offline_access", + "research_and_scholarship", + ], } self.endpoint = self.server.get_endpoint("userinfo") self.session_manager = self.context.session_manager @@ -256,31 +257,29 @@ def _mint_token(self, token_class, grant, session_id, token_ref=None): def test_init(self): assert self.endpoint - assert set( - self.endpoint.upstream_get("context").provider_info["claims_supported"] - ) == { - "address", - "birthdate", - "email", - "email_verified", - "eduperson_scoped_affiliation", - "family_name", - "gender", - "given_name", - "locale", - "middle_name", - "name", - "nickname", - "phone_number", - "phone_number_verified", - "picture", - "preferred_username", - "profile", - "sub", - "updated_at", - "website", - "zoneinfo", - } + assert set(self.endpoint.upstream_get("context").provider_info["claims_supported"]) == { + "address", + "birthdate", + "email", + "email_verified", + "eduperson_scoped_affiliation", + "family_name", + "gender", + "given_name", + "locale", + "middle_name", + "name", + "nickname", + "phone_number", + "phone_number_verified", + "picture", + "preferred_username", + "profile", + "sub", + "updated_at", + "website", + "zoneinfo", + } def test_parse(self): session_id = self._create_session(AUTH_REQ) @@ -468,7 +467,7 @@ def test_allowed_scopes(self): "email", "family_name", "name", - "sub" + "sub", } def test_allowed_scopes_per_client(self): @@ -634,11 +633,15 @@ def test_process_request_absent_userinfo_conf(self): ec = self.endpoint.upstream_get("context") ec.userinfo = None - session_id = self._create_session(AUTH_REQ) + _auth_req = AUTH_REQ.copy() + _auth_req["scope"] = ["openid", "email"] + + session_id = self._create_session(_auth_req) grant = self.session_manager[session_id] + code = self._mint_code(grant, session_id) with pytest.raises(ImproperlyConfigured): - code = self._mint_code(grant, session_id) + self._mint_token("access_token", grant, session_id, code) def test_userinfo_policy(self): _auth_req = AUTH_REQ.copy() @@ -675,10 +678,7 @@ def _custom_validate_userinfo_policy(request, token, response_info, **kwargs): return {"custom": "policy"} self.context.cdb["client_1"]["userinfo"] = { - "policy": { - "function": _custom_validate_userinfo_policy, - "kwargs": {} - } + "policy": {"function": _custom_validate_userinfo_policy, "kwargs": {}} } _req = self.endpoint.parse_request({}, http_info=http_info) @@ -687,4 +687,3 @@ def _custom_validate_userinfo_policy(request, token, response_info, **kwargs): res = self.endpoint.do_response(request=_req, **args) _response = json.loads(res["response"]) assert "custom" in _response - diff --git a/tests/test_server_30_oidc_end_session.py b/tests/test_server_30_oidc_end_session.py index b8fc9f7a..59487b74 100644 --- a/tests/test_server_30_oidc_end_session.py +++ b/tests/test_server_30_oidc_end_session.py @@ -45,18 +45,9 @@ KEYJAR = build_keyjar(KEYDEFS) KEYJAR.import_jwks(KEYJAR.export_jwks(private=True), ISS) -RESPONSE_TYPES_SUPPORTED = [ - ["code"], - ["token"], - ["id_token"], - ["code", "token"], - ["code", "id_token"], - ["id_token", "token"], - ["code", "token", "id_token"], - ["none"], -] +RESPONSE_TYPES_SUPPORTED = [["code"], ["id_token"], ["code", "id_token"]] -CAPABILITIES = { +PREFRERENCES = { "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], "token_endpoint_auth_methods_supported": [ "client_secret_post", @@ -106,7 +97,7 @@ def create_endpoint(self): "issuer": ISS, "password": "mycket hemlig zebra", "verify_ssl": False, - "capabilities": CAPABILITIES, + "preferences": PREFRERENCES, "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, "endpoint": { "provider_config": { @@ -207,18 +198,32 @@ def create_endpoint(self): "redirect_uris": [("{}cb".format(CLI1), None)], "client_salt": "salted", "token_endpoint_auth_method": "client_secret_post", - "response_types": ["code", "token", "code id_token", "id_token"], + "response_types_supported": ["code", "code id_token", "id_token"], "post_logout_redirect_uri": [f"{CLI1}logout_cb", ""], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": [ + "openid", + "profile", + "email", + "address", + "phone", + "offline_access", + ], }, "client_2": { "client_secret": "hemligare", "redirect_uris": [("{}cb".format(CLI2), None)], "client_salt": "saltare", "token_endpoint_auth_method": "client_secret_post", - "response_types": ["code", "token", "code id_token", "id_token"], + "response_types_supported": ["code", "code id_token", "id_token"], "post_logout_redirect_uri": [f"{CLI2}logout_cb", ""], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": [ + "openid", + "profile", + "email", + "address", + "phone", + "offline_access", + ], }, } self.context = context @@ -336,7 +341,7 @@ def test_end_session_endpoint_with_cookie_id_token_and_unknown_sid(self): http_info = {"cookie": [cookie]} msg = Message(id_token=id_token) - verify_id_token(msg, keyjar=self.session_endpoint.upstream_get("attribute",'keyjar')) + verify_id_token(msg, keyjar=self.session_endpoint.upstream_get("attribute", "keyjar")) msg2 = Message(id_token_hint=id_token) msg2[verified_claim_name("id_token_hint")] = msg[verified_claim_name("id_token")] @@ -403,7 +408,7 @@ def test_end_session_endpoint_with_wrong_post_logout_redirect_uri(self): post_logout_redirect_uri = "https://demo.example.com/log_out" msg = Message(id_token=id_token) - verify_id_token(msg, keyjar=self.session_endpoint.upstream_get("attribute",'keyjar')) + verify_id_token(msg, keyjar=self.session_endpoint.upstream_get("attribute", "keyjar")) with pytest.raises(RedirectURIError): self.session_endpoint.process_request( @@ -492,9 +497,7 @@ def test_logout_from_client_bc(self): self.session_endpoint.upstream_get("context").cdb["client_1"][ "backchannel_logout_uri" ] = "https://example.com/bc_logout" - self.session_endpoint.upstream_get("context").cdb["client_1"][ - "client_id" - ] = "client_1" + self.session_endpoint.upstream_get("context").cdb["client_1"]["client_id"] = "client_1" res = self.session_endpoint.logout_from_client(_session_info["branch_id"]) assert set(res.keys()) == {"blu"} @@ -522,9 +525,7 @@ def test_logout_from_client_fc(self): self.session_endpoint.upstream_get("context").cdb["client_1"][ "frontchannel_logout_uri" ] = "https://example.com/fc_logout" - self.session_endpoint.upstream_get("context").cdb["client_1"][ - "client_id" - ] = "client_1" + self.session_endpoint.upstream_get("context").cdb["client_1"]["client_id"] = "client_1" res = self.session_endpoint.logout_from_client(_session_info["branch_id"]) assert set(res.keys()) == {"flu"} @@ -557,15 +558,11 @@ def test_logout_from_client(self): self.session_endpoint.upstream_get("context").cdb["client_1"][ "backchannel_logout_uri" ] = "https://example.com/bc_logout" - self.session_endpoint.upstream_get("context").cdb["client_1"][ - "client_id" - ] = "client_1" + self.session_endpoint.upstream_get("context").cdb["client_1"]["client_id"] = "client_1" self.session_endpoint.upstream_get("context").cdb["client_2"][ "frontchannel_logout_uri" ] = "https://example.com/fc_logout" - self.session_endpoint.upstream_get("context").cdb["client_2"][ - "client_id" - ] = "client_2" + self.session_endpoint.upstream_get("context").cdb["client_2"]["client_id"] = "client_2" res = self.session_endpoint.logout_all_clients(_session_info["branch_id"]) @@ -635,15 +632,11 @@ def test_logout_from_client_no_session(self): self.session_endpoint.upstream_get("context").cdb["client_1"][ "backchannel_logout_uri" ] = "https://example.com/bc_logout" - self.session_endpoint.upstream_get("context").cdb["client_1"][ - "client_id" - ] = "client_1" + self.session_endpoint.upstream_get("context").cdb["client_1"]["client_id"] = "client_1" self.session_endpoint.upstream_get("context").cdb["client_2"][ "frontchannel_logout_uri" ] = "https://example.com/fc_logout" - self.session_endpoint.upstream_get("context").cdb["client_2"][ - "client_id" - ] = "client_2" + self.session_endpoint.upstream_get("context").cdb["client_2"]["client_id"] = "client_2" _uid, _cid, _gid = self.session_manager.decrypt_session_id(_session_info["branch_id"]) self.session_endpoint.upstream_get("context").session_manager.delete([_uid, _cid]) diff --git a/tests/test_server_31_oauth2_introspection.py b/tests/test_server_31_oauth2_introspection.py index ab5e6985..5e28c632 100644 --- a/tests/test_server_31_oauth2_introspection.py +++ b/tests/test_server_31_oauth2_introspection.py @@ -61,7 +61,7 @@ "claim_types_supported": ["normal", "aggregated", "distributed"], "claims_parameter_supported": True, "request_parameter_supported": True, - "request_uri_parameter_supported": True, + # "request_uri_parameter_supported": True, } AUTH_REQ = AuthorizationRequest( @@ -91,7 +91,6 @@ def full_path(local_file): @pytest.mark.parametrize("jwt_token", [True, False]) class TestEndpoint: - @pytest.fixture(autouse=True) def create_endpoint(self, jwt_token): conf = { @@ -133,6 +132,7 @@ def create_endpoint(self, jwt_token): "kwargs": { "client_authn_method": ["client_secret_post"], "enable_claims_per_client": False, + "enforce_audience_restriction": True, }, }, "token": { @@ -205,7 +205,7 @@ def create_endpoint(self, jwt_token): }, "by_scope": {}, }, - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } server.keyjar.import_jwks_as_json( server.keyjar.export_jwks_as_json(private=True), context.issuer @@ -266,16 +266,14 @@ def test_parse_with_client_auth_in_req(self): ) assert isinstance(_req, TokenIntrospectionRequest) - assert set(_req.keys()) == {"token", "client_id", "client_secret", 'authenticated'} + assert set(_req.keys()) == {"token", "client_id", "client_secret", "authenticated"} def test_parse_with_wrong_client_authn(self): access_token = self._get_access_token(AUTH_REQ) _basic_token = "{}:{}".format( "client_1", - self.introspection_endpoint.upstream_get("context").cdb["client_1"][ - "client_secret" - ], + self.introspection_endpoint.upstream_get("context").cdb["client_1"]["client_secret"], ) _basic_token = as_unicode(base64.b64encode(as_bytes(_basic_token))) _basic_authz = "Basic {}".format(_basic_token) diff --git a/tests/test_server_32_oidc_read_registration.py b/tests/test_server_32_oidc_read_registration.py index da6b19f2..bac0e207 100644 --- a/tests/test_server_32_oidc_read_registration.py +++ b/tests/test_server_32_oidc_read_registration.py @@ -75,7 +75,6 @@ class TestEndpoint(object): - @pytest.fixture(autouse=True) def create_endpoint(self): conf = { @@ -128,7 +127,9 @@ def create_endpoint(self): server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) self.registration_endpoint = server.get_endpoint("registration") self.registration_api_endpoint = server.get_endpoint("registration_read") - server.context.cdb["client_1"] = {} + server.context.cdb["client_1"] = { + "redirect_uris": [("https://example.com/cb", ""), ("https://example.com/2nd_cb", "")] + } def test_do_response(self): _req = self.registration_endpoint.parse_request(CLI_REQ.to_json()) @@ -150,7 +151,7 @@ def test_do_response(self): "client_id={}".format(_resp["response_args"]["client_id"]), http_info=http_info, ) - assert set(_api_req.keys()) == {"client_id", 'authenticated'} + assert set(_api_req.keys()) == {"client_id", "authenticated"} _info = self.registration_api_endpoint.process_request(request=_api_req) assert set(_info.keys()) == {"response_args"} diff --git a/tests/test_server_33_oauth2_pkce.py b/tests/test_server_33_oauth2_pkce.py index fbb40d9d..78fcbe3a 100644 --- a/tests/test_server_33_oauth2_pkce.py +++ b/tests/test_server_33_oauth2_pkce.py @@ -17,11 +17,10 @@ from idpyoidc.server.configure import ASConfiguration from idpyoidc.server.configure import OPConfiguration from idpyoidc.server.cookie_handler import CookieHandler -from idpyoidc.server.oidc.add_on.pkce import CC_METHOD -from idpyoidc.server.oidc.add_on.pkce import add_pkce_support +from idpyoidc.server.oauth2.add_on.pkce import CC_METHOD +from idpyoidc.server.oauth2.add_on.pkce import add_support from idpyoidc.server.oidc.authorization import Authorization from idpyoidc.server.oidc.token import Token - from . import CRYPT_CONFIG from . import SESSION_PARAMS from . import full_path @@ -93,10 +92,8 @@ def full_path(local_file): 'token_endpoint_auth_method': 'client_secret_post' 'response_types': - 'code' - - 'token' - 'code id_token' - 'id_token' - - 'code id_token token' allowed_scopes: - 'openid' - 'profile' @@ -172,7 +169,7 @@ def conf(): "template_dir": "template", "add_on": { "pkce": { - "function": "idpyoidc.server.oidc.add_on.pkce.add_pkce_support", + "function": "idpyoidc.server.oauth2.add_on.pkce.add_support", "kwargs": {"essential": True}, } }, @@ -418,7 +415,7 @@ def test_missing_authz_endpoint(): } configuration = OPConfiguration(conf, base_path=BASEDIR, domain="127.0.0.1", port=443) server = Server(configuration) - add_pkce_support(server.get_endpoints()) + add_support(server.get_endpoints()) assert "pkce" not in server.get_context().args @@ -443,6 +440,6 @@ def test_missing_token_endpoint(): } configuration = OPConfiguration(conf, base_path=BASEDIR, domain="127.0.0.1", port=443) server = Server(configuration) - add_pkce_support(server.get_endpoints()) + add_support(server.get_endpoints()) assert "pkce" not in server.get_context().args diff --git a/tests/test_server_34_oidc_sso.py b/tests/test_server_34_oidc_sso.py index 6b4132f2..4090b511 100755 --- a/tests/test_server_34_oidc_sso.py +++ b/tests/test_server_34_oidc_sso.py @@ -199,9 +199,7 @@ def create_endpoint_context(self): context = server.context _clients = yaml.safe_load(io.StringIO(client_yaml)) context.cdb = _clients["oidc_clients"] - server.keyjar.import_jwks( - server.keyjar.export_jwks(True, ""), conf["issuer"] - ) + server.keyjar.import_jwks(server.keyjar.export_jwks(True, ""), conf["issuer"]) self.endpoint = server.get_endpoint("authorization") self.context = context self.rp_keyjar = KeyJar() @@ -272,9 +270,7 @@ def test_sso(self): # No valid login cookie so new session assert info["session_id"] != sid2 - user_session_info = self.endpoint.upstream_get("context").session_manager.get( - ["diana"] - ) + user_session_info = self.endpoint.upstream_get("context").session_manager.get(["diana"]) assert len(user_session_info.subordinate) == 3 assert set(user_session_info.subordinate) == { "diana;;client_1", @@ -285,15 +281,9 @@ def test_sso(self): # Should be one grant for each of client_2 and client_3 and # 2 grants for client_1 - csi1 = self.endpoint.upstream_get("context").session_manager.get( - ["diana", "client_1"] - ) - csi2 = self.endpoint.upstream_get("context").session_manager.get( - ["diana", "client_2"] - ) - csi3 = self.endpoint.upstream_get("context").session_manager.get( - ["diana", "client_3"] - ) + csi1 = self.endpoint.upstream_get("context").session_manager.get(["diana", "client_1"]) + csi2 = self.endpoint.upstream_get("context").session_manager.get(["diana", "client_2"]) + csi3 = self.endpoint.upstream_get("context").session_manager.get(["diana", "client_3"]) assert len(csi1.subordinate) == 2 assert len(csi2.subordinate) == 1 diff --git a/tests/test_server_35_oidc_token_endpoint.py b/tests/test_server_35_oidc_token_endpoint.py index c6141261..0f18fea7 100755 --- a/tests/test_server_35_oidc_token_endpoint.py +++ b/tests/test_server_35_oidc_token_endpoint.py @@ -198,7 +198,6 @@ def conf(): class TestEndpoint(_TestEndpoint): - @pytest.fixture(autouse=True) def create_endpoint(self, conf): self.server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) @@ -210,7 +209,7 @@ def create_endpoint(self, conf): "client_salt": "salted", "endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } self.server.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") context.userinfo = USERINFO @@ -284,7 +283,7 @@ def test_parse(self): _token_request["code"] = code.value _req = self.token_endpoint.parse_request(_token_request) - assert set(_req.keys()).difference(set(_token_request.keys())) == {'authenticated'} + assert set(_req.keys()).difference(set(_token_request.keys())) == {"authenticated"} def test_process_request(self): session_id = self._create_session(AUTH_REQ) @@ -394,7 +393,7 @@ def test_do_refresh_access_token(self): "scope", } AuthorizationResponse().from_jwt( - _resp["response_args"]["id_token"], self.server.get_attribute('keyjar'), sender="" + _resp["response_args"]["id_token"], self.server.get_attribute("keyjar"), sender="" ) msg = self.token_endpoint.do_response(request=_req, **_resp) @@ -774,7 +773,7 @@ def test_new_refresh_token(self, conf): "client_salt": "salted", "endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } areq = AUTH_REQ.copy() @@ -814,7 +813,7 @@ def test_revoke_on_issue_refresh_token(self, conf): "client_salt": "salted", "endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } self.token_endpoint.revoke_refresh_on_issue = True areq = AUTH_REQ.copy() @@ -854,7 +853,7 @@ def test_revoke_on_issue_refresh_token_per_client(self, conf): "client_salt": "salted", "endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } self.context.cdb[AUTH_REQ["client_id"]]["revoke_refresh_on_issue"] = True areq = AUTH_REQ.copy() @@ -938,8 +937,9 @@ def test_do_refresh_access_token_revoked(self): def test_configure_grant_types(self): conf = {"access_token": {"class": "idpyoidc.server.oidc.token.AccessTokenHelper"}} - _helper = self.token_endpoint.configure_types(conf, - self.token_endpoint.helper_by_grant_type) + _helper = self.token_endpoint.configure_types( + conf, self.token_endpoint.helper_by_grant_type + ) assert len(_helper) == 1 assert "access_token" in _helper @@ -1015,7 +1015,6 @@ def test_refresh_token_request_other_client(self): class TestOldTokens(object): - @pytest.fixture(autouse=True) def create_endpoint(self, conf): server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) @@ -1027,7 +1026,7 @@ def create_endpoint(self, conf): "client_salt": "salted", "endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } server.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") self.session_manager = context.session_manager @@ -1121,7 +1120,7 @@ def test_old_jwt_token(self): # payload.update(kwargs) _context = _handler.upstream_get("context") signer = JWT( - key_jar=_handler.upstream_get('attribute', 'keyjar'), + key_jar=_handler.upstream_get("attribute", "keyjar"), iss=_handler.issuer, lifetime=300, sign_alg=_handler.alg, diff --git a/tests/test_server_35_oidc_token_endpoint_def_conf.py b/tests/test_server_35_oidc_token_endpoint_def_conf.py new file mode 100755 index 00000000..62a4bef9 --- /dev/null +++ b/tests/test_server_35_oidc_token_endpoint_def_conf.py @@ -0,0 +1,867 @@ +import os + +import pytest +from cryptojwt import JWT +from cryptojwt.key_jar import build_keyjar + +from idpyoidc.client.defaults import DEFAULT_KEY_DEFS +from idpyoidc.defaults import JWT_BEARER +from idpyoidc.message.oidc import AccessTokenRequest +from idpyoidc.message.oidc import AuthorizationRequest +from idpyoidc.message.oidc import AuthorizationResponse +from idpyoidc.message.oidc import RefreshAccessTokenRequest +from idpyoidc.message.oidc import TokenErrorResponse +from idpyoidc.server import Server +from idpyoidc.server.authn_event import create_authn_event +from idpyoidc.server.configure import OPConfiguration +from idpyoidc.server.exception import InvalidToken +from idpyoidc.time_util import utc_time_sans_frac + +CLIENT_KEYJAR = build_keyjar(DEFAULT_KEY_DEFS) + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid"], + state="STATE", + response_type="code", +) + +TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", +) + +REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( + grant_type="refresh_token", client_id="client_1", client_secret="hemligt" +) + +TOKEN_REQ_DICT = TOKEN_REQ.to_dict() + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +class TestEndpoint(): + + @pytest.fixture(autouse=True) + def create_endpoint(self): + conf = { + "issuer": "https://example.com/", + 'userinfo': { + "class": "idpyoidc.server.user_info.UserInfo", + "kwargs": {"db_file": full_path("users.json")}, + } + } + self.server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + + context = self.server.context + context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], + } + self.server.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") + self.session_manager = context.session_manager + self.token_endpoint = self.server.get_endpoint("token") + self.user_id = "diana" + self.context = context + + def test_init(self): + assert self.token_endpoint + + def _create_session(self, auth_req, sub_type="public", sector_identifier=""): + if sector_identifier: + authz_req = auth_req.copy() + authz_req["sector_identifier_uri"] = sector_identifier + else: + authz_req = auth_req + client_id = authz_req["client_id"] + ae = create_authn_event(self.user_id) + return self.session_manager.create_session( + ae, authz_req, self.user_id, client_id=client_id, sub_type=sub_type + ) + + def _mint_code(self, grant, client_id): + session_id = self.session_manager.encrypted_session_id(self.user_id, client_id, grant.id) + usage_rules = grant.usage_rules.get("authorization_code", {}) + _exp_in = usage_rules.get("expires_in") + + # Constructing an authorization code + _code = grant.mint_token( + session_id=session_id, + context=self.context, + token_class="authorization_code", + token_handler=self.session_manager.token_handler["authorization_code"], + usage_rules=usage_rules, + ) + + if _exp_in: + if isinstance(_exp_in, str): + _exp_in = int(_exp_in) + if _exp_in: + _code.expires_at = utc_time_sans_frac() + _exp_in + return _code + + def _mint_access_token(self, grant, session_id, token_ref=None): + _session_info = self.session_manager.get_session_info(session_id) + usage_rules = grant.usage_rules.get("access_token", {}) + _exp_in = usage_rules.get("expires_in", 0) + + _token = grant.mint_token( + _session_info, + context=self.context, + token_class="access_token", + token_handler=self.session_manager.token_handler["access_token"], + based_on=token_ref, # Means the token (tok) was used to mint this token + usage_rules=usage_rules, + ) + if isinstance(_exp_in, str): + _exp_in = int(_exp_in) + if _exp_in: + _token.expires_at = utc_time_sans_frac() + _exp_in + + return _token + + def test_parse(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + + assert set(_req.keys()).difference(set(_token_request.keys())) == {"authenticated"} + + def test_process_request(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _context = self.context + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + assert _resp + assert set(_resp.keys()) == {"cookie", "http_headers", "response_args"} + assert "expires_in" in _resp["response_args"] + + def test_process_request_using_code_twice(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _context = self.context + _token_request["code"] = code.value + + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + # 2nd time used + _2nd_response = self.token_endpoint.parse_request(_token_request) + assert "error" in _2nd_response + + def test_do_response(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + + _resp = self.token_endpoint.process_request(request=_req) + msg = self.token_endpoint.do_response(request=_req, **_resp) + assert isinstance(msg, dict) + + def test_process_request_using_private_key_jwt(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + del _token_request["client_id"] + del _token_request["client_secret"] + _context = self.context + + _jwt = JWT(CLIENT_KEYJAR, iss=AUTH_REQ["client_id"], sign_alg="RS256") + _jwt.with_jti = True + _assertion = _jwt.pack({"aud": [self.token_endpoint.full_path]}) + _token_request.update({"client_assertion": _assertion, + "client_assertion_type": JWT_BEARER}) + _token_request["code"] = code.value + + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + # 2nd time used + with pytest.raises(InvalidToken): + self.token_endpoint.parse_request(_token_request) + + def test_do_refresh_access_token(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _cntx = self.context + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + + _token_value = _resp["response_args"]["refresh_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="refresh_token" + ) + _token = self.session_manager.find_token(_session_info["branch_id"], _token_value) + _token.usage_rules["supports_minting"] = [ + "access_token", + "refresh_token", + "id_token", + ] + + _req = self.token_endpoint.parse_request(_request.to_urlencoded()) + _resp = self.token_endpoint.process_request(request=_req) + assert set(_resp.keys()) == {"cookie", "response_args", "http_headers"} + assert set(_resp["response_args"].keys()) == { + "access_token", + "token_type", + "expires_in", + "refresh_token", + "id_token", + "scope", + } + AuthorizationResponse().from_jwt( + _resp["response_args"]["id_token"], self.server.get_attribute("keyjar"), sender="" + ) + + msg = self.token_endpoint.do_response(request=_req, **_resp) + assert isinstance(msg, dict) + + def test_do_2nd_refresh_access_token(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + self.token_endpoint.revoke_refresh_on_issue = False + _cntx = self.context + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + + # Make sure ID Tokens can also be used by this refresh token + _token_value = _resp["response_args"]["refresh_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="refresh_token" + ) + + _req = self.token_endpoint.parse_request(_request.to_urlencoded()) + _resp = self.token_endpoint.process_request(request=_req) + + _2nd_request = REFRESH_TOKEN_REQ.copy() + _2nd_request["refresh_token"] = _resp["response_args"]["refresh_token"] + _2nd_req = self.token_endpoint.parse_request(_request.to_urlencoded()) + _2nd_resp = self.token_endpoint.process_request(request=_req) + + assert set(_2nd_resp.keys()) == {"cookie", "response_args", "http_headers"} + assert set(_2nd_resp["response_args"].keys()) == { + "access_token", + "token_type", + "expires_in", + "refresh_token", + "id_token", + "scope", + } + AuthorizationResponse().from_jwt( + _2nd_resp["response_args"]["id_token"], self.server.keyjar, sender="" + ) + + msg = self.token_endpoint.do_response(request=_req, **_resp) + assert isinstance(msg, dict) + + def test_invalid_refresh(self): + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = "invalid" + + _req = self.token_endpoint.parse_request(_request.to_urlencoded()) + + assert isinstance(_req, TokenErrorResponse) + assert _req.to_dict() == { + "error": "invalid_grant", + "error_description": "Invalid refresh token", + } + + def test_refresh_scopes(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access", "profile"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + _request["scope"] = ["openid", "offline_access"] + + _token_value = _resp["response_args"]["refresh_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="refresh_token" + ) + _token = self.session_manager.find_token(_session_info["branch_id"], _token_value) + _token.usage_rules["supports_minting"] = [ + "access_token", + "refresh_token", + "id_token", + ] + + _req = self.token_endpoint.parse_request(_request.to_urlencoded()) + _resp = self.token_endpoint.process_request(request=_req) + assert set(_resp.keys()) == {"cookie", "response_args", "http_headers"} + assert set(_resp["response_args"].keys()) == { + "access_token", + "token_type", + "expires_in", + "refresh_token", + "id_token", + "scope", + } + AuthorizationResponse().from_jwt( + _resp["response_args"]["id_token"], + self.server.keyjar, + sender="", + ) + + _token_value = _resp["response_args"]["access_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="access_token" + ) + at = self.session_manager.find_token(_session_info["branch_id"], _token_value) + rt = self.session_manager.find_token( + _session_info["branch_id"], _resp["response_args"]["refresh_token"] + ) + + assert at.scope == rt.scope == _request["scope"] == _resp["response_args"]["scope"] + + def test_refresh_more_scopes(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + _request["scope"] = ["openid", "offline_access", "profile"] + + _token_value = _resp["response_args"]["refresh_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="refresh_token" + ) + _token = self.session_manager.find_token(_session_info["branch_id"], _token_value) + _token.usage_rules["supports_minting"] = [ + "access_token", + "refresh_token", + "id_token", + ] + + _req = self.token_endpoint.parse_request(_request.to_urlencoded()) + assert isinstance(_req, TokenErrorResponse) + _resp = self.token_endpoint.process_request(request=_req) + + assert _resp.to_dict() == { + "error": "invalid_request", + "error_description": "Invalid refresh scopes", + } + + def test_refresh_more_scopes_2(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access", "profile"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + _request["scope"] = ["openid", "offline_access"] + + _token_value = _resp["response_args"]["refresh_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="refresh_token" + ) + _token = self.session_manager.find_token(_session_info["branch_id"], _token_value) + _token.usage_rules["supports_minting"] = [ + "access_token", + "refresh_token", + "id_token", + ] + + _req = self.token_endpoint.parse_request(_request.to_urlencoded()) + _resp = self.token_endpoint.process_request(request=_req) + + _token_value = _resp["response_args"]["refresh_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="refresh_token" + ) + _token = self.session_manager.find_token(_session_info["branch_id"], _token_value) + _token.usage_rules["supports_minting"] = [ + "access_token", + "refresh_token", + "id_token", + ] + _request["refresh_token"] = _token_value + # We should be able to request the original requests scopes + _request["scope"] = ["openid", "offline_access", "profile"] + + _req = self.token_endpoint.parse_request(_request.to_urlencoded()) + _resp = self.token_endpoint.process_request(request=_req) + + assert set(_resp.keys()) == {"cookie", "response_args", "http_headers"} + assert set(_resp["response_args"].keys()) == { + "access_token", + "token_type", + "expires_in", + "refresh_token", + "id_token", + "scope", + } + AuthorizationResponse().from_jwt( + _resp["response_args"]["id_token"], + self.server.keyjar, + sender="", + ) + + _token_value = _resp["response_args"]["access_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="access_token" + ) + at = self.session_manager.find_token(_session_info["branch_id"], _token_value) + rt = self.session_manager.find_token( + _session_info["branch_id"], _resp["response_args"]["refresh_token"] + ) + + assert at.scope == rt.scope == _request["scope"] == _resp["response_args"]["scope"] + + def test_refresh_less_scopes(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access", "email"] + + self.session_manager.token_handler.handler["id_token"].kwargs["add_claims_by_scope"] = True + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + idtoken = AuthorizationResponse().from_jwt( + _resp["response_args"]["id_token"], + self.server.keyjar, + sender="", + ) + + assert "email" in idtoken + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + _request["scope"] = ["openid", "offline_access"] + + _token_value = _resp["response_args"]["refresh_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="refresh_token" + ) + _token = self.session_manager.find_token(_session_info["branch_id"], _token_value) + _token.usage_rules["supports_minting"] = [ + "access_token", + "refresh_token", + "id_token", + ] + + _req = self.token_endpoint.parse_request(_request.to_urlencoded()) + _resp = self.token_endpoint.process_request(request=_req) + idtoken = AuthorizationResponse().from_jwt( + _resp["response_args"]["id_token"], + self.server.keyjar, + sender="", + ) + + assert "email" not in idtoken + assert _resp["response_args"]["scope"] == ["openid", "offline_access"] + + def test_refresh_no_openid_scope(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + _request["scope"] = ["offline_access"] + + _token_value = _resp["response_args"]["refresh_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="refresh_token" + ) + _token = self.session_manager.find_token(_session_info["branch_id"], _token_value) + _token.usage_rules["supports_minting"] = [ + "access_token", + "refresh_token", + "id_token", + ] + + _req = self.token_endpoint.parse_request(_request.to_urlencoded()) + _resp = self.token_endpoint.process_request(request=_req) + + assert set(_resp.keys()) == {"cookie", "response_args", "http_headers"} + assert set(_resp["response_args"].keys()) == { + "access_token", + "token_type", + "expires_in", + "refresh_token", + "scope", + } + assert _resp["response_args"]["scope"] == ["offline_access"] + + def test_refresh_no_offline_access_scope(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + _request["scope"] = ["openid"] + + _token_value = _resp["response_args"]["refresh_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="refresh_token" + ) + _token = self.session_manager.find_token(_session_info["branch_id"], _token_value) + _token.usage_rules["supports_minting"] = [ + "access_token", + "refresh_token", + "id_token", + ] + + _req = self.token_endpoint.parse_request(_request.to_urlencoded()) + _resp = self.token_endpoint.process_request(request=_req) + + assert set(_resp.keys()) == {"cookie", "response_args", "http_headers"} + assert set(_resp["response_args"].keys()) == { + "access_token", + "token_type", + "expires_in", + "id_token", + "scope", + } + AuthorizationResponse().from_jwt( + _resp["response_args"]["id_token"], + self.server.keyjar, + sender="", + ) + assert _resp["response_args"]["scope"] == ["openid"] + + def test_new_refresh_token(self): + self.context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], + } + + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + assert "refresh_token" in _resp["response_args"] + first_refresh_token = _resp["response_args"]["refresh_token"] + + _refresh_request = REFRESH_TOKEN_REQ.copy() + _refresh_request["refresh_token"] = first_refresh_token + _2nd_req = self.token_endpoint.parse_request(_refresh_request.to_urlencoded()) + _2nd_resp = self.token_endpoint.process_request(request=_2nd_req) + assert "refresh_token" in _2nd_resp["response_args"] + second_refresh_token = _2nd_resp["response_args"]["refresh_token"] + + _2d_refresh_request = REFRESH_TOKEN_REQ.copy() + _2d_refresh_request["refresh_token"] = second_refresh_token + _3rd_req = self.token_endpoint.parse_request(_2d_refresh_request.to_urlencoded()) + _3rd_resp = self.token_endpoint.process_request(request=_3rd_req) + assert "access_token" in _3rd_resp["response_args"] + assert "refresh_token" in _3rd_resp["response_args"] + + assert first_refresh_token != second_refresh_token + + def test_revoke_on_issue_refresh_token(self): + self.context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], + } + self.token_endpoint.revoke_refresh_on_issue = True + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + assert "refresh_token" in _resp["response_args"] + first_refresh_token = _resp["response_args"]["refresh_token"] + + _refresh_request = REFRESH_TOKEN_REQ.copy() + _refresh_request["refresh_token"] = first_refresh_token + _2nd_req = self.token_endpoint.parse_request(_refresh_request.to_urlencoded()) + _2nd_resp = self.token_endpoint.process_request(request=_2nd_req, issue_refresh=True) + assert "refresh_token" in _2nd_resp["response_args"] + second_refresh_token = _2nd_resp["response_args"]["refresh_token"] + + _2d_refresh_request = REFRESH_TOKEN_REQ.copy() + _2d_refresh_request["refresh_token"] = second_refresh_token + + assert first_refresh_token != second_refresh_token + first_refresh_token = grant.get_token(first_refresh_token) + second_refresh_token = grant.get_token(second_refresh_token) + assert first_refresh_token.revoked is True + assert second_refresh_token.revoked is False + + def test_revoke_on_issue_refresh_token_per_client(self): + self.context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], + } + self.context.cdb[AUTH_REQ["client_id"]]["revoke_refresh_on_issue"] = True + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + assert "refresh_token" in _resp["response_args"] + first_refresh_token = _resp["response_args"]["refresh_token"] + + _refresh_request = REFRESH_TOKEN_REQ.copy() + _refresh_request["refresh_token"] = first_refresh_token + _2nd_req = self.token_endpoint.parse_request(_refresh_request.to_urlencoded()) + _2nd_resp = self.token_endpoint.process_request(request=_2nd_req, issue_refresh=True) + assert "refresh_token" in _2nd_resp["response_args"] + second_refresh_token = _2nd_resp["response_args"]["refresh_token"] + + _2d_refresh_request = REFRESH_TOKEN_REQ.copy() + _2d_refresh_request["refresh_token"] = second_refresh_token + + assert first_refresh_token != second_refresh_token + first_refresh_token = grant.get_token(first_refresh_token) + second_refresh_token = grant.get_token(second_refresh_token) + assert first_refresh_token.revoked is True + assert second_refresh_token.revoked is False + + def test_do_refresh_access_token_not_allowed(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _cntx = self.token_endpoint.upstream_get("context") + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + # This is weird, issuing a refresh token that can't be used to mint anything + # but it's testing so anything goes. + grant.usage_rules["refresh_token"] = {"supports_minting": []} + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + _req = self.token_endpoint.parse_request(_request.to_urlencoded()) + res = self.token_endpoint.process_request(_req) + assert "error" in res + + def test_do_refresh_access_token_revoked(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access"] + + session_id = self._create_session(areq) + grant = self.context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _cntx = self.token_endpoint.upstream_get("context") + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + _refresh_token = _resp["response_args"]["refresh_token"] + _cntx.session_manager.revoke_token(session_id, _refresh_token) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _refresh_token + _req = self.token_endpoint.parse_request(_request.to_urlencoded()) + # A revoked token is caught already when parsing the query. + assert isinstance(_req, TokenErrorResponse) + + def test_configure_grant_types(self): + conf = {"access_token": {"class": "idpyoidc.server.oidc.token.AccessTokenHelper"}} + + _helper = self.token_endpoint.configure_types( + conf, self.token_endpoint.helper_by_grant_type + ) + + assert len(_helper) == 1 + assert "access_token" in _helper + assert "refresh_token" not in _helper + + def test_access_token_lifetime(self): + lifetime = 100 + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + grant.usage_rules["access_token"] = {"expires_in": lifetime} + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + access_token = AccessTokenRequest().from_jwt( + _resp["response_args"]["access_token"], + self.server.keyjar, + sender="", + ) + + assert access_token["exp"] - access_token["iat"] == lifetime + + def test_token_request_other_client(self): + _context = self.context + _context.cdb["client_2"] = _context.cdb["client_1"] + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["client_id"] = "client_2" + _token_request["code"] = code.value + + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + assert isinstance(_resp, TokenErrorResponse) + assert _resp.to_dict() == {"error": "invalid_grant", "error_description": "Wrong client"} + + def test_refresh_token_request_other_client(self): + _context = self.context + _context.cdb["client_2"] = _context.cdb["client_1"] + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + _request = REFRESH_TOKEN_REQ.copy() + _request["client_id"] = "client_2" + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + + _token_value = _resp["response_args"]["refresh_token"] + _session_info = self.session_manager.get_session_info_by_token( + _token_value, handler_key="refresh_token" + ) + _token = self.session_manager.find_token(_session_info["branch_id"], _token_value) + _token.usage_rules["supports_minting"] = ["access_token", "refresh_token"] + + _req = self.token_endpoint.parse_request(_request.to_urlencoded()) + _resp = self.token_endpoint.process_request( + request=_req, + ) + assert isinstance(_resp, TokenErrorResponse) + assert _resp.to_dict() == {"error": "invalid_grant", "error_description": "Wrong client"} diff --git a/tests/test_server_36_oauth2_token_exchange.py b/tests/test_server_36_oauth2_token_exchange.py index e1cf6615..5b3a5663 100644 --- a/tests/test_server_36_oauth2_token_exchange.py +++ b/tests/test_server_36_oauth2_token_exchange.py @@ -186,7 +186,7 @@ def create_endpoint(self): "authorization_code", "urn:ietf:params:oauth:grant-type:jwt-bearer", "refresh_token", - "urn:ietf:params:oauth:grant-type:token-exchange" + "urn:ietf:params:oauth:grant-type:token-exchange", ], "response_types": ["code", "token", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "offline_access"], @@ -267,7 +267,7 @@ def test_token_exchange1(self, token): token_exchange_req = TokenExchangeRequest( grant_type="urn:ietf:params:oauth:grant-type:token-exchange", subject_token=_token_value, - subject_token_type=token[list(token.keys())[0]] + subject_token_type=token[list(token.keys())[0]], ) _req = self.endpoint.parse_request( @@ -275,7 +275,7 @@ def test_token_exchange1(self, token): {"headers": {"authorization": "Basic {}".format("Y2xpZW50XzI6aGVtbGlndA==")}}, ) _resp = self.endpoint.process_request(request=_req) - print(_resp['response_args']) + print(_resp["response_args"]) assert set(_resp["response_args"].keys()) == { "access_token", "token_type", @@ -411,14 +411,17 @@ def test_token_exchange_scopes_per_client(self): "policy": { "": { "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", - "kwargs": { - "scope": ["openid", "profile", "offline_access"] - }, + "kwargs": {"scope": ["openid", "profile", "offline_access"]}, } }, } - self.context.cdb["client_1"]["allowed_scopes"] = ["openid", "email", "profile", "offline_access"] + self.context.cdb["client_1"]["allowed_scopes"] = [ + "openid", + "email", + "profile", + "offline_access", + ] areq = AUTH_REQ.copy() areq["scope"].append("profile") @@ -440,7 +443,7 @@ def test_token_exchange_scopes_per_client(self): subject_token=_token_value, subject_token_type="urn:ietf:params:oauth:token-type:access_token", requested_token_type="urn:ietf:params:oauth:token-type:access_token", - scope="openid profile offline_access" + scope="openid profile offline_access", ) _req = self.endpoint.parse_request( @@ -469,12 +472,10 @@ def test_token_exchange_unsupported_scopes_per_client(self): "policy": { "": { "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", - "kwargs": { - "scope": ["openid", "profile", "offline_access"] - }, + "kwargs": {"scope": ["openid", "profile", "offline_access"]}, } }, - "allowed_scopes": ["openid", "email", "profile", "offline_access"] + "allowed_scopes": ["openid", "email", "profile", "offline_access"], } areq = AUTH_REQ.copy() @@ -496,7 +497,7 @@ def test_token_exchange_unsupported_scopes_per_client(self): subject_token=_token_value, subject_token_type="urn:ietf:params:oauth:token-type:access_token", requested_token_type="urn:ietf:params:oauth:token-type:access_token", - scope="email" + scope="email", ) _req = self.endpoint.parse_request( @@ -523,12 +524,10 @@ def test_token_exchange_no_scopes_requested(self): "policy": { "": { "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", - "kwargs": { - "scope": ["openid", "offline_access"] - }, + "kwargs": {"scope": ["openid", "offline_access"]}, } }, - "allowed_scopes": ["openid", "email", "profile", "offline_access"] + "allowed_scopes": ["openid", "email", "profile", "offline_access"], } areq = AUTH_REQ.copy() @@ -549,7 +548,7 @@ def test_token_exchange_no_scopes_requested(self): grant_type="urn:ietf:params:oauth:grant-type:token-exchange", subject_token=_token_value, subject_token_type="urn:ietf:params:oauth:token-type:access_token", - requested_token_type="urn:ietf:params:oauth:token-type:access_token" + requested_token_type="urn:ietf:params:oauth:token-type:access_token", ) _req = self.endpoint.parse_request( @@ -564,7 +563,9 @@ def test_additional_parameters(self): Test that a token exchange with additional parameters including scope, audience and subject_token_type works. """ - conf = self.endpoint.grant_type_helper["urn:ietf:params:oauth:grant-type:token-exchange"].config + conf = self.endpoint.grant_type_helper[ + "urn:ietf:params:oauth:grant-type:token-exchange" + ].config conf["policy"][""]["kwargs"] = {} conf["policy"][""]["kwargs"]["audience"] = ["https://example.com"] conf["policy"][""]["kwargs"]["resource"] = ["https://example.com"] @@ -612,10 +613,10 @@ def test_token_exchange_fails_if_disabled(self): grant_types_supported (that are set in its helper attribute). """ self.context.cdb["client_1"]["grant_types_supported"] = [ - 'authorization_code', - 'implicit', - 'urn:ietf:params:oauth:grant-type:jwt-bearer', - 'refresh_token' + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", ] areq = AUTH_REQ.copy() @@ -645,15 +646,17 @@ def test_token_exchange_fails_if_disabled(self): _resp = self.endpoint.process_request(request=_req) assert _resp["error"] == "invalid_request" assert ( - _resp["error_description"] - == "Unsupported grant_type: urn:ietf:params:oauth:grant-type:token-exchange" + _resp["error_description"] + == "Unsupported grant_type: urn:ietf:params:oauth:grant-type:token-exchange" ) def test_wrong_resource(self): """ Test that requesting a token for an unknown resource fails. """ - conf = self.endpoint.grant_type_helper["urn:ietf:params:oauth:grant-type:token-exchange"].config + conf = self.endpoint.grant_type_helper[ + "urn:ietf:params:oauth:grant-type:token-exchange" + ].config conf["policy"][""]["kwargs"] = {} conf["policy"][""]["kwargs"]["resource"] = ["https://example.com"] areq = AUTH_REQ.copy() @@ -723,7 +726,9 @@ def test_wrong_audience(self): """ Test that requesting a token for an unknown audience fails. """ - conf = self.endpoint.grant_type_helper["urn:ietf:params:oauth:grant-type:token-exchange"].config + conf = self.endpoint.grant_type_helper[ + "urn:ietf:params:oauth:grant-type:token-exchange" + ].config conf["policy"][""]["kwargs"] = {} conf["policy"][""]["kwargs"]["audience"] = ["https://example.com"] areq = AUTH_REQ.copy() @@ -1042,9 +1047,7 @@ def test_token_exchange_unsupported_scope_requested_1(self): "policy": { "": { "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", - "kwargs": { - "scope": ["offline_access", "profile"] - }, + "kwargs": {"scope": ["offline_access", "profile"]}, } }, } @@ -1131,9 +1134,7 @@ def test_token_exchange_unsupported_scope_requested_2(self): "policy": { "": { "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", - "kwargs": { - "scope": ["profile"] - }, + "kwargs": {"scope": ["profile"]}, } }, } @@ -1219,17 +1220,15 @@ def test_token_exchange_unsupported_scope_requested_3(self): "policy": { "": { "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", - "kwargs": { - "scope": ["offline_access", "profile"] - }, + "kwargs": {"scope": ["offline_access", "profile"]}, } }, } self.context.cdb["client_1"]["grant_types_supported"] = [ - 'authorization_code', - 'implicit', - 'urn:ietf:params:oauth:grant-type:jwt-bearer', - 'urn:ietf:params:oauth:grant-type:token-exchange' + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "urn:ietf:params:oauth:grant-type:token-exchange", ] areq = AUTH_REQ.copy() @@ -1327,17 +1326,15 @@ def test_token_exchange_unsupported_scope_requested_4(self): "policy": { "": { "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", - "kwargs": { - "scope": ["offline_access", "profile"] - }, + "kwargs": {"scope": ["offline_access", "profile"]}, } }, } self.context.cdb["client_1"]["grant_types_supported"] = [ - 'authorization_code', - 'implicit', - 'urn:ietf:params:oauth:grant-type:jwt-bearer', - 'urn:ietf:params:oauth:grant-type:token-exchange' + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "urn:ietf:params:oauth:grant-type:token-exchange", ] areq = AUTH_REQ.copy() @@ -1377,8 +1374,7 @@ def test_token_exchange_unsupported_scope_requested_4(self): _resp = self.endpoint.process_request(request=_req) assert _resp["error"] == "invalid_request" assert ( - _resp["error_description"] - == "Exchanging this subject token to refresh token forbidden" + _resp["error_description"] == "Exchanging this subject token to refresh token forbidden" ) token_exchange_req["scope"] = "offline_access" @@ -1425,9 +1421,7 @@ def test_token_exchange_unsupported_scope_requested_5(self): "policy": { "": { "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", - "kwargs": { - "scope": ["profile"] - }, + "kwargs": {"scope": ["profile"]}, } }, } @@ -1460,8 +1454,7 @@ def test_token_exchange_unsupported_scope_requested_5(self): _resp = self.endpoint.process_request(request=_req) assert _resp["error"] == "invalid_request" assert ( - _resp["error_description"] - == "Exchanging this subject token to refresh token forbidden" + _resp["error_description"] == "Exchanging this subject token to refresh token forbidden" ) token_exchange_req["scope"] = "profile" @@ -1473,8 +1466,7 @@ def test_token_exchange_unsupported_scope_requested_5(self): _resp = self.endpoint.process_request(request=_req) assert _resp["error"] == "invalid_request" assert ( - _resp["error_description"] - == "Exchanging this subject token to refresh token forbidden" + _resp["error_description"] == "Exchanging this subject token to refresh token forbidden" ) token_exchange_req["scope"] = "offline_access" @@ -1485,10 +1477,7 @@ def test_token_exchange_unsupported_scope_requested_5(self): ) _resp = self.endpoint.process_request(request=_req) assert _resp["error"] == "invalid_scope" - assert ( - _resp["error_description"] - == "Invalid requested scopes" - ) + assert _resp["error_description"] == "Invalid requested scopes" token_exchange_req["scope"] = "offline_access profile" @@ -1499,7 +1488,5 @@ def test_token_exchange_unsupported_scope_requested_5(self): _resp = self.endpoint.process_request(request=_req) assert _resp["error"] == "invalid_request" assert ( - _resp["error_description"] - == "Exchanging this subject token to refresh token forbidden" + _resp["error_description"] == "Exchanging this subject token to refresh token forbidden" ) - diff --git a/tests/test_server_38_oauth2_revocation_endpoint.py b/tests/test_server_38_oauth2_revocation_endpoint.py index 73a0b199..ad83af19 100644 --- a/tests/test_server_38_oauth2_revocation_endpoint.py +++ b/tests/test_server_38_oauth2_revocation_endpoint.py @@ -33,13 +33,8 @@ RESPONSE_TYPES_SUPPORTED = [ ["code"], - ["token"], ["id_token"], - ["code", "token"], ["code", "id_token"], - ["id_token", "token"], - ["code", "token", "id_token"], - ["none"], ] CAPABILITIES = { @@ -61,7 +56,7 @@ "claim_types_supported": ["normal", "aggregated", "distributed"], "claims_parameter_supported": True, "request_parameter_supported": True, - "request_uri_parameter_supported": True, + # "request_uri_parameter_supported": True, } AUTH_REQ = AuthorizationRequest( @@ -91,7 +86,6 @@ def full_path(local_file): @pytest.mark.parametrize("jwt_token", [True, False]) class TestEndpoint: - @pytest.fixture(autouse=True) def create_endpoint(self, jwt_token): conf = { @@ -212,8 +206,15 @@ def create_endpoint(self, jwt_token): }, "by_scope": {}, }, - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access", - "research_and_scholarship"] + "allowed_scopes": [ + "openid", + "profile", + "email", + "address", + "phone", + "offline_access", + "research_and_scholarship", + ], } endpoint_context.keyjar.import_jwks_as_json( endpoint_context.keyjar.export_jwks_as_json(private=True), @@ -284,7 +285,7 @@ def test_parse_with_client_auth_in_req(self): ) assert isinstance(_req, TokenRevocationRequest) - assert set(_req.keys()) == {"token", "client_id", "client_secret", 'authenticated'} + assert set(_req.keys()) == {"token", "client_id", "client_secret", "authenticated"} def test_parse_with_wrong_client_authn(self): access_token = self._get_access_token(AUTH_REQ) @@ -318,7 +319,7 @@ def test_process_request(self): ) _resp = self.revocation_endpoint.process_request(_req) assert _resp - assert set(_resp.keys()) == {"response_args"} + assert set(_resp.keys()) == {"response_msg"} def test_do_response(self): access_token = self._get_access_token(AUTH_REQ) @@ -337,7 +338,7 @@ def test_do_response(self): assert isinstance(msg_info, dict) assert set(msg_info.keys()) == {"response", "http_headers"} assert msg_info["http_headers"] == [ - ("Content-type", "application/json; charset=utf-8"), + ("Content-type", "text/plain"), ("Pragma", "no-cache"), ("Cache-Control", "no-store"), ] @@ -366,16 +367,16 @@ def test_access_token(self): } ) _resp = self.revocation_endpoint.process_request(_req) - assert "response_args" in _resp + assert "response_msg" in _resp assert access_token.revoked def test_access_token_per_client(self): - def custom_token_revocation_policy(token, session_info, **kwargs): _token = token _token.revoke() - response_args = {"response_args": {"type": "custom"}} - return TokenRevocationResponse(**response_args) + # response_args = {"response_args": {"type": "custom"}} + # return TokenRevocationResponse(**response_args) + return {"response_msg": "OK"} access_token = self._get_access_token(AUTH_REQ) assert access_token.revoked is False @@ -390,7 +391,7 @@ def custom_token_revocation_policy(token, session_info, **kwargs): }, "access_token": { "function": custom_token_revocation_policy, - } + }, }, } _req = self.revocation_endpoint.parse_request( @@ -401,13 +402,11 @@ def custom_token_revocation_policy(token, session_info, **kwargs): } ) _resp = self.revocation_endpoint.process_request(_req) - assert "response_args" in _resp - assert "type" in _resp["response_args"] - assert _resp["response_args"]["type"] == "custom" + assert "response_msg" in _resp + assert _resp["response_msg"] == "OK" assert access_token.revoked def test_missing_token_policy_per_client(self): - def custom_token_revocation_policy(token, session_info, **kwargs): _token = token _token.revoke() @@ -427,7 +426,7 @@ def custom_token_revocation_policy(token, session_info, **kwargs): }, "refresh_token": { "function": custom_token_revocation_policy, - } + }, }, } _req = self.revocation_endpoint.parse_request( @@ -438,7 +437,7 @@ def custom_token_revocation_policy(token, session_info, **kwargs): } ) _resp = self.revocation_endpoint.process_request(_req) - assert "response_args" in _resp + assert "response_msg" in _resp assert access_token.revoked def test_code(self): @@ -460,7 +459,7 @@ def test_code(self): } ) _resp = self.revocation_endpoint.process_request(_req) - assert "response_args" in _resp + assert "response_msg" in _resp assert code.revoked def test_refresh_token(self): @@ -475,7 +474,7 @@ def test_refresh_token(self): } ) _resp = self.revocation_endpoint.process_request(_req) - assert "response_args" in _resp + assert "response_msg" in _resp assert refresh_token.revoked def test_expired_access_token(self): @@ -492,7 +491,7 @@ def test_expired_access_token(self): } ) _resp = self.revocation_endpoint.process_request(_req) - assert "response_args" in _resp + assert "response_msg" in _resp def test_revoked_access_token(self): access_token = self._get_access_token(AUTH_REQ) @@ -508,7 +507,7 @@ def test_revoked_access_token(self): } ) _resp = self.revocation_endpoint.process_request(_req) - assert "response_args" in _resp + assert "response_msg" in _resp def test_unsupported_token_type(self): self.revocation_endpoint.token_types_supported = ["access_token"] diff --git a/tests/test_server_40_oauth2_pushed_authorization.py b/tests/test_server_40_oauth2_pushed_authorization.py index fa1a6acd..323dd6d6 100644 --- a/tests/test_server_40_oauth2_pushed_authorization.py +++ b/tests/test_server_40_oauth2_pushed_authorization.py @@ -167,15 +167,11 @@ def create_endpoint(self): context = server.context _clients = yaml.safe_load(io.StringIO(client_yaml)) context.cdb = verify_oidc_client_information(_clients["oidc_clients"]) - server.keyjar.import_jwks( - server.keyjar.export_jwks(True, ""), conf["issuer"] - ) + server.keyjar.import_jwks(server.keyjar.export_jwks(True, ""), conf["issuer"]) self.rp_keyjar = init_key_jar(key_defs=KEYDEFS, issuer_id="s6BhdRkqt3") # Add RP's keys to the OP's keyjar - server.keyjar.import_jwks( - self.rp_keyjar.export_jwks(issuer_id="s6BhdRkqt3"), "s6BhdRkqt3" - ) + server.keyjar.import_jwks(self.rp_keyjar.export_jwks(issuer_id="s6BhdRkqt3"), "s6BhdRkqt3") self.pushed_authorization_endpoint = server.get_endpoint("pushed_authorization") self.authorization_endpoint = server.get_endpoint("authorization") @@ -199,7 +195,7 @@ def test_pushed_auth_urlencoded(self): "code_challenge_method", "client_id", "code_challenge", - 'authenticated' + "authenticated", } def test_pushed_auth_request(self): @@ -226,7 +222,7 @@ def test_pushed_auth_request(self): "code_challenge", "request", "__verified_request", - 'authenticated' + "authenticated", } def test_pushed_auth_urlencoded_process(self): @@ -245,7 +241,7 @@ def test_pushed_auth_urlencoded_process(self): "code_challenge_method", "client_id", "code_challenge", - 'authenticated' + "authenticated", } _resp = self.pushed_authorization_endpoint.process_request(_req) diff --git a/tests/test_server_50_persistence.py b/tests/test_server_50_persistence.py index a0202cfa..7cea5afc 100644 --- a/tests/test_server_50_persistence.py +++ b/tests/test_server_50_persistence.py @@ -205,16 +205,19 @@ def create_endpoint(self): # Both have to use the same keyjar _keyjar = init_key_jar(key_defs=KEYDEFS) - _keyjar.import_jwks_as_json(_keyjar.export_jwks_as_json(True, ""), - ENDPOINT_CONTEXT_CONFIG['issuer']) + _keyjar.import_jwks_as_json( + _keyjar.export_jwks_as_json(True, ""), ENDPOINT_CONTEXT_CONFIG["issuer"] + ) server1 = Server( - OPConfiguration(conf=ENDPOINT_CONTEXT_CONFIG, base_path=BASEDIR), cwd=BASEDIR, - keyjar=_keyjar + OPConfiguration(conf=ENDPOINT_CONTEXT_CONFIG, base_path=BASEDIR), + cwd=BASEDIR, + keyjar=_keyjar, ) server2 = Server( - OPConfiguration(conf=ENDPOINT_CONTEXT_CONFIG, base_path=BASEDIR), cwd=BASEDIR, - keyjar=_keyjar + OPConfiguration(conf=ENDPOINT_CONTEXT_CONFIG, base_path=BASEDIR), + cwd=BASEDIR, + keyjar=_keyjar, ) # The top most part (Server class instance) is not @@ -224,7 +227,15 @@ def create_endpoint(self): "client_salt": "salted", "token_endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access", "research_and_scholarship"] + "allowed_scopes": [ + "openid", + "profile", + "email", + "address", + "phone", + "offline_access", + "research_and_scholarship", + ], } # make server2 endpoint context a copy of server 1 endpoint context @@ -300,12 +311,13 @@ def _dump_restore(self, fro, to): def test_init(self): assert self.endpoint[1] - assert set( - self.endpoint[1].upstream_get("context").provider_info["scopes_supported"] - ) == {"openid"} - assert self.endpoint[1].upstream_get("context").provider_info[ - "claims_parameter_supported"] == \ - self.endpoint[2].upstream_get("context").provider_info["claims_parameter_supported"] + assert set(self.endpoint[1].upstream_get("context").provider_info["scopes_supported"]) == { + "openid" + } + assert ( + self.endpoint[1].upstream_get("context").provider_info["claims_parameter_supported"] + == self.endpoint[2].upstream_get("context").provider_info["claims_parameter_supported"] + ) def test_parse(self): session_id = self._create_session(AUTH_REQ, index=1) diff --git a/tests/test_server_60_dpop.py b/tests/test_server_60_dpop.py index 69eef704..e13b8a35 100644 --- a/tests/test_server_60_dpop.py +++ b/tests/test_server_60_dpop.py @@ -14,7 +14,7 @@ from idpyoidc.server.client_authn import verify_client from idpyoidc.server.configure import OPConfiguration from idpyoidc.server.oauth2.add_on.dpop import DPoPProof -from idpyoidc.server.oauth2.add_on.dpop import post_parse_request +from idpyoidc.server.oauth2.add_on.dpop import token_post_parse_request from idpyoidc.server.oauth2.authorization import Authorization from idpyoidc.server.oidc.token import Token from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD @@ -66,13 +66,8 @@ def test_verify_header(): RESPONSE_TYPES_SUPPORTED = [ ["code"], - ["token"], ["id_token"], - ["code", "token"], ["code", "id_token"], - ["id_token", "token"], - ["code", "token", "id_token"], - ["none"], ] CAPABILITIES = { @@ -88,7 +83,7 @@ def test_verify_header(): "claim_types_supported": ["normal", "aggregated", "distributed"], "claims_parameter_supported": True, "request_parameter_supported": True, - "request_uri_parameter_supported": True, + # "request_uri_parameter_supported": True, } AUTH_REQ = AuthorizationRequest( @@ -188,7 +183,7 @@ def create_endpoint(self): "client_salt": "salted", "token_endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } self.user_id = "diana" self.token_endpoint = server.get_endpoint("token") @@ -228,7 +223,7 @@ def _mint_code(self, grant, client_id): return _code def test_post_parse_request(self): - auth_req = post_parse_request( + auth_req = token_post_parse_request( AUTH_REQ, AUTH_REQ["client_id"], self.context, diff --git a/tests/test_server_61_add_on.py b/tests/test_server_61_add_on.py index c9513cf7..9af72062 100644 --- a/tests/test_server_61_add_on.py +++ b/tests/test_server_61_add_on.py @@ -28,13 +28,8 @@ RESPONSE_TYPES_SUPPORTED = [ ["code"], - ["token"], ["id_token"], - ["code", "token"], ["code", "id_token"], - ["id_token", "token"], - ["code", "token", "id_token"], - ["none"], ] CAPABILITIES = { @@ -55,7 +50,7 @@ "claim_types_supported": ["normal", "aggregated", "distributed"], "claims_parameter_supported": True, "request_parameter_supported": True, - "request_uri_parameter_supported": True, + # "request_uri_parameter_supported": True, } AUTH_REQ = AuthorizationRequest( @@ -143,7 +138,7 @@ def create_endpoint(self): "client_salt": "salted", "token_endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], - "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"] + "allowed_scopes": ["openid", "profile", "email", "address", "phone", "offline_access"], } self.endpoint = server.get_endpoint("authorization") diff --git a/tests/test_tandem_oauth2_add_on.py b/tests/test_tandem_oauth2_add_on.py new file mode 100644 index 00000000..77b9391f --- /dev/null +++ b/tests/test_tandem_oauth2_add_on.py @@ -0,0 +1,346 @@ +import json +import os +from typing import List + +from cryptojwt.key_jar import build_keyjar + +from idpyoidc.client.oauth2 import Client +from idpyoidc.message.oauth2 import is_error_message +from idpyoidc.message.oidc import AccessTokenRequest +from idpyoidc.message.oidc import AuthorizationRequest +from idpyoidc.message.oidc import RefreshAccessTokenRequest +from idpyoidc.server import Server +from idpyoidc.server.authz import AuthzHandling +from idpyoidc.server.client_authn import verify_client +from idpyoidc.server.configure import ASConfiguration +from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from idpyoidc.server.user_info import UserInfo +from idpyoidc.util import rndstr +from tests import CRYPT_CONFIG +from tests import SESSION_PARAMS + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +CLIENT_KEYJAR = build_keyjar(KEYDEFS) + +COOKIE_KEYDEFS = [ + {"type": "oct", "kid": "sig", "use": ["sig"]}, + {"type": "oct", "kid": "enc", "use": ["enc"]}, +] + +AUTH_REQ = AuthorizationRequest( + client_id="client", + redirect_uri="https://example.com/cb", + scope=["openid"], + state="STATE", + response_type="code", +) + +TOKEN_REQ = AccessTokenRequest( + client_id="client", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", +) + +REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( + grant_type="refresh_token", client_id="https://example.com/", client_secret="hemligt" +) + +TOKEN_REQ_DICT = TOKEN_REQ.to_dict() + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +USERINFO = UserInfo(json.loads(open(full_path("users.json")).read())) + +_OAUTH2_SERVICES = { + "metadata": {"class": "idpyoidc.client.oauth2.server_metadata.ServerMetadata"}, + "authorization": {"class": "idpyoidc.client.oauth2.authorization.Authorization"}, + "access_token": {"class": "idpyoidc.client.oauth2.access_token.AccessToken"}, + "resource": {"class": "idpyoidc.client.oauth2.resource.Resource"}, +} + +SERVER_CONF = { + "issuer": "https://example.com/", + "httpc_params": {"verify": False, "timeout": 1}, + "subject_types_supported": ["public", "pairwise", "ephemeral"], + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "endpoint": { + "metadata": { + "path": ".well-known/oauth-authorization-server", + "class": "idpyoidc.server.oauth2.server_metadata.ServerMetadata", + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": "idpyoidc.server.oauth2.authorization.Authorization", + "kwargs": {}, + }, + "token": { + "path": "token", + "class": "idpyoidc.server.oauth2.token.Token", + "kwargs": {}, + }, + }, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "idpyoidc.server.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "userinfo": {"class": UserInfo, "kwargs": {"db": {}}}, + "client_authn": verify_client, + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "authorization_code": { + "supports_minting": ["access_token", "refresh_token"], + "max_usage": 1, + }, + "access_token": { + "supports_minting": ["access_token", "refresh_token"], + "expires_in": 600, + }, + "refresh_token": { + "supports_minting": ["access_token"], + "audience": ["https://example.com", "https://example2.com"], + "expires_in": 43200, + }, + }, + "expires_in": 43200, + } + }, + }, + "token_handler_args": { + "jwks_file": "private/token_jwks.json", + "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + }, + }, + }, + "session_params": SESSION_PARAMS, + "add_ons": { + "pkce": { + "function": "idpyoidc.server.oauth2.add_on.pkce.add_support", + "kwargs": {}, + }, + }, +} + +CLIENT_CONFIG = { + "issuer": SERVER_CONF["issuer"], + "client_secret": "hemligtlösenord", + "client_id": "client", + "redirect_uris": ["https://example.com/cb"], + "client_salt": "salted_peanuts_cooking", + "token_endpoint_auth_methods_supported": ["client_secret_post"], + "response_types_supported": ["code"], + "add_ons": { + "pkce": { + "function": "idpyoidc.client.oauth2.add_on.pkce.add_support", + "kwargs": {"code_challenge_length": 64, "code_challenge_method": "S256"}, + }, + }, +} + + +class Flow(object): + + def __init__(self, client, server): + self.client = client + self.server = server + + def do_query(self, service_type, endpoint_type, request_args=None, msg=None): + if request_args is None: + request_args = {} + if msg is None: + msg = {} + + _client_service = self.client.get_service(service_type) + req_info = _client_service.get_request_parameters(request_args=request_args) + + areq = req_info.get("request") + headers = req_info.get("headers") + + _server_endpoint = self.server.get_endpoint(endpoint_type) + if headers: + argv = {"http_info": {"headers": headers}} + else: + argv = {} + + if areq: + if _server_endpoint.request_format == "json": + _pr_req = _server_endpoint.parse_request(areq.to_json(), **argv) + else: + _pr_req = _server_endpoint.parse_request(areq.to_urlencoded(), **argv) + else: + if areq is None: + _pr_req = _server_endpoint.parse_request(areq) + else: + _pr_req = _server_endpoint.parse_request(areq, **argv) + + if is_error_message(_pr_req): + return areq, _pr_req + + _resp = _server_endpoint.process_request(_pr_req) + if is_error_message(_resp): + return areq, _resp + + _response = _server_endpoint.do_response(**_resp) + + resp = _client_service.parse_response(_response["response"]) + _state = msg.get("state", "") + _client_service.update_service_context(_resp["response_args"], key=_state) + return {"request": areq, "response": resp} + + def server_metadata_request(self, msg): + return {} + + def authorization_request(self, msg): + # ***** Authorization Request ********** + _nonce = (rndstr(24),) + _context = self.client.get_service_context() + # Need a new state for a new authorization request + _state = _context.cstate.create_state(iss=_context.get("issuer")) + _context.cstate.bind_key(_nonce, _state) + + req_args = {"response_type": ["code"], "nonce": _nonce, "state": _state} + + scope = msg.get("scope") + if scope: + _scope = scope + else: + _scope = ["openid"] + + req_args["scope"] = _scope + + return req_args + + def accesstoken_request(self, msg): + # ***** Token Request ********** + _context = self.client.get_service_context() + + auth_resp = msg["authorization"]["response"] + req_args = { + "code": auth_resp["code"], + "state": auth_resp["state"], + "redirect_uri": msg["authorization"]["request"]["redirect_uri"], + "grant_type": "authorization_code", + "client_id": self.client.get_client_id(), + "client_secret": _context.get_usage("client_secret"), + } + + return req_args + + def __call__(self, request_responses: List[list], **kwargs): + msg = kwargs + for request, response in request_responses: + func = getattr(self, f"{request}_request") + req_args = func(msg) + msg[request] = self.do_query(request, response, req_args, msg) + return msg + + +def test_pkce(): + server_conf = SERVER_CONF.copy() + server_conf["add_ons"] = { + "pkce": { + "function": "idpyoidc.server.oauth2.add_on.pkce.add_support", + "kwargs": {}, + }, + } + server = Server(ASConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + + client_config = CLIENT_CONFIG.copy() + client_config["add_ons"] = { + "pkce": { + "function": "idpyoidc.client.oauth2.add_on.pkce.add_support", + "kwargs": {"code_challenge_length": 64, "code_challenge_method": "S256"}, + }, + } + + client = Client( + client_type="oauth2", + config=client_config, + keyjar=build_keyjar(KEYDEFS), + services=_OAUTH2_SERVICES, + ) + + server.context.cdb["client"] = CLIENT_CONFIG + server.context.keyjar.import_jwks(client.keyjar.export_jwks(), "client") + + server.context.set_provider_info() + + flow = Flow(client, server) + msg = flow( + [ + ["server_metadata", "server_metadata"], + ["authorization", "authorization"], + ["accesstoken", "token"], + ], + scope=["foobar"], + ) + assert msg + + +def test_jar(): + server_conf = SERVER_CONF.copy() + # server_conf['add_ons'] = { + # "jar": { + # "function": "idpyoidc.server.oauth2.add_on.jar.add_support", + # "kwargs": {}, + # }, + # } + server = Server(ASConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + + client_config = CLIENT_CONFIG.copy() + client_config["add_ons"] = { + "jar": { + "function": "idpyoidc.client.oauth2.add_on.jar.add_support", + "kwargs": {}, + }, + } + + client = Client( + client_type="oauth2", + config=client_config, + keyjar=build_keyjar(KEYDEFS), + services=_OAUTH2_SERVICES, + ) + + server.context.cdb["client"] = CLIENT_CONFIG + server.context.keyjar.import_jwks(client.keyjar.export_jwks(), "client") + + server.context.set_provider_info() + + flow = Flow(client, server) + msg = flow( + [["server_metadata", "server_metadata"], ["authorization", "authorization"]], + scope=["foobar"], + ) + + assert msg diff --git a/tests/test_tandem_08_oauth2_cc_ropc.py b/tests/test_tandem_oauth2_cc_ropc.py similarity index 90% rename from tests/test_tandem_08_oauth2_cc_ropc.py rename to tests/test_tandem_oauth2_cc_ropc.py index 30c2a967..4dd718aa 100644 --- a/tests/test_tandem_08_oauth2_cc_ropc.py +++ b/tests/test_tandem_oauth2_cc_ropc.py @@ -37,10 +37,8 @@ def full_path(local_file): CONFIG = { "issuer": "https://example.net/", "httpc_params": {"verify": False}, - "preference": { - "grant_types_supported": ["client_credentials", "password"] - }, - "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS, 'read_only': False}, + "preference": {"grant_types_supported": ["client_credentials", "password"]}, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS, "read_only": False}, "token_handler_args": { "jwks_defs": {"key_defs": KEYDEFS}, "token": { @@ -49,8 +47,8 @@ def full_path(local_file): "lifetime": 3600, "add_claims_by_scope": True, "aud": ["https://example.org/appl"], - } - } + }, + }, }, "endpoint": { "token": { @@ -96,11 +94,11 @@ def full_path(local_file): "kwargs": { "db_conf": { "class": "idpyoidc.server.util.JSONDictDB", - "kwargs": {"filename": full_path("passwd.json")} + "kwargs": {"filename": full_path("passwd.json")}, } - } + }, } - } + }, } CLIENT_BASE_URL = "https://example.com" @@ -108,7 +106,7 @@ def full_path(local_file): CLIENT_CONFIG = { "client_id": "client_1", "client_secret": "another password", - "base_url": CLIENT_BASE_URL + "base_url": CLIENT_BASE_URL, } CLIENT_SERVICES = { "resource_owner_password_credentials": { @@ -123,9 +121,10 @@ def test_ropc(): client = Client(config=CLIENT_CONFIG, services=CLIENT_SERVICES) client.get_service("resource_owner_password_credentials").endpoint = "https://example.com/token" - service = client.get_service('resource_owner_password_credentials') + service = client.get_service("resource_owner_password_credentials") client_request_info = service.get_request_parameters( - request_args={'username': 'diana', 'password': 'krall'}) + request_args={"username": "diana", "password": "krall"} + ) # Server side @@ -141,5 +140,5 @@ def test_ropc(): } token_endpoint = server.get_endpoint("token") - request = token_endpoint.parse_request(client_request_info['request']) + request = token_endpoint.parse_request(client_request_info["request"]) assert request diff --git a/tests/test_tandem_oauth2_code.py b/tests/test_tandem_oauth2_code.py new file mode 100644 index 00000000..5adec34a --- /dev/null +++ b/tests/test_tandem_oauth2_code.py @@ -0,0 +1,268 @@ +import json +import os + +from cryptojwt.key_jar import build_keyjar +import pytest + +from idpyoidc.client.oauth2 import Client +from idpyoidc.message.oauth2 import is_error_message +from idpyoidc.message.oidc import AccessTokenRequest +from idpyoidc.message.oidc import AuthorizationRequest +from idpyoidc.message.oidc import RefreshAccessTokenRequest +from idpyoidc.server import Server +from idpyoidc.server.authz import AuthzHandling +from idpyoidc.server.client_authn import verify_client +from idpyoidc.server.configure import ASConfiguration +from idpyoidc.server.cookie_handler import CookieHandler +from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from idpyoidc.server.user_info import UserInfo +from idpyoidc.util import rndstr +from tests import CRYPT_CONFIG +from tests import SESSION_PARAMS + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +CLIENT_KEYJAR = build_keyjar(KEYDEFS) + +COOKIE_KEYDEFS = [ + {"type": "oct", "kid": "sig", "use": ["sig"]}, + {"type": "oct", "kid": "enc", "use": ["enc"]}, +] + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid"], + state="STATE", + response_type="code", +) + +TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", +) + +REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( + grant_type="refresh_token", client_id="https://example.com/", client_secret="hemligt" +) + +TOKEN_REQ_DICT = TOKEN_REQ.to_dict() + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +USERINFO = UserInfo(json.loads(open(full_path("users.json")).read())) + +_OAUTH2_SERVICES = { + "metadata": {"class": "idpyoidc.client.oauth2.server_metadata.ServerMetadata"}, + "authorization": {"class": "idpyoidc.client.oauth2.authorization.Authorization"}, + "access_token": {"class": "idpyoidc.client.oauth2.access_token.AccessToken"}, + "resource": {"class": "idpyoidc.client.oauth2.resource.Resource"}, +} + + +class TestFlow(object): + @pytest.fixture(autouse=True) + def create_entities(self): + server_conf = { + "issuer": "https://example.com/", + "httpc_params": {"verify": False, "timeout": 1}, + "subject_types_supported": ["public", "pairwise", "ephemeral"], + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "endpoint": { + "metadata": { + "path": ".well-known/oauth-authorization-server", + "class": "idpyoidc.server.oauth2.server_metadata.ServerMetadata", + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": "idpyoidc.server.oauth2.authorization.Authorization", + "kwargs": {}, + }, + "token": { + "path": "token", + "class": "idpyoidc.server.oauth2.token.Token", + "kwargs": {}, + }, + }, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "idpyoidc.server.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "userinfo": {"class": UserInfo, "kwargs": {"db": {}}}, + "client_authn": verify_client, + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "authorization_code": { + "supports_minting": ["access_token", "refresh_token"], + "max_usage": 1, + }, + "access_token": { + "supports_minting": ["access_token", "refresh_token"], + "expires_in": 600, + }, + "refresh_token": { + "supports_minting": ["access_token"], + "audience": ["https://example.com", "https://example2.com"], + "expires_in": 43200, + }, + }, + "expires_in": 43200, + } + }, + }, + "token_handler_args": { + "jwks_file": "private/token_jwks.json", + "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + }, + }, + }, + "session_params": SESSION_PARAMS, + } + self.server = Server(ASConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + + client_1_config = { + "issuer": server_conf["issuer"], + "client_secret": "hemligtlösenord", + "client_id": "client_1", + "redirect_uris": ["https://example.com/cb"], + "client_salt": "salted_peanuts_cooking", + "token_endpoint_auth_methods_supported": ["client_secret_post"], + "response_types_supported": ["code"], + } + client_services = _OAUTH2_SERVICES + self.client = Client( + client_type="oauth2", + config=client_1_config, + keyjar=build_keyjar(KEYDEFS), + services=_OAUTH2_SERVICES, + ) + + self.context = self.server.context + self.context.cdb["client_1"] = client_1_config + self.context.keyjar.import_jwks(self.client.keyjar.export_jwks(), "client_1") + + self.context.set_provider_info() + self.session_manager = self.context.session_manager + self.user_id = "diana" + + def do_query(self, service_type, endpoint_type, request_args, state): + _client_service = self.client.get_service(service_type) + req_info = _client_service.get_request_parameters(request_args=request_args) + + areq = req_info.get("request") + headers = req_info.get("headers") + + _server_endpoint = self.server.get_endpoint(endpoint_type) + if areq: + if headers: + argv = {"http_info": {"headers": headers}} + else: + argv = {} + areq.lax = True + _req = areq.serialize(_server_endpoint.request_format) + _pr_resp = _server_endpoint.parse_request(_req, **argv) + else: + _pr_resp = _server_endpoint.parse_request(areq) + + if is_error_message(_pr_resp): + return areq, _pr_resp + + _resp = _server_endpoint.process_request(_pr_resp) + if is_error_message(_resp): + return areq, _resp + + _response = _server_endpoint.do_response(**_resp) + + resp = _client_service.parse_response(_response["response"]) + _client_service.update_service_context(_resp["response_args"], key=state) + return areq, resp + + def process_setup(self, token=None, scope=None): + # ***** Discovery ********* + + _req, _resp = self.do_query("server_metadata", "server_metadata", {}, "") + + # ***** Authorization Request ********** + _nonce = (rndstr(24),) + _context = self.client.get_service_context() + # Need a new state for a new authorization request + _state = _context.cstate.create_state(iss=_context.get("issuer")) + _context.cstate.bind_key(_nonce, _state) + + req_args = {"response_type": ["code"], "nonce": _nonce, "state": _state} + + if scope: + _scope = scope + else: + _scope = ["openid"] + + if token and list(token.keys())[0] == "refresh_token": + _scope = ["openid", "offline_access"] + + req_args["scope"] = _scope + + areq, auth_response = self.do_query("authorization", "authorization", req_args, _state) + + # ***** Token Request ********** + + req_args = { + "code": auth_response["code"], + "state": auth_response["state"], + "redirect_uri": areq["redirect_uri"], + "grant_type": "authorization_code", + "client_id": self.client.get_client_id(), + "client_secret": _context.get_usage("client_secret"), + } + + _token_request, resp = self.do_query("accesstoken", "token", req_args, _state) + + return resp, _state, _scope + + def test_flow(self): + """ + Test that token exchange requests work correctly + """ + + resp, _state, _scope = self.process_setup(token="access_token", scope=["foobar"]) + + # Construct the resource request + + _client_service = self.client.get_service("resource") + req_info = _client_service.get_request_parameters( + authn_method="bearer_header", state=_state, endpoint="https://resource.example.com" + ) + + assert req_info["url"] == "https://resource.example.com" + assert "Authorization" in req_info["headers"] + assert req_info["headers"]["Authorization"].startswith("Bearer") diff --git a/tests/test_tandem_10_oauth2_token_exchange.py b/tests/test_tandem_oauth2_token_exchange.py similarity index 90% rename from tests/test_tandem_10_oauth2_token_exchange.py rename to tests/test_tandem_oauth2_token_exchange.py index 773fb218..141b1333 100644 --- a/tests/test_tandem_10_oauth2_token_exchange.py +++ b/tests/test_tandem_oauth2_token_exchange.py @@ -81,9 +81,7 @@ def full_path(local_file): "refresh_access_token": { "class": "idpyoidc.client.oauth2.refresh_access_token.RefreshAccessToken" }, - "token_exchange": { - "class": "idpyoidc.client.oauth2.token_exchange.TokenExchange" - } + "token_exchange": {"class": "idpyoidc.client.oauth2.token_exchange.TokenExchange"}, } @@ -118,7 +116,7 @@ def create_endpoint(self): }, "token": { "path": "token", - "class": "idpyoidc.server.oidc.token.Token", + "class": "idpyoidc.server.oauth2.token.Token", "kwargs": {}, }, }, @@ -185,7 +183,7 @@ def create_endpoint(self): "redirect_uris": ["https://example.com/cb"], "client_salt": "salted_peanuts_cooking", "token_endpoint_auth_methods_supported": ["client_secret_post"], - "response_types_supported": ["code", "token", "code id_token", "id_token"], + "response_types_supported": ["code", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "offline_access"], } client_2_config = { @@ -195,23 +193,27 @@ def create_endpoint(self): "redirect_uris": ["https://example.com/cb"], "client_salt": "salted_peanuts_cooking", "token_endpoint_auth_methods_supported": ["client_secret_post"], - "response_types_supported": ["code", "token", "code id_token", "id_token"], + "response_types_supported": ["code", "code id_token", "id_token"], "allowed_scopes": ["openid", "profile", "offline_access"], } - self.client_1 = Client(client_type='oauth2', config=client_1_config, - keyjar=build_keyjar(KEYDEFS), - services=_OAUTH2_SERVICES) - self.client_2 = Client(client_type='oauth2', config=client_2_config, - keyjar=build_keyjar(KEYDEFS), - services=_OAUTH2_SERVICES) + self.client_1 = Client( + client_type="oauth2", + config=client_1_config, + keyjar=build_keyjar(KEYDEFS), + services=_OAUTH2_SERVICES, + ) + self.client_2 = Client( + client_type="oauth2", + config=client_2_config, + keyjar=build_keyjar(KEYDEFS), + services=_OAUTH2_SERVICES, + ) self.context = self.server.context self.context.cdb["client_1"] = client_1_config self.context.cdb["client_2"] = client_2_config - self.context.keyjar.import_jwks( - self.client_1.keyjar.export_jwks(), "client_1") - self.context.keyjar.import_jwks( - self.client_2.keyjar.export_jwks(), "client_2") + self.context.keyjar.import_jwks(self.client_1.keyjar.export_jwks(), "client_1") + self.context.keyjar.import_jwks(self.client_2.keyjar.export_jwks(), "client_2") self.context.set_provider_info() @@ -234,7 +236,8 @@ def do_query(self, service_type, endpoint_type, request_args, state): else: argv = {} areq.lax = True - _pr_resp = _server.parse_request(areq.to_urlencoded(), **argv) + _req = areq.serialize(_server.request_format) + _pr_resp = _server.parse_request(_req, **argv) else: _pr_resp = _server.parse_request(areq) @@ -254,20 +257,16 @@ def do_query(self, service_type, endpoint_type, request_args, state): def process_setup(self, token=None, scope=None): # ***** Discovery ********* - _req, _resp = self.do_query('server_metadata', 'server_metadata', {}, '') + _req, _resp = self.do_query("server_metadata", "server_metadata", {}, "") # ***** Authorization Request ********** - _nonce = rndstr(24), + _nonce = (rndstr(24),) _context = self.client_1.get_service_context() # Need a new state for a new authorization request _state = _context.cstate.create_state(iss=_context.get("issuer")) _context.cstate.bind_key(_nonce, _state) - req_args = { - "response_type": ["code"], - "nonce": _nonce, - "state": _state - } + req_args = {"response_type": ["code"], "nonce": _nonce, "state": _state} if scope: _scope = scope @@ -279,7 +278,7 @@ def process_setup(self, token=None, scope=None): req_args["scope"] = _scope - areq, auth_response = self.do_query('authorization', 'authorization', req_args, _state) + areq, auth_response = self.do_query("authorization", "authorization", req_args, _state) # ***** Token Request ********** @@ -292,7 +291,7 @@ def process_setup(self, token=None, scope=None): "client_secret": _context.get_usage("client_secret"), } - _token_request, resp = self.do_query("accesstoken", 'token', req_args, _state) + _token_request, resp = self.do_query("accesstoken", "token", req_args, _state) return resp, _state, _scope @@ -316,12 +315,13 @@ def test_token_exchange(self, token): "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", "requested_token_type": token[list(token.keys())[0]], "subject_token": resp["access_token"], - "subject_token_type": 'urn:ietf:params:oauth:token-type:access_token', - "state": _state + "subject_token_type": "urn:ietf:params:oauth:token-type:access_token", + "state": _state, } - _token_exchange_request, _te_resp = self.do_query("token_exchange", "token", req_args, - _state) + _token_exchange_request, _te_resp = self.do_query( + "token_exchange", "token", req_args, _state + ) assert set(_te_resp.keys()) == { "access_token", @@ -356,8 +356,7 @@ def test_token_exchange_per_client(self, token): ], "policy": { "": { - "function": - "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", + "function": "idpyoidc.server.oauth2.token_helper.validate_token_exchange_policy", "kwargs": {"scope": ["openid", "offline_access"]}, } }, @@ -371,12 +370,13 @@ def test_token_exchange_per_client(self, token): "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", "requested_token_type": token[list(token.keys())[0]], "subject_token": resp["access_token"], - "subject_token_type": 'urn:ietf:params:oauth:token-type:access_token', - "state": _state + "subject_token_type": "urn:ietf:params:oauth:token-type:access_token", + "state": _state, } - _token_exchange_request, _te_resp = self.do_query("token_exchange", "token", req_args, - _state) + _token_exchange_request, _te_resp = self.do_query( + "token_exchange", "token", req_args, _state + ) assert set(_te_resp.keys()) == { "access_token", @@ -413,8 +413,9 @@ def test_additional_parameters(self): "resource": ["https://example.com"], } - _token_exchange_request, _te_resp = self.do_query("token_exchange", "token", req_args, - _state) + _token_exchange_request, _te_resp = self.do_query( + "token_exchange", "token", req_args, _state + ) assert set(_te_resp.keys()) == { "access_token", @@ -440,7 +441,7 @@ def test_token_exchange_fails_if_disabled(self): "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", "subject_token": resp["access_token"], "subject_token_type": "urn:ietf:params:oauth:token-type:access_token", - "resource": ["https://example.com/api"] + "resource": ["https://example.com/api"], } _te_request, _te_resp = self.do_query("token_exchange", "token", req_args, _state) @@ -481,7 +482,8 @@ def test_refresh_token_audience(self): """ resp, _state, _scope = self.process_setup( - {"refresh_token": "urn:ietf:params:oauth:token-type:refresh_token"}) + {"refresh_token": "urn:ietf:params:oauth:token-type:refresh_token"} + ) # ****** Token Exchange Request ********** @@ -529,7 +531,8 @@ def test_exchange_refresh_token_to_refresh_token(self): Test whether exchanging a refresh token to another refresh token works. """ resp, _state, _scope = self.process_setup( - {"refresh_token": "urn:ietf:params:oauth:token-type:refresh_token"}) + {"refresh_token": "urn:ietf:params:oauth:token-type:refresh_token"} + ) # ****** Token Exchange Request ********** diff --git a/tests/test_tandem_oauth2_token_revocation.py b/tests/test_tandem_oauth2_token_revocation.py new file mode 100644 index 00000000..92d44309 --- /dev/null +++ b/tests/test_tandem_oauth2_token_revocation.py @@ -0,0 +1,240 @@ +import os + +import pytest +from cryptojwt.key_jar import build_keyjar + +from idpyoidc.client.oauth2 import Client +from idpyoidc.message.oauth2 import is_error_message +from idpyoidc.server import ASConfiguration +from idpyoidc.server import Server +from idpyoidc.server.authz import AuthzHandling +from idpyoidc.server.client_authn import verify_client +from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from idpyoidc.server.user_info import UserInfo +from idpyoidc.util import rndstr +from tests import CRYPT_CONFIG +from tests import SESSION_PARAMS + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +class TestClient(object): + @pytest.fixture(autouse=True) + def create_entities(self): + # -------------- Server ----------------------- + + server_conf = { + "issuer": "https://example.com/", + "httpc_params": {"verify": False, "timeout": 1}, + "subject_types_supported": ["public", "pairwise", "ephemeral"], + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "endpoint": { + "discovery": { + "path": "/.well-known/oauth-authorization-server", + "class": "idpyoidc.server.oauth2.server_metadata.ServerMetadata", + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": "idpyoidc.server.oauth2.authorization.Authorization", + "kwargs": {}, + }, + "token": { + "path": "token", + "class": "idpyoidc.server.oauth2.token.Token", + "kwargs": {}, + }, + "token_revocation": { + "path": "revocation", + "class": "idpyoidc.server.oauth2.token_revocation.TokenRevocation", + "kwargs": {}, + }, + "introspection": { + "path": "introspection", + "class": "idpyoidc.server.oauth2.introspection.Introspection", + }, + }, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "idpyoidc.server.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "userinfo": {"class": UserInfo, "kwargs": {"db": {}}}, + "client_authn": verify_client, + "template_dir": "template", + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "authorization_code": { + "supports_minting": ["access_token", "refresh_token"], + "max_usage": 1, + }, + "access_token": { + "supports_minting": ["access_token", "refresh_token"], + "expires_in": 600, + }, + "refresh_token": { + "supports_minting": ["access_token"], + "audience": ["https://example.com", "https://example2.com"], + "expires_in": 43200, + }, + }, + "expires_in": 43200, + } + }, + }, + "token_handler_args": { + "jwks_file": "private/token_jwks.json", + "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + }, + }, + }, + "session_params": SESSION_PARAMS, + } + self.server = Server(ASConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + + # -------------- Client ----------------------- + + client_conf = { + "redirect_uris": ["https://example.com/cli/code_cb"], + "client_id": "client_1", + "client_secret": "abcdefghijklmnop", + "issuer": "https://example.com/", + "response_types_supported": ["code"], + } + services = { + "server_metadata": {"class": "idpyoidc.client.oauth2.server_metadata.ServerMetadata"}, + "authorization": {"class": "idpyoidc.client.oauth2.authorization.Authorization"}, + "access_token": {"class": "idpyoidc.client.oauth2.access_token.AccessToken"}, + "token_revocation": { + "class": "idpyoidc.client.oauth2.token_revocation.TokenRevocation" + }, + "introspection": {"class": "idpyoidc.client.oauth2.introspection.Introspection"}, + } + self.client = Client(config=client_conf, keyjar=build_keyjar(KEYDEFS), services=services) + + # ------- tell the server about the client ---------------- + self.context = self.server.context + self.context.cdb["client_1"] = client_conf + self.context.keyjar.import_jwks(self.client.keyjar.export_jwks(), "client_1") + + def do_query(self, service_type, endpoint_type, request_args, state): + _client = self.client.get_service(service_type) + req_info = _client.get_request_parameters(request_args=request_args) + + areq = req_info.get("request") + headers = req_info.get("headers") + + _server = self.server.get_endpoint(endpoint_type) + if areq: + if headers: + argv = {"http_info": {"headers": headers}} + else: + argv = {} + areq.lax = True + if _server.request_format == "json": + _pr_req = _server.parse_request(areq.to_json(), **argv) + else: + _pr_req = _server.parse_request(areq.to_urlencoded(), **argv) + else: + _pr_req = _server.parse_request(areq) + + if is_error_message(_pr_req): + return areq, _pr_req + + _resp = _server.process_request(_pr_req) + if is_error_message(_resp): + return areq, _resp + + _response = _server.do_response(**_resp) + + resp = _client.parse_response(_response["response"]) + if "response_args" in _resp: + _client.update_service_context(_resp["response_args"], key=state) + + return areq, resp + + def process_setup(self, token=None, scope=None): + # ***** Discovery ********* + + _req, _resp = self.do_query("server_metadata", "server_metadata", {}, "") + + # ***** Authorization Request ********** + _context = self.client.get_service_context() + # Need a new state for a new authorization request + _state = _context.cstate.create_state(iss=_context.get("issuer")) + _nonce = (rndstr(24),) + # bind nonce to state + _context.cstate.bind_key(_nonce, _state) + + req_args = {"response_type": ["code"], "nonce": _nonce, "state": _state} + + if scope: + _scope = scope + else: + _scope = ["foobar"] + + req_args["scope"] = _scope + + areq, auth_response = self.do_query("authorization", "authorization", req_args, _state) + + # ***** Token Request ********** + + req_args = { + "code": auth_response["code"], + "state": auth_response["state"], + "redirect_uri": areq["redirect_uri"], + # "grant_type": "authorization_code", + # "client_id": self.client_.get_client_id(), + # "client_secret": _context.get_usage("client_secret"), + } + + _token_request, resp = self.do_query("accesstoken", "token", req_args, _state) + + return resp, _state, _scope + + def test_revoke(self): + resp, _state, _scope = self.process_setup() + + _context = self.client.get_context() + _state = _context.cstate.get(_state) + + req_args = {"token": _state["access_token"], "token_type_hint": "access_token"} + + # Check that I have an active token + + _request, _resp = self.do_query("introspection", "introspection", req_args, _state) + + assert _resp["active"] == True + + # ****** Token Revocation Request ********** + + _request, _resp = self.do_query("token_revocation", "token_revocation", req_args, _state) + assert _resp == "OK" + + # Test if it's really revoked + + _request, _resp = self.do_query("introspection", "introspection", req_args, _state) + + assert _resp.to_dict() == {"active": False} diff --git a/tests/test_tandem_oidc_code.py b/tests/test_tandem_oidc_code.py new file mode 100644 index 00000000..ad40a46e --- /dev/null +++ b/tests/test_tandem_oidc_code.py @@ -0,0 +1,294 @@ +import json +import os + +from cryptojwt.key_jar import build_keyjar + +import pytest + +from idpyoidc.client.oidc import RP +from idpyoidc.message.oauth2 import is_error_message +from idpyoidc.message.oidc import AccessTokenRequest +from idpyoidc.message.oidc import AuthorizationRequest +from idpyoidc.message.oidc import RefreshAccessTokenRequest +from idpyoidc.server import Server +from idpyoidc.server.authz import AuthzHandling +from idpyoidc.server.client_authn import verify_client +from idpyoidc.server.configure import OPConfiguration +from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from idpyoidc.server.user_info import UserInfo +from idpyoidc.util import rndstr +from tests import CRYPT_CONFIG +from tests import SESSION_PARAMS + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +CLIENT_KEYJAR = build_keyjar(KEYDEFS) + +COOKIE_KEYDEFS = [ + {"type": "oct", "kid": "sig", "use": ["sig"]}, + {"type": "oct", "kid": "enc", "use": ["enc"]}, +] + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid"], + state="STATE", + response_type="code", +) + +TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", +) + +REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( + grant_type="refresh_token", client_id="https://example.com/", client_secret="hemligt" +) + +TOKEN_REQ_DICT = TOKEN_REQ.to_dict() + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +USERINFO = UserInfo(json.loads(open(full_path("users.json")).read())) + +_OIDC_SERVICES = { + "provider_info": { + "class": "idpyoidc.client.oidc.provider_info_discovery.ProviderInfoDiscovery" + }, + "registration": {"class": "idpyoidc.client.oidc.registration.Registration"}, + "authorization": {"class": "idpyoidc.client.oidc.authorization.Authorization"}, + "access_token": {"class": "idpyoidc.client.oidc.access_token.AccessToken"}, + "userinfo": {"class": "idpyoidc.client.oidc.userinfo.UserInfo"}, +} + + +class TestFlow(object): + @pytest.fixture(autouse=True) + def create_entities(self): + server_conf = { + "issuer": "https://op.example.com/", + "httpc_params": {"verify": False, "timeout": 1}, + "subject_types_supported": ["public", "pairwise", "ephemeral"], + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "endpoint": { + "provider_info": { + "path": ".well-known/openid-configuration", + "class": "idpyoidc.server.oidc.provider_config.ProviderConfiguration", + "kwargs": {}, + }, + "register": { + "path": "authorization", + "class": "idpyoidc.server.oidc.registration.Registration", + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": "idpyoidc.server.oidc.authorization.Authorization", + "kwargs": {}, + }, + "token": { + "path": "token", + "class": "idpyoidc.server.oidc.token.Token", + "kwargs": {}, + }, + "userinfo": { + "path": "user", + "class": "idpyoidc.server.oidc.userinfo.UserInfo", + "kwargs": {}, + }, + }, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "idpyoidc.server.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "userinfo": {"class": UserInfo, "kwargs": {"db_file": "users.json"}}, + "client_authn": verify_client, + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "authorization_code": { + "supports_minting": ["access_token", "refresh_token"], + "max_usage": 1, + }, + "access_token": { + "supports_minting": ["access_token", "refresh_token"], + "expires_in": 600, + }, + "refresh_token": { + "supports_minting": ["access_token"], + "audience": ["https://example.com", "https://example2.com"], + "expires_in": 43200, + }, + }, + "expires_in": 43200, + } + }, + }, + "token_handler_args": { + "jwks_file": "private/token_jwks.json", + "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + }, + }, + "id_token": { + "class": "idpyoidc.server.token.id_token.IDToken", + "kwargs": { + "base_claims": { + "email": {"essential": True}, + "email_verified": {"essential": True}, + } + }, + }, + }, + "session_params": SESSION_PARAMS, + } + self.server = Server(OPConfiguration(conf=server_conf, base_path=BASEDIR), cwd=BASEDIR) + + client_config = { + "issuer": server_conf["issuer"], + # "client_secret": "hemligtlösenord", + # "client_id": "client_1", + # "client_salt": "salted_peanuts_cooking", + "redirect_uris": ["https://example.com/cb"], + "token_endpoint_auth_methods_supported": ["client_secret_post"], + "response_types_supported": ["code", "id_token", "id_token token"], + } + self.rp = RP(config=client_config, keyjar=build_keyjar(KEYDEFS), services=_OIDC_SERVICES) + + self.context = self.server.context + # self.context.cdb["client_1"] = client_config + # self.context.keyjar.import_jwks(self.rp.keyjar.export_jwks(), "client_1") + + self.context.set_provider_info() + # self.session_manager = self.context.session_manager + # self.user_id = "diana" + + def do_query(self, service_type, endpoint_type, request_args, state): + _client_service = self.rp.get_service(service_type) + req_info = _client_service.get_request_parameters(request_args=request_args, state=state) + + areq = req_info.get("request") + headers = req_info.get("headers") + + _server_endpoint = self.server.get_endpoint(endpoint_type) + + if headers: + argv = {"http_info": {"headers": headers}} + else: + argv = {} + + if areq: + areq.lax = True + _req = areq.serialize(_server_endpoint.request_format) + _pr_req = _server_endpoint.parse_request(_req, **argv) + else: + _pr_req = _server_endpoint.parse_request(areq, **argv) + + if is_error_message(_pr_req): + return areq, _pr_req + + _resp = _server_endpoint.process_request(_pr_req) + if is_error_message(_resp): + return areq, _resp + + _response = _server_endpoint.do_response(**_resp) + + resp = _client_service.parse_response(_response["response"], state=state) + _client_service.update_service_context(_resp["response_args"], key=state) + # Fake key import + if service_type == "provider_info": + _client_service.upstream_get("attribute", "keyjar").import_jwks( + _server_endpoint.upstream_get("attribute", "keyjar").export_jwks(), + issuer_id=_server_endpoint.upstream_get("attribute", "issuer"), + ) + return areq, resp + + def process_setup(self, token=None, scope=None): + # ***** Discovery ********* + _req, _resp = self.do_query("provider_info", "provider_config", {}, "") + + # ***** Client Registration ********** + + _req, _resp = self.do_query("registration", "registration", {}, "") + + # ***** Authorization Request ********** + + _nonce = rndstr(24) + _context = self.rp.get_service_context() + # Need a new state for a new authorization request + _state = _context.cstate.create_state(iss=_context.get("issuer")) + _context.cstate.bind_key(_nonce, _state) + + req_args = {"response_type": ["code"], "nonce": _nonce, "state": _state} + + if scope: + _scope = scope + else: + _scope = ["openid"] + + if token and list(token.keys())[0] == "refresh_token": + _scope = ["openid", "offline_access"] + + req_args["scope"] = _scope + + areq, auth_response = self.do_query("authorization", "authorization", req_args, _state) + + # ***** Token Request ********** + + req_args = { + "code": auth_response["code"], + "state": auth_response["state"], + "redirect_uri": areq["redirect_uri"], + "grant_type": "authorization_code", + "client_id": self.rp.get_client_id(), + "client_secret": _context.get_usage("client_secret"), + } + + _token_request, resp = self.do_query("accesstoken", "token", req_args, _state) + + return resp, _state, _scope + + def test_flow(self): + """ + Test that token exchange requests work correctly + """ + + resp, _state, _scope = self.process_setup( + token="access_token", + scope=["openid", "profile", "email", "address", "phone", "offline_access"], + ) + + # The User Info request + + _request, resp = self.do_query("userinfo", "userinfo", {}, _state) + + assert resp