diff --git a/src/ell/api/__init__.py b/src/ell/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ell/api/__main__.py b/src/ell/api/__main__.py new file mode 100644 index 00000000..de1c2b87 --- /dev/null +++ b/src/ell/api/__main__.py @@ -0,0 +1,41 @@ +import asyncio +import uvicorn +from argparse import ArgumentParser +from ell.api.config import Config +from ell.api.server import create_app +from ell.studio.logger import setup_logging + + +def main(): + setup_logging() + parser = ArgumentParser(description="ELL API Server") + parser.add_argument("--storage-dir", default=None, + help="Storage directory (default: None)") + parser.add_argument("--pg-connection-string", default=None, + help="PostgreSQL connection string (default: None)") + parser.add_argument("--mqtt-connection-string", default=None, + help="MQTT connection string (default: None)") + parser.add_argument("--host", default="127.0.0.1", help="Host to run the server on") + parser.add_argument("--port", type=int, default=8080, help="Port to run the server on") + parser.add_argument("--dev", action="store_true", help="Run in development mode") + args = parser.parse_args() + + config = Config( + storage_dir=args.storage_dir, + pg_connection_string=args.pg_connection_string, + mqtt_connection_string=args.mqtt_connection_string, + ) + + app = create_app(config) + + loop = asyncio.new_event_loop() + + config = uvicorn.Config(app=app, port=args.port, loop=loop) + server = uvicorn.Server(config) + + loop.create_task(server.serve()) + + loop.run_forever() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/ell/api/config.py b/src/ell/api/config.py new file mode 100644 index 00000000..6116a825 --- /dev/null +++ b/src/ell/api/config.py @@ -0,0 +1,46 @@ +from functools import lru_cache +import json +import os +from typing import Any, Optional +from pydantic import BaseModel + +import logging + +logger = logging.getLogger(__name__) + + +# todo. maybe we default storage dir and other things in the future to a well-known location +# like ~/.ell or something +@lru_cache(maxsize=1) +def ell_home() -> str: + return os.path.join(os.path.expanduser("~"), ".ell") + + +class Config(BaseModel): + storage_dir: Optional[str] = None + pg_connection_string: Optional[str] = None + mqtt_connection_string: Optional[str] = None + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + + def model_post_init(self, __context: Any): + # Storage + self.pg_connection_string = self.pg_connection_string or os.getenv( + "ELL_PG_CONNECTION_STRING") + self.storage_dir = self.storage_dir or os.getenv("ELL_STORAGE_DIR") + + # Enforce that we use either sqlite or postgres, but not both + if self.pg_connection_string is not None and self.storage_dir is not None: + raise ValueError("Cannot use both sqlite and postgres") + + # For now, fall back to sqlite if no PostgreSQL connection string is provided + if self.pg_connection_string is None and self.storage_dir is None: + # This intends to honor the default we had set in the CLI + # todo. better default? + self.storage_dir = os.getcwd() + + # Pubsub + self.mqtt_connection_string = self.mqtt_connection_string or os.getenv("ELL_MQTT_CONNECTION_STRING") + + logger.info(f"Resolved config: {json.dumps(self.model_dump(), indent=2)}") + diff --git a/src/ell/api/server.py b/src/ell/api/server.py new file mode 100644 index 00000000..38c55e67 --- /dev/null +++ b/src/ell/api/server.py @@ -0,0 +1,160 @@ +from abc import ABC, abstractmethod +import asyncio +from contextlib import asynccontextmanager +import json +import logging +from typing import Any, Dict, List, Set + +import aiomqtt +from fastapi import Depends, FastAPI, HTTPException +from sqlmodel import Session +from ell.api.config import Config +from ell.api.types import GetLMPResponse, WriteLMPInput, LMP +from ell.store import Store +from ell.stores.sql import PostgresStore, SQLStore, SQLiteStore +from ell.types import Invocation, SerializedLStr + + +logger = logging.getLogger(__name__) + + +class Publisher(ABC): + @abstractmethod + async def publish(self, topic: str, message: str) -> None: + pass + + +class MqttPub(Publisher): + def __init__(self, conn: aiomqtt.Client): + self.mqtt_client = conn + + async def publish(self, topic: str, message: str) -> None: + await self.mqtt_client.publish(topic, message) + + +class NoopPublisher(Publisher): + async def publish(self, topic: str, message: str) -> None: + pass + + +publisher = None + + +async def get_publisher(): + yield publisher + +serializer: SQLStore | None = None + + +def init_serializer(config: Config) -> SQLStore: + global serializer + if serializer is not None: + return serializer + elif config.pg_connection_string: + return PostgresStore(config.pg_connection_string) + elif config.storage_dir: + return SQLiteStore(config.storage_dir) + else: + raise ValueError("No storage configuration found") + + +def get_serializer(): + if serializer is None: + raise ValueError("Serializer not initialized") + return serializer + + +def get_session(): + if serializer is None: + raise ValueError("Serializer not initialized") + with Session(serializer.engine) as session: + yield session + + +def create_app(config: Config): + app = FastAPI( + title="ELL API", + description="API server for ELL", + version="0.1.0", + # dependencies=[Depends(get_publisher), + # Depends(get_serializer)] + ) + + @asynccontextmanager + async def lifespan(app: FastAPI): + global serializer + global publisher + + logger.info("Starting lifespan") + + serializer = init_serializer(config) + + if config.mqtt_connection_string is not None: + try: + async with aiomqtt.Client(config.mqtt_connection_string) as mqtt: + logger.info("Connected to MQTT") + publisher = MqttPub(mqtt) + yield # Allow the app to run + except aiomqtt.MqttError as e: + logger.error(f"Failed to connect to MQTT", exc_info=e) + publisher = None + else: + publisher = NoopPublisher() + yield # allow the app to run + + @app.get("/lmp/{lmp_id}", response_model=GetLMPResponse) + async def get_lmp(lmp_id: str, + serializer: Store = Depends(get_serializer), + session: Session = Depends(get_session)): + lmp = serializer.get_lmp(lmp_id, session=session) + if lmp is None: + raise HTTPException(status_code=404, detail="LMP not found") + + return LMP.from_serialized_lmp(lmp) + + @app.post("/lmp") + async def write_lmp( + lmp: WriteLMPInput, + uses: Dict[str, Any], # SerializedLMPUses, + publisher: Publisher = Depends(get_publisher), + serializer: Store = Depends(get_serializer) + ): + serializer.write_lmp(lmp.to_serialized_lmp(), uses) + + loop = asyncio.get_event_loop() + loop.create_task( + publisher.publish( + f"lmp/{lmp.lmp_id}/created", + json.dumps({ + "lmp": lmp.model_dump(), + "uses": uses + }, default=str) + ) + ) + + @app.post("/invocation") + async def write_invocation( + invocation: Invocation, + results: List[SerializedLStr], + consumes: Set[str], + publisher: Publisher = Depends(get_publisher), + serializer: Store = Depends(get_serializer) + ): + serializer.write_invocation( + invocation, + results, + consumes + ) + loop = asyncio.get_event_loop() + loop.create_task( + publisher.publish( + f"lmp/{invocation.lmp_id}/invoked", + json.dumps({ + "invocation": invocation, + "results": results, + "consumes": consumes + }) + ) + ) + + return app diff --git a/src/ell/api/types.py b/src/ell/api/types.py new file mode 100644 index 00000000..d3e059e9 --- /dev/null +++ b/src/ell/api/types.py @@ -0,0 +1,98 @@ +from typing import Annotated, Any, Dict, Optional, cast +from datetime import datetime, timezone + +from openai import BaseModel +from pydantic import AwareDatetime, BeforeValidator, Field + +from ell.types import SerializedLMP, utc_now + + +def iso_timestamp_to_utc_datetime(v: datetime) -> datetime: + if isinstance(v, str): + return datetime.fromisoformat(v).replace(tzinfo=timezone.utc) + # elif isinstance(v, datetime): + # if v.tzinfo is not timezone.utc: + # raise ValueError(f"Invalid value for UTCTimestampField: {v}") + # return v + elif v is None: + return None + raise ValueError(f"Invalid value for UTCTimestampField: {v}") + + +# todo. does pydantic compose optional with this or do we have to in the before validator...? +UTCTimestamp = Annotated[AwareDatetime, + BeforeValidator(iso_timestamp_to_utc_datetime)] + + +class WriteLMPInput(BaseModel): + """ + Arguments to write a LMP. + """ + lmp_id: str + name: str + source: str + dependencies: str + is_lm: bool + lm_kwargs: Optional[Dict[str, Any]] + initial_free_vars: Optional[Dict[str, Any]] + initial_global_vars: Optional[Dict[str, Any]] + # num_invocations: Optional[int] + commit_message: Optional[str] + version_number: Optional[int] + created_at: Optional[AwareDatetime] = Field(default_factory=utc_now) + + def to_serialized_lmp(self): + return SerializedLMP( + lmp_id=self.lmp_id, + name=self.name, + source=self.source, + dependencies=self.dependencies, + is_lm=self.is_lm, + lm_kwargs=self.lm_kwargs, + version_number=self.version_number, + initial_global_vars=self.initial_global_vars, + initial_free_vars=self.initial_free_vars, + commit_message=self.commit_message, + created_at=cast(datetime, self.created_at) + ) + + +class LMP(BaseModel): + lmp_id: str + name: str + source: str + dependencies: str + is_lm: bool + lm_kwargs: Optional[Dict[str, Any]] + initial_free_vars: Optional[Dict[str, Any]] + initial_global_vars: Optional[Dict[str, Any]] + created_at: AwareDatetime + version_number: int + commit_message: Optional[str] + num_invocations: int + + @staticmethod + def from_serialized_lmp(serialized: SerializedLMP): + return LMP( + lmp_id=cast(str, serialized.lmp_id), + name=serialized.name, + source=serialized.source, + dependencies=serialized.dependencies, + is_lm=serialized.is_lm, + lm_kwargs=serialized.lm_kwargs, + initial_free_vars=serialized.initial_free_vars, + initial_global_vars=serialized.initial_global_vars, + created_at=serialized.created_at, + version_number=cast(int, serialized.version_number), + commit_message=serialized.commit_message, + num_invocations=cast(int, serialized.num_invocations), + ) + +# class GetLMPResponse(BaseModel): +# lmp: LMP +# uses: List[str] + +GetLMPResponse = LMP +# class LMPCreatedEvent(BaseModel): +# lmp: LMP +# uses: List[str] diff --git a/src/ell/store.py b/src/ell/store.py index 36f8c492..458267ab 100644 --- a/src/ell/store.py +++ b/src/ell/store.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Optional, Dict, List, Set + +from sqlmodel import Session from ell.types import InvocableLM, SerializedLMP, Invocation, SerializedLStr @@ -9,6 +11,16 @@ class Store(ABC): Abstract base class for serializers. Defines the interface for serializing and deserializing LMPs and invocations. """ + @abstractmethod + def get_lmp(self, lmp_id: str, session: Optional[Session] = None) -> Optional[SerializedLMP]: + """ + Get an LMP by its ID. + + :param lmp_id: ID of the LMP to retrieve. + :return: SerializedLMP object containing all LMP details, or None if the LMP does not exist. + """ + pass + @abstractmethod def write_lmp(self, serialized_lmp: SerializedLMP, uses: Dict[str, Any]) -> Optional[Any]: """ diff --git a/src/ell/stores/sql.py b/src/ell/stores/sql.py index 1c581dea..4a3a18f8 100644 --- a/src/ell/stores/sql.py +++ b/src/ell/stores/sql.py @@ -1,35 +1,45 @@ -import asyncio import ell.store import os -from ell.studio.pubsub import PubSub from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLStr -from sqlalchemy import func, and_ +from sqlalchemy import Engine, func, and_ from sqlalchemy.sql import text from sqlmodel import Session, SQLModel, create_engine, select from typing import Any, Optional, Dict, List, Set from datetime import datetime, timedelta -import cattrs -import numpy as np from sqlalchemy.sql import text -from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLMPUses, SerializedLStr, utc_now -from ell.lstr import lstr +from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLStr +import logging + +logger = logging.getLogger(__name__) + class SQLStore(ell.store.Store): - def __init__(self, db_uri: str): - self.engine = create_engine(db_uri) - SQLModel.metadata.create_all(self.engine) - + def __init__(self, db_uri: Optional[str] = None, engine: Optional[Engine] = None): + if engine is not None: + self.engine = engine + elif db_uri is None: + raise ValueError( + "db_uri cannot be None when engine is not provided as an argument") + else: + self.engine = create_engine(db_uri) - self.open_files: Dict[str, Dict[str, Any]] = {} + SQLModel.metadata.create_all(self.engine) + def get_lmp(self, lmp_id: str,session:Optional[Session] = None) -> Optional[SerializedLMP]: + if session is None: + with Session(self.engine) as session: + return session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first() + else: + return session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first() def write_lmp(self, serialized_lmp: SerializedLMP, uses: Dict[str, Any]) -> Optional[Any]: with Session(self.engine) as session: # Bind the serialized_lmp to the session - lmp = session.query(SerializedLMP).filter(SerializedLMP.lmp_id == serialized_lmp.lmp_id).first() + lmp = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == serialized_lmp.lmp_id)).first() if lmp: # Already added to the DB. + logger.debug(f"LMP {serialized_lmp.lmp_id} already exists in the DB. Skipping write.") return lmp else: session.add(serialized_lmp) @@ -40,11 +50,12 @@ def write_lmp(self, serialized_lmp: SerializedLMP, uses: Dict[str, Any]) -> Opti serialized_lmp.uses.append(used_lmp) session.commit() + logger.debug(f"Wrote new LMP {serialized_lmp.lmp_id} to the DB.") return None def write_invocation(self, invocation: Invocation, results: List[SerializedLStr], consumes: Set[str]) -> Optional[Any]: with Session(self.engine) as session: - lmp = session.query(SerializedLMP).filter(SerializedLMP.lmp_id == invocation.lmp_id).first() + lmp = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == invocation.lmp_id)).first() assert lmp is not None, f"LMP with id {invocation.lmp_id} not found. Writing invocation erroneously" # Increment num_invocations @@ -285,11 +296,22 @@ def get_invocations_aggregate(self, session: Session, lmp_filters: Dict[str, Any "graph_data": graph_data } + class SQLiteStore(SQLStore): def __init__(self, storage_dir: str): - os.makedirs(storage_dir, exist_ok=True) - db_path = os.path.join(storage_dir, 'ell.db') - super().__init__(f'sqlite:///{db_path}') + if ":memory:" not in storage_dir: + db_path = os.path.join(storage_dir, 'ell.db') + return super().__init__(f'sqlite:///{db_path}') + else: + from sqlalchemy.pool import StaticPool + engine = create_engine( + 'sqlite://', + connect_args={'check_same_thread': False}, + poolclass=StaticPool + ) + + return super().__init__(engine=engine) + class PostgresStore(SQLStore): def __init__(self, db_uri: str): diff --git a/src/ell/studio/logger.py b/src/ell/studio/logger.py index 325b190a..58a493e6 100644 --- a/src/ell/studio/logger.py +++ b/src/ell/studio/logger.py @@ -1,6 +1,7 @@ import logging from colorama import Fore, Style, init +initialized = False def setup_logging(level: int = logging.INFO): # Initialize colorama for cross-platform colored output init(autoreset=True) @@ -30,5 +31,7 @@ def format(self, record): # Add the handler to the logger logger.addHandler(console_handler) + global initialized + initialized = True return logger \ No newline at end of file diff --git a/src/ell/types.py b/src/ell/types.py index 53d597ad..907434a9 100644 --- a/src/ell/types.py +++ b/src/ell/types.py @@ -1,6 +1,9 @@ # Let's define the core types. from dataclasses import dataclass -from typing import Callable, Dict, List, Union, Any, Optional +from typing import Annotated, Callable, Dict, List, Union, Any, Optional +from pydantic import BeforeValidator + +from sqlalchemy.engine.interfaces import Dialect from ell.lstr import lstr from ell.util.dict_sync_meta import DictSyncMeta @@ -50,6 +53,7 @@ def utc_now() -> datetime: Serializes to ISO-8601. """ return datetime.now(tz=timezone.utc) + class SerializedLMPUses(SQLModel, table=True): """ Represents the many-to-many relationship between SerializedLMPs. @@ -66,13 +70,19 @@ class UTCTimestamp(types.TypeDecorator[datetime]): impl = types.TIMESTAMP def process_result_value(self, value: datetime, dialect:Any): return value.replace(tzinfo=timezone.utc) + # def process_bind_param(self, value: str|datetime, dialect:Any): + # if isinstance(value, str): + # return datetime.fromisoformat(value).replace(tzinfo=timezone.utc) + # elif isinstance(value, datetime): + # return value.replace(tzinfo=timezone.utc) + # raise ValueError(f"Invalid value for UTCTimestamp: {value}") + def UTCTimestampField(index:bool=False, **kwargs:Any): return Field( sa_column=Column(UTCTimestamp(timezone=True), index=index, **kwargs)) - class SerializedLMPBase(SQLModel): lmp_id: Optional[str] = Field(default=None, primary_key=True) name: str = Field(index=True) @@ -111,6 +121,7 @@ class Config: table_name = "serializedlmp" unique_together = [("version_number", "name")] + class InvocationTrace(SQLModel, table=True): invocation_consumer_id: str = Field(foreign_key="invocation.id", primary_key=True, index=True) invocation_consuming_id: str = Field(foreign_key="invocation.id", primary_key=True, index=True) diff --git a/tests/api/test_api.py b/tests/api/test_api.py new file mode 100644 index 00000000..b90bbb41 --- /dev/null +++ b/tests/api/test_api.py @@ -0,0 +1,177 @@ +from datetime import timezone +from logging import DEBUG +from uuid import uuid4 +import pytest +from typing import Any, Dict +from fastapi.testclient import TestClient +from sqlmodel import Session +from ell.api.server import NoopPublisher, create_app, get_publisher, get_serializer, get_session +from ell.api.config import Config +from ell.api.types import WriteLMPInput + +from ell.stores.sql import SQLStore, SQLiteStore +from ell.studio.logger import setup_logging +from ell.types import SerializedLMP, utc_now + + +@pytest.fixture +def sql_store() -> SQLStore: + return SQLiteStore(":memory:") + +def test_construct_serialized_lmp(): + serialized_lmp = SerializedLMP( + lmp_id="test_lmp_id", + name="Test LMP", + source="def test_function(): pass", + dependencies=str(["dep1", "dep2"]), + lm_kwargs={"param1": "value1"}, + is_lm=True, + version_number=1, + # uses={"used_lmp_1": {}, "used_lmp_2": {}}, + initial_global_vars={"global_var1": "value1"}, + initial_free_vars={"free_var1": "value2"}, + commit_message="Initial commit", + created_at=utc_now() + ) + assert serialized_lmp.lmp_id == "test_lmp_id" + assert serialized_lmp.name == "Test LMP" + assert serialized_lmp.source == "def test_function(): pass" + assert serialized_lmp.dependencies == str(["dep1", "dep2"]) + assert serialized_lmp.lm_kwargs == {"param1": "value1"} + assert serialized_lmp.version_number == 1 + assert serialized_lmp.created_at is not None + +def test_write_lmp_input(): + ## Should be able to construct a WriteLMPInput from data + input = WriteLMPInput( + lmp_id="test_lmp_id", + name="Test LMP", + source="def test_function(): pass", + dependencies=str(["dep1", "dep2"]), + is_lm=True, + lm_kwargs={"param1": "value1"}, + initial_global_vars={"global_var1": "value1"}, + initial_free_vars={"free_var1": "value2"}, + commit_message="Initial commit", + version_number=1, + ) + + # Should default a created_at to utc_now + assert input.created_at is not None + assert input.created_at.tzinfo == timezone.utc + + ## Should be able to construct a SerializedLMP from a WriteLMPInput + model = SerializedLMP(**input.model_dump()) + assert model.created_at == input.created_at + + input2 = WriteLMPInput( + lmp_id="test_lmp_id", + name="Test LMP", + source="def test_function(): pass", + dependencies=str(["dep1", "dep2"]), + is_lm=True, + lm_kwargs={"param1": "value1"}, + initial_global_vars={"global_var1": "value1"}, + initial_free_vars={"free_var1": "value2"}, + commit_message="Initial commit", + version_number=1, + # should work with an isoformat string + created_at=utc_now().isoformat() # type: ignore + ) + model2 = SerializedLMP(**input2.model_dump()) + assert model2.created_at == input2.created_at + assert input2.created_at is not None + assert input2.created_at.tzinfo == timezone.utc + + + +def test_write_lmp(sql_store: SQLStore): + setup_logging(DEBUG) + config = Config(storage_dir=":memory:") + app = create_app(config) + + publisher = NoopPublisher() + async def get_publisher_override(): + yield publisher + + async def get_session_override(): + with Session(sql_store.engine) as session: + yield session + + def get_serializer_override(): + return sql_store + + app.dependency_overrides[get_publisher] = get_publisher_override + app.dependency_overrides[get_session] = get_session_override + app.dependency_overrides[get_serializer] = get_serializer_override + + client = TestClient(app) + + + lmp_data:Dict[str, Any] = { + "lmp_id": uuid4().hex, + "name": "Test LMP", + "source": "def test_function(): pass", + "dependencies": str(["dep1", "dep2"]), + "is_lm": True, + "lm_kwargs": {"param1": "value1"}, + "version_number": 1, + "uses": {"used_lmp_1": {}, "used_lmp_2": {}}, + "initial_global_vars": {"global_var1": "value1"}, + "initial_free_vars": {"free_var1": "value2"}, + "commit_message": "Initial commit", + "created_at": utc_now().isoformat().replace("+00:00", "Z") + } + uses:Dict[str, Any] = { + "used_lmp_1": {}, + "used_lmp_2": {} + } + + response = client.post( + "/lmp", + json={ + "lmp": lmp_data, + "uses": uses + } + ) + + assert response.status_code == 200 + + lmp = client.get(f"/lmp/{lmp_data['lmp_id']}") + assert lmp.status_code == 200 + del lmp_data["uses"] + assert lmp.json() == {**lmp_data, "num_invocations": 0} + +def test_write_invocation(sql_store: SQLStore): + config = Config(storage_dir=":memory:") + app = create_app(config) + client = TestClient(app) + + invocation_data = { + "lmp_id": "test_lmp_id", + "name": "Test Invocation", + "description": "This is a test invocation" + } + results_data = [ + { + "result_id": "test_result_id", + "name": "Test Result", + "description": "This is a test result" + } + ] + consumes_data = ["test_consumes_id"] + + response = client.post( + "/invocation", + json={ + "invocation": invocation_data, + "results": results_data, + "consumes": consumes_data + } + ) + + assert response.status_code == 200 + assert response.json() == {"message": "Invocation written successfully"} + +if __name__ == "__main__": + pytest.main() \ No newline at end of file