From d9a5d72176346518b156c75636d3d6bdb0220e73 Mon Sep 17 00:00:00 2001 From: Sylvia McLaughlin <85905333+sylviamclaughlin@users.noreply.github.com> Date: Wed, 5 Jun 2024 16:57:26 +0000 Subject: [PATCH 01/10] Adding rate limiting to the sre-bot --- app/requirements.txt | 1 + app/server/server.py | 29 ++++++++++++++++++++++++----- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/app/requirements.txt b/app/requirements.txt index cb0d8cf8..1cd75b98 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -20,6 +20,7 @@ python-dotenv==0.21.1 python-i18n==0.3.9 pytz==2023.4 requests==2.31.0 +slowapi==0.1.9 schedule==1.2.2 slack-bolt==1.18.1 trello==0.9.7.3 diff --git a/app/server/server.py b/app/server/server.py index ea33bafd..63449d85 100644 --- a/app/server/server.py +++ b/app/server/server.py @@ -13,6 +13,9 @@ from pydantic import BaseModel, Extra from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded from models import webhooks from server.utils import log_ops_message from integrations.sentinel import log_to_sentinel @@ -29,7 +32,6 @@ logging.basicConfig(level=logging.INFO) sns_message_validator = SNSMessageValidator() - class WebhookPayload(BaseModel): channel: str | None = None text: str | None = None @@ -68,9 +70,15 @@ class AwsSnsPayload(BaseModel): class Config: extra = Extra.forbid +# initialize the limiter +limiter = Limiter(key_func=get_remote_address) handler = FastAPI() +# add the limiter to the handler +handler.state.limiter = limiter +handler.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + # Set up the templates directory and static folder for the frontend with the build folder for production if os.path.exists("../frontend/build"): # Sets the templates directory to the React build folder @@ -123,6 +131,7 @@ class Config: # Logout route. If you log out of the application, you will be redirected to the homepage @handler.route("/logout") +@limiter.limit("5/minute") async def logout(request: Request): request.session.pop("user", None) return RedirectResponse(url="/") @@ -130,6 +139,7 @@ async def logout(request: Request): # Login route. You will be redirected to the google login page @handler.get("/login") +@limiter.limit("5/minute") async def login(request: Request): # get the current environment (ie dev or prod) environment = os.environ.get("ENVIRONMENT") @@ -146,6 +156,7 @@ async def login(request: Request): # Authenticate route. This is the route that will be called after the user logs in and you are redirected to the /home page @handler.route("/auth") +@limiter.limit("5/minute") async def auth(request: Request): try: access_token = await oauth.google.authorize_access_token(request) @@ -159,6 +170,7 @@ async def auth(request: Request): # User route. Returns the user's first name that is currently logged into the application @handler.route("/user") +@limiter.limit("5/minute") async def user(request: Request): user = request.session.get("user") if user: @@ -168,7 +180,11 @@ async def user(request: Request): @handler.get("/geolocate/{ip}") -def geolocate(ip): +@limiter.limit("15/minute") +def geolocate(ip, request: Request): + print("Request: ", request) + print("Headers:", request.headers) + print("Query parameters:", request.query_params) reader = maxmind.geolocate(ip) if isinstance(reader, str): raise HTTPException(status_code=404, detail=reader) @@ -183,6 +199,7 @@ def geolocate(ip): @handler.post("/hook/{id}") +@limiter.limit("15/minute") def handle_webhook(id: str, payload: WebhookPayload | str, request: Request): webhook = webhooks.get_webhook(id) if webhook: @@ -291,7 +308,8 @@ def handle_webhook(id: str, payload: WebhookPayload | str, request: Request): @handler.get("/version") -def get_version(): +@limiter.limit("5/minute") +def get_version(request: Request): return {"version": os.environ.get("GIT_SHA", "unknown")} @@ -325,5 +343,6 @@ def append_incident_buttons(payload, webhook_id): # Defines a route handler for `/*` essentially. @handler.get("/{rest_of_path:path}") -async def react_app(req: Request, rest_of_path: str): - return templates.TemplateResponse("index.html", {"request": req}) +@limiter.limit("5/minute") +async def react_app(request: Request, rest_of_path: str): + return templates.TemplateResponse("index.html", {"request": request}) From 2431047316af243749a928573ef7118f3d8a47f6 Mon Sep 17 00:00:00 2001 From: Sylvia McLaughlin <85905333+sylviamclaughlin@users.noreply.github.com> Date: Wed, 5 Jun 2024 21:33:37 +0000 Subject: [PATCH 02/10] Adding function to extract header --- app/server/server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/app/server/server.py b/app/server/server.py index 63449d85..b74db6ac 100644 --- a/app/server/server.py +++ b/app/server/server.py @@ -79,6 +79,10 @@ class Config: handler.state.limiter = limiter handler.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + +def sentinel_hader_key(request: Request): + return request.headers.get("X-Sentinel-Auth") + # Set up the templates directory and static folder for the frontend with the build folder for production if os.path.exists("../frontend/build"): # Sets the templates directory to the React build folder From 5d41c6c44e46e643f6979aaa41d888256dd14b58 Mon Sep 17 00:00:00 2001 From: Sylvia McLaughlin <85905333+sylviamclaughlin@users.noreply.github.com> Date: Thu, 6 Jun 2024 04:04:54 +0000 Subject: [PATCH 03/10] Adding rate limiting to the SRE bot API endpoints --- app/server/server.py | 24 ++-- app/tests/server/test_server.py | 192 +++++++++++++++++++++++++++++++- 2 files changed, 207 insertions(+), 9 deletions(-) diff --git a/app/server/server.py b/app/server/server.py index b74db6ac..6818d0a9 100644 --- a/app/server/server.py +++ b/app/server/server.py @@ -80,9 +80,6 @@ class Config: handler.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) -def sentinel_hader_key(request: Request): - return request.headers.get("X-Sentinel-Auth") - # Set up the templates directory and static folder for the frontend with the build folder for production if os.path.exists("../frontend/build"): # Sets the templates directory to the React build folder @@ -132,6 +129,19 @@ def sentinel_hader_key(request: Request): client_kwargs={"scope": "openid email profile"}, ) +def sentinel_key_func(request: Request): + # Check if the 'X-Sentinel-Source' exists and is not empty + if request.headers.get('X-Sentinel-Source'): + return None # Skip rate limiting if the header exists and is not empty + return get_remote_address(request) + + +@handler.exception_handler(RateLimitExceeded) +async def rate_limit_handler(request: Request, exc: RateLimitExceeded): + return JSONResponse( + status_code=429, + content={"message": "Rate limit exceeded"} + ) # Logout route. If you log out of the application, you will be redirected to the homepage @handler.route("/logout") @@ -184,11 +194,8 @@ async def user(request: Request): @handler.get("/geolocate/{ip}") -@limiter.limit("15/minute") +@limiter.limit("10/minute", key_func=sentinel_key_func) def geolocate(ip, request: Request): - print("Request: ", request) - print("Headers:", request.headers) - print("Query parameters:", request.query_params) reader = maxmind.geolocate(ip) if isinstance(reader, str): raise HTTPException(status_code=404, detail=reader) @@ -202,8 +209,9 @@ def geolocate(ip, request: Request): } + @handler.post("/hook/{id}") -@limiter.limit("15/minute") +@limiter.limit("30/minute") # since some slack channels use this for alerting, we want to be generous with the rate limiting on this one def handle_webhook(id: str, payload: WebhookPayload | str, request: Request): webhook = webhooks.get_webhook(id) if webhook: diff --git a/app/tests/server/test_server.py b/app/tests/server/test_server.py index 420435af..1b3bfb97 100644 --- a/app/tests/server/test_server.py +++ b/app/tests/server/test_server.py @@ -1,11 +1,17 @@ from unittest import mock +from unittest.mock import patch, AsyncMock from server import bot_middleware, server import urllib.parse +from slowapi.errors import RateLimitExceeded +from starlette.responses import JSONResponse +from httpx import AsyncClient + import os import pytest from fastapi.testclient import TestClient -from unittest.mock import ANY, call, MagicMock, patch, PropertyMock +from fastapi import Request +from unittest.mock import ANY, call, MagicMock, patch, PropertyMock, Mock app = server.handler app.add_middleware(bot_middleware.BotMiddleware, bot=MagicMock()) @@ -442,3 +448,187 @@ def test_user_endpoint_with_no_logged_in_user(): response = client.get("/user") assert response.status_code == 200 assert response.json() == {"error": "Not logged in"} + + +def test_header_exists_and_not_empty(): + # Create a mock request with the header 'X-Sentinel-Source' + mock_request = Mock(spec=Request) + mock_request.headers = {'X-Sentinel-Source': 'some_value'} + + # Call the function + result = server.sentinel_key_func(mock_request) + + # Assert that the result is None (no rate limiting) + assert result is None + +def test_header_not_present(): + # Create a mock request without the header 'X-Sentinel-Source' + mock_request = Mock(spec=Request) + mock_request.headers = {} + + # Mock the client attribute to return the expected IP address + mock_request.client.host = '192.168.1.1' + + # Mock the get_remote_address function to return a specific value + with patch('slowapi.util.get_remote_address', return_value='192.168.1.1'): + result = server.sentinel_key_func(mock_request) + # Assert that the result is the IP address (rate limiting applied) + assert result == '192.168.1.1' + +def test_header_empty(): + # Create a mock request with an empty 'X-Sentinel-Source' header + mock_request = Mock(spec=Request) + mock_request.headers = {'X-Sentinel-Source': ''} + + # Mock the client attribute to return the expected IP address + mock_request.client.host = '192.168.1.1' + + # Mock the get_remote_address function to return a specific value + with patch('slowapi.util.get_remote_address', return_value='192.168.1.1'): + result = server.sentinel_key_func(mock_request) + + # Assert that the result is the IP address (rate limiting applied) + assert result == '192.168.1.1' + +@pytest.mark.asyncio +async def test_rate_limit_handler(): + # Create a mock request + mock_request = Mock(spec=Request) + + # Create a mock exception + mock_exception = Mock(spec=RateLimitExceeded) + + # Call the handler function + response = await server.rate_limit_handler(mock_request, mock_exception) + + # Assert the response is a JSONResponse + assert isinstance(response, JSONResponse) + + # Assert the status code is 429 + assert response.status_code == 429 + + # Assert the content of the response + assert response.body.decode('utf-8') == '{"message":"Rate limit exceeded"}' + + +@pytest.mark.asyncio +async def logout_test_rate_limiting(): + async with AsyncClient(app=app, base_url="http://test") as client: + # Make 5 requests to the logout endpoint + for _ in range(5): + response = await client.get("/logout") + assert response.status_code == 200 + assert response.url.path == "/" + + # The 6th request should be rate limited + response = await client.get("/logout") + assert response.status_code == 429 + assert response.json() == {"message": "Rate limit exceeded"} + + +@pytest.mark.asyncio +async def login_test_rate_limiting(): + async with AsyncClient(app=app, base_url="http://test") as client: + # Set the environment variable for the test + os.environ["ENVIRONMENT"] = "dev" + + # Make 5 requests to the login endpoint + for _ in range(5): + response = await client.get("/login") + assert response.status_code == 302 + + # The 6th request should be rate limited + response = await client.get("/login") + assert response.status_code == 429 + assert response.json() == {"message": "Rate limit exceeded"} + +@pytest.mark.asyncio +async def auth_test_rate_limiting(): + async with AsyncClient(app=app, base_url="http://test") as client: + # Mock the OAuth process + with patch('oauth.google.authorize_access_token', new_callable=AsyncMock) as mock_auth: + mock_auth.return_value = {"userinfo": {"name": "Test User"}} + + # Make 5 requests to the auth endpoint + for _ in range(5): + response = await client.get("/auth") + assert response.status_code == 200 + + # The 6th request should be rate limited + response = await client.get("/auth") + assert response.status_code == 429 + assert response.json() == {"message": "Rate limit exceeded"} + +@pytest.mark.asyncio +async def user_test_rate_limiting(): + async with AsyncClient(app=app, base_url="http://test") as client: + # Create a mock session with a user + user_data = {"given_name": "John"} + session = {"user": user_data} + + # Mock the request object to include the session + with patch("starlette.requests.Request.session", new_callable=Mock) as mock_session: + mock_session.return_value = session + + # Make 5 requests to the user endpoint + for _ in range(5): + response = await client.get("/user") + assert response.status_code == 200 + assert response.json() == {"name": "John"} + + # The 6th request should be rate limited + response = await client.get("/user") + assert response.status_code == 429 + assert response.json() == {"message": "Rate limit exceeded"} + +@pytest.mark.asyncio +async def geolocate_test_rate_limiting(): + async with AsyncClient(app=app, base_url="http://test") as client: + # Mock the maxmind.geolocate function + with patch('your_module.maxmind.geolocate', return_value=("Country", "City", 12.34, 56.78)): + # Make 10 requests to the geolocate endpoint + for _ in range(10): + response = await client.get("/geolocate/8.8.8.8") + assert response.status_code == 200 + assert response.json() == { + "country": "Country", + "city": "City", + "latitude": 12.34, + "longitude": 56.78, + } + + # The 11th request should be rate limited + response = await client.get("/geolocate/8.8.8.8") + assert response.status_code == 429 + assert response.json() == {"message": "Rate limit exceeded"} + +@pytest.mark.asyncio +async def webhook_test_rate_limiting(): + async with AsyncClient(app=app, base_url="http://test") as client: + # Mock the webhooks.get_webhook function + with patch('your_module.webhooks.get_webhook', return_value={"channel": {"S": "test-channel"}}): + with patch('your_module.webhooks.is_active', return_value=True): + with patch('your_module.webhooks.increment_invocation_count'): + with patch('your_module.sns_message_validator.validate_message'): + # Make 30 requests to the handle_webhook endpoint + for _ in range(30): + response = await client.post("/hook/test-id", json={"Type": "Notification"}) + assert response.status_code == 200 + + # The 31st request should be rate limited + response = await client.post("/hook/test-id", json={"Type": "Notification"}) + assert response.status_code == 429 + assert response.json() == {"message": "Rate limit exceeded"} + +@pytest.mark.asyncio +async def version_test_rate_limiting(): + async with AsyncClient(app=app, base_url="http://test") as client: + # Make 5 requests to the version endpoint + for _ in range(5): + response = await client.get("/version") + assert response.status_code == 200 + + # The 6th request should be rate limited + response = await client.get("/version") + assert response.status_code == 429 + assert response.json() == {"message": "Rate limit exceeded"} \ No newline at end of file From 45f04262a176bb08cf814c0fe6285fa65198946e Mon Sep 17 00:00:00 2001 From: Sylvia McLaughlin <85905333+sylviamclaughlin@users.noreply.github.com> Date: Thu, 6 Jun 2024 04:08:55 +0000 Subject: [PATCH 04/10] Adding a bigger reate limit for the rest of path --- app/server/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/server/server.py b/app/server/server.py index 6818d0a9..39925b50 100644 --- a/app/server/server.py +++ b/app/server/server.py @@ -355,6 +355,6 @@ def append_incident_buttons(payload, webhook_id): # Defines a route handler for `/*` essentially. @handler.get("/{rest_of_path:path}") -@limiter.limit("5/minute") +@limiter.limit("10/minute") async def react_app(request: Request, rest_of_path: str): return templates.TemplateResponse("index.html", {"request": request}) From 09d28b44df7a1642b2d8436e885c2bb8c61a7732 Mon Sep 17 00:00:00 2001 From: Sylvia McLaughlin <85905333+sylviamclaughlin@users.noreply.github.com> Date: Thu, 6 Jun 2024 13:50:55 +0000 Subject: [PATCH 05/10] Increasing rate limiting of version endpoint --- app/server/server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/app/server/server.py b/app/server/server.py index 39925b50..6bc2870b 100644 --- a/app/server/server.py +++ b/app/server/server.py @@ -318,9 +318,10 @@ def handle_webhook(id: str, payload: WebhookPayload | str, request: Request): else: raise HTTPException(status_code=404, detail="Webhook not found") - +# Route53 uses this as a healthcheck every 30 seconds and the alb uses this as a checkpoint every 10 seconds. +# As a result, we are giving a generous rate limit of so that we don't run into any issues with the healthchecks @handler.get("/version") -@limiter.limit("5/minute") +@limiter.limit("15/minute") def get_version(request: Request): return {"version": os.environ.get("GIT_SHA", "unknown")} From c3aaf4e84541addc468fd146848e0e35ffd71e07 Mon Sep 17 00:00:00 2001 From: Sylvia McLaughlin <85905333+sylviamclaughlin@users.noreply.github.com> Date: Thu, 6 Jun 2024 13:53:57 +0000 Subject: [PATCH 06/10] Reformatting --- app/server/server.py | 17 +++--- app/tests/server/test_server.py | 95 ++++++++++++++++++++------------- 2 files changed, 68 insertions(+), 44 deletions(-) diff --git a/app/server/server.py b/app/server/server.py index 6bc2870b..0c23c04e 100644 --- a/app/server/server.py +++ b/app/server/server.py @@ -32,6 +32,7 @@ logging.basicConfig(level=logging.INFO) sns_message_validator = SNSMessageValidator() + class WebhookPayload(BaseModel): channel: str | None = None text: str | None = None @@ -70,6 +71,7 @@ class AwsSnsPayload(BaseModel): class Config: extra = Extra.forbid + # initialize the limiter limiter = Limiter(key_func=get_remote_address) @@ -129,19 +131,18 @@ class Config: client_kwargs={"scope": "openid email profile"}, ) + def sentinel_key_func(request: Request): # Check if the 'X-Sentinel-Source' exists and is not empty - if request.headers.get('X-Sentinel-Source'): + if request.headers.get("X-Sentinel-Source"): return None # Skip rate limiting if the header exists and is not empty return get_remote_address(request) @handler.exception_handler(RateLimitExceeded) async def rate_limit_handler(request: Request, exc: RateLimitExceeded): - return JSONResponse( - status_code=429, - content={"message": "Rate limit exceeded"} - ) + return JSONResponse(status_code=429, content={"message": "Rate limit exceeded"}) + # Logout route. If you log out of the application, you will be redirected to the homepage @handler.route("/logout") @@ -209,9 +210,10 @@ def geolocate(ip, request: Request): } - @handler.post("/hook/{id}") -@limiter.limit("30/minute") # since some slack channels use this for alerting, we want to be generous with the rate limiting on this one +@limiter.limit( + "30/minute" +) # since some slack channels use this for alerting, we want to be generous with the rate limiting on this one def handle_webhook(id: str, payload: WebhookPayload | str, request: Request): webhook = webhooks.get_webhook(id) if webhook: @@ -318,6 +320,7 @@ def handle_webhook(id: str, payload: WebhookPayload | str, request: Request): else: raise HTTPException(status_code=404, detail="Webhook not found") + # Route53 uses this as a healthcheck every 30 seconds and the alb uses this as a checkpoint every 10 seconds. # As a result, we are giving a generous rate limit of so that we don't run into any issues with the healthchecks @handler.get("/version") diff --git a/app/tests/server/test_server.py b/app/tests/server/test_server.py index 1b3bfb97..278a93d4 100644 --- a/app/tests/server/test_server.py +++ b/app/tests/server/test_server.py @@ -1,5 +1,5 @@ from unittest import mock -from unittest.mock import patch, AsyncMock +from unittest.mock import ANY, call, MagicMock, patch, PropertyMock, Mock, AsyncMock from server import bot_middleware, server import urllib.parse from slowapi.errors import RateLimitExceeded @@ -11,7 +11,6 @@ import pytest from fastapi.testclient import TestClient from fastapi import Request -from unittest.mock import ANY, call, MagicMock, patch, PropertyMock, Mock app = server.handler app.add_middleware(bot_middleware.BotMiddleware, bot=MagicMock()) @@ -453,62 +452,65 @@ def test_user_endpoint_with_no_logged_in_user(): def test_header_exists_and_not_empty(): # Create a mock request with the header 'X-Sentinel-Source' mock_request = Mock(spec=Request) - mock_request.headers = {'X-Sentinel-Source': 'some_value'} - + mock_request.headers = {"X-Sentinel-Source": "some_value"} + # Call the function result = server.sentinel_key_func(mock_request) - + # Assert that the result is None (no rate limiting) assert result is None + def test_header_not_present(): # Create a mock request without the header 'X-Sentinel-Source' mock_request = Mock(spec=Request) mock_request.headers = {} - + # Mock the client attribute to return the expected IP address - mock_request.client.host = '192.168.1.1' + mock_request.client.host = "192.168.1.1" # Mock the get_remote_address function to return a specific value - with patch('slowapi.util.get_remote_address', return_value='192.168.1.1'): + with patch("slowapi.util.get_remote_address", return_value="192.168.1.1"): result = server.sentinel_key_func(mock_request) # Assert that the result is the IP address (rate limiting applied) - assert result == '192.168.1.1' + assert result == "192.168.1.1" + def test_header_empty(): # Create a mock request with an empty 'X-Sentinel-Source' header mock_request = Mock(spec=Request) - mock_request.headers = {'X-Sentinel-Source': ''} - + mock_request.headers = {"X-Sentinel-Source": ""} + # Mock the client attribute to return the expected IP address - mock_request.client.host = '192.168.1.1' - + mock_request.client.host = "192.168.1.1" + # Mock the get_remote_address function to return a specific value - with patch('slowapi.util.get_remote_address', return_value='192.168.1.1'): + with patch("slowapi.util.get_remote_address", return_value="192.168.1.1"): result = server.sentinel_key_func(mock_request) - + # Assert that the result is the IP address (rate limiting applied) - assert result == '192.168.1.1' + assert result == "192.168.1.1" + @pytest.mark.asyncio async def test_rate_limit_handler(): # Create a mock request mock_request = Mock(spec=Request) - + # Create a mock exception mock_exception = Mock(spec=RateLimitExceeded) - + # Call the handler function response = await server.rate_limit_handler(mock_request, mock_exception) - + # Assert the response is a JSONResponse assert isinstance(response, JSONResponse) - + # Assert the status code is 429 assert response.status_code == 429 - + # Assert the content of the response - assert response.body.decode('utf-8') == '{"message":"Rate limit exceeded"}' + assert response.body.decode("utf-8") == '{"message":"Rate limit exceeded"}' @pytest.mark.asyncio @@ -531,24 +533,27 @@ async def login_test_rate_limiting(): async with AsyncClient(app=app, base_url="http://test") as client: # Set the environment variable for the test os.environ["ENVIRONMENT"] = "dev" - + # Make 5 requests to the login endpoint for _ in range(5): response = await client.get("/login") - assert response.status_code == 302 + assert response.status_code == 302 # The 6th request should be rate limited response = await client.get("/login") assert response.status_code == 429 assert response.json() == {"message": "Rate limit exceeded"} + @pytest.mark.asyncio async def auth_test_rate_limiting(): async with AsyncClient(app=app, base_url="http://test") as client: # Mock the OAuth process - with patch('oauth.google.authorize_access_token', new_callable=AsyncMock) as mock_auth: + with patch( + "oauth.google.authorize_access_token", new_callable=AsyncMock + ) as mock_auth: mock_auth.return_value = {"userinfo": {"name": "Test User"}} - + # Make 5 requests to the auth endpoint for _ in range(5): response = await client.get("/auth") @@ -559,15 +564,18 @@ async def auth_test_rate_limiting(): assert response.status_code == 429 assert response.json() == {"message": "Rate limit exceeded"} + @pytest.mark.asyncio async def user_test_rate_limiting(): async with AsyncClient(app=app, base_url="http://test") as client: # Create a mock session with a user user_data = {"given_name": "John"} session = {"user": user_data} - + # Mock the request object to include the session - with patch("starlette.requests.Request.session", new_callable=Mock) as mock_session: + with patch( + "starlette.requests.Request.session", new_callable=Mock + ) as mock_session: mock_session.return_value = session # Make 5 requests to the user endpoint @@ -581,11 +589,15 @@ async def user_test_rate_limiting(): assert response.status_code == 429 assert response.json() == {"message": "Rate limit exceeded"} + @pytest.mark.asyncio async def geolocate_test_rate_limiting(): async with AsyncClient(app=app, base_url="http://test") as client: # Mock the maxmind.geolocate function - with patch('your_module.maxmind.geolocate', return_value=("Country", "City", 12.34, 56.78)): + with patch( + "your_module.maxmind.geolocate", + return_value=("Country", "City", 12.34, 56.78), + ): # Make 10 requests to the geolocate endpoint for _ in range(10): response = await client.get("/geolocate/8.8.8.8") @@ -601,25 +613,34 @@ async def geolocate_test_rate_limiting(): response = await client.get("/geolocate/8.8.8.8") assert response.status_code == 429 assert response.json() == {"message": "Rate limit exceeded"} - + + @pytest.mark.asyncio async def webhook_test_rate_limiting(): async with AsyncClient(app=app, base_url="http://test") as client: # Mock the webhooks.get_webhook function - with patch('your_module.webhooks.get_webhook', return_value={"channel": {"S": "test-channel"}}): - with patch('your_module.webhooks.is_active', return_value=True): - with patch('your_module.webhooks.increment_invocation_count'): - with patch('your_module.sns_message_validator.validate_message'): + with patch( + "your_module.webhooks.get_webhook", + return_value={"channel": {"S": "test-channel"}}, + ): + with patch("your_module.webhooks.is_active", return_value=True): + with patch("your_module.webhooks.increment_invocation_count"): + with patch("your_module.sns_message_validator.validate_message"): # Make 30 requests to the handle_webhook endpoint for _ in range(30): - response = await client.post("/hook/test-id", json={"Type": "Notification"}) + response = await client.post( + "/hook/test-id", json={"Type": "Notification"} + ) assert response.status_code == 200 # The 31st request should be rate limited - response = await client.post("/hook/test-id", json={"Type": "Notification"}) + response = await client.post( + "/hook/test-id", json={"Type": "Notification"} + ) assert response.status_code == 429 assert response.json() == {"message": "Rate limit exceeded"} + @pytest.mark.asyncio async def version_test_rate_limiting(): async with AsyncClient(app=app, base_url="http://test") as client: @@ -631,4 +652,4 @@ async def version_test_rate_limiting(): # The 6th request should be rate limited response = await client.get("/version") assert response.status_code == 429 - assert response.json() == {"message": "Rate limit exceeded"} \ No newline at end of file + assert response.json() == {"message": "Rate limit exceeded"} From e676d6e13f5038599ddbf13562aa070c4d2e3368 Mon Sep 17 00:00:00 2001 From: Sylvia McLaughlin <85905333+sylviamclaughlin@users.noreply.github.com> Date: Thu, 6 Jun 2024 13:55:56 +0000 Subject: [PATCH 07/10] Formatting --- app/server/server.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/app/server/server.py b/app/server/server.py index 0c23c04e..76557509 100644 --- a/app/server/server.py +++ b/app/server/server.py @@ -139,6 +139,7 @@ def sentinel_key_func(request: Request): return get_remote_address(request) +# Rate limit handler for RateLimitExceeded exceptions @handler.exception_handler(RateLimitExceeded) async def rate_limit_handler(request: Request, exc: RateLimitExceeded): return JSONResponse(status_code=429, content={"message": "Rate limit exceeded"}) @@ -194,6 +195,8 @@ async def user(request: Request): return JSONResponse({"error": "Not logged in"}) +# Geolocate route. Returns the country, city, latitude, and longitude of the IP address. +# If we have a custom header of 'X-Sentinel-Source', then we skip rate limiting so that Sentinel is not rate limited @handler.get("/geolocate/{ip}") @limiter.limit("10/minute", key_func=sentinel_key_func) def geolocate(ip, request: Request): From e83f3a77e406d8b674eba88ba943e6edce38a710 Mon Sep 17 00:00:00 2001 From: Sylvia McLaughlin <85905333+sylviamclaughlin@users.noreply.github.com> Date: Thu, 6 Jun 2024 18:28:23 +0000 Subject: [PATCH 08/10] Updating unit tests --- app/tests/server/test_server.py | 123 ++++++++++++++++++++++++-------- 1 file changed, 92 insertions(+), 31 deletions(-) diff --git a/app/tests/server/test_server.py b/app/tests/server/test_server.py index 278a93d4..f3dde12b 100644 --- a/app/tests/server/test_server.py +++ b/app/tests/server/test_server.py @@ -514,13 +514,13 @@ async def test_rate_limit_handler(): @pytest.mark.asyncio -async def logout_test_rate_limiting(): +async def test_logout_rate_limiting(): async with AsyncClient(app=app, base_url="http://test") as client: # Make 5 requests to the logout endpoint for _ in range(5): response = await client.get("/logout") - assert response.status_code == 200 - assert response.url.path == "/" + assert response.status_code == 307 + assert response.url.path == "/logout" # The 6th request should be rate limited response = await client.get("/logout") @@ -529,7 +529,7 @@ async def logout_test_rate_limiting(): @pytest.mark.asyncio -async def login_test_rate_limiting(): +async def test_login_rate_limiting(): async with AsyncClient(app=app, base_url="http://test") as client: # Set the environment variable for the test os.environ["ENVIRONMENT"] = "dev" @@ -546,18 +546,18 @@ async def login_test_rate_limiting(): @pytest.mark.asyncio -async def auth_test_rate_limiting(): +async def test_auth_rate_limiting(): async with AsyncClient(app=app, base_url="http://test") as client: # Mock the OAuth process with patch( - "oauth.google.authorize_access_token", new_callable=AsyncMock + "server.server.oauth.google.authorize_access_token", new_callable=AsyncMock ) as mock_auth: mock_auth.return_value = {"userinfo": {"name": "Test User"}} # Make 5 requests to the auth endpoint for _ in range(5): response = await client.get("/auth") - assert response.status_code == 200 + assert response.status_code == 307 # The 6th request should be rate limited response = await client.get("/auth") @@ -566,36 +566,96 @@ async def auth_test_rate_limiting(): @pytest.mark.asyncio -async def user_test_rate_limiting(): +async def test_user_rate_limiting(): async with AsyncClient(app=app, base_url="http://test") as client: # Create a mock session with a user - user_data = {"given_name": "John"} - session = {"user": user_data} + # user_data = {"given_name": "John"} - # Mock the request object to include the session - with patch( - "starlette.requests.Request.session", new_callable=Mock - ) as mock_session: - mock_session.return_value = session + # session_data = {"user": {"given_name": "FirstName"}} + # headers = {"Cookie": f"session={session_data}"} + # session_data = {"user": {"given_name": "FirstName"}} + # headers = {"Cookie": f"session={session_data}"} + # response = client.get("/user", headers=headers) + # assert response.status_code == 200 - # Make 5 requests to the user endpoint - for _ in range(5): - response = await client.get("/user") - assert response.status_code == 200 - assert response.json() == {"name": "John"} + # Make 5 requests to the user endpoint + for _ in range(5): + session_data = {"user": {"given_name": "FirstName"}} + headers = {"Cookie": f"session={session_data}"} + response = await client.get("/user", headers=headers) + assert response.status_code == 200 + assert response.json() == {'error': 'Not logged in'} # The 6th request should be rate limited response = await client.get("/user") assert response.status_code == 429 assert response.json() == {"message": "Rate limit exceeded"} + # Patch the session middleware to include the user session + # with patch("starlette.middleware.sessions.SessionMiddleware.__call__", new_callable=AsyncMock): + # with patch("starlette.requests.Request.session", new_callable=AsyncMock) as mock_session: + # mock_session.return_value = {"user": user_data} + + # # Make 5 requests to the user endpoint + # for _ in range(5): + # response = await client.get("/user", headers=headers) + # assert response.status_code == 200 + # assert response.json() == {"name": "John"} + + # # The 6th request should be rate limited + # response = await client.get("/user") + # assert response.status_code == 429 + # assert response.json() == {"message": "Rate limit exceeded"} +# @pytest.mark.asyncio +# async def test_user_rate_limiting(): +# async with AsyncClient(app=app, base_url="http://test") as client: +# # Create a mock session with a user +# user_data = {"given_name": "John"} + +# # Mock the session +# with patch("starlette.middleware.sessions.SessionMiddleware.__call__", new_callable=Mock): +# with patch("starlette.requests.Request.session", return_value={"user": user_data}) as mock_session: +# # Make 5 requests to the user endpoint +# for _ in range(5): +# response = await client.get("/user") +# assert response.status_code == 200 +# assert response.json() == {"name": "John"} + +# # The 6th request should be rate limited +# response = await client.get("/user") +# assert response.status_code == 429 +# assert response.json() == {"message": "Rate limit exceeded"} + +# @pytest.mark.asyncio +# async def test_user_rate_limiting(): +# async with AsyncClient(app=app, base_url="http://test") as client: +# # Create a mock session with a user +# user_data = {"given_name": "John"} +# session = {"user": user_data} + +# # Mock the request object to include the session +# with patch( +# "starlette.requests.Request.session", return_value={"user":user_data}) as mock_session: +# mock_session.return_value = session + +# # Make 5 requests to the user endpoint +# for _ in range(5): +# response = await client.get("/user") +# assert response.status_code == 200 +# assert response.json() == {"name": "John"} + +# # The 6th request should be rate limited +# response = await client.get("/user") +# assert response.status_code == 429 +# assert response.json() == {"message": "Rate limit exceeded"} + @pytest.mark.asyncio -async def geolocate_test_rate_limiting(): +async def test_geolocate_rate_limiting(): async with AsyncClient(app=app, base_url="http://test") as client: # Mock the maxmind.geolocate function with patch( - "your_module.maxmind.geolocate", + "server.server.maxmind.geolocate", return_value=("Country", "City", 12.34, 56.78), ): # Make 10 requests to the geolocate endpoint @@ -616,36 +676,37 @@ async def geolocate_test_rate_limiting(): @pytest.mark.asyncio -async def webhook_test_rate_limiting(): +async def test_webhooks_rate_limiting(): async with AsyncClient(app=app, base_url="http://test") as client: # Mock the webhooks.get_webhook function with patch( - "your_module.webhooks.get_webhook", + "server.server.webhooks.get_webhook", return_value={"channel": {"S": "test-channel"}}, ): - with patch("your_module.webhooks.is_active", return_value=True): - with patch("your_module.webhooks.increment_invocation_count"): - with patch("your_module.sns_message_validator.validate_message"): + with patch("server.server.webhooks.is_active", return_value=True): + with patch("server.server.webhooks.increment_invocation_count"): + with patch("server.server.sns_message_validator.validate_message"): # Make 30 requests to the handle_webhook endpoint + payload = '{"Type": "Notification"}' for _ in range(30): response = await client.post( - "/hook/test-id", json={"Type": "Notification"} + "/hook/test-id", json=payload ) assert response.status_code == 200 # The 31st request should be rate limited response = await client.post( - "/hook/test-id", json={"Type": "Notification"} + "/hook/test-id", json=payload ) assert response.status_code == 429 assert response.json() == {"message": "Rate limit exceeded"} @pytest.mark.asyncio -async def version_test_rate_limiting(): +async def test_version_rate_limiting(): async with AsyncClient(app=app, base_url="http://test") as client: # Make 5 requests to the version endpoint - for _ in range(5): + for _ in range(15): response = await client.get("/version") assert response.status_code == 200 From f22e328bef5cd940d0b6166287d47a7768ce38a3 Mon Sep 17 00:00:00 2001 From: Sylvia McLaughlin <85905333+sylviamclaughlin@users.noreply.github.com> Date: Thu, 6 Jun 2024 18:56:31 +0000 Subject: [PATCH 09/10] Updating unit tests --- app/tests/server/test_server.py | 94 ++++----------------------------- 1 file changed, 11 insertions(+), 83 deletions(-) diff --git a/app/tests/server/test_server.py b/app/tests/server/test_server.py index f3dde12b..84964362 100644 --- a/app/tests/server/test_server.py +++ b/app/tests/server/test_server.py @@ -519,7 +519,7 @@ async def test_logout_rate_limiting(): # Make 5 requests to the logout endpoint for _ in range(5): response = await client.get("/logout") - assert response.status_code == 307 + assert response.status_code == 307 assert response.url.path == "/logout" # The 6th request should be rate limited @@ -557,7 +557,7 @@ async def test_auth_rate_limiting(): # Make 5 requests to the auth endpoint for _ in range(5): response = await client.get("/auth") - assert response.status_code == 307 + assert response.status_code == 307 # The 6th request should be rate limited response = await client.get("/auth") @@ -568,86 +568,18 @@ async def test_auth_rate_limiting(): @pytest.mark.asyncio async def test_user_rate_limiting(): async with AsyncClient(app=app, base_url="http://test") as client: - # Create a mock session with a user - # user_data = {"given_name": "John"} - - # session_data = {"user": {"given_name": "FirstName"}} - # headers = {"Cookie": f"session={session_data}"} - # session_data = {"user": {"given_name": "FirstName"}} - # headers = {"Cookie": f"session={session_data}"} - # response = client.get("/user", headers=headers) - # assert response.status_code == 200 - + # Simulate a logged in session + session_data = {"user": {"given_name": "FirstName"}} + headers = {"Cookie": f"session={session_data}"} # Make 5 requests to the user endpoint for _ in range(5): - session_data = {"user": {"given_name": "FirstName"}} - headers = {"Cookie": f"session={session_data}"} response = await client.get("/user", headers=headers) assert response.status_code == 200 - assert response.json() == {'error': 'Not logged in'} - # The 6th request should be rate limited - response = await client.get("/user") - assert response.status_code == 429 - assert response.json() == {"message": "Rate limit exceeded"} - - # Patch the session middleware to include the user session - # with patch("starlette.middleware.sessions.SessionMiddleware.__call__", new_callable=AsyncMock): - # with patch("starlette.requests.Request.session", new_callable=AsyncMock) as mock_session: - # mock_session.return_value = {"user": user_data} - - # # Make 5 requests to the user endpoint - # for _ in range(5): - # response = await client.get("/user", headers=headers) - # assert response.status_code == 200 - # assert response.json() == {"name": "John"} - - # # The 6th request should be rate limited - # response = await client.get("/user") - # assert response.status_code == 429 - # assert response.json() == {"message": "Rate limit exceeded"} -# @pytest.mark.asyncio -# async def test_user_rate_limiting(): -# async with AsyncClient(app=app, base_url="http://test") as client: -# # Create a mock session with a user -# user_data = {"given_name": "John"} - -# # Mock the session -# with patch("starlette.middleware.sessions.SessionMiddleware.__call__", new_callable=Mock): -# with patch("starlette.requests.Request.session", return_value={"user": user_data}) as mock_session: -# # Make 5 requests to the user endpoint -# for _ in range(5): -# response = await client.get("/user") -# assert response.status_code == 200 -# assert response.json() == {"name": "John"} - -# # The 6th request should be rate limited -# response = await client.get("/user") -# assert response.status_code == 429 -# assert response.json() == {"message": "Rate limit exceeded"} - -# @pytest.mark.asyncio -# async def test_user_rate_limiting(): -# async with AsyncClient(app=app, base_url="http://test") as client: -# # Create a mock session with a user -# user_data = {"given_name": "John"} -# session = {"user": user_data} - -# # Mock the request object to include the session -# with patch( -# "starlette.requests.Request.session", return_value={"user":user_data}) as mock_session: -# mock_session.return_value = session - -# # Make 5 requests to the user endpoint -# for _ in range(5): -# response = await client.get("/user") -# assert response.status_code == 200 -# assert response.json() == {"name": "John"} - -# # The 6th request should be rate limited -# response = await client.get("/user") -# assert response.status_code == 429 -# assert response.json() == {"message": "Rate limit exceeded"} + # The 6th request should be rate limited + response = await client.get("/user", headers=headers) + assert response.status_code == 429 + assert response.json() == {"message": "Rate limit exceeded"} @pytest.mark.asyncio @@ -689,15 +621,11 @@ async def test_webhooks_rate_limiting(): # Make 30 requests to the handle_webhook endpoint payload = '{"Type": "Notification"}' for _ in range(30): - response = await client.post( - "/hook/test-id", json=payload - ) + response = await client.post("/hook/test-id", json=payload) assert response.status_code == 200 # The 31st request should be rate limited - response = await client.post( - "/hook/test-id", json=payload - ) + response = await client.post("/hook/test-id", json=payload) assert response.status_code == 429 assert response.json() == {"message": "Rate limit exceeded"} From 74a9f9b200864c5d1d334b4917fdf4ee0f49593f Mon Sep 17 00:00:00 2001 From: Sylvia McLaughlin <85905333+sylviamclaughlin@users.noreply.github.com> Date: Thu, 6 Jun 2024 19:00:55 +0000 Subject: [PATCH 10/10] Adding additional unit test for rate limiting --- app/tests/server/test_server.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/app/tests/server/test_server.py b/app/tests/server/test_server.py index 84964362..1ed52fea 100644 --- a/app/tests/server/test_server.py +++ b/app/tests/server/test_server.py @@ -642,3 +642,17 @@ async def test_version_rate_limiting(): response = await client.get("/version") assert response.status_code == 429 assert response.json() == {"message": "Rate limit exceeded"} + + +@pytest.mark.asyncio +async def test_react_app_rate_limiting(): + async with AsyncClient(app=app, base_url="http://test") as client: + # Make 10 requests to the react_app endpoint + for _ in range(10): + response = await client.get("/some-path") + assert response.status_code == 200 + + # The 11th request should be rate limited + response = await client.get("/some-path") + assert response.status_code == 429 + assert response.json() == {"message": "Rate limit exceeded"}