From 92b96637833f03103a13b0b79f0095fcff105b46 Mon Sep 17 00:00:00 2001 From: granawkins Date: Thu, 25 Apr 2024 10:23:44 -0700 Subject: [PATCH] impelment summarizer annotator --- ragdaemon/annotators/__init__.py | 2 + ragdaemon/annotators/summarizer.py | 80 +++++++++++++++++++++++++++++ ragdaemon/graph.py | 1 + tests/annotators/test_chunker.py | 12 ++++- tests/annotators/test_summarizer.py | 29 +++++++++++ tests/conftest.py | 11 +--- 6 files changed, 123 insertions(+), 12 deletions(-) create mode 100644 ragdaemon/annotators/summarizer.py create mode 100644 tests/annotators/test_summarizer.py diff --git a/ragdaemon/annotators/__init__.py b/ragdaemon/annotators/__init__.py index 6126a70..b45e7ad 100644 --- a/ragdaemon/annotators/__init__.py +++ b/ragdaemon/annotators/__init__.py @@ -5,6 +5,7 @@ from ragdaemon.annotators.diff import Diff from ragdaemon.annotators.hierarchy import Hierarchy from ragdaemon.annotators.layout_hierarchy import LayoutHierarchy +from ragdaemon.annotators.summarizer import Summarizer annotators_map = { "hierarchy": Hierarchy, @@ -13,4 +14,5 @@ "chunker_line": ChunkerLine, "diff": Diff, "layout_hierarchy": LayoutHierarchy, + "summarizer": Summarizer, } diff --git a/ragdaemon/annotators/summarizer.py b/ragdaemon/annotators/summarizer.py new file mode 100644 index 0000000..6e484aa --- /dev/null +++ b/ragdaemon/annotators/summarizer.py @@ -0,0 +1,80 @@ +""" +Add a 1-sentence text summary to each file or chunk node +""" + +import asyncio +from typing import Any, Coroutine + +from tqdm.asyncio import tqdm + +from ragdaemon.annotators.base_annotator import Annotator +from ragdaemon.database import Database +from ragdaemon.graph import KnowledgeGraph +from ragdaemon.errors import RagdaemonError +from spice import SpiceMessage + +summarizer_prompt = """\ +Generate a 1-sentence summary of the provided code. Follow conventions of docstrings: +write in the imerative voice and start with a verb. Do not include any preamble or +asides. + +It may be useful to name specific fucntions from the target repository (not built-in +Python functions) which are integral to the functioning of the target code. Include a +maximum of two (2) such named functions, but err on the side of brevity. +""" + + +semaphore = asyncio.Semaphore(50) + + +class Summarizer(Annotator): + name = "summarizer" + + def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool: + return all( + data.get("summary") is not None + for _, data in graph.nodes(data=True) + if data is not None and data.get("checksum") is not None + ) + + async def get_llm_response(self, document: str) -> str: + if self.spice_client is None: + raise RagdaemonError("Spice client is not initialized.") + global semaphore + async with semaphore: + messages: list[SpiceMessage] = [ + {"role": "system", "content": summarizer_prompt}, + {"role": "user", "content": document}, + ] + response = await self.spice_client.get_response( + messages=messages, + ) + return response.text + + async def get_summary(self, data: dict[str, Any], db: Database): + """Asynchronously generate summary and update graph and db""" + record = db.get(data["checksum"]) + document = record["documents"][0] + metadatas = record["metadatas"][0] + summary = await self.get_llm_response(document) + metadatas["summary"] = summary + db.update(data["checksum"], metadatas=metadatas) + data["summary"] = summary + + async def annotate( + self, graph: KnowledgeGraph, db: Database, refresh: bool = False + ) -> KnowledgeGraph: + # Generate/add summaries to nodes with checksums (file, chunk, diff) + tasks = [] + for _, data in graph.nodes(data=True): + if data is None or data.get("checksum") is None: + continue + if data.get("summary") is not None and not refresh: + continue + tasks.append(self.get_summary(data, db)) + if len(tasks) > 0: + if self.verbose: + await tqdm.gather(*tasks, desc="Summarizing code...") + else: + await asyncio.gather(*tasks) + return graph diff --git a/ragdaemon/graph.py b/ragdaemon/graph.py index 6cef67b..6aa4838 100644 --- a/ragdaemon/graph.py +++ b/ragdaemon/graph.py @@ -16,6 +16,7 @@ class NodeMetadata(TypedDict): chunks: Optional[ list[dict[str, str]] ] # For files, func/class/method. For diff, by file/hunk + summary: Optional[str] # Generated summary of the node class EdgeMetadata(TypedDict): diff --git a/tests/annotators/test_chunker.py b/tests/annotators/test_chunker.py index 9e1009f..cd61ebf 100644 --- a/tests/annotators/test_chunker.py +++ b/tests/annotators/test_chunker.py @@ -1,5 +1,5 @@ from pathlib import Path -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, patch import pytest @@ -8,6 +8,15 @@ from ragdaemon.graph import KnowledgeGraph +@pytest.fixture +def mock_get_llm_response(): + with patch( + "ragdaemon.annotators.chunker_llm.ChunkerLLM.get_llm_response", + return_value={"chunks": []}, + ) as mock: + yield mock + + def test_chunker_is_complete(cwd, mock_db): chunker = Chunker() @@ -45,7 +54,6 @@ async def test_chunker_llm_annotate(cwd, mock_get_llm_response, mock_db): daemon = Daemon( cwd=cwd, annotators={"hierarchy": {}}, - graph_path=(Path.cwd() / "tests/data/hierarchy_graph.json"), ) chunker = ChunkerLLM(spice_client=AsyncMock()) actual = await chunker.annotate(daemon.graph, mock_db) diff --git a/tests/annotators/test_summarizer.py b/tests/annotators/test_summarizer.py new file mode 100644 index 0000000..c057e12 --- /dev/null +++ b/tests/annotators/test_summarizer.py @@ -0,0 +1,29 @@ +from unittest.mock import AsyncMock, patch + +import pytest + +from ragdaemon.annotators import Summarizer +from ragdaemon.daemon import Daemon + + +@pytest.fixture +def mock_get_llm_response(): + with patch( + "ragdaemon.annotators.summarizer.Summarizer.get_llm_response", + return_value="summary of", + ) as mock: + yield mock + + +@pytest.mark.asyncio +async def test_summarizer_annotate(cwd, mock_get_llm_response): + daemon = Daemon( + cwd=cwd, + annotators={"hierarchy": {}}, + ) + await daemon.update(refresh=True) + summarizer = Summarizer(spice_client=AsyncMock()) + actual = await summarizer.annotate(daemon.graph, daemon.db) + for _, data in actual.nodes(data=True): + if data.get("checksum") is not None: + assert data.get("summary") == "summary of" diff --git a/tests/conftest.py b/tests/conftest.py index c5a3305..d6443f8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import subprocess import tempfile from pathlib import Path -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock import pytest @@ -22,15 +22,6 @@ def mock_db(cwd): ) -@pytest.fixture -def mock_get_llm_response(): - with patch( - "ragdaemon.annotators.chunker_llm.ChunkerLLM.get_llm_response", - return_value={"chunks": []}, - ) as mock: - yield mock - - @pytest.fixture(scope="function") def git_history(cwd): with tempfile.TemporaryDirectory() as tmpdir: