diff --git a/pyproject.toml b/pyproject.toml index c545614..4dfe7a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,18 +7,20 @@ packages=["ragdaemon"] [project] name = "ragdaemon" -version = "0.8.3" +version = "0.9.0" description = "Generate and render a call graph for a Python project." readme = "README.md" dependencies = [ "astroid==3.2.2", - "chromadb==0.4.24", + "asyncpg==0.29.0", "dict2xml==1.7.5", "docker==7.1.0", "fastapi==0.109.2", "Jinja2==3.1.3", "networkx==3.2.1", + "pgvector==0.3.2", "psycopg2-binary==2.9.9", + "python-dotenv", "rank_bm25==0.2.2", "sqlalchemy==2.0.30", "spiceai~=0.3.0", diff --git a/ragdaemon/__init__.py b/ragdaemon/__init__.py index 732155f..3e2f46a 100644 --- a/ragdaemon/__init__.py +++ b/ragdaemon/__init__.py @@ -1 +1 @@ -__version__ = "0.8.3" +__version__ = "0.9.0" diff --git a/ragdaemon/annotators/call_graph.py b/ragdaemon/annotators/call_graph.py index aefb655..8a25958 100644 --- a/ragdaemon/annotators/call_graph.py +++ b/ragdaemon/annotators/call_graph.py @@ -9,7 +9,7 @@ from tqdm.asyncio import tqdm from ragdaemon.annotators.base_annotator import Annotator -from ragdaemon.database import Database, remove_update_db_duplicates +from ragdaemon.database import Database from ragdaemon.errors import RagdaemonError from ragdaemon.graph import KnowledgeGraph from ragdaemon.utils import ( @@ -229,7 +229,6 @@ async def annotate( update_db["ids"].append(data["checksum"]) metadatas = {self.call_field_id: json.dumps(data[self.call_field_id])} update_db["metadatas"].append(metadatas) - update_db = remove_update_db_duplicates(**update_db) db.update(**update_db) # Add call edges to graph. Each call should have only ONE source; if there are diff --git a/ragdaemon/annotators/chunker/__init__.py b/ragdaemon/annotators/chunker/__init__.py index 800090b..d8d75c7 100644 --- a/ragdaemon/annotators/chunker/__init__.py +++ b/ragdaemon/annotators/chunker/__init__.py @@ -1,6 +1,5 @@ import asyncio import json -from copy import deepcopy from functools import partial from pathlib import Path from typing import Optional, Set @@ -13,11 +12,7 @@ from ragdaemon.annotators.chunker.chunk_line import chunk_document as chunk_line from ragdaemon.annotators.chunker.chunk_llm import chunk_document as chunk_llm from ragdaemon.annotators.chunker.utils import resolve_chunk_parent -from ragdaemon.database import ( - Database, - remove_add_to_db_duplicates, - remove_update_db_duplicates, -) +from ragdaemon.database import Database from ragdaemon.errors import RagdaemonError from ragdaemon.graph import KnowledgeGraph from ragdaemon.utils import ( @@ -145,7 +140,6 @@ async def annotate( update_db["ids"].append(data["checksum"]) metadatas = {self.chunk_field_id: json.dumps(data[self.chunk_field_id])} update_db["metadatas"].append(metadatas) - update_db = remove_update_db_duplicates(**update_db) db.update(**update_db) # Process chunks @@ -189,22 +183,19 @@ async def annotate( ids = list(set(checksums.values())) response = db.get(ids=ids, include=["metadatas"]) db_data = {id: data for id, data in zip(response["ids"], response["metadatas"])} - add_to_db = {"ids": [], "documents": [], "metadatas": []} + add_to_db = {"ids": [], "documents": []} for node, checksum in checksums.items(): if checksum in db_data: data = db_data[checksum] graph.nodes[node].update(data) else: - data = deepcopy(graph.nodes[node]) - document = data.pop("document") + document = graph.nodes[node].get("document") document, truncate_ratio = truncate(document, db.embedding_model) if truncate_ratio > 0 and self.verbose > 1: print(f"Truncated {node} by {truncate_ratio:.2%}") add_to_db["ids"].append(checksum) add_to_db["documents"].append(document) - add_to_db["metadatas"].append(data) if len(add_to_db["ids"]) > 0: - add_to_db = remove_add_to_db_duplicates(**add_to_db) db.add(**add_to_db) return graph diff --git a/ragdaemon/annotators/diff.py b/ragdaemon/annotators/diff.py index fc41025..1f753e7 100644 --- a/ragdaemon/annotators/diff.py +++ b/ragdaemon/annotators/diff.py @@ -1,9 +1,8 @@ import json import re -from copy import deepcopy from ragdaemon.annotators.base_annotator import Annotator -from ragdaemon.database import Database, remove_add_to_db_duplicates +from ragdaemon.database import Database from ragdaemon.graph import KnowledgeGraph from ragdaemon.errors import RagdaemonError from ragdaemon.utils import ( @@ -150,10 +149,11 @@ async def annotate( for id, checksum in checksums.items(): if checksum in db_data: continue - data = deepcopy(graph.nodes[id]) - document = data.pop("document") - if "chunks" in data: - data["chunks"] = json.dumps(data["chunks"]) + data = {} + document = graph.nodes[id].get("document") + chunks = graph.nodes[id].get("chunks") + if chunks: + data["chunks"] = json.dumps(chunks) document, truncate_ratio = truncate(document, db.embedding_model) if self.verbose > 1 and truncate_ratio > 0: print(f"Truncated {id} by {truncate_ratio:.2%}") @@ -161,7 +161,6 @@ async def annotate( add_to_db["documents"].append(document) add_to_db["metadatas"].append(data) if len(add_to_db["ids"]) > 0: - add_to_db = remove_add_to_db_duplicates(**add_to_db) db.add(**add_to_db) return graph diff --git a/ragdaemon/annotators/hierarchy.py b/ragdaemon/annotators/hierarchy.py index 2580b44..d5ce6ae 100644 --- a/ragdaemon/annotators/hierarchy.py +++ b/ragdaemon/annotators/hierarchy.py @@ -1,8 +1,7 @@ -from copy import deepcopy from pathlib import Path from ragdaemon.annotators.base_annotator import Annotator -from ragdaemon.database import Database, remove_add_to_db_duplicates +from ragdaemon.database import Database from ragdaemon.graph import KnowledgeGraph from ragdaemon.errors import RagdaemonError from ragdaemon.utils import get_document, hash_str, truncate @@ -93,22 +92,19 @@ async def annotate( ids = list(set(checksums.values())) response = db.get(ids=ids, include=["metadatas"]) db_data = {id: data for id, data in zip(response["ids"], response["metadatas"])} - add_to_db = {"ids": [], "documents": [], "metadatas": []} + add_to_db = {"ids": [], "documents": []} for path, checksum in checksums.items(): if checksum in db_data: data = db_data[checksum] graph.nodes[path.as_posix()].update(data) else: - data = deepcopy(graph.nodes[path.as_posix()]) - document = data.pop("document") + document = graph.nodes[path.as_posix()]["document"] document, truncate_ratio = truncate(document, db.embedding_model) if self.verbose > 1 and truncate_ratio > 0: print(f"Truncated {path} by {truncate_ratio:.2%}") add_to_db["ids"].append(checksum) add_to_db["documents"].append(document) - add_to_db["metadatas"].append(data) if len(add_to_db["ids"]) > 0: - add_to_db = remove_add_to_db_duplicates(**add_to_db) db.add(**add_to_db) return graph diff --git a/ragdaemon/annotators/layout_hierarchy.py b/ragdaemon/annotators/layout_hierarchy.py index be59c6b..554a34f 100644 --- a/ragdaemon/annotators/layout_hierarchy.py +++ b/ragdaemon/annotators/layout_hierarchy.py @@ -85,6 +85,10 @@ def iterate(iteration: int): class LayoutHierarchy(Annotator): name = "layout_hierarchy" + def __init__(self, *args, iterations: int = 40, **kwargs): + super().__init__(*args, **kwargs) + self.iterations = iterations + def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool: # Check that they have data.layout.hierarchy for node, data in graph.nodes(data=True): @@ -99,15 +103,14 @@ async def annotate( graph: KnowledgeGraph, db: Database, refresh: str | bool = False, - iterations: int = 40, ) -> KnowledgeGraph: """ a. Regenerate x/y/z for all nodes b. Update all nodes - c. Save to chroma + c. Save to db """ pos = fruchterman_reingold_3d( - graph, iterations=iterations, verbose=self.verbose + graph, iterations=self.iterations, verbose=self.verbose ) for node_id, coordinates in pos.items(): node = graph.nodes[node_id] diff --git a/ragdaemon/annotators/summarizer.py b/ragdaemon/annotators/summarizer.py index eddf402..7f1eb76 100644 --- a/ragdaemon/annotators/summarizer.py +++ b/ragdaemon/annotators/summarizer.py @@ -9,7 +9,7 @@ from ragdaemon.annotators.base_annotator import Annotator from ragdaemon.context import ContextBuilder -from ragdaemon.database import Database, remove_update_db_duplicates +from ragdaemon.database import Database from ragdaemon.graph import KnowledgeGraph from ragdaemon.errors import RagdaemonError from ragdaemon.io import IO @@ -198,7 +198,6 @@ def get_chunk_summaries(target: str) -> list[str]: class Summarizer(Annotator): name = "summarizer" summary_field_id = "summary" - checksum_field_id = "summary_checksum" def __init__( self, @@ -226,13 +225,6 @@ def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool: raise RagdaemonError(f"Node {node} missing checksum.") if data.get(self.summary_field_id) is None: return False - # Checksum used to be hash_str(document + context) using the above method. This is - # technically more correct, because the summary context includes adjacent summaries - # so the whole system updates iteratively. In practice it was just too much looping - # so for now we just reuse the checksum generated in hierarchy (hash_str(document)). - summary_checksum = data["checksum"] - if summary_checksum != data.get(self.checksum_field_id): - return False return True async def generate_summary( @@ -247,13 +239,8 @@ async def generate_summary( raise RagdaemonError("Spice client not initialized") data = graph.nodes[node] - summary_checksum = data["checksum"] _refresh = match_refresh(refresh, node) - if ( - _refresh - or data.get(self.summary_field_id) is None - or summary_checksum != data.get(self.checksum_field_id) - ): + if _refresh or data.get(self.summary_field_id) is None: document, context = get_document_and_context( node, graph, @@ -283,7 +270,6 @@ async def generate_summary( if summary != "PASS": data[self.summary_field_id] = summary - data[self.checksum_field_id] = summary_checksum if loading_bar is not None: loading_bar.update(1) @@ -311,31 +297,27 @@ async def annotate( self, graph: KnowledgeGraph, db: Database, refresh: str | bool = False ) -> KnowledgeGraph: """Asynchronously generate or fetch summaries and add to graph/db""" - summaries = dict[str, str]() + nodes_to_summarize: set[str] = set() for node, data in graph.nodes(data=True): if data is not None and data.get("type") in self.summarize_nodes: - summaries[node] = data.get(self.checksum_field_id, "") + nodes_to_summarize.add(node) if self.verbose > 1: - loading_bar = tqdm(total=len(summaries), desc="Summarizing code...") + loading_bar = tqdm( + total=len(nodes_to_summarize), desc="Summarizing code..." + ) else: loading_bar = None await self.dfs("ROOT", graph, loading_bar, refresh) update_db = {"ids": [], "metadatas": []} - for node, summary_checksum in summaries.items(): - if graph.nodes[node].get(self.checksum_field_id) != summary_checksum: - data = graph.nodes[node] - update_db["ids"].append(data["checksum"]) - update_db["metadatas"].append( - { - self.summary_field_id: data[self.summary_field_id], - self.checksum_field_id: data[self.checksum_field_id], - } - ) + for node in nodes_to_summarize: + data = graph.nodes[node] + update_db["ids"].append(data["checksum"]) + metadatas = {self.summary_field_id: data[self.summary_field_id]} + update_db["metadatas"].append(metadatas) if len(update_db["ids"]) > 1: - update_db = remove_update_db_duplicates(**update_db) db.update(**update_db) if loading_bar is not None: diff --git a/ragdaemon/app.py b/ragdaemon/app.py index 42da668..ff3a524 100644 --- a/ragdaemon/app.py +++ b/ragdaemon/app.py @@ -1,7 +1,6 @@ import argparse import asyncio import socket -import webbrowser from contextlib import asynccontextmanager from pathlib import Path from typing import Any @@ -40,9 +39,8 @@ annotators = { "hierarchy": {}, "chunker": {"use_llm": True}, - # "summarizer": {}, - # "clusterer_binary": {}, - # "call_graph": {"call_extensions": code_extensions}, + "call_graph": {"call_extensions": code_extensions}, + "summarizer": {}, "diff": {"diff": diff}, "layout_hierarchy": {}, } @@ -117,9 +115,9 @@ async def main(): print(f"Starting server on port {port}...") server = uvicorn.Server(config) - async def _wait_1s_then_open_browser(): - await asyncio.sleep(1) - webbrowser.open(f"http://localhost:{port}") + # async def _wait_1s_then_open_browser(): + # await asyncio.sleep(1) + # webbrowser.open(f"http://localhost:{port}") - asyncio.create_task(_wait_1s_then_open_browser()) + # asyncio.create_task(_wait_1s_then_open_browser()) await server.serve() diff --git a/ragdaemon/daemon.py b/ragdaemon/daemon.py index 49fc841..adcb5d3 100644 --- a/ragdaemon/daemon.py +++ b/ragdaemon/daemon.py @@ -13,12 +13,17 @@ from ragdaemon.annotators import annotators_map from ragdaemon.cerebrus import cerebrus from ragdaemon.context import ContextBuilder -from ragdaemon.database import DEFAULT_EMBEDDING_MODEL, Database, get_db +from ragdaemon.database import Database, get_db from ragdaemon.errors import RagdaemonError from ragdaemon.graph import KnowledgeGraph from ragdaemon.io import DockerIO, IO, LocalIO from ragdaemon.locate import locate -from ragdaemon.utils import DEFAULT_COMPLETION_MODEL, match_refresh, mentat_dir_path +from ragdaemon.utils import ( + DEFAULT_COMPLETION_MODEL, + DEFAULT_EMBEDDING_MODEL, + match_refresh, + mentat_dir_path, +) def default_annotators(): @@ -61,7 +66,7 @@ def __init__( if spice_client is None: spice_client = Spice( default_text_model=DEFAULT_COMPLETION_MODEL, - default_embeddings_model=model, + default_embeddings_model=DEFAULT_EMBEDDING_MODEL, logging_dir=logging_dir, ) self.spice_client = spice_client diff --git a/ragdaemon/database/__init__.py b/ragdaemon/database/__init__.py index 7ef778e..f5a3363 100644 --- a/ragdaemon/database/__init__.py +++ b/ragdaemon/database/__init__.py @@ -3,20 +3,9 @@ from spice import Spice -from ragdaemon.database.chroma_database import ( - # ChromaDB, - remove_add_to_db_duplicates, # noqa: F401 - remove_update_db_duplicates, # noqa: F401 -) from ragdaemon.database.database import Database - -# from ragdaemon.database.chroma_database import ChromaDB from ragdaemon.database.lite_database import LiteDB - -# from ragdaemon.database.pg_database import PGDB -from ragdaemon.utils import mentat_dir_path - -DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large" +from ragdaemon.database.pg_database import PGDB def get_db( @@ -25,26 +14,16 @@ def get_db( embedding_provider: Optional[str] = None, verbose: int = 0, ) -> Database: - db_path = mentat_dir_path / "chroma" - db_path.mkdir(parents=True, exist_ok=True) - # if embedding_model is not None and "PYTEST_CURRENT_TEST" not in os.environ: - # try: - # # db = ChromaDB( - # # db_path=db_path, - # # spice_client=spice_client, - # # embedding_model=embedding_model, - # # embedding_provider=embedding_provider, - # # verbose=verbose, - # # ) - # # # In case the api key is wrong, try to embed something to trigger an error. - # # _ = db.add(ids="test", documents="test doc") - # # db.delete(ids="test") - # db = PGDB(db_path=db_path, verbose=verbose) - # return db - # except Exception as e: - # if verbose > 1: - # print( - # f"Failed to initialize Postgres Database: {e}. Falling back to LiteDB." - # ) - # pass - return LiteDB(db_path=db_path, verbose=verbose) + if embedding_model is not None and "PYTEST_CURRENT_TEST" not in os.environ: + try: + db = PGDB( + spice_client, embedding_model, embedding_provider, verbose=verbose + ) + return db + except Exception as e: + if verbose > 1: + print( + f"Failed to initialize Postgres Database: {e}. Falling back to LiteDB." + ) + pass + return LiteDB(verbose=verbose) diff --git a/ragdaemon/database/chroma_database.py b/ragdaemon/database/chroma_database.py deleted file mode 100644 index 5d70c9f..0000000 --- a/ragdaemon/database/chroma_database.py +++ /dev/null @@ -1,129 +0,0 @@ -import os -from pathlib import Path -from typing import Any, Optional, cast - -import dotenv -from spice import Spice - -from ragdaemon import __version__ -from ragdaemon.database.database import Database -from ragdaemon.errors import RagdaemonError -from ragdaemon.utils import basic_auth - -MAX_INPUTS_PER_CALL = 2048 - - -def remove_add_to_db_duplicates( - ids: list[str], documents: list[str], metadatas: list[dict] -) -> dict[str, Any]: - seen = set() - output = {"ids": [], "documents": [], "metadatas": []} - for id, document, metadata in zip(ids, documents, metadatas): - if id not in seen: - output["ids"].append(id) - output["documents"].append(document) - output["metadatas"].append(metadata) - seen.add(id) - return output - - -def remove_update_db_duplicates( - ids: list[str], metadatas: list[dict] -) -> dict[str, Any]: - seen = set() - output = {"ids": [], "metadatas": []} - for id, metadata in zip(ids, metadatas): - if id not in seen: - output["ids"].append(id) - output["metadatas"].append(metadata) - seen.add(id) - return output - - -class ChromaDB(Database): - def __init__( - self, - db_path: Path, - spice_client: Spice, - embedding_model: str, - embedding_provider: Optional[str] = None, - verbose: int = 0, - ) -> None: - self.db_path = db_path - self.embedding_model = embedding_model - self.verbose = verbose - - import chromadb # Imports are slow so do it lazily - from chromadb.api.types import ( - Embeddable, - EmbeddingFunction, - Embeddings, - ) - from chromadb.config import Settings - - class SpiceEmbeddingFunction(EmbeddingFunction[Embeddable]): - def __call__(self, input_texts: Embeddable) -> Embeddings: - if not all(isinstance(item, str) for item in input_texts): - raise RagdaemonError("SpiceEmbeddings only enabled for text files.") - input_texts = cast(list[str], input_texts) - # Embed in batches - n_batches = (len(input_texts) - 1) // MAX_INPUTS_PER_CALL + 1 - output: Embeddings = [] - for batch in range(n_batches): - start = batch * MAX_INPUTS_PER_CALL - end = min((batch + 1) * MAX_INPUTS_PER_CALL, len(input_texts)) - embeddings = spice_client.get_embeddings_sync( - input_texts=input_texts[start:end], - model=embedding_model, - provider=embedding_provider, - ).embeddings - output.extend(embeddings) - return output - - embedding_function = SpiceEmbeddingFunction() - - dotenv.load_dotenv() - - try: - host = os.environ["CHROMA_SERVER_HOST"] - port = int(os.environ.get("CHROMA_SERVER_HTTP_PORT", 443)) - username = os.environ["CHROMA_SERVER_USERNAME"] - password = os.environ["CHROMA_SERVER_PASSWORD"] - _client = chromadb.HttpClient( - host=host, - port=port, - ssl=port == 443, - headers={"Authorization": basic_auth(username, password)}, - settings=Settings(allow_reset=True, anonymized_telemetry=False), - ) - except KeyError: - if self.verbose > 0: - print( - "No Chroma HTTP client environment variables found. Defaulting to PersistentClient." - ) - _client = chromadb.PersistentClient(path=str(db_path)) - - minor_version = ".".join(__version__.split(".")[:2]) - name = f"ragdaemon-{minor_version}-{self.embedding_model}" - self._collection = _client.get_or_create_collection( - name=name, - embedding_function=embedding_function, - ) - - def query(self, query: str, active_checksums: list[str]) -> list[dict]: - response = self._collection.query( - query_texts=query, - where={"checksum": {"$in": active_checksums}}, # type: ignore - n_results=len(active_checksums), - include=["distances"], - ) - ids = response["ids"] - distances = response["distances"] - if not ids or not distances: - return [] - results = [ - {"checksum": id, "distance": distance} - for id, distance in zip(ids[0], distances[0]) - ] - results = sorted(results, key=lambda x: x["distance"]) - return results diff --git a/ragdaemon/database/database.py b/ragdaemon/database/database.py index 6d4d306..6ea3d97 100644 --- a/ragdaemon/database/database.py +++ b/ragdaemon/database/database.py @@ -6,16 +6,27 @@ class Database: embedding_model: str | None = None - _collection = None # Collection | LiteDB def __init__(self, db_path: Path) -> None: raise NotImplementedError - def __getattr__(self, name): - """Delegate attribute access to the collection.""" - return getattr(self._collection, name) + def add( + self, + ids: list[str], + documents: list[str], + metadatas: Optional[list[dict]] = None, + ): + # NOTE: In the past we had issues with duplicates. LiteDB doesn't mind, but PGDB might. + raise NotImplementedError + + def update(self, ids: list[str], metadatas: list[dict]): + # NOTE: Same as above re: duplicates + raise NotImplementedError + + def get(self, ids: list[str], include: Optional[list[str]] = None) -> dict: + raise NotImplementedError - def query(self, query: str, active_checksums: list[str]) -> list[dict]: + def query(self, query: str, active_checksums: set[str]) -> list[dict]: raise NotImplementedError def query_graph( @@ -43,7 +54,7 @@ def query_graph( for node, data in graph.nodes(data=True) if data and "checksum" in data and data["type"] in node_types } - response = self.query(query, list(checksum_index.keys())) + response = self.query(query, set(checksum_index.keys())) # Add (local) metadata to results results = list[dict[str, Any]]() diff --git a/ragdaemon/database/lite_database.py b/ragdaemon/database/lite_database.py index c31c36f..3df0318 100644 --- a/ragdaemon/database/lite_database.py +++ b/ragdaemon/database/lite_database.py @@ -1,5 +1,4 @@ -from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, TypedDict from rank_bm25 import BM25Okapi @@ -10,35 +9,24 @@ def tokenize(document: str) -> list[str]: return document.split() -class LiteDB(Database): - def __init__(self, db_path: Path, verbose: int = 0): - self.db_path = db_path - self.verbose = verbose - self._collection = LiteCollection(self.verbose) - - def query(self, query: str, active_checksums: list[str]) -> list[dict]: - return self._collection.query(query, active_checksums) - +class Document(TypedDict): + checksum: str + chunks: Optional[list[dict[str, str]]] + summary: Optional[str] + embedding: Optional[list[float]] -class LiteCollection: - """A fast alternative to ChromaDB for testing (and anything else). - Matches the chroma Collection API except: - - No embeddings - - In-memory - - Query returns all distances=1 - """ +class LiteDB(Database): + """A fast alternative to Embeddings DB for testing (and anything else).""" bm25: BM25Okapi bm25_index: list[str] def __init__(self, verbose: int = 0): - self.data = dict[str, dict[str, Any]]() # {id: {metadatas, document}} self.verbose = verbose + self.data = dict[str, dict[str, Any]]() # {id: {metadatas, document}} - def get(self, ids: list[str] | str, include: Optional[list[str]] = None) -> dict: - if isinstance(ids, str): - ids = [ids] + def get(self, ids: list[str], include: Optional[list[str]] = None) -> dict: output = {"ids": [], "metadatas": [], "documents": []} for id in ids: if id in self.data: @@ -52,15 +40,13 @@ def get(self, ids: list[str] | str, include: Optional[list[str]] = None) -> dict def count(self) -> int: return len(self.data) - def update(self, ids: list[str] | str, metadatas: list[dict] | dict): - ids = [ids] if isinstance(ids, str) else ids - metadatas = [metadatas] if isinstance(metadatas, dict) else metadatas + def update(self, ids: list[str], metadatas: list[dict]): for checksum, metadata in zip(ids, metadatas): if checksum not in self.data: raise ValueError(f"Record {checksum} does not exist.") self.data[checksum]["metadatas"] = metadata - def query(self, query: str, active_checksums: list[str]) -> list[dict]: + def query(self, query: str, active_checksums: set[str]) -> list[dict]: scores = self.bm25.get_scores(tokenize(query)) max_score = max(scores) if max_score > 0: @@ -76,13 +62,12 @@ def query(self, query: str, active_checksums: list[str]) -> list[dict]: def add( self, - ids: list[str] | str, - metadatas: list[dict] | dict, - documents: list[str] | str, - ) -> list[str]: - ids = [ids] if isinstance(ids, str) else ids - metadatas = [metadatas] if isinstance(metadatas, dict) else metadatas - documents = [documents] if isinstance(documents, str) else documents + ids: list[str], + documents: list[str], + metadatas: Optional[list[dict]] = None, + ): + if metadatas is None: + metadatas = [{} for _ in range(len(ids))] for checksum, metadata, document in zip(ids, metadatas, documents): existing_metadata = self.data.get(checksum, {}).get("metadatas", {}) metadata = {**existing_metadata, **metadata} @@ -95,5 +80,3 @@ def add( documents.append(data["document"]) self.bm25 = BM25Okapi([tokenize(document) for document in documents]) self.bm25_index = ids - - return ids diff --git a/ragdaemon/database/pg_database.py b/ragdaemon/database/pg_database.py index d4c96e7..bfc2afa 100644 --- a/ragdaemon/database/pg_database.py +++ b/ragdaemon/database/pg_database.py @@ -1,26 +1,14 @@ -import json -import os -from collections import defaultdict -from typing import Dict, Optional +from typing import Any, Optional -from sqlalchemy import create_engine -from sqlalchemy.exc import OperationalError -from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column -from typing_extensions import override +from pgvector.sqlalchemy import Vector +from psycopg2 import OperationalError +from spice import Spice +from sqlalchemy import select, func -from ragdaemon.database.lite_database import LiteCollection, LiteDB - - -class Base(DeclarativeBase): - pass - - -class DocumentMetadata(Base): - __tablename__ = "document_metadata" - - id: Mapped[str] = mapped_column(primary_key=True) - # We serialize whatever we get, which can be 'null', so we need Optional - chunks: Mapped[Optional[str]] +from ragdaemon.database.database import Database +from ragdaemon.database.postgres import DocumentMetadata, get_database_session_sync +from ragdaemon.errors import RagdaemonError +from ragdaemon.utils import MAX_INPUTS_PER_CALL def retry_on_exception(retries: int = 3, exceptions={OperationalError}): @@ -39,148 +27,103 @@ def wrapper(*args, **kwargs): return decorator -class Engine: - def __init__(self, verbose: int = 0): - database = "ragdaemon" - host = os.environ.get("RAGDAEMON_DB_ENDPOINT", None) - port = os.environ.get("RAGDAEMON_DB_PORT", 5432) - username = os.environ.get("RAGDAEMON_DB_USERNAME", None) - password = os.environ.get("RAGDAEMON_DB_PASSWORD", None) - - if host is None or username is None or password is None: - raise ValueError( - "Missing ragdaemon environment variables: cannot use PGDB." - ) - - url = f"postgresql+psycopg2://{username}:{password}@{host}:{port}/{database}" - self.engine = create_engine(url) - if verbose > 1: - print("Connected to PGDB.") +class PGDB(Database): + """Implementation of Database with embeddings search using PostgreSQL.""" - def migrate(self): - if input( - "Migrating will clear the database. ALL DATA WILL BE LOST. Proceed (Y/n)? " - ).lower().strip() in [ - "", - "y", - ]: - Base.metadata.drop_all(self.engine) - Base.metadata.create_all(self.engine) - print("PGDB migrated successfully.") + def __init__( + self, + spice_client: Spice, + embedding_model: str | None = None, + embedding_provider: Optional[str] = None, + verbose: int = 0, + ): + self.verbose = verbose + SessionLocal = get_database_session_sync() + with SessionLocal() as session: + query = select(func.count(DocumentMetadata.id)) + count = session.execute(query).scalar() + if self.verbose > 0: + print(f"Initialized PGDB with {count} documents.") + + def embed_documents(input_texts: list[str]) -> list[list[float]]: + if not all(isinstance(item, str) for item in input_texts): + raise RagdaemonError("SpiceEmbeddings only enabled for text files.") + # Embed in batches + n_batches = (len(input_texts) - 1) // MAX_INPUTS_PER_CALL + 1 + output: list[list[float]] = [] + for batch in range(n_batches): + start = batch * MAX_INPUTS_PER_CALL + end = min((batch + 1) * MAX_INPUTS_PER_CALL, len(input_texts)) + embeddings = spice_client.get_embeddings_sync( + input_texts=input_texts[start:end], + model=embedding_model, + provider=embedding_provider, + ).embeddings + output.extend(embeddings) + return output + + self.embed_documents = embed_documents @retry_on_exception() - def add_document_metadata(self, ids: str | list[str], metadatas: Dict | list[Dict]): - ids = ids if isinstance(ids, list) else [ids] - metadatas = metadatas if isinstance(metadatas, list) else [metadatas] - if len(ids) != len(metadatas): - raise ValueError("ids and metadatas must have the same length.") - with Session(self.engine) as session: + def add( + self, + ids: list[str], + documents: list[str], + metadatas: Optional[list[dict]] = None, + ): + if metadatas is None: + metadatas = [{} for _ in range(len(ids))] + embeddings = self.embed_documents(documents) + metadatas = [ + {**meta, "embedding": emb} for meta, emb in zip(metadatas, embeddings) + ] + SessionLocal = get_database_session_sync() + with SessionLocal() as session: for id, metadata in zip(ids, metadatas): - serialized_metadata = {} - for k, v in metadata.items(): - if not isinstance(v, str): - v = json.dumps(v) - serialized_metadata[k] = v - metadata_object = DocumentMetadata(id=id, **serialized_metadata) - session.add(metadata_object) + session.add(DocumentMetadata(id=id, **metadata)) session.commit() @retry_on_exception() - def update_document_metadata( - self, ids: str | list[str], metadatas: Dict | list[Dict] - ): - ids = ids if isinstance(ids, list) else [ids] - metadatas = metadatas if isinstance(metadatas, list) else [metadatas] - if len(ids) != len(metadatas): - raise ValueError("ids and metadatas must have the same length.") - with Session(self.engine) as session: + def update(self, ids: list[str], metadatas: list[dict]): + SessionLocal = get_database_session_sync() + with SessionLocal() as session: for id, metadata in zip(ids, metadatas): - metadata_object = session.get(DocumentMetadata, id) - if metadata_object is None: - metadata_object = DocumentMetadata(id=id) - session.add(metadata_object) - for k, v in metadata.items(): - if not isinstance(v, str): - v = json.dumps(v) - setattr(metadata_object, k, v) + session.query(DocumentMetadata).filter( + DocumentMetadata.id == id + ).update(metadata) session.commit() @retry_on_exception() - def get_document_metadata(self, ids: str | list[str]) -> Dict[str, Dict]: - if not isinstance(ids, list): - ids = [ids] - with Session(self.engine) as session: - metadata_objects = ( - session.query(DocumentMetadata) - .filter(DocumentMetadata.id.in_(ids)) - .all() - ) - result = dict[str, Dict]() - for object in metadata_objects: - id = object.id - serialized_metadata = object.__dict__.copy() - del serialized_metadata["_sa_instance_state"] - del serialized_metadata["id"] - result[id] = dict(serialized_metadata) # Deserialization logic is elsewhere - return result - - -class PGDB(LiteDB): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._collection = PGCollection(self.verbose) - - -class PGCollection(LiteCollection): - """Wraps a LiteDB and adds/gets targeted fields from a remote Postgres Database.""" - - def __init__(self, *args, fields: list[str] = ["chunks"], **kwargs): - super().__init__(*args, **kwargs) - self.engine = Engine(self.verbose) - self.fields = fields - - @override - def update(self, ids: list[str] | str, metadatas: list[dict] | dict): - remote_records = defaultdict(dict) - for id, metadata in zip(ids, metadatas): - for k, v in metadata.items(): - if k in self.fields: - remote_records[id][k] = v - self.engine.update_document_metadata( - ids=list(remote_records.keys()), metadatas=list(remote_records.values()) - ) - super().update(ids, metadatas) - - @override - def add( - self, - ids: list[str] | str, - metadatas: list[dict] | dict, - documents: list[str] | str, - ) -> list[str]: - remote_metadatas = self.engine.get_document_metadata(ids) - for id, metadata in zip(ids, metadatas): - if id in remote_metadatas: - metadata.update(remote_metadatas[id]) - return super().add(ids, metadatas, documents) - - @override def get( - self, - ids: list[str] | str, - include: list[str] | None = None, - ): - response = super().get(ids, include) - response_ids = response.get("ids", []) - if response_ids and include is not None and "metadatas" in include: - remote_metadatas = self.engine.get_document_metadata(response_ids) - for id, metadata in zip( - response.get("ids", []), response.get("metadatas", []) - ): - if id in remote_metadatas: - metadata.update(remote_metadatas[id]) - return response - + self, ids: list[str], include: Optional[list[str]] = None + ) -> dict[str, list[str] | list[dict] | list[Vector]]: + SessionLocal = get_database_session_sync() + with SessionLocal() as session: + query = select(DocumentMetadata).filter(DocumentMetadata.id.in_(ids)) + result = session.execute(query).scalars().all() + output: dict[str, list[str] | list[dict] | list[Vector]] = { + "ids": [doc.id for doc in result] + } + if include is None or "metadatas" in include: + output["metadatas"] = [ + doc.to_dict(exclude=["id", "embedding"]) for doc in result + ] + if include is None or "embeddings" in include: + output["embeddings"] = [doc.embedding for doc in result] + return output -if __name__ == "__main__": - Engine().migrate() + @retry_on_exception() + def query(self, query: str, active_checksums: set[str]) -> list[dict[str, Any]]: + query_embedding = self.embed_documents([query])[0] + SessionLocal = get_database_session_sync() + with SessionLocal() as session: + emb_query = select( + DocumentMetadata.id, + DocumentMetadata.embedding.cosine_distance(query_embedding), + ).where(DocumentMetadata.id.in_(active_checksums)) + result = session.execute(emb_query).all() + ordered = sorted(result, key=lambda x: x[1]) + return [ + {"checksum": checksum, "distance": distance} + for checksum, distance in ordered + ] diff --git a/ragdaemon/database/postgres.py b/ragdaemon/database/postgres.py new file mode 100644 index 0000000..dd4e480 --- /dev/null +++ b/ragdaemon/database/postgres.py @@ -0,0 +1,116 @@ +import os +from functools import cache +from typing import Optional + +from dotenv import load_dotenv +from pgvector.sqlalchemy import Vector +from sqlalchemy import Engine, create_engine, text +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import ( + DeclarativeBase, + Session, + sessionmaker, + class_mapper, + Mapped, + mapped_column, +) + +from ragdaemon.utils import EMBEDDING_DIMENSIONS + + +load_dotenv() + + +class Base(DeclarativeBase): + def to_dict(self, exclude: list[str] = []) -> dict: + result = {} + for key in class_mapper(self.__class__).columns.keys(): + if key not in exclude: + result[key] = getattr(self, key) + return result + + +class DocumentMetadata(Base): + __tablename__ = "document_metadata" + + id: Mapped[str] = mapped_column(primary_key=True) # Checksum of the document + embedding: Mapped[Vector] = mapped_column(Vector(EMBEDDING_DIMENSIONS)) + chunks: Mapped[Optional[str]] + calls: Mapped[Optional[str]] + summary: Mapped[Optional[str]] + + +@cache +def get_database_url(sync: bool = False) -> str: + database = "ragdaemon" + host = os.environ.get("RAGDAEMON_DB_ENDPOINT", None) + port = os.environ.get("RAGDAEMON_DB_PORT", 5432) + username = os.environ.get("RAGDAEMON_DB_USERNAME", None) + password = os.environ.get("RAGDAEMON_DB_PASSWORD", None) + + if host is None or username is None or password is None: + raise ValueError("Missing ragdaemon environment variables: cannot use PGDB.") + + if sync: + sync_string = "+psycopg2" + else: + sync_string = "+asyncpg" + + return f"postgresql{sync_string}://{username}:{password}@{host}:{port}/{database}" + + +@cache +def get_database_engine() -> AsyncEngine: + url = get_database_url() + return create_async_engine(url, echo=False) + + +@cache +def get_database_engine_sync() -> Engine: + url = get_database_url(sync=True) + return create_engine(url, echo=False) + + +def get_database_session() -> async_sessionmaker[AsyncSession]: + engine = get_database_engine() + return async_sessionmaker(autocommit=False, bind=engine, class_=AsyncSession) + + +def get_database_session_sync() -> sessionmaker[Session]: + engine = get_database_engine_sync() + return sessionmaker(autocommit=False, bind=engine, class_=Session) + + +if __name__ == "__main__": + if input( + "Migrating will clear the database. ALL DATA WILL BE LOST. Proceed (Y/n)? " + ).lower().strip() in [ + "", + "y", + ]: + SessionLocal = get_database_session_sync() + # Check if vector extension is installed + with SessionLocal() as session: + query = text("SELECT * FROM pg_extension WHERE extname = 'vector'") + result = session.execute(query).fetchone() + if result is None: + try: + session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) + except Exception as e: + raise Exception( + f"""\ +Failed to install pgvector extension: {e} +1. Install `pgvector` on your device: https://github.com/pgvector/pgvector +2. Enable the `vector` extension to the ragdaemon database: +https://github.com/pgvector/pgvector-python?tab=readme-ov-file#sqlalchemy +""" + ) + engine = get_database_engine_sync() + Base.metadata.drop_all(engine) + Base.metadata.create_all(engine) + print("PGDB migrated successfully.") diff --git a/ragdaemon/graph.py b/ragdaemon/graph.py index af775f9..2ac84fb 100644 --- a/ragdaemon/graph.py +++ b/ragdaemon/graph.py @@ -11,7 +11,7 @@ class NodeMetadata(TypedDict): ref: Optional[ str ] # Used to fetch document: path/to/file:start-end, diff_ref:start-end - checksum: Optional[str] # Unique identifier for chroma; sha256 of the document + checksum: Optional[str] # sha256 of the document chunks: Optional[ list[dict[str, str]] ] # For files, func/class/method. For diff, by file/hunk diff --git a/ragdaemon/io/file_like.py b/ragdaemon/io/file_like.py index 262838c..986877d 100644 --- a/ragdaemon/io/file_like.py +++ b/ragdaemon/io/file_like.py @@ -2,10 +2,14 @@ class FileLike(Protocol): - def read(self) -> str: ... + def read(self) -> str: + ... - def write(self, data: str) -> int: ... + def write(self, data: str) -> int: + ... - def __enter__(self) -> "FileLike": ... + def __enter__(self) -> "FileLike": + ... - def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + ... diff --git a/ragdaemon/utils.py b/ragdaemon/utils.py index 767af30..7e5868f 100644 --- a/ragdaemon/utils.py +++ b/ragdaemon/utils.py @@ -5,7 +5,7 @@ from pathlib import Path from spice import Spice -from spice.models import GPT_4o_2024_05_13, Model, UnknownModel +from spice.models import GPT_4o_mini, Model, UnknownModel from spice.spice import get_model_from_name from ragdaemon.errors import RagdaemonError @@ -40,7 +40,10 @@ ] -DEFAULT_COMPLETION_MODEL = GPT_4o_2024_05_13 +DEFAULT_COMPLETION_MODEL = GPT_4o_mini +DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large" +EMBEDDING_DIMENSIONS = 3072 +MAX_INPUTS_PER_CALL = 2048 def hash_str(string: str) -> str: diff --git a/tests/conftest.py b/tests/conftest.py index 477a075..d5861e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,8 +12,9 @@ from docker.errors import DockerException import pytest -from ragdaemon.database import DEFAULT_EMBEDDING_MODEL, get_db +from ragdaemon.database import get_db from ragdaemon.io import LocalIO +from ragdaemon.utils import DEFAULT_EMBEDDING_MODEL @pytest.fixture diff --git a/tests/test_comments.py b/tests/test_comments.py index 43d306c..bab7bc7 100644 --- a/tests/test_comments.py +++ b/tests/test_comments.py @@ -37,7 +37,8 @@ async def test_comment_render(cwd_git_diff, mock_db): ) context.add_comment("src/operations.py", "Comments can just be strings", line=12) actual = context.render() - assert actual == dedent("""\ + assert actual == dedent( + """\ src/operations.py What is this file for? (test-flag) 1:import math @@ -77,10 +78,12 @@ async def test_comment_render(cwd_git_diff, mock_db): 21: return math.sqrt(a) 22: - """) + """ + ) context.remove_comments("src/operations.py", tags=["test-flag"]) actual = context.render() - assert actual == dedent("""\ + assert actual == dedent( + """\ src/operations.py 1:import math @@ -120,10 +123,12 @@ async def test_comment_render(cwd_git_diff, mock_db): 21: return math.sqrt(a) 22: - """) + """ + ) context.remove_comments("src/operations.py") actual = context.render() - assert actual == dedent("""\ + assert actual == dedent( + """\ src/operations.py 1:import math 2: #modified @@ -147,4 +152,5 @@ async def test_comment_render(cwd_git_diff, mock_db): 20:def sqrt(a): 21: return math.sqrt(a) 22: - """) + """ + ) diff --git a/tests/test_database.py b/tests/test_database.py index afb6c63..9d3d90a 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,6 +1,7 @@ from unittest.mock import AsyncMock -from ragdaemon.database import DEFAULT_EMBEDDING_MODEL, LiteDB, get_db +from ragdaemon.database import LiteDB, get_db +from ragdaemon.utils import DEFAULT_EMBEDDING_MODEL def test_mock_database(): diff --git a/tutorial.ipynb b/tutorial.ipynb index a2f44a4..c48f3c5 100644 --- a/tutorial.ipynb +++ b/tutorial.ipynb @@ -46,7 +46,7 @@ " - For files and chunks, relative path + lines, e.g. `mentat/config.py:10-15`\n", " - For diffs, the diff target (\"DEFAULT\" if none provided) + lines in diff, e.g. `DEFAULT:4-10`\n", "4. `document`: The embedded content. Always f\"{id}\\n{content}\" so the path / filename is also embedded, and there are no duplicates.\n", - "5. `checksum`: an md5 hash of the document, used as the chroma index." + "5. `checksum`: an md5 hash of the document" ] }, {