Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Rate limiting to the SRE bots endpoints #528

Merged
merged 10 commits into from
Jun 12, 2024
1 change: 1 addition & 0 deletions app/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ python-dotenv==0.21.1
python-i18n==0.3.9
pytz==2023.4
requests==2.31.0
slowapi==0.1.9
schedule==1.2.2
slack-bolt==1.18.1
trello==0.9.7.3
Expand Down
46 changes: 42 additions & 4 deletions app/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from pydantic import BaseModel, Extra
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from models import webhooks
from server.utils import log_ops_message
from integrations.sentinel import log_to_sentinel
Expand Down Expand Up @@ -69,8 +72,16 @@ class Config:
extra = Extra.forbid


# initialize the limiter
limiter = Limiter(key_func=get_remote_address)

handler = FastAPI()

# add the limiter to the handler
handler.state.limiter = limiter
handler.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)


# Set up the templates directory and static folder for the frontend with the build folder for production
if os.path.exists("../frontend/build"):
# Sets the templates directory to the React build folder
Expand Down Expand Up @@ -121,15 +132,30 @@ class Config:
)


def sentinel_key_func(request: Request):
# Check if the 'X-Sentinel-Source' exists and is not empty
if request.headers.get("X-Sentinel-Source"):
return None # Skip rate limiting if the header exists and is not empty
return get_remote_address(request)


# Rate limit handler for RateLimitExceeded exceptions
@handler.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
return JSONResponse(status_code=429, content={"message": "Rate limit exceeded"})


# Logout route. If you log out of the application, you will be redirected to the homepage
@handler.route("/logout")
@limiter.limit("5/minute")
async def logout(request: Request):
request.session.pop("user", None)
return RedirectResponse(url="/")


# Login route. You will be redirected to the google login page
@handler.get("/login")
@limiter.limit("5/minute")
async def login(request: Request):
# get the current environment (ie dev or prod)
environment = os.environ.get("ENVIRONMENT")
Expand All @@ -146,6 +172,7 @@ async def login(request: Request):

# Authenticate route. This is the route that will be called after the user logs in and you are redirected to the /home page
@handler.route("/auth")
@limiter.limit("5/minute")
async def auth(request: Request):
try:
access_token = await oauth.google.authorize_access_token(request)
Expand All @@ -159,6 +186,7 @@ async def auth(request: Request):

# User route. Returns the user's first name that is currently logged into the application
@handler.route("/user")
@limiter.limit("5/minute")
async def user(request: Request):
user = request.session.get("user")
if user:
Expand All @@ -167,8 +195,11 @@ async def user(request: Request):
return JSONResponse({"error": "Not logged in"})


# Geolocate route. Returns the country, city, latitude, and longitude of the IP address.
# If we have a custom header of 'X-Sentinel-Source', then we skip rate limiting so that Sentinel is not rate limited
@handler.get("/geolocate/{ip}")
def geolocate(ip):
@limiter.limit("10/minute", key_func=sentinel_key_func)
def geolocate(ip, request: Request):
reader = maxmind.geolocate(ip)
if isinstance(reader, str):
raise HTTPException(status_code=404, detail=reader)
Expand All @@ -183,6 +214,9 @@ def geolocate(ip):


@handler.post("/hook/{id}")
@limiter.limit(
"30/minute"
) # since some slack channels use this for alerting, we want to be generous with the rate limiting on this one
def handle_webhook(id: str, payload: WebhookPayload | str, request: Request):
webhook = webhooks.get_webhook(id)
if webhook:
Expand Down Expand Up @@ -290,8 +324,11 @@ def handle_webhook(id: str, payload: WebhookPayload | str, request: Request):
raise HTTPException(status_code=404, detail="Webhook not found")


# Route53 uses this as a healthcheck every 30 seconds and the alb uses this as a checkpoint every 10 seconds.
# As a result, we are giving a generous rate limit of so that we don't run into any issues with the healthchecks
@handler.get("/version")
def get_version():
@limiter.limit("15/minute")
def get_version(request: Request):
return {"version": os.environ.get("GIT_SHA", "unknown")}


Expand Down Expand Up @@ -325,5 +362,6 @@ def append_incident_buttons(payload, webhook_id):

# Defines a route handler for `/*` essentially.
@handler.get("/{rest_of_path:path}")
async def react_app(req: Request, rest_of_path: str):
return templates.TemplateResponse("index.html", {"request": req})
@limiter.limit("10/minute")
async def react_app(request: Request, rest_of_path: str):
return templates.TemplateResponse("index.html", {"request": request})
216 changes: 215 additions & 1 deletion app/tests/server/test_server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from unittest import mock
from unittest.mock import ANY, call, MagicMock, patch, PropertyMock, Mock, AsyncMock
from server import bot_middleware, server
import urllib.parse
from slowapi.errors import RateLimitExceeded
from starlette.responses import JSONResponse
from httpx import AsyncClient


import os
import pytest
from fastapi.testclient import TestClient
from unittest.mock import ANY, call, MagicMock, patch, PropertyMock
from fastapi import Request

app = server.handler
app.add_middleware(bot_middleware.BotMiddleware, bot=MagicMock())
Expand Down Expand Up @@ -442,3 +447,212 @@ def test_user_endpoint_with_no_logged_in_user():
response = client.get("/user")
assert response.status_code == 200
assert response.json() == {"error": "Not logged in"}


def test_header_exists_and_not_empty():
# Create a mock request with the header 'X-Sentinel-Source'
mock_request = Mock(spec=Request)
mock_request.headers = {"X-Sentinel-Source": "some_value"}

# Call the function
result = server.sentinel_key_func(mock_request)

# Assert that the result is None (no rate limiting)
assert result is None


def test_header_not_present():
# Create a mock request without the header 'X-Sentinel-Source'
mock_request = Mock(spec=Request)
mock_request.headers = {}

# Mock the client attribute to return the expected IP address
mock_request.client.host = "192.168.1.1"

# Mock the get_remote_address function to return a specific value
with patch("slowapi.util.get_remote_address", return_value="192.168.1.1"):
result = server.sentinel_key_func(mock_request)
# Assert that the result is the IP address (rate limiting applied)
assert result == "192.168.1.1"


def test_header_empty():
# Create a mock request with an empty 'X-Sentinel-Source' header
mock_request = Mock(spec=Request)
mock_request.headers = {"X-Sentinel-Source": ""}

# Mock the client attribute to return the expected IP address
mock_request.client.host = "192.168.1.1"

# Mock the get_remote_address function to return a specific value
with patch("slowapi.util.get_remote_address", return_value="192.168.1.1"):
result = server.sentinel_key_func(mock_request)

# Assert that the result is the IP address (rate limiting applied)
assert result == "192.168.1.1"


@pytest.mark.asyncio
async def test_rate_limit_handler():
# Create a mock request
mock_request = Mock(spec=Request)

# Create a mock exception
mock_exception = Mock(spec=RateLimitExceeded)

# Call the handler function
response = await server.rate_limit_handler(mock_request, mock_exception)

# Assert the response is a JSONResponse
assert isinstance(response, JSONResponse)

# Assert the status code is 429
assert response.status_code == 429

# Assert the content of the response
assert response.body.decode("utf-8") == '{"message":"Rate limit exceeded"}'


@pytest.mark.asyncio
async def test_logout_rate_limiting():
async with AsyncClient(app=app, base_url="http://test") as client:
# Make 5 requests to the logout endpoint
for _ in range(5):
response = await client.get("/logout")
assert response.status_code == 307
assert response.url.path == "/logout"

# The 6th request should be rate limited
response = await client.get("/logout")
assert response.status_code == 429
assert response.json() == {"message": "Rate limit exceeded"}


@pytest.mark.asyncio
async def test_login_rate_limiting():
async with AsyncClient(app=app, base_url="http://test") as client:
# Set the environment variable for the test
os.environ["ENVIRONMENT"] = "dev"

# Make 5 requests to the login endpoint
for _ in range(5):
response = await client.get("/login")
assert response.status_code == 302

# The 6th request should be rate limited
response = await client.get("/login")
assert response.status_code == 429
assert response.json() == {"message": "Rate limit exceeded"}


@pytest.mark.asyncio
async def test_auth_rate_limiting():
async with AsyncClient(app=app, base_url="http://test") as client:
# Mock the OAuth process
with patch(
"server.server.oauth.google.authorize_access_token", new_callable=AsyncMock
) as mock_auth:
mock_auth.return_value = {"userinfo": {"name": "Test User"}}

# Make 5 requests to the auth endpoint
for _ in range(5):
response = await client.get("/auth")
assert response.status_code == 307

# The 6th request should be rate limited
response = await client.get("/auth")
assert response.status_code == 429
assert response.json() == {"message": "Rate limit exceeded"}


@pytest.mark.asyncio
async def test_user_rate_limiting():
async with AsyncClient(app=app, base_url="http://test") as client:
# Simulate a logged in session
session_data = {"user": {"given_name": "FirstName"}}
headers = {"Cookie": f"session={session_data}"}
# Make 5 requests to the user endpoint
for _ in range(5):
response = await client.get("/user", headers=headers)
assert response.status_code == 200

# The 6th request should be rate limited
response = await client.get("/user", headers=headers)
assert response.status_code == 429
assert response.json() == {"message": "Rate limit exceeded"}


@pytest.mark.asyncio
async def test_geolocate_rate_limiting():
async with AsyncClient(app=app, base_url="http://test") as client:
# Mock the maxmind.geolocate function
with patch(
"server.server.maxmind.geolocate",
return_value=("Country", "City", 12.34, 56.78),
):
# Make 10 requests to the geolocate endpoint
for _ in range(10):
response = await client.get("/geolocate/8.8.8.8")
assert response.status_code == 200
assert response.json() == {
"country": "Country",
"city": "City",
"latitude": 12.34,
"longitude": 56.78,
}

# The 11th request should be rate limited
response = await client.get("/geolocate/8.8.8.8")
assert response.status_code == 429
assert response.json() == {"message": "Rate limit exceeded"}


@pytest.mark.asyncio
async def test_webhooks_rate_limiting():
async with AsyncClient(app=app, base_url="http://test") as client:
# Mock the webhooks.get_webhook function
with patch(
"server.server.webhooks.get_webhook",
return_value={"channel": {"S": "test-channel"}},
):
with patch("server.server.webhooks.is_active", return_value=True):
with patch("server.server.webhooks.increment_invocation_count"):
with patch("server.server.sns_message_validator.validate_message"):
# Make 30 requests to the handle_webhook endpoint
payload = '{"Type": "Notification"}'
for _ in range(30):
response = await client.post("/hook/test-id", json=payload)
assert response.status_code == 200

# The 31st request should be rate limited
response = await client.post("/hook/test-id", json=payload)
assert response.status_code == 429
assert response.json() == {"message": "Rate limit exceeded"}


@pytest.mark.asyncio
async def test_version_rate_limiting():
async with AsyncClient(app=app, base_url="http://test") as client:
# Make 5 requests to the version endpoint
for _ in range(15):
response = await client.get("/version")
assert response.status_code == 200

# The 6th request should be rate limited
response = await client.get("/version")
assert response.status_code == 429
assert response.json() == {"message": "Rate limit exceeded"}


@pytest.mark.asyncio
async def test_react_app_rate_limiting():
async with AsyncClient(app=app, base_url="http://test") as client:
# Make 10 requests to the react_app endpoint
for _ in range(10):
response = await client.get("/some-path")
assert response.status_code == 200

# The 11th request should be rate limited
response = await client.get("/some-path")
assert response.status_code == 429
assert response.json() == {"message": "Rate limit exceeded"}
Loading