diff --git a/src/ell/store.py b/src/ell/store.py index 13861abd..97ede0cd 100644 --- a/src/ell/store.py +++ b/src/ell/store.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, Optional, Dict, List, Set +from typing import Any, Optional, Dict, List, Set, Union from ell.lstr import lstr from ell.types import InvocableLM @@ -32,7 +32,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str] pass @abstractmethod - def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: lstr | List[lstr], invocation_kwargs: Dict[str, Any], + def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: Union[lstr, List[lstr]], invocation_kwargs: Dict[str, Any], created_at: Optional[float], consumes: Set[str], prompt_tokens: Optional[int] = None, completion_tokens: Optional[int] = None, latency_ms: Optional[float] = None, state_cache_key: Optional[str] = None, @@ -118,7 +118,7 @@ def get_latest_lmps(self) -> List[Dict[str, Any]]: @contextmanager - def freeze(self, *lmps : InvocableLM): + def freeze(self, *lmps: InvocableLM): """ A context manager for caching operations using a particular store. @@ -138,6 +138,7 @@ def freeze(self, *lmps : InvocableLM): finally: # TODO: Implement cache storage logic here for lmp in lmps: - lmp.__ell_use_cache__ = old_cache_values.get(lmp, None) - - + if lmp in old_cache_values: + setattr(lmp, '__ell_use_cache__', old_cache_values[lmp]) + else: + delattr(lmp, '__ell_use_cache__') \ No newline at end of file diff --git a/src/ell/stores/sql.py b/src/ell/stores/sql.py index 8152c037..77ae040b 100644 --- a/src/ell/stores/sql.py +++ b/src/ell/stores/sql.py @@ -1,7 +1,7 @@ import datetime import json import os -from typing import Any, Optional, Dict, List, Set +from typing import Any, Optional, Dict, List, Set, Union from sqlmodel import Session, SQLModel, create_engine, select import ell.store import cattrs @@ -57,7 +57,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str] session.commit() return None - def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: lstr | List[lstr], invocation_kwargs: Dict[str, Any], + def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: Union[lstr, List[lstr]], invocation_kwargs: Dict[str, Any], global_vars: Dict[str, Any], free_vars: Dict[str, Any], created_at: Optional[float], consumes: Set[str], prompt_tokens: Optional[int] = None, completion_tokens: Optional[int] = None, latency_ms: Optional[float] = None,