Skip to content

Commit

Permalink
add pubsub
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-dixon committed Aug 12, 2024
1 parent d95c242 commit 97455fc
Show file tree
Hide file tree
Showing 12 changed files with 489 additions and 63 deletions.
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12.2
2 changes: 1 addition & 1 deletion examples/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_html_of_url(url: str) -> str:
get_html_of_url
]

@ell.l(model="gpt-4o", temperature=0.1)
@ell.lm(model="gpt-4o", temperature=0.1)
def tool_user(task: str) -> List[Any]:
return [
ell.system(
Expand Down
33 changes: 31 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ include = [
]

[tool.poetry.dependencies]
python = ">=3.9"
python = ">=3.9,<4"
fastapi = "^0.111.1"
numpy = "^2.0.1"
dill = "^0.3.8"
Expand All @@ -38,6 +38,8 @@ typing-extensions = "^4.12.2"


black = "^24.8.0"
psycopg2 = "^2.9.9"
aiomqtt = "^2.3.0"
[tool.poetry.group.dev.dependencies]
pytest = "^8.3.2"

Expand Down
10 changes: 5 additions & 5 deletions src/ell/__version__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
try:
from importlib.metadata import version
except ImportError:
from importlib_metadata import version
from importlib.metadata import version, PackageNotFoundError

__version__ = version("ell")
try:
__version__ = version("ell")
except PackageNotFoundError:
__version__ = "unknown"
43 changes: 32 additions & 11 deletions src/ell/stores/sql.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from datetime import datetime
import json
import os
from typing import Any, Optional, Dict, List, Set, Union
from sqlmodel import Session, SQLModel, create_engine, select
import asyncio
import ell.store
import cattrs
import numpy as np
import os
from ell.studio.pubsub import PubSub
from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLStr
from sqlalchemy import func, and_
from sqlalchemy.sql import text
from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLMPUses, SerializedLStr, utc_now
from ell.lstr import lstr
from sqlalchemy import or_, func, and_
from sqlmodel import Session, SQLModel, create_engine, select
from typing import Any, Optional, Dict, List, Set


class SQLStore(ell.store.Store):
def __init__(self, db_uri: str):
Expand Down Expand Up @@ -243,4 +241,27 @@ class SQLiteStore(SQLStore):
def __init__(self, storage_dir: str):
os.makedirs(storage_dir, exist_ok=True)
db_path = os.path.join(storage_dir, 'ell.db')
super().__init__(f'sqlite:///{db_path}')
super().__init__(f'sqlite:///{db_path}')

class PostgresStore(SQLStore):
def __init__(self, db_uri: str):
super().__init__(db_uri)



class SQLStorePublisher(SQLStore):
def __init__(self, db_uri: str, pubsub: PubSub):
self.pubsub = pubsub
super().__init__(db_uri)

def write_lmp(self, serialized_lmp: SerializedLMP, uses: Dict[str, Any]) -> Optional[Any]:
super().write_lmp(serialized_lmp, uses)
# todo. return result from write lmp so we can check if it was created or alredy exists
asyncio.create_task(self.pubsub.publish(f"lmp/{serialized_lmp.lmp_id}/created", serialized_lmp))
return None

def write_invocation(self, invocation: Invocation, results: List[SerializedLStr], consumes: Set[str]) -> Optional[Any]:
super().write_invocation(invocation, results, consumes)
asyncio.create_task(self.pubsub.publish(f"lmp/{invocation.lmp_id}/invoked", invocation))
return None

28 changes: 23 additions & 5 deletions src/ell/studio/__main__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,36 @@
import asyncio
import os
from fastapi import FastAPI
import uvicorn
from argparse import ArgumentParser
from ell.studio.config import Config
from ell.studio.logger import setup_logging
from ell.studio.server import create_app
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from watchfiles import awatch
import time

def main():
setup_logging()
parser = ArgumentParser(description="ELL Studio Data Server")
parser.add_argument("--storage-dir", default=os.getcwd(),
help="Directory for filesystem serializer storage (default: current directory)")
parser.add_argument("--pg-connection-string", default=None,
help="PostgreSQL connection string (default: None)")
parser.add_argument("--mqtt-connection-string", default=None,
help="MQTT connection string (default: None)")
parser.add_argument("--host", default="127.0.0.1", help="Host to run the server on")
parser.add_argument("--port", type=int, default=8080, help="Port to run the server on")
parser.add_argument("--dev", action="store_true", help="Run in development mode")
args = parser.parse_args()

app = create_app(args.storage_dir)
config = Config(
storage_dir=args.storage_dir,
pg_connection_string=args.pg_connection_string,
mqtt_connection_string=args.mqtt_connection_string
)

app = create_app(config)

if not args.dev:
# In production mode, serve the built React app
Expand All @@ -30,7 +43,7 @@ async def serve_react_app(full_path: str):

db_path = os.path.join(args.storage_dir, "ell.db")

async def db_watcher(db_path, app):
async def db_watcher(db_path: str, app: FastAPI):
last_stat = None

while True:
Expand Down Expand Up @@ -70,8 +83,13 @@ async def db_watcher(db_path, app):

config = uvicorn.Config(app=app, port=args.port, loop=loop)
server = uvicorn.Server(config)
loop.create_task(server.serve())
loop.create_task(db_watcher(db_path, app))

tasks = []
tasks.append(loop.create_task(server.serve()))

if args.storage_dir:
tasks.append(loop.create_task(db_watcher(db_path, app)))

loop.run_forever()

if __name__ == "__main__":
Expand Down
47 changes: 47 additions & 0 deletions src/ell/studio/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@

from functools import lru_cache
import json
import os
from typing import Optional
from pydantic import BaseModel

import logging

logger = logging.getLogger(__name__)


# todo. maybe we default storage dir and other things in the future to a well-known location
# like ~/.ell or something
@lru_cache(maxsize=1)
def ell_home() -> str:
return os.path.join(os.path.expanduser("~"), ".ell")


class Config(BaseModel):
pg_connection_string: Optional[str] = None
storage_dir: Optional[str] = None
mqtt_connection_string: Optional[str] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)

def model_post_init(self, __context):
# Storage
self.pg_connection_string = self.pg_connection_string or os.getenv(
"ELL_PG_CONNECTION_STRING")
self.storage_dir = self.storage_dir or os.getenv("ELL_STORAGE_DIR")

# Enforce that we use either sqlite or postgres, but not both
if self.pg_connection_string is not None and self.storage_dir is not None:
raise ValueError("Cannot use both sqlite and postgres")

# For now, fall back to sqlite if no PostgreSQL connection string is provided
if self.pg_connection_string is None and self.storage_dir is None:
# This intends to honor the default we had set in the CLI
# todo. better default?
self.storage_dir = os.getcwd()

# Pubsub
self.mqtt_connection_string = self.mqtt_connection_string or os.getenv("ELL_MQTT_CONNECTION_STRING")

logger.info(f"Resolved config: {json.dumps(self.model_dump(), indent=2)}")

18 changes: 0 additions & 18 deletions src/ell/studio/connection_manager.py

This file was deleted.

34 changes: 34 additions & 0 deletions src/ell/studio/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import logging
from colorama import Fore, Style, init

def setup_logging(level: int = logging.INFO):
# Initialize colorama for cross-platform colored output
init(autoreset=True)

# Create a custom formatter
class ColoredFormatter(logging.Formatter):
FORMATS = {
logging.DEBUG: Fore.CYAN + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL,
logging.INFO: Fore.GREEN + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL,
logging.WARNING: Fore.YELLOW + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL,
logging.ERROR: Fore.RED + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL,
logging.CRITICAL: Fore.RED + Style.BRIGHT + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL
}

def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt, datefmt="%Y-%m-%d %H:%M:%S")
return formatter.format(record)

# Create and configure the logger
logger = logging.getLogger("ell")
logger.setLevel(level)

# Create console handler and set formatter
console_handler = logging.StreamHandler()
console_handler.setFormatter(ColoredFormatter())

# Add the handler to the logger
logger.addHandler(console_handler)

return logger
Loading

0 comments on commit 97455fc

Please sign in to comment.