From 77a0b1b93e26c6907d4bcca60d2c81078ba5c7de Mon Sep 17 00:00:00 2001 From: moon Date: Thu, 10 Aug 2023 20:15:26 -0700 Subject: [PATCH] 0.4.3 - fix postgres issues --- agentmemory/events.py | 2 +- agentmemory/main.py | 2 + agentmemory/postgres.py | 126 +++++++++++++++++++++++------------- agentmemory/tests/events.py | 3 +- agentmemory/tests/main.py | 9 +++ setup.py | 2 +- 6 files changed, 95 insertions(+), 49 deletions(-) diff --git a/agentmemory/events.py b/agentmemory/events.py index 7b4d55a..4d8f60e 100644 --- a/agentmemory/events.py +++ b/agentmemory/events.py @@ -35,7 +35,7 @@ def increment_epoch(): epoch = get_epoch() epoch = epoch + 1 create_memory("epoch", str(epoch)) - return epoch + return int(epoch) def get_epoch(): diff --git a/agentmemory/main.py b/agentmemory/main.py index 22e90d9..0d6a13a 100644 --- a/agentmemory/main.py +++ b/agentmemory/main.py @@ -336,6 +336,8 @@ def update_memory(category, id, text=None, metadata=None, embedding=None): if isinstance(value, bool): debug_log(f"WARNING: Boolean metadata field {key} converted to string") metadata[key] = str(value) + else: + metadata = {} metadata["updated_at"] = datetime.datetime.now().timestamp() diff --git a/agentmemory/postgres.py b/agentmemory/postgres.py index 849bfba..29b50fc 100644 --- a/agentmemory/postgres.py +++ b/agentmemory/postgres.py @@ -40,6 +40,7 @@ def get( ): category = self.category table_name = self.client._table_name(category) + if not ids: if limit is None: limit = 100 # or another default value @@ -60,9 +61,7 @@ def get( if offset is None: offset = 0 - table_name = self.client._table_name(category) ids = [int(i) for i in ids] - query = f"SELECT * FROM {table_name} WHERE id=ANY(%s) LIMIT %s OFFSET %s" params = (ids, limit, offset) @@ -71,7 +70,14 @@ def get( # Convert rows to list of dictionaries columns = [desc[0] for desc in self.client.cur.description] - result = [dict(zip(columns, row)) for row in rows] + metadata_columns = [col for col in columns if col not in ["id", "document", "embedding"]] + + result = [] + for row in rows: + item = dict(zip(columns, row)) + metadata = {col: item[col] for col in metadata_columns} + item["metadata"] = metadata + result.append(item) return { "ids": [row["id"] for row in result], @@ -79,6 +85,7 @@ def get( "metadatas": [row["metadata"] for row in result], } + def peek(self, limit=10): return self.get(limit=limit) @@ -151,14 +158,21 @@ class PostgresCategory: def __init__(self, name): self.name = name + default_model_path = str(Path.home() / ".cache" / "onnx_models") + class PostgresClient: - def __init__(self, connection_string, model_name = "all-MiniLM-L6-v2", model_path = default_model_path): + def __init__( + self, + connection_string, + model_name="all-MiniLM-L6-v2", + model_path=default_model_path, + ): self.connection = psycopg2.connect(connection_string) self.cur = self.connection.cursor() from pgvector.psycopg2 import register_vector - + register_vector(self.cur) # Register PGVector functions full_model_path = check_model(model_name=model_name, model_path=model_path) self.model_path = full_model_path @@ -173,7 +187,6 @@ def ensure_table_exists(self, category): CREATE TABLE IF NOT EXISTS {table_name} ( id SERIAL PRIMARY KEY, document TEXT NOT NULL, - metadata JSONB, embedding VECTOR(384) ) """ @@ -226,7 +239,6 @@ def insert_memory(self, category, document, metadata={}, embedding=None, id=None self._ensure_metadata_columns_exist(category, metadata) table_name = self._table_name(category) - metadata_string = json.dumps(metadata) # Convert the dict to a JSON string if embedding is None: embedding = self.create_embedding(document) @@ -234,11 +246,16 @@ def insert_memory(self, category, document, metadata={}, embedding=None, id=None if id is None: id = self.get_or_create_collection(category).count() + # Extracting the keys and values from metadata to insert them into respective columns + columns = ["id", "document", "embedding"] + list(metadata.keys()) + placeholders = ["%s"] * len(columns) + values = [id, document, embedding] + list(metadata.values()) + query = f""" - INSERT INTO {table_name} (id, document, metadata, embedding) VALUES (%s, %s, %s, %s) + INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({', '.join(placeholders)}) RETURNING id; """ - self.cur.execute(query, (id, document, metadata_string, embedding)) + self.cur.execute(query, tuple(values)) self.connection.commit() return self.cur.fetchone()[0] @@ -252,18 +269,21 @@ def add(self, category, documents, metadatas, ids): with self.connection.cursor() as cur: for document, metadata, id_ in zip(documents, metadatas, ids): self._ensure_metadata_columns_exist(category, metadata) + + columns = ["id", "document", "embedding"] + list(metadata.keys()) + placeholders = ["%s"] * len(columns) embedding = self.create_embedding(document) - cur.execute( - f""" - INSERT INTO {table_name} (id, document, metadata, embedding) - VALUES (%s, %s %s, %s) - """, - (id_, document, metadata, embedding), - ) + values = [id_, document, embedding] + list(metadata.values()) + + query = f""" + INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({', '.join(placeholders)}); + """ + cur.execute(query, tuple(values)) self.connection.commit() def query(self, category, query_texts, n_results=5): - embeddings = [self.create_embedding(q) for q in query_texts] + self.ensure_table_exists(category) + table_name = self._table_name(category) results = { "ids": [], "documents": [], @@ -271,26 +291,34 @@ def query(self, category, query_texts, n_results=5): "embeddings": [], "distances": [], } - self.ensure_table_exists(category) - table_name = self._table_name(category) with self.connection.cursor() as cur: - for emb in embeddings: + for emb in query_texts: + query_emb = self.create_embedding(emb) cur.execute( f""" - SELECT id, document, metadata, embedding, embedding <-> %s AS distance + SELECT id, document, embedding, embedding <-> %s AS distance, * FROM {table_name} ORDER BY embedding <-> %s LIMIT %s """, - (emb, emb, n_results), + (query_emb, query_emb, n_results), ) rows = cur.fetchall() + columns = [desc[0] for desc in cur.description] + metadata_columns = [ + col + for col in columns + if col not in ["id", "document", "embedding", "distance"] + ] for row in rows: results["ids"].append(row[0]) results["documents"].append(row[1]) - results["metadatas"].append(row[2]) - results["embeddings"].append(row[3]) - results["distances"].append(row[4]) + results["embeddings"].append(row[2]) + results["distances"].append(row[3]) + metadata = { + col: row[columns.index(col)] for col in metadata_columns + } + results["metadatas"].append(metadata) return results def update(self, category, id_, document=None, metadata=None, embedding=None): @@ -298,30 +326,36 @@ def update(self, category, id_, document=None, metadata=None, embedding=None): table_name = self._table_name(category) with self.connection.cursor() as cur: if document: - # if metadata is a dict, convert it to a JSON string - if isinstance(metadata, dict): - metadata = json.dumps(metadata) if embedding is None: embedding = self.create_embedding(document) - cur.execute( - f""" - UPDATE {table_name} - SET document=%s, embedding=%s, metadata=%s - WHERE id=%s - """, - (document, embedding, metadata, id_), - ) - else: - cur.execute( - f""" - UPDATE {table_name} - SET metadata=%s - WHERE id=%s - """, - (metadata, id_), - ) + if metadata: + self._ensure_metadata_columns_exist(category, metadata) + columns = ["document=%s", "embedding=%s"] + [ + f"{key}=%s" for key in metadata.keys() + ] + values = [document, embedding] + list(metadata.values()) + else: + columns = ["document=%s", "embedding=%s"] + values = [document, embedding] + + query = f""" + UPDATE {table_name} + SET {', '.join(columns)} + WHERE id=%s + """ + cur.execute(query, tuple(values) + (id_,)) + elif metadata: + self._ensure_metadata_columns_exist(category, metadata) + columns = [f"{key}=%s" for key in metadata.keys()] + values = list(metadata.values()) + query = f""" + UPDATE {table_name} + SET {', '.join(columns)} + WHERE id=%s + """ + cur.execute(query, tuple(values) + (id_,)) self.connection.commit() def close(self): self.cur.close() - self.connection.close() + self.connection.close() diff --git a/agentmemory/tests/events.py b/agentmemory/tests/events.py index c21d82d..511bbe5 100644 --- a/agentmemory/tests/events.py +++ b/agentmemory/tests/events.py @@ -34,7 +34,8 @@ def test_create_event(): event = get_events()[0] assert event["document"] == "test event" assert event["metadata"]["test"] == "test" - assert event["metadata"]["epoch"] == 1 + print(event["metadata"]) + assert int(event["metadata"]["epoch"]) == 1 wipe_category("events") wipe_category("epoch") diff --git a/agentmemory/tests/main.py b/agentmemory/tests/main.py index f9bb8d9..a87659a 100644 --- a/agentmemory/tests/main.py +++ b/agentmemory/tests/main.py @@ -53,8 +53,17 @@ def test_memory_update(): memories = get_memories("test") memory_id = memories[0]["id"] + update_memory("test", memory_id, "doc 1 updated no", metadata={"test": "test"}) update_memory("test", memory_id, "doc 1 updated", metadata={"test": "test"}) + assert get_memory("test", memory_id)["document"] == "doc 1 updated" + + create_memory("test", "new memory test", metadata={"test": "test"}) + memories = get_memories("test") + memory_id = memories[0]["id"] + update_memory("test", memory_id, "doc 2 updated", metadata={"test": "test"}) + assert get_memory("test", memory_id)["document"] == "doc 2 updated" + wipe_category("test") diff --git a/setup.py b/setup.py index a986637..f0c69b7 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( name='agentmemory', - version='0.4.2', + version='0.4.3', description='Easy-to-use memory for agents, document search, knowledge graphing and more.', long_description=long_description, # added this line long_description_content_type="text/markdown", # and this line