diff --git a/src/ell/api/__main__.py b/src/ell/api/__main__.py index de1c2b87..c4adce41 100644 --- a/src/ell/api/__main__.py +++ b/src/ell/api/__main__.py @@ -15,9 +15,12 @@ def main(): help="PostgreSQL connection string (default: None)") parser.add_argument("--mqtt-connection-string", default=None, help="MQTT connection string (default: None)") - parser.add_argument("--host", default="127.0.0.1", help="Host to run the server on") - parser.add_argument("--port", type=int, default=8080, help="Port to run the server on") - parser.add_argument("--dev", action="store_true", help="Run in development mode") + parser.add_argument("--host", default="0.0.0.0", + help="Host to run the server on") + parser.add_argument("--port", type=int, default=8080, + help="Port to run the server on") + parser.add_argument("--dev", action="store_true", + help="Run in development mode") args = parser.parse_args() config = Config( @@ -30,12 +33,18 @@ def main(): loop = asyncio.new_event_loop() - config = uvicorn.Config(app=app, port=args.port, loop=loop) + config = uvicorn.Config( + app=app, + host=args.host, + port=args.port, + loop=loop # type: ignore + ) server = uvicorn.Server(config) loop.create_task(server.serve()) loop.run_forever() + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/ell/api/config.py b/src/ell/api/config.py index 6116a825..37086c14 100644 --- a/src/ell/api/config.py +++ b/src/ell/api/config.py @@ -20,6 +20,7 @@ class Config(BaseModel): storage_dir: Optional[str] = None pg_connection_string: Optional[str] = None mqtt_connection_string: Optional[str] = None + log_level: int = logging.INFO def __init__(self, **kwargs: Any): super().__init__(**kwargs) diff --git a/src/ell/api/logger.py b/src/ell/api/logger.py new file mode 100644 index 00000000..2983125f --- /dev/null +++ b/src/ell/api/logger.py @@ -0,0 +1,40 @@ +import logging +from colorama import Fore, Style, init + +initialized = False + +def setup_logging(level: int = logging.INFO): + global initialized + if initialized: + return + # Initialize colorama for cross-platform colored output + init(autoreset=True) + + # Create a custom formatter + class ColoredFormatter(logging.Formatter): + FORMATS = { + logging.DEBUG: Fore.CYAN + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.INFO: Fore.GREEN + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.WARNING: Fore.YELLOW + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.ERROR: Fore.RED + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.CRITICAL: Fore.RED + Style.BRIGHT + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL + } + + def format(self, record: logging.LogRecord) -> str: + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt, datefmt="%Y-%m-%d %H:%M:%S") + return formatter.format(record) + + # Create and configure the logger + logger = logging.getLogger("ell") + logger.setLevel(level) + + # Create console handler and set formatter + console_handler = logging.StreamHandler() + console_handler.setFormatter(ColoredFormatter()) + + # Add the handler to the logger + logger.addHandler(console_handler) + initialized = True + + return logger \ No newline at end of file diff --git a/src/ell/api/publisher.py b/src/ell/api/publisher.py new file mode 100644 index 00000000..cace87e4 --- /dev/null +++ b/src/ell/api/publisher.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod + +import aiomqtt + + +class Publisher(ABC): + @abstractmethod + async def publish(self, topic: str, message: str) -> None: + pass + + +class MqttPub(Publisher): + def __init__(self, conn: aiomqtt.Client): + self.mqtt_client = conn + + async def publish(self, topic: str, message: str) -> None: + await self.mqtt_client.publish(topic, message) + + +class NoopPublisher(Publisher): + async def publish(self, topic: str, message: str) -> None: + pass \ No newline at end of file diff --git a/src/ell/api/server.py b/src/ell/api/server.py index 38c55e67..c1ada1df 100644 --- a/src/ell/api/server.py +++ b/src/ell/api/server.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod import asyncio from contextlib import asynccontextmanager import json @@ -9,35 +8,18 @@ from fastapi import Depends, FastAPI, HTTPException from sqlmodel import Session from ell.api.config import Config -from ell.api.types import GetLMPResponse, WriteLMPInput, LMP +from ell.api.publisher import MqttPub, NoopPublisher, Publisher +from ell.api.types import GetLMPResponse, WriteInvocationInput, WriteLMPInput, LMP from ell.store import Store from ell.stores.sql import PostgresStore, SQLStore, SQLiteStore +from ell.studio.logger import setup_logging from ell.types import Invocation, SerializedLStr logger = logging.getLogger(__name__) -class Publisher(ABC): - @abstractmethod - async def publish(self, topic: str, message: str) -> None: - pass - - -class MqttPub(Publisher): - def __init__(self, conn: aiomqtt.Client): - self.mqtt_client = conn - - async def publish(self, topic: str, message: str) -> None: - await self.mqtt_client.publish(topic, message) - - -class NoopPublisher(Publisher): - async def publish(self, topic: str, message: str) -> None: - pass - - -publisher = None +publisher: Publisher | None = None async def get_publisher(): @@ -51,7 +33,7 @@ def init_serializer(config: Config) -> SQLStore: if serializer is not None: return serializer elif config.pg_connection_string: - return PostgresStore(config.pg_connection_string) + return PostgresStore(config.pg_connection_string) elif config.storage_dir: return SQLiteStore(config.storage_dir) else: @@ -72,12 +54,12 @@ def get_session(): def create_app(config: Config): + setup_logging(config.log_level) + app = FastAPI( title="ELL API", description="API server for ELL", version="0.1.0", - # dependencies=[Depends(get_publisher), - # Depends(get_serializer)] ) @asynccontextmanager @@ -132,29 +114,31 @@ async def write_lmp( ) ) - @app.post("/invocation") + @app.post("/invocation", response_model=WriteInvocationInput) async def write_invocation( - invocation: Invocation, - results: List[SerializedLStr], - consumes: Set[str], + input: WriteInvocationInput, publisher: Publisher = Depends(get_publisher), serializer: Store = Depends(get_serializer) ): + ser_input = input.to_serialized_invocation_input() serializer.write_invocation( - invocation, - results, - consumes + invocation=ser_input['invocation'], + results=ser_input['results'], + consumes=ser_input['consumes'] ) loop = asyncio.get_event_loop() loop.create_task( publisher.publish( - f"lmp/{invocation.lmp_id}/invoked", + f"lmp/{input.invocation.lmp_id}/invoked", json.dumps({ - "invocation": invocation, - "results": results, - "consumes": consumes + 'foo': 'bar' + # "not json serializable lol" + # "invocation": input.invocation, + # "results": results, + # "consumes": consumes }) ) ) + return input return app diff --git a/src/ell/api/types.py b/src/ell/api/types.py index d3e059e9..af589619 100644 --- a/src/ell/api/types.py +++ b/src/ell/api/types.py @@ -1,27 +1,13 @@ -from typing import Annotated, Any, Dict, Optional, cast -from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Set, cast +from datetime import datetime +from numpy import ndarray from openai import BaseModel -from pydantic import AwareDatetime, BeforeValidator, Field +from pydantic import AwareDatetime, Field +from ell.lstr import lstr -from ell.types import SerializedLMP, utc_now - - -def iso_timestamp_to_utc_datetime(v: datetime) -> datetime: - if isinstance(v, str): - return datetime.fromisoformat(v).replace(tzinfo=timezone.utc) - # elif isinstance(v, datetime): - # if v.tzinfo is not timezone.utc: - # raise ValueError(f"Invalid value for UTCTimestampField: {v}") - # return v - elif v is None: - return None - raise ValueError(f"Invalid value for UTCTimestampField: {v}") - - -# todo. does pydantic compose optional with this or do we have to in the before validator...? -UTCTimestamp = Annotated[AwareDatetime, - BeforeValidator(iso_timestamp_to_utc_datetime)] +from ell.types import SerializedLMP, SerializedLStr, utc_now +import ell.types class WriteLMPInput(BaseModel): @@ -33,12 +19,12 @@ class WriteLMPInput(BaseModel): source: str dependencies: str is_lm: bool - lm_kwargs: Optional[Dict[str, Any]] - initial_free_vars: Optional[Dict[str, Any]] - initial_global_vars: Optional[Dict[str, Any]] + lm_kwargs: Optional[Dict[str, Any]] = None + initial_free_vars: Optional[Dict[str, Any]] = None + initial_global_vars: Optional[Dict[str, Any]] = None # num_invocations: Optional[int] - commit_message: Optional[str] - version_number: Optional[int] + commit_message: Optional[str] = None + version_number: Optional[int] = None created_at: Optional[AwareDatetime] = Field(default_factory=utc_now) def to_serialized_lmp(self): @@ -92,7 +78,72 @@ def from_serialized_lmp(serialized: SerializedLMP): # lmp: LMP # uses: List[str] + GetLMPResponse = LMP # class LMPCreatedEvent(BaseModel): # lmp: LMP # uses: List[str] + + +class Invocation(BaseModel): + """ + An invocation of an LMP. + """ + id: Optional[str] = None + lmp_id: str + args: List[Any] + kwargs: Dict[str, Any] + global_vars: Dict[str, Any] + free_vars: Dict[str, Any] + latency_ms: int + invocation_kwargs: Dict[str, Any] + prompt_tokens: Optional[int] = None + completion_tokens: Optional[int] = None + state_cache: Optional[str] = None + created_at: AwareDatetime = Field(default_factory=utc_now) + # used_by_id: Optional[str] = None + + def to_serialized_invocation(self): + return ell.types.Invocation( + **self.model_dump() + ) + + +class WriteInvocationInputLStr(BaseModel): + id: Optional[str] = None + content: str + logits: Optional[List[float]] = None + + +def lstr_to_serialized_lstr(ls: lstr) -> SerializedLStr: + return SerializedLStr( + content=str(ls), + logits=ls.logits if ls.logits is not None else None + ) + + +class WriteInvocationInput(BaseModel): + """ + Arguments to write an invocation. + """ + invocation: Invocation + results: List[WriteInvocationInputLStr] + consumes: List[str] + + def to_serialized_invocation_input(self): + results = [ + SerializedLStr( + id=ls.id, + content=ls.content, + logits=ndarray( + ls.logits) if ls.logits is not None else None + ) + for ls in self.results] + + sinvo = self.invocation.to_serialized_invocation() + return { + 'invocation': sinvo, + 'results': results, + # todo. is this a list or a set? + 'consumes': self.consumes + } diff --git a/src/ell/studio/logger.py b/src/ell/studio/logger.py index 58a493e6..ca291177 100644 --- a/src/ell/studio/logger.py +++ b/src/ell/studio/logger.py @@ -3,6 +3,9 @@ initialized = False def setup_logging(level: int = logging.INFO): + global initialized + if initialized: + return # Initialize colorama for cross-platform colored output init(autoreset=True) @@ -31,7 +34,7 @@ def format(self, record): # Add the handler to the logger logger.addHandler(console_handler) - global initialized + initialized = True return logger \ No newline at end of file diff --git a/src/ell/types.py b/src/ell/types.py index 907434a9..e90f90d8 100644 --- a/src/ell/types.py +++ b/src/ell/types.py @@ -1,9 +1,7 @@ # Let's define the core types. from dataclasses import dataclass -from typing import Annotated, Callable, Dict, List, Union, Any, Optional -from pydantic import BeforeValidator +from typing import Callable, Dict, List, Union, Any, Optional -from sqlalchemy.engine.interfaces import Dialect from ell.lstr import lstr from ell.util.dict_sync_meta import DictSyncMeta @@ -70,12 +68,6 @@ class UTCTimestamp(types.TypeDecorator[datetime]): impl = types.TIMESTAMP def process_result_value(self, value: datetime, dialect:Any): return value.replace(tzinfo=timezone.utc) - # def process_bind_param(self, value: str|datetime, dialect:Any): - # if isinstance(value, str): - # return datetime.fromisoformat(value).replace(tzinfo=timezone.utc) - # elif isinstance(value, datetime): - # return value.replace(tzinfo=timezone.utc) - # raise ValueError(f"Invalid value for UTCTimestamp: {value}") def UTCTimestampField(index:bool=False, **kwargs:Any): diff --git a/tests/api/test_api.py b/tests/api/test_api.py index b90bbb41..4d54c68c 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -11,13 +11,14 @@ from ell.stores.sql import SQLStore, SQLiteStore from ell.studio.logger import setup_logging -from ell.types import SerializedLMP, utc_now +from ell.types import SerializedLMP, utc_now @pytest.fixture def sql_store() -> SQLStore: return SQLiteStore(":memory:") + def test_construct_serialized_lmp(): serialized_lmp = SerializedLMP( lmp_id="test_lmp_id", @@ -41,8 +42,9 @@ def test_construct_serialized_lmp(): assert serialized_lmp.version_number == 1 assert serialized_lmp.created_at is not None + def test_write_lmp_input(): - ## Should be able to construct a WriteLMPInput from data + # Should be able to construct a WriteLMPInput from data input = WriteLMPInput( lmp_id="test_lmp_id", name="Test LMP", @@ -60,7 +62,7 @@ def test_write_lmp_input(): assert input.created_at is not None assert input.created_at.tzinfo == timezone.utc - ## Should be able to construct a SerializedLMP from a WriteLMPInput + # Should be able to construct a SerializedLMP from a WriteLMPInput model = SerializedLMP(**input.model_dump()) assert model.created_at == input.created_at @@ -76,7 +78,7 @@ def test_write_lmp_input(): commit_message="Initial commit", version_number=1, # should work with an isoformat string - created_at=utc_now().isoformat() # type: ignore + created_at=utc_now().isoformat() # type: ignore ) model2 = SerializedLMP(**input2.model_dump()) assert model2.created_at == input2.created_at @@ -84,20 +86,20 @@ def test_write_lmp_input(): assert input2.created_at.tzinfo == timezone.utc - -def test_write_lmp(sql_store: SQLStore): +def create_test_app(sql_store: SQLStore): setup_logging(DEBUG) config = Config(storage_dir=":memory:") app = create_app(config) publisher = NoopPublisher() + async def get_publisher_override(): yield publisher async def get_session_override(): with Session(sql_store.engine) as session: yield session - + def get_serializer_override(): return sql_store @@ -107,8 +109,13 @@ def get_serializer_override(): client = TestClient(app) - - lmp_data:Dict[str, Any] = { + return app, client, publisher, config + + +def test_write_lmp(sql_store: SQLStore): + _app, client, *_ = create_test_app(sql_store) + + lmp_data: Dict[str, Any] = { "lmp_id": uuid4().hex, "name": "Test LMP", "source": "def test_function(): pass", @@ -122,7 +129,7 @@ def get_serializer_override(): "commit_message": "Initial commit", "created_at": utc_now().isoformat().replace("+00:00", "Z") } - uses:Dict[str, Any] = { + uses: Dict[str, Any] = { "used_lmp_1": {}, "used_lmp_2": {} } @@ -131,7 +138,7 @@ def get_serializer_override(): "/lmp", json={ "lmp": lmp_data, - "uses": uses + "uses": uses } ) @@ -142,36 +149,59 @@ def get_serializer_override(): del lmp_data["uses"] assert lmp.json() == {**lmp_data, "num_invocations": 0} + def test_write_invocation(sql_store: SQLStore): - config = Config(storage_dir=":memory:") - app = create_app(config) - client = TestClient(app) + _app, client, *_ = create_test_app(sql_store) + + lmp_id = uuid4().hex + lmp_data: Dict[str, Any] = { + "lmp_id": lmp_id, + "name": "Test LMP", + "source": "def test_function(): pass", + "dependencies": str(["dep1", "dep2"]), + "is_lm": True, + } + response = client.post( + "/lmp", + json={'lmp': lmp_data, 'uses': {}} + ) + assert response.status_code == 200 invocation_data = { - "lmp_id": "test_lmp_id", - "name": "Test Invocation", - "description": "This is a test invocation" + "id": uuid4().hex, + "lmp_id": lmp_id, + "args": ["arg1", "arg2"], + "kwargs": {"kwarg1": "value1"}, + "global_vars": {"global_var1": "value1"}, + "free_vars": {"free_var1": "value2"}, + "latency_ms": 100.0, + "invocation_kwargs": {"model": "gpt-4o", "messages": [{"role": "system", "content": "You are a JSON parser. You respond only in JSON. Do not format using markdown."}, {"role": "user", "content": "You are given the following task: \"What is two plus two?\"\n Parse the task into the following type:\n {'$defs': {'Add': {'properties': {'op': {'const': '+', 'enum': ['+'], 'title': 'Op', 'type': 'string'}, 'a': {'title': 'A', 'type': 'number'}, 'b': {'title': 'B', 'type': 'number'}}, 'required': ['op', 'a', 'b'], 'title': 'Add', 'type': 'object'}, 'Div': {'properties': {'op': {'const': '/', 'enum': ['/'], 'title': 'Op', 'type': 'string'}, 'a': {'title': 'A', 'type': 'number'}, 'b': {'title': 'B', 'type': 'number'}}, 'required': ['op', 'a', 'b'], 'title': 'Div', 'type': 'object'}, 'Mul': {'properties': {'op': {'const': '*', 'enum': ['*'], 'title': 'Op', 'type': 'string'}, 'a': {'title': 'A', 'type': 'number'}, 'b': {'title': 'B', 'type': 'number'}}, 'required': ['op', 'a', 'b'], 'title': 'Mul', 'type': 'object'}, 'Sub': {'properties': {'op': {'const': '-', 'enum': ['-'], 'title': 'Op', 'type': 'string'}, 'a': {'title': 'A', 'type': 'number'}, 'b': {'title': 'B', 'type': 'number'}}, 'required': ['op', 'a', 'b'], 'title': 'Sub', 'type': 'object'}}, 'anyOf': [{'$ref': '#/$defs/Add'}, {'$ref': '#/$defs/Sub'}, {'$ref': '#/$defs/Mul'}, {'$ref': '#/$defs/Div'}]}\n "}], "lm_kwargs": {"temperature": 0.1}, "client": None} } results_data = [ { - "result_id": "test_result_id", - "name": "Test Result", - "description": "This is a test result" + "content": """{ + "op": "+", + "a": 2, + "b": 2 +}""" } ] - consumes_data = ["test_consumes_id"] + consumes_data = [] + input = { + "invocation": invocation_data, + "results": results_data, + "consumes": consumes_data + } response = client.post( "/invocation", - json={ - "invocation": invocation_data, - "results": results_data, - "consumes": consumes_data - } + json=input ) + print(response.json()) assert response.status_code == 200 - assert response.json() == {"message": "Invocation written successfully"} + # assert response.json() == input + if __name__ == "__main__": - pytest.main() \ No newline at end of file + pytest.main()