diff --git a/ell-studio/src/components/LMPDetailsSidePanel.js b/ell-studio/src/components/LMPDetailsSidePanel.js index a24ac983..1fefb23c 100644 --- a/ell-studio/src/components/LMPDetailsSidePanel.js +++ b/ell-studio/src/components/LMPDetailsSidePanel.js @@ -28,7 +28,7 @@ function VersionItem({ version, index, totalVersions, currentLmpId }) { {isLatest && Latest}
- {getTimeAgo(new Date(version.created_at + "Z"))} + {getTimeAgo(new Date(version.created_at))}
diff --git a/src/ell/decorators/track.py b/src/ell/decorators/track.py index 8dc7aa1b..fe726240 100644 --- a/src/ell/decorators/track.py +++ b/src/ell/decorators/track.py @@ -1,5 +1,5 @@ import logging -from ell.types import SerializedLStr +from ell.types import SerializedLStr, utc_now import ell.util.closure from ell.configurator import config from ell.lstr import lstr @@ -86,7 +86,7 @@ def wrapper(*fn_args, **fn_kwargs) -> str: logger.info(f"Attempted to use cache on {func_to_track.__qualname__} but it was not cached, or did not exist in the store. Refreshing cache...") - _start_time = datetime.now() + _start_time = utc_now() # XXX: thread saftey note, if I prevent yielding right here and get the global context I should be fine re: cache key problem @@ -96,7 +96,7 @@ def wrapper(*fn_args, **fn_kwargs) -> str: if not lmp else fn(*fn_args, _invocation_origin=invocation_id, **fn_kwargs, ) ) - latency_ms = (datetime.now() - _start_time).total_seconds() * 1000 + latency_ms = (utc_now() - _start_time).total_seconds() * 1000 usage = metadata.get("usage", {}) prompt_tokens=usage.get("prompt_tokens", 0) completion_tokens=usage.get("completion_tokens", 0) @@ -145,7 +145,7 @@ def _serialize_lmp(func, name, fn_closure, is_lmp, lm_kwargs): config._store.write_lmp( lmp_id=func.__ell_hash__, name=name, - created_at=datetime.now(), + created_at=utc_now(), source=fn_closure[0], dependencies=fn_closure[1], commit_message=commit, @@ -162,7 +162,7 @@ def _write_invocation(func, invocation_id, latency_ms, prompt_tokens, completion config._store.write_invocation( id=invocation_id, lmp_id=func.__ell_hash__, - created_at=datetime.now(), + created_at=utc_now(), global_vars=get_immutable_vars(func.__ell_closure__[2]), free_vars=get_immutable_vars(func.__ell_closure__[3]), latency_ms=latency_ms, diff --git a/src/ell/store.py b/src/ell/store.py index 02a46ffc..11348a35 100644 --- a/src/ell/store.py +++ b/src/ell/store.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from contextlib import contextmanager +from datetime import datetime from typing import Any, Optional, Dict, List, Set, Union from ell.lstr import lstr from ell.types import InvocableLM @@ -15,7 +16,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str] version_number: int, uses: Dict[str, Any], commit_message: Optional[str] = None, - created_at: Optional[float]=None) -> Optional[Any]: + created_at: Optional[datetime]=None) -> Optional[Any]: """ Write an LMP (Language Model Package) to the storage. @@ -33,7 +34,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str] @abstractmethod def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: Union[lstr, List[lstr]], invocation_kwargs: Dict[str, Any], - created_at: Optional[float], consumes: Set[str], prompt_tokens: Optional[int] = None, + created_at: Optional[datetime], consumes: Set[str], prompt_tokens: Optional[int] = None, completion_tokens: Optional[int] = None, latency_ms: Optional[float] = None, state_cache_key: Optional[str] = None, cost_estimate: Optional[float] = None) -> Optional[Any]: diff --git a/src/ell/stores/sql.py b/src/ell/stores/sql.py index 2bcbac6e..7e32e18b 100644 --- a/src/ell/stores/sql.py +++ b/src/ell/stores/sql.py @@ -1,4 +1,4 @@ -import datetime +from datetime import datetime import json import os from typing import Any, Optional, Dict, List, Set, Union @@ -7,7 +7,7 @@ import cattrs import numpy as np from sqlalchemy.sql import text -from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLMPUses, SerializedLStr +from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLMPUses, SerializedLStr, utc_now from ell.lstr import lstr from sqlalchemy import or_, func, and_ @@ -26,7 +26,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str] global_vars: Dict[str, Any], free_vars: Dict[str, Any], commit_message: Optional[str] = None, - created_at: Optional[float]=None) -> Optional[Any]: + created_at: Optional[datetime]=None) -> Optional[Any]: with Session(self.engine) as session: lmp = session.query(SerializedLMP).filter(SerializedLMP.lmp_id == lmp_id).first() @@ -42,7 +42,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str] dependencies=dependencies, initial_global_vars=global_vars, initial_free_vars=free_vars, - created_at= created_at or datetime.datetime.utcnow(), + created_at= created_at or utc_now(), is_lm=is_lmp, lm_kwargs=lm_kwargs, commit_message=commit_message diff --git a/src/ell/types.py b/src/ell/types.py index ee5bb44d..eb5a4ead 100644 --- a/src/ell/types.py +++ b/src/ell/types.py @@ -1,14 +1,15 @@ # Let's define the core types. from dataclasses import dataclass -from typing import Callable, Dict, List, Union +from typing import Callable, Dict, List, Union, Any, Optional -from typing import Any from ell.lstr import lstr from ell.util.dict_sync_meta import DictSyncMeta -from datetime import datetime +from datetime import datetime, timezone from typing import Any, List, Optional -from sqlmodel import Field, SQLModel, Relationship, JSON, ARRAY, Column, Float +from sqlmodel import Field, SQLModel, Relationship, JSON, Column +from sqlalchemy import func +import sqlalchemy.types as types _lstr_generic = Union[lstr, str] @@ -43,6 +44,14 @@ class Message(dict, metaclass=DictSyncMeta): InvocableLM = Callable[..., _lstr_generic] +def utc_now() -> datetime: + """ + Returns the current UTC timestamp. + Serializes to ISO-8601. + """ + return datetime.now(tz=timezone.utc) + + class SerializedLMPUses(SQLModel, table=True): """ Represents the many-to-many relationship between SerializedLMPs. @@ -54,6 +63,16 @@ class SerializedLMPUses(SQLModel, table=True): lmp_using_id: Optional[str] = Field(default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True) # ID of the LMP that is using the other LMP +class UTCTimestamp(types.TypeDecorator[datetime]): + impl = types.TIMESTAMP + def process_result_value(self, value: datetime, dialect:Any): + return value.replace(tzinfo=timezone.utc) + +def UTCTimestampField(index:bool=False, **kwargs:Any): + return Field( + sa_column= Column(UTCTimestamp(timezone=True),index=index, **kwargs)) + + class SerializedLMP(SQLModel, table=True): """ @@ -65,7 +84,11 @@ class SerializedLMP(SQLModel, table=True): name: str = Field(index=True) # Name of the LMP source: str # Source code or reference for the LMP dependencies: str # List of dependencies for the LMP, stored as a string - created_at: datetime = Field(default_factory=datetime.utcnow, index=True) # Timestamp of when the LMP was created + # Timestamp of when the LMP was created + created_at: datetime = UTCTimestampField( + index=True, + nullable=False + ) is_lm: bool # Boolean indicating if it is an LM (Language Model) or an LMP lm_kwargs: dict = Field(sa_column=Column(JSON)) # Additional keyword arguments for the LMP @@ -131,8 +154,8 @@ class Invocation(SQLModel, table=True): completion_tokens: Optional[int] = Field(default=None) state_cache_key: Optional[str] = Field(default=None) - - created_at: datetime = Field(default_factory=datetime.utcnow) # Timestamp of when the invocation was created + # Timestamp of when the invocation was created + created_at: datetime = UTCTimestampField(default=func.now(), nullable=False) invocation_kwargs: dict = Field(default_factory=dict, sa_column=Column(JSON)) # Additional keyword arguments for the invocation # Relationships diff --git a/tests/test_sql_store.py b/tests/test_sql_store.py new file mode 100644 index 00000000..a8b854ae --- /dev/null +++ b/tests/test_sql_store.py @@ -0,0 +1,88 @@ +import pytest +from datetime import datetime, timezone +from sqlmodel import Session, select +from ell.stores.sql import SQLStore, SerializedLMP +from sqlalchemy import Engine, create_engine + +from ell.types import utc_now + +@pytest.fixture +def in_memory_db(): + return create_engine("sqlite:///:memory:") + +@pytest.fixture +def sql_store(in_memory_db: Engine) -> SQLStore: + store = SQLStore("sqlite:///:memory:") + store.engine = in_memory_db + SerializedLMP.metadata.create_all(in_memory_db) + return store + +def test_write_lmp(sql_store: SQLStore): + # Arrange + lmp_id = "test_lmp_1" + name = "Test LMP" + source = "def test_function(): pass" + dependencies = str(["dep1", "dep2"]) + is_lmp = True + lm_kwargs = '{"param1": "value1"}' + version_number = 1 + uses = {"used_lmp_1": {}, "used_lmp_2": {}} + global_vars = {"global_var1": "value1"} + free_vars = {"free_var1": "value2"} + commit_message = "Initial commit" + created_at = utc_now() + assert created_at.tzinfo is not None + + # Act + sql_store.write_lmp( + lmp_id=lmp_id, + name=name, + source=source, + dependencies=dependencies, + is_lmp=is_lmp, + lm_kwargs=lm_kwargs, + version_number=version_number, + uses=uses, + global_vars=global_vars, + free_vars=free_vars, + commit_message=commit_message, + created_at=created_at + ) + + # Assert + with Session(sql_store.engine) as session: + result = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first() + + assert result is not None + assert result.lmp_id == lmp_id + assert result.name == name + assert result.source == source + assert result.dependencies == str(dependencies) + assert result.is_lm == is_lmp + assert result.lm_kwargs == lm_kwargs + assert result.version_number == version_number + assert result.initial_global_vars == global_vars + assert result.initial_free_vars == free_vars + assert result.commit_message == commit_message + # we want to assert created_at has timezone information + assert result.created_at.tzinfo is not None + + # Test that writing the same LMP again doesn't create a duplicate + sql_store.write_lmp( + lmp_id=lmp_id, + name=name, + source=source, + dependencies=dependencies, + is_lmp=is_lmp, + lm_kwargs=lm_kwargs, + version_number=version_number, + uses=uses, + global_vars=global_vars, + free_vars=free_vars, + commit_message=commit_message, + created_at=created_at + ) + + with Session(sql_store.engine) as session: + count = session.query(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id).count() + assert count == 1