Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pg vector db #59

Merged
merged 5 commits into from
Aug 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -7,18 +7,20 @@ packages=["ragdaemon"]

[project]
name = "ragdaemon"
version = "0.8.3"
version = "0.9.0"
description = "Generate and render a call graph for a Python project."
readme = "README.md"
dependencies = [
"astroid==3.2.2",
"chromadb==0.4.24",
"asyncpg==0.29.0",
"dict2xml==1.7.5",
"docker==7.1.0",
"fastapi==0.109.2",
"Jinja2==3.1.3",
"networkx==3.2.1",
"pgvector==0.3.2",
"psycopg2-binary==2.9.9",
"python-dotenv",
"rank_bm25==0.2.2",
"sqlalchemy==2.0.30",
"spiceai~=0.3.0",
2 changes: 1 addition & 1 deletion ragdaemon/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.8.3"
__version__ = "0.9.0"
3 changes: 1 addition & 2 deletions ragdaemon/annotators/call_graph.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from tqdm.asyncio import tqdm

from ragdaemon.annotators.base_annotator import Annotator
from ragdaemon.database import Database, remove_update_db_duplicates
from ragdaemon.database import Database
from ragdaemon.errors import RagdaemonError
from ragdaemon.graph import KnowledgeGraph
from ragdaemon.utils import (
@@ -229,7 +229,6 @@ async def annotate(
update_db["ids"].append(data["checksum"])
metadatas = {self.call_field_id: json.dumps(data[self.call_field_id])}
update_db["metadatas"].append(metadatas)
update_db = remove_update_db_duplicates(**update_db)
db.update(**update_db)

# Add call edges to graph. Each call should have only ONE source; if there are
15 changes: 3 additions & 12 deletions ragdaemon/annotators/chunker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import json
from copy import deepcopy
from functools import partial
from pathlib import Path
from typing import Optional, Set
@@ -13,11 +12,7 @@
from ragdaemon.annotators.chunker.chunk_line import chunk_document as chunk_line
from ragdaemon.annotators.chunker.chunk_llm import chunk_document as chunk_llm
from ragdaemon.annotators.chunker.utils import resolve_chunk_parent
from ragdaemon.database import (
Database,
remove_add_to_db_duplicates,
remove_update_db_duplicates,
)
from ragdaemon.database import Database
from ragdaemon.errors import RagdaemonError
from ragdaemon.graph import KnowledgeGraph
from ragdaemon.utils import (
@@ -145,7 +140,6 @@ async def annotate(
update_db["ids"].append(data["checksum"])
metadatas = {self.chunk_field_id: json.dumps(data[self.chunk_field_id])}
update_db["metadatas"].append(metadatas)
update_db = remove_update_db_duplicates(**update_db)
db.update(**update_db)

# Process chunks
@@ -189,22 +183,19 @@ async def annotate(
ids = list(set(checksums.values()))
response = db.get(ids=ids, include=["metadatas"])
db_data = {id: data for id, data in zip(response["ids"], response["metadatas"])}
add_to_db = {"ids": [], "documents": [], "metadatas": []}
add_to_db = {"ids": [], "documents": []}
for node, checksum in checksums.items():
if checksum in db_data:
data = db_data[checksum]
graph.nodes[node].update(data)
else:
data = deepcopy(graph.nodes[node])
document = data.pop("document")
document = graph.nodes[node].get("document")
document, truncate_ratio = truncate(document, db.embedding_model)
if truncate_ratio > 0 and self.verbose > 1:
print(f"Truncated {node} by {truncate_ratio:.2%}")
add_to_db["ids"].append(checksum)
add_to_db["documents"].append(document)
add_to_db["metadatas"].append(data)
if len(add_to_db["ids"]) > 0:
add_to_db = remove_add_to_db_duplicates(**add_to_db)
db.add(**add_to_db)

return graph
13 changes: 6 additions & 7 deletions ragdaemon/annotators/diff.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import json
import re
from copy import deepcopy

from ragdaemon.annotators.base_annotator import Annotator
from ragdaemon.database import Database, remove_add_to_db_duplicates
from ragdaemon.database import Database
from ragdaemon.graph import KnowledgeGraph
from ragdaemon.errors import RagdaemonError
from ragdaemon.utils import (
@@ -150,18 +149,18 @@ async def annotate(
for id, checksum in checksums.items():
if checksum in db_data:
continue
data = deepcopy(graph.nodes[id])
document = data.pop("document")
if "chunks" in data:
data["chunks"] = json.dumps(data["chunks"])
data = {}
document = graph.nodes[id].get("document")
chunks = graph.nodes[id].get("chunks")
if chunks:
data["chunks"] = json.dumps(chunks)
document, truncate_ratio = truncate(document, db.embedding_model)
if self.verbose > 1 and truncate_ratio > 0:
print(f"Truncated {id} by {truncate_ratio:.2%}")
add_to_db["ids"].append(checksum)
add_to_db["documents"].append(document)
add_to_db["metadatas"].append(data)
if len(add_to_db["ids"]) > 0:
add_to_db = remove_add_to_db_duplicates(**add_to_db)
db.add(**add_to_db)

return graph
10 changes: 3 additions & 7 deletions ragdaemon/annotators/hierarchy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from copy import deepcopy
from pathlib import Path

from ragdaemon.annotators.base_annotator import Annotator
from ragdaemon.database import Database, remove_add_to_db_duplicates
from ragdaemon.database import Database
from ragdaemon.graph import KnowledgeGraph
from ragdaemon.errors import RagdaemonError
from ragdaemon.utils import get_document, hash_str, truncate
@@ -93,22 +92,19 @@ async def annotate(
ids = list(set(checksums.values()))
response = db.get(ids=ids, include=["metadatas"])
db_data = {id: data for id, data in zip(response["ids"], response["metadatas"])}
add_to_db = {"ids": [], "documents": [], "metadatas": []}
add_to_db = {"ids": [], "documents": []}
for path, checksum in checksums.items():
if checksum in db_data:
data = db_data[checksum]
graph.nodes[path.as_posix()].update(data)
else:
data = deepcopy(graph.nodes[path.as_posix()])
document = data.pop("document")
document = graph.nodes[path.as_posix()]["document"]
document, truncate_ratio = truncate(document, db.embedding_model)
if self.verbose > 1 and truncate_ratio > 0:
print(f"Truncated {path} by {truncate_ratio:.2%}")
add_to_db["ids"].append(checksum)
add_to_db["documents"].append(document)
add_to_db["metadatas"].append(data)
if len(add_to_db["ids"]) > 0:
add_to_db = remove_add_to_db_duplicates(**add_to_db)
db.add(**add_to_db)

return graph
9 changes: 6 additions & 3 deletions ragdaemon/annotators/layout_hierarchy.py
Original file line number Diff line number Diff line change
@@ -85,6 +85,10 @@ def iterate(iteration: int):
class LayoutHierarchy(Annotator):
name = "layout_hierarchy"

def __init__(self, *args, iterations: int = 40, **kwargs):
super().__init__(*args, **kwargs)
self.iterations = iterations

def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool:
# Check that they have data.layout.hierarchy
for node, data in graph.nodes(data=True):
@@ -99,15 +103,14 @@ async def annotate(
graph: KnowledgeGraph,
db: Database,
refresh: str | bool = False,
iterations: int = 40,
) -> KnowledgeGraph:
"""
a. Regenerate x/y/z for all nodes
b. Update all nodes
c. Save to chroma
c. Save to db
"""
pos = fruchterman_reingold_3d(
graph, iterations=iterations, verbose=self.verbose
graph, iterations=self.iterations, verbose=self.verbose
)
for node_id, coordinates in pos.items():
node = graph.nodes[node_id]
42 changes: 12 additions & 30 deletions ragdaemon/annotators/summarizer.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@

from ragdaemon.annotators.base_annotator import Annotator
from ragdaemon.context import ContextBuilder
from ragdaemon.database import Database, remove_update_db_duplicates
from ragdaemon.database import Database
from ragdaemon.graph import KnowledgeGraph
from ragdaemon.errors import RagdaemonError
from ragdaemon.io import IO
@@ -198,7 +198,6 @@ def get_chunk_summaries(target: str) -> list[str]:
class Summarizer(Annotator):
name = "summarizer"
summary_field_id = "summary"
checksum_field_id = "summary_checksum"

def __init__(
self,
@@ -226,13 +225,6 @@ def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool:
raise RagdaemonError(f"Node {node} missing checksum.")
if data.get(self.summary_field_id) is None:
return False
# Checksum used to be hash_str(document + context) using the above method. This is
# technically more correct, because the summary context includes adjacent summaries
# so the whole system updates iteratively. In practice it was just too much looping
# so for now we just reuse the checksum generated in hierarchy (hash_str(document)).
summary_checksum = data["checksum"]
if summary_checksum != data.get(self.checksum_field_id):
return False
return True

async def generate_summary(
@@ -247,13 +239,8 @@ async def generate_summary(
raise RagdaemonError("Spice client not initialized")

data = graph.nodes[node]
summary_checksum = data["checksum"]
_refresh = match_refresh(refresh, node)
if (
_refresh
or data.get(self.summary_field_id) is None
or summary_checksum != data.get(self.checksum_field_id)
):
if _refresh or data.get(self.summary_field_id) is None:
document, context = get_document_and_context(
node,
graph,
@@ -283,7 +270,6 @@ async def generate_summary(

if summary != "PASS":
data[self.summary_field_id] = summary
data[self.checksum_field_id] = summary_checksum

if loading_bar is not None:
loading_bar.update(1)
@@ -311,31 +297,27 @@ async def annotate(
self, graph: KnowledgeGraph, db: Database, refresh: str | bool = False
) -> KnowledgeGraph:
"""Asynchronously generate or fetch summaries and add to graph/db"""
summaries = dict[str, str]()
nodes_to_summarize: set[str] = set()
for node, data in graph.nodes(data=True):
if data is not None and data.get("type") in self.summarize_nodes:
summaries[node] = data.get(self.checksum_field_id, "")
nodes_to_summarize.add(node)

if self.verbose > 1:
loading_bar = tqdm(total=len(summaries), desc="Summarizing code...")
loading_bar = tqdm(
total=len(nodes_to_summarize), desc="Summarizing code..."
)
else:
loading_bar = None

await self.dfs("ROOT", graph, loading_bar, refresh)

update_db = {"ids": [], "metadatas": []}
for node, summary_checksum in summaries.items():
if graph.nodes[node].get(self.checksum_field_id) != summary_checksum:
data = graph.nodes[node]
update_db["ids"].append(data["checksum"])
update_db["metadatas"].append(
{
self.summary_field_id: data[self.summary_field_id],
self.checksum_field_id: data[self.checksum_field_id],
}
)
for node in nodes_to_summarize:
data = graph.nodes[node]
update_db["ids"].append(data["checksum"])
metadatas = {self.summary_field_id: data[self.summary_field_id]}
update_db["metadatas"].append(metadatas)
if len(update_db["ids"]) > 1:
update_db = remove_update_db_duplicates(**update_db)
db.update(**update_db)

if loading_bar is not None:
14 changes: 6 additions & 8 deletions ragdaemon/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import asyncio
import socket
import webbrowser
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any
@@ -40,9 +39,8 @@
annotators = {
"hierarchy": {},
"chunker": {"use_llm": True},
# "summarizer": {},
# "clusterer_binary": {},
# "call_graph": {"call_extensions": code_extensions},
"call_graph": {"call_extensions": code_extensions},
"summarizer": {},
"diff": {"diff": diff},
"layout_hierarchy": {},
}
@@ -117,9 +115,9 @@ async def main():
print(f"Starting server on port {port}...")
server = uvicorn.Server(config)

async def _wait_1s_then_open_browser():
await asyncio.sleep(1)
webbrowser.open(f"http://localhost:{port}")
# async def _wait_1s_then_open_browser():
# await asyncio.sleep(1)
# webbrowser.open(f"http://localhost:{port}")

asyncio.create_task(_wait_1s_then_open_browser())
# asyncio.create_task(_wait_1s_then_open_browser())
Comment on lines -120 to +122
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meant to leave these?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh whoops - ya it wasn't working right on my pc.

await server.serve()
11 changes: 8 additions & 3 deletions ragdaemon/daemon.py
Original file line number Diff line number Diff line change
@@ -13,12 +13,17 @@
from ragdaemon.annotators import annotators_map
from ragdaemon.cerebrus import cerebrus
from ragdaemon.context import ContextBuilder
from ragdaemon.database import DEFAULT_EMBEDDING_MODEL, Database, get_db
from ragdaemon.database import Database, get_db
from ragdaemon.errors import RagdaemonError
from ragdaemon.graph import KnowledgeGraph
from ragdaemon.io import DockerIO, IO, LocalIO
from ragdaemon.locate import locate
from ragdaemon.utils import DEFAULT_COMPLETION_MODEL, match_refresh, mentat_dir_path
from ragdaemon.utils import (
DEFAULT_COMPLETION_MODEL,
DEFAULT_EMBEDDING_MODEL,
match_refresh,
mentat_dir_path,
)


def default_annotators():
@@ -61,7 +66,7 @@ def __init__(
if spice_client is None:
spice_client = Spice(
default_text_model=DEFAULT_COMPLETION_MODEL,
default_embeddings_model=model,
default_embeddings_model=DEFAULT_EMBEDDING_MODEL,
logging_dir=logging_dir,
)
self.spice_client = spice_client
Loading
Loading