From cdc9fc555da464dce68f2c93d9f60a2f25309c49 Mon Sep 17 00:00:00 2001 From: Alex Dixon Date: Fri, 9 Aug 2024 14:02:12 -0700 Subject: [PATCH] add postgres --- examples/hello_postgres.py | 34 ++++++++++++++++++++++++++++++++++ poetry.lock | 26 ++++++++++++++++++++++++-- pyproject.toml | 1 + src/ell/stores/sql.py | 8 +++++++- src/ell/types.py | 3 ++- 5 files changed, 68 insertions(+), 4 deletions(-) create mode 100644 examples/hello_postgres.py diff --git a/examples/hello_postgres.py b/examples/hello_postgres.py new file mode 100644 index 00000000..f71c46db --- /dev/null +++ b/examples/hello_postgres.py @@ -0,0 +1,34 @@ +import ell +import numpy as np + +from ell.stores.sql import PostgresStore + + +class MyPrompt: + x : int + +def get_random_length(): + return int(np.random.beta(2, 6) * 1500) + +@ell.lm(model="gpt-4o-mini") +def hello(world : str): + """Your goal is to be really meant to the other guy whiel say hello""" + name = world.capitalize() + number_of_chars_in_name = get_random_length() + + return f"Say hello to {name} in {number_of_chars_in_name} characters or more!" + + +if __name__ == "__main__": + ell.config.verbose = True + ell.set_store(PostgresStore('postgresql://postgres:postgres@localhost:5432/postgres'), autocommit=True) + + greeting = hello("sam altman") # > "hello sama! ... " + + + + # F_Theta: X -> Y + + # my_prompt_omega: Z -> X + + diff --git a/poetry.lock b/poetry.lock index da86d491..7763f817 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. [[package]] name = "annotated-types" @@ -834,6 +834,28 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "psycopg2" +version = "2.9.9" +description = "psycopg2 - Python-PostgreSQL Database Adapter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "psycopg2-2.9.9-cp310-cp310-win32.whl", hash = "sha256:38a8dcc6856f569068b47de286b472b7c473ac7977243593a288ebce0dc89516"}, + {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, + {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, + {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, + {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, + {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, + {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, + {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, + {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, + {file = "psycopg2-2.9.9-cp38-cp38-win_amd64.whl", hash = "sha256:bac58c024c9922c23550af2a581998624d6e02350f4ae9c5f0bc642c633a2d5e"}, + {file = "psycopg2-2.9.9-cp39-cp39-win32.whl", hash = "sha256:c92811b2d4c9b6ea0285942b2e7cac98a59e166d59c588fe5cfe1eda58e72d59"}, + {file = "psycopg2-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:de80739447af31525feddeb8effd640782cf5998e1a4e9192ebdf829717e3913"}, + {file = "psycopg2-2.9.9.tar.gz", hash = "sha256:d1454bde93fb1e224166811694d600e746430c006fbb031ea06ecc2ea41bf156"}, +] + [[package]] name = "pydantic" version = "2.8.2" @@ -1583,4 +1605,4 @@ npm-install = [] [metadata] lock-version = "2.0" python-versions = ">=3.9" -content-hash = "0f7033afc73feeffedbe46b95f08e51723520a396bdeb8e86326b5ee88110fed" +content-hash = "349392fd9d0b9a11a4a808fb3cf3105e815b4c8a243a706862387286c0482209" diff --git a/pyproject.toml b/pyproject.toml index 15c42dae..37882470 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ typing-extensions = "^4.12.2" black = "^24.8.0" +psycopg2 = "^2.9.9" [tool.poetry.group.dev.dependencies] pytest = "^8.3.2" diff --git a/src/ell/stores/sql.py b/src/ell/stores/sql.py index 53996530..5db6589d 100644 --- a/src/ell/stores/sql.py +++ b/src/ell/stores/sql.py @@ -243,4 +243,10 @@ class SQLiteStore(SQLStore): def __init__(self, storage_dir: str): os.makedirs(storage_dir, exist_ok=True) db_path = os.path.join(storage_dir, 'ell.db') - super().__init__(f'sqlite:///{db_path}') \ No newline at end of file + super().__init__(f'sqlite:///{db_path}') + +class PostgresStore(SQLStore): + def __init__(self, db_uri: str): + super().__init__(db_uri) + + diff --git a/src/ell/types.py b/src/ell/types.py index b6ac4b03..e9d5f106 100644 --- a/src/ell/types.py +++ b/src/ell/types.py @@ -62,6 +62,7 @@ class SerializedLMPUses(SQLModel, table=True): class UTCTimestamp(types.TypeDecorator[datetime]): + cache_ok = True impl = types.TIMESTAMP def process_result_value(self, value: datetime, dialect:Any): return value.replace(tzinfo=timezone.utc) @@ -118,7 +119,7 @@ class SerializedLStrBase(SQLModel): id: Optional[int] = Field(default=None, primary_key=True) content: str logits: List[float] = Field(default_factory=list, sa_column=Column(JSON)) - producer_invocation_id: Optional[int] = Field(default=None, foreign_key="invocation.id", index=True) + producer_invocation_id: Optional[str] = Field(default=None, foreign_key="invocation.id", index=True) class SerializedLStr(SerializedLStrBase, table=True): producer_invocation: Optional["Invocation"] = Relationship(back_populates="results")