From a5e3197cd247c5468d8739ef9de811cd2a1cbc2f Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Wed, 11 Dec 2024 17:35:27 +0100 Subject: [PATCH 01/11] initial commit --- agenta-cli/agenta/__init__.py | 1 + agenta-cli/agenta/sdk/__init__.py | 1 + agenta-cli/agenta/sdk/agenta_init.py | 14 +- agenta-cli/agenta/sdk/context/exporting.py | 25 + agenta-cli/agenta/sdk/context/routing.py | 24 +- agenta-cli/agenta/sdk/context/tracing.py | 27 +- agenta-cli/agenta/sdk/decorators/routing.py | 577 ++++++++---------- agenta-cli/agenta/sdk/decorators/tracing.py | 88 +-- agenta-cli/agenta/sdk/managers/config.py | 124 +--- agenta-cli/agenta/sdk/managers/vault.py | 25 + agenta-cli/agenta/sdk/middleware/auth.py | 231 ++++--- agenta-cli/agenta/sdk/middleware/cache.py | 43 -- agenta-cli/agenta/sdk/middleware/config.py | 212 +++++++ agenta-cli/agenta/sdk/middleware/cors.py | 27 + agenta-cli/agenta/sdk/middleware/otel.py | 34 ++ agenta-cli/agenta/sdk/middleware/vault.py | 61 ++ agenta-cli/agenta/sdk/tracing/context.py | 24 - agenta-cli/agenta/sdk/tracing/exporters.py | 52 +- agenta-cli/agenta/sdk/tracing/processors.py | 14 +- agenta-cli/agenta/sdk/tracing/tracing.py | 26 +- agenta-cli/agenta/sdk/utils/exceptions.py | 24 +- agenta-cli/tests/baggage/_main.py | 8 + agenta-cli/tests/baggage/agenta | 1 + agenta-cli/tests/baggage/app.py | 20 + agenta-cli/tests/baggage/config.toml | 4 + .../tests/baggage/specs/check_generate.py | 82 +++ .../tests/baggage/specs/check_openapi.py | 51 ++ agenta-cli/tests/run_pytest.sh | 44 ++ agenta-cli/tests/run_tests.sh | 56 ++ agenta-cli/tests/start_server.sh | 40 ++ agenta-cli/tests/stop_server.sh | 45 ++ 31 files changed, 1321 insertions(+), 684 deletions(-) create mode 100644 agenta-cli/agenta/sdk/context/exporting.py create mode 100644 agenta-cli/agenta/sdk/managers/vault.py delete mode 100644 agenta-cli/agenta/sdk/middleware/cache.py create mode 100644 agenta-cli/agenta/sdk/middleware/config.py create mode 100644 agenta-cli/agenta/sdk/middleware/cors.py create mode 100644 agenta-cli/agenta/sdk/middleware/otel.py create mode 100644 agenta-cli/agenta/sdk/middleware/vault.py delete mode 100644 agenta-cli/agenta/sdk/tracing/context.py create mode 100644 agenta-cli/tests/baggage/_main.py create mode 120000 agenta-cli/tests/baggage/agenta create mode 100644 agenta-cli/tests/baggage/app.py create mode 100644 agenta-cli/tests/baggage/config.toml create mode 100644 agenta-cli/tests/baggage/specs/check_generate.py create mode 100644 agenta-cli/tests/baggage/specs/check_openapi.py create mode 100755 agenta-cli/tests/run_pytest.sh create mode 100755 agenta-cli/tests/run_tests.sh create mode 100755 agenta-cli/tests/start_server.sh create mode 100755 agenta-cli/tests/stop_server.sh diff --git a/agenta-cli/agenta/__init__.py b/agenta-cli/agenta/__init__.py index 53c65db70f..53600a4a1e 100644 --- a/agenta-cli/agenta/__init__.py +++ b/agenta-cli/agenta/__init__.py @@ -28,6 +28,7 @@ from .sdk.utils.costs import calculate_token_usage from .sdk.client import Agenta from .sdk.litellm import litellm as callbacks +from .sdk.managers.vault import VaultManager from .sdk.managers.config import ConfigManager from .sdk.managers.variant import VariantManager from .sdk.managers.deployment import DeploymentManager diff --git a/agenta-cli/agenta/sdk/__init__.py b/agenta-cli/agenta/sdk/__init__.py index c1e40757c4..4fc475ef45 100644 --- a/agenta-cli/agenta/sdk/__init__.py +++ b/agenta-cli/agenta/sdk/__init__.py @@ -27,6 +27,7 @@ from .decorators.routing import entrypoint, app, route from .agenta_init import Config, AgentaSingleton, init as _init from .utils.costs import calculate_token_usage +from .managers.vault import VaultManager from .managers.config import ConfigManager from .managers.variant import VariantManager from .managers.deployment import DeploymentManager diff --git a/agenta-cli/agenta/sdk/agenta_init.py b/agenta-cli/agenta/sdk/agenta_init.py index c2180457c2..dbc5734378 100644 --- a/agenta-cli/agenta/sdk/agenta_init.py +++ b/agenta-cli/agenta/sdk/agenta_init.py @@ -59,9 +59,7 @@ def init( ValueError: If `app_id` is not specified either as an argument, in the config file, or in the environment variables. """ - log.info("---------------------------") - log.info("Agenta SDK - using version: %s", version("agenta")) - log.info("---------------------------") + log.info("Agenta - SDK version: %s", version("agenta")) config = {} if config_fname: @@ -86,6 +84,13 @@ def init( self.api_key = api_key or getenv("AGENTA_API_KEY") or config.get("api_key") + self.base_id = getenv("AGENTA_BASE_ID") + + self.service_id = getenv("AGENTA_SERVICE_ID") or self.base_id + + log.info("Agenta - Service ID: %s", self.service_id) + log.info("Agenta - Application ID: %s", self.app_id) + self.tracing = Tracing( url=f"{self.host}/api/observability/v1/otlp/traces", # type: ignore redact=redact, @@ -94,6 +99,7 @@ def init( self.tracing.configure( api_key=self.api_key, + service_id=self.service_id, # DEPRECATING app_id=self.app_id, ) @@ -108,8 +114,6 @@ def init( api_key=self.api_key if self.api_key else "", ) - self.base_id = getenv("AGENTA_BASE_ID") - self.config = Config( host=self.host, base_id=self.base_id, diff --git a/agenta-cli/agenta/sdk/context/exporting.py b/agenta-cli/agenta/sdk/context/exporting.py new file mode 100644 index 0000000000..2fe03a09cd --- /dev/null +++ b/agenta-cli/agenta/sdk/context/exporting.py @@ -0,0 +1,25 @@ +from typing import Optional + +from contextlib import contextmanager +from contextvars import ContextVar + +from pydantic import BaseModel + + +class ExportingContext(BaseModel): + credentials: Optional[str] = None + + +exporting_context = ContextVar("exporting_context", default=ExportingContext()) + + +@contextmanager +def exporting_context_manager( + *, + context: Optional[ExportingContext] = None, +): + token = exporting_context.set(context) + try: + yield + finally: + exporting_context.reset(token) diff --git a/agenta-cli/agenta/sdk/context/routing.py b/agenta-cli/agenta/sdk/context/routing.py index 1d716a69ec..c47ef74712 100644 --- a/agenta-cli/agenta/sdk/context/routing.py +++ b/agenta-cli/agenta/sdk/context/routing.py @@ -1,24 +1,24 @@ +from typing import Any, Dict, Optional + from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Dict, Optional -routing_context = ContextVar("routing_context", default={}) +from pydantic import BaseModel + + +class RoutingContext(BaseModel): + parameters: Optional[Dict[str, Any]] = None + secrets: Optional[Dict[str, Any]] = None + + +routing_context = ContextVar("routing_context", default=RoutingContext()) @contextmanager def routing_context_manager( *, - config: Optional[Dict[str, Any]] = None, - application: Optional[Dict[str, Any]] = None, - variant: Optional[Dict[str, Any]] = None, - environment: Optional[Dict[str, Any]] = None, + context: Optional[RoutingContext] = None, ): - context = { - "config": config, - "application": application, - "variant": variant, - "environment": environment, - } token = routing_context.set(context) try: yield diff --git a/agenta-cli/agenta/sdk/context/tracing.py b/agenta-cli/agenta/sdk/context/tracing.py index 0585a014ad..3bebe13dc1 100644 --- a/agenta-cli/agenta/sdk/context/tracing.py +++ b/agenta-cli/agenta/sdk/context/tracing.py @@ -1,3 +1,28 @@ +from typing import Any, Dict, Optional + +from contextlib import contextmanager from contextvars import ContextVar -tracing_context = ContextVar("tracing_context", default={}) +from pydantic import BaseModel + + +class TracingContext(BaseModel): + credentials: Optional[str] = None + parameters: Optional[Dict[str, Any]] = None + references: Optional[Dict[str, Any]] = None + link: Optional[Dict[str, Any]] = None + + +tracing_context = ContextVar("tracing_context", default=TracingContext()) + + +@contextmanager +def tracing_context_manager( + *, + context: Optional[TracingContext] = None, +): + token = tracing_context.set(context) + try: + yield + finally: + tracing_context.reset(token) diff --git a/agenta-cli/agenta/sdk/decorators/routing.py b/agenta-cli/agenta/sdk/decorators/routing.py index 6be0c1c309..e7fda7d22c 100644 --- a/agenta-cli/agenta/sdk/decorators/routing.py +++ b/agenta-cli/agenta/sdk/decorators/routing.py @@ -1,7 +1,6 @@ from typing import Type, Any, Callable, Dict, Optional, Tuple, List from annotated_types import Ge, Le, Gt, Lt from pydantic import BaseModel, HttpUrl, ValidationError -from json import dumps from inspect import signature, iscoroutinefunction, Signature, Parameter, _empty from argparse import ArgumentParser from functools import wraps @@ -9,16 +8,27 @@ from traceback import format_exc, format_exception from pathlib import Path from tempfile import NamedTemporaryFile -from os import environ -from fastapi.middleware.cors import CORSMiddleware -from fastapi import Body, FastAPI, UploadFile, HTTPException -from agenta.sdk.middleware.auth import AuthorizationMiddleware -from agenta.sdk.context.routing import routing_context_manager, routing_context -from agenta.sdk.context.tracing import tracing_context +from fastapi import Body, FastAPI, UploadFile, HTTPException, Request + +from agenta.sdk.middleware.auth import AuthMiddleware +from agenta.sdk.middleware.otel import OTelMiddleware +from agenta.sdk.middleware.config import ConfigMiddleware +from agenta.sdk.middleware.vault import VaultMiddleware +from agenta.sdk.middleware.cors import CORSMiddleware + +from agenta.sdk.context.routing import ( + routing_context_manager, + routing_context, + RoutingContext, +) +from agenta.sdk.context.tracing import ( + tracing_context_manager, + tracing_context, + TracingContext, +) from agenta.sdk.router import router -from agenta.sdk.utils import helpers from agenta.sdk.utils.exceptions import suppress from agenta.sdk.utils.logging import log from agenta.sdk.types import ( @@ -39,19 +49,10 @@ import agenta as ag -AGENTA_USE_CORS = str(environ.get("AGENTA_USE_CORS", "true")).lower() in ( - "true", - "1", - "t", -) - app = FastAPI() log.setLevel("DEBUG") -_MIDDLEWARES = True - - app.include_router(router, prefix="") @@ -65,7 +66,11 @@ class route: # the @entrypoint decorator, which has certain limitations. By using @route(), we can create new # routes without altering the main workflow entrypoint. This helps in modularizing the services # and provides flexibility in how we expose different functionalities as APIs. - def __init__(self, path, config_schema: BaseModel): + def __init__( + self, + path: Optional[str] = "/", + config_schema: Optional[BaseModel] = None, + ): self.config_schema: BaseModel = config_schema path = "/" + path.strip("/").strip() path = "" if path == "/" else path @@ -73,9 +78,13 @@ def __init__(self, path, config_schema: BaseModel): self.route_path = path + self.e = None + def __call__(self, f): self.e = entrypoint( - f, route_path=self.route_path, config_schema=self.config_schema + f, + route_path=self.route_path, + config_schema=self.config_schema, ) return f @@ -113,232 +122,202 @@ async def chain_of_prompts_llm(prompt: str): """ routes = list() + _middleware = False + _run_path = "/run" + _test_path = "/test" + # LEGACY + _legacy_playground_run_path = "/playground/run" + _legacy_generate_path = "/generate" + _legacy_generate_deployed_path = "/generate_deployed" def __init__( self, func: Callable[..., Any], - route_path="", + route_path: str = "", config_schema: Optional[BaseModel] = None, ): - ### --- Update Middleware --- # - try: - global _MIDDLEWARES # pylint: disable=global-statement - - if _MIDDLEWARES: - app.add_middleware( - AuthorizationMiddleware, - host=ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host, - resource_id=ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.app_id, - resource_type="application", - ) + self.func = func + self.route_path = route_path + self.config_schema = config_schema - if AGENTA_USE_CORS: - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_methods=["*"], - allow_headers=["*"], - allow_credentials=True, - ) + signature_parameters = signature(func).parameters + ingestible_files = self.extract_ingestible_files() + config, default_parameters = self.parse_config() - _MIDDLEWARES = False + ### --- Middleware --- # + if not entrypoint._middleware: + entrypoint._middleware = True - except: # pylint: disable=bare-except - log.warning("Agenta SDK - failed to secure route: %s", route_path) - ### --- Update Middleware --- # + app.add_middleware(VaultMiddleware) + app.add_middleware(ConfigMiddleware) + app.add_middleware(AuthMiddleware) + app.add_middleware(OTelMiddleware) + app.add_middleware(CORSMiddleware) + ### ------------------ # - DEFAULT_PATH = "generate" - PLAYGROUND_PATH = "/playground" - RUN_PATH = "/run" - func_signature = signature(func) - try: - config = ( - config_schema() if config_schema else None - ) # we initialize the config object to be able to use it - except ValidationError as e: - raise ValueError( - f"Error initializing config_schema. Please ensure all required fields have default values: {str(e)}" - ) from e - except Exception as e: - raise ValueError( - f"Unexpected error initializing config_schema: {str(e)}" - ) from e - - config_params = config.dict() if config else ag.config.all() - ingestible_files = self.extract_ingestible_files(func_signature) + ### --- Run --- # + @wraps(func) + async def run_wrapper(request: Request, *args, **kwargs) -> Any: + arguments = { + k: v + for k, v in kwargs.items() + if k not in ["config", "environment", "app"] + } - self.route_path = route_path + return await self.execute_wrapper( + request, + False, + *args, + **arguments, + ) - ### --- Playground --- # + self.update_deployed_function_signature( + run_wrapper, + ingestible_files, + ) + + run_route = f"{entrypoint._run_path}{route_path}" + app.post(run_route, response_model=BaseResponse)(run_wrapper) + + # LEGACY + if route_path == "": + run_route = entrypoint._legacy_generate_deployed_path + app.post(run_route, response_model=BaseResponse)(run_wrapper) + # LEGACY + ### ----------- # + + ### --- Test --- # @wraps(func) - async def wrapper(*args, **kwargs) -> Any: - func_params, api_config_params = self.split_kwargs(kwargs, config_params) - self.ingest_files(func_params, ingestible_files) - if not config_schema: - ag.config.set(**api_config_params) - - with routing_context_manager( - config=api_config_params, - ): - entrypoint_result = await self.execute_function( - func, - True, # inline trace: True - *args, - params=func_params, - config_params=config_params, - ) + async def test_wrapper(request: Request, *args, **kwargs) -> Any: + arguments, _ = self.split_kwargs( + kwargs, + default_parameters, + ) - return entrypoint_result + self.ingest_files( + arguments, + ingestible_files, + ) - self.update_function_signature( - wrapper=wrapper, - func_signature=func_signature, + return await self.execute_wrapper( + request, + True, + *args, + **arguments, + ) + + self.update_test_wrapper_signature( + wrapper=test_wrapper, config_class=config, - config_dict=config_params, + config_dict=default_parameters, ingestible_files=ingestible_files, ) - # + test_route = f"{entrypoint._test_path}{route_path}" + app.post(test_route, response_model=BaseResponse)(test_wrapper) + + # LEGACY if route_path == "": - route = f"/{DEFAULT_PATH}" - app.post(route, response_model=BaseResponse)(wrapper) - entrypoint.routes.append( - { - "func": func.__name__, - "endpoint": route, - "params": ( - {**config_params, **func_signature.parameters} - if not config - else func_signature.parameters - ), - "config": config, - } - ) + test_route = entrypoint._legacy_generate_path + app.post(test_route, response_model=BaseResponse)(test_wrapper) + # LEGACY + ### ------------ # - route = f"{PLAYGROUND_PATH}{RUN_PATH}{route_path}" - app.post(route, response_model=BaseResponse)(wrapper) + ### --- OpenAPI --- # + test_route = f"{entrypoint._test_path}{route_path}" entrypoint.routes.append( { "func": func.__name__, - "endpoint": route, + "endpoint": test_route, "params": ( - {**config_params, **func_signature.parameters} + {**default_parameters, **signature_parameters} if not config - else func_signature.parameters + else signature_parameters ), "config": config, } ) - ### ---------------------------- # - ### --- Deployed --- # - @wraps(func) - async def wrapper_deployed(*args, **kwargs) -> Any: - func_params = { - k: v - for k, v in kwargs.items() - if k not in ["config", "environment", "app"] - } - if not config_schema: - if "environment" in kwargs and kwargs["environment"] is not None: - ag.config.pull(environment_name=kwargs["environment"]) - elif "config" in kwargs and kwargs["config"] is not None: - ag.config.pull(config_name=kwargs["config"]) - else: - ag.config.pull(config_name="default") - - app_id = environ.get("AGENTA_APP_ID") - - with routing_context_manager( - application={ - "id": app_id, - "slug": kwargs.get("app"), - }, - variant={ - "slug": kwargs.get("config"), - }, - environment={ - "slug": kwargs.get("environment"), - }, - ): - entrypoint_result = await self.execute_function( - func, - False, # inline trace: False - *args, - params=func_params, - config_params=config_params, - ) - - return entrypoint_result - - self.update_deployed_function_signature( - wrapper_deployed, - func_signature, - ingestible_files, - ) + # LEGACY if route_path == "": - route_deployed = f"/{DEFAULT_PATH}_deployed" - app.post(route_deployed, response_model=BaseResponse)(wrapper_deployed) - - route_deployed = f"{RUN_PATH}{route_path}" - app.post(route_deployed, response_model=BaseResponse)(wrapper_deployed) - ### ---------------- # + test_route = entrypoint._legacy_generate_path + entrypoint.routes.append( + { + "func": func.__name__, + "endpoint": test_route, + "params": ( + {**default_parameters, **signature_parameters} + if not config + else signature_parameters + ), + "config": config, + } + ) + # LEGACY - ### --- Update OpenAPI --- # app.openapi_schema = None # Forces FastAPI to re-generate the schema openapi_schema = app.openapi() - # Inject the current version of the SDK into the openapi_schema - openapi_schema["agenta_sdk"] = {"version": helpers.get_current_version()} - - for route in entrypoint.routes: + for _route in entrypoint.routes: self.override_schema( openapi_schema=openapi_schema, - func_name=route["func"], - endpoint=route["endpoint"], - params=route["params"], + func_name=_route["func"], + endpoint=_route["endpoint"], + params=_route["params"], ) - if route["config"] is not None: # new SDK version + + if _route["config"] is not None: # new SDK version self.override_config_in_schema( openapi_schema=openapi_schema, - func_name=route["func"], - endpoint=route["endpoint"], - config=route["config"], + func_name=_route["func"], + endpoint=_route["endpoint"], + config=_route["config"], ) + ### --------------- # - if self.is_main_script(func) and route_path == "": - self.handle_terminal_run( - func, - func_signature.parameters, # type: ignore - config_params, - ingestible_files, - ) - - def extract_ingestible_files( - self, - func_signature: Signature, - ) -> Dict[str, Parameter]: + def extract_ingestible_files(self) -> Dict[str, Parameter]: """Extract parameters annotated as InFile from function signature.""" return { name: param - for name, param in func_signature.parameters.items() + for name, param in signature(self.func).parameters.items() if param.annotation is InFile } + def parse_config(self) -> Dict[str, Any]: + config = None + default_parameters = ag.config.all() + + if self.config_schema: + try: + config = self.config_schema() if self.config_schema else None + default_parameters = config.dict() if config else default_parameters + except ValidationError as e: + raise ValueError( + f"Error initializing config_schema. Please ensure all required fields have default values: {str(e)}" + ) from e + except Exception as e: + raise ValueError( + f"Unexpected error initializing config_schema: {str(e)}" + ) from e + + return config, default_parameters + def split_kwargs( - self, kwargs: Dict[str, Any], config_params: Dict[str, Any] + self, kwargs: Dict[str, Any], default_parameters: Dict[str, Any] ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Split keyword arguments into function parameters and API configuration parameters.""" - func_params = {k: v for k, v in kwargs.items() if k not in config_params} - api_config_params = {k: v for k, v in kwargs.items() if k in config_params} + func_params = {k: v for k, v in kwargs.items() if k not in default_parameters} + api_config_params = {k: v for k, v in kwargs.items() if k in default_parameters} + return func_params, api_config_params def ingest_file(self, upfile: UploadFile): temp_file = NamedTemporaryFile(delete=False) temp_file.write(upfile.file.read()) temp_file.close() + return InFile(file_name=upfile.filename, file_path=temp_file.name) def ingest_files( @@ -352,51 +331,80 @@ def ingest_files( if name in func_params and func_params[name] is not None: func_params[name] = self.ingest_file(func_params[name]) - async def execute_function( + async def execute_wrapper( self, - func: Callable[..., Any], - inline_trace, + request: Request, + inline: bool, *args, - **func_params, + **kwargs, ): - log.info("Agenta SDK - handling route: %s", repr(self.route_path or "/")) + if not request: + raise HTTPException(status_code=500, detail="Missing 'request'.") + + state = request.state + credentials = state.auth.get("credentials") if state.auth else None + parameters = state.config.get("parameters") if state.config else None + references = state.config.get("references") if state.config else None + secrets = state.vault.get("secrets") if state.vault else None + + with routing_context_manager( + context=RoutingContext( + parameters=parameters, + secrets=secrets, + ) + ): + with tracing_context_manager( + context=TracingContext( + credentials=credentials, + parameters=parameters, + references=references, + ) + ): + result = await self.execute_function(inline, *args, **kwargs) - tracing_context.set(routing_context.get()) + return result + + async def execute_function( + self, + inline: bool, + *args, + **kwargs, + ): + log.info("Agenta - Handling: '%s'", repr(self.route_path or "/")) try: result = ( - await func(*args, **func_params["params"]) - if iscoroutinefunction(func) - else func(*args, **func_params["params"]) + await self.func(*args, **kwargs) + if iscoroutinefunction(self.func) + else self.func(*args, **kwargs) ) - return await self.handle_success(result, inline_trace) + return await self.handle_success(result, inline) - except Exception as error: + except Exception as error: # pylint: disable=broad-except self.handle_failure(error) - async def handle_success(self, result: Any, inline_trace: bool): + async def handle_success(self, result: Any, inline: bool): data = None tree = None with suppress(): data = self.patch_result(result) - if inline_trace: - tree = await self.fetch_inline_trace(inline_trace) - - log.info(f"----------------------------------") - log.info(f"Agenta SDK - exiting with success: 200") - log.info(f"----------------------------------") + if inline: + tree = await self.fetch_inline_trace(inline) - return BaseResponse(data=data, tree=tree) + try: + return BaseResponse(data=data, tree=tree) + except: + return BaseResponse(data=data) def handle_failure(self, error: Exception): - log.warning("--------------------------------------------------") - log.warning("Agenta SDK - handling application exception below:") - log.warning("--------------------------------------------------") + log.warning("-------------------------------") + log.warning("Agenta - Application Exception:") + log.warning("-------------------------------") log.warning(format_exc().strip("\n")) - log.warning("--------------------------------------------------") + log.warning("-------------------------------") status_code = 500 message = str(error) @@ -442,7 +450,7 @@ def patch_result(self, result: Any): return data - async def fetch_inline_trace(self, inline_trace): + async def fetch_inline_trace(self, inline): WAIT_FOR_SPANS = True TIMEOUT = 1 TIMESTEP = 0.1 @@ -451,12 +459,14 @@ async def fetch_inline_trace(self, inline_trace): trace = None - root_context: Dict[str, Any] = tracing_context.get().get("root") + context = tracing_context.get() + + link = context.link - trace_id = root_context.get("trace_id") if root_context else None + trace_id = link.get("tree_id") if link else None if trace_id is not None: - if inline_trace: + if inline: if WAIT_FOR_SPANS: remaining_steps = NOFSTEPS @@ -476,6 +486,27 @@ async def fetch_inline_trace(self, inline_trace): return trace + # --- OpenAPI --- # + + def add_request_to_signature( + self, + wrapper: Callable[..., Any], + ): + original_sig = signature(wrapper) + parameters = [ + Parameter( + "request", + kind=Parameter.POSITIONAL_OR_KEYWORD, + annotation=Request, + ), + *original_sig.parameters.values(), + ] + new_sig = Signature( + parameters, + return_annotation=original_sig.return_annotation, + ) + wrapper.__signature__ = new_sig + def update_wrapper_signature( self, wrapper: Callable[..., Any], updated_params: List ): @@ -492,10 +523,9 @@ def update_wrapper_signature( wrapper_signature = wrapper_signature.replace(parameters=updated_params) wrapper.__signature__ = wrapper_signature # type: ignore - def update_function_signature( + def update_test_wrapper_signature( self, wrapper: Callable[..., Any], - func_signature: Signature, config_class: Type[BaseModel], # TODO: change to our type config_dict: Dict[str, Any], ingestible_files: Dict[str, Parameter], @@ -507,19 +537,19 @@ def update_function_signature( self.add_config_params_to_parser(updated_params, config_class) else: self.deprecated_add_config_params_to_parser(updated_params, config_dict) - self.add_func_params_to_parser(updated_params, func_signature, ingestible_files) + self.add_func_params_to_parser(updated_params, ingestible_files) self.update_wrapper_signature(wrapper, updated_params) + self.add_request_to_signature(wrapper) def update_deployed_function_signature( self, wrapper: Callable[..., Any], - func_signature: Signature, ingestible_files: Dict[str, Parameter], ) -> None: """Update the function signature to include new parameters.""" updated_params: List[Parameter] = [] - self.add_func_params_to_parser(updated_params, func_signature, ingestible_files) + self.add_func_params_to_parser(updated_params, ingestible_files) for param in [ "config", "environment", @@ -533,6 +563,7 @@ def update_deployed_function_signature( ) ) self.update_wrapper_signature(wrapper, updated_params) + self.add_request_to_signature(wrapper) def add_config_params_to_parser( self, updated_params: list, config_class: Type[BaseModel] @@ -573,11 +604,10 @@ def deprecated_add_config_params_to_parser( def add_func_params_to_parser( self, updated_params: list, - func_signature: Signature, ingestible_files: Dict[str, Parameter], ) -> None: """Add function parameters to function signature.""" - for name, param in func_signature.parameters.items(): + for name, param in signature(self.func).parameters.items(): if name in ingestible_files: updated_params.append( Parameter(name, param.kind, annotation=UploadFile) @@ -599,115 +629,6 @@ def add_func_params_to_parser( ) ) - def is_main_script(self, func: Callable) -> bool: - """ - Check if the script containing the function is the main script being run. - - Args: - func (Callable): The function object to check. - - Returns: - bool: True if the script containing the function is the main script, False otherwise. - - Example: - if is_main_script(my_function): - print("This is the main script.") - """ - return func.__module__ == "__main__" - - def handle_terminal_run( - self, - func: Callable, - func_params: Dict[str, Parameter], - config_params: Dict[str, Any], - ingestible_files: Dict, - ): - """ - Parses command line arguments and sets configuration when script is run from the terminal. - - Args: - func_params (dict): A dictionary containing the function parameters and their annotations. - config_params (dict): A dictionary containing the configuration parameters. - ingestible_files (dict): A dictionary containing the files that should be ingested. - """ - - # For required parameters, we add them as arguments - parser = ArgumentParser() - for name, param in func_params.items(): - if name in ingestible_files: - parser.add_argument(name, type=str) - else: - parser.add_argument(name, type=param.annotation) - - for name, param in config_params.items(): - if type(param) is MultipleChoiceParam: - parser.add_argument( - f"--{name}", - type=str, - default=param.default, - choices=param.choices, # type: ignore - ) - else: - parser.add_argument( - f"--{name}", - type=type(param), - default=param, - ) - - args = parser.parse_args() - - # split the arg list into the arg in the app_param and - # the args from the sig.parameter - args_config_params = {k: v for k, v in vars(args).items() if k in config_params} - args_func_params = { - k: v for k, v in vars(args).items() if k not in config_params - } - for name in ingestible_files: - args_func_params[name] = InFile( - file_name=Path(args_func_params[name]).stem, - file_path=args_func_params[name], - ) - - # Update args_config_params with default values from config_params if not provided in command line arguments - args_config_params.update( - { - key: value - for key, value in config_params.items() - if key not in args_config_params - } - ) - - loop = get_event_loop() - - with routing_context_manager(config=args_config_params): - result = loop.run_until_complete( - self.execute_function( - func, - True, # inline trace: True - **{"params": args_func_params, "config_params": args_config_params}, - ) - ) - - if result.trace: - log.info("\n========= Result =========\n") - - log.info(f"trace_id: {result.trace['trace_id']}") - log.info(f"latency: {result.trace.get('latency')}") - log.info(f"cost: {result.trace.get('cost')}") - log.info(f"usage: {list(result.trace.get('usage', {}).values())}") - - log.info(" ") - log.info("data:") - log.info(dumps(result.data, indent=2)) - - log.info(" ") - log.info("trace:") - log.info("----------------") - log.info(dumps(result.trace.get("spans", []), indent=2)) - log.info("----------------") - - log.info("\n==========================\n") - def override_config_in_schema( self, openapi_schema: dict, diff --git a/agenta-cli/agenta/sdk/decorators/tracing.py b/agenta-cli/agenta/sdk/decorators/tracing.py index 68f707b694..44891355a7 100644 --- a/agenta-cli/agenta/sdk/decorators/tracing.py +++ b/agenta-cli/agenta/sdk/decorators/tracing.py @@ -1,8 +1,12 @@ from typing import Callable, Optional, Any, Dict, List, Union + from functools import wraps from itertools import chain from inspect import iscoroutinefunction, getfullargspec +from opentelemetry import baggage as baggage +from opentelemetry.context import attach, detach + from agenta.sdk.utils.exceptions import suppress from agenta.sdk.context.tracing import tracing_context from agenta.sdk.tracing.conventions import parse_span_kind @@ -39,10 +43,12 @@ def __call__(self, func: Callable[..., Any]): is_coroutine_function = iscoroutinefunction(func) @wraps(func) - async def async_wrapper(*args, **kwargs): - async def _async_auto_instrumented(*args, **kwargs): + async def awrapper(*args, **kwargs): + async def aauto_instrumented(*args, **kwargs): self._parse_type_and_kind() + token = self._attach_baggage() + with ag.tracer.start_as_current_span(func.__name__, kind=self.kind): self._pre_instrument(func, *args, **kwargs) @@ -52,13 +58,17 @@ async def _async_auto_instrumented(*args, **kwargs): return result - return await _async_auto_instrumented(*args, **kwargs) + self._detach_baggage(token) + + return await aauto_instrumented(*args, **kwargs) @wraps(func) - def sync_wrapper(*args, **kwargs): - def _sync_auto_instrumented(*args, **kwargs): + def wrapper(*args, **kwargs): + def auto_instrumented(*args, **kwargs): self._parse_type_and_kind() + token = self._attach_baggage() + with ag.tracer.start_as_current_span(func.__name__, kind=self.kind): self._pre_instrument(func, *args, **kwargs) @@ -68,9 +78,11 @@ def _sync_auto_instrumented(*args, **kwargs): return result - return _sync_auto_instrumented(*args, **kwargs) + self._detach_baggage(token) + + return auto_instrumented(*args, **kwargs) - return async_wrapper if is_coroutine_function else sync_wrapper + return awrapper if is_coroutine_function else wrapper def _parse_type_and_kind(self): if not ag.tracing.get_current_span().is_recording(): @@ -78,6 +90,25 @@ def _parse_type_and_kind(self): self.kind = parse_span_kind(self.type) + def _attach_baggage(self): + context = tracing_context.get() + + references = context.references + + token = None + if references: + for k, v in references: + token = attach(baggage.set_baggage(f"ag.refs.{k}", v)) + + return token + + def _detach_baggage( + self, + token, + ): + if token: + detach(token) + def _pre_instrument( self, func, @@ -86,29 +117,21 @@ def _pre_instrument( ): span = ag.tracing.get_current_span() + context = tracing_context.get() + with suppress(): + trace_id = span.context.trace_id + + ag.tracing.credentials[trace_id] = context.credentials + span.set_attributes( attributes={"node": self.type}, namespace="type", ) if span.parent is None: - rctx = tracing_context.get() - - span.set_attributes( - attributes={"configuration": rctx.get("config", {})}, - namespace="meta", - ) - span.set_attributes( - attributes={"environment": rctx.get("environment", {})}, - namespace="meta", - ) span.set_attributes( - attributes={"version": rctx.get("version", {})}, - namespace="meta", - ) - span.set_attributes( - attributes={"variant": rctx.get("variant", {})}, + attributes={"configuration": context.parameters or {}}, namespace="meta", ) @@ -118,6 +141,7 @@ def _pre_instrument( io=self._parse(func, *args, **kwargs), ignore=self.ignore_inputs, ) + span.set_attributes( attributes={"inputs": _inputs}, namespace="data", @@ -161,6 +185,7 @@ def _post_instrument( io=self._patch(result), ignore=self.ignore_outputs, ) + span.set_attributes( attributes={"outputs": _outputs}, namespace="data", @@ -171,15 +196,12 @@ def _post_instrument( with suppress(): if hasattr(span, "parent") and span.parent is None: - tracing_context.set( - tracing_context.get() - | { - "root": { - "trace_id": span.get_span_context().trace_id, - "span_id": span.get_span_context().span_id, - } - } - ) + context = tracing_context.get() + context.link = { + "tree_id": span.get_span_context().trace_id, + "node_id": span.get_span_context().span_id, + } + tracing_context.set(context) def _parse( self, @@ -224,9 +246,7 @@ def _redact( not in ( ignore if isinstance(ignore, list) - else io.keys() - if ignore is True - else [] + else io.keys() if ignore is True else [] ) } diff --git a/agenta-cli/agenta/sdk/managers/config.py b/agenta-cli/agenta/sdk/managers/config.py index edadbaedc0..f433fa5242 100644 --- a/agenta-cli/agenta/sdk/managers/config.py +++ b/agenta-cli/agenta/sdk/managers/config.py @@ -20,7 +20,7 @@ class ConfigManager: @staticmethod def get_from_route( schema: Optional[Type[T]] = None, - ) -> Union[Dict[str, Any], T]: + ) -> Optional[Union[Dict[str, Any], T]]: """ Retrieves the configuration from the route context and returns a config object. @@ -47,125 +47,15 @@ def get_from_route( context = routing_context.get() - parameters = None - - if "config" in context and context["config"]: - parameters = context["config"] - - else: - app_id: Optional[str] = None - app_slug: Optional[str] = None - variant_id: Optional[str] = None - variant_slug: Optional[str] = None - variant_version: Optional[int] = None - environment_id: Optional[str] = None - environment_slug: Optional[str] = None - environment_version: Optional[int] = None - - if "application" in context: - app_id = context["application"].get("id") - app_slug = context["application"].get("slug") - - if "variant" in context: - variant_id = context["variant"].get("id") - variant_slug = context["variant"].get("slug") - variant_version = context["variant"].get("version") - - if "environment" in context: - environment_id = context["environment"].get("id") - environment_slug = context["environment"].get("slug") - environment_version = context["environment"].get("version") - - parameters = ConfigManager.get_from_registry( - app_id=app_id, - app_slug=app_slug, - variant_id=variant_id, - variant_slug=variant_slug, - variant_version=variant_version, - environment_id=environment_id, - environment_slug=environment_slug, - environment_version=environment_version, - ) + parameters = context.parameters - if schema: - return schema(**parameters) - - return parameters + if not parameters: + return None - @staticmethod - async def aget_from_route( - schema: Optional[Type[T]] = None, - ) -> Union[Dict[str, Any], T]: - """ - Asynchronously retrieves the configuration from the route context and returns a config object. + if not schema: + return parameters - This method checks the route context for configuration information and returns - an instance of the specified schema based on the available context data. - - Args: - schema (Type[T]): A Pydantic model class that defines the structure of the configuration. - - Returns: - T: An instance of the specified schema populated with the configuration data. - - Raises: - ValueError: If conflicting configuration sources are provided or if no valid - configuration source is found in the context. - - Note: - The method prioritizes the inputs in the following way: - 1. 'config' (i.e. when called explicitly from the playground) - 2. 'environment' - 3. 'variant' - Only one of these should be provided. - """ - - context = routing_context.get() - - parameters = None - - if "config" in context and context["config"]: - parameters = context["config"] - - else: - app_id: Optional[str] = None - app_slug: Optional[str] = None - variant_id: Optional[str] = None - variant_slug: Optional[str] = None - variant_version: Optional[int] = None - environment_id: Optional[str] = None - environment_slug: Optional[str] = None - environment_version: Optional[int] = None - - if "application" in context: - app_id = context["application"].get("id") - app_slug = context["application"].get("slug") - - if "variant" in context: - variant_id = context["variant"].get("id") - variant_slug = context["variant"].get("slug") - variant_version = context["variant"].get("version") - - if "environment" in context: - environment_id = context["environment"].get("id") - environment_slug = context["environment"].get("slug") - environment_version = context["environment"].get("version") - - parameters = await ConfigManager.async_get_from_registry( - app_id=app_id, - app_slug=app_slug, - variant_id=variant_id, - variant_slug=variant_slug, - variant_version=variant_version, - environment_id=environment_id, - environment_slug=environment_slug, - environment_version=environment_version, - ) - - if schema: - return schema(**parameters) - - return parameters + return schema(**parameters) @staticmethod def get_from_registry( diff --git a/agenta-cli/agenta/sdk/managers/vault.py b/agenta-cli/agenta/sdk/managers/vault.py new file mode 100644 index 0000000000..bf3c1afe0e --- /dev/null +++ b/agenta-cli/agenta/sdk/managers/vault.py @@ -0,0 +1,25 @@ +from typing import Optional, Type, TypeVar, Dict, Any, Union + +from pydantic import BaseModel + +from agenta.sdk.decorators.routing import routing_context + +T = TypeVar("T", bound=BaseModel) + + +class VaultManager: + @staticmethod + def get_from_route( + schema: Optional[Type[T]] = None, + ) -> Optional[Union[Dict[str, Any], T]]: + context = routing_context.get() + + secrets = context.secrets + + if not secrets: + return None + + if not schema: + return secrets + + return schema(**secrets) diff --git a/agenta-cli/agenta/sdk/middleware/auth.py b/agenta-cli/agenta/sdk/middleware/auth.py index c02e46322a..a0cda77af0 100644 --- a/agenta-cli/agenta/sdk/middleware/auth.py +++ b/agenta-cli/agenta/sdk/middleware/auth.py @@ -1,82 +1,86 @@ -from typing import Callable, Optional -from os import environ -from uuid import UUID -from json import dumps +from typing import Callable, Dict, Optional + +from os import getenv from traceback import format_exc import httpx from starlette.middleware.base import BaseHTTPMiddleware -from fastapi import FastAPI, Request, Response +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse -from agenta.sdk.utils.logging import log -from agenta.sdk.middleware.cache import TTLLRUCache -AGENTA_SDK_AUTH_CACHE_CAPACITY = environ.get( - "AGENTA_SDK_AUTH_CACHE_CAPACITY", - 512, -) +from agenta.sdk.utils.logging import log -AGENTA_SDK_AUTH_CACHE_TTL = environ.get( - "AGENTA_SDK_AUTH_CACHE_TTL", - 15 * 60, # 15 minutes -) +import agenta as ag -AGENTA_SDK_AUTH_CACHE = str(environ.get("AGENTA_SDK_AUTH_CACHE", True)).lower() in ( - "true", - "1", - "t", +_TRUTHY = {"true", "1", "t", "y", "yes", "on", "enable", "enabled"} +_ALLOW_UNAUTHORIZED = ( + getenv("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", "false").lower() in _TRUTHY ) -AGENTA_SDK_AUTH_CACHE = False -AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED = str( - environ.get("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", False) -).lower() in ("true", "1", "t") +class DenyResponse(JSONResponse): + def __init__( + self, + status_code: int = 401, + detail: str = "Unauthorized", + ) -> None: + super().__init__( + status_code=status_code, + content={"detail": detail}, + ) -class Deny(Response): - def __init__(self) -> None: - super().__init__(status_code=401, content="Unauthorized") +class DenyException(Exception): + def __init__( + self, + status_code: int = 401, + content: str = "Unauthorized", + ) -> None: + super().__init__() -cache = TTLLRUCache( - capacity=AGENTA_SDK_AUTH_CACHE_CAPACITY, - ttl=AGENTA_SDK_AUTH_CACHE_TTL, -) + self.status_code = status_code + self.content = content -class AuthorizationMiddleware(BaseHTTPMiddleware): - def __init__( - self, - app: FastAPI, - host: str, - resource_id: UUID, - resource_type: str, - ): +class AuthMiddleware(BaseHTTPMiddleware): + def __init__(self, app: FastAPI): super().__init__(app) - self.host = host - self.resource_id = resource_id - self.resource_type = resource_type + self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host + self.resource_id = ( + # STATELESS + ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id + # LEGACY OR STATEFUL + or ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.app_id + ) + self.resource_type = "application" async def dispatch( self, request: Request, call_next: Callable, ): - if AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED: + print("--- agenta/sdk/middleware/auth.py ---") + request.state.auth = None + + if _ALLOW_UNAUTHORIZED: return await call_next(request) try: - authorization = ( - request.headers.get("Authorization") - or request.headers.get("authorization") - or None - ) + authorization = request.headers.get("Authorization", None) headers = {"Authorization": authorization} if authorization else None - cookies = {"sAccessToken": request.cookies.get("sAccessToken")} + baggage = request.state.otel.get("baggage") if request.state.otel else {} + + project_id = ( + # CLEANEST + baggage.get("project_id") + # ALTERNATIVE + or request.query_params.get("project_id") + ) params = { "action": "run_service", @@ -84,62 +88,107 @@ async def dispatch( "resource_id": self.resource_id, } - project_id = request.query_params.get("project_id") - if project_id: params["project_id"] = project_id - _hash = dumps( - { - "headers": headers, - "cookies": cookies, - "params": params, - }, - sort_keys=True, + credentials = await self._get_credentials( + # credentials = await self._mock_get_credentials( + params=params, + headers=headers, ) - policy = None - if AGENTA_SDK_AUTH_CACHE: - policy = cache.get(_hash) - - if not policy: - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.host}/api/permissions/verify", - headers=headers, - cookies=cookies, - params=params, - ) + request.state.auth = {"credentials": credentials} - if response.status_code != 200: - cache.put(_hash, {"effect": "deny"}) - return Deny() + print(request.state.auth) - auth = response.json() + return await call_next(request) - if auth.get("effect") != "allow": - cache.put(_hash, {"effect": "deny"}) - return Deny() + except DenyException as deny: + log.warning("-----------------------------------") + log.warning("Agenta - Auth Middleware Exception:") + log.warning("-----------------------------------") + log.warning(format_exc().strip("\n")) + log.warning("-----------------------------------") - policy = { - "effect": "allow", - "credentials": auth.get("credentials"), - } + return DenyResponse( + status_code=deny.status_code, + detail=deny.content, + ) - cache.put(_hash, policy) + except: # pylint: disable=bare-except + log.warning("-----------------------------------") + log.warning("Agenta - Auth Middleware Exception:") + log.warning("-----------------------------------") + log.warning(format_exc().strip("\n")) + log.warning("-----------------------------------") - if not policy or policy.get("effect") == "deny": - return Deny() + return DenyResponse( + status_code=500, + detail="Internal Server Error: auth middleware.", + ) - request.state.credentials = policy.get("credentials") + async def _mock_get_credentials( + self, + params: Dict[str, str], + headers: Dict[str, str], + ): + if not headers: + raise DenyException(content="Missing 'authorization' header.") - return await call_next(request) + return headers.get("Authorization") - except: # pylint: disable=bare-except - log.warning("------------------------------------------------------") - log.warning("Agenta SDK - handling auth middleware exception below:") - log.warning("------------------------------------------------------") + async def _get_credentials( + self, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + ): + if not headers: + raise DenyException(content="Missing 'authorization' header.") + + try: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.host}/api/permissions/verify", + headers=headers, + params=params, + ) + + if response.status_code == 401: + raise DenyException( + status_code=401, content="Invalid 'authorization' header." + ) + elif response.status_code == 403: + raise DenyException( + status_code=403, content="Service execution not allowed." + ) + elif response.status_code != 200: + raise DenyException( + status_code=400, + content="Internal Server Error: auth middleware.", + ) + + auth = response.json() + + if auth.get("effect") != "allow": + raise DenyException( + status_code=403, content="Service execution not allowed." + ) + + credentials = auth.get("credentials") + # --- # + + return credentials + + except DenyException as deny: + raise deny + + except Exception as exc: # pylint: disable=bare-except + log.warning("------------------------------------------------") + log.warning("Agenta - Auth Middleware Exception (suppressed):") + log.warning("------------------------------------------------") log.warning(format_exc().strip("\n")) - log.warning("------------------------------------------------------") + log.warning("------------------------------------------------") - return Deny() + raise DenyException( + status_code=500, content="Internal Server Error: auth middleware." + ) from exc diff --git a/agenta-cli/agenta/sdk/middleware/cache.py b/agenta-cli/agenta/sdk/middleware/cache.py deleted file mode 100644 index 5445b1fafc..0000000000 --- a/agenta-cli/agenta/sdk/middleware/cache.py +++ /dev/null @@ -1,43 +0,0 @@ -from time import time -from collections import OrderedDict - - -class TTLLRUCache: - def __init__(self, capacity: int, ttl: int): - self.cache = OrderedDict() - self.capacity = capacity - self.ttl = ttl - - def get(self, key): - # CACHE - if key not in self.cache: - return None - - value, expiry = self.cache[key] - # ----- - - # TTL - if time() > expiry: - del self.cache[key] - - return None - # --- - - # LRU - self.cache.move_to_end(key) - # --- - - return value - - def put(self, key, value): - # CACHE - if key in self.cache: - del self.cache[key] - # CACHE & LRU - elif len(self.cache) >= self.capacity: - self.cache.popitem(last=False) - # ----------- - - # TTL - self.cache[key] = (value, time() + self.ttl) - # --- diff --git a/agenta-cli/agenta/sdk/middleware/config.py b/agenta-cli/agenta/sdk/middleware/config.py new file mode 100644 index 0000000000..c91e5cf11f --- /dev/null +++ b/agenta-cli/agenta/sdk/middleware/config.py @@ -0,0 +1,212 @@ +from typing import Callable, Optional, Dict + +from uuid import UUID +from pydantic import BaseModel + +from starlette.middleware.base import BaseHTTPMiddleware +from fastapi import Request, FastAPI + +import httpx + +from agenta.sdk.utils.exceptions import suppress + +import agenta as ag + + +class Reference(BaseModel): + id: Optional[str] = None + slug: Optional[str] = None + version: Optional[str] = None + + +async def _parse_application_ref( + request: Request, +) -> Optional[Reference]: + baggage = request.state.otel.get("baggage") if request.state.otel else {} + + application_id = ( + # CLEANEST + baggage.get("application_id") + # ALTERNATIVE + or request.query_params.get("application_id") + # LEGACY + or request.query_params.get("app_id") + ) + application_slug = ( + # CLEANEST + baggage.get("application_slug") + # ALTERNATIVE + or request.query_params.get("application_slug") + # LEGACY + or request.query_params.get("app_slug") + or request.query_params.get("app") + ) + + if not any([application_id, application_slug]): + return None + + return Reference( + id=application_id, + slug=application_slug, + ) + + +async def _parse_variant_ref( + request: Request, +) -> Optional[Reference]: + baggage = request.state.otel.get("baggage") if request.state.otel else {} + + variant_id = ( + # CLEANEST + baggage.get("variant_id") + # ALTERNATIVE + or request.query_params.get("variant_id") + ) + variant_slug = ( + # CLEANEST + baggage.get("variant_slug") + # ALTERNATIVE + or request.query_params.get("variant_slug") + # LEGACY + or request.query_params.get("config") + ) + variant_version = ( + # CLEANEST + baggage.get("variant_version") + # ALTERNATIVE + or request.query_params.get("variant_version") + ) + + if not any([variant_id, variant_slug, variant_version]): + return None + + return Reference( + id=variant_id, + slug=variant_slug, + version=variant_version, + ) + + +async def _parse_environment_ref( + request: Request, +) -> Optional[Reference]: + baggage = request.state.otel.get("baggage") if request.state.otel else {} + + environment_id = ( + # CLEANEST + baggage.get("environment_id") + # ALTERNATIVE + or request.query_params.get("environment_id") + ) + environment_slug = ( + # CLEANEST + baggage.get("environment_slug") + # ALTERNATIVE + or request.query_params.get("environment_slug") + # LEGACY + or request.query_params.get("environment") + ) + environment_version = ( + # CLEANEST + baggage.get("environment_version") + # ALTERNATIVE + or request.query_params.get("environment_version") + ) + + if not any([environment_id, environment_slug, environment_version]): + return None + + return Reference( + id=environment_id, + slug=environment_slug, + version=environment_version, + ) + + +class ConfigMiddleware(BaseHTTPMiddleware): + def __init__(self, app: FastAPI): + super().__init__(app) + + self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host + + async def dispatch( + self, + request: Request, + call_next: Callable, + ): + print("--- agenta/sdk/middleware/config.py ---") + request.state.config = None + + with suppress(): + application_ref = await _parse_application_ref(request) + variant_ref = await _parse_variant_ref(request) + environment_ref = await _parse_environment_ref(request) + + auth = request.state.auth or {} + + headers = { + "Authorization": auth.get("credentials"), + } + + refs = {} + if application_ref: + refs["application_ref"] = application_ref.model_dump() + if variant_ref: + refs["variant_ref"] = variant_ref.model_dump() + if environment_ref: + refs["environment_ref"] = environment_ref.model_dump() + + config = await self._get_config( + headers=headers, + refs=refs, + ) + + if config: + parameters = config.get("params") + + references = {} + + for ref_key in ["application_ref", "variant_ref", "environment_ref"]: + refs = config.get(ref_key) + ref_prefix = ref_key.split("_", maxsplit=1)[0] + + for ref_part_key in ["id", "slug", "version"]: + ref_part = refs.get(ref_part_key) + + if ref_part: + references[ref_prefix + "." + ref_part_key] = ref_part + + request.state.config = { + "parameters": parameters, + "references": references, + } + + print(request.state.config) + + return await call_next(request) + + async def _get_config( + self, + headers: Dict[str, str], + refs: Dict[str, str], + ): + if not refs: + return None + + try: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.host}/api/variants/configs/fetch", + headers=headers, + json=refs, + ) + + if response.status_code != 200: + return None + + config = response.json() + + return config + + except: # pylint: disable=bare-except + return None diff --git a/agenta-cli/agenta/sdk/middleware/cors.py b/agenta-cli/agenta/sdk/middleware/cors.py new file mode 100644 index 0000000000..80f0a30fc5 --- /dev/null +++ b/agenta-cli/agenta/sdk/middleware/cors.py @@ -0,0 +1,27 @@ +from os import getenv + +from starlette.types import ASGIApp, Receive, Scope, Send +from fastapi.middleware.cors import CORSMiddleware as _CORSMiddleware + +_TRUTHY = {"true", "1", "t", "y", "yes", "on", "enable", "enabled"} +_USE_CORS = getenv("AGENTA_USE_CORS", "enable").lower() in _TRUTHY + + +class CORSMiddleware(_CORSMiddleware): + def __init__(self, app: ASGIApp): + if _USE_CORS: + super().__init__( + app=app, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], + allow_credentials=True, + expose_headers=None, + max_age=None, + ) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if _USE_CORS: + return await super().__call__(scope, receive, send) + + return await self.app(scope, receive, send) diff --git a/agenta-cli/agenta/sdk/middleware/otel.py b/agenta-cli/agenta/sdk/middleware/otel.py new file mode 100644 index 0000000000..1e7352e82d --- /dev/null +++ b/agenta-cli/agenta/sdk/middleware/otel.py @@ -0,0 +1,34 @@ +from typing import Callable + +from starlette.middleware.base import BaseHTTPMiddleware +from fastapi import Request, FastAPI + +from agenta.sdk.utils.exceptions import suppress + +from opentelemetry.baggage.propagation import W3CBaggagePropagator + + +class OTelMiddleware(BaseHTTPMiddleware): + def __init__(self, app: FastAPI): + super().__init__(app) + + async def dispatch( + self, + request: Request, + call_next: Callable, + ): + request.state.otel = None + + with suppress(): + baggage = {"baggage": request.headers.get("Baggage", "")} + + context = W3CBaggagePropagator().extract(baggage) + + if context: + request.state.otel = {"baggage": {}} + + for _, partial in context.values(): + for key, value in partial.items(): + request.state.otel["baggage"][key] = value + + return await call_next(request) diff --git a/agenta-cli/agenta/sdk/middleware/vault.py b/agenta-cli/agenta/sdk/middleware/vault.py new file mode 100644 index 0000000000..ec7f73165a --- /dev/null +++ b/agenta-cli/agenta/sdk/middleware/vault.py @@ -0,0 +1,61 @@ +from typing import Callable, Dict + +import httpx +from starlette.middleware.base import BaseHTTPMiddleware +from fastapi import FastAPI, Request + +from agenta.sdk.utils.exceptions import suppress + +import agenta as ag + + +class VaultMiddleware(BaseHTTPMiddleware): + def __init__(self, app: FastAPI): + super().__init__(app) + + self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host + + async def dispatch( + self, + request: Request, + call_next: Callable, + ): + print("--- agenta/sdk/middleware/vault.py ---") + request.state.vault = None + + with suppress(): + headers = { + "Authorization": request.state.auth.get("credentials"), + } + + secrets = await self._get_secrets( + headers=headers, + ) + + if secrets: + request.state.vault = { + "secrets": secrets, + } + + print(request.state.vault) + + return await call_next(request) + + async def _get_secrets( + self, + headers: Dict[str, str], + ): + try: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.host}/api/vault/v1/secrets", + headers=headers, + ) + + vault = response.json() + + secrets = vault.get("secrets") + + return secrets + except: # pylint: disable=bare-except + return None diff --git a/agenta-cli/agenta/sdk/tracing/context.py b/agenta-cli/agenta/sdk/tracing/context.py deleted file mode 100644 index 23925db01d..0000000000 --- a/agenta-cli/agenta/sdk/tracing/context.py +++ /dev/null @@ -1,24 +0,0 @@ -from contextvars import ContextVar -from contextlib import contextmanager -from traceback import format_exc - -from agenta.sdk.utils.logging import log - -tracing_context = ContextVar("tracing_context", default={}) - - -@contextmanager -def tracing_context_manager(): - _tracing_context = {"health": {"status": "ok"}} - - token = tracing_context.set(_tracing_context) - try: - yield - except: # pylint: disable=bare-except - log.warning("----------------------------------------------") - log.warning("Agenta SDK - handling tracing exception below:") - log.warning("----------------------------------------------") - log.warning(format_exc().strip("\n")) - log.warning("----------------------------------------------") - finally: - tracing_context.reset(token) diff --git a/agenta-cli/agenta/sdk/tracing/exporters.py b/agenta-cli/agenta/sdk/tracing/exporters.py index 62f03a10b5..c713811eca 100644 --- a/agenta-cli/agenta/sdk/tracing/exporters.py +++ b/agenta-cli/agenta/sdk/tracing/exporters.py @@ -9,6 +9,11 @@ ) from agenta.sdk.utils.exceptions import suppress +from agenta.sdk.context.exporting import ( + exporting_context_manager, + exporting_context, + ExportingContext, +) class InlineTraceExporter(SpanExporter): @@ -58,8 +63,51 @@ def fetch( return trace -OTLPSpanExporter._MAX_RETRY_TIMEOUT = 2 # pylint: disable=protected-access +class OTLPExporter(OTLPSpanExporter): + _MAX_RETRY_TIMEOUT = 2 + + def __init__(self, *args, credentials: Dict[int, str] = None, **kwargs): + super().__init__(*args, **kwargs) + + self.credentials = credentials + + def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: + credentials = None + + # --- DEBUG + for span in spans: + print(span.name, span.attributes) + # --------- + + if self.credentials: + trace_ids = set(span.get_span_context().trace_id for span in spans) + + if len(trace_ids) == 1: + trace_id = trace_ids.pop() + + if trace_id in self.credentials: + credentials = self.credentials.pop(trace_id) + + with exporting_context_manager( + context=ExportingContext( + credentials=credentials, + ) + ): + return super().export(spans) + + def _export(self, serialized_data: bytes): + credentials = exporting_context.get().credentials + + if credentials: + self._session.headers.update({"Authorization": credentials}) + + # --- DEBUG + auth = {"Authorization": self._session.headers.get("Authorization")} + print(" ", auth) + # --------- + + return super()._export(serialized_data) + ConsoleExporter = ConsoleSpanExporter InlineExporter = InlineTraceExporter -OTLPExporter = OTLPSpanExporter diff --git a/agenta-cli/agenta/sdk/tracing/processors.py b/agenta-cli/agenta/sdk/tracing/processors.py index b5d04d8085..2c612220cc 100644 --- a/agenta-cli/agenta/sdk/tracing/processors.py +++ b/agenta-cli/agenta/sdk/tracing/processors.py @@ -1,5 +1,6 @@ from typing import Optional, Dict, List +from opentelemetry.baggage import get_all as get_baggage from opentelemetry.context import Context from opentelemetry.sdk.trace import Span from opentelemetry.sdk.trace.export import ( @@ -11,8 +12,7 @@ ) from agenta.sdk.utils.logging import log - -# LOAD CONTEXT, HERE ! +from agenta.sdk.tracing.conventions import Reference class TraceProcessor(BatchSpanProcessor): @@ -43,9 +43,17 @@ def on_start( span: Span, parent_context: Optional[Context] = None, ) -> None: + baggage = get_baggage(parent_context) + for key in self.references.keys(): span.set_attribute(f"ag.refs.{key}", self.references[key]) + for key in baggage.keys(): + if key.startswith("ag.refs."): + _key = key.replace("ag.refs.", "") + if _key in [_.value for _ in Reference.__members__.values()]: + span.set_attribute(key, baggage[key]) + if span.context.trace_id not in self._registry: self._registry[span.context.trace_id] = dict() @@ -89,7 +97,7 @@ def force_flush( ret = super().force_flush(timeout_millis) if not ret: - log.warning("Agenta SDK - skipping export due to timeout.") + log.warning("Agenta - Skipping export due to timeout.") def is_ready( self, diff --git a/agenta-cli/agenta/sdk/tracing/tracing.py b/agenta-cli/agenta/sdk/tracing/tracing.py index 809c864936..0e92bb9d19 100644 --- a/agenta-cli/agenta/sdk/tracing/tracing.py +++ b/agenta-cli/agenta/sdk/tracing/tracing.py @@ -41,6 +41,8 @@ def __init__( self.headers: Dict[str, str] = dict() # REFERENCES self.references: Dict[str, str] = dict() + # CREDENTIALS + self.credentials: Dict[int, str] = dict() # TRACER PROVIDER self.tracer_provider: Optional[TracerProvider] = None @@ -60,13 +62,16 @@ def __init__( def configure( self, api_key: Optional[str] = None, + service_id: Optional[str] = None, # DEPRECATING app_id: Optional[str] = None, ): # HEADERS (OTLP) if api_key: - self.headers["Authorization"] = api_key + self.headers["Authorization"] = f"ApiKey {api_key}" # REFERENCES + if service_id: + self.references["service.id"] = service_id if app_id: self.references["application.id"] = app_id @@ -84,31 +89,28 @@ def configure( self.tracer_provider.add_span_processor(self.inline) # TRACE PROCESSORS -- OTLP try: - log.info("--------------------------------------------") log.info( - "Agenta SDK - connecting to otlp receiver at: %s", + "Agenta - OLTP URL: %s", self.otlp_url, ) - log.info("--------------------------------------------") - check( - self.otlp_url, - headers=self.headers, - timeout=1, - ) + # check( + # self.otlp_url, + # headers=self.headers, + # timeout=1, + # ) _otlp = TraceProcessor( OTLPExporter( endpoint=self.otlp_url, headers=self.headers, + credentials=self.credentials, ), references=self.references, ) self.tracer_provider.add_span_processor(_otlp) - log.info("Success: traces will be exported.") - log.info("--------------------------------------------") except: # pylint: disable=bare-except - log.warning("Agenta SDK - traces will not be exported.") + log.warning("Agenta - OLTP unreachable, skipping exports.") # GLOBAL TRACER PROVIDER -- INSTRUMENTATION LIBRARIES set_tracer_provider(self.tracer_provider) diff --git a/agenta-cli/agenta/sdk/utils/exceptions.py b/agenta-cli/agenta/sdk/utils/exceptions.py index a451b1de78..2376b6f912 100644 --- a/agenta-cli/agenta/sdk/utils/exceptions.py +++ b/agenta-cli/agenta/sdk/utils/exceptions.py @@ -17,11 +17,11 @@ def __exit__(self, exc_type, exc_value, exc_tb): if exc_type is None: return True else: - log.warning("-------------------------------------------------") - log.warning("Agenta SDK - suppressing tracing exception below:") - log.warning("-------------------------------------------------") + log.warning("--------------------------------") + log.warning("Agenta - Exception (suppressed):") + log.warning("--------------------------------") log.warning(format_exc().strip("\n")) - log.warning("-------------------------------------------------") + log.warning("--------------------------------") return True @@ -34,11 +34,11 @@ async def async_wrapper(*args, **kwargs): try: return await func(*args, **kwargs) except Exception as e: - log.warning("------------------------------------------") - log.warning("Agenta SDK - intercepting exception below:") - log.warning("------------------------------------------") + log.warning("-------------------") + log.warning("Agenta - Exception:") + log.warning("-------------------") log.warning(format_exc().strip("\n")) - log.warning("------------------------------------------") + log.warning("-------------------") raise e @wraps(func) @@ -46,11 +46,11 @@ def sync_wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: - log.warning("------------------------------------------") - log.warning("Agenta SDK - intercepting exception below:") - log.warning("------------------------------------------") + log.warning("-------------------") + log.warning("Agenta - Exception:") + log.warning("-------------------") log.warning(format_exc().strip("\n")) - log.warning("------------------------------------------") + log.warning("-------------------") raise e return async_wrapper if is_coroutine_function else sync_wrapper diff --git a/agenta-cli/tests/baggage/_main.py b/agenta-cli/tests/baggage/_main.py new file mode 100644 index 0000000000..4040d2adb5 --- /dev/null +++ b/agenta-cli/tests/baggage/_main.py @@ -0,0 +1,8 @@ +from uvicorn import run + +import app # pylint: disable=unused-import + +import agenta # pylint: disable=unused-import + +if __name__ == "__main__": + run("agenta:app", host="0.0.0.0", port=8888, reload=True) diff --git a/agenta-cli/tests/baggage/agenta b/agenta-cli/tests/baggage/agenta new file mode 120000 index 0000000000..d77f00d8ae --- /dev/null +++ b/agenta-cli/tests/baggage/agenta @@ -0,0 +1 @@ +/Users/junaway/Agenta/github/agenta-sandbox/baggage/agenta \ No newline at end of file diff --git a/agenta-cli/tests/baggage/app.py b/agenta-cli/tests/baggage/app.py new file mode 100644 index 0000000000..166ab8c8c9 --- /dev/null +++ b/agenta-cli/tests/baggage/app.py @@ -0,0 +1,20 @@ +import agenta as ag + +ag.init(config_fname="config.toml") + +ag.config.default( + flag=ag.BinaryParam(value=False), +) + +# XELnjVve.c1f177c87250b603cf1ed2a69ebdfc1cec3124776058e7afcbba93890c515e74 + + +@ag.route() +@ag.instrument() +def main(aloha: str = "Aloha") -> str: + + print(ag.ConfigManager.get_from_route()) + print(ag.VaultManager.get_from_route()) + print(ag.config.flag) + + return aloha diff --git a/agenta-cli/tests/baggage/config.toml b/agenta-cli/tests/baggage/config.toml new file mode 100644 index 0000000000..f32346649b --- /dev/null +++ b/agenta-cli/tests/baggage/config.toml @@ -0,0 +1,4 @@ +app_name = "baggage" +app_id = "0193b67a-b673-7919-85c2-0b5b0a2183d3" +backend_host = "http://localhost" +api_key = "XELnjVve.c1f177c87250b603cf1ed2a69ebdfc1cec3124776058e7afcbba93890c515e74" diff --git a/agenta-cli/tests/baggage/specs/check_generate.py b/agenta-cli/tests/baggage/specs/check_generate.py new file mode 100644 index 0000000000..e07277b74e --- /dev/null +++ b/agenta-cli/tests/baggage/specs/check_generate.py @@ -0,0 +1,82 @@ +import pytest +import httpx +import os + +BASE_URL = os.getenv("BASE_URL", None) or None +API_KEY = os.getenv("API_KEY", None) or None + +# 200 +# 401 +# 403 +# 405 +# 422 +# 500 + + +def test_unauth_generate(): + """Test /generate without credentials for status 401.""" + + assert ( + BASE_URL is not None + ), "BASE_URL environment variable must be set to run this test" + + response = httpx.get(f"{BASE_URL}/generate") + + assert ( + response.status_code == 401 + ), f"Expected status 401, got {response.status_code}" + + data = response.json() + + assert ( + data["detail"] == "Missing 'authorization' header." + ), f'Expected "Missing \'authorization\' header.", got "{data["detail"]}"' + + +# REQUIRES +# - a valid API key -> API endpoint to create a new API key +# - a valid APP_ID -> API endpoint to create a an app from hooks +def test_auth_generate(): + """Test /generate with credentials for status 401.""" + + assert ( + BASE_URL is not None + ), "BASE_URL environment variable must be set to run this test" + + assert ( + API_KEY is not None + ), "API KEY environment variable must be set to run this test" + + response = httpx.post( + f"{BASE_URL}/generate", + headers={"Authorization": API_KEY}, + json={ + "aloha": "mahalo", + }, + ) + + assert ( + response.status_code == 200 + ), f"Expected status 200, got {response.status_code}" + + data = response.json() + + assert "data" in data, "Expected 'data' key in response JSON" + + assert "mahalo" in data["data"], "Expected data:'mahalo' in response JSON" + + assert "tree" in data, "Expected 'tree' key in response JSON" + + assert "nodes" in data["tree"], "Expected tree:'nodes' in response JSON" + + assert ( + len(data["tree"]["nodes"]) == 1 + ), "Expected tree:'nodes' length 1 in response JSON" + + assert ( + "inputs" in data["tree"]["nodes"][0]["data"] + ), "Expected tree:'nodes':'inputs' in response JSON" + + assert ( + "outputs" in data["tree"]["nodes"][0]["data"] + ), "Expected tree:'nodes':'outputs' in response JSON" diff --git a/agenta-cli/tests/baggage/specs/check_openapi.py b/agenta-cli/tests/baggage/specs/check_openapi.py new file mode 100644 index 0000000000..7531313800 --- /dev/null +++ b/agenta-cli/tests/baggage/specs/check_openapi.py @@ -0,0 +1,51 @@ +import pytest +import httpx +import os + +BASE_URL = os.getenv("BASE_URL", None) or None +API_KEY = os.getenv("API_KEY", None) or None + + +def test_unauth_openapi(): + """Test /openapi.json without credentials for status 401.""" + + assert ( + BASE_URL is not None + ), "BASE_URL environment variable must be set to run this test" + + response = httpx.get(f"{BASE_URL}/openapi.json") + + assert ( + response.status_code == 401 + ), f"Expected status 401, got {response.status_code}" + + data = response.json() + + assert ( + data["detail"] == "Missing 'authorization' header." + ), f'Expected "Missing \'authorization\' header.", got "{data["detail"]}"' + + +# REQUIRES +# - a valid API key -> API endpoint to create a new API key +# - a valid APP_ID -> API endpoint to create a an app from hooks +def test_auth_openapi(): + """Test /openapi.json with credentials for status 401.""" + + assert ( + BASE_URL is not None + ), "BASE_URL environment variable must be set to run this test" + + assert ( + API_KEY is not None + ), "API KEY environment variable must be set to run this test" + + response = httpx.get(f"{BASE_URL}/openapi.json", headers={"Authorization": API_KEY}) + + assert ( + response.status_code == 200 + ), f"Expected status 200, got {response.status_code}" + + data = response.json() + + assert "openapi" in data, "Expected 'openapi' key in response JSON" diff --git a/agenta-cli/tests/run_pytest.sh b/agenta-cli/tests/run_pytest.sh new file mode 100755 index 0000000000..19bdf1163b --- /dev/null +++ b/agenta-cli/tests/run_pytest.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +# Define default values +TEST_TARGET="specs/*" +PYTEST_OPTIONS="" +MARKERS="" +APP="" + +# Function to display usage +usage() { + echo "Usage: $0 [-t test_target] [-o pytest_options] [-m markers] [-a app]" + echo " -t test_target Specify the pytest test target to run. Default is 'specs/'." + echo " -o pytest_options Pass additional options to pytest." + echo " -m markers Specify marker expressions (e.g., 'smoke or integration')." + echo " -a app Specify the FastAPI app to run." + exit 1 +} + +# Parse command-line arguments +while getopts "t:o:m:a:" opt; do + case ${opt} in + t) TEST_TARGET="$OPTARG" ;; + o) PYTEST_OPTIONS="$OPTARG" ;; + m) MARKERS="$OPTARG" ;; + a) APP="$OPTARG" ;; + *) usage ;; + esac +done + +if [[ -z "$APP" ]]; then + echo "Error: Please specify the FastAPI app to run with the -a option." + usage +fi + +TEST_TARGET="./apps/${APP}/${TEST_TARGET}" + +# Build marker expression if markers are specified +if [[ -n "$MARKERS" ]]; then + MARKER_EXPR="-m \"$MARKERS\"" +fi + +# Run pytest with the specified options +echo "Running pytest tests in $TEST_TARGET with options: $PYTEST_OPTIONS $MARKER_EXPR" +eval pytest "$TEST_TARGET" $PYTEST_OPTIONS $MARKER_EXPR \ No newline at end of file diff --git a/agenta-cli/tests/run_tests.sh b/agenta-cli/tests/run_tests.sh new file mode 100755 index 0000000000..a3f5e82f15 --- /dev/null +++ b/agenta-cli/tests/run_tests.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +# Define default values for the server +HOST="127.0.0.1" +PORT="8888" +TEST_TARGET="specs/*" +PYTEST_OPTIONS="" +MARKERS="" +APP="" + +# Function to display usage +usage() { + echo "Usage: $0 [-h host] [-p port] [-t test_target] [-o pytest_options] [-m markers] [-a app] [-k key]" + echo " -h host Specify the FastAPI server host. Default is 127.0.0.1." + echo " -p port Specify the FastAPI server port. Default is 8000." + echo " -t test_target Specify the pytest test target to run. Default is 'specs'." + echo " -o pytest_options Pass additional options to pytest." + echo " -m markers Specify marker expressions (e.g., 'smoke or integration')." + echo " -a app Specify the FastAPI app to run." + echo " -k key Specify the API key." + exit 1 +} + +# Parse command-line arguments +while getopts "h:p:t:o:m:a:k:" opt; do + case ${opt} in + h) HOST="$OPTARG" ;; + p) PORT="$OPTARG" ;; + t) TEST_TARGET="$OPTARG" ;; + o) PYTEST_OPTIONS="$OPTARG" ;; + m) MARKERS="$OPTARG" ;; + a) APP="$OPTARG" ;; + k) API_KEY="$OPTARG" ;; + *) usage ;; + esac +done + +if [[ -z "$APP" ]]; then + echo "Error: Please specify the FastAPI app to run with the -a option." + usage +fi + +# Start the FastAPI server +./start_server.sh -h "$HOST" -p "$PORT" -a "$APP" + +# Export the base URL as an environment variable +export BASE_URL="http://${HOST}:${PORT}" + +# Export the API key as an environment variable +export API_KEY="$API_KEY" + +# Run pytest tests with markers +./run_pytest.sh -t "$TEST_TARGET" -o "$PYTEST_OPTIONS" -m "$MARKERS" -a "$APP" + +# Stop the FastAPI server +./stop_server.sh \ No newline at end of file diff --git a/agenta-cli/tests/start_server.sh b/agenta-cli/tests/start_server.sh new file mode 100755 index 0000000000..f6b4129ef9 --- /dev/null +++ b/agenta-cli/tests/start_server.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Define default values +HOST="127.0.0.1" +PORT="8888" +APP= + +# Function to display usage +usage() { + echo "Usage: $0 [-h host] [-p port] [-a app]" + echo " -h host Specify the FastAPI server host. Default is 127.0.0.1." + echo " -p port Specify the FastAPI server port. Default is 8000." + echo " -a app Specify the FastAPI app to run. Default is 'baggage'." + exit 1 +} + +# Parse command-line arguments +while getopts "h:p:a:" opt; do + case ${opt} in + h) HOST="$OPTARG" ;; + p) PORT="$OPTARG" ;; + a) APP="$OPTARG" ;; + *) usage ;; + esac +done + +if [[ -z "$APP" ]]; then + echo "Error: Please specify the FastAPI app to run with the -a option." + usage +fi + +# Start the FastAPI server +echo "Starting FastAPI server at http://${HOST}:${PORT}..." +#uvicorn app:app --host "${HOST}" --port "${PORT}" & +cd ./apps/${APP} +python3 _main.py & +SERVER_PID=$! +echo "Server PID: $SERVER_PID" +echo $SERVER_PID > ../../server.pid # Save PID to a file for later use +sleep 3 # Wait for the server to start \ No newline at end of file diff --git a/agenta-cli/tests/stop_server.sh b/agenta-cli/tests/stop_server.sh new file mode 100755 index 0000000000..1678d850db --- /dev/null +++ b/agenta-cli/tests/stop_server.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# Function to display usage +usage() { + echo "Usage: $0" + echo "Stops the FastAPI server started with start_server.sh." + exit 1 +} + +# Check if the PID file exists +if [[ ! -f server.pid ]]; then + echo "Error: PID file 'server.pid' not found. Is the server running?" + exit 1 +fi + +# Read the PID from the file +SERVER_PID=$(cat server.pid) + +# Validate that the PID is a running process +if ps -p "$SERVER_PID" > /dev/null 2>&1; then + echo "Stopping FastAPI server (PID: $SERVER_PID)..." + kill "$SERVER_PID" # Send the termination signal + + # Wait for the process to terminate + sleep 2 + + # Double-check if the process is still running + if ps -p "$SERVER_PID" > /dev/null 2>&1; then + echo "Error: Failed to stop the server. Attempting to force stop..." + kill -9 "$SERVER_PID" # Force kill the process + sleep 1 + + if ps -p "$SERVER_PID" > /dev/null 2>&1; then + echo "Error: Unable to stop the server process even with force. Manual intervention required." + exit 1 + fi + fi + + echo "FastAPI server stopped successfully." + rm -f server.pid # Remove the PID file +else + echo "Error: No process found with PID $SERVER_PID. Removing stale PID file." + rm -f server.pid + exit 1 +fi \ No newline at end of file From 7a0008eb4e0f24e611ed854405f923ab79397b6e Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Wed, 11 Dec 2024 17:36:44 +0100 Subject: [PATCH 02/11] black . --- agenta-cli/agenta/sdk/decorators/tracing.py | 4 +++- agenta-cli/agenta/sdk/middleware/auth.py | 1 - agenta-cli/tests/baggage/app.py | 3 --- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/agenta-cli/agenta/sdk/decorators/tracing.py b/agenta-cli/agenta/sdk/decorators/tracing.py index 44891355a7..7b8e3c6881 100644 --- a/agenta-cli/agenta/sdk/decorators/tracing.py +++ b/agenta-cli/agenta/sdk/decorators/tracing.py @@ -246,7 +246,9 @@ def _redact( not in ( ignore if isinstance(ignore, list) - else io.keys() if ignore is True else [] + else io.keys() + if ignore is True + else [] ) } diff --git a/agenta-cli/agenta/sdk/middleware/auth.py b/agenta-cli/agenta/sdk/middleware/auth.py index a0cda77af0..89e17a6005 100644 --- a/agenta-cli/agenta/sdk/middleware/auth.py +++ b/agenta-cli/agenta/sdk/middleware/auth.py @@ -20,7 +20,6 @@ class DenyResponse(JSONResponse): - def __init__( self, status_code: int = 401, diff --git a/agenta-cli/tests/baggage/app.py b/agenta-cli/tests/baggage/app.py index 166ab8c8c9..9a4ebc1004 100644 --- a/agenta-cli/tests/baggage/app.py +++ b/agenta-cli/tests/baggage/app.py @@ -6,13 +6,10 @@ flag=ag.BinaryParam(value=False), ) -# XELnjVve.c1f177c87250b603cf1ed2a69ebdfc1cec3124776058e7afcbba93890c515e74 - @ag.route() @ag.instrument() def main(aloha: str = "Aloha") -> str: - print(ag.ConfigManager.get_from_route()) print(ag.VaultManager.get_from_route()) print(ag.config.flag) From 15c6b85f3f37479399d3ce2517783512d97d15b0 Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Thu, 12 Dec 2024 11:10:32 +0100 Subject: [PATCH 03/11] display exception improvements --- agenta-cli/agenta/sdk/decorators/routing.py | 8 +-- agenta-cli/agenta/sdk/middleware/auth.py | 63 ++++++++++----------- agenta-cli/agenta/sdk/middleware/otel.py | 8 +-- agenta-cli/agenta/sdk/utils/exceptions.py | 39 ++++++------- agenta-cli/pyproject.toml | 2 +- 5 files changed, 54 insertions(+), 66 deletions(-) diff --git a/agenta-cli/agenta/sdk/decorators/routing.py b/agenta-cli/agenta/sdk/decorators/routing.py index e7fda7d22c..a0bddfa72a 100644 --- a/agenta-cli/agenta/sdk/decorators/routing.py +++ b/agenta-cli/agenta/sdk/decorators/routing.py @@ -29,7 +29,7 @@ TracingContext, ) from agenta.sdk.router import router -from agenta.sdk.utils.exceptions import suppress +from agenta.sdk.utils.exceptions import suppress, display_exception from agenta.sdk.utils.logging import log from agenta.sdk.types import ( DictInput, @@ -400,11 +400,7 @@ async def handle_success(self, result: Any, inline: bool): return BaseResponse(data=data) def handle_failure(self, error: Exception): - log.warning("-------------------------------") - log.warning("Agenta - Application Exception:") - log.warning("-------------------------------") - log.warning(format_exc().strip("\n")) - log.warning("-------------------------------") + display_exception("Application Exception") status_code = 500 message = str(error) diff --git a/agenta-cli/agenta/sdk/middleware/auth.py b/agenta-cli/agenta/sdk/middleware/auth.py index 89e17a6005..2175616dd0 100644 --- a/agenta-cli/agenta/sdk/middleware/auth.py +++ b/agenta-cli/agenta/sdk/middleware/auth.py @@ -10,6 +10,7 @@ from agenta.sdk.utils.logging import log +from agenta.sdk.utils.exceptions import display_exception import agenta as ag @@ -48,13 +49,17 @@ def __init__(self, app: FastAPI): super().__init__(app) self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host - self.resource_id = ( - # STATELESS - ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id - # LEGACY OR STATEFUL - or ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.app_id - ) - self.resource_type = "application" + + self.resource_id = None + self.resource_type = None + + if ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id: + self.resource_id = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id + self.resource_type = "service" + + elif ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.app_id: + self.resource_id = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.app_id + self.resource_type = "application" async def dispatch( self, @@ -68,10 +73,14 @@ async def dispatch( return await call_next(request) try: - authorization = request.headers.get("Authorization", None) + authorization = request.headers.get("authorization", None) headers = {"Authorization": authorization} if authorization else None + access_token = request.cookies.get("sAccessToken", None) + + cookies = {"sAccessToken": access_token} if access_token else None + baggage = request.state.otel.get("baggage") if request.state.otel else {} project_id = ( @@ -90,10 +99,16 @@ async def dispatch( if project_id: params["project_id"] = project_id + print("-----------------------------------") + print(headers) + print(cookies) + print(params) + print("-----------------------------------") + credentials = await self._get_credentials( - # credentials = await self._mock_get_credentials( params=params, headers=headers, + cookies=cookies, ) request.state.auth = {"credentials": credentials} @@ -103,11 +118,7 @@ async def dispatch( return await call_next(request) except DenyException as deny: - log.warning("-----------------------------------") - log.warning("Agenta - Auth Middleware Exception:") - log.warning("-----------------------------------") - log.warning(format_exc().strip("\n")) - log.warning("-----------------------------------") + display_exception("Auth Middleware Exception") return DenyResponse( status_code=deny.status_code, @@ -115,31 +126,18 @@ async def dispatch( ) except: # pylint: disable=bare-except - log.warning("-----------------------------------") - log.warning("Agenta - Auth Middleware Exception:") - log.warning("-----------------------------------") - log.warning(format_exc().strip("\n")) - log.warning("-----------------------------------") + display_exception("Auth Middleware Exception") return DenyResponse( status_code=500, detail="Internal Server Error: auth middleware.", ) - async def _mock_get_credentials( - self, - params: Dict[str, str], - headers: Dict[str, str], - ): - if not headers: - raise DenyException(content="Missing 'authorization' header.") - - return headers.get("Authorization") - async def _get_credentials( self, params: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None, + cookies: Optional[str] = None, ): if not headers: raise DenyException(content="Missing 'authorization' header.") @@ -150,6 +148,7 @@ async def _get_credentials( f"{self.host}/api/permissions/verify", headers=headers, params=params, + cookies=cookies, ) if response.status_code == 401: @@ -182,11 +181,7 @@ async def _get_credentials( raise deny except Exception as exc: # pylint: disable=bare-except - log.warning("------------------------------------------------") - log.warning("Agenta - Auth Middleware Exception (suppressed):") - log.warning("------------------------------------------------") - log.warning(format_exc().strip("\n")) - log.warning("------------------------------------------------") + display_exception("Auth Middleware Exception (suppressed)") raise DenyException( status_code=500, content="Internal Server Error: auth middleware." diff --git a/agenta-cli/agenta/sdk/middleware/otel.py b/agenta-cli/agenta/sdk/middleware/otel.py index 1e7352e82d..bca6acc7ab 100644 --- a/agenta-cli/agenta/sdk/middleware/otel.py +++ b/agenta-cli/agenta/sdk/middleware/otel.py @@ -12,11 +12,7 @@ class OTelMiddleware(BaseHTTPMiddleware): def __init__(self, app: FastAPI): super().__init__(app) - async def dispatch( - self, - request: Request, - call_next: Callable, - ): + async def dispatch(self, request: Request, call_next: Callable): request.state.otel = None with suppress(): @@ -27,7 +23,7 @@ async def dispatch( if context: request.state.otel = {"baggage": {}} - for _, partial in context.values(): + for partial in context.values(): for key, value in partial.items(): request.state.otel["baggage"][key] = value diff --git a/agenta-cli/agenta/sdk/utils/exceptions.py b/agenta-cli/agenta/sdk/utils/exceptions.py index 2376b6f912..a1d5cb3793 100644 --- a/agenta-cli/agenta/sdk/utils/exceptions.py +++ b/agenta-cli/agenta/sdk/utils/exceptions.py @@ -6,6 +6,17 @@ from agenta.sdk.utils.logging import log +def display_exception(message: str): + _len = len("Agenta - ") + len(message) + len(":") + _bar = "-" * _len + + log.warning(_bar) + log.warning("Agenta - %s:", message) + log.warning(_bar) + log.warning(format_exc().strip("\n")) + log.warning(_bar) + + class suppress(AbstractContextManager): # pylint: disable=invalid-name def __init__(self): pass @@ -14,15 +25,10 @@ def __enter__(self): pass def __exit__(self, exc_type, exc_value, exc_tb): - if exc_type is None: - return True - else: - log.warning("--------------------------------") - log.warning("Agenta - Exception (suppressed):") - log.warning("--------------------------------") - log.warning(format_exc().strip("\n")) - log.warning("--------------------------------") - return True + if exc_type is not None: + display_exception("Exception (suppressed)") + + return True def handle_exceptions(): @@ -33,12 +39,10 @@ def decorator(func): async def async_wrapper(*args, **kwargs): try: return await func(*args, **kwargs) + except Exception as e: - log.warning("-------------------") - log.warning("Agenta - Exception:") - log.warning("-------------------") - log.warning(format_exc().strip("\n")) - log.warning("-------------------") + display_exception("Exception") + raise e @wraps(func) @@ -46,11 +50,8 @@ def sync_wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: - log.warning("-------------------") - log.warning("Agenta - Exception:") - log.warning("-------------------") - log.warning(format_exc().strip("\n")) - log.warning("-------------------") + display_exception("Exception") + raise e return async_wrapper if is_coroutine_function else sync_wrapper diff --git a/agenta-cli/pyproject.toml b/agenta-cli/pyproject.toml index fc823fc410..4c16e50376 100644 --- a/agenta-cli/pyproject.toml +++ b/agenta-cli/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "agenta" -version = "0.29.0" +version = "0.29.1a1" description = "The SDK for agenta is an open-source LLMOps platform." readme = "README.md" authors = ["Mahmoud Mabrouk "] From 372b3d757cf63f72815227626eee89fe6470f286 Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Thu, 12 Dec 2024 16:56:07 +0100 Subject: [PATCH 04/11] Fix middleware and add positive caching. --- agenta-cli/agenta/sdk/middleware/auth.py | 125 +++++++++---------- agenta-cli/agenta/sdk/middleware/cache.py | 43 +++++++ agenta-cli/agenta/sdk/middleware/config.py | 137 ++++++++++++--------- agenta-cli/agenta/sdk/middleware/otel.py | 32 +++-- agenta-cli/agenta/sdk/middleware/vault.py | 73 ++++++----- agenta-cli/agenta/sdk/utils/timing.py | 58 +++++++++ 6 files changed, 305 insertions(+), 163 deletions(-) create mode 100644 agenta-cli/agenta/sdk/middleware/cache.py create mode 100644 agenta-cli/agenta/sdk/utils/timing.py diff --git a/agenta-cli/agenta/sdk/middleware/auth.py b/agenta-cli/agenta/sdk/middleware/auth.py index 2175616dd0..6b0b9a1bad 100644 --- a/agenta-cli/agenta/sdk/middleware/auth.py +++ b/agenta-cli/agenta/sdk/middleware/auth.py @@ -1,16 +1,16 @@ from typing import Callable, Dict, Optional from os import getenv -from traceback import format_exc +from json import dumps import httpx from starlette.middleware.base import BaseHTTPMiddleware from fastapi import FastAPI, Request from fastapi.responses import JSONResponse - -from agenta.sdk.utils.logging import log +from agenta.sdk.middleware.cache import TTLLRUCache from agenta.sdk.utils.exceptions import display_exception +from agenta.sdk.utils.timing import atimeit import agenta as ag @@ -18,6 +18,13 @@ _ALLOW_UNAUTHORIZED = ( getenv("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", "false").lower() in _TRUTHY ) +_SHARED_SERVICE = getenv("AGENTA_SHARED_SERVICE", "true").lower() in _TRUTHY +_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in _TRUTHY + +_CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512")) +_CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes + +_cache = TTLLRUCache(capacity=_CACHE_CAPACITY, ttl=_CACHE_TTL) class DenyResponse(JSONResponse): @@ -49,29 +56,42 @@ def __init__(self, app: FastAPI): super().__init__(app) self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host + self.resource_id = ( + ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id + if not _SHARED_SERVICE + else None + ) - self.resource_id = None - self.resource_type = None - - if ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id: - self.resource_id = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id - self.resource_type = "service" + async def dispatch(self, request: Request, call_next: Callable): + try: + if _ALLOW_UNAUTHORIZED: + request.state.auth = None - elif ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.app_id: - self.resource_id = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.app_id - self.resource_type = "application" + else: + credentials = await self._get_credentials(request) - async def dispatch( - self, - request: Request, - call_next: Callable, - ): - print("--- agenta/sdk/middleware/auth.py ---") - request.state.auth = None + request.state.auth = {"credentials": credentials} - if _ALLOW_UNAUTHORIZED: return await call_next(request) + except DenyException as deny: + display_exception("Auth Middleware Exception") + + return DenyResponse( + status_code=deny.status_code, + detail=deny.content, + ) + + except: # pylint: disable=bare-except + display_exception("Auth Middleware Exception") + + return DenyResponse( + status_code=500, + detail="Internal Server Error: auth middleware.", + ) + + # @atimeit + async def _get_credentials(self, request: Request) -> Optional[str]: try: authorization = request.headers.get("authorization", None) @@ -90,65 +110,35 @@ async def dispatch( or request.query_params.get("project_id") ) - params = { - "action": "run_service", - "resource_type": self.resource_type, - "resource_id": self.resource_id, - } + params = {"action": "run_service", "resource_type": "service"} + + if self.resource_id: + params["resource_id"] = self.resource_id if project_id: params["project_id"] = project_id - print("-----------------------------------") - print(headers) - print(cookies) - print(params) - print("-----------------------------------") - - credentials = await self._get_credentials( - params=params, - headers=headers, - cookies=cookies, + _hash = dumps( + { + "headers": headers, + "cookies": cookies, + "params": params, + }, + sort_keys=True, ) - request.state.auth = {"credentials": credentials} + if _CACHE_ENABLED: + credentials = _cache.get(_hash) - print(request.state.auth) + if credentials: + return credentials - return await call_next(request) - - except DenyException as deny: - display_exception("Auth Middleware Exception") - - return DenyResponse( - status_code=deny.status_code, - detail=deny.content, - ) - - except: # pylint: disable=bare-except - display_exception("Auth Middleware Exception") - - return DenyResponse( - status_code=500, - detail="Internal Server Error: auth middleware.", - ) - - async def _get_credentials( - self, - params: Optional[Dict[str, str]] = None, - headers: Optional[Dict[str, str]] = None, - cookies: Optional[str] = None, - ): - if not headers: - raise DenyException(content="Missing 'authorization' header.") - - try: async with httpx.AsyncClient() as client: response = await client.get( f"{self.host}/api/permissions/verify", headers=headers, - params=params, cookies=cookies, + params=params, ) if response.status_code == 401: @@ -173,7 +163,8 @@ async def _get_credentials( ) credentials = auth.get("credentials") - # --- # + + _cache.put(_hash, credentials) return credentials diff --git a/agenta-cli/agenta/sdk/middleware/cache.py b/agenta-cli/agenta/sdk/middleware/cache.py new file mode 100644 index 0000000000..5445b1fafc --- /dev/null +++ b/agenta-cli/agenta/sdk/middleware/cache.py @@ -0,0 +1,43 @@ +from time import time +from collections import OrderedDict + + +class TTLLRUCache: + def __init__(self, capacity: int, ttl: int): + self.cache = OrderedDict() + self.capacity = capacity + self.ttl = ttl + + def get(self, key): + # CACHE + if key not in self.cache: + return None + + value, expiry = self.cache[key] + # ----- + + # TTL + if time() > expiry: + del self.cache[key] + + return None + # --- + + # LRU + self.cache.move_to_end(key) + # --- + + return value + + def put(self, key, value): + # CACHE + if key in self.cache: + del self.cache[key] + # CACHE & LRU + elif len(self.cache) >= self.capacity: + self.cache.popitem(last=False) + # ----------- + + # TTL + self.cache[key] = (value, time() + self.ttl) + # --- diff --git a/agenta-cli/agenta/sdk/middleware/config.py b/agenta-cli/agenta/sdk/middleware/config.py index c91e5cf11f..119a206424 100644 --- a/agenta-cli/agenta/sdk/middleware/config.py +++ b/agenta-cli/agenta/sdk/middleware/config.py @@ -1,6 +1,8 @@ -from typing import Callable, Optional, Dict +from typing import Callable, Optional, Tuple, Dict + +from os import getenv +from json import dumps -from uuid import UUID from pydantic import BaseModel from starlette.middleware.base import BaseHTTPMiddleware @@ -8,10 +10,20 @@ import httpx +from agenta.sdk.middleware.cache import TTLLRUCache from agenta.sdk.utils.exceptions import suppress +from agenta.sdk.utils.timing import atimeit import agenta as ag +_TRUTHY = {"true", "1", "t", "y", "yes", "on", "enable", "enabled"} +_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in _TRUTHY + +_CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512")) +_CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes + +_cache = TTLLRUCache(capacity=_CACHE_CAPACITY, ttl=_CACHE_TTL) + class Reference(BaseModel): id: Optional[str] = None @@ -134,79 +146,90 @@ async def dispatch( request: Request, call_next: Callable, ): - print("--- agenta/sdk/middleware/config.py ---") request.state.config = None with suppress(): - application_ref = await _parse_application_ref(request) - variant_ref = await _parse_variant_ref(request) - environment_ref = await _parse_environment_ref(request) - - auth = request.state.auth or {} + parameters, references = await self._get_config(request) - headers = { - "Authorization": auth.get("credentials"), + request.state.config = { + "parameters": parameters, + "references": references, } - refs = {} - if application_ref: - refs["application_ref"] = application_ref.model_dump() - if variant_ref: - refs["variant_ref"] = variant_ref.model_dump() - if environment_ref: - refs["environment_ref"] = environment_ref.model_dump() + return await call_next(request) - config = await self._get_config( - headers=headers, - refs=refs, - ) + # @atimeit + async def _get_config(self, request: Request) -> Optional[Tuple[Dict, Dict]]: + application_ref = await _parse_application_ref(request) + variant_ref = await _parse_variant_ref(request) + environment_ref = await _parse_environment_ref(request) - if config: - parameters = config.get("params") + auth = request.state.auth or {} - references = {} + headers = { + "Authorization": auth.get("credentials"), + } - for ref_key in ["application_ref", "variant_ref", "environment_ref"]: - refs = config.get(ref_key) - ref_prefix = ref_key.split("_", maxsplit=1)[0] + refs = {} + if application_ref: + refs["application_ref"] = application_ref.model_dump() + if variant_ref: + refs["variant_ref"] = variant_ref.model_dump() + if environment_ref: + refs["environment_ref"] = environment_ref.model_dump() - for ref_part_key in ["id", "slug", "version"]: - ref_part = refs.get(ref_part_key) + if not refs: + return None, None - if ref_part: - references[ref_prefix + "." + ref_part_key] = ref_part + _hash = dumps( + { + "headers": headers, + "refs": refs, + }, + sort_keys=True, + ) - request.state.config = { - "parameters": parameters, - "references": references, - } + if _CACHE_ENABLED: + config_cache = _cache.get(_hash) - print(request.state.config) + if config_cache: + parameters = config_cache.get("parameters") + references = config_cache.get("references") - return await call_next(request) + return parameters, references - async def _get_config( - self, - headers: Dict[str, str], - refs: Dict[str, str], - ): - if not refs: - return None + config = None + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.host}/api/variants/configs/fetch", + headers=headers, + json=refs, + ) + + if response.status_code != 200: + return None + + config = response.json() + + if not config: + _cache.put(_hash, {"parameters": None, "references": None}) + + return None, None + + parameters = config.get("params") + + references = {} - try: - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.host}/api/variants/configs/fetch", - headers=headers, - json=refs, - ) + for ref_key in ["application_ref", "variant_ref", "environment_ref"]: + refs = config.get(ref_key) + ref_prefix = ref_key.split("_", maxsplit=1)[0] - if response.status_code != 200: - return None + for ref_part_key in ["id", "slug", "version"]: + ref_part = refs.get(ref_part_key) - config = response.json() + if ref_part: + references[ref_prefix + "." + ref_part_key] = ref_part - return config + _cache.put(_hash, {"parameters": parameters, "references": references}) - except: # pylint: disable=bare-except - return None + return parameters, references diff --git a/agenta-cli/agenta/sdk/middleware/otel.py b/agenta-cli/agenta/sdk/middleware/otel.py index bca6acc7ab..51f3154e16 100644 --- a/agenta-cli/agenta/sdk/middleware/otel.py +++ b/agenta-cli/agenta/sdk/middleware/otel.py @@ -3,10 +3,11 @@ from starlette.middleware.base import BaseHTTPMiddleware from fastapi import Request, FastAPI -from agenta.sdk.utils.exceptions import suppress - from opentelemetry.baggage.propagation import W3CBaggagePropagator +from agenta.sdk.utils.exceptions import suppress +from agenta.sdk.utils.timing import atimeit + class OTelMiddleware(BaseHTTPMiddleware): def __init__(self, app: FastAPI): @@ -16,15 +17,26 @@ async def dispatch(self, request: Request, call_next: Callable): request.state.otel = None with suppress(): - baggage = {"baggage": request.headers.get("Baggage", "")} + baggage = await self._get_baggage(request) - context = W3CBaggagePropagator().extract(baggage) + request.state.otel = {"baggage": baggage} - if context: - request.state.otel = {"baggage": {}} + return await call_next(request) - for partial in context.values(): - for key, value in partial.items(): - request.state.otel["baggage"][key] = value + # @atimeit + async def _get_baggage( + self, + request, + ): + _baggage = {"baggage": request.headers.get("Baggage", "")} - return await call_next(request) + context = W3CBaggagePropagator().extract(_baggage) + + baggage = {} + + if context: + for partial in context.values(): + for key, value in partial.items(): + baggage[key] = value + + return baggage diff --git a/agenta-cli/agenta/sdk/middleware/vault.py b/agenta-cli/agenta/sdk/middleware/vault.py index ec7f73165a..a69135b188 100644 --- a/agenta-cli/agenta/sdk/middleware/vault.py +++ b/agenta-cli/agenta/sdk/middleware/vault.py @@ -1,13 +1,26 @@ -from typing import Callable, Dict +from typing import Callable, Dict, Optional + +from os import getenv +from json import dumps import httpx from starlette.middleware.base import BaseHTTPMiddleware from fastapi import FastAPI, Request +from agenta.sdk.middleware.cache import TTLLRUCache from agenta.sdk.utils.exceptions import suppress +from agenta.sdk.utils.timing import atimeit import agenta as ag +_TRUTHY = {"true", "1", "t", "y", "yes", "on", "enable", "enabled"} +_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in _TRUTHY + +_CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512")) +_CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes + +_cache = TTLLRUCache(capacity=_CACHE_CAPACITY, ttl=_CACHE_TTL) + class VaultMiddleware(BaseHTTPMiddleware): def __init__(self, app: FastAPI): @@ -20,42 +33,44 @@ async def dispatch( request: Request, call_next: Callable, ): - print("--- agenta/sdk/middleware/vault.py ---") request.state.vault = None with suppress(): - headers = { - "Authorization": request.state.auth.get("credentials"), - } + secrets = await self._get_secrets(request) - secrets = await self._get_secrets( - headers=headers, - ) - - if secrets: - request.state.vault = { - "secrets": secrets, - } - - print(request.state.vault) + request.state.vault = {"secrets": secrets} return await call_next(request) - async def _get_secrets( - self, - headers: Dict[str, str], - ): - try: - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.host}/api/vault/v1/secrets", - headers=headers, - ) + # @atimeit + async def _get_secrets(self, request: Request) -> Optional[Dict]: + headers = {"Authorization": request.state.auth.get("credentials")} + + _hash = dumps( + { + "headers": headers, + }, + sort_keys=True, + ) - vault = response.json() + if _CACHE_ENABLED: + secrets_cache = _cache.get(_hash) - secrets = vault.get("secrets") + if secrets_cache: + secrets = secrets_cache.get("secrets") return secrets - except: # pylint: disable=bare-except - return None + + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.host}/api/vault/v1/secrets", + headers=headers, + ) + + vault = response.json() + + secrets = vault.get("secrets") + + _cache.put(_hash, {"secrets": secrets}) + + return secrets diff --git a/agenta-cli/agenta/sdk/utils/timing.py b/agenta-cli/agenta/sdk/utils/timing.py new file mode 100644 index 0000000000..c73b5f210d --- /dev/null +++ b/agenta-cli/agenta/sdk/utils/timing.py @@ -0,0 +1,58 @@ +import time +from functools import wraps + +from agenta.sdk.utils.logging import log + + +def timeit(func): + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + + execution_time = end_time - start_time + + if execution_time < 1e-3: + time_value = execution_time * 1e6 + unit = "us" + elif execution_time < 1: + time_value = execution_time * 1e3 + unit = "ms" + else: + time_value = execution_time + unit = "s" + + class_name = args[0].__class__.__name__ if args else None + + log.info(f"'{class_name}.{func.__name__}' executed in {time_value:.4f} {unit}.") + return result + + return wrapper + + +def atimeit(func): + @wraps(func) + async def wrapper(*args, **kwargs): + start_time = time.time() + result = await func(*args, **kwargs) + end_time = time.time() + + execution_time = end_time - start_time + + if execution_time < 1e-3: + time_value = execution_time * 1e6 + unit = "us" + elif execution_time < 1: + time_value = execution_time * 1e3 + unit = "ms" + else: + time_value = execution_time + unit = "s" + + class_name = args[0].__class__.__name__ if args else None + + log.info(f"'{class_name}.{func.__name__}' executed in {time_value:.4f} {unit}.") + return result + + return wrapper From 1d59833218aceef9d4658597f39e5f32b23d5077 Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Thu, 12 Dec 2024 16:56:44 +0100 Subject: [PATCH 05/11] Add shared service and services in permissions router --- .../agenta_backend/routers/permissions_router.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/agenta-backend/agenta_backend/routers/permissions_router.py b/agenta-backend/agenta_backend/routers/permissions_router.py index 9e7532b975..3dfd0bdba5 100644 --- a/agenta-backend/agenta_backend/routers/permissions_router.py +++ b/agenta-backend/agenta_backend/routers/permissions_router.py @@ -51,7 +51,7 @@ async def verify_permissions( if isOss(): return Allow(None) - if not action or not resource_type or not resource_id: + if not action or not resource_type: raise Deny() if isCloudEE(): @@ -70,8 +70,8 @@ async def verify_permissions( # CHECK PERMISSION 2/2: RESOURCE allow_resource = await check_resource_access( project_id=UUID(request.state.project_id), - resource_id=resource_id, resource_type=resource_type, + resource_id=resource_id, ) if not allow_resource: @@ -80,13 +80,14 @@ async def verify_permissions( return Allow(request.state.credentials) except Exception as exc: # pylint: disable=bare-except + print(exc) raise Deny() from exc async def check_resource_access( project_id: UUID, - resource_id: UUID, resource_type: str, + resource_id: Optional[UUID] = None, ) -> bool: resource_project_id = None @@ -95,6 +96,15 @@ async def check_resource_access( resource_project_id = app.project_id + if resource_type == "service": + if resource_id is None: + resource_project_id = project_id + + else: + base = await db_manager.fetch_base_by_id(base_id=str(resource_id)) + + resource_project_id = base.project_id + allow_resource = resource_project_id == project_id return allow_resource From 86f65cdf361094426189b1b11235a8d242aa5a54 Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Fri, 13 Dec 2024 10:13:52 +0100 Subject: [PATCH 06/11] Add local secrets and mrge with vault secrets --- agenta-cli/agenta/sdk/managers/vault.py | 15 +--- agenta-cli/agenta/sdk/middleware/vault.py | 92 ++++++++++++++++++++--- 2 files changed, 85 insertions(+), 22 deletions(-) diff --git a/agenta-cli/agenta/sdk/managers/vault.py b/agenta-cli/agenta/sdk/managers/vault.py index bf3c1afe0e..ffe596a678 100644 --- a/agenta-cli/agenta/sdk/managers/vault.py +++ b/agenta-cli/agenta/sdk/managers/vault.py @@ -1,17 +1,11 @@ -from typing import Optional, Type, TypeVar, Dict, Any, Union - -from pydantic import BaseModel +from typing import Optional, Dict, Any from agenta.sdk.decorators.routing import routing_context -T = TypeVar("T", bound=BaseModel) - class VaultManager: @staticmethod - def get_from_route( - schema: Optional[Type[T]] = None, - ) -> Optional[Union[Dict[str, Any], T]]: + def get_from_route() -> Optional[Dict[str, Any]]: context = routing_context.get() secrets = context.secrets @@ -19,7 +13,4 @@ def get_from_route( if not secrets: return None - if not schema: - return secrets - - return schema(**secrets) + return secrets diff --git a/agenta-cli/agenta/sdk/middleware/vault.py b/agenta-cli/agenta/sdk/middleware/vault.py index a69135b188..a5eca6a03f 100644 --- a/agenta-cli/agenta/sdk/middleware/vault.py +++ b/agenta-cli/agenta/sdk/middleware/vault.py @@ -1,18 +1,54 @@ from typing import Callable, Dict, Optional +from enum import Enum from os import getenv from json import dumps +from pydantic import BaseModel + import httpx from starlette.middleware.base import BaseHTTPMiddleware from fastapi import FastAPI, Request from agenta.sdk.middleware.cache import TTLLRUCache -from agenta.sdk.utils.exceptions import suppress +from agenta.sdk.utils.exceptions import suppress, display_exception from agenta.sdk.utils.timing import atimeit import agenta as ag + +# TODO: Move these four to backend client types + + +class SecretKind(str, Enum): + PROVIDER_KEY = "provider_key" + + +class ProviderKind(str, Enum): + OPENAI = "openai" + COHERE = "cohere" + ANYSCALE = "anyscale" + DEEPINFRA = "deepinfra" + ALEPHALPHA = "alephalpha" + GROQ = "groq" + MISTRALAI = "mistralai" + ANTHROPIC = "anthropic" + PERPLEXITYAI = "perplexityai" + TOGETHERAI = "togetherai" + OPENROUTER = "openrouter" + GEMINI = "gemini" + + +class ProviderKeyDTO(BaseModel): + provider: ProviderKind + key: str + + +class SecretDTO(BaseModel): + kind: SecretKind = "provider_key" + data: ProviderKeyDTO + + _TRUTHY = {"true", "1", "t", "y", "yes", "on", "enable", "enabled"} _CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in _TRUTHY @@ -61,16 +97,52 @@ async def _get_secrets(self, request: Request) -> Optional[Dict]: return secrets - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.host}/api/vault/v1/secrets", - headers=headers, - ) + local_secrets = [] + + try: + for provider_kind in ProviderKind: + provider = provider_kind.value + key = f"{provider.upper()}_API_KEY" + + secret = SecretDTO( + kind=SecretKind.PROVIDER_KEY, + data=ProviderKeyDTO( + provider=provider, + key=key, + ), + ) + + local_secrets.append(secret.model_dump()) + except: # pylint: disable=bare-except + display_exception("Vault: Local Secrets Exception") + + vault_secrets = [] + + try: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.host}/api/vault/v1/secrets", + headers=headers, + ) + + vault = response.json() + + vault_secrets = vault.get("secrets") + except: # pylint: disable=bare-except + display_exception("Vault: Vault Secrets Exception") + + merged_secrets = {} + + for secret in local_secrets: + provider = secret["data"]["provider"] + merged_secrets[provider] = secret - vault = response.json() + for secret in vault_secrets: + provider = secret["data"]["provider"] + merged_secrets[provider] = secret - secrets = vault.get("secrets") + secrets = list(merged_secrets.values()) - _cache.put(_hash, {"secrets": secrets}) + _cache.put(_hash, {"secrets": secrets}) - return secrets + return secrets From 041e9bb907c6033e98a58969ec7356450283a705 Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Fri, 13 Dec 2024 15:29:25 +0100 Subject: [PATCH 07/11] integration fixes --- agenta-cli/agenta/client/client.py | 2 +- agenta-cli/agenta/sdk/agenta_init.py | 148 +++--------- agenta-cli/agenta/sdk/context/routing.py | 4 +- agenta-cli/agenta/sdk/decorators/routing.py | 102 ++++---- agenta-cli/agenta/sdk/decorators/tracing.py | 6 +- agenta-cli/agenta/sdk/managers/config.py | 2 +- agenta-cli/agenta/sdk/managers/vault.py | 2 +- agenta-cli/agenta/sdk/middleware/auth.py | 41 ++-- agenta-cli/agenta/sdk/middleware/cache.py | 4 + agenta-cli/agenta/sdk/middleware/config.py | 251 +++++++++++--------- agenta-cli/agenta/sdk/middleware/otel.py | 2 - agenta-cli/agenta/sdk/middleware/vault.py | 34 ++- agenta-cli/agenta/sdk/tracing/exporters.py | 10 - agenta-cli/agenta/sdk/tracing/inline.py | 4 +- agenta-cli/agenta/sdk/utils/constants.py | 1 + agenta-cli/agenta/sdk/utils/globals.py | 12 +- 16 files changed, 282 insertions(+), 343 deletions(-) create mode 100644 agenta-cli/agenta/sdk/utils/constants.py diff --git a/agenta-cli/agenta/client/client.py b/agenta-cli/agenta/client/client.py index d5e4547f74..17dc1ac460 100644 --- a/agenta-cli/agenta/client/client.py +++ b/agenta-cli/agenta/client/client.py @@ -559,5 +559,5 @@ def run_evaluation(app_name: str, host: str, api_key: str = None) -> str: raise APIRequestError( f"Request to run evaluations failed with status code {response.status_code} and error message: {error_message}." ) - print(response.json()) + return response.json() diff --git a/agenta-cli/agenta/sdk/agenta_init.py b/agenta-cli/agenta/sdk/agenta_init.py index dbc5734378..52daec5b50 100644 --- a/agenta-cli/agenta/sdk/agenta_init.py +++ b/agenta-cli/agenta/sdk/agenta_init.py @@ -6,8 +6,9 @@ from agenta.sdk.utils.logging import log from agenta.sdk.utils.globals import set_global from agenta.client.backend.client import AgentaApi, AsyncAgentaApi + from agenta.sdk.tracing import Tracing -from agenta.client.exceptions import APIRequestError +from agenta.sdk.context.routing import routing_context class AgentaSingleton: @@ -124,28 +125,40 @@ def init( class Config: def __init__( self, - host: str, + # LEGACY + host: Optional[str] = None, base_id: Optional[str] = None, - api_key: Optional[str] = "", + api_key: Optional[str] = None, + # LEGACY + **kwargs, ): - self.host = host + self.default_parameters = {**kwargs} + + def set_default(self, **kwargs): + self.default_parameters.update(kwargs) + + def get_default(self): + return self.default_parameters + + def __getattr__(self, key): + context = routing_context.get() + + parameters = context.parameters - self.base_id = base_id + if key in parameters: + value = parameters[key] - if self.base_id is None: - # print( - # "Warning: Your configuration will not be saved permanently since base_id is not provided.\n" - # ) - pass + if isinstance(value, dict): + nested_config = Config() + nested_config.set_default(**value) - if base_id is None or host is None: - self.persist = False - else: - self.persist = True - self.client = AgentaApi( - base_url=self.host + "/api", - api_key=api_key if api_key else "", - ) + return nested_config + + return value + + return None + + ### --- LEGACY --- ### def register_default(self, overwrite=False, **kwargs): """alias for default""" @@ -157,104 +170,13 @@ def default(self, overwrite=False, **kwargs): overwrite: Whether to overwrite the existing configuration or not **kwargs: A dict containing the parameters """ - self.set( - **kwargs - ) # In case there is no connectivity, we still can use the default values - try: - self.push(config_name="default", overwrite=overwrite, **kwargs) - except Exception as ex: - log.warning( - "Unable to push the default configuration to the server. %s", str(ex) - ) - - def push(self, config_name: str, overwrite=True, **kwargs): - """Pushes the parameters for the app variant to the server - Args: - config_name: Name of the configuration to push to - overwrite: Whether to overwrite the existing configuration or not - **kwargs: A dict containing the parameters - """ - if not self.persist: - return - try: - self.client.configs.save_config( - base_id=self.base_id, - config_name=config_name, - parameters=kwargs, - overwrite=overwrite, - ) - except Exception as ex: - log.warning( - "Failed to push the configuration to the server with error: %s", ex - ) - - def pull( - self, config_name: str = "default", environment_name: Optional[str] = None - ): - """Pulls the parameters for the app variant from the server and sets them to the config""" - if not self.persist and ( - config_name != "default" or environment_name is not None - ): - raise ValueError( - "Cannot pull the configuration from the server since the app_name and base_name are not provided." - ) - if self.persist: - try: - if environment_name: - config = self.client.configs.get_config( - base_id=self.base_id, environment_name=environment_name - ) - - else: - config = self.client.configs.get_config( - base_id=self.base_id, - config_name=config_name, - ) - except Exception as ex: - log.warning( - "Failed to pull the configuration from the server with error: %s", - str(ex), - ) - try: - self.set(**{"current_version": config.current_version, **config.parameters}) - except Exception as ex: - log.warning("Failed to set the configuration with error: %s", str(ex)) + self.set(**kwargs) - def all(self): - """Returns all the parameters for the app variant""" - return { - k: v - for k, v in self.__dict__.items() - if k - not in [ - "app_name", - "base_name", - "host", - "base_id", - "api_key", - "persist", - "client", - ] - } - - # function to set the parameters for the app variant def set(self, **kwargs): - """Sets the parameters for the app variant - - Args: - **kwargs: A dict containing the parameters - """ - for key, value in kwargs.items(): - setattr(self, key, value) - - def dump(self): - """Returns all the information about the current version in the configuration. - - Raises: - NotImplementedError: _description_ - """ + self.set_default(**kwargs) - raise NotImplementedError() + def all(self): + return self.default_parameters def init( diff --git a/agenta-cli/agenta/sdk/context/routing.py b/agenta-cli/agenta/sdk/context/routing.py index c47ef74712..1284898289 100644 --- a/agenta-cli/agenta/sdk/context/routing.py +++ b/agenta-cli/agenta/sdk/context/routing.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from contextlib import contextmanager from contextvars import ContextVar @@ -8,7 +8,7 @@ class RoutingContext(BaseModel): parameters: Optional[Dict[str, Any]] = None - secrets: Optional[Dict[str, Any]] = None + secrets: Optional[List[Any]] = None routing_context = ContextVar("routing_context", default=RoutingContext()) diff --git a/agenta-cli/agenta/sdk/decorators/routing.py b/agenta-cli/agenta/sdk/decorators/routing.py index a0bddfa72a..7587cd7d52 100644 --- a/agenta-cli/agenta/sdk/decorators/routing.py +++ b/agenta-cli/agenta/sdk/decorators/routing.py @@ -1,14 +1,12 @@ from typing import Type, Any, Callable, Dict, Optional, Tuple, List -from annotated_types import Ge, Le, Gt, Lt -from pydantic import BaseModel, HttpUrl, ValidationError from inspect import signature, iscoroutinefunction, Signature, Parameter, _empty -from argparse import ArgumentParser from functools import wraps -from asyncio import sleep, get_event_loop -from traceback import format_exc, format_exception -from pathlib import Path -from tempfile import NamedTemporaryFile +from traceback import format_exception +from asyncio import sleep +from tempfile import NamedTemporaryFile +from annotated_types import Ge, Le, Gt, Lt +from pydantic import BaseModel, HttpUrl, ValidationError from fastapi import Body, FastAPI, UploadFile, HTTPException, Request @@ -20,7 +18,6 @@ from agenta.sdk.context.routing import ( routing_context_manager, - routing_context, RoutingContext, ) from agenta.sdk.context.tracing import ( @@ -60,7 +57,7 @@ class PathValidator(BaseModel): url: HttpUrl -class route: +class route: # pylint: disable=invalid-name # This decorator is used to expose specific stages of a workflow (embedding, retrieval, summarization, etc.) # as independent endpoints. It is designed for backward compatibility with existing code that uses # the @entrypoint decorator, which has certain limitations. By using @route(), we can create new @@ -122,6 +119,7 @@ async def chain_of_prompts_llm(prompt: str): """ routes = list() + _middleware = False _run_path = "/run" _test_path = "/test" @@ -158,28 +156,34 @@ def __init__( ### --- Run --- # @wraps(func) async def run_wrapper(request: Request, *args, **kwargs) -> Any: - arguments = { + # LEGACY + # TODO: Removing this implies breaking changes in : + # - calls to /generate_deployed + kwargs = { k: v for k, v in kwargs.items() if k not in ["config", "environment", "app"] } + # LEGACY - return await self.execute_wrapper( - request, - False, - *args, - **arguments, - ) + kwargs, _ = self.split_kwargs(kwargs, default_parameters) - self.update_deployed_function_signature( - run_wrapper, - ingestible_files, + # TODO: Why is this not used in the run_wrapper? + # self.ingest_files(kwargs, ingestible_files) + + return await self.execute_wrapper(request, False, *args, **kwargs) + + self.update_run_wrapper_signature( + wrapper=run_wrapper, + ingestible_files=ingestible_files, ) run_route = f"{entrypoint._run_path}{route_path}" app.post(run_route, response_model=BaseResponse)(run_wrapper) # LEGACY + # TODO: Removing this implies breaking changes in : + # - calls to /generate_deployed must be replaced with calls to /run if route_path == "": run_route = entrypoint._legacy_generate_deployed_path app.post(run_route, response_model=BaseResponse)(run_wrapper) @@ -189,34 +193,28 @@ async def run_wrapper(request: Request, *args, **kwargs) -> Any: ### --- Test --- # @wraps(func) async def test_wrapper(request: Request, *args, **kwargs) -> Any: - arguments, _ = self.split_kwargs( - kwargs, - default_parameters, - ) + kwargs, parameters = self.split_kwargs(kwargs, default_parameters) - self.ingest_files( - arguments, - ingestible_files, - ) + request.state.config["parameters"] = parameters - return await self.execute_wrapper( - request, - True, - *args, - **arguments, - ) + # TODO: Why is this only used in the test_wrapper? + self.ingest_files(kwargs, ingestible_files) + + return await self.execute_wrapper(request, True, *args, **kwargs) self.update_test_wrapper_signature( wrapper=test_wrapper, + ingestible_files=ingestible_files, config_class=config, config_dict=default_parameters, - ingestible_files=ingestible_files, ) test_route = f"{entrypoint._test_path}{route_path}" app.post(test_route, response_model=BaseResponse)(test_wrapper) # LEGACY + # TODO: Removing this implies breaking changes in : + # - calls to /generate must be replaced with calls to /test if route_path == "": test_route = entrypoint._legacy_generate_path app.post(test_route, response_model=BaseResponse)(test_wrapper) @@ -306,14 +304,15 @@ def parse_config(self) -> Dict[str, Any]: def split_kwargs( self, kwargs: Dict[str, Any], default_parameters: Dict[str, Any] ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """Split keyword arguments into function parameters and API configuration parameters.""" + arguments = {k: v for k, v in kwargs.items() if k not in default_parameters} + parameters = {k: v for k, v in kwargs.items() if k in default_parameters} - func_params = {k: v for k, v in kwargs.items() if k not in default_parameters} - api_config_params = {k: v for k, v in kwargs.items() if k in default_parameters} + return arguments, parameters - return func_params, api_config_params - - def ingest_file(self, upfile: UploadFile): + def ingest_file( + self, + upfile: UploadFile, + ): temp_file = NamedTemporaryFile(delete=False) temp_file.write(upfile.file.read()) temp_file.close() @@ -384,7 +383,11 @@ async def execute_function( except Exception as error: # pylint: disable=broad-except self.handle_failure(error) - async def handle_success(self, result: Any, inline: bool): + async def handle_success( + self, + result: Any, + inline: bool, + ): data = None tree = None @@ -399,7 +402,10 @@ async def handle_success(self, result: Any, inline: bool): except: return BaseResponse(data=data) - def handle_failure(self, error: Exception): + def handle_failure( + self, + error: Exception, + ): display_exception("Application Exception") status_code = 500 @@ -409,7 +415,10 @@ def handle_failure(self, error: Exception): raise HTTPException(status_code=status_code, detail=detail) - def patch_result(self, result: Any): + def patch_result( + self, + result: Any, + ): """ Patch the result to only include the message if the result is a FuncResponse-style dictionary with message, cost, and usage keys. @@ -446,7 +455,10 @@ def patch_result(self, result: Any): return data - async def fetch_inline_trace(self, inline): + async def fetch_inline_trace( + self, + inline, + ): WAIT_FOR_SPANS = True TIMEOUT = 1 TIMESTEP = 0.1 @@ -537,7 +549,7 @@ def update_test_wrapper_signature( self.update_wrapper_signature(wrapper, updated_params) self.add_request_to_signature(wrapper) - def update_deployed_function_signature( + def update_run_wrapper_signature( self, wrapper: Callable[..., Any], ingestible_files: Dict[str, Parameter], diff --git a/agenta-cli/agenta/sdk/decorators/tracing.py b/agenta-cli/agenta/sdk/decorators/tracing.py index 7b8e3c6881..6d31243963 100644 --- a/agenta-cli/agenta/sdk/decorators/tracing.py +++ b/agenta-cli/agenta/sdk/decorators/tracing.py @@ -97,7 +97,7 @@ def _attach_baggage(self): token = None if references: - for k, v in references: + for k, v in references.items(): token = attach(baggage.set_baggage(f"ag.refs.{k}", v)) return token @@ -246,9 +246,7 @@ def _redact( not in ( ignore if isinstance(ignore, list) - else io.keys() - if ignore is True - else [] + else io.keys() if ignore is True else [] ) } diff --git a/agenta-cli/agenta/sdk/managers/config.py b/agenta-cli/agenta/sdk/managers/config.py index f433fa5242..d3ec7b97cb 100644 --- a/agenta-cli/agenta/sdk/managers/config.py +++ b/agenta-cli/agenta/sdk/managers/config.py @@ -7,7 +7,7 @@ from pydantic import BaseModel from agenta.sdk.managers.shared import SharedManager -from agenta.sdk.decorators.routing import routing_context +from agenta.sdk.context.routing import routing_context T = TypeVar("T", bound=BaseModel) diff --git a/agenta-cli/agenta/sdk/managers/vault.py b/agenta-cli/agenta/sdk/managers/vault.py index ffe596a678..f559af19d2 100644 --- a/agenta-cli/agenta/sdk/managers/vault.py +++ b/agenta-cli/agenta/sdk/managers/vault.py @@ -1,6 +1,6 @@ from typing import Optional, Dict, Any -from agenta.sdk.decorators.routing import routing_context +from agenta.sdk.context.routing import routing_context class VaultManager: diff --git a/agenta-cli/agenta/sdk/middleware/auth.py b/agenta-cli/agenta/sdk/middleware/auth.py index 6b0b9a1bad..60f86f2c05 100644 --- a/agenta-cli/agenta/sdk/middleware/auth.py +++ b/agenta-cli/agenta/sdk/middleware/auth.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Optional +from typing import Callable, Optional from os import getenv from json import dumps @@ -8,23 +8,21 @@ from fastapi import FastAPI, Request from fastapi.responses import JSONResponse -from agenta.sdk.middleware.cache import TTLLRUCache +from agenta.sdk.middleware.cache import TTLLRUCache, CACHE_CAPACITY, CACHE_TTL +from agenta.sdk.utils.constants import TRUTHY from agenta.sdk.utils.exceptions import display_exception -from agenta.sdk.utils.timing import atimeit import agenta as ag -_TRUTHY = {"true", "1", "t", "y", "yes", "on", "enable", "enabled"} -_ALLOW_UNAUTHORIZED = ( - getenv("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", "false").lower() in _TRUTHY -) -_SHARED_SERVICE = getenv("AGENTA_SHARED_SERVICE", "true").lower() in _TRUTHY -_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in _TRUTHY -_CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512")) -_CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes +_SHARED_SERVICE = getenv("AGENTA_SHARED_SERVICE", "false").lower() in TRUTHY +_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in TRUTHY +_UNAUTHORIZED_ALLOWED = ( + getenv("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", "false").lower() in TRUTHY +) +_ALWAYS_ALLOW_LIST = ["/health"] -_cache = TTLLRUCache(capacity=_CACHE_CAPACITY, ttl=_CACHE_TTL) +_cache = TTLLRUCache(capacity=CACHE_CAPACITY, ttl=CACHE_TTL) class DenyResponse(JSONResponse): @@ -64,7 +62,7 @@ def __init__(self, app: FastAPI): async def dispatch(self, request: Request, call_next: Callable): try: - if _ALLOW_UNAUTHORIZED: + if _UNAUTHORIZED_ALLOWED or request.url.path in _ALWAYS_ALLOW_LIST: request.state.auth = None else: @@ -87,10 +85,9 @@ async def dispatch(self, request: Request, call_next: Callable): return DenyResponse( status_code=500, - detail="Internal Server Error: auth middleware.", + detail="Auth: Unexpected Error.", ) - # @atimeit async def _get_credentials(self, request: Request) -> Optional[str]: try: authorization = request.headers.get("authorization", None) @@ -143,23 +140,26 @@ async def _get_credentials(self, request: Request) -> Optional[str]: if response.status_code == 401: raise DenyException( - status_code=401, content="Invalid 'authorization' header." + status_code=401, + content="Invalid credentials", ) elif response.status_code == 403: raise DenyException( - status_code=403, content="Service execution not allowed." + status_code=403, + content="Service execution not allowed.", ) elif response.status_code != 200: raise DenyException( status_code=400, - content="Internal Server Error: auth middleware.", + content="Auth: Unexpected Error.", ) auth = response.json() if auth.get("effect") != "allow": raise DenyException( - status_code=403, content="Service execution not allowed." + status_code=403, + content="Service execution not allowed.", ) credentials = auth.get("credentials") @@ -175,5 +175,6 @@ async def _get_credentials(self, request: Request) -> Optional[str]: display_exception("Auth Middleware Exception (suppressed)") raise DenyException( - status_code=500, content="Internal Server Error: auth middleware." + status_code=500, + content="Auth: Unexpected Error.", ) from exc diff --git a/agenta-cli/agenta/sdk/middleware/cache.py b/agenta-cli/agenta/sdk/middleware/cache.py index 5445b1fafc..641f4f802d 100644 --- a/agenta-cli/agenta/sdk/middleware/cache.py +++ b/agenta-cli/agenta/sdk/middleware/cache.py @@ -1,6 +1,10 @@ +from os import getenv from time import time from collections import OrderedDict +CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512")) +CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes + class TTLLRUCache: def __init__(self, capacity: int, ttl: int): diff --git a/agenta-cli/agenta/sdk/middleware/config.py b/agenta-cli/agenta/sdk/middleware/config.py index 119a206424..ec8ff7f43f 100644 --- a/agenta-cli/agenta/sdk/middleware/config.py +++ b/agenta-cli/agenta/sdk/middleware/config.py @@ -10,19 +10,16 @@ import httpx -from agenta.sdk.middleware.cache import TTLLRUCache +from agenta.sdk.middleware.cache import TTLLRUCache, CACHE_CAPACITY, CACHE_TTL +from agenta.sdk.utils.constants import TRUTHY from agenta.sdk.utils.exceptions import suppress -from agenta.sdk.utils.timing import atimeit import agenta as ag -_TRUTHY = {"true", "1", "t", "y", "yes", "on", "enable", "enabled"} -_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in _TRUTHY -_CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512")) -_CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes +_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in TRUTHY -_cache = TTLLRUCache(capacity=_CACHE_CAPACITY, ttl=_CACHE_TTL) +_cache = TTLLRUCache(capacity=CACHE_CAPACITY, ttl=CACHE_TTL) class Reference(BaseModel): @@ -31,122 +28,19 @@ class Reference(BaseModel): version: Optional[str] = None -async def _parse_application_ref( - request: Request, -) -> Optional[Reference]: - baggage = request.state.otel.get("baggage") if request.state.otel else {} - - application_id = ( - # CLEANEST - baggage.get("application_id") - # ALTERNATIVE - or request.query_params.get("application_id") - # LEGACY - or request.query_params.get("app_id") - ) - application_slug = ( - # CLEANEST - baggage.get("application_slug") - # ALTERNATIVE - or request.query_params.get("application_slug") - # LEGACY - or request.query_params.get("app_slug") - or request.query_params.get("app") - ) - - if not any([application_id, application_slug]): - return None - - return Reference( - id=application_id, - slug=application_slug, - ) - - -async def _parse_variant_ref( - request: Request, -) -> Optional[Reference]: - baggage = request.state.otel.get("baggage") if request.state.otel else {} - - variant_id = ( - # CLEANEST - baggage.get("variant_id") - # ALTERNATIVE - or request.query_params.get("variant_id") - ) - variant_slug = ( - # CLEANEST - baggage.get("variant_slug") - # ALTERNATIVE - or request.query_params.get("variant_slug") - # LEGACY - or request.query_params.get("config") - ) - variant_version = ( - # CLEANEST - baggage.get("variant_version") - # ALTERNATIVE - or request.query_params.get("variant_version") - ) - - if not any([variant_id, variant_slug, variant_version]): - return None - - return Reference( - id=variant_id, - slug=variant_slug, - version=variant_version, - ) - - -async def _parse_environment_ref( - request: Request, -) -> Optional[Reference]: - baggage = request.state.otel.get("baggage") if request.state.otel else {} - - environment_id = ( - # CLEANEST - baggage.get("environment_id") - # ALTERNATIVE - or request.query_params.get("environment_id") - ) - environment_slug = ( - # CLEANEST - baggage.get("environment_slug") - # ALTERNATIVE - or request.query_params.get("environment_slug") - # LEGACY - or request.query_params.get("environment") - ) - environment_version = ( - # CLEANEST - baggage.get("environment_version") - # ALTERNATIVE - or request.query_params.get("environment_version") - ) - - if not any([environment_id, environment_slug, environment_version]): - return None - - return Reference( - id=environment_id, - slug=environment_slug, - version=environment_version, - ) - - class ConfigMiddleware(BaseHTTPMiddleware): def __init__(self, app: FastAPI): super().__init__(app) self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host + self.application_id = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.app_id async def dispatch( self, request: Request, call_next: Callable, ): - request.state.config = None + request.state.config = {} with suppress(): parameters, references = await self._get_config(request) @@ -160,9 +54,9 @@ async def dispatch( # @atimeit async def _get_config(self, request: Request) -> Optional[Tuple[Dict, Dict]]: - application_ref = await _parse_application_ref(request) - variant_ref = await _parse_variant_ref(request) - environment_ref = await _parse_environment_ref(request) + application_ref = await self._parse_application_ref(request) + variant_ref = await self._parse_variant_ref(request) + environment_ref = await self._parse_environment_ref(request) auth = request.state.auth or {} @@ -207,7 +101,7 @@ async def _get_config(self, request: Request) -> Optional[Tuple[Dict, Dict]]: ) if response.status_code != 200: - return None + return None, None config = response.json() @@ -233,3 +127,128 @@ async def _get_config(self, request: Request) -> Optional[Tuple[Dict, Dict]]: _cache.put(_hash, {"parameters": parameters, "references": references}) return parameters, references + + async def _parse_application_ref( + self, + request: Request, + ) -> Optional[Reference]: + baggage = request.state.otel.get("baggage") if request.state.otel else {} + + body = {} + try: + body = await request.json() + except: # pylint: disable=bare-except + pass + + application_id = ( + # CLEANEST + baggage.get("application_id") + # ALTERNATIVE + or request.query_params.get("application_id") + # LEGACY + or request.query_params.get("app_id") + or self.application_id + ) + application_slug = ( + # CLEANEST + baggage.get("application_slug") + # ALTERNATIVE + or request.query_params.get("application_slug") + # LEGACY + or request.query_params.get("app_slug") + or body.get("app") + ) + + if not any([application_id, application_slug]): + return None + + return Reference( + id=application_id, + slug=application_slug, + ) + + async def _parse_variant_ref( + self, + request: Request, + ) -> Optional[Reference]: + baggage = request.state.otel.get("baggage") if request.state.otel else {} + + body = {} + try: + body = await request.json() + except: # pylint: disable=bare-except + pass + + variant_id = ( + # CLEANEST + baggage.get("variant_id") + # ALTERNATIVE + or request.query_params.get("variant_id") + ) + variant_slug = ( + # CLEANEST + baggage.get("variant_slug") + # ALTERNATIVE + or request.query_params.get("variant_slug") + # LEGACY + or request.query_params.get("config") + or body.get("config") + ) + variant_version = ( + # CLEANEST + baggage.get("variant_version") + # ALTERNATIVE + or request.query_params.get("variant_version") + ) + + if not any([variant_id, variant_slug, variant_version]): + return None + + return Reference( + id=variant_id, + slug=variant_slug, + version=variant_version, + ) + + async def _parse_environment_ref( + self, + request: Request, + ) -> Optional[Reference]: + baggage = request.state.otel.get("baggage") if request.state.otel else {} + + body = {} + try: + body = await request.json() + except: # pylint: disable=bare-except + pass + + environment_id = ( + # CLEANEST + baggage.get("environment_id") + # ALTERNATIVE + or request.query_params.get("environment_id") + ) + environment_slug = ( + # CLEANEST + baggage.get("environment_slug") + # ALTERNATIVE + or request.query_params.get("environment_slug") + # LEGACY + or request.query_params.get("environment") + or body.get("environment") + ) + environment_version = ( + # CLEANEST + baggage.get("environment_version") + # ALTERNATIVE + or request.query_params.get("environment_version") + ) + + if not any([environment_id, environment_slug, environment_version]): + return None + + return Reference( + id=environment_id, + slug=environment_slug, + version=environment_version, + ) diff --git a/agenta-cli/agenta/sdk/middleware/otel.py b/agenta-cli/agenta/sdk/middleware/otel.py index 51f3154e16..e1be195ae6 100644 --- a/agenta-cli/agenta/sdk/middleware/otel.py +++ b/agenta-cli/agenta/sdk/middleware/otel.py @@ -6,7 +6,6 @@ from opentelemetry.baggage.propagation import W3CBaggagePropagator from agenta.sdk.utils.exceptions import suppress -from agenta.sdk.utils.timing import atimeit class OTelMiddleware(BaseHTTPMiddleware): @@ -23,7 +22,6 @@ async def dispatch(self, request: Request, call_next: Callable): return await call_next(request) - # @atimeit async def _get_baggage( self, request, diff --git a/agenta-cli/agenta/sdk/middleware/vault.py b/agenta-cli/agenta/sdk/middleware/vault.py index a5eca6a03f..538bad3dd6 100644 --- a/agenta-cli/agenta/sdk/middleware/vault.py +++ b/agenta-cli/agenta/sdk/middleware/vault.py @@ -10,20 +10,19 @@ from starlette.middleware.base import BaseHTTPMiddleware from fastapi import FastAPI, Request -from agenta.sdk.middleware.cache import TTLLRUCache +from agenta.sdk.middleware.cache import TTLLRUCache, CACHE_CAPACITY, CACHE_TTL +from agenta.sdk.utils.constants import TRUTHY from agenta.sdk.utils.exceptions import suppress, display_exception -from agenta.sdk.utils.timing import atimeit import agenta as ag -# TODO: Move these four to backend client types - - +# TODO: Move to backend client types class SecretKind(str, Enum): PROVIDER_KEY = "provider_key" +# TODO: Move to backend client types class ProviderKind(str, Enum): OPENAI = "openai" COHERE = "cohere" @@ -39,23 +38,21 @@ class ProviderKind(str, Enum): GEMINI = "gemini" +# TODO: Move to backend client types class ProviderKeyDTO(BaseModel): provider: ProviderKind key: str +# TODO: Move to backend client types class SecretDTO(BaseModel): kind: SecretKind = "provider_key" data: ProviderKeyDTO -_TRUTHY = {"true", "1", "t", "y", "yes", "on", "enable", "enabled"} -_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in _TRUTHY - -_CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512")) -_CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes +_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in TRUTHY -_cache = TTLLRUCache(capacity=_CACHE_CAPACITY, ttl=_CACHE_TTL) +_cache = TTLLRUCache(capacity=CACHE_CAPACITY, ttl=CACHE_TTL) class VaultMiddleware(BaseHTTPMiddleware): @@ -78,7 +75,6 @@ async def dispatch( return await call_next(request) - # @atimeit async def _get_secrets(self, request: Request) -> Optional[Dict]: headers = {"Authorization": request.state.auth.get("credentials")} @@ -133,13 +129,15 @@ async def _get_secrets(self, request: Request) -> Optional[Dict]: merged_secrets = {} - for secret in local_secrets: - provider = secret["data"]["provider"] - merged_secrets[provider] = secret + if local_secrets: + for secret in local_secrets: + provider = secret["data"]["provider"] + merged_secrets[provider] = secret - for secret in vault_secrets: - provider = secret["data"]["provider"] - merged_secrets[provider] = secret + if vault_secrets: + for secret in vault_secrets: + provider = secret["data"]["provider"] + merged_secrets[provider] = secret secrets = list(merged_secrets.values()) diff --git a/agenta-cli/agenta/sdk/tracing/exporters.py b/agenta-cli/agenta/sdk/tracing/exporters.py index c713811eca..7a38201d5a 100644 --- a/agenta-cli/agenta/sdk/tracing/exporters.py +++ b/agenta-cli/agenta/sdk/tracing/exporters.py @@ -74,11 +74,6 @@ def __init__(self, *args, credentials: Dict[int, str] = None, **kwargs): def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: credentials = None - # --- DEBUG - for span in spans: - print(span.name, span.attributes) - # --------- - if self.credentials: trace_ids = set(span.get_span_context().trace_id for span in spans) @@ -101,11 +96,6 @@ def _export(self, serialized_data: bytes): if credentials: self._session.headers.update({"Authorization": credentials}) - # --- DEBUG - auth = {"Authorization": self._session.headers.get("Authorization")} - print(" ", auth) - # --------- - return super()._export(serialized_data) diff --git a/agenta-cli/agenta/sdk/tracing/inline.py b/agenta-cli/agenta/sdk/tracing/inline.py index 6905ad5cf0..3bf55cdf82 100644 --- a/agenta-cli/agenta/sdk/tracing/inline.py +++ b/agenta-cli/agenta/sdk/tracing/inline.py @@ -101,8 +101,8 @@ class NodeDTO(BaseModel): Data = Dict[str, Any] Metrics = Dict[str, Any] Metadata = Dict[str, Any] -Tags = Dict[str, str] -Refs = Dict[str, str] +Tags = Dict[str, Any] +Refs = Dict[str, Any] class LinkDTO(BaseModel): diff --git a/agenta-cli/agenta/sdk/utils/constants.py b/agenta-cli/agenta/sdk/utils/constants.py new file mode 100644 index 0000000000..fc2e1ae25d --- /dev/null +++ b/agenta-cli/agenta/sdk/utils/constants.py @@ -0,0 +1 @@ +TRUTHY = {"true", "1", "t", "y", "yes", "on", "enable", "enabled"} diff --git a/agenta-cli/agenta/sdk/utils/globals.py b/agenta-cli/agenta/sdk/utils/globals.py index f05141e089..ceae076427 100644 --- a/agenta-cli/agenta/sdk/utils/globals.py +++ b/agenta-cli/agenta/sdk/utils/globals.py @@ -1,14 +1,10 @@ -import agenta +import agenta as ag def set_global(config=None, tracing=None): - """Allows usage of agenta.config and agenta.tracing in the user's code. + """Allows usage of agenta.config and agenta.tracing in the user's code.""" - Args: - config: _description_. Defaults to None. - tracing: _description_. Defaults to None. - """ if config is not None: - agenta.config = config + ag.config = config if tracing is not None: - agenta.tracing = tracing + ag.tracing = tracing From a73afc4443fba2c1da42c50a987246c4f938fa4d Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Fri, 13 Dec 2024 18:29:43 +0100 Subject: [PATCH 08/11] fix integrations issues --- agenta-cli/agenta/sdk/agenta_init.py | 3 +++ agenta-cli/agenta/sdk/decorators/routing.py | 17 ++++++++++------ agenta-cli/agenta/sdk/decorators/tracing.py | 4 +++- agenta-cli/agenta/sdk/middleware/auth.py | 2 +- agenta-cli/agenta/sdk/middleware/config.py | 12 +++++------ agenta-cli/agenta/sdk/middleware/otel.py | 2 +- agenta-cli/agenta/sdk/middleware/vault.py | 22 ++++++++++++++++----- 7 files changed, 42 insertions(+), 20 deletions(-) diff --git a/agenta-cli/agenta/sdk/agenta_init.py b/agenta-cli/agenta/sdk/agenta_init.py index 52daec5b50..06659f4f4d 100644 --- a/agenta-cli/agenta/sdk/agenta_init.py +++ b/agenta-cli/agenta/sdk/agenta_init.py @@ -145,6 +145,9 @@ def __getattr__(self, key): parameters = context.parameters + if not parameters: + return None + if key in parameters: value = parameters[key] diff --git a/agenta-cli/agenta/sdk/decorators/routing.py b/agenta-cli/agenta/sdk/decorators/routing.py index 7587cd7d52..45229ea4bc 100644 --- a/agenta-cli/agenta/sdk/decorators/routing.py +++ b/agenta-cli/agenta/sdk/decorators/routing.py @@ -219,6 +219,13 @@ async def test_wrapper(request: Request, *args, **kwargs) -> Any: test_route = entrypoint._legacy_generate_path app.post(test_route, response_model=BaseResponse)(test_wrapper) # LEGACY + + # LEGACY + # TODO: Removing this implies no breaking changes + if route_path == "": + test_route = entrypoint._legacy_playground_run_path + app.post(test_route, response_model=BaseResponse)(test_wrapper) + # LEGACY ### ------------ # ### --- OpenAPI --- # @@ -341,10 +348,10 @@ async def execute_wrapper( raise HTTPException(status_code=500, detail="Missing 'request'.") state = request.state - credentials = state.auth.get("credentials") if state.auth else None - parameters = state.config.get("parameters") if state.config else None - references = state.config.get("references") if state.config else None - secrets = state.vault.get("secrets") if state.vault else None + credentials = state.auth.get("credentials") + parameters = state.config.get("parameters") + references = state.config.get("references") + secrets = state.vault.get("secrets") with routing_context_manager( context=RoutingContext( @@ -369,8 +376,6 @@ async def execute_function( *args, **kwargs, ): - log.info("Agenta - Handling: '%s'", repr(self.route_path or "/")) - try: result = ( await self.func(*args, **kwargs) diff --git a/agenta-cli/agenta/sdk/decorators/tracing.py b/agenta-cli/agenta/sdk/decorators/tracing.py index 6d31243963..f368509fc6 100644 --- a/agenta-cli/agenta/sdk/decorators/tracing.py +++ b/agenta-cli/agenta/sdk/decorators/tracing.py @@ -246,7 +246,9 @@ def _redact( not in ( ignore if isinstance(ignore, list) - else io.keys() if ignore is True else [] + else io.keys() + if ignore is True + else [] ) } diff --git a/agenta-cli/agenta/sdk/middleware/auth.py b/agenta-cli/agenta/sdk/middleware/auth.py index 60f86f2c05..fd82198d05 100644 --- a/agenta-cli/agenta/sdk/middleware/auth.py +++ b/agenta-cli/agenta/sdk/middleware/auth.py @@ -63,7 +63,7 @@ def __init__(self, app: FastAPI): async def dispatch(self, request: Request, call_next: Callable): try: if _UNAUTHORIZED_ALLOWED or request.url.path in _ALWAYS_ALLOW_LIST: - request.state.auth = None + request.state.auth = {} else: credentials = await self._get_credentials(request) diff --git a/agenta-cli/agenta/sdk/middleware/config.py b/agenta-cli/agenta/sdk/middleware/config.py index ec8ff7f43f..8ea9eb9ffe 100644 --- a/agenta-cli/agenta/sdk/middleware/config.py +++ b/agenta-cli/agenta/sdk/middleware/config.py @@ -54,16 +54,16 @@ async def dispatch( # @atimeit async def _get_config(self, request: Request) -> Optional[Tuple[Dict, Dict]]: + credentials = request.state.auth.get("credentials") + + headers = None + if credentials: + headers = {"Authorization": credentials} + application_ref = await self._parse_application_ref(request) variant_ref = await self._parse_variant_ref(request) environment_ref = await self._parse_environment_ref(request) - auth = request.state.auth or {} - - headers = { - "Authorization": auth.get("credentials"), - } - refs = {} if application_ref: refs["application_ref"] = application_ref.model_dump() diff --git a/agenta-cli/agenta/sdk/middleware/otel.py b/agenta-cli/agenta/sdk/middleware/otel.py index e1be195ae6..0a6396f979 100644 --- a/agenta-cli/agenta/sdk/middleware/otel.py +++ b/agenta-cli/agenta/sdk/middleware/otel.py @@ -13,7 +13,7 @@ def __init__(self, app: FastAPI): super().__init__(app) async def dispatch(self, request: Request, call_next: Callable): - request.state.otel = None + request.state.otel = {} with suppress(): baggage = await self._get_baggage(request) diff --git a/agenta-cli/agenta/sdk/middleware/vault.py b/agenta-cli/agenta/sdk/middleware/vault.py index 538bad3dd6..c7b6a8877f 100644 --- a/agenta-cli/agenta/sdk/middleware/vault.py +++ b/agenta-cli/agenta/sdk/middleware/vault.py @@ -66,7 +66,7 @@ async def dispatch( request: Request, call_next: Callable, ): - request.state.vault = None + request.state.vault = {} with suppress(): secrets = await self._get_secrets(request) @@ -76,7 +76,11 @@ async def dispatch( return await call_next(request) async def _get_secrets(self, request: Request) -> Optional[Dict]: - headers = {"Authorization": request.state.auth.get("credentials")} + credentials = request.state.auth.get("credentials") + + headers = None + if credentials: + headers = {"Authorization": credentials} _hash = dumps( { @@ -98,7 +102,11 @@ async def _get_secrets(self, request: Request) -> Optional[Dict]: try: for provider_kind in ProviderKind: provider = provider_kind.value - key = f"{provider.upper()}_API_KEY" + key_name = f"{provider.upper()}_API_KEY" + key = getenv(key_name) + + if not key: + continue secret = SecretDTO( kind=SecretKind.PROVIDER_KEY, @@ -121,9 +129,13 @@ async def _get_secrets(self, request: Request) -> Optional[Dict]: headers=headers, ) - vault = response.json() + if response.status_code != 200: + vault_secrets = [] + + else: + vault = response.json() - vault_secrets = vault.get("secrets") + vault_secrets = vault.get("secrets") except: # pylint: disable=bare-except display_exception("Vault: Vault Secrets Exception") From 6708afa68a0e5cba2ab5e52f29792231645c5e8e Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Tue, 17 Dec 2024 13:39:41 +0100 Subject: [PATCH 09/11] Add get_api_key_from_model() --- agenta-cli/agenta/sdk/assets.py | 103 ++++++++++++---------- agenta-cli/agenta/sdk/managers/vault.py | 22 +++++ agenta-cli/agenta/sdk/middleware/vault.py | 21 ++--- 3 files changed, 90 insertions(+), 56 deletions(-) diff --git a/agenta-cli/agenta/sdk/assets.py b/agenta-cli/agenta/sdk/assets.py index c62cc9dd97..86ade46687 100644 --- a/agenta-cli/agenta/sdk/assets.py +++ b/agenta-cli/agenta/sdk/assets.py @@ -1,23 +1,9 @@ supported_llm_models = { - "Mistral AI": [ - "mistral/mistral-tiny", - "mistral/mistral-small", - "mistral/mistral-medium", - "mistral/mistral-large-latest", - ], - "Open AI": [ - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo", - "gpt-4", - "gpt-4o", - "gpt-4o-mini", - "gpt-4-1106-preview", - ], - "Gemini": ["gemini/gemini-1.5-pro-latest", "gemini/gemini-1.5-flash"], - "Cohere": [ - "cohere/command-light", - "cohere/command-r-plus", - "cohere/command-nightly", + "Aleph Alpha": [ + "luminous-base", + "luminous-base-control", + "luminous-extended-control", + "luminous-supreme", ], "Anthropic": [ "anthropic/claude-3-5-sonnet-20240620", @@ -33,11 +19,10 @@ "anyscale/meta-llama/Llama-2-13b-chat-hf", "anyscale/meta-llama/Llama-2-70b-chat-hf", ], - "Perplexity AI": [ - "perplexity/pplx-7b-chat", - "perplexity/pplx-70b-chat", - "perplexity/pplx-7b-online", - "perplexity/pplx-70b-online", + "Cohere": [ + "cohere/command-light", + "cohere/command-r-plus", + "cohere/command-nightly", ], "DeepInfra": [ "deepinfra/meta-llama/Llama-2-70b-chat-hf", @@ -46,6 +31,46 @@ "deepinfra/mistralai/Mistral-7B-Instruct-v0.1", "deepinfra/jondurbin/airoboros-l2-70b-gpt4-1.4.1", ], + "Gemini": [ + "gemini/gemini-1.5-pro-latest", + "gemini/gemini-1.5-flash", + ], + "Groq": [ + "groq/llama3-8b-8192", + "groq/llama3-70b-8192", + "groq/llama2-70b-4096", + "groq/mixtral-8x7b-32768", + "groq/gemma-7b-it", + ], + "Mistral AI": [ + "mistral/mistral-tiny", + "mistral/mistral-small", + "mistral/mistral-medium", + "mistral/mistral-large-latest", + ], + "Open AI": [ + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + "gpt-4", + "gpt-4o", + "gpt-4o-mini", + "gpt-4-1106-preview", + ], + "OpenRouter": [ + "openrouter/openai/gpt-3.5-turbo", + "openrouter/openai/gpt-3.5-turbo-16k", + "openrouter/anthropic/claude-instant-v1", + "openrouter/google/palm-2-chat-bison", + "openrouter/google/palm-2-codechat-bison", + "openrouter/meta-llama/llama-2-13b-chat", + "openrouter/meta-llama/llama-2-70b-chat", + ], + "Perplexity AI": [ + "perplexity/pplx-7b-chat", + "perplexity/pplx-70b-chat", + "perplexity/pplx-7b-online", + "perplexity/pplx-70b-online", + ], "Together AI": [ "together_ai/togethercomputer/llama-2-70b-chat", "together_ai/togethercomputer/llama-2-70b", @@ -59,26 +84,12 @@ "together_ai/NousResearch/Nous-Hermes-Llama2-13b", "together_ai/Austism/chronos-hermes-13b", ], - "Aleph Alpha": [ - "luminous-base", - "luminous-base-control", - "luminous-extended-control", - "luminous-supreme", - ], - "OpenRouter": [ - "openrouter/openai/gpt-3.5-turbo", - "openrouter/openai/gpt-3.5-turbo-16k", - "openrouter/anthropic/claude-instant-v1", - "openrouter/google/palm-2-chat-bison", - "openrouter/google/palm-2-codechat-bison", - "openrouter/meta-llama/llama-2-13b-chat", - "openrouter/meta-llama/llama-2-70b-chat", - ], - "Groq": [ - "groq/llama3-8b-8192", - "groq/llama3-70b-8192", - "groq/llama2-70b-4096", - "groq/mixtral-8x7b-32768", - "groq/gemma-7b-it", - ], +} + +providers_list = list(supported_llm_models.keys()) + +model_to_provider_mapping = { + model: provider + for provider, models in supported_llm_models.items() + for model in models } diff --git a/agenta-cli/agenta/sdk/managers/vault.py b/agenta-cli/agenta/sdk/managers/vault.py index f559af19d2..88f1908a00 100644 --- a/agenta-cli/agenta/sdk/managers/vault.py +++ b/agenta-cli/agenta/sdk/managers/vault.py @@ -2,6 +2,8 @@ from agenta.sdk.context.routing import routing_context +from agenta.sdk.assets import model_to_provider_mapping + class VaultManager: @staticmethod @@ -14,3 +16,23 @@ def get_from_route() -> Optional[Dict[str, Any]]: return None return secrets + + @staticmethod + def get_api_key_for_model(model: str) -> str: + secrets = VaultManager.get_from_route() + + if not secrets: + return None + + provider = model_to_provider_mapping.get(model) + + if not provider: + return None + + provider = provider.lower().replace(" ", "") + + for secret in secrets: + if secret["data"]["provider"] == provider: + return secret["data"]["key"] + + return None diff --git a/agenta-cli/agenta/sdk/middleware/vault.py b/agenta-cli/agenta/sdk/middleware/vault.py index c7b6a8877f..b19836d88d 100644 --- a/agenta-cli/agenta/sdk/middleware/vault.py +++ b/agenta-cli/agenta/sdk/middleware/vault.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, List from enum import Enum from os import getenv @@ -24,18 +24,18 @@ class SecretKind(str, Enum): # TODO: Move to backend client types class ProviderKind(str, Enum): - OPENAI = "openai" - COHERE = "cohere" + ALEPHALPHA = "alephalpha" + ANTHROPIC = "anthropic" ANYSCALE = "anyscale" + COHERE = "cohere" DEEPINFRA = "deepinfra" - ALEPHALPHA = "alephalpha" + GEMINI = "gemini" GROQ = "groq" MISTRALAI = "mistralai" - ANTHROPIC = "anthropic" + OPENAI = "openai" + OPENROUTER = "openrouter" PERPLEXITYAI = "perplexityai" TOGETHERAI = "togetherai" - OPENROUTER = "openrouter" - GEMINI = "gemini" # TODO: Move to backend client types @@ -97,7 +97,7 @@ async def _get_secrets(self, request: Request) -> Optional[Dict]: return secrets - local_secrets = [] + local_secrets: List[SecretDTO] = [] try: for provider_kind in ProviderKind: @@ -120,7 +120,7 @@ async def _get_secrets(self, request: Request) -> Optional[Dict]: except: # pylint: disable=bare-except display_exception("Vault: Local Secrets Exception") - vault_secrets = [] + vault_secrets: List[SecretDTO] = [] try: async with httpx.AsyncClient() as client: @@ -135,7 +135,8 @@ async def _get_secrets(self, request: Request) -> Optional[Dict]: else: vault = response.json() - vault_secrets = vault.get("secrets") + vault_secrets = [secret.secret.model_dump() for secret in vault] + except: # pylint: disable=bare-except display_exception("Vault: Vault Secrets Exception") From b3fea92a0f26b997338dda1de5fd50b18563b00e Mon Sep 17 00:00:00 2001 From: Abram Date: Tue, 17 Dec 2024 13:52:10 +0100 Subject: [PATCH 10/11] refactor (sdk): - resolve AttributeError in vault middleware - added function to transform secrets response to secret dto - replace types in vault middleware with backend client types - fixed error when iterating ProviderKind --- TypeError: '_UnionGenericAlias' object is not iterable --- agenta-cli/agenta/sdk/middleware/vault.py | 87 +++++++++++------------ 1 file changed, 43 insertions(+), 44 deletions(-) diff --git a/agenta-cli/agenta/sdk/middleware/vault.py b/agenta-cli/agenta/sdk/middleware/vault.py index b19836d88d..d7c1793af2 100644 --- a/agenta-cli/agenta/sdk/middleware/vault.py +++ b/agenta-cli/agenta/sdk/middleware/vault.py @@ -1,53 +1,41 @@ -from typing import Callable, Dict, Optional, List - -from enum import Enum from os import getenv from json import dumps - -from pydantic import BaseModel +from typing import Callable, Dict, Optional, List, Any import httpx -from starlette.middleware.base import BaseHTTPMiddleware from fastapi import FastAPI, Request +from starlette.middleware.base import BaseHTTPMiddleware -from agenta.sdk.middleware.cache import TTLLRUCache, CACHE_CAPACITY, CACHE_TTL from agenta.sdk.utils.constants import TRUTHY +from agenta.client.backend.types.provider_kind import ProviderKind from agenta.sdk.utils.exceptions import suppress, display_exception +from agenta.client.backend.types.secret_dto import SecretDto as SecretDTO +from agenta.client.backend.types.provider_key_dto import ( + ProviderKeyDto as ProviderKeyDTO, +) +from agenta.sdk.middleware.cache import TTLLRUCache, CACHE_CAPACITY, CACHE_TTL import agenta as ag -# TODO: Move to backend client types -class SecretKind(str, Enum): - PROVIDER_KEY = "provider_key" - +# ProviderKind (agenta.client.backend.types.provider_kind import ProviderKind) defines a type hint that allows \ +# for a fixed set of string literals representing various provider names, alongside `typing.Any`. +PROVIDER_KINDS = [] -# TODO: Move to backend client types -class ProviderKind(str, Enum): - ALEPHALPHA = "alephalpha" - ANTHROPIC = "anthropic" - ANYSCALE = "anyscale" - COHERE = "cohere" - DEEPINFRA = "deepinfra" - GEMINI = "gemini" - GROQ = "groq" - MISTRALAI = "mistralai" - OPENAI = "openai" - OPENROUTER = "openrouter" - PERPLEXITYAI = "perplexityai" - TOGETHERAI = "togetherai" +# Rationale behind the following: +# ------------------------------- +# You cannot loop directly over the values in `typing.Literal` because: +# - `Literal` is not iterable. +# - `ProviderKind.__args__` includes `Literal` and `Any`, but the actual string values +# are nested within the `Literal`'s own `__args__` attribute. - -# TODO: Move to backend client types -class ProviderKeyDTO(BaseModel): - provider: ProviderKind - key: str - - -# TODO: Move to backend client types -class SecretDTO(BaseModel): - kind: SecretKind = "provider_key" - data: ProviderKeyDTO +# To solve this, we programmatically extract the values from `Literal` while retaining +# the structure of ProviderKind. This ensures: +# 1. We don't modify the original `ProviderKind` type definition. +# 2. We dynamically access the literal values for use at runtime when necessary. +for arg in ProviderKind.__args__: # type: ignore + if hasattr(arg, "__args__"): + PROVIDER_KINDS.extend(arg.__args__) _CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in TRUTHY @@ -61,6 +49,18 @@ def __init__(self, app: FastAPI): self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host + def _transform_secrets_response_to_secret_dto( + self, secrets_list: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + secrets_dto_dict = [ + { + "kind": secret.get("secret", {}).get("kind"), + "data": secret.get("secret", {}).get("data", {}), + } + for secret in secrets_list + ] + return secrets_dto_dict + async def dispatch( self, request: Request, @@ -100,16 +100,15 @@ async def _get_secrets(self, request: Request) -> Optional[Dict]: local_secrets: List[SecretDTO] = [] try: - for provider_kind in ProviderKind: - provider = provider_kind.value + for provider_kind in PROVIDER_KINDS: + provider = provider_kind key_name = f"{provider.upper()}_API_KEY" key = getenv(key_name) if not key: continue - secret = SecretDTO( - kind=SecretKind.PROVIDER_KEY, + secret = SecretDTO( # 'kind' attribute in SecretDTO defaults to 'provider_kind' data=ProviderKeyDTO( provider=provider, key=key, @@ -133,10 +132,10 @@ async def _get_secrets(self, request: Request) -> Optional[Dict]: vault_secrets = [] else: - vault = response.json() - - vault_secrets = [secret.secret.model_dump() for secret in vault] - + secrets = response.json() + vault_secrets = self._transform_secrets_response_to_secret_dto( + secrets + ) except: # pylint: disable=bare-except display_exception("Vault: Vault Secrets Exception") From 2eda1a027f06fd87126387e400f4cfdfd7a12b9c Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Wed, 18 Dec 2024 13:25:00 +0100 Subject: [PATCH 11/11] VaultManager to SecretsManager --- agenta-cli/agenta/__init__.py | 2 +- agenta-cli/agenta/sdk/__init__.py | 2 +- agenta-cli/agenta/sdk/managers/{vault.py => secrets.py} | 4 ++-- agenta-cli/tests/baggage/app.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) rename agenta-cli/agenta/sdk/managers/{vault.py => secrets.py} (91%) diff --git a/agenta-cli/agenta/__init__.py b/agenta-cli/agenta/__init__.py index 53600a4a1e..59629b8dde 100644 --- a/agenta-cli/agenta/__init__.py +++ b/agenta-cli/agenta/__init__.py @@ -28,7 +28,7 @@ from .sdk.utils.costs import calculate_token_usage from .sdk.client import Agenta from .sdk.litellm import litellm as callbacks -from .sdk.managers.vault import VaultManager +from .sdk.managers.secrets import SecretsManager from .sdk.managers.config import ConfigManager from .sdk.managers.variant import VariantManager from .sdk.managers.deployment import DeploymentManager diff --git a/agenta-cli/agenta/sdk/__init__.py b/agenta-cli/agenta/sdk/__init__.py index 4fc475ef45..9f3dc86628 100644 --- a/agenta-cli/agenta/sdk/__init__.py +++ b/agenta-cli/agenta/sdk/__init__.py @@ -27,7 +27,7 @@ from .decorators.routing import entrypoint, app, route from .agenta_init import Config, AgentaSingleton, init as _init from .utils.costs import calculate_token_usage -from .managers.vault import VaultManager +from .managers.secrets import SecretsManager from .managers.config import ConfigManager from .managers.variant import VariantManager from .managers.deployment import DeploymentManager diff --git a/agenta-cli/agenta/sdk/managers/vault.py b/agenta-cli/agenta/sdk/managers/secrets.py similarity index 91% rename from agenta-cli/agenta/sdk/managers/vault.py rename to agenta-cli/agenta/sdk/managers/secrets.py index 88f1908a00..aca5988811 100644 --- a/agenta-cli/agenta/sdk/managers/vault.py +++ b/agenta-cli/agenta/sdk/managers/secrets.py @@ -5,7 +5,7 @@ from agenta.sdk.assets import model_to_provider_mapping -class VaultManager: +class SecretsManager: @staticmethod def get_from_route() -> Optional[Dict[str, Any]]: context = routing_context.get() @@ -19,7 +19,7 @@ def get_from_route() -> Optional[Dict[str, Any]]: @staticmethod def get_api_key_for_model(model: str) -> str: - secrets = VaultManager.get_from_route() + secrets = SecretsManager.get_from_route() if not secrets: return None diff --git a/agenta-cli/tests/baggage/app.py b/agenta-cli/tests/baggage/app.py index 9a4ebc1004..465b65b0ac 100644 --- a/agenta-cli/tests/baggage/app.py +++ b/agenta-cli/tests/baggage/app.py @@ -11,7 +11,7 @@ @ag.instrument() def main(aloha: str = "Aloha") -> str: print(ag.ConfigManager.get_from_route()) - print(ag.VaultManager.get_from_route()) + print(ag.SecretsManager.get_from_route()) print(ag.config.flag) return aloha