Skip to content

Commit

Permalink
Merge branch 'main' of github.com:MadcowD/ell
Browse files Browse the repository at this point in the history
  • Loading branch information
MadcowD committed Aug 6, 2024
2 parents 2bce344 + 10c95ab commit c62be06
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 19 deletions.
2 changes: 1 addition & 1 deletion ell-studio/src/components/LMPDetailsSidePanel.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function VersionItem({ version, index, totalVersions, currentLmpId }) {
{isLatest && <span className="text-xs bg-green-500 text-white px-2 py-0.5 rounded">Latest</span>}
</div>
<div className="text-xs text-gray-500 mt-1">
{getTimeAgo(new Date(version.created_at + "Z"))}
{getTimeAgo(new Date(version.created_at))}
</div>
</Link>
</div>
Expand Down
10 changes: 5 additions & 5 deletions src/ell/decorators/track.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from ell.types import SerializedLStr
from ell.types import SerializedLStr, utc_now
import ell.util.closure
from ell.configurator import config
from ell.lstr import lstr
Expand Down Expand Up @@ -86,7 +86,7 @@ def wrapper(*fn_args, **fn_kwargs) -> str:
logger.info(f"Attempted to use cache on {func_to_track.__qualname__} but it was not cached, or did not exist in the store. Refreshing cache...")


_start_time = datetime.now()
_start_time = utc_now()

# XXX: thread saftey note, if I prevent yielding right here and get the global context I should be fine re: cache key problem

Expand All @@ -96,7 +96,7 @@ def wrapper(*fn_args, **fn_kwargs) -> str:
if not lmp
else fn(*fn_args, _invocation_origin=invocation_id, **fn_kwargs, )
)
latency_ms = (datetime.now() - _start_time).total_seconds() * 1000
latency_ms = (utc_now() - _start_time).total_seconds() * 1000
usage = metadata.get("usage", {})
prompt_tokens=usage.get("prompt_tokens", 0)
completion_tokens=usage.get("completion_tokens", 0)
Expand Down Expand Up @@ -145,7 +145,7 @@ def _serialize_lmp(func, name, fn_closure, is_lmp, lm_kwargs):
config._store.write_lmp(
lmp_id=func.__ell_hash__,
name=name,
created_at=datetime.now(),
created_at=utc_now(),
source=fn_closure[0],
dependencies=fn_closure[1],
commit_message=commit,
Expand All @@ -162,7 +162,7 @@ def _write_invocation(func, invocation_id, latency_ms, prompt_tokens, completion
config._store.write_invocation(
id=invocation_id,
lmp_id=func.__ell_hash__,
created_at=datetime.now(),
created_at=utc_now(),
global_vars=get_immutable_vars(func.__ell_closure__[2]),
free_vars=get_immutable_vars(func.__ell_closure__[3]),
latency_ms=latency_ms,
Expand Down
5 changes: 3 additions & 2 deletions src/ell/store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from datetime import datetime
from typing import Any, Optional, Dict, List, Set, Union
from ell.lstr import lstr
from ell.types import InvocableLM
Expand All @@ -15,7 +16,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
version_number: int,
uses: Dict[str, Any],
commit_message: Optional[str] = None,
created_at: Optional[float]=None) -> Optional[Any]:
created_at: Optional[datetime]=None) -> Optional[Any]:
"""
Write an LMP (Language Model Package) to the storage.
Expand All @@ -33,7 +34,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]

@abstractmethod
def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: Union[lstr, List[lstr]], invocation_kwargs: Dict[str, Any],
created_at: Optional[float], consumes: Set[str], prompt_tokens: Optional[int] = None,
created_at: Optional[datetime], consumes: Set[str], prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None, latency_ms: Optional[float] = None,
state_cache_key: Optional[str] = None,
cost_estimate: Optional[float] = None) -> Optional[Any]:
Expand Down
8 changes: 4 additions & 4 deletions src/ell/stores/sql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import datetime
from datetime import datetime
import json
import os
from typing import Any, Optional, Dict, List, Set, Union
Expand All @@ -7,7 +7,7 @@
import cattrs
import numpy as np
from sqlalchemy.sql import text
from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLMPUses, SerializedLStr
from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLMPUses, SerializedLStr, utc_now
from ell.lstr import lstr
from sqlalchemy import or_, func, and_

Expand All @@ -26,7 +26,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
global_vars: Dict[str, Any],
free_vars: Dict[str, Any],
commit_message: Optional[str] = None,
created_at: Optional[float]=None) -> Optional[Any]:
created_at: Optional[datetime]=None) -> Optional[Any]:
with Session(self.engine) as session:
lmp = session.query(SerializedLMP).filter(SerializedLMP.lmp_id == lmp_id).first()

Expand All @@ -42,7 +42,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
dependencies=dependencies,
initial_global_vars=global_vars,
initial_free_vars=free_vars,
created_at= created_at or datetime.datetime.utcnow(),
created_at= created_at or utc_now(),
is_lm=is_lmp,
lm_kwargs=lm_kwargs,
commit_message=commit_message
Expand Down
37 changes: 30 additions & 7 deletions src/ell/types.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Let's define the core types.
from dataclasses import dataclass
from typing import Callable, Dict, List, Union
from typing import Callable, Dict, List, Union, Any, Optional

from typing import Any
from ell.lstr import lstr
from ell.util.dict_sync_meta import DictSyncMeta

from datetime import datetime
from datetime import datetime, timezone
from typing import Any, List, Optional
from sqlmodel import Field, SQLModel, Relationship, JSON, ARRAY, Column, Float
from sqlmodel import Field, SQLModel, Relationship, JSON, Column
from sqlalchemy import func
import sqlalchemy.types as types

_lstr_generic = Union[lstr, str]

Expand Down Expand Up @@ -43,6 +44,14 @@ class Message(dict, metaclass=DictSyncMeta):
InvocableLM = Callable[..., _lstr_generic]


def utc_now() -> datetime:
"""
Returns the current UTC timestamp.
Serializes to ISO-8601.
"""
return datetime.now(tz=timezone.utc)


class SerializedLMPUses(SQLModel, table=True):
"""
Represents the many-to-many relationship between SerializedLMPs.
Expand All @@ -54,6 +63,16 @@ class SerializedLMPUses(SQLModel, table=True):
lmp_using_id: Optional[str] = Field(default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True) # ID of the LMP that is using the other LMP


class UTCTimestamp(types.TypeDecorator[datetime]):
impl = types.TIMESTAMP
def process_result_value(self, value: datetime, dialect:Any):
return value.replace(tzinfo=timezone.utc)

def UTCTimestampField(index:bool=False, **kwargs:Any):
return Field(
sa_column= Column(UTCTimestamp(timezone=True),index=index, **kwargs))



class SerializedLMP(SQLModel, table=True):
"""
Expand All @@ -65,7 +84,11 @@ class SerializedLMP(SQLModel, table=True):
name: str = Field(index=True) # Name of the LMP
source: str # Source code or reference for the LMP
dependencies: str # List of dependencies for the LMP, stored as a string
created_at: datetime = Field(default_factory=datetime.utcnow, index=True) # Timestamp of when the LMP was created
# Timestamp of when the LMP was created
created_at: datetime = UTCTimestampField(
index=True,
nullable=False
)
is_lm: bool # Boolean indicating if it is an LM (Language Model) or an LMP
lm_kwargs: dict = Field(sa_column=Column(JSON)) # Additional keyword arguments for the LMP

Expand Down Expand Up @@ -131,8 +154,8 @@ class Invocation(SQLModel, table=True):
completion_tokens: Optional[int] = Field(default=None)
state_cache_key: Optional[str] = Field(default=None)


created_at: datetime = Field(default_factory=datetime.utcnow) # Timestamp of when the invocation was created
# Timestamp of when the invocation was created
created_at: datetime = UTCTimestampField(default=func.now(), nullable=False)
invocation_kwargs: dict = Field(default_factory=dict, sa_column=Column(JSON)) # Additional keyword arguments for the invocation

# Relationships
Expand Down
88 changes: 88 additions & 0 deletions tests/test_sql_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest
from datetime import datetime, timezone
from sqlmodel import Session, select
from ell.stores.sql import SQLStore, SerializedLMP
from sqlalchemy import Engine, create_engine

from ell.types import utc_now

@pytest.fixture
def in_memory_db():
return create_engine("sqlite:///:memory:")

@pytest.fixture
def sql_store(in_memory_db: Engine) -> SQLStore:
store = SQLStore("sqlite:///:memory:")
store.engine = in_memory_db
SerializedLMP.metadata.create_all(in_memory_db)
return store

def test_write_lmp(sql_store: SQLStore):
# Arrange
lmp_id = "test_lmp_1"
name = "Test LMP"
source = "def test_function(): pass"
dependencies = str(["dep1", "dep2"])
is_lmp = True
lm_kwargs = '{"param1": "value1"}'
version_number = 1
uses = {"used_lmp_1": {}, "used_lmp_2": {}}
global_vars = {"global_var1": "value1"}
free_vars = {"free_var1": "value2"}
commit_message = "Initial commit"
created_at = utc_now()
assert created_at.tzinfo is not None

# Act
sql_store.write_lmp(
lmp_id=lmp_id,
name=name,
source=source,
dependencies=dependencies,
is_lmp=is_lmp,
lm_kwargs=lm_kwargs,
version_number=version_number,
uses=uses,
global_vars=global_vars,
free_vars=free_vars,
commit_message=commit_message,
created_at=created_at
)

# Assert
with Session(sql_store.engine) as session:
result = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first()

assert result is not None
assert result.lmp_id == lmp_id
assert result.name == name
assert result.source == source
assert result.dependencies == str(dependencies)
assert result.is_lm == is_lmp
assert result.lm_kwargs == lm_kwargs
assert result.version_number == version_number
assert result.initial_global_vars == global_vars
assert result.initial_free_vars == free_vars
assert result.commit_message == commit_message
# we want to assert created_at has timezone information
assert result.created_at.tzinfo is not None

# Test that writing the same LMP again doesn't create a duplicate
sql_store.write_lmp(
lmp_id=lmp_id,
name=name,
source=source,
dependencies=dependencies,
is_lmp=is_lmp,
lm_kwargs=lm_kwargs,
version_number=version_number,
uses=uses,
global_vars=global_vars,
free_vars=free_vars,
commit_message=commit_message,
created_at=created_at
)

with Session(sql_store.engine) as session:
count = session.query(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id).count()
assert count == 1

0 comments on commit c62be06

Please sign in to comment.