Skip to content

Commit

Permalink
add write invocation api
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-dixon committed Aug 15, 2024
1 parent 9daa336 commit 5ed6825
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 105 deletions.
19 changes: 14 additions & 5 deletions src/ell/api/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ def main():
help="PostgreSQL connection string (default: None)")
parser.add_argument("--mqtt-connection-string", default=None,
help="MQTT connection string (default: None)")
parser.add_argument("--host", default="127.0.0.1", help="Host to run the server on")
parser.add_argument("--port", type=int, default=8080, help="Port to run the server on")
parser.add_argument("--dev", action="store_true", help="Run in development mode")
parser.add_argument("--host", default="0.0.0.0",
help="Host to run the server on")
parser.add_argument("--port", type=int, default=8080,
help="Port to run the server on")
parser.add_argument("--dev", action="store_true",
help="Run in development mode")
args = parser.parse_args()

config = Config(
Expand All @@ -30,12 +33,18 @@ def main():

loop = asyncio.new_event_loop()

config = uvicorn.Config(app=app, port=args.port, loop=loop)
config = uvicorn.Config(
app=app,
host=args.host,
port=args.port,
loop=loop # type: ignore
)
server = uvicorn.Server(config)

loop.create_task(server.serve())

loop.run_forever()


if __name__ == "__main__":
main()
main()
1 change: 1 addition & 0 deletions src/ell/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Config(BaseModel):
storage_dir: Optional[str] = None
pg_connection_string: Optional[str] = None
mqtt_connection_string: Optional[str] = None
log_level: int = logging.INFO
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)

Expand Down
40 changes: 40 additions & 0 deletions src/ell/api/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import logging
from colorama import Fore, Style, init

initialized = False

def setup_logging(level: int = logging.INFO):
global initialized
if initialized:
return
# Initialize colorama for cross-platform colored output
init(autoreset=True)

# Create a custom formatter
class ColoredFormatter(logging.Formatter):
FORMATS = {
logging.DEBUG: Fore.CYAN + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL,
logging.INFO: Fore.GREEN + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL,
logging.WARNING: Fore.YELLOW + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL,
logging.ERROR: Fore.RED + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL,
logging.CRITICAL: Fore.RED + Style.BRIGHT + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL
}

def format(self, record: logging.LogRecord) -> str:
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt, datefmt="%Y-%m-%d %H:%M:%S")
return formatter.format(record)

# Create and configure the logger
logger = logging.getLogger("ell")
logger.setLevel(level)

# Create console handler and set formatter
console_handler = logging.StreamHandler()
console_handler.setFormatter(ColoredFormatter())

# Add the handler to the logger
logger.addHandler(console_handler)
initialized = True

return logger
22 changes: 22 additions & 0 deletions src/ell/api/publisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from abc import ABC, abstractmethod

import aiomqtt


class Publisher(ABC):
@abstractmethod
async def publish(self, topic: str, message: str) -> None:
pass


class MqttPub(Publisher):
def __init__(self, conn: aiomqtt.Client):
self.mqtt_client = conn

async def publish(self, topic: str, message: str) -> None:
await self.mqtt_client.publish(topic, message)


class NoopPublisher(Publisher):
async def publish(self, topic: str, message: str) -> None:
pass
56 changes: 20 additions & 36 deletions src/ell/api/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from abc import ABC, abstractmethod
import asyncio
from contextlib import asynccontextmanager
import json
Expand All @@ -9,35 +8,18 @@
from fastapi import Depends, FastAPI, HTTPException
from sqlmodel import Session
from ell.api.config import Config
from ell.api.types import GetLMPResponse, WriteLMPInput, LMP
from ell.api.publisher import MqttPub, NoopPublisher, Publisher
from ell.api.types import GetLMPResponse, WriteInvocationInput, WriteLMPInput, LMP
from ell.store import Store
from ell.stores.sql import PostgresStore, SQLStore, SQLiteStore
from ell.studio.logger import setup_logging
from ell.types import Invocation, SerializedLStr


logger = logging.getLogger(__name__)


class Publisher(ABC):
@abstractmethod
async def publish(self, topic: str, message: str) -> None:
pass


class MqttPub(Publisher):
def __init__(self, conn: aiomqtt.Client):
self.mqtt_client = conn

async def publish(self, topic: str, message: str) -> None:
await self.mqtt_client.publish(topic, message)


class NoopPublisher(Publisher):
async def publish(self, topic: str, message: str) -> None:
pass


publisher = None
publisher: Publisher | None = None


async def get_publisher():
Expand All @@ -51,7 +33,7 @@ def init_serializer(config: Config) -> SQLStore:
if serializer is not None:
return serializer
elif config.pg_connection_string:
return PostgresStore(config.pg_connection_string)
return PostgresStore(config.pg_connection_string)
elif config.storage_dir:
return SQLiteStore(config.storage_dir)
else:
Expand All @@ -72,12 +54,12 @@ def get_session():


def create_app(config: Config):
setup_logging(config.log_level)

app = FastAPI(
title="ELL API",
description="API server for ELL",
version="0.1.0",
# dependencies=[Depends(get_publisher),
# Depends(get_serializer)]
)

@asynccontextmanager
Expand Down Expand Up @@ -132,29 +114,31 @@ async def write_lmp(
)
)

@app.post("/invocation")
@app.post("/invocation", response_model=WriteInvocationInput)
async def write_invocation(
invocation: Invocation,
results: List[SerializedLStr],
consumes: Set[str],
input: WriteInvocationInput,
publisher: Publisher = Depends(get_publisher),
serializer: Store = Depends(get_serializer)
):
ser_input = input.to_serialized_invocation_input()
serializer.write_invocation(
invocation,
results,
consumes
invocation=ser_input['invocation'],
results=ser_input['results'],
consumes=ser_input['consumes']
)
loop = asyncio.get_event_loop()
loop.create_task(
publisher.publish(
f"lmp/{invocation.lmp_id}/invoked",
f"lmp/{input.invocation.lmp_id}/invoked",
json.dumps({
"invocation": invocation,
"results": results,
"consumes": consumes
'foo': 'bar'
# "not json serializable lol"
# "invocation": input.invocation,
# "results": results,
# "consumes": consumes
})
)
)
return input

return app
103 changes: 77 additions & 26 deletions src/ell/api/types.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,13 @@
from typing import Annotated, Any, Dict, Optional, cast
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Set, cast
from datetime import datetime
from numpy import ndarray

from openai import BaseModel
from pydantic import AwareDatetime, BeforeValidator, Field
from pydantic import AwareDatetime, Field
from ell.lstr import lstr

from ell.types import SerializedLMP, utc_now


def iso_timestamp_to_utc_datetime(v: datetime) -> datetime:
if isinstance(v, str):
return datetime.fromisoformat(v).replace(tzinfo=timezone.utc)
# elif isinstance(v, datetime):
# if v.tzinfo is not timezone.utc:
# raise ValueError(f"Invalid value for UTCTimestampField: {v}")
# return v
elif v is None:
return None
raise ValueError(f"Invalid value for UTCTimestampField: {v}")


# todo. does pydantic compose optional with this or do we have to in the before validator...?
UTCTimestamp = Annotated[AwareDatetime,
BeforeValidator(iso_timestamp_to_utc_datetime)]
from ell.types import SerializedLMP, SerializedLStr, utc_now
import ell.types


class WriteLMPInput(BaseModel):
Expand All @@ -33,12 +19,12 @@ class WriteLMPInput(BaseModel):
source: str
dependencies: str
is_lm: bool
lm_kwargs: Optional[Dict[str, Any]]
initial_free_vars: Optional[Dict[str, Any]]
initial_global_vars: Optional[Dict[str, Any]]
lm_kwargs: Optional[Dict[str, Any]] = None
initial_free_vars: Optional[Dict[str, Any]] = None
initial_global_vars: Optional[Dict[str, Any]] = None
# num_invocations: Optional[int]
commit_message: Optional[str]
version_number: Optional[int]
commit_message: Optional[str] = None
version_number: Optional[int] = None
created_at: Optional[AwareDatetime] = Field(default_factory=utc_now)

def to_serialized_lmp(self):
Expand Down Expand Up @@ -92,7 +78,72 @@ def from_serialized_lmp(serialized: SerializedLMP):
# lmp: LMP
# uses: List[str]


GetLMPResponse = LMP
# class LMPCreatedEvent(BaseModel):
# lmp: LMP
# uses: List[str]


class Invocation(BaseModel):
"""
An invocation of an LMP.
"""
id: Optional[str] = None
lmp_id: str
args: List[Any]
kwargs: Dict[str, Any]
global_vars: Dict[str, Any]
free_vars: Dict[str, Any]
latency_ms: int
invocation_kwargs: Dict[str, Any]
prompt_tokens: Optional[int] = None
completion_tokens: Optional[int] = None
state_cache: Optional[str] = None
created_at: AwareDatetime = Field(default_factory=utc_now)
# used_by_id: Optional[str] = None

def to_serialized_invocation(self):
return ell.types.Invocation(
**self.model_dump()
)


class WriteInvocationInputLStr(BaseModel):
id: Optional[str] = None
content: str
logits: Optional[List[float]] = None


def lstr_to_serialized_lstr(ls: lstr) -> SerializedLStr:
return SerializedLStr(
content=str(ls),
logits=ls.logits if ls.logits is not None else None
)


class WriteInvocationInput(BaseModel):
"""
Arguments to write an invocation.
"""
invocation: Invocation
results: List[WriteInvocationInputLStr]
consumes: List[str]

def to_serialized_invocation_input(self):
results = [
SerializedLStr(
id=ls.id,
content=ls.content,
logits=ndarray(
ls.logits) if ls.logits is not None else None
)
for ls in self.results]

sinvo = self.invocation.to_serialized_invocation()
return {
'invocation': sinvo,
'results': results,
# todo. is this a list or a set?
'consumes': self.consumes
}
5 changes: 4 additions & 1 deletion src/ell/studio/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

initialized = False
def setup_logging(level: int = logging.INFO):
global initialized
if initialized:
return
# Initialize colorama for cross-platform colored output
init(autoreset=True)

Expand Down Expand Up @@ -31,7 +34,7 @@ def format(self, record):

# Add the handler to the logger
logger.addHandler(console_handler)
global initialized

initialized = True

return logger
10 changes: 1 addition & 9 deletions src/ell/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Let's define the core types.
from dataclasses import dataclass
from typing import Annotated, Callable, Dict, List, Union, Any, Optional
from pydantic import BeforeValidator
from typing import Callable, Dict, List, Union, Any, Optional

from sqlalchemy.engine.interfaces import Dialect

from ell.lstr import lstr
from ell.util.dict_sync_meta import DictSyncMeta
Expand Down Expand Up @@ -70,12 +68,6 @@ class UTCTimestamp(types.TypeDecorator[datetime]):
impl = types.TIMESTAMP
def process_result_value(self, value: datetime, dialect:Any):
return value.replace(tzinfo=timezone.utc)
# def process_bind_param(self, value: str|datetime, dialect:Any):
# if isinstance(value, str):
# return datetime.fromisoformat(value).replace(tzinfo=timezone.utc)
# elif isinstance(value, datetime):
# return value.replace(tzinfo=timezone.utc)
# raise ValueError(f"Invalid value for UTCTimestamp: {value}")


def UTCTimestampField(index:bool=False, **kwargs:Any):
Expand Down
Loading

0 comments on commit 5ed6825

Please sign in to comment.