Skip to content

Commit

Permalink
ensure tz on deserialize timestamp
Browse files Browse the repository at this point in the history
this is needed to ensure we get a utc datetime when reading from sqlite or engines that don't support storing timestamps with a timezone
  • Loading branch information
alex-dixon committed Aug 6, 2024
1 parent f11149b commit 4f65459
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 4 deletions.
29 changes: 25 additions & 4 deletions src/ell/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

from datetime import datetime, timezone
from typing import Any, List, Optional
from sqlmodel import Field, SQLModel, Relationship, JSON, ARRAY, Column, Float
from sqlmodel import Field, SQLModel, Relationship, JSON, Column
from sqlalchemy import TIMESTAMP, func
import sqlalchemy.types as types

_lstr_generic = Union[lstr, str]

Expand Down Expand Up @@ -42,6 +44,10 @@ class Message(dict, metaclass=DictSyncMeta):
LMP = Union[OneTurn, MultiTurnLMP, ChatLMP]
InvocableLM = Callable[..., _lstr_generic]

from datetime import timezone
from sqlmodel import Field
from typing import Optional


def utc_now() -> datetime:
"""
Expand All @@ -62,6 +68,16 @@ class SerializedLMPUses(SQLModel, table=True):
lmp_using_id: Optional[str] = Field(default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True) # ID of the LMP that is using the other LMP


class UTCTimestamp(types.TypeDecorator[datetime]):
impl = types.TIMESTAMP
def process_result_value(self, value: datetime, dialect:Any):
return value.replace(tzinfo=timezone.utc)

def UTCTimestampField(index:bool=False, **kwargs:Any):
return Field(
sa_column= Column(UTCTimestamp(timezone=True),index=index, **kwargs))



class SerializedLMP(SQLModel, table=True):
"""
Expand All @@ -73,7 +89,12 @@ class SerializedLMP(SQLModel, table=True):
name: str = Field(index=True) # Name of the LMP
source: str # Source code or reference for the LMP
dependencies: str # List of dependencies for the LMP, stored as a string
created_at: datetime = Field(default_factory=utc_now, index=True) # Timestamp of when the LMP was created
# Timestamp of when the LMP was created
created_at: datetime = UTCTimestampField(
index=True,
default=func.now(),
nullable=False
)
is_lm: bool # Boolean indicating if it is an LM (Language Model) or an LMP
lm_kwargs: dict = Field(sa_column=Column(JSON)) # Additional keyword arguments for the LMP

Expand Down Expand Up @@ -139,8 +160,8 @@ class Invocation(SQLModel, table=True):
completion_tokens: Optional[int] = Field(default=None)
state_cache_key: Optional[str] = Field(default=None)


created_at: datetime = Field(default_factory=utc_now) # Timestamp of when the invocation was created
# Timestamp of when the invocation was created
created_at: datetime = UTCTimestampField(default=func.now(), nullable=False)
invocation_kwargs: dict = Field(default_factory=dict, sa_column=Column(JSON)) # Additional keyword arguments for the invocation

# Relationships
Expand Down
88 changes: 88 additions & 0 deletions tests/test_sql_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest
from datetime import datetime, timezone
from sqlmodel import Session, select
from ell.stores.sql import SQLStore, SerializedLMP
from sqlalchemy import Engine, create_engine

from ell.types import utc_now

@pytest.fixture
def in_memory_db():
return create_engine("sqlite:///:memory:")

@pytest.fixture
def sql_store(in_memory_db: Engine) -> SQLStore:
store = SQLStore("sqlite:///:memory:")
store.engine = in_memory_db
SerializedLMP.metadata.create_all(in_memory_db)
return store

def test_write_lmp(sql_store: SQLStore):
# Arrange
lmp_id = "test_lmp_1"
name = "Test LMP"
source = "def test_function(): pass"
dependencies = str(["dep1", "dep2"])
is_lmp = True
lm_kwargs = '{"param1": "value1"}'
version_number = 1
uses = {"used_lmp_1": {}, "used_lmp_2": {}}
global_vars = {"global_var1": "value1"}
free_vars = {"free_var1": "value2"}
commit_message = "Initial commit"
created_at = utc_now()
assert created_at.tzinfo is not None

# Act
sql_store.write_lmp(
lmp_id=lmp_id,
name=name,
source=source,
dependencies=dependencies,
is_lmp=is_lmp,
lm_kwargs=lm_kwargs,
version_number=version_number,
uses=uses,
global_vars=global_vars,
free_vars=free_vars,
commit_message=commit_message,
created_at=created_at
)

# Assert
with Session(sql_store.engine) as session:
result = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first()

assert result is not None
assert result.lmp_id == lmp_id
assert result.name == name
assert result.source == source
assert result.dependencies == str(dependencies)
assert result.is_lm == is_lmp
assert result.lm_kwargs == lm_kwargs
assert result.version_number == version_number
assert result.initial_global_vars == global_vars
assert result.initial_free_vars == free_vars
assert result.commit_message == commit_message
# we want to assert created_at has timezone information
assert result.created_at.tzinfo is not None

# Test that writing the same LMP again doesn't create a duplicate
sql_store.write_lmp(
lmp_id=lmp_id,
name=name,
source=source,
dependencies=dependencies,
is_lmp=is_lmp,
lm_kwargs=lm_kwargs,
version_number=version_number,
uses=uses,
global_vars=global_vars,
free_vars=free_vars,
commit_message=commit_message,
created_at=created_at
)

with Session(sql_store.engine) as session:
count = session.query(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id).count()
assert count == 1

0 comments on commit 4f65459

Please sign in to comment.