diff --git a/requirements.txt b/requirements.txt index 57b9356..ac2d622 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ fastapi-auth0==0.3.0 httpx structlog sentry-sdk +slowapi diff --git a/src/gsp.py b/src/gsp.py index 298baff..4b8a4bc 100644 --- a/src/gsp.py +++ b/src/gsp.py @@ -24,7 +24,7 @@ LocationWithGSPYields, OneDatetimeManyForecastValues, ) -from utils import format_datetime +from utils import N_CALLS_PER_HOUR, format_datetime, limiter GSP_TOTAL = 317 @@ -44,6 +44,7 @@ dependencies=[Depends(get_auth_implicit_scheme())], ) @cache_response +@limiter.limit(f"{N_CALLS_PER_HOUR}/hour") def get_all_available_forecasts( request: Request, historic: Optional[bool] = True, @@ -111,6 +112,7 @@ def get_all_available_forecasts( responses={status.HTTP_204_NO_CONTENT: {"model": None}}, ) @cache_response +@limiter.limit(f"{N_CALLS_PER_HOUR}/hour") def get_forecasts_for_a_specific_gsp_old_route( request: Request, gsp_id: int, @@ -135,6 +137,7 @@ def get_forecasts_for_a_specific_gsp_old_route( responses={status.HTTP_204_NO_CONTENT: {"model": None}}, ) @cache_response +@limiter.limit(f"{N_CALLS_PER_HOUR}/hour") def get_forecasts_for_a_specific_gsp( request: Request, gsp_id: int, @@ -203,6 +206,7 @@ def get_forecasts_for_a_specific_gsp( dependencies=[Depends(get_auth_implicit_scheme())], ) @cache_response +@limiter.limit(f"{N_CALLS_PER_HOUR}/hour") def get_truths_for_all_gsps( request: Request, regime: Optional[str] = None, @@ -257,6 +261,7 @@ def get_truths_for_all_gsps( responses={status.HTTP_204_NO_CONTENT: {"model": None}}, ) @cache_response +@limiter.limit(f"{N_CALLS_PER_HOUR}/hour") def get_truths_for_a_specific_gsp_old_route( request: Request, gsp_id: int, @@ -282,6 +287,7 @@ def get_truths_for_a_specific_gsp_old_route( responses={status.HTTP_204_NO_CONTENT: {"model": None}}, ) @cache_response +@limiter.limit(f"{N_CALLS_PER_HOUR}/hour") def get_truths_for_a_specific_gsp( request: Request, gsp_id: int, diff --git a/src/main.py b/src/main.py index d1fd820..2532e77 100644 --- a/src/main.py +++ b/src/main.py @@ -9,13 +9,15 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi from fastapi.responses import FileResponse +from slowapi import _rate_limit_exceeded_handler +from slowapi.errors import RateLimitExceeded from gsp import router as gsp_router from national import router as national_router from redoc_theme import get_redoc_html_with_theme from status import router as status_router from system import router as system_router -from utils import traces_sampler +from utils import limiter, traces_sampler # flake8: noqa E501 @@ -183,6 +185,9 @@ allow_headers=["*"], ) +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + @app.middleware("http") async def add_process_time_header(request: Request, call_next): diff --git a/src/national.py b/src/national.py index 3d47e74..92996de 100644 --- a/src/national.py +++ b/src/national.py @@ -16,7 +16,7 @@ get_truth_values_for_a_specific_gsp_from_database, ) from pydantic_models import NationalForecast, NationalForecastValue, NationalYield -from utils import filter_forecast_values, format_datetime, format_plevels +from utils import N_CALLS_PER_HOUR, filter_forecast_values, format_datetime, format_plevels, limiter logger = structlog.stdlib.get_logger() @@ -33,6 +33,7 @@ dependencies=[Depends(get_auth_implicit_scheme())], ) @cache_response +@limiter.limit(f"{N_CALLS_PER_HOUR}/hour") def get_national_forecast( request: Request, session: Session = Depends(get_session), @@ -156,6 +157,7 @@ def get_national_forecast( dependencies=[Depends(get_auth_implicit_scheme())], ) @cache_response +@limiter.limit(f"{N_CALLS_PER_HOUR}/hour") def get_national_pvlive( request: Request, regime: Optional[str] = None, diff --git a/src/status.py b/src/status.py index 336556e..548ccbf 100644 --- a/src/status.py +++ b/src/status.py @@ -10,6 +10,7 @@ from cache import cache_response from database import get_latest_status_from_database, get_session, save_api_call_to_db +from utils import N_CALLS_PER_HOUR, limiter logger = structlog.stdlib.get_logger() @@ -20,6 +21,7 @@ @router.get("/status", response_model=Status) @cache_response +@limiter.limit(f"{N_CALLS_PER_HOUR}/hour") def get_status(request: Request, session: Session = Depends(get_session)) -> Status: """### Get status for the database and forecasts @@ -32,6 +34,7 @@ def get_status(request: Request, session: Session = Depends(get_session)) -> Sta @router.get("/check_last_forecast_run", include_in_schema=False) +@limiter.limit(f"{N_CALLS_PER_HOUR}/hour") def check_last_forecast(request: Request, session: Session = Depends(get_session)) -> datetime: """Check to that a forecast has run with in the last 2 hours""" diff --git a/src/system.py b/src/system.py index 4e9029f..4e6da2a 100644 --- a/src/system.py +++ b/src/system.py @@ -13,6 +13,7 @@ from auth_utils import get_auth_implicit_scheme, get_user from cache import cache_response from database import get_gsp_system, get_session +from utils import N_CALLS_PER_HOUR, limiter # flake8: noqa: E501 logger = structlog.stdlib.get_logger() @@ -43,6 +44,7 @@ def get_gsp_boundaries_from_eso_wgs84() -> gpd.GeoDataFrame: dependencies=[Depends(get_auth_implicit_scheme())], ) @cache_response +@limiter.limit(f"{N_CALLS_PER_HOUR}/hour") def get_gsp_boundaries( request: Request, session: Session = Depends(get_session), @@ -75,6 +77,7 @@ def get_gsp_boundaries( dependencies=[Depends(get_auth_implicit_scheme())], ) @cache_response +@limiter.limit(f"{N_CALLS_PER_HOUR}/hour") def get_system_details( request: Request, session: Session = Depends(get_session), diff --git a/src/utils.py b/src/utils.py index a633574..61b421c 100644 --- a/src/utils.py +++ b/src/utils.py @@ -7,6 +7,8 @@ import structlog from nowcasting_datamodel.models import Forecast from pytz import timezone +from slowapi import Limiter +from slowapi.util import get_remote_address from pydantic_models import NationalForecastValue @@ -14,6 +16,8 @@ europe_london_tz = timezone("Europe/London") utc = timezone("UTC") +limiter = Limiter(key_func=get_remote_address) +N_CALLS_PER_HOUR = 3600 def floor_30_minutes_dt(dt):