Skip to content

Commit

Permalink
Binary Clusterer (#38)
Browse files Browse the repository at this point in the history
* implement SummarizerAgglomerative annotator

* integrate it into the program

* format typing and tests

* efficiency improvements

* add scipy dependency

* rename agglomerative_summarizer to binary_clusterer

* move annotator initialization checks to __init__

* remove scipy from requirements and lazy-load in clusterer_binary

* clean up commit

* version bump

* ignore scipy import typecheck
  • Loading branch information
granawkins authored May 8, 2024
1 parent c42a285 commit 969cca7
Show file tree
Hide file tree
Showing 10 changed files with 295 additions and 37 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ packages=["ragdaemon"]

[project]
name = "ragdaemon"
version = "0.4.4"
version = "0.4.5"
description = "Generate and render a call graph for a Python project."
readme = "README.md"
dependencies = [
Expand Down
2 changes: 1 addition & 1 deletion ragdaemon/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.4"
__version__ = "0.4.5"
6 changes: 4 additions & 2 deletions ragdaemon/annotators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from ragdaemon.annotators.hierarchy import Hierarchy
from ragdaemon.annotators.layout_hierarchy import LayoutHierarchy
from ragdaemon.annotators.summarizer import Summarizer
from ragdaemon.annotators.clusterer_binary import ClustererBinary

annotators_map = {
"hierarchy": Hierarchy,
"call_graph": CallGraph,
"chunker": Chunker,
"chunker_llm": ChunkerLLM,
"chunker_line": ChunkerLine,
"chunker_llm": ChunkerLLM,
"clusterer_binary": ClustererBinary,
"diff": Diff,
"hierarchy": Hierarchy,
"layout_hierarchy": LayoutHierarchy,
"summarizer": Summarizer,
}
9 changes: 8 additions & 1 deletion ragdaemon/annotators/base_annotator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Optional

from spice import Spice
Expand All @@ -9,7 +11,12 @@
class Annotator:
name: str = "base_annotator"

def __init__(self, verbose: bool = False, spice_client: Optional[Spice] = None):
def __init__(
self,
verbose: bool = False,
spice_client: Optional[Spice] = None,
pipeline: Optional[list[Annotator]] = None,
):
self.verbose = verbose
self.spice_client = spice_client
pass
Expand Down
12 changes: 9 additions & 3 deletions ragdaemon/annotators/call_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,22 @@ def __init__(
self,
*args,
call_extensions: Optional[list[str]] = None,
chunk_field_id: Optional[str] = None,
model: Optional[TextModel | str] = DEFAULT_COMPLETION_MODEL,
pipeline: list[Annotator] = [],
**kwargs,
):
super().__init__(*args, **kwargs)
if call_extensions is None:
call_extensions = DEFAULT_CODE_EXTENSIONS
self.call_extensions = call_extensions
if chunk_field_id is None:
raise RagdaemonError("Chunk field ID is required for call graph annotator.")
try:
chunk_field_id = next(
getattr(a, "chunk_field_id") for a in pipeline if "chunker" in a.name
)
except (StopIteration, AttributeError):
raise RagdaemonError(
"CallGraph annotator requires a 'chunker' annotator with chunk_field_id."
)
self.chunk_field_id = chunk_field_id
self.model = model

Expand Down
242 changes: 242 additions & 0 deletions ragdaemon/annotators/clusterer_binary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import asyncio
import json
from typing import Optional, Any

import numpy as np
from spice import SpiceMessage
from spice.models import TextModel
from tqdm.asyncio import tqdm

from ragdaemon.annotators.base_annotator import Annotator
from ragdaemon.database import Database
from ragdaemon.graph import KnowledgeGraph
from ragdaemon.errors import RagdaemonError
from ragdaemon.utils import DEFAULT_COMPLETION_MODEL, hash_str, semaphore

clusterer_binary_prompt = """\
You are building a hierarchical summary of a codebase using agglomerative clustering.
You will be given two one-line summaries of code chunks or existing summaries.
Combine the two summaries into a single one-line summary.
Your summary should concisely answer the question "What does this do?"
Don't aim to give an exhaustive report; instead, focus on what would distinguish this
particular code from other parts of the codebase.
"""


class ClustererBinary(Annotator):
name = "cluterer_binary"

def __init__(
self,
*args,
pipeline: list[Annotator] = [],
linkage_method: str = "ward",
model: Optional[TextModel | str] = DEFAULT_COMPLETION_MODEL,
**kwargs,
):
super().__init__(*args, **kwargs)
try:
chunk_field_id = next(
getattr(a, "chunk_field_id") for a in pipeline if "chunker" in a.name
)
summary_field_id = next(
getattr(a, "summary_field_id")
for a in pipeline
if "summarizer" in a.name
)
except (StopIteration, AttributeError):
raise RagdaemonError(
"ClustererBinary annotator requires a 'chunker' and 'summarizer' annotator with chunk_field_id and summary_field_id."
)
self.chunk_field_id = chunk_field_id
self.summary_field_id = summary_field_id
self.linkage_method = linkage_method
self.model = model

def select_leaf_nodes(self, graph: KnowledgeGraph) -> list[str]:
leaf_nodes = []
for node, data in graph.nodes(data=True):
if data is None:
raise RagdaemonError(f"Node {node} has no data.")
if data.get("type") != "file":
continue

# Determine whether to use the file itself or its chunks
chunks = data.get(self.chunk_field_id)
if chunks is None:
leaf_nodes.append(node)
continue
if not isinstance(chunks, list):
chunks = json.loads(chunks)
if len(chunks) == 0:
leaf_nodes.append(node)
else:
for chunk in chunks:
leaf_nodes.append(chunk["id"])
return leaf_nodes

def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool:
# Start with a list of all the summary nodes
cluster_binary_nodes = [
(node, graph.in_degree(node), graph.out_degree(node))
for node, data in graph.nodes(data=True)
if data is not None and data.get("type") == "cluster_binary"
]
root = None
leaves = set()
for node, in_degree, out_degree in cluster_binary_nodes:
if not graph.nodes[node].get("summary"):
return False # Each needs a summary
if out_degree != 2:
return False # Each needs 2 successors
if in_degree == 0:
if root is not None:
return False # Only one should have no predecessors
root = node
else:
if in_degree != 1:
return False # The rest need 1 predecessor
for neighbor in graph.successors(node):
if graph.nodes[neighbor].get("type") != "cluster_binary":
leaves.add(neighbor)
if root is None:
return False # There has to be a root
expected_leaves = set(self.select_leaf_nodes(graph))
return leaves == expected_leaves # All leaves are accounted for

async def get_llm_response(self, document: str) -> str:
if self.spice_client is None:
raise RagdaemonError("Spice client is not initialized.")
messages: list[SpiceMessage] = [
{"role": "system", "content": clusterer_binary_prompt},
{"role": "user", "content": document},
]
async with semaphore:
response = await self.spice_client.get_response(
messages=messages,
model=self.model,
)
return response.text

async def get_summary(
self,
node: str,
document: str,
graph: KnowledgeGraph,
loading_bar: Optional[tqdm] = None,
) -> dict[str, Any]:
"""Asynchronously generate summary and update graph and db"""
summary = await self.get_llm_response(document)
checksum = hash_str(document)
record = {
"id": node,
"type": "cluster_binary",
"summary": summary,
"checksum": checksum,
"active": False,
}
graph.nodes[node].update(record)
if loading_bar is not None:
loading_bar.update(1)
return {"ids": checksum, "documents": document, "metadatas": record}

async def load_all_summary_nodes(
self,
new_nodes: list[str],
graph: KnowledgeGraph,
db: Database,
refresh: bool = False,
):
"""Asynchronously generate or fetch summaries and add to graph/db"""
loading_bar = (
None
if not self.verbose
else tqdm(total=len(new_nodes), desc="Refreshing binary clusters")
)
while len(new_nodes) > 0:
tasks = []
just_added = set()
for node in new_nodes:
a, b = list(graph.successors(node))
a_summary = graph.nodes[a].get("summary")
b_summary = graph.nodes[b].get("summary")
if a_summary is None or b_summary is None:
continue
just_added.add(node)
document = f"{a_summary}\n{b_summary}"
checksum = hash_str(document)
records = db.get(checksum)["metadatas"]
if refresh or len(records) == 0:
tasks.append(self.get_summary(node, document, graph, loading_bar))
else:
record = records[0]
graph.nodes[node].update(record)
if loading_bar is not None:
loading_bar.update(1)

new_nodes = list(set(new_nodes) - just_added)
if len(tasks) > 0:
results = await asyncio.gather(*tasks)
add_to_db = {"ids": [], "documents": [], "metadatas": []}
for result in results:
for key, value in result.items():
add_to_db[key].append(value)
db.add(**add_to_db)
elif new_nodes:
raise RagdaemonError(f"Stuck on nodes {new_nodes}")

if loading_bar is not None:
loading_bar.close()

async def annotate(
self, graph: KnowledgeGraph, db: Database, refresh: bool = False
) -> KnowledgeGraph:
try:
# Scipy is intentionally excluded from package requirements because it's
# a large package and this is an experimental feature.
from scipy.cluster.hierarchy import linkage # type: ignore
except ImportError:
raise RagdaemonError(
"ClustererBinary requires scipy to be installed. Run 'pip install scipy'."
)

# Remove any existing cluster_binary nodes and edges
cluster_binary_nodes = [
node
for node, data in graph.nodes(data=True)
if data is not None and data.get("type") == "cluster_binary"
]
graph.remove_nodes_from(cluster_binary_nodes)
cluster_binary_edges = [
(e[0], e[1])
for e in graph.edges(data=True)
if e[-1].get("type") == "cluster_binary"
]
graph.remove_edges_from(cluster_binary_edges)

# Generate the linkage_list for active checksums
leaf_ids = self.select_leaf_nodes(graph)
leaf_checksums = [graph.nodes[leaf]["checksum"] for leaf in leaf_ids]
embeddings = db.get(ids=leaf_checksums, include=["embeddings"])["embeddings"]
data = np.array([np.array(e) for e in embeddings])
linkage_matrix = linkage(data, method=self.linkage_method)

# Add empty nodes and edges to the graph
all_nodes = leaf_ids.copy()
for i, (a, b, _, height) in enumerate(linkage_matrix):
i_link = i + len(leaf_ids)
node = f"summary_{i_link}"
all_nodes.append(node)
graph.add_node(node)
graph.add_edge(node, all_nodes[int(a)], type="cluster_binary")
graph.add_edge(node, all_nodes[int(b)], type="cluster_binary")

# Generate/fetch summaries and add to graph/db.
new_nodes = all_nodes[len(leaf_ids) :]
try:
await self.load_all_summary_nodes(new_nodes, graph, db, refresh=refresh)
except KeyboardInterrupt:
raise

return graph
9 changes: 5 additions & 4 deletions ragdaemon/annotators/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

class Summarizer(Annotator):
name = "summarizer"
summary_field_id = "summary"

def __init__(
self,
Expand All @@ -40,7 +41,7 @@ def __init__(

def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool:
return all(
data.get("summary") is not None
data.get(self.summary_field_id) is not None
for _, data in graph.nodes(data=True)
if data is not None and data.get("checksum") is not None
)
Expand All @@ -65,9 +66,9 @@ async def get_summary(self, data: dict[str, Any], db: Database):
document = record["documents"][0]
metadatas = record["metadatas"][0]
summary = await self.get_llm_response(document)
metadatas["summary"] = summary
metadatas[self.summary_field_id] = summary
db.update(data["checksum"], metadatas=metadatas)
data["summary"] = summary
data[self.summary_field_id] = summary

async def annotate(
self, graph: KnowledgeGraph, db: Database, refresh: bool = False
Expand All @@ -77,7 +78,7 @@ async def annotate(
for _, data in graph.nodes(data=True):
if data is None or data.get("checksum") is None:
continue
if data.get("summary") is not None and not refresh:
if data.get(self.summary_field_id) is not None and not refresh:
continue
tasks.append(self.get_summary(data, db))
if len(tasks) > 0:
Expand Down
4 changes: 3 additions & 1 deletion ragdaemon/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
annotators = {
"hierarchy": {},
"chunker_llm": {"chunk_extensions": code_extensions},
"call_graph": {"call_extensions": code_extensions},
# "summarizer": {},
# "clusterer_binary": {},
# "call_graph": {"call_extensions": code_extensions},
"diff": {"diff": diff},
"layout_hierarchy": {},
}
Expand Down
Loading

0 comments on commit 969cca7

Please sign in to comment.