Skip to content

Commit

Permalink
Fix middleware and add positive caching.
Browse files Browse the repository at this point in the history
  • Loading branch information
jp-agenta committed Dec 12, 2024
1 parent 15c6b85 commit 372b3d7
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 163 deletions.
125 changes: 58 additions & 67 deletions agenta-cli/agenta/sdk/middleware/auth.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
from typing import Callable, Dict, Optional

from os import getenv
from traceback import format_exc
from json import dumps

import httpx
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse


from agenta.sdk.utils.logging import log
from agenta.sdk.middleware.cache import TTLLRUCache
from agenta.sdk.utils.exceptions import display_exception
from agenta.sdk.utils.timing import atimeit

import agenta as ag

_TRUTHY = {"true", "1", "t", "y", "yes", "on", "enable", "enabled"}
_ALLOW_UNAUTHORIZED = (
getenv("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", "false").lower() in _TRUTHY
)
_SHARED_SERVICE = getenv("AGENTA_SHARED_SERVICE", "true").lower() in _TRUTHY
_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in _TRUTHY

_CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512"))
_CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes

_cache = TTLLRUCache(capacity=_CACHE_CAPACITY, ttl=_CACHE_TTL)


class DenyResponse(JSONResponse):
Expand Down Expand Up @@ -49,29 +56,42 @@ def __init__(self, app: FastAPI):
super().__init__(app)

self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host
self.resource_id = (
ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id
if not _SHARED_SERVICE
else None
)

self.resource_id = None
self.resource_type = None

if ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id:
self.resource_id = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id
self.resource_type = "service"
async def dispatch(self, request: Request, call_next: Callable):
try:
if _ALLOW_UNAUTHORIZED:
request.state.auth = None

elif ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.app_id:
self.resource_id = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.app_id
self.resource_type = "application"
else:
credentials = await self._get_credentials(request)

async def dispatch(
self,
request: Request,
call_next: Callable,
):
print("--- agenta/sdk/middleware/auth.py ---")
request.state.auth = None
request.state.auth = {"credentials": credentials}

if _ALLOW_UNAUTHORIZED:
return await call_next(request)

except DenyException as deny:
display_exception("Auth Middleware Exception")

return DenyResponse(
status_code=deny.status_code,
detail=deny.content,
)

except: # pylint: disable=bare-except
display_exception("Auth Middleware Exception")

return DenyResponse(
status_code=500,
detail="Internal Server Error: auth middleware.",
)

# @atimeit
async def _get_credentials(self, request: Request) -> Optional[str]:
try:
authorization = request.headers.get("authorization", None)

Expand All @@ -90,65 +110,35 @@ async def dispatch(
or request.query_params.get("project_id")
)

params = {
"action": "run_service",
"resource_type": self.resource_type,
"resource_id": self.resource_id,
}
params = {"action": "run_service", "resource_type": "service"}

if self.resource_id:
params["resource_id"] = self.resource_id

if project_id:
params["project_id"] = project_id

print("-----------------------------------")
print(headers)
print(cookies)
print(params)
print("-----------------------------------")

credentials = await self._get_credentials(
params=params,
headers=headers,
cookies=cookies,
_hash = dumps(
{
"headers": headers,
"cookies": cookies,
"params": params,
},
sort_keys=True,
)

request.state.auth = {"credentials": credentials}
if _CACHE_ENABLED:
credentials = _cache.get(_hash)

print(request.state.auth)
if credentials:
return credentials

return await call_next(request)

except DenyException as deny:
display_exception("Auth Middleware Exception")

return DenyResponse(
status_code=deny.status_code,
detail=deny.content,
)

except: # pylint: disable=bare-except
display_exception("Auth Middleware Exception")

return DenyResponse(
status_code=500,
detail="Internal Server Error: auth middleware.",
)

async def _get_credentials(
self,
params: Optional[Dict[str, str]] = None,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[str] = None,
):
if not headers:
raise DenyException(content="Missing 'authorization' header.")

try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.host}/api/permissions/verify",
headers=headers,
params=params,
cookies=cookies,
params=params,
)

if response.status_code == 401:
Expand All @@ -173,7 +163,8 @@ async def _get_credentials(
)

credentials = auth.get("credentials")
# --- #

_cache.put(_hash, credentials)

return credentials

Expand Down
43 changes: 43 additions & 0 deletions agenta-cli/agenta/sdk/middleware/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from time import time
from collections import OrderedDict


class TTLLRUCache:
def __init__(self, capacity: int, ttl: int):
self.cache = OrderedDict()
self.capacity = capacity
self.ttl = ttl

def get(self, key):
# CACHE
if key not in self.cache:
return None

value, expiry = self.cache[key]
# -----

# TTL
if time() > expiry:
del self.cache[key]

return None
# ---

# LRU
self.cache.move_to_end(key)
# ---

return value

def put(self, key, value):
# CACHE
if key in self.cache:
del self.cache[key]
# CACHE & LRU
elif len(self.cache) >= self.capacity:
self.cache.popitem(last=False)
# -----------

# TTL
self.cache[key] = (value, time() + self.ttl)
# ---
Loading

0 comments on commit 372b3d7

Please sign in to comment.