Skip to content

Commit

Permalink
slight cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-dixon committed Aug 12, 2024
1 parent 03b8d6a commit b228fb0
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 149 deletions.
19 changes: 0 additions & 19 deletions src/ell/stores/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,22 +294,3 @@ def __init__(self, storage_dir: str):
class PostgresStore(SQLStore):
def __init__(self, db_uri: str):
super().__init__(db_uri)



class SQLStorePublisher(SQLStore):
def __init__(self, db_uri: str, pubsub: PubSub):
self.pubsub = pubsub
super().__init__(db_uri)

def write_lmp(self, serialized_lmp: SerializedLMP, uses: Dict[str, Any]) -> Optional[Any]:
super().write_lmp(serialized_lmp, uses)
# todo. return result from write lmp so we can check if it was created or alredy exists
asyncio.create_task(self.pubsub.publish(f"lmp/{serialized_lmp.lmp_id}/created", serialized_lmp))
return None

def write_invocation(self, invocation: Invocation, results: List[SerializedLStr], consumes: Set[str]) -> Optional[Any]:
super().write_invocation(invocation, results, consumes)
asyncio.create_task(self.pubsub.publish(f"lmp/{invocation.lmp_id}/invoked", invocation))
return None

6 changes: 3 additions & 3 deletions src/ell/studio/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import lru_cache
import json
import os
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel

import logging
Expand All @@ -20,10 +20,10 @@ class Config(BaseModel):
pg_connection_string: Optional[str] = None
storage_dir: Optional[str] = None
mqtt_connection_string: Optional[str] = None
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)

def model_post_init(self, __context):
def model_post_init(self, __context: Any):
# Storage
self.pg_connection_string = self.pg_connection_string or os.getenv(
"ELL_PG_CONNECTION_STRING")
Expand Down
124 changes: 1 addition & 123 deletions src/ell/studio/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,128 +13,6 @@

logger = logging.getLogger(__name__)

# ################################## from https://github.com/empicano/aiomqtt/blob/bd91349f9c75482824022bcf1a8c0b1bd50f1349/aiomqtt/client.py#L1
# # SPDX-License-Identifier: BSD-3-Clause
# import dataclasses
# import sys
# from typing import Any

# from fastapi import WebSocket

# if sys.version_info >= (3, 10):
# from typing import TypeAlias
# else:
# from typing_extensions import TypeAlias


# MAX_TOPIC_LENGTH = 65535


# @dataclasses.dataclass(frozen=True)
# class Wildcard:
# """MQTT wildcard that can be subscribed to, but not published to.

# A wildcard is similar to a topic, but can optionally contain ``+`` and ``#``
# placeholders. You can access the ``value`` attribute directly to perform ``str``
# operations on a wildcard.

# Args:
# value: The wildcard string.

# Attributes:
# value: The wildcard string.
# """

# value: str

# def __str__(self) -> str:
# return self.value

# def __post_init__(self) -> None:
# """Validate the wildcard."""
# if not isinstance(self.value, str):
# msg = "Wildcard must be of type str"
# raise TypeError(msg)
# if (
# len(self.value) == 0
# or len(self.value) > MAX_TOPIC_LENGTH
# or "#/" in self.value
# or any(
# "+" in level or "#" in level
# for level in self.value.split("/")
# if len(level) > 1
# )
# ):
# msg = f"Invalid wildcard: {self.value}"
# raise ValueError(msg)


# WildcardLike: TypeAlias = "str | Wildcard"


# @dataclasses.dataclass(frozen=True)
# class Topic(Wildcard):
# """MQTT topic that can be published and subscribed to.

# Args:
# value: The topic string.

# Attributes:
# value: The topic string.
# """

# def __post_init__(self) -> None:
# """Validate the topic."""
# if not isinstance(self.value, str):
# msg = "Topic must be of type str"
# raise TypeError(msg)
# if (
# len(self.value) == 0
# or len(self.value) > MAX_TOPIC_LENGTH
# or "+" in self.value
# or "#" in self.value
# ):
# msg = f"Invalid topic: {self.value}"
# raise ValueError(msg)

# def matches(self, wildcard: WildcardLike) -> bool:
# """Check if the topic matches a given wildcard.

# Args:
# wildcard: The wildcard to match against.

# Returns:
# True if the topic matches the wildcard, False otherwise.
# """
# if not isinstance(wildcard, Wildcard):
# wildcard = Wildcard(wildcard)
# # Split topics into levels to compare them one by one
# topic_levels = self.value.split("/")
# wildcard_levels = str(wildcard).split("/")
# if wildcard_levels[0] == "$share":
# # Shared subscriptions use the topic structure: $share/<group_id>/<topic>
# wildcard_levels = wildcard_levels[2:]

# def recurse(tl: list[str], wl: list[str]) -> bool:
# """Recursively match topic levels with wildcard levels."""
# if not tl:
# if not wl or wl[0] == "#":
# return True
# return False
# if not wl:
# return False
# if wl[0] == "#":
# return True
# if tl[0] == wl[0] or wl[0] == "+":
# return recurse(tl[1:], wl[1:])
# return False

# return recurse(topic_levels, wildcard_levels)


# TopicLike: TypeAlias = "str | Topic"
# ##################################

Subscriber = WebSocket

class PubSub(ABC):
Expand Down Expand Up @@ -189,7 +67,7 @@ def unsubscribe_from_all(self, subscriber: Subscriber):
for topic in self.subscribers.copy():
self.unsubscribe(topic, subscriber)

class MqttPubSub(WebSocketPubSub):
class MqttWebSocketPubSub(WebSocketPubSub):
mqtt_client: aiomqtt.Client
def __init__(self, conn: aiomqtt.Client):
self.mqtt_client = conn
Expand Down
7 changes: 3 additions & 4 deletions src/ell/studio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from fastapi import FastAPI, Query, HTTPException, Depends, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from ell.studio.datamodels import SerializedLMPWithUses,InvocationsAggregate
from ell.studio.pubsub import MqttPubSub, NoOpPubSub, WebSocketPubSub
from ell.studio.pubsub import MqttWebSocketPubSub, NoOpPubSub, WebSocketPubSub
from ell.studio.config import Config

from ell.types import SerializedLMP
Expand Down Expand Up @@ -66,7 +66,7 @@ async def lifespan(app: FastAPI):
try:
async with aiomqtt.Client(config.mqtt_connection_string) as mqtt:
logger.info("Connected to MQTT")
pubsub = MqttPubSub(mqtt)
pubsub = MqttWebSocketPubSub(mqtt)
loop = asyncio.get_event_loop()
task = pubsub.listen(loop)

Expand Down Expand Up @@ -107,7 +107,7 @@ async def lifespan(app: FastAPI):


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, pubsub: MqttPubSub = Depends(get_pubsub)):
async def websocket_endpoint(websocket: WebSocket, pubsub: MqttWebSocketPubSub = Depends(get_pubsub)):
await websocket.accept()
await pubsub.subscribe_async("all", websocket)
try:
Expand All @@ -116,7 +116,6 @@ async def websocket_endpoint(websocket: WebSocket, pubsub: MqttPubSub = Depends(
# Handle incoming WebSocket messages if needed
except WebSocketDisconnect:
pubsub.unsubscribe_from_all(websocket)
# manager.disconnect(websocket)


@app.get("/api/latest/lmps", response_model=list[SerializedLMPWithUses])
Expand Down

0 comments on commit b228fb0

Please sign in to comment.