From 4f65459c8bddf2158cf53378f90ab7ceb4192c99 Mon Sep 17 00:00:00 2001 From: Alex Dixon Date: Mon, 5 Aug 2024 22:07:19 -0700 Subject: [PATCH] ensure tz on deserialize timestamp this is needed to ensure we get a utc datetime when reading from sqlite or engines that don't support storing timestamps with a timezone --- src/ell/types.py | 29 ++++++++++++-- tests/test_sql_store.py | 88 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 4 deletions(-) create mode 100644 tests/test_sql_store.py diff --git a/src/ell/types.py b/src/ell/types.py index 1203ea20..7514386a 100644 --- a/src/ell/types.py +++ b/src/ell/types.py @@ -8,7 +8,9 @@ from datetime import datetime, timezone from typing import Any, List, Optional -from sqlmodel import Field, SQLModel, Relationship, JSON, ARRAY, Column, Float +from sqlmodel import Field, SQLModel, Relationship, JSON, Column +from sqlalchemy import TIMESTAMP, func +import sqlalchemy.types as types _lstr_generic = Union[lstr, str] @@ -42,6 +44,10 @@ class Message(dict, metaclass=DictSyncMeta): LMP = Union[OneTurn, MultiTurnLMP, ChatLMP] InvocableLM = Callable[..., _lstr_generic] +from datetime import timezone +from sqlmodel import Field +from typing import Optional + def utc_now() -> datetime: """ @@ -62,6 +68,16 @@ class SerializedLMPUses(SQLModel, table=True): lmp_using_id: Optional[str] = Field(default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True) # ID of the LMP that is using the other LMP +class UTCTimestamp(types.TypeDecorator[datetime]): + impl = types.TIMESTAMP + def process_result_value(self, value: datetime, dialect:Any): + return value.replace(tzinfo=timezone.utc) + +def UTCTimestampField(index:bool=False, **kwargs:Any): + return Field( + sa_column= Column(UTCTimestamp(timezone=True),index=index, **kwargs)) + + class SerializedLMP(SQLModel, table=True): """ @@ -73,7 +89,12 @@ class SerializedLMP(SQLModel, table=True): name: str = Field(index=True) # Name of the LMP source: str # Source code or reference for the LMP dependencies: str # List of dependencies for the LMP, stored as a string - created_at: datetime = Field(default_factory=utc_now, index=True) # Timestamp of when the LMP was created + # Timestamp of when the LMP was created + created_at: datetime = UTCTimestampField( + index=True, + default=func.now(), + nullable=False + ) is_lm: bool # Boolean indicating if it is an LM (Language Model) or an LMP lm_kwargs: dict = Field(sa_column=Column(JSON)) # Additional keyword arguments for the LMP @@ -139,8 +160,8 @@ class Invocation(SQLModel, table=True): completion_tokens: Optional[int] = Field(default=None) state_cache_key: Optional[str] = Field(default=None) - - created_at: datetime = Field(default_factory=utc_now) # Timestamp of when the invocation was created + # Timestamp of when the invocation was created + created_at: datetime = UTCTimestampField(default=func.now(), nullable=False) invocation_kwargs: dict = Field(default_factory=dict, sa_column=Column(JSON)) # Additional keyword arguments for the invocation # Relationships diff --git a/tests/test_sql_store.py b/tests/test_sql_store.py new file mode 100644 index 00000000..a8b854ae --- /dev/null +++ b/tests/test_sql_store.py @@ -0,0 +1,88 @@ +import pytest +from datetime import datetime, timezone +from sqlmodel import Session, select +from ell.stores.sql import SQLStore, SerializedLMP +from sqlalchemy import Engine, create_engine + +from ell.types import utc_now + +@pytest.fixture +def in_memory_db(): + return create_engine("sqlite:///:memory:") + +@pytest.fixture +def sql_store(in_memory_db: Engine) -> SQLStore: + store = SQLStore("sqlite:///:memory:") + store.engine = in_memory_db + SerializedLMP.metadata.create_all(in_memory_db) + return store + +def test_write_lmp(sql_store: SQLStore): + # Arrange + lmp_id = "test_lmp_1" + name = "Test LMP" + source = "def test_function(): pass" + dependencies = str(["dep1", "dep2"]) + is_lmp = True + lm_kwargs = '{"param1": "value1"}' + version_number = 1 + uses = {"used_lmp_1": {}, "used_lmp_2": {}} + global_vars = {"global_var1": "value1"} + free_vars = {"free_var1": "value2"} + commit_message = "Initial commit" + created_at = utc_now() + assert created_at.tzinfo is not None + + # Act + sql_store.write_lmp( + lmp_id=lmp_id, + name=name, + source=source, + dependencies=dependencies, + is_lmp=is_lmp, + lm_kwargs=lm_kwargs, + version_number=version_number, + uses=uses, + global_vars=global_vars, + free_vars=free_vars, + commit_message=commit_message, + created_at=created_at + ) + + # Assert + with Session(sql_store.engine) as session: + result = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first() + + assert result is not None + assert result.lmp_id == lmp_id + assert result.name == name + assert result.source == source + assert result.dependencies == str(dependencies) + assert result.is_lm == is_lmp + assert result.lm_kwargs == lm_kwargs + assert result.version_number == version_number + assert result.initial_global_vars == global_vars + assert result.initial_free_vars == free_vars + assert result.commit_message == commit_message + # we want to assert created_at has timezone information + assert result.created_at.tzinfo is not None + + # Test that writing the same LMP again doesn't create a duplicate + sql_store.write_lmp( + lmp_id=lmp_id, + name=name, + source=source, + dependencies=dependencies, + is_lmp=is_lmp, + lm_kwargs=lm_kwargs, + version_number=version_number, + uses=uses, + global_vars=global_vars, + free_vars=free_vars, + commit_message=commit_message, + created_at=created_at + ) + + with Session(sql_store.engine) as session: + count = session.query(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id).count() + assert count == 1