diff --git a/mwdb/app.py b/mwdb/app.py index 88b260ed6..fc3f9a0dc 100755 --- a/mwdb/app.py +++ b/mwdb/app.py @@ -412,3 +412,6 @@ def apply_rate_limit(): plugin_context = PluginAppContext() with app.app_context(): load_plugins(plugin_context) + +# Register blueprint +api.register() diff --git a/mwdb/core/app.py b/mwdb/core/app.py index 4530439d8..4199596a6 100644 --- a/mwdb/core/app.py +++ b/mwdb/core/app.py @@ -1,4 +1,4 @@ -from flask import Blueprint, Flask +from flask import Flask from werkzeug.middleware.proxy_fix import ProxyFix from mwdb.core.config import app_config @@ -7,9 +7,7 @@ app = Flask(__name__, static_folder=None) app.config["MAX_CONTENT_LENGTH"] = app_config.mwdb.max_upload_size -api_blueprint = Blueprint("api", __name__, url_prefix="/api") -api = Service(app, api_blueprint) -app.register_blueprint(api_blueprint) +api = Service(app) if app_config.mwdb.use_x_forwarded_for: app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1) diff --git a/mwdb/core/rate_limit.py b/mwdb/core/rate_limit.py index 2a0cc9a00..164043f19 100644 --- a/mwdb/core/rate_limit.py +++ b/mwdb/core/rate_limit.py @@ -50,7 +50,10 @@ def apply_rate_limit_for_request() -> bool: ): return False # Split blueprint name and resource name from endpoint - _, resource_name = request.endpoint.split(".", 2) + if request.endpoint: + _, resource_name = request.endpoint.split(".", 2) + else: + resource_name = "None" method = request.method.lower() user = g.auth_user.login if g.auth_user is not None else request.remote_addr # Limit keys from most specific to the least specific diff --git a/mwdb/core/service.py b/mwdb/core/service.py index be6228e6f..021ff0883 100644 --- a/mwdb/core/service.py +++ b/mwdb/core/service.py @@ -1,11 +1,12 @@ +import re import sys import textwrap -from apispec import APISpec +from apispec import APISpec, yaml_utils from apispec.ext.marshmallow import MarshmallowPlugin -from apispec_webframeworks.flask import FlaskPlugin from flask import Blueprint, Flask, jsonify, request from flask.typing import ResponseReturnValue +from flask.views import MethodView from sqlalchemy.exc import OperationalError from werkzeug.exceptions import ( HTTPException, @@ -19,72 +20,69 @@ from .log import getLogger -SUPPORTED_METHODS = ["head", "get", "post", "put", "delete", "patch"] - logger = getLogger() -class Resource: - def __init__(self): - self.available_methods = [ - method.upper() for method in SUPPORTED_METHODS if hasattr(self, method) - ] +def flaskpath2openapi(path: str) -> str: + """Convert a Flask URL rule to an OpenAPI-compliant path. - def __call__(self, *args, **kwargs): - """ - Acts as view function, calling appropriate method and - jsonifying response - """ - if request.method not in self.available_methods: + Got from https://github.com/marshmallow-code/apispec-webframeworks/ + + :param str path: Flask path template. + """ + # from flask-restplus + re_url = re.compile(r"<(?:[^:<>]+:)?([^<>]+)>") + return re_url.sub(r"{\1}", path) + + +class Resource(MethodView): + init_every_request = False + + def dispatch_request(self, *args, **kwargs): + method = request.method.lower() + if not hasattr(self, method): raise MethodNotAllowed( - valid_methods=self.available_methods, + valid_methods=self.methods, description="Method is not allowed for this endpoint", ) - method = request.method.lower() response = getattr(self, method)(*args, **kwargs) if isinstance(response, Response): return response return jsonify(response) - def get_methods(self): - """ - Returns available methods for this resource + +class Service: + description = textwrap.dedent( """ - return [getattr(self, method) for method in self.available_methods] + MWDB API documentation. + If you want to automate things, we recommend using + + mwdblib library + + """ + ) + servers = [ + { + "url": "{scheme}://{host}", + "description": "MWDB API endpoint", + "variables": { + "scheme": {"enum": ["http", "https"], "default": "https"}, + "host": {"default": "mwdb.cert.pl"}, + }, + } + ] -class Service: - def __init__(self, app: Flask, blueprint: Blueprint) -> None: + def __init__(self, app: Flask) -> None: self.app = app - self.blueprint = blueprint - self.blueprint.register_error_handler(Exception, self.error_handler) + self.blueprint = Blueprint("api", __name__, url_prefix="/api") self.spec = APISpec( title="MWDB Core", version=app_version, openapi_version="3.0.2", - plugins=[FlaskPlugin(), MarshmallowPlugin()], - info={ - "description": textwrap.dedent( - """ - MWDB API documentation. - - If you want to automate things, we recommend using - - mwdblib library - - """ - ) - }, - servers=[ - { - "url": "{scheme}://{host}", - "description": "MWDB API endpoint", - "variables": { - "scheme": {"enum": ["http", "https"], "default": "https"}, - "host": {"default": "mwdb.cert.pl"}, - }, - } - ], + plugins=[MarshmallowPlugin()], + info={"description": self.description}, + servers=self.servers, ) self.spec.components.security_scheme( "bearerAuth", {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"} @@ -98,20 +96,52 @@ def error_handler(self, exc: Exception) -> ResponseReturnValue: return self._make_error_response(exc) elif isinstance(exc, OperationalError): return self._make_error_response( - ServiceUnavailable("Request canceled due to statement timeout") + ServiceUnavailable( + description="Request canceled due to statement timeout" + ) ) else: # Unknown exception, return ISE 500 logger.exception("Internal server error", exc_info=sys.exc_info()) return self._make_error_response( - InternalServerError("Internal server error") + InternalServerError(description="Internal server error") ) def add_resource( self, resource: Resource, *urls: str, undocumented: bool = False ) -> None: + view = resource.as_view(resource.__name__) + endpoint = view.__name__.lower() for url in urls: - endpoint = f"{self.blueprint.name}.{resource.__name__.lower()}" - self.blueprint.add_url_rule(rule=url, endpoint=endpoint, view_func=resource) + self.blueprint.add_url_rule(rule=url, endpoint=endpoint, view_func=view) if not undocumented: - self.spec.path(view=resource, app=self.app) + resource_doc = resource.__doc__ or "" + operations = yaml_utils.load_operations_from_docstring(resource_doc) + for method in resource.methods: + method_name = method.lower() + method_doc = getattr(resource, method_name).__doc__ + if method_doc: + operations[method_name] = yaml_utils.load_yaml_from_docstring( + method_doc + ) + for url in urls: + self.spec.path(path=flaskpath2openapi(url), operations=operations) + + def register(self): + """ + Registers service and its blueprint to the app. + + This must be done after adding all resources. + """ + # This handler is intentionally set on app and not blueprint + # to catch routing errors as well. The side effect is that + # it will return jsonified error messages for static endpoints + # but static files should be handled by separate server anyway... + self.app.register_error_handler(Exception, self.error_handler) + self.app.register_blueprint(self.blueprint) + + def relative_url_for(self, resource, **values): + # TODO: Remove this along with legacy download endpoint + endpoint = self.blueprint.name + "." + resource.__name__.lower() + path = self.app.url_for(endpoint, **values) + return path[len(self.blueprint.url_prefix) :] diff --git a/requirements.txt b/requirements.txt index 10da632fe..0dec94876 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,14 @@ Werkzeug==3.0.1 gunicorn==20.1.0 alembic==1.4.2 -Flask==2.3.2 +Flask==2.3.3 Flask-SQLAlchemy==2.5.1 Flask-Migrate==3.1.0 SQLAlchemy==1.3.18 marshmallow==3.20.2 psycopg2-binary==2.8.5 requests==2.31.0 -apispec[marshmallow]==6.4.0 -apispec-webframeworks==1.0.0 +apispec[marshmallow,yaml]==6.4.0 bcrypt==3.1.4 python-magic==0.4.18 luqum==0.13.0