Skip to content

Commit

Permalink
stores
Browse files Browse the repository at this point in the history
  • Loading branch information
MadcowD committed Aug 4, 2024
1 parent 6156aec commit e6ce85f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
13 changes: 7 additions & 6 deletions src/ell/store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Optional, Dict, List, Set
from typing import Any, Optional, Dict, List, Set, Union
from ell.lstr import lstr
from ell.types import InvocableLM

Expand Down Expand Up @@ -32,7 +32,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
pass

@abstractmethod
def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: lstr | List[lstr], invocation_kwargs: Dict[str, Any],
def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: Union[lstr, List[lstr]], invocation_kwargs: Dict[str, Any],
created_at: Optional[float], consumes: Set[str], prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None, latency_ms: Optional[float] = None,
state_cache_key: Optional[str] = None,
Expand Down Expand Up @@ -118,7 +118,7 @@ def get_latest_lmps(self) -> List[Dict[str, Any]]:


@contextmanager
def freeze(self, *lmps : InvocableLM):
def freeze(self, *lmps: InvocableLM):
"""
A context manager for caching operations using a particular store.
Expand All @@ -138,6 +138,7 @@ def freeze(self, *lmps : InvocableLM):
finally:
# TODO: Implement cache storage logic here
for lmp in lmps:
lmp.__ell_use_cache__ = old_cache_values.get(lmp, None)


if lmp in old_cache_values:
setattr(lmp, '__ell_use_cache__', old_cache_values[lmp])
else:
delattr(lmp, '__ell_use_cache__')
4 changes: 2 additions & 2 deletions src/ell/stores/sql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import json
import os
from typing import Any, Optional, Dict, List, Set
from typing import Any, Optional, Dict, List, Set, Union
from sqlmodel import Session, SQLModel, create_engine, select
import ell.store
import cattrs
Expand Down Expand Up @@ -57,7 +57,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
session.commit()
return None

def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: lstr | List[lstr], invocation_kwargs: Dict[str, Any],
def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: Union[lstr, List[lstr]], invocation_kwargs: Dict[str, Any],
global_vars: Dict[str, Any],
free_vars: Dict[str, Any], created_at: Optional[float], consumes: Set[str], prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None, latency_ms: Optional[float] = None,
Expand Down

0 comments on commit e6ce85f

Please sign in to comment.