diff --git a/ell-studio/src/components/LMPDetailsSidePanel.js b/ell-studio/src/components/LMPDetailsSidePanel.js
index a24ac983..1fefb23c 100644
--- a/ell-studio/src/components/LMPDetailsSidePanel.js
+++ b/ell-studio/src/components/LMPDetailsSidePanel.js
@@ -28,7 +28,7 @@ function VersionItem({ version, index, totalVersions, currentLmpId }) {
{isLatest && Latest}
- {getTimeAgo(new Date(version.created_at + "Z"))}
+ {getTimeAgo(new Date(version.created_at))}
diff --git a/src/ell/decorators/track.py b/src/ell/decorators/track.py
index 8dc7aa1b..fe726240 100644
--- a/src/ell/decorators/track.py
+++ b/src/ell/decorators/track.py
@@ -1,5 +1,5 @@
import logging
-from ell.types import SerializedLStr
+from ell.types import SerializedLStr, utc_now
import ell.util.closure
from ell.configurator import config
from ell.lstr import lstr
@@ -86,7 +86,7 @@ def wrapper(*fn_args, **fn_kwargs) -> str:
logger.info(f"Attempted to use cache on {func_to_track.__qualname__} but it was not cached, or did not exist in the store. Refreshing cache...")
- _start_time = datetime.now()
+ _start_time = utc_now()
# XXX: thread saftey note, if I prevent yielding right here and get the global context I should be fine re: cache key problem
@@ -96,7 +96,7 @@ def wrapper(*fn_args, **fn_kwargs) -> str:
if not lmp
else fn(*fn_args, _invocation_origin=invocation_id, **fn_kwargs, )
)
- latency_ms = (datetime.now() - _start_time).total_seconds() * 1000
+ latency_ms = (utc_now() - _start_time).total_seconds() * 1000
usage = metadata.get("usage", {})
prompt_tokens=usage.get("prompt_tokens", 0)
completion_tokens=usage.get("completion_tokens", 0)
@@ -145,7 +145,7 @@ def _serialize_lmp(func, name, fn_closure, is_lmp, lm_kwargs):
config._store.write_lmp(
lmp_id=func.__ell_hash__,
name=name,
- created_at=datetime.now(),
+ created_at=utc_now(),
source=fn_closure[0],
dependencies=fn_closure[1],
commit_message=commit,
@@ -162,7 +162,7 @@ def _write_invocation(func, invocation_id, latency_ms, prompt_tokens, completion
config._store.write_invocation(
id=invocation_id,
lmp_id=func.__ell_hash__,
- created_at=datetime.now(),
+ created_at=utc_now(),
global_vars=get_immutable_vars(func.__ell_closure__[2]),
free_vars=get_immutable_vars(func.__ell_closure__[3]),
latency_ms=latency_ms,
diff --git a/src/ell/store.py b/src/ell/store.py
index 02a46ffc..11348a35 100644
--- a/src/ell/store.py
+++ b/src/ell/store.py
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
+from datetime import datetime
from typing import Any, Optional, Dict, List, Set, Union
from ell.lstr import lstr
from ell.types import InvocableLM
@@ -15,7 +16,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
version_number: int,
uses: Dict[str, Any],
commit_message: Optional[str] = None,
- created_at: Optional[float]=None) -> Optional[Any]:
+ created_at: Optional[datetime]=None) -> Optional[Any]:
"""
Write an LMP (Language Model Package) to the storage.
@@ -33,7 +34,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
@abstractmethod
def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: Union[lstr, List[lstr]], invocation_kwargs: Dict[str, Any],
- created_at: Optional[float], consumes: Set[str], prompt_tokens: Optional[int] = None,
+ created_at: Optional[datetime], consumes: Set[str], prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None, latency_ms: Optional[float] = None,
state_cache_key: Optional[str] = None,
cost_estimate: Optional[float] = None) -> Optional[Any]:
diff --git a/src/ell/stores/sql.py b/src/ell/stores/sql.py
index 2bcbac6e..7e32e18b 100644
--- a/src/ell/stores/sql.py
+++ b/src/ell/stores/sql.py
@@ -1,4 +1,4 @@
-import datetime
+from datetime import datetime
import json
import os
from typing import Any, Optional, Dict, List, Set, Union
@@ -7,7 +7,7 @@
import cattrs
import numpy as np
from sqlalchemy.sql import text
-from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLMPUses, SerializedLStr
+from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLMPUses, SerializedLStr, utc_now
from ell.lstr import lstr
from sqlalchemy import or_, func, and_
@@ -26,7 +26,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
global_vars: Dict[str, Any],
free_vars: Dict[str, Any],
commit_message: Optional[str] = None,
- created_at: Optional[float]=None) -> Optional[Any]:
+ created_at: Optional[datetime]=None) -> Optional[Any]:
with Session(self.engine) as session:
lmp = session.query(SerializedLMP).filter(SerializedLMP.lmp_id == lmp_id).first()
@@ -42,7 +42,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
dependencies=dependencies,
initial_global_vars=global_vars,
initial_free_vars=free_vars,
- created_at= created_at or datetime.datetime.utcnow(),
+ created_at= created_at or utc_now(),
is_lm=is_lmp,
lm_kwargs=lm_kwargs,
commit_message=commit_message
diff --git a/src/ell/types.py b/src/ell/types.py
index ee5bb44d..eb5a4ead 100644
--- a/src/ell/types.py
+++ b/src/ell/types.py
@@ -1,14 +1,15 @@
# Let's define the core types.
from dataclasses import dataclass
-from typing import Callable, Dict, List, Union
+from typing import Callable, Dict, List, Union, Any, Optional
-from typing import Any
from ell.lstr import lstr
from ell.util.dict_sync_meta import DictSyncMeta
-from datetime import datetime
+from datetime import datetime, timezone
from typing import Any, List, Optional
-from sqlmodel import Field, SQLModel, Relationship, JSON, ARRAY, Column, Float
+from sqlmodel import Field, SQLModel, Relationship, JSON, Column
+from sqlalchemy import func
+import sqlalchemy.types as types
_lstr_generic = Union[lstr, str]
@@ -43,6 +44,14 @@ class Message(dict, metaclass=DictSyncMeta):
InvocableLM = Callable[..., _lstr_generic]
+def utc_now() -> datetime:
+ """
+ Returns the current UTC timestamp.
+ Serializes to ISO-8601.
+ """
+ return datetime.now(tz=timezone.utc)
+
+
class SerializedLMPUses(SQLModel, table=True):
"""
Represents the many-to-many relationship between SerializedLMPs.
@@ -54,6 +63,16 @@ class SerializedLMPUses(SQLModel, table=True):
lmp_using_id: Optional[str] = Field(default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True) # ID of the LMP that is using the other LMP
+class UTCTimestamp(types.TypeDecorator[datetime]):
+ impl = types.TIMESTAMP
+ def process_result_value(self, value: datetime, dialect:Any):
+ return value.replace(tzinfo=timezone.utc)
+
+def UTCTimestampField(index:bool=False, **kwargs:Any):
+ return Field(
+ sa_column= Column(UTCTimestamp(timezone=True),index=index, **kwargs))
+
+
class SerializedLMP(SQLModel, table=True):
"""
@@ -65,7 +84,11 @@ class SerializedLMP(SQLModel, table=True):
name: str = Field(index=True) # Name of the LMP
source: str # Source code or reference for the LMP
dependencies: str # List of dependencies for the LMP, stored as a string
- created_at: datetime = Field(default_factory=datetime.utcnow, index=True) # Timestamp of when the LMP was created
+ # Timestamp of when the LMP was created
+ created_at: datetime = UTCTimestampField(
+ index=True,
+ nullable=False
+ )
is_lm: bool # Boolean indicating if it is an LM (Language Model) or an LMP
lm_kwargs: dict = Field(sa_column=Column(JSON)) # Additional keyword arguments for the LMP
@@ -131,8 +154,8 @@ class Invocation(SQLModel, table=True):
completion_tokens: Optional[int] = Field(default=None)
state_cache_key: Optional[str] = Field(default=None)
-
- created_at: datetime = Field(default_factory=datetime.utcnow) # Timestamp of when the invocation was created
+ # Timestamp of when the invocation was created
+ created_at: datetime = UTCTimestampField(default=func.now(), nullable=False)
invocation_kwargs: dict = Field(default_factory=dict, sa_column=Column(JSON)) # Additional keyword arguments for the invocation
# Relationships
diff --git a/tests/test_sql_store.py b/tests/test_sql_store.py
new file mode 100644
index 00000000..a8b854ae
--- /dev/null
+++ b/tests/test_sql_store.py
@@ -0,0 +1,88 @@
+import pytest
+from datetime import datetime, timezone
+from sqlmodel import Session, select
+from ell.stores.sql import SQLStore, SerializedLMP
+from sqlalchemy import Engine, create_engine
+
+from ell.types import utc_now
+
+@pytest.fixture
+def in_memory_db():
+ return create_engine("sqlite:///:memory:")
+
+@pytest.fixture
+def sql_store(in_memory_db: Engine) -> SQLStore:
+ store = SQLStore("sqlite:///:memory:")
+ store.engine = in_memory_db
+ SerializedLMP.metadata.create_all(in_memory_db)
+ return store
+
+def test_write_lmp(sql_store: SQLStore):
+ # Arrange
+ lmp_id = "test_lmp_1"
+ name = "Test LMP"
+ source = "def test_function(): pass"
+ dependencies = str(["dep1", "dep2"])
+ is_lmp = True
+ lm_kwargs = '{"param1": "value1"}'
+ version_number = 1
+ uses = {"used_lmp_1": {}, "used_lmp_2": {}}
+ global_vars = {"global_var1": "value1"}
+ free_vars = {"free_var1": "value2"}
+ commit_message = "Initial commit"
+ created_at = utc_now()
+ assert created_at.tzinfo is not None
+
+ # Act
+ sql_store.write_lmp(
+ lmp_id=lmp_id,
+ name=name,
+ source=source,
+ dependencies=dependencies,
+ is_lmp=is_lmp,
+ lm_kwargs=lm_kwargs,
+ version_number=version_number,
+ uses=uses,
+ global_vars=global_vars,
+ free_vars=free_vars,
+ commit_message=commit_message,
+ created_at=created_at
+ )
+
+ # Assert
+ with Session(sql_store.engine) as session:
+ result = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first()
+
+ assert result is not None
+ assert result.lmp_id == lmp_id
+ assert result.name == name
+ assert result.source == source
+ assert result.dependencies == str(dependencies)
+ assert result.is_lm == is_lmp
+ assert result.lm_kwargs == lm_kwargs
+ assert result.version_number == version_number
+ assert result.initial_global_vars == global_vars
+ assert result.initial_free_vars == free_vars
+ assert result.commit_message == commit_message
+ # we want to assert created_at has timezone information
+ assert result.created_at.tzinfo is not None
+
+ # Test that writing the same LMP again doesn't create a duplicate
+ sql_store.write_lmp(
+ lmp_id=lmp_id,
+ name=name,
+ source=source,
+ dependencies=dependencies,
+ is_lmp=is_lmp,
+ lm_kwargs=lm_kwargs,
+ version_number=version_number,
+ uses=uses,
+ global_vars=global_vars,
+ free_vars=free_vars,
+ commit_message=commit_message,
+ created_at=created_at
+ )
+
+ with Session(sql_store.engine) as session:
+ count = session.query(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id).count()
+ assert count == 1