diff --git a/aikido_zen/__init__.py b/aikido_zen/__init__.py index 7b2d5e34..c983cb5e 100644 --- a/aikido_zen/__init__.py +++ b/aikido_zen/__init__.py @@ -6,8 +6,9 @@ from dotenv import load_dotenv -# Re-export set_current_user : +# Re-export functions : from aikido_zen.context.users import set_user +from aikido_zen.middleware import should_block_request # Import logger from aikido_zen.helpers.logging import logger diff --git a/aikido_zen/context/__init__.py b/aikido_zen/context/__init__.py index 4d5a6ce5..251221e6 100644 --- a/aikido_zen/context/__init__.py +++ b/aikido_zen/context/__init__.py @@ -63,6 +63,8 @@ def __init__(self, context_obj=None, body=None, req=None, source=None): self.route_params = extract_route_params(self.url) self.subdomains = get_subdomains_from_url(self.url) + self.executed_middleware = False + def __reduce__(self): return ( self.__class__, @@ -81,6 +83,7 @@ def __reduce__(self): "user": self.user, "xml": self.xml, "outgoing_req_redirects": self.outgoing_req_redirects, + "executed_middleware": self.executed_middleware, "route_params": self.route_params, }, None, diff --git a/aikido_zen/context/init_test.py b/aikido_zen/context/init_test.py index f1eb7cf3..f8a9650c 100644 --- a/aikido_zen/context/init_test.py +++ b/aikido_zen/context/init_test.py @@ -68,6 +68,7 @@ def test_wsgi_context_1(): "parsed_userinput": {}, "xml": {}, "outgoing_req_redirects": [], + "executed_middleware": False, "route_params": [], } @@ -95,6 +96,7 @@ def test_wsgi_context_2(): "parsed_userinput": {}, "xml": {}, "outgoing_req_redirects": [], + "executed_middleware": False, "route_params": [], } diff --git a/aikido_zen/context/users.py b/aikido_zen/context/users.py index b693ecd6..29e51adc 100644 --- a/aikido_zen/context/users.py +++ b/aikido_zen/context/users.py @@ -20,6 +20,12 @@ def set_user(user): if not context: logger.debug("No context set, returning") return + if context.executed_middleware is True: + # Middleware to rate-limit/check for users ran already. Could be misconfiguration. + logger.warning( + "set_user(...) must be called before the Zen middleware is executed." + ) + validated_user["lastIpAddress"] = context.remote_address context.user = validated_user diff --git a/aikido_zen/context/users_test.py b/aikido_zen/context/users_test.py index 8e876ba9..608416e6 100644 --- a/aikido_zen/context/users_test.py +++ b/aikido_zen/context/users_test.py @@ -1,6 +1,41 @@ +from lib2to3.fixes.fix_input import context + import pytest +from . import current_context, Context from .users import validate_user, set_user +from .. import should_block_request + + +@pytest.fixture(autouse=True) +def run_around_tests(): + yield + # Make sure to reset context and cache after every test so it does not + # interfere with other tests + current_context.set(None) + + +def set_context_and_lifecycle(): + wsgi_request = { + "REQUEST_METHOD": "GET", + "HTTP_HEADER_1": "header 1 value", + "HTTP_HEADER_2": "Header 2 value", + "RANDOM_VALUE": "Random value", + "HTTP_COOKIE": "sessionId=abc123xyz456;", + "wsgi.url_scheme": "http", + "HTTP_HOST": "localhost:8080", + "PATH_INFO": "/hello", + "QUERY_STRING": "user=JohnDoe&age=30&age=35", + "CONTENT_TYPE": "application/json", + "REMOTE_ADDR": "198.51.100.23", + } + context = Context( + req=wsgi_request, + body=None, + source="flask", + ) + context.set_as_current_context() + return context def test_validate_user_valid_input(): @@ -67,3 +102,57 @@ def test_validate_user_invalid_user_type_dict_without_id(caplog): def test_set_user_with_none(caplog): result = set_user(None) assert "expects a dict with 'id' and 'name' properties" in caplog.text + + +def test_set_valid_user(): + context1 = set_context_and_lifecycle() + assert context1.user is None + + user = {"id": 456, "name": "Bob"} + set_user(user) + + assert context1.user == { + "id": "456", + "name": "Bob", + "lastIpAddress": "198.51.100.23", + } + + +def test_re_set_valid_user(): + context1 = set_context_and_lifecycle() + assert context1.user is None + + user = {"id": 456, "name": "Bob"} + set_user(user) + + assert context1.user == { + "id": "456", + "name": "Bob", + "lastIpAddress": "198.51.100.23", + } + + user = {"id": "1000", "name": "Alice"} + set_user(user) + + assert context1.user == { + "id": "1000", + "name": "Alice", + "lastIpAddress": "198.51.100.23", + } + + +def test_after_middleware(caplog): + context1 = set_context_and_lifecycle() + assert context1.user is None + + should_block_request() + + user = {"id": 456, "name": "Bob"} + set_user(user) + + assert "must be called before the Zen middleware is executed" in caplog.text + assert context1.user == { + "id": "456", + "name": "Bob", + "lastIpAddress": "198.51.100.23", + } diff --git a/aikido_zen/middleware/__init__.py b/aikido_zen/middleware/__init__.py new file mode 100644 index 00000000..3afe3664 --- /dev/null +++ b/aikido_zen/middleware/__init__.py @@ -0,0 +1,7 @@ +"""Re-exports middleware""" + +from .asgi import AikidoASGIMiddleware as AikidoQuartMiddleware +from .asgi import AikidoASGIMiddleware as AikidoStarletteMiddleware +from .flask import AikidoFlaskMiddleware +from .django import AikidoDjangoMiddleware +from .should_block_request import should_block_request diff --git a/aikido_zen/middleware/asgi.py b/aikido_zen/middleware/asgi.py new file mode 100644 index 00000000..5b7e04be --- /dev/null +++ b/aikido_zen/middleware/asgi.py @@ -0,0 +1,47 @@ +"""Exports ratelimiting and user blocking middleware for ASGI""" + +from aikido_zen.helpers.logging import logger +from .should_block_request import should_block_request + + +class AikidoASGIMiddleware: + """Ratelimiting and user blocking middleware for ASGI""" + + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + result = should_block_request() + if result["block"] is not True: + return await self.app(scope, receive, send) + + if result["type"] == "ratelimited": + message = "You are rate limited by Zen." + if result["trigger"] == "ip" and result["ip"]: + message += " (Your IP: " + result["ip"] + ")" + return await send_status_code_and_text(send, (message, 429)) + elif result["type"] == "blocked": + return await send_status_code_and_text( + send, ("You are blocked by Zen.", 403) + ) + + logger.debug("Unknown type for blocking request: %s", result["type"]) + return await self.app(scope, receive, send) + + +async def send_status_code_and_text(send, pre_response): + """Sends a status code and text""" + await send( + { + "type": "http.response.start", + "status": pre_response[1], + "headers": [(b"content-type", b"text/plain")], + } + ) + await send( + { + "type": "http.response.body", + "body": pre_response[0].encode("utf-8"), + "more_body": False, + } + ) diff --git a/aikido_zen/middleware/django.py b/aikido_zen/middleware/django.py new file mode 100644 index 00000000..4adaad11 --- /dev/null +++ b/aikido_zen/middleware/django.py @@ -0,0 +1,37 @@ +"""Exports AikidoDjangoMiddleware""" + +from aikido_zen.helpers.logging import logger +from .should_block_request import should_block_request + + +class AikidoDjangoMiddleware: + """Middleware for rate-limiting and user blocking for django""" + + def __init__(self, get_response): + logger.critical("Django middleware ised") + self.get_response = get_response + try: + from django.http import HttpResponse + + self.HttpResponse = HttpResponse + except ImportError: + logger.warning( + "django.http import not working, aikido rate-limiting middleware not running." + ) + + def __call__(self, request): + result = should_block_request() + if result["block"] is not True or self.HttpResponse is None: + return self.get_response(request) + + if result["type"] == "ratelimited": + message = "You are rate limited by Zen." + if result["trigger"] == "ip" and result["ip"]: + message += " (Your IP: " + result["ip"] + ")" + return self.HttpResponse(message, content_type="text/plain", status=429) + elif result["type"] == "blocked": + return self.HttpResponse( + "You are blocked by Zen.", content_type="text/plain", status=403 + ) + logger.debug("Unknown type for blocking request: %s", result["type"]) + return self.get_response(request) diff --git a/aikido_zen/middleware/flask.py b/aikido_zen/middleware/flask.py new file mode 100644 index 00000000..e34fef79 --- /dev/null +++ b/aikido_zen/middleware/flask.py @@ -0,0 +1,38 @@ +"""Exports ratelimiting and user blocking middleware for Flask""" + +from aikido_zen.helpers.logging import logger +from .should_block_request import should_block_request + + +class AikidoFlaskMiddleware: + """Ratelimiting and user blocking middleware for Flask""" + + def __init__(self, app): + self.app = app + try: + from werkzeug.wrappers import Response + + self.Response = Response + except ImportError: + logger.warning( + "Something went wrong whilst importing werkzeug.wrappers, middleware does not work" + ) + + def __call__(self, environ, start_response): + result = should_block_request() + if result["block"] is not True or self.Response is None: + return self.app(environ, start_response) + + if result["type"] == "ratelimited": + message = "You are rate limited by Zen." + if result["trigger"] == "ip" and result["ip"]: + message += " (Your IP: " + result["ip"] + ")" + res = self.Response(message, mimetype="text/plain", status=429) + return res(environ, start_response) + elif result["type"] == "blocked": + res = self.Response( + "You are blocked by Zen.", mimetype="text/plain", status=403 + ) + return res(environ, start_response) + logger.debug("Unknown type for blocking request: %s", result["type"]) + return self.app(environ, start_response) diff --git a/aikido_zen/middleware/init_test.py b/aikido_zen/middleware/init_test.py new file mode 100644 index 00000000..a274386a --- /dev/null +++ b/aikido_zen/middleware/init_test.py @@ -0,0 +1,164 @@ +from unittest.mock import patch, MagicMock + +import pytest +from aikido_zen.context import current_context, Context, get_current_context +from aikido_zen.thread.thread_cache import ThreadCache, threadlocal_storage +from . import should_block_request + + +@pytest.fixture(autouse=True) +def run_around_tests(): + yield + # Make sure to reset context and cache after every test so it does not + # interfere with other tests + current_context.set(None) + + +def test_without_context(): + current_context.set(None) + assert should_block_request() == {"block": False} + + +def set_context(user=None, executed_middleware=False): + Context( + context_obj={ + "remote_address": "::1", + "method": "POST", + "url": "http://localhost:4000", + "query": { + "abc": "def", + }, + "headers": {}, + "body": None, + "cookies": {}, + "source": "flask", + "route": "/posts/:id", + "user": user, + "executed_middleware": executed_middleware, + } + ).set_as_current_context() + + +class MyThreadCache(ThreadCache): + def renew_if_ttl_expired(self): + return + + +def test_with_context_without_cache(): + set_context() + threadlocal_storage.cache = None + assert should_block_request() == {"block": False} + + +def test_with_context_with_cache(): + set_context(user={"id": "123"}) + threadCache = MyThreadCache() + + threadCache.blocked_uids = ["123"] + assert get_current_context().executed_middleware == False + assert should_block_request() == { + "block": True, + "trigger": "user", + "type": "blocked", + } + assert get_current_context().executed_middleware == True + + threadCache.blocked_uids = [] + assert should_block_request() == {"block": False} + + threadCache.blocked_uids = ["23", "234", "456"] + assert should_block_request() == {"block": False} + assert get_current_context().executed_middleware == True + + +def test_cache_comms_with_endpoints(): + set_context(user={"id": "456"}) + threadCache = MyThreadCache() + threadCache.blocked_uids = ["123"] + threadCache.endpoints = [ + { + "method": "POST", + "route": "/login", + "forceProtectionOff": False, + "rateLimiting": { + "enabled": True, + "maxRequests": 3, + "windowSizeInMS": 1000, + }, + } + ] + assert get_current_context().executed_middleware == False + + with patch("aikido_zen.background_process.comms.get_comms") as mock_get_comms: + mock_get_comms.return_value = None # Set the return value of get_comms + assert should_block_request() == {"block": False} + assert get_current_context().executed_middleware == True + + with patch("aikido_zen.background_process.comms.get_comms") as mock_get_comms: + mock_comms = MagicMock() + mock_get_comms.return_value = mock_comms # Set the return value of get_comms + + # No matching endpoints : + assert should_block_request() == {"block": False} + mock_comms.send_data_to_bg_process.assert_not_called() + + threadCache.endpoints.append( + { + "method": "POST", + "route": "/posts/:id", + "forceProtectionOff": False, + "rateLimiting": { + "enabled": False, + "maxRequests": 3, + "windowSizeInMS": 1000, + }, + } + ) + + with patch("aikido_zen.background_process.comms.get_comms") as mock_get_comms: + mock_comms = MagicMock() + mock_get_comms.return_value = mock_comms # Set the return value of get_comms + + # Rate-limiting disabled: + assert should_block_request() == {"block": False} + mock_comms.send_data_to_bg_process.assert_not_called() + + # Enable ratelimiting + threadCache.endpoints[1]["rateLimiting"]["enabled"] = True + + with patch("aikido_zen.background_process.comms.get_comms") as mock_get_comms: + mock_comms = MagicMock() + mock_get_comms.return_value = mock_comms # Set the return value of get_comms + mock_comms.send_data_to_bg_process.return_value = {"success": False} + + assert should_block_request() == {"block": False} + mock_comms.send_data_to_bg_process.assert_called_with( + action="SHOULD_RATELIMIT", + obj={ + "route_metadata": { + "method": "POST", + "route": "/posts/:id", + "url": "http://localhost:4000", + }, + "user": {"id": "456"}, + "remote_address": "::1", + }, + receive=True, + ) + + mock_comms.send_data_to_bg_process.return_value = { + "success": True, + "data": {"block": False, "trigger": "my_trigger"}, + } + assert should_block_request() == {"block": False} + + mock_comms.send_data_to_bg_process.return_value = { + "success": True, + "data": {"block": True, "trigger": "my_trigger"}, + } + assert should_block_request() == { + "block": True, + "ip": "::1", + "type": "ratelimited", + "trigger": "my_trigger", + } diff --git a/aikido_zen/middleware/should_block_request.py b/aikido_zen/middleware/should_block_request.py new file mode 100644 index 00000000..289e3c6a --- /dev/null +++ b/aikido_zen/middleware/should_block_request.py @@ -0,0 +1,58 @@ +"""Exports function should_block_request""" + +from aikido_zen.helpers.logging import logger +from aikido_zen.context import get_current_context +from aikido_zen.thread.thread_cache import get_cache +import aikido_zen.background_process.comms as c +from aikido_zen.ratelimiting.get_ratelimited_endpoint import get_ratelimited_endpoint +from aikido_zen.helpers.match_endpoints import match_endpoints + + +def should_block_request(): + """ + Checks for rate-limiting and checks if the current user is blocked. + """ + try: + context = get_current_context() + cache = get_cache() + if not context or not cache: + return {"block": False} + context.executed_middleware = ( + True # Update context with middleware execution set to true + ) + context.set_as_current_context() + + # Blocked users: + if context.user and cache.is_user_blocked(context.user["id"]): + return {"block": True, "type": "blocked", "trigger": "user"} + + route_metadata = context.get_route_metadata() + endpoints = getattr(cache, "endpoints", None) + comms = c.get_comms() + if not comms or not endpoints: + return {"block": False} + matched_endpoints = match_endpoints(route_metadata, endpoints) + # Ratelimiting : + if matched_endpoints and get_ratelimited_endpoint( + matched_endpoints, context.route + ): + # As an optimization check if the route is rate limited before sending over IPC + ratelimit_res = comms.send_data_to_bg_process( + action="SHOULD_RATELIMIT", + obj={ + "route_metadata": route_metadata, + "user": context.user, + "remote_address": context.remote_address, + }, + receive=True, + ) + if ratelimit_res["success"] and ratelimit_res["data"]["block"]: + return { + "block": True, + "type": "ratelimited", + "trigger": ratelimit_res["data"]["trigger"], + "ip": context.remote_address, + } + except Exception as e: + logger.debug("Exception occured in should_block_request: %s", e) + return {"block": False} diff --git a/aikido_zen/sources/functions/request_handler.py b/aikido_zen/sources/functions/request_handler.py index 5623a1c6..51853083 100644 --- a/aikido_zen/sources/functions/request_handler.py +++ b/aikido_zen/sources/functions/request_handler.py @@ -35,8 +35,6 @@ def pre_response(): """ This is executed at the end of the middleware chain before a response is present - IP Allowlist - - Blocked users - - Ratelimiting """ context = ctx.get_current_context() comms = communications.get_comms() @@ -44,11 +42,7 @@ def pre_response(): logger.debug("Request was not complete, not running any pre_response code") return - # Blocked users: - if context.user and get_cache() and get_cache().is_user_blocked(context.user["id"]): - return ("You are blocked by Aikido Firewall.", 403) - - # Fetch endpoints for IP Allowlist and ratelimiting : + # Fetch endpoints for IP Allowlist : route_metadata = context.get_route_metadata() endpoints = getattr(get_cache(), "endpoints", None) if not endpoints: @@ -66,24 +60,6 @@ def pre_response(): message += f" (Your IP: {context.remote_address})" return (message, 403) - # Ratelimiting : - if get_ratelimited_endpoint(matched_endpoints, context.route): - # As an optimization check if the route is rate limited before sending over IPC - ratelimit_res = comms.send_data_to_bg_process( - action="SHOULD_RATELIMIT", - obj={ - "route_metadata": route_metadata, - "user": context.user, - "remote_address": context.remote_address, - }, - receive=True, - ) - if ratelimit_res["success"] and ratelimit_res["data"]["block"]: - message = "You are rate limited by Zen" - if ratelimit_res["data"]["trigger"] == "ip": - message += f" (Your IP: {context.remote_address})" - return (message, 429) - def post_response(status_code): """Checks if the current route is useful""" diff --git a/docs/django.md b/docs/django.md index ad785047..ae2ab1c1 100644 --- a/docs/django.md +++ b/docs/django.md @@ -34,6 +34,24 @@ AIKIDO_BLOCKING=true It's recommended to enable this on your staging environment for a considerable amount of time before enabling it on your production environment (e.g. one week). +## Rate limiting and user blocking +If you want to add the rate limiting feature to your app, modify your code like this: +```py +# settings.py file : + +MIDDLEWARE = [ + # Authorization middleware here (Make sure aikido middleware runs after this) + "aikido_zen.middleware.AikidoDjangoMiddleware", + # ... +] +``` +As soon as you identify the user in you authorization middleware, pass the identity info to Aikido. +```py +from aikido_zen import set_user + +# Set a user (presumably in middleware) : +set_user({"id": "123", "name": "John Doe"}) +``` ## Debug mode If you need to debug the firewall, you can run your code with the environment variable `AIKIDO_DEBUG` set to `true`: diff --git a/docs/flask.md b/docs/flask.md index 10392e63..d29e5b21 100644 --- a/docs/flask.md +++ b/docs/flask.md @@ -33,6 +33,25 @@ AIKIDO_BLOCKING=true It's recommended to enable this on your staging environment for a considerable amount of time before enabling it on your production environment (e.g. one week). +## Rate limiting and user blocking +If you want to add the rate limiting feature to your app, modify your code like this: +```py +from aikido_zen.middleware import AikidoFlaskMiddleware + +app = Flask(__name__) +# ... +app.wsgi_app = AikidoFlaskMiddleware(app.wsgi_app) +# ... +# Authorization middleware here (Make sure aikido middleware runs after this) +# ... +``` +As soon as you identify the user in you authorization middleware, pass the identity info to Aikido. +```py +from aikido_zen import set_user + +# Set a user (presumably in middleware) : +set_user({"id": "123", "name": "John Doe"}) +``` ## Debug mode If you need to debug the firewall, you can run your code with the environment variable `AIKIDO_DEBUG` set to `true`: diff --git a/docs/quart.md b/docs/quart.md index 92f732e9..0aacf947 100644 --- a/docs/quart.md +++ b/docs/quart.md @@ -49,6 +49,24 @@ AIKIDO_BLOCKING=true It's recommended to enable this on your staging environment for a considerable amount of time before enabling it on your production environment (e.g. one week). +## Rate limiting and user blocking +If you want to add the rate limiting feature to your app, modify your code like this: +```py +from aikido_zen.middleware import AikidoQuartMiddleware + +app = Quart(__name__) +app.asgi_app = AikidoQuartMiddleware(app.asgi_app) +# Authorization middleware here (Make sure aikido middleware runs after this) +``` + +As soon as you identify the user in you authorization middleware, pass the identity info to Aikido. +```py +from aikido_zen import set_user + +# Set a user (presumably in middleware) : +set_user({"id": "123", "name": "John Doe"}) +``` + ## Debug mode If you need to debug the firewall, you can run your code with the environment variable `AIKIDO_DEBUG` set to `true`: diff --git a/docs/starlette.md b/docs/starlette.md index 3b8d1f02..20dec290 100644 --- a/docs/starlette.md +++ b/docs/starlette.md @@ -30,6 +30,32 @@ AIKIDO_BLOCKING=true It's recommended to enable this on your staging environment for a considerable amount of time before enabling it on your production environment (e.g. one week). +## Rate limiting and user blocking +If you want to add the rate limiting feature to your app, modify your code like this: +```py +... +from starlette.middleware import Middleware +from aikido_zen.middleware import AikidoStarletteMiddleware + +app = Starlette(routes=[ + ... +], middleware=[ + ... + # Authorization middleware here (Make sure aikido middleware runs after this) + ... + Middleware(AikidoStarletteMiddleware), + ... +]) +``` + +As soon as you identify the user in you authorization middleware, pass the identity info to Aikido. +```py +from aikido_zen import set_user + +# Set a user (presumably in middleware) : +set_user({"id": "123", "name": "John Doe"}) +``` + ## Debug mode If you need to debug the firewall, you can run your code with the environment variable `AIKIDO_DEBUG` set to `true`: diff --git a/end2end/flask_mysql_test.py b/end2end/flask_mysql_test.py index 3b0e09af..d7e83e12 100644 --- a/end2end/flask_mysql_test.py +++ b/end2end/flask_mysql_test.py @@ -37,16 +37,16 @@ def test_dangerous_response_with_firewall(): assert len(attacks) == 1 del attacks[0]["attack"]["stack"] - assert attacks[0]["attack"] == { - "blocked": True, - "kind": "sql_injection", - 'metadata': {'sql': 'INSERT INTO dogs (dog_name, isAdmin) VALUES ("Dangerous bobby", 1); -- ", 0)'}, - 'operation': 'pymysql.Cursor.execute', - 'pathToPayload': '.dog_name', - 'payload': '"Dangerous bobby\\", 1); -- "', - 'source': "body", - 'user': None - } + assert attacks[0]["attack"]["blocked"] == True + assert attacks[0]["attack"]["kind"] == "sql_injection" + assert attacks[0]["attack"]["metadata"]["sql"] == 'INSERT INTO dogs (dog_name, isAdmin) VALUES ("Dangerous bobby", 1); -- ", 0)' + assert attacks[0]["attack"]["operation"] == 'pymysql.Cursor.execute' + assert attacks[0]["attack"]["pathToPayload"] == '.dog_name' + assert attacks[0]["attack"]["payload"] == '"Dangerous bobby\\", 1); -- "' + assert attacks[0]["attack"]["source"] == "body" + assert attacks[0]["attack"]["user"]["id"] == "123" + assert attacks[0]["attack"]["user"]["name"] == "John Doe" + def test_dangerous_response_with_firewall_route_params(): events = fetch_events_from_mock("http://localhost:5000") @@ -60,17 +60,15 @@ def test_dangerous_response_with_firewall_route_params(): assert len(attacks) == 2 del attacks[0] - del attacks[0]["attack"]["stack"] - assert attacks[0]["attack"] == { - "blocked": True, - "kind": "shell_injection", - 'metadata': {'command': 'ls -la'}, - 'operation': 'subprocess.Popen', - 'pathToPayload': '.command', - 'payload': '"ls -la"', - 'source': "route_params", - 'user': None - } + assert attacks[0]["attack"]["blocked"] == True + assert attacks[0]["attack"]["kind"] == "shell_injection" + assert attacks[0]["attack"]['metadata']['command'] == 'ls -la' + assert attacks[0]["attack"]["operation"] == 'subprocess.Popen' + assert attacks[0]["attack"]["pathToPayload"] == '.command' + assert attacks[0]["attack"]["payload"] == '"ls -la"' + assert attacks[0]["attack"]["source"] == "route_params" + assert attacks[0]["attack"]["user"]["id"] == "123" + assert attacks[0]["attack"]["user"]["name"] == "John Doe" def test_dangerous_response_without_firewall(): diff --git a/end2end/quart_postgres_uvicorn_test.py b/end2end/quart_postgres_uvicorn_test.py index ca8c804d..ebea4a9a 100644 --- a/end2end/quart_postgres_uvicorn_test.py +++ b/end2end/quart_postgres_uvicorn_test.py @@ -35,16 +35,16 @@ def test_dangerous_response_with_firewall(): assert len(attacks) == 1 del attacks[0]["attack"]["stack"] - assert attacks[0]["attack"] == { - "blocked": True, - "kind": "sql_injection", - 'metadata': {'sql': "INSERT INTO dogs (dog_name, isAdmin) VALUES ('Dangerous Bobby', TRUE); -- ', FALSE)"}, - 'operation': "asyncpg.connection.Connection.execute", - 'pathToPayload': '.dog_name', - 'payload': "\"Dangerous Bobby', TRUE); -- \"", - 'source': "body", - 'user': None - } + assert attacks[0]["attack"]["blocked"] == True + assert attacks[0]["attack"]["kind"] == "sql_injection" + assert attacks[0]["attack"]["metadata"]["sql"] == "INSERT INTO dogs (dog_name, isAdmin) VALUES ('Dangerous Bobby', TRUE); -- ', FALSE)" + assert attacks[0]["attack"]["operation"] == "asyncpg.connection.Connection.execute" + assert attacks[0]["attack"]["pathToPayload"] == '.dog_name' + assert attacks[0]["attack"]["payload"] == "\"Dangerous Bobby', TRUE); -- \"" + assert attacks[0]["attack"]["source"] == "body" + assert attacks[0]["attack"]["user"]["id"] == "user123" + assert attacks[0]["attack"]["user"]["name"] == "John Doe" + def test_dangerous_response_without_firewall(): dog_name = "Dangerous Bobby', TRUE); -- " diff --git a/end2end/starlette_postgres_uvicorn_test.py b/end2end/starlette_postgres_uvicorn_test.py index e8e719c7..331bedcd 100644 --- a/end2end/starlette_postgres_uvicorn_test.py +++ b/end2end/starlette_postgres_uvicorn_test.py @@ -13,7 +13,7 @@ def test_firewall_started_okay(): events = fetch_events_from_mock("http://localhost:5000") started_events = filter_on_event_type(events, "started") assert len(started_events) == 1 - validate_started_event(started_events[0], []) + validate_started_event(started_events[0], None) # Don't assert stack def test_safe_response_with_firewall(): dog_name = "Bobby Tables" @@ -37,16 +37,16 @@ def test_dangerous_response_with_firewall(): assert len(attacks) == 1 del attacks[0]["attack"]["stack"] - assert attacks[0]["attack"] == { - "blocked": True, - "kind": "sql_injection", - 'metadata': {'sql': "INSERT INTO dogs (dog_name, isAdmin) VALUES ('Dangerous Bobby', TRUE); -- ', FALSE)"}, - 'operation': "asyncpg.connection.Connection.execute", - 'pathToPayload': '.dog_name', - 'payload': "\"Dangerous Bobby', TRUE); -- \"", - 'source': "body", - 'user': None - } + assert attacks[0]["attack"]["blocked"] == True + assert attacks[0]["attack"]["kind"] == "sql_injection" + assert attacks[0]["attack"]["metadata"]["sql"] == "INSERT INTO dogs (dog_name, isAdmin) VALUES ('Dangerous Bobby', TRUE); -- ', FALSE)" + assert attacks[0]["attack"]["operation"] == "asyncpg.connection.Connection.execute" + assert attacks[0]["attack"]["pathToPayload"] == ".dog_name" + assert attacks[0]["attack"]["payload"] == "\"Dangerous Bobby', TRUE); -- \"" + assert attacks[0]["attack"]["source"] == "body" + assert attacks[0]["attack"]["user"]["id"] == "user123" + assert attacks[0]["attack"]["user"]["name"] == "John Doe" + def test_dangerous_response_without_firewall(): dog_name = "Dangerous Bobby', TRUE); -- " diff --git a/sample-apps/django-mysql/sample-django-mysql-app/settings.py b/sample-apps/django-mysql/sample-django-mysql-app/settings.py index b163baad..9eb6d817 100644 --- a/sample-apps/django-mysql/sample-django-mysql-app/settings.py +++ b/sample-apps/django-mysql/sample-django-mysql-app/settings.py @@ -9,7 +9,7 @@ For the full list of settings and their values, see https://docs.djangoproject.com/en/5.0/ref/settings/ """ - +import os from pathlib import Path from decouple import config @@ -58,6 +58,11 @@ 'django.middleware.clickjacking.XFrameOptionsMiddleware', ] +firewall_disabled = os.getenv("FIREWALL_DISABLED") +if firewall_disabled is not None: + if firewall_disabled.lower() != "1": + MIDDLEWARE = ["aikido_zen.middleware.AikidoDjangoMiddleware"] + MIDDLEWARE + ROOT_URLCONF = 'sample-django-mysql-app.urls' TEMPLATES = [ diff --git a/sample-apps/flask-mysql-uwsgi/app.py b/sample-apps/flask-mysql-uwsgi/app.py index 28158058..4c1d1f2c 100644 --- a/sample-apps/flask-mysql-uwsgi/app.py +++ b/sample-apps/flask-mysql-uwsgi/app.py @@ -11,6 +11,8 @@ from flaskext.mysql import MySQL app = Flask(__name__) + + if __name__ == '__main__': app.run() mysql = MySQL() diff --git a/sample-apps/flask-mysql/.env.benchmark b/sample-apps/flask-mysql/.env.benchmark index cc291cab..ed226d53 100644 --- a/sample-apps/flask-mysql/.env.benchmark +++ b/sample-apps/flask-mysql/.env.benchmark @@ -1,3 +1,4 @@ AIKIDO_DEBUG=false AIKIDO_TOKEN="AIK_secret_token" AIKIDO_BLOCKING=true +DONT_ADD_MIDDLEWARE=1 diff --git a/sample-apps/flask-mysql/app.py b/sample-apps/flask-mysql/app.py index eeb2c34b..26c9c9d6 100644 --- a/sample-apps/flask-mysql/app.py +++ b/sample-apps/flask-mysql/app.py @@ -1,5 +1,6 @@ import os firewall_disabled = os.getenv("FIREWALL_DISABLED") +dont_add_middleware = os.getenv("DONT_ADD_MIDDLEWARE") if firewall_disabled is not None: if firewall_disabled.lower() != "1": import aikido_zen # Aikido package import @@ -12,6 +13,22 @@ import subprocess app = Flask(__name__) +if firewall_disabled is not None: + if firewall_disabled.lower() != "1" and (dont_add_middleware is None or dont_add_middleware.lower() != "1"): + # Use DONT_ADD_MIDDLEWARE so we don't add this middleware during e.g. benchmarks. + import aikido_zen + from aikido_zen.middleware import AikidoFlaskMiddleware + class SetUserMiddleware: + def __init__(self, app): + self.app = app + def __call__(self, environ, start_response): + aikido_zen.set_user({"id": "123", "name": "John Doe"}) + return self.app(environ, start_response) + app.wsgi_app = AikidoFlaskMiddleware(app.wsgi_app) + app.wsgi_app = SetUserMiddleware(app.wsgi_app) + + + if __name__ == '__main__': app.run() mysql = MySQL() diff --git a/sample-apps/quart-postgres-uvicorn/app.py b/sample-apps/quart-postgres-uvicorn/app.py index f670ebd3..e5196aba 100644 --- a/sample-apps/quart-postgres-uvicorn/app.py +++ b/sample-apps/quart-postgres-uvicorn/app.py @@ -11,6 +11,20 @@ aikido_zen.protect() app = Quart(__name__) +if firewall_disabled is not None: + if firewall_disabled.lower() != "1": + import aikido_zen # Aikido package import + from aikido_zen.middleware import AikidoQuartMiddleware + class SetUserMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + aikido_zen.set_user({"id": "user123", "name": "John Doe"}) + return await self.app(scope, receive, send) + + app.asgi_app = AikidoQuartMiddleware(app.asgi_app) + app.asgi_app = SetUserMiddleware(app.asgi_app) async def get_db_connection(): return await asyncpg.connect( diff --git a/sample-apps/starlette-postgres-uvicorn/.env.benchmark b/sample-apps/starlette-postgres-uvicorn/.env.benchmark index cc291cab..ed226d53 100644 --- a/sample-apps/starlette-postgres-uvicorn/.env.benchmark +++ b/sample-apps/starlette-postgres-uvicorn/.env.benchmark @@ -1,3 +1,4 @@ AIKIDO_DEBUG=false AIKIDO_TOKEN="AIK_secret_token" AIKIDO_BLOCKING=true +DONT_ADD_MIDDLEWARE=1 diff --git a/sample-apps/starlette-postgres-uvicorn/app.py b/sample-apps/starlette-postgres-uvicorn/app.py index 34d7a885..0a1d702d 100644 --- a/sample-apps/starlette-postgres-uvicorn/app.py +++ b/sample-apps/starlette-postgres-uvicorn/app.py @@ -2,6 +2,7 @@ import os load_dotenv() firewall_disabled = os.getenv("FIREWALL_DISABLED") +dont_add_middleware = os.getenv("DONT_ADD_MIDDLEWARE") if firewall_disabled is not None: if firewall_disabled.lower() != "1": import aikido_zen # Aikido package import @@ -13,6 +14,7 @@ from starlette.routing import Route from starlette.templating import Jinja2Templates from starlette.requests import Request +from starlette.middleware import Middleware templates = Jinja2Templates(directory="templates") @@ -65,8 +67,25 @@ async def delayed_route(request: Request): def sync_route(request): data = {"message": "This is a non-async route!"} return JSONResponse(data) +middleware = [] +if firewall_disabled is not None: + if firewall_disabled.lower() != "1" and (dont_add_middleware is None or dont_add_middleware.lower() != "1"): + # Use DONT_ADD_MIDDLEWARE so we don't add this middleware during e.g. benchmarks. + import aikido_zen + from aikido_zen.middleware import AikidoStarletteMiddleware # Aikido package import + class SetUserMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + aikido_zen.set_user({"id": "user123", "name": "John Doe"}) + return await self.app(scope, receive, send) + middleware.append(Middleware(SetUserMiddleware)) + middleware.append(Middleware(AikidoStarletteMiddleware)) + -app = Starlette(routes=[ + +routes = [ Route("/", homepage), Route("/dogpage/{dog_id:int}", get_dogpage), Route("/create", show_create_dog_form, methods=["GET"]), @@ -74,4 +93,9 @@ def sync_route(request): Route("/sync_route", sync_route), Route("/just", just, methods=["GET"]), Route("/delayed_route", delayed_route, methods=["GET"]) -]) +] +if len(middleware) != 0: + app = Starlette(routes=routes, middleware=middleware) +else: + app = Starlette(routes=routes) +