Skip to content

Commit

Permalink
Merge pull request #15 from AbanteAI/summarizer
Browse files Browse the repository at this point in the history
impelment summarizer annotator
  • Loading branch information
granawkins authored Apr 25, 2024
2 parents b42896e + 92b9663 commit 1ab8b67
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 12 deletions.
2 changes: 2 additions & 0 deletions ragdaemon/annotators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ragdaemon.annotators.diff import Diff
from ragdaemon.annotators.hierarchy import Hierarchy
from ragdaemon.annotators.layout_hierarchy import LayoutHierarchy
from ragdaemon.annotators.summarizer import Summarizer

annotators_map = {
"hierarchy": Hierarchy,
Expand All @@ -13,4 +14,5 @@
"chunker_line": ChunkerLine,
"diff": Diff,
"layout_hierarchy": LayoutHierarchy,
"summarizer": Summarizer,
}
80 changes: 80 additions & 0 deletions ragdaemon/annotators/summarizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Add a 1-sentence text summary to each file or chunk node
"""

import asyncio
from typing import Any, Coroutine

from tqdm.asyncio import tqdm

from ragdaemon.annotators.base_annotator import Annotator
from ragdaemon.database import Database
from ragdaemon.graph import KnowledgeGraph
from ragdaemon.errors import RagdaemonError
from spice import SpiceMessage

summarizer_prompt = """\
Generate a 1-sentence summary of the provided code. Follow conventions of docstrings:
write in the imerative voice and start with a verb. Do not include any preamble or
asides.
It may be useful to name specific fucntions from the target repository (not built-in
Python functions) which are integral to the functioning of the target code. Include a
maximum of two (2) such named functions, but err on the side of brevity.
"""


semaphore = asyncio.Semaphore(50)


class Summarizer(Annotator):
name = "summarizer"

def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool:
return all(
data.get("summary") is not None
for _, data in graph.nodes(data=True)
if data is not None and data.get("checksum") is not None
)

async def get_llm_response(self, document: str) -> str:
if self.spice_client is None:
raise RagdaemonError("Spice client is not initialized.")
global semaphore
async with semaphore:
messages: list[SpiceMessage] = [
{"role": "system", "content": summarizer_prompt},
{"role": "user", "content": document},
]
response = await self.spice_client.get_response(
messages=messages,
)
return response.text

async def get_summary(self, data: dict[str, Any], db: Database):
"""Asynchronously generate summary and update graph and db"""
record = db.get(data["checksum"])
document = record["documents"][0]
metadatas = record["metadatas"][0]
summary = await self.get_llm_response(document)
metadatas["summary"] = summary
db.update(data["checksum"], metadatas=metadatas)
data["summary"] = summary

async def annotate(
self, graph: KnowledgeGraph, db: Database, refresh: bool = False
) -> KnowledgeGraph:
# Generate/add summaries to nodes with checksums (file, chunk, diff)
tasks = []
for _, data in graph.nodes(data=True):
if data is None or data.get("checksum") is None:
continue
if data.get("summary") is not None and not refresh:
continue
tasks.append(self.get_summary(data, db))
if len(tasks) > 0:
if self.verbose:
await tqdm.gather(*tasks, desc="Summarizing code...")
else:
await asyncio.gather(*tasks)
return graph
1 change: 1 addition & 0 deletions ragdaemon/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class NodeMetadata(TypedDict):
chunks: Optional[
list[dict[str, str]]
] # For files, func/class/method. For diff, by file/hunk
summary: Optional[str] # Generated summary of the node


class EdgeMetadata(TypedDict):
Expand Down
12 changes: 10 additions & 2 deletions tests/annotators/test_chunker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, patch

import pytest

Expand All @@ -8,6 +8,15 @@
from ragdaemon.graph import KnowledgeGraph


@pytest.fixture
def mock_get_llm_response():
with patch(
"ragdaemon.annotators.chunker_llm.ChunkerLLM.get_llm_response",
return_value={"chunks": []},
) as mock:
yield mock


def test_chunker_is_complete(cwd, mock_db):
chunker = Chunker()

Expand Down Expand Up @@ -45,7 +54,6 @@ async def test_chunker_llm_annotate(cwd, mock_get_llm_response, mock_db):
daemon = Daemon(
cwd=cwd,
annotators={"hierarchy": {}},
graph_path=(Path.cwd() / "tests/data/hierarchy_graph.json"),
)
chunker = ChunkerLLM(spice_client=AsyncMock())
actual = await chunker.annotate(daemon.graph, mock_db)
Expand Down
29 changes: 29 additions & 0 deletions tests/annotators/test_summarizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from unittest.mock import AsyncMock, patch

import pytest

from ragdaemon.annotators import Summarizer
from ragdaemon.daemon import Daemon


@pytest.fixture
def mock_get_llm_response():
with patch(
"ragdaemon.annotators.summarizer.Summarizer.get_llm_response",
return_value="summary of",
) as mock:
yield mock


@pytest.mark.asyncio
async def test_summarizer_annotate(cwd, mock_get_llm_response):
daemon = Daemon(
cwd=cwd,
annotators={"hierarchy": {}},
)
await daemon.update(refresh=True)
summarizer = Summarizer(spice_client=AsyncMock())
actual = await summarizer.annotate(daemon.graph, daemon.db)
for _, data in actual.nodes(data=True):
if data.get("checksum") is not None:
assert data.get("summary") == "summary of"
11 changes: 1 addition & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import subprocess
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock

import pytest

Expand All @@ -22,15 +22,6 @@ def mock_db(cwd):
)


@pytest.fixture
def mock_get_llm_response():
with patch(
"ragdaemon.annotators.chunker_llm.ChunkerLLM.get_llm_response",
return_value={"chunks": []},
) as mock:
yield mock


@pytest.fixture(scope="function")
def git_history(cwd):
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down

0 comments on commit 1ab8b67

Please sign in to comment.