Skip to content

Commit

Permalink
add postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-dixon committed Aug 9, 2024
1 parent 0dc8728 commit cdc9fc5
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 4 deletions.
34 changes: 34 additions & 0 deletions examples/hello_postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import ell
import numpy as np

from ell.stores.sql import PostgresStore


class MyPrompt:
x : int

def get_random_length():
return int(np.random.beta(2, 6) * 1500)

@ell.lm(model="gpt-4o-mini")
def hello(world : str):
"""Your goal is to be really meant to the other guy whiel say hello"""
name = world.capitalize()
number_of_chars_in_name = get_random_length()

return f"Say hello to {name} in {number_of_chars_in_name} characters or more!"


if __name__ == "__main__":
ell.config.verbose = True
ell.set_store(PostgresStore('postgresql://postgres:postgres@localhost:5432/postgres'), autocommit=True)

greeting = hello("sam altman") # > "hello sama! ... "



# F_Theta: X -> Y

# my_prompt_omega: Z -> X


26 changes: 24 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ typing-extensions = "^4.12.2"


black = "^24.8.0"
psycopg2 = "^2.9.9"
[tool.poetry.group.dev.dependencies]
pytest = "^8.3.2"

Expand Down
8 changes: 7 additions & 1 deletion src/ell/stores/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,4 +243,10 @@ class SQLiteStore(SQLStore):
def __init__(self, storage_dir: str):
os.makedirs(storage_dir, exist_ok=True)
db_path = os.path.join(storage_dir, 'ell.db')
super().__init__(f'sqlite:///{db_path}')
super().__init__(f'sqlite:///{db_path}')

class PostgresStore(SQLStore):
def __init__(self, db_uri: str):
super().__init__(db_uri)


3 changes: 2 additions & 1 deletion src/ell/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class SerializedLMPUses(SQLModel, table=True):


class UTCTimestamp(types.TypeDecorator[datetime]):
cache_ok = True
impl = types.TIMESTAMP
def process_result_value(self, value: datetime, dialect:Any):
return value.replace(tzinfo=timezone.utc)
Expand Down Expand Up @@ -118,7 +119,7 @@ class SerializedLStrBase(SQLModel):
id: Optional[int] = Field(default=None, primary_key=True)
content: str
logits: List[float] = Field(default_factory=list, sa_column=Column(JSON))
producer_invocation_id: Optional[int] = Field(default=None, foreign_key="invocation.id", index=True)
producer_invocation_id: Optional[str] = Field(default=None, foreign_key="invocation.id", index=True)

class SerializedLStr(SerializedLStrBase, table=True):
producer_invocation: Optional["Invocation"] = Relationship(back_populates="results")
Expand Down

0 comments on commit cdc9fc5

Please sign in to comment.