Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: replace Flask-Restful with own, simple implementation #916

Merged
merged 5 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions mwdb/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class HashConverter(BaseConverter):
app.register_blueprint(static_blueprint)


@app.before_request
@api.blueprint.before_request
def assign_request_id():
g.request_id = token_hex(16)
g.request_start_time = datetime.utcnow()
Expand All @@ -152,7 +152,7 @@ def assign_request_id():
)


@app.after_request
@api.blueprint.after_request
def log_request(response):
if hasattr(g, "request_start_time"):
response_time = datetime.utcnow() - g.request_start_time
Expand Down Expand Up @@ -186,7 +186,7 @@ def log_request(response):
return response


@app.before_request
@api.blueprint.before_request
def require_auth():
if request.method == "OPTIONS":
return
Expand Down Expand Up @@ -221,7 +221,7 @@ def require_auth():
raise Forbidden("User has been disabled.")


@app.before_request
@api.blueprint.before_request
def apply_rate_limit():
apply_rate_limit_for_request()

Expand Down Expand Up @@ -412,3 +412,6 @@ def apply_rate_limit():
plugin_context = PluginAppContext()
with app.app_context():
load_plugins(plugin_context)

# Register blueprint
api.register()
91 changes: 0 additions & 91 deletions mwdb/core/apispec_utils.py

This file was deleted.

6 changes: 2 additions & 4 deletions mwdb/core/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from flask import Blueprint, Flask
from flask import Flask
from werkzeug.middleware.proxy_fix import ProxyFix

from mwdb.core.config import app_config
Expand All @@ -7,9 +7,7 @@

app = Flask(__name__, static_folder=None)
app.config["MAX_CONTENT_LENGTH"] = app_config.mwdb.max_upload_size
api_blueprint = Blueprint("api", __name__, url_prefix="/api")
api = Service(app, api_blueprint)
app.register_blueprint(api_blueprint)
api = Service(app)

if app_config.mwdb.use_x_forwarded_for:
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1)
5 changes: 4 additions & 1 deletion mwdb/core/rate_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def apply_rate_limit_for_request() -> bool:
):
return False
# Split blueprint name and resource name from endpoint
_, resource_name = request.endpoint.split(".", 2)
if request.endpoint:
_, resource_name = request.endpoint.split(".", 2)
else:
resource_name = "None"
method = request.method.lower()
user = g.auth_user.login if g.auth_user is not None else request.remote_addr
# Limit keys from most specific to the least specific
Expand Down
209 changes: 131 additions & 78 deletions mwdb/core/service.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,150 @@
import re
import sys
import textwrap
from functools import partial

from apispec import APISpec
from apispec import APISpec, yaml_utils
from apispec.ext.marshmallow import MarshmallowPlugin
from flask_restful import Api
from flask import Blueprint, Flask, jsonify, request
from flask.typing import ResponseReturnValue
from flask.views import MethodView
from sqlalchemy.exc import OperationalError
from werkzeug.exceptions import HTTPException, ServiceUnavailable
from werkzeug.exceptions import (
HTTPException,
InternalServerError,
MethodNotAllowed,
ServiceUnavailable,
)
from werkzeug.wrappers import Response

from mwdb.version import app_version

from . import log
from .apispec_utils import ApispecFlaskRestful


class Service(Api):
def __init__(self, flask_app, *args, **kwargs):
self.spec = self._create_spec()
self.flask_app = flask_app
super().__init__(*args, **kwargs)

def _init_app(self, app):
# I want to log exceptions on my own
def dont_log(*_, **__):
pass

app.log_exception = dont_log
if (
isinstance(app.handle_exception, partial)
and app.handle_exception.func is self.error_router
):
# Prevent double-initialization
return
super()._init_app(app)

def _create_spec(self):
spec = APISpec(
title="MWDB",
from .log import getLogger

logger = getLogger()


def flaskpath2openapi(path: str) -> str:
"""Convert a Flask URL rule to an OpenAPI-compliant path.

Got from https://github.com/marshmallow-code/apispec-webframeworks/

:param str path: Flask path template.
"""
# from flask-restplus
re_url = re.compile(r"<(?:[^:<>]+:)?([^<>]+)>")
return re_url.sub(r"{\1}", path)


class Resource(MethodView):
init_every_request = False

def dispatch_request(self, *args, **kwargs):
method = request.method.lower()
if not hasattr(self, method):
raise MethodNotAllowed(
valid_methods=self.methods,
description="Method is not allowed for this endpoint",
)
response = getattr(self, method)(*args, **kwargs)
if isinstance(response, Response):
return response
return jsonify(response)


class Service:
description = textwrap.dedent(
"""
MWDB API documentation.

If you want to automate things, we recommend using
<a href="https://github.com/CERT-Polska/mwdblib">
mwdblib library
</a>
"""
)
servers = [
{
"url": "{scheme}://{host}",
"description": "MWDB API endpoint",
"variables": {
"scheme": {"enum": ["http", "https"], "default": "https"},
"host": {"default": "mwdb.cert.pl"},
},
}
]

def __init__(self, app: Flask) -> None:
self.app = app
self.blueprint = Blueprint("api", __name__, url_prefix="/api")
self.spec = APISpec(
title="MWDB Core",
version=app_version,
openapi_version="3.0.2",
plugins=[ApispecFlaskRestful(), MarshmallowPlugin()],
plugins=[MarshmallowPlugin()],
info={"description": self.description},
servers=self.servers,
)

spec.components.security_scheme(
self.spec.components.security_scheme(
"bearerAuth", {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
)
spec.options["info"] = {
"description": textwrap.dedent(
"""
MWDB API documentation.

If you want to automate things, we recommend using
<a href="http://github.com/CERT-Polska/mwdblib">mwdblib library</a>"""
def _make_error_response(self, exc: HTTPException) -> ResponseReturnValue:
return jsonify({"message": exc.description}), exc.code

def error_handler(self, exc: Exception) -> ResponseReturnValue:
if isinstance(exc, HTTPException):
return self._make_error_response(exc)
elif isinstance(exc, OperationalError):
return self._make_error_response(
ServiceUnavailable(
description="Request canceled due to statement timeout"
)
)
}
spec.options["servers"] = [
{
"url": "{scheme}://{host}",
"description": "MWDB API endpoint",
"variables": {
"scheme": {"enum": ["http", "https"], "default": "https"},
"host": {"default": "mwdb.cert.pl"},
},
}
]
return spec

def error_router(self, original_handler, e):
logger = log.getLogger()
if isinstance(e, HTTPException):
logger.error(str(e))
elif isinstance(e, OperationalError):
logger.error(str(e))
raise ServiceUnavailable("Request canceled due to statement timeout")
else:
logger.exception("Unhandled exception occurred")

# Handle all exceptions using handle_error, not only for owned routes
try:
return self.handle_error(e)
except Exception:
logger.exception("Exception from handle_error occurred")
pass
# If something went wrong - fallback to original behavior
return super().error_router(original_handler, e)

def add_resource(self, resource, *urls, undocumented=False, **kwargs):
super().add_resource(resource, *urls, **kwargs)
# Unknown exception, return ISE 500
logger.exception("Internal server error", exc_info=sys.exc_info())
return self._make_error_response(
InternalServerError(description="Internal server error")
)

def add_resource(
self, resource: Resource, *urls: str, undocumented: bool = False
) -> None:
view = resource.as_view(resource.__name__)
endpoint = view.__name__.lower()
for url in urls:
self.blueprint.add_url_rule(rule=url, endpoint=endpoint, view_func=view)
if not undocumented:
self.spec.path(resource=resource, api=self, app=self.flask_app)
resource_doc = resource.__doc__ or ""
operations = yaml_utils.load_operations_from_docstring(resource_doc)
for method in resource.methods:
method_name = method.lower()
method_doc = getattr(resource, method_name).__doc__
if method_doc:
operations[method_name] = yaml_utils.load_yaml_from_docstring(
method_doc
)
for url in urls:
prefixed_url = self.blueprint.url_prefix + "/" + url.lstrip("/")
self.spec.path(
path=flaskpath2openapi(prefixed_url), operations=operations
)

def register(self):
"""
Registers service and its blueprint to the app.

This must be done after adding all resources.
"""
# This handler is intentionally set on app and not blueprint
# to catch routing errors as well. The side effect is that
# it will return jsonified error messages for static endpoints
# but static files should be handled by separate server anyway...
self.app.register_error_handler(Exception, self.error_handler)
self.app.register_blueprint(self.blueprint)

def relative_url_for(self, resource, **values):
path = self.url_for(resource, **values)
# TODO: Remove this along with legacy download endpoint
endpoint = self.blueprint.name + "." + resource.__name__.lower()
path = self.app.url_for(endpoint, **values)
return path[len(self.blueprint.url_prefix) :]

def endpoint_for(self, resource):
return f"{self.blueprint.name}.{resource}"
Loading
Loading