diff --git a/invenio_oauthclient/config.py b/invenio_oauthclient/config.py index 726a4554..b267d17b 100644 --- a/invenio_oauthclient/config.py +++ b/invenio_oauthclient/config.py @@ -200,7 +200,7 @@ class CustomOAuthRemoteApp(OAuthRemoteApp): OAUTHCLIENT_REMOTE_APPS = {} """Configuration of remote applications.""" -OAUTHCLIENT_SESSION_KEY_PREFIX = 'oauth_token' +OAUTHCLIENT_SESSION_KEY_PREFIX = "oauth_token" """Session key prefix used when storing the access token for a remote app.""" OAUTHCLIENT_STATE_EXPIRES = 300 @@ -209,8 +209,14 @@ class CustomOAuthRemoteApp(OAuthRemoteApp): OAUTHCLIENT_STATE_ENABLED = True """Internal variable used to disable state validation during tests.""" -OAUTHCLIENT_SIGNUP_TEMPLATE = 'invenio_oauthclient/signup.html' +OAUTHCLIENT_SIGNUP_TEMPLATE = "invenio_oauthclient/signup.html" """Template for the signup page.""" OAUTHCLIENT_REST_REMOTE_APPS = {} """Configuration of remote rest applications.""" + +OAUTHCLIENT_REST_DEFAULT_RESPONSE_HANDLER = None +"""Default REST response handler""" + +OAUTHCLIENT_REST_DEFAULT_ERROR_HANDLER = None +"""Default REST error handler""" diff --git a/invenio_oauthclient/handlers/rest.py b/invenio_oauthclient/handlers/rest.py index 01341181..b113aadc 100644 --- a/invenio_oauthclient/handlers/rest.py +++ b/invenio_oauthclient/handlers/rest.py @@ -30,7 +30,7 @@ account_setup_received from ..utils import create_csrf_disabled_registrationform, \ create_registrationform, fill_form, oauth_authenticate, oauth_get_user, \ - oauth_register, rest_oauth_register + oauth_register, rest_oauth_register, obj_or_import_string from .utils import get_session_next_url, response_token_setter, token_getter, \ token_session_key, token_setter @@ -45,18 +45,16 @@ def response_handler_postmessage(remote, url, payload=dict()): def default_response_handler(remote, url, payload=dict()): """Default response handler.""" + default_handler = current_app.config.get( + "OAUTHCLIENT_REST_DEFAULT_RESPONSE_HANDLER" + ) + if default_handler: + return obj_or_import_string(default_handler)(remote, url, payload) if payload: return redirect( "{url}?{payload}".format(url=url, payload=urlencode(payload))) return redirect(url) - -def response_handler(remote, url, payload=dict()): - """Handle oauthclient rest response.""" - return current_oauthclient.response_handler[remote.name]( - url, payload) - - # # Error handling decorators # @@ -67,47 +65,68 @@ def inner(resp, remote, *args, **kwargs): # OAuthErrors should not happen, so they are not caught here. Hence # they will result in a 500 Internal Server Error which is what we # are interested in. - remote_app_config = current_app.config['OAUTHCLIENT_REST_REMOTE_APPS'][ - remote.name] try: + remote_app_config = current_app.config['OAUTHCLIENT_REST_REMOTE_APPS'][remote.name] return f(resp, remote, *args, **kwargs) except OAuthClientError as e: current_app.logger.warning(e.message, exc_info=True) - return oauth2_handle_error( + default_oauth_error_handler = ( + current_app.config["OAUTHCLIENT_REST_DEFAULT_ERROR_HANDLER"] + or oauth2_handle_error + ) + return default_oauth_error_handler( e.remote, e.response, e.code, e.uri, e.description ) except OAuthCERNRejectedAccountError as e: current_app.logger.warning(e.message, exc_info=True) return response_handler( remote, - remote_app_config['error_redirect_url'], - payload=dict( - message='CERN account not allowed.', - code=400) - ) + remote_app_config["error_redirect_url"], + payload=dict(message="CERN account not allowed.", code=400), + ) except OAuthRejectedRequestError: return response_handler( remote, - remote_app_config['error_redirect_url'], + remote_app_config["error_redirect_url"], payload=dict( - message='You rejected the authentication request.', - code=400) - ) + message="You rejected the authentication request.", code=400 + ), + ) except AlreadyLinkedError: - msg = 'External service is already linked to another account.' + msg = "External service is already linked to another account." return response_handler( remote, - remote_app_config['error_redirect_url'], - payload=dict( - message=msg, - code=400) - ) + remote_app_config["error_redirect_url"], + payload=dict(message=msg, code=400), + ) return inner +def oauth_response_error_handler(f): + """Decorator to handle exceptions.""" + @wraps(f) + def inner(remote, url, payload=None): + try: + return f(remote, url, payload) + except AttributeError as e: + current_app.logger.warning(e, exc_info=True) + default_oauth_error_handler = ( + current_app.config["OAUTHCLIENT_REST_DEFAULT_ERROR_HANDLER"] + or oauth2_handle_error + ) + return default_oauth_error_handler(remote, None, 500, url, str(e)) + + return inner # # Handlers # + +@oauth_response_error_handler +def response_handler(remote, url, payload=None, *args, **kwargs): + """Handle oauthclient rest response.""" + return current_oauthclient.response_handler[remote.name]( + url, payload) + @oauth_error_handler def authorized_default_handler(resp, remote, *args, **kwargs): """Store access token in session. @@ -382,16 +401,19 @@ def signup_handler(remote, *args, **kwargs): )) -def oauth2_handle_error(remote, resp, error_code, error_uri, - error_description): +def oauth2_handle_error(remote, resp, error_code, error_uri, error_description): """Handle errors during exchange of one-time code for an access tokens.""" - remote_app_config = current_app.config['OAUTHCLIENT_REST_REMOTE_APPS'][ - remote.name] - return response_handler( + error_handler = response_handler + url = error_uri + if not remote or not hasattr(remote, "name"): + error_handler = default_response_handler + else: + remote_app_config = current_app.config["OAUTHCLIENT_REST_REMOTE_APPS"][ + remote.name + ] + url = remote_app_config["error_redirect_url"] + return error_handler( remote, - remote_app_config['error_redirect_url'], - payload=dict( - message="Authorization with remote service failed.", - code=400 - ) + url, + payload=dict(message="Authorization with remote service failed.", code=400), ) diff --git a/tests/conftest.py b/tests/conftest.py index a8fb8dc5..06d92cbf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -139,7 +139,7 @@ def _init_app(app_): def _init_app_rest(app_): """Init OAuth rest app.""" FlaskOAuth(app_) - InvenioAccountsRest(app_) + InvenioAccountsREST(app_) InvenioOAuthClientREST(app_) app_.register_blueprint(blueprint_client) return app_ diff --git a/tests/test_handlers_rest.py b/tests/test_handlers_rest.py index 800d9afc..c8617af0 100644 --- a/tests/test_handlers_rest.py +++ b/tests/test_handlers_rest.py @@ -10,8 +10,9 @@ from __future__ import absolute_import, print_function +import mock import pytest -from flask import current_app, session, url_for +from flask import session, url_for from flask_login import current_user from flask_oauthlib.client import OAuth as FlaskOAuth from flask_security import login_user, logout_user @@ -20,11 +21,11 @@ from werkzeug.routing import BuildError from invenio_oauthclient import InvenioOAuthClientREST, current_oauthclient -from invenio_oauthclient.errors import AlreadyLinkedError, OAuthResponseError +from invenio_oauthclient.errors import AlreadyLinkedError from invenio_oauthclient.handlers import token_session_key, token_setter from invenio_oauthclient.handlers.rest import authorized_signup_handler, \ - disconnect_handler, oauth_error_handler, response_handler_postmessage, \ - signup_handler + default_response_handler, disconnect_handler, oauth2_handle_error, \ + oauth_error_handler, response_handler_postmessage, signup_handler from invenio_oauthclient.models import RemoteToken from invenio_oauthclient.utils import oauth_authenticate from invenio_oauthclient.views.client import blueprint as blueprint_client @@ -35,7 +36,7 @@ @pytest.fixture(scope="function") def remote(request, app_rest): """.""" - oauth = current_app.extensions['oauthlib.client'] + oauth = current_oauthclient.oauth return oauth.remote_apps[request.param] @@ -241,3 +242,26 @@ def test_response_handler_with_postmessage(remote, app_rest, models_fixture): assert expected_message in response assert expected_status in response assert 'window.opener.postMessage' in response + + +@mock.patch('invenio_oauthclient.handlers.rest.default_response_handler') +def test_error_handler_with_string_in_remote_regresion( + default_response_handler_mock, + app_rest +): + remote = 'orcid' + oauth2_handle_error(remote, None, None, None, None) + default_response_handler_mock.assert_called_once() + + +@mock.patch('invenio_oauthclient.handlers.rest.response_handler_postmessage') +def test_default_response_handler_from_config(response_handler_mock, app_rest): + default_handler = 'invenio_oauthclient.handlers.' \ + 'rest.response_handler_postmessage' + with mock.patch.dict( + app_rest.config, { + "OAUTHCLIENT_REST_DEFAULT_RESPONSE_HANDLER": default_handler + } + ): + default_response_handler(None, None, None) + response_handler_mock.assert_called_once()