From 372b3d757cf63f72815227626eee89fe6470f286 Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Thu, 12 Dec 2024 16:56:07 +0100 Subject: [PATCH] Fix middleware and add positive caching. --- agenta-cli/agenta/sdk/middleware/auth.py | 125 +++++++++---------- agenta-cli/agenta/sdk/middleware/cache.py | 43 +++++++ agenta-cli/agenta/sdk/middleware/config.py | 137 ++++++++++++--------- agenta-cli/agenta/sdk/middleware/otel.py | 32 +++-- agenta-cli/agenta/sdk/middleware/vault.py | 73 ++++++----- agenta-cli/agenta/sdk/utils/timing.py | 58 +++++++++ 6 files changed, 305 insertions(+), 163 deletions(-) create mode 100644 agenta-cli/agenta/sdk/middleware/cache.py create mode 100644 agenta-cli/agenta/sdk/utils/timing.py diff --git a/agenta-cli/agenta/sdk/middleware/auth.py b/agenta-cli/agenta/sdk/middleware/auth.py index 2175616dd0..6b0b9a1bad 100644 --- a/agenta-cli/agenta/sdk/middleware/auth.py +++ b/agenta-cli/agenta/sdk/middleware/auth.py @@ -1,16 +1,16 @@ from typing import Callable, Dict, Optional from os import getenv -from traceback import format_exc +from json import dumps import httpx from starlette.middleware.base import BaseHTTPMiddleware from fastapi import FastAPI, Request from fastapi.responses import JSONResponse - -from agenta.sdk.utils.logging import log +from agenta.sdk.middleware.cache import TTLLRUCache from agenta.sdk.utils.exceptions import display_exception +from agenta.sdk.utils.timing import atimeit import agenta as ag @@ -18,6 +18,13 @@ _ALLOW_UNAUTHORIZED = ( getenv("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", "false").lower() in _TRUTHY ) +_SHARED_SERVICE = getenv("AGENTA_SHARED_SERVICE", "true").lower() in _TRUTHY +_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in _TRUTHY + +_CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512")) +_CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes + +_cache = TTLLRUCache(capacity=_CACHE_CAPACITY, ttl=_CACHE_TTL) class DenyResponse(JSONResponse): @@ -49,29 +56,42 @@ def __init__(self, app: FastAPI): super().__init__(app) self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host + self.resource_id = ( + ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id + if not _SHARED_SERVICE + else None + ) - self.resource_id = None - self.resource_type = None - - if ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id: - self.resource_id = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id - self.resource_type = "service" + async def dispatch(self, request: Request, call_next: Callable): + try: + if _ALLOW_UNAUTHORIZED: + request.state.auth = None - elif ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.app_id: - self.resource_id = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.app_id - self.resource_type = "application" + else: + credentials = await self._get_credentials(request) - async def dispatch( - self, - request: Request, - call_next: Callable, - ): - print("--- agenta/sdk/middleware/auth.py ---") - request.state.auth = None + request.state.auth = {"credentials": credentials} - if _ALLOW_UNAUTHORIZED: return await call_next(request) + except DenyException as deny: + display_exception("Auth Middleware Exception") + + return DenyResponse( + status_code=deny.status_code, + detail=deny.content, + ) + + except: # pylint: disable=bare-except + display_exception("Auth Middleware Exception") + + return DenyResponse( + status_code=500, + detail="Internal Server Error: auth middleware.", + ) + + # @atimeit + async def _get_credentials(self, request: Request) -> Optional[str]: try: authorization = request.headers.get("authorization", None) @@ -90,65 +110,35 @@ async def dispatch( or request.query_params.get("project_id") ) - params = { - "action": "run_service", - "resource_type": self.resource_type, - "resource_id": self.resource_id, - } + params = {"action": "run_service", "resource_type": "service"} + + if self.resource_id: + params["resource_id"] = self.resource_id if project_id: params["project_id"] = project_id - print("-----------------------------------") - print(headers) - print(cookies) - print(params) - print("-----------------------------------") - - credentials = await self._get_credentials( - params=params, - headers=headers, - cookies=cookies, + _hash = dumps( + { + "headers": headers, + "cookies": cookies, + "params": params, + }, + sort_keys=True, ) - request.state.auth = {"credentials": credentials} + if _CACHE_ENABLED: + credentials = _cache.get(_hash) - print(request.state.auth) + if credentials: + return credentials - return await call_next(request) - - except DenyException as deny: - display_exception("Auth Middleware Exception") - - return DenyResponse( - status_code=deny.status_code, - detail=deny.content, - ) - - except: # pylint: disable=bare-except - display_exception("Auth Middleware Exception") - - return DenyResponse( - status_code=500, - detail="Internal Server Error: auth middleware.", - ) - - async def _get_credentials( - self, - params: Optional[Dict[str, str]] = None, - headers: Optional[Dict[str, str]] = None, - cookies: Optional[str] = None, - ): - if not headers: - raise DenyException(content="Missing 'authorization' header.") - - try: async with httpx.AsyncClient() as client: response = await client.get( f"{self.host}/api/permissions/verify", headers=headers, - params=params, cookies=cookies, + params=params, ) if response.status_code == 401: @@ -173,7 +163,8 @@ async def _get_credentials( ) credentials = auth.get("credentials") - # --- # + + _cache.put(_hash, credentials) return credentials diff --git a/agenta-cli/agenta/sdk/middleware/cache.py b/agenta-cli/agenta/sdk/middleware/cache.py new file mode 100644 index 0000000000..5445b1fafc --- /dev/null +++ b/agenta-cli/agenta/sdk/middleware/cache.py @@ -0,0 +1,43 @@ +from time import time +from collections import OrderedDict + + +class TTLLRUCache: + def __init__(self, capacity: int, ttl: int): + self.cache = OrderedDict() + self.capacity = capacity + self.ttl = ttl + + def get(self, key): + # CACHE + if key not in self.cache: + return None + + value, expiry = self.cache[key] + # ----- + + # TTL + if time() > expiry: + del self.cache[key] + + return None + # --- + + # LRU + self.cache.move_to_end(key) + # --- + + return value + + def put(self, key, value): + # CACHE + if key in self.cache: + del self.cache[key] + # CACHE & LRU + elif len(self.cache) >= self.capacity: + self.cache.popitem(last=False) + # ----------- + + # TTL + self.cache[key] = (value, time() + self.ttl) + # --- diff --git a/agenta-cli/agenta/sdk/middleware/config.py b/agenta-cli/agenta/sdk/middleware/config.py index c91e5cf11f..119a206424 100644 --- a/agenta-cli/agenta/sdk/middleware/config.py +++ b/agenta-cli/agenta/sdk/middleware/config.py @@ -1,6 +1,8 @@ -from typing import Callable, Optional, Dict +from typing import Callable, Optional, Tuple, Dict + +from os import getenv +from json import dumps -from uuid import UUID from pydantic import BaseModel from starlette.middleware.base import BaseHTTPMiddleware @@ -8,10 +10,20 @@ import httpx +from agenta.sdk.middleware.cache import TTLLRUCache from agenta.sdk.utils.exceptions import suppress +from agenta.sdk.utils.timing import atimeit import agenta as ag +_TRUTHY = {"true", "1", "t", "y", "yes", "on", "enable", "enabled"} +_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in _TRUTHY + +_CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512")) +_CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes + +_cache = TTLLRUCache(capacity=_CACHE_CAPACITY, ttl=_CACHE_TTL) + class Reference(BaseModel): id: Optional[str] = None @@ -134,79 +146,90 @@ async def dispatch( request: Request, call_next: Callable, ): - print("--- agenta/sdk/middleware/config.py ---") request.state.config = None with suppress(): - application_ref = await _parse_application_ref(request) - variant_ref = await _parse_variant_ref(request) - environment_ref = await _parse_environment_ref(request) - - auth = request.state.auth or {} + parameters, references = await self._get_config(request) - headers = { - "Authorization": auth.get("credentials"), + request.state.config = { + "parameters": parameters, + "references": references, } - refs = {} - if application_ref: - refs["application_ref"] = application_ref.model_dump() - if variant_ref: - refs["variant_ref"] = variant_ref.model_dump() - if environment_ref: - refs["environment_ref"] = environment_ref.model_dump() + return await call_next(request) - config = await self._get_config( - headers=headers, - refs=refs, - ) + # @atimeit + async def _get_config(self, request: Request) -> Optional[Tuple[Dict, Dict]]: + application_ref = await _parse_application_ref(request) + variant_ref = await _parse_variant_ref(request) + environment_ref = await _parse_environment_ref(request) - if config: - parameters = config.get("params") + auth = request.state.auth or {} - references = {} + headers = { + "Authorization": auth.get("credentials"), + } - for ref_key in ["application_ref", "variant_ref", "environment_ref"]: - refs = config.get(ref_key) - ref_prefix = ref_key.split("_", maxsplit=1)[0] + refs = {} + if application_ref: + refs["application_ref"] = application_ref.model_dump() + if variant_ref: + refs["variant_ref"] = variant_ref.model_dump() + if environment_ref: + refs["environment_ref"] = environment_ref.model_dump() - for ref_part_key in ["id", "slug", "version"]: - ref_part = refs.get(ref_part_key) + if not refs: + return None, None - if ref_part: - references[ref_prefix + "." + ref_part_key] = ref_part + _hash = dumps( + { + "headers": headers, + "refs": refs, + }, + sort_keys=True, + ) - request.state.config = { - "parameters": parameters, - "references": references, - } + if _CACHE_ENABLED: + config_cache = _cache.get(_hash) - print(request.state.config) + if config_cache: + parameters = config_cache.get("parameters") + references = config_cache.get("references") - return await call_next(request) + return parameters, references - async def _get_config( - self, - headers: Dict[str, str], - refs: Dict[str, str], - ): - if not refs: - return None + config = None + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.host}/api/variants/configs/fetch", + headers=headers, + json=refs, + ) + + if response.status_code != 200: + return None + + config = response.json() + + if not config: + _cache.put(_hash, {"parameters": None, "references": None}) + + return None, None + + parameters = config.get("params") + + references = {} - try: - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.host}/api/variants/configs/fetch", - headers=headers, - json=refs, - ) + for ref_key in ["application_ref", "variant_ref", "environment_ref"]: + refs = config.get(ref_key) + ref_prefix = ref_key.split("_", maxsplit=1)[0] - if response.status_code != 200: - return None + for ref_part_key in ["id", "slug", "version"]: + ref_part = refs.get(ref_part_key) - config = response.json() + if ref_part: + references[ref_prefix + "." + ref_part_key] = ref_part - return config + _cache.put(_hash, {"parameters": parameters, "references": references}) - except: # pylint: disable=bare-except - return None + return parameters, references diff --git a/agenta-cli/agenta/sdk/middleware/otel.py b/agenta-cli/agenta/sdk/middleware/otel.py index bca6acc7ab..51f3154e16 100644 --- a/agenta-cli/agenta/sdk/middleware/otel.py +++ b/agenta-cli/agenta/sdk/middleware/otel.py @@ -3,10 +3,11 @@ from starlette.middleware.base import BaseHTTPMiddleware from fastapi import Request, FastAPI -from agenta.sdk.utils.exceptions import suppress - from opentelemetry.baggage.propagation import W3CBaggagePropagator +from agenta.sdk.utils.exceptions import suppress +from agenta.sdk.utils.timing import atimeit + class OTelMiddleware(BaseHTTPMiddleware): def __init__(self, app: FastAPI): @@ -16,15 +17,26 @@ async def dispatch(self, request: Request, call_next: Callable): request.state.otel = None with suppress(): - baggage = {"baggage": request.headers.get("Baggage", "")} + baggage = await self._get_baggage(request) - context = W3CBaggagePropagator().extract(baggage) + request.state.otel = {"baggage": baggage} - if context: - request.state.otel = {"baggage": {}} + return await call_next(request) - for partial in context.values(): - for key, value in partial.items(): - request.state.otel["baggage"][key] = value + # @atimeit + async def _get_baggage( + self, + request, + ): + _baggage = {"baggage": request.headers.get("Baggage", "")} - return await call_next(request) + context = W3CBaggagePropagator().extract(_baggage) + + baggage = {} + + if context: + for partial in context.values(): + for key, value in partial.items(): + baggage[key] = value + + return baggage diff --git a/agenta-cli/agenta/sdk/middleware/vault.py b/agenta-cli/agenta/sdk/middleware/vault.py index ec7f73165a..a69135b188 100644 --- a/agenta-cli/agenta/sdk/middleware/vault.py +++ b/agenta-cli/agenta/sdk/middleware/vault.py @@ -1,13 +1,26 @@ -from typing import Callable, Dict +from typing import Callable, Dict, Optional + +from os import getenv +from json import dumps import httpx from starlette.middleware.base import BaseHTTPMiddleware from fastapi import FastAPI, Request +from agenta.sdk.middleware.cache import TTLLRUCache from agenta.sdk.utils.exceptions import suppress +from agenta.sdk.utils.timing import atimeit import agenta as ag +_TRUTHY = {"true", "1", "t", "y", "yes", "on", "enable", "enabled"} +_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in _TRUTHY + +_CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512")) +_CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes + +_cache = TTLLRUCache(capacity=_CACHE_CAPACITY, ttl=_CACHE_TTL) + class VaultMiddleware(BaseHTTPMiddleware): def __init__(self, app: FastAPI): @@ -20,42 +33,44 @@ async def dispatch( request: Request, call_next: Callable, ): - print("--- agenta/sdk/middleware/vault.py ---") request.state.vault = None with suppress(): - headers = { - "Authorization": request.state.auth.get("credentials"), - } + secrets = await self._get_secrets(request) - secrets = await self._get_secrets( - headers=headers, - ) - - if secrets: - request.state.vault = { - "secrets": secrets, - } - - print(request.state.vault) + request.state.vault = {"secrets": secrets} return await call_next(request) - async def _get_secrets( - self, - headers: Dict[str, str], - ): - try: - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.host}/api/vault/v1/secrets", - headers=headers, - ) + # @atimeit + async def _get_secrets(self, request: Request) -> Optional[Dict]: + headers = {"Authorization": request.state.auth.get("credentials")} + + _hash = dumps( + { + "headers": headers, + }, + sort_keys=True, + ) - vault = response.json() + if _CACHE_ENABLED: + secrets_cache = _cache.get(_hash) - secrets = vault.get("secrets") + if secrets_cache: + secrets = secrets_cache.get("secrets") return secrets - except: # pylint: disable=bare-except - return None + + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.host}/api/vault/v1/secrets", + headers=headers, + ) + + vault = response.json() + + secrets = vault.get("secrets") + + _cache.put(_hash, {"secrets": secrets}) + + return secrets diff --git a/agenta-cli/agenta/sdk/utils/timing.py b/agenta-cli/agenta/sdk/utils/timing.py new file mode 100644 index 0000000000..c73b5f210d --- /dev/null +++ b/agenta-cli/agenta/sdk/utils/timing.py @@ -0,0 +1,58 @@ +import time +from functools import wraps + +from agenta.sdk.utils.logging import log + + +def timeit(func): + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + + execution_time = end_time - start_time + + if execution_time < 1e-3: + time_value = execution_time * 1e6 + unit = "us" + elif execution_time < 1: + time_value = execution_time * 1e3 + unit = "ms" + else: + time_value = execution_time + unit = "s" + + class_name = args[0].__class__.__name__ if args else None + + log.info(f"'{class_name}.{func.__name__}' executed in {time_value:.4f} {unit}.") + return result + + return wrapper + + +def atimeit(func): + @wraps(func) + async def wrapper(*args, **kwargs): + start_time = time.time() + result = await func(*args, **kwargs) + end_time = time.time() + + execution_time = end_time - start_time + + if execution_time < 1e-3: + time_value = execution_time * 1e6 + unit = "us" + elif execution_time < 1: + time_value = execution_time * 1e3 + unit = "ms" + else: + time_value = execution_time + unit = "s" + + class_name = args[0].__class__.__name__ if args else None + + log.info(f"'{class_name}.{func.__name__}' executed in {time_value:.4f} {unit}.") + return result + + return wrapper