Skip to content

Commit

Permalink
Update - improve execute_function and entrypoint functions
Browse files Browse the repository at this point in the history
  • Loading branch information
aybruhm committed Dec 10, 2023
1 parent 8be6ab6 commit e9356b3
Showing 1 changed file with 34 additions and 11 deletions.
45 changes: 34 additions & 11 deletions agenta-cli/agenta/sdk/agenta_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
import inspect
import os
import sys
import time
import traceback
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple, Union, TypeVar

import agenta
from fastapi.responses import JSONResponse
from fastapi import Body, FastAPI, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse

import agenta
from .context import save_context
from .router import router as router
from .types import (
Expand All @@ -26,9 +27,11 @@
TextParam,
MessagesInput,
FileInputURL,
FuncResponse,
)

app = FastAPI()
T = TypeVar("T")

origins = [
"*",
Expand All @@ -52,7 +55,7 @@ def ingest_file(upfile: UploadFile):
return InFile(file_name=upfile.filename, file_path=temp_file.name)


def entrypoint(func: Callable[..., Any]) -> Callable[..., Any]:
def entrypoint(func: Callable[..., T]) -> Callable[..., Dict[str, T]]:
"""
Decorator to wrap a function for HTTP POST and terminal exposure.
Expand All @@ -68,14 +71,14 @@ def entrypoint(func: Callable[..., Any]) -> Callable[..., Any]:
ingestible_files = extract_ingestible_files(func_signature)

@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
def wrapper(*args, **kwargs) -> Dict[str, T]:
func_params, api_config_params = split_kwargs(kwargs, config_params)
ingest_files(func_params, ingestible_files)
agenta.config.set(**api_config_params)
return execute_function(func, *args, **func_params)

@functools.wraps(func)
def wrapper_deployed(*args, **kwargs) -> Any:
def wrapper_deployed(*args, **kwargs) -> FuncResponse:
func_params = {
k: v for k, v in kwargs.items() if k not in ["config", "environment"]
}
Expand All @@ -89,15 +92,15 @@ def wrapper_deployed(*args, **kwargs) -> Any:

update_function_signature(wrapper, func_signature, config_params, ingestible_files)
route = f"/{endpoint_name}"
app.post(route)(wrapper)
app.post(route, response_model=FuncResponse)(wrapper)

update_deployed_function_signature(
wrapper_deployed,
func_signature,
ingestible_files,
)
route_deployed = f"/{endpoint_name}_deployed"
app.post(route_deployed)(wrapper_deployed)
app.post(route_deployed, response_model=FuncResponse)(wrapper_deployed)
override_schema(
openapi_schema=app.openapi(),
func_name=func.__name__,
Expand Down Expand Up @@ -142,13 +145,33 @@ def ingest_files(
func_params[name] = ingest_file(func_params[name])


def execute_function(func: Callable[..., Any], *args, **func_params) -> Any:
"""Execute the function and handle any exceptions."""
def execute_function(
func: Callable[..., Any], *args, **func_params
) -> Union[Dict[str, Any], JSONResponse]:
"""
Execute the given function and handle any exceptions.
Parameters:
- func: The function to be executed.
- args: Positional arguments for the function.
- func_params: Keyword arguments for the function.
Returns:
Either a dictionary or a JSONResponse object.
"""

try:
start_time = time.time()
result = func(*args, **func_params)
end_time = time.time()
latency = end_time - start_time

if isinstance(result, Context):
save_context(result)
return result
if isinstance(result, FuncResponse):
return FuncResponse(**result, latency=str(latency)).dict()
if isinstance(result, str):
return FuncResponse(message=result, latency=str(latency)).dict()
except Exception as e:
return handle_exception(e)

Expand Down

0 comments on commit e9356b3

Please sign in to comment.