From 5d1e8ace1e2c595a3d4f7a3b1d283220084015ab Mon Sep 17 00:00:00 2001 From: Grant <50287275+granawkins@users.noreply.github.com> Date: Sun, 26 May 2024 11:56:49 +0700 Subject: [PATCH] Deterministic Chunking with Astroid * implement python-specific chunker using astroid * use LiteDB by default * minor version bump * format fixes --- pyproject.toml | 3 +- ragdaemon/__init__.py | 2 +- ragdaemon/annotators/__init__.py | 4 - .../{chunker.py => chunker/__init__.py} | 91 +++--- ragdaemon/annotators/chunker/chunk_astroid.py | 35 +++ ragdaemon/annotators/chunker/chunk_line.py | 28 ++ ragdaemon/annotators/chunker/chunk_llm.py | 214 ++++++++++++++ ragdaemon/annotators/chunker/utils.py | 82 ++++++ ragdaemon/annotators/chunker_line.py | 45 --- ragdaemon/annotators/chunker_llm.py | 264 ------------------ ragdaemon/app.py | 2 +- ragdaemon/daemon.py | 4 +- ragdaemon/database/__init__.py | 49 ++-- ragdaemon/database/pg_database.py | 21 +- .../{chunker_llm.toml => chunk_llm.toml} | 0 tests/annotators/test_chunker.py | 86 ++++-- tests/annotators/test_chunker_llm.py | 56 ---- tests/data/summarizer_graph.json | 8 +- 18 files changed, 509 insertions(+), 485 deletions(-) rename ragdaemon/annotators/{chunker.py => chunker/__init__.py} (76%) create mode 100644 ragdaemon/annotators/chunker/chunk_astroid.py create mode 100644 ragdaemon/annotators/chunker/chunk_line.py create mode 100644 ragdaemon/annotators/chunker/chunk_llm.py create mode 100644 ragdaemon/annotators/chunker/utils.py delete mode 100644 ragdaemon/annotators/chunker_line.py delete mode 100644 ragdaemon/annotators/chunker_llm.py rename ragdaemon/prompts/{chunker_llm.toml => chunk_llm.toml} (100%) delete mode 100644 tests/annotators/test_chunker_llm.py diff --git a/pyproject.toml b/pyproject.toml index 725be2d..fd4a34f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,10 +7,11 @@ packages=["ragdaemon"] [project] name = "ragdaemon" -version = "0.6.2" +version = "0.7.0" description = "Generate and render a call graph for a Python project." readme = "README.md" dependencies = [ + "astroid==3.2.2", "chromadb==0.4.24", "dict2xml==1.7.5", "fastapi==0.109.2", diff --git a/ragdaemon/__init__.py b/ragdaemon/__init__.py index 22049ab..49e0fc1 100644 --- a/ragdaemon/__init__.py +++ b/ragdaemon/__init__.py @@ -1 +1 @@ -__version__ = "0.6.2" +__version__ = "0.7.0" diff --git a/ragdaemon/annotators/__init__.py b/ragdaemon/annotators/__init__.py index 90098ce..bd3c3e8 100644 --- a/ragdaemon/annotators/__init__.py +++ b/ragdaemon/annotators/__init__.py @@ -1,8 +1,6 @@ from ragdaemon.annotators.base_annotator import Annotator # noqa: F401 from ragdaemon.annotators.call_graph import CallGraph # noqa: F401 from ragdaemon.annotators.chunker import Chunker -from ragdaemon.annotators.chunker_line import ChunkerLine -from ragdaemon.annotators.chunker_llm import ChunkerLLM from ragdaemon.annotators.diff import Diff from ragdaemon.annotators.hierarchy import Hierarchy from ragdaemon.annotators.layout_hierarchy import LayoutHierarchy @@ -11,8 +9,6 @@ annotators_map = { "call_graph": CallGraph, "chunker": Chunker, - "chunker_line": ChunkerLine, - "chunker_llm": ChunkerLLM, "diff": Diff, "hierarchy": Hierarchy, "layout_hierarchy": LayoutHierarchy, diff --git a/ragdaemon/annotators/chunker.py b/ragdaemon/annotators/chunker/__init__.py similarity index 76% rename from ragdaemon/annotators/chunker.py rename to ragdaemon/annotators/chunker/__init__.py index 2dfa107..e274535 100644 --- a/ragdaemon/annotators/chunker.py +++ b/ragdaemon/annotators/chunker/__init__.py @@ -1,26 +1,10 @@ -""" -Chunk data a list of objects following [ - {id: path/to/file:class.method, start_line: int, end_line: int} -] - -It's stored on the file node as data['chunks'] and json.dumped into the database. - -A chunker annotator: -1. Is complete when all files (with matching extensions) have a 'chunks' field -2. Generates chunks using a subclass method (llm, ctags..) -3. Adds that data to each file's graph node and database record -4. Add graph nodes (and db records) for each of those chunks -5. Add hierarchy edges connecting everything back to cwd - -The Chunker base class below handles everything except step 2. -""" - import asyncio import json from copy import deepcopy +from functools import partial from pathlib import Path -from typing import Any, Optional +from astroid.exceptions import AstroidSyntaxError from tqdm.asyncio import tqdm from ragdaemon.annotators.base_annotator import Annotator @@ -29,6 +13,11 @@ remove_add_to_db_duplicates, remove_update_db_duplicates, ) +from ragdaemon.annotators.chunker.utils import resolve_chunk_parent +from ragdaemon.annotators.chunker.chunk_astroid import chunk_document as chunk_astroid +from ragdaemon.annotators.chunker.chunk_llm import chunk_document as chunk_llm +from ragdaemon.annotators.chunker.chunk_line import chunk_document as chunk_line + from ragdaemon.errors import RagdaemonError from ragdaemon.graph import KnowledgeGraph from ragdaemon.utils import ( @@ -40,34 +29,39 @@ ) -def resolve_chunk_parent(id: str, nodes: set[str]) -> str | None: - file, chunk_str = id.split(":") - if chunk_str == "BASE": - return file - elif "." not in chunk_str: - return f"{file}:BASE" - else: - parts = chunk_str.split(".") - while True: - parent = f"{file}:{'.'.join(parts[:-1])}" - if parent in nodes: - return parent - parent_str = parent.split(":")[1] - if "." not in parent_str: - return None - # If intermediate parents are missing, skip them - parts = parent_str.split(".") - - class Chunker(Annotator): name = "chunker" chunk_field_id = "chunks" - def __init__(self, *args, chunk_extensions: Optional[list[str]] = None, **kwargs): + def __init__(self, *args, use_llm: bool = False, **kwargs): super().__init__(*args, **kwargs) - if chunk_extensions is None: - chunk_extensions = DEFAULT_CODE_EXTENSIONS - self.chunk_extensions = chunk_extensions + + # By default, use either the LLM chunker or a basic line chunker. + if use_llm and self.spice_client is not None: + default_chunk_fn = partial( + chunk_llm, spice_client=self.spice_client, verbose=self.verbose + ) + else: + default_chunk_fn = chunk_line + + # For python files, try to use astroid. If that fails, fall back to the default chunker. + async def python_chunk_fn(document: str): + try: + return await chunk_astroid(document) + except AstroidSyntaxError: + if self.verbose > 0: + file = document.split("\n")[0] + print( + f"Error chunking {file} with astroid; falling back to default chunker." + ) + return await default_chunk_fn(document) + + self.chunk_extensions_map = {} + for extension in DEFAULT_CODE_EXTENSIONS: + if extension == ".py": + self.chunk_extensions_map[extension] = python_chunk_fn + else: + self.chunk_extensions_map[extension] = default_chunk_fn def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool: for node, data in graph.nodes(data=True): @@ -77,10 +71,10 @@ def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool: continue chunks = data.get(self.chunk_field_id, None) if chunks is None: - if self.chunk_extensions is None: + if self.chunk_extensions_map is None: return False extension = Path(data["ref"]).suffix - if extension in self.chunk_extensions: + if extension in self.chunk_extensions_map: return False else: if not isinstance(chunks, list): @@ -90,15 +84,12 @@ def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool: return False return True - async def chunk_document(self, document: str) -> list[dict[str, Any]]: - """Return a list of {id, ref} chunks for the given document.""" - raise NotImplementedError() - async def get_file_chunk_data(self, node, data): """Generate and save chunk data for a file node to graph and db""" document = data["document"] + extension = Path(data["ref"]).suffix try: - chunks = await self.chunk_document(document) + chunks = await self.chunk_extensions_map[extension](document) except RagdaemonError: if self.verbose > 0: print(f"Error chunking {node}; skipping.") @@ -118,11 +109,11 @@ async def annotate( if data.get("type") == "chunk": graph.remove_node(node) elif data.get("type") == "file": - if self.chunk_extensions is None: + if self.chunk_extensions_map is None: files_with_chunks.append((node, data)) else: extension = Path(data["ref"]).suffix - if extension in self.chunk_extensions: + if extension in self.chunk_extensions_map: files_with_chunks.append((node, data)) # Generate/add chunk data for nodes that don't have it diff --git a/ragdaemon/annotators/chunker/chunk_astroid.py b/ragdaemon/annotators/chunker/chunk_astroid.py new file mode 100644 index 0000000..159c02b --- /dev/null +++ b/ragdaemon/annotators/chunker/chunk_astroid.py @@ -0,0 +1,35 @@ +import astroid + +from ragdaemon.annotators.chunker.utils import Chunk, RawChunk, resolve_raw_chunks +from ragdaemon.errors import RagdaemonError + + +async def chunk_document(document: str) -> list[Chunk]: + # Parse the code into an astroid AST + lines = document.split("\n") + file_path = lines[0].strip() + code = "\n".join(lines[1:]) + + tree = astroid.parse(code) + + chunks = list[RawChunk]() + + def extract_chunks(node, parent_path=None): + if isinstance(node, (astroid.FunctionDef, astroid.ClassDef)): + delimiter = ":" if parent_path == file_path else "." + current_path = f"{parent_path}{delimiter}{node.name}" + start_line, end_line = node.lineno, node.end_lineno + if start_line is None or end_line is None: + raise RagdaemonError(f"Function {node.name} has no line numbers.") + chunks.append( + RawChunk(id=current_path, start_line=start_line, end_line=end_line) + ) + # Recursively handle nested functions + for child in node.body: + extract_chunks(child, parent_path=current_path) + + # Recursively extract chunks from the AST + for node in tree.body: + extract_chunks(node, parent_path=file_path) + + return resolve_raw_chunks(document, chunks) diff --git a/ragdaemon/annotators/chunker/chunk_line.py b/ragdaemon/annotators/chunker/chunk_line.py new file mode 100644 index 0000000..f04cfd2 --- /dev/null +++ b/ragdaemon/annotators/chunker/chunk_line.py @@ -0,0 +1,28 @@ +async def chunk_document( + document: str, lines_per_chunk: int = 100 +) -> list[dict[str, str]]: + lines = document.split("\n") + file = lines[0] + file_lines = lines[1:] + if not file_lines or not any(line for line in file_lines): + return [] + + chunks = list[dict[str, str]]() + if len(file_lines) > lines_per_chunk: + chunks.append( + { + "id": f"{file}:BASE", + "ref": f"{file}:1-{lines_per_chunk}", + } + ) # First N lines is always the base chunk + for i, start_line in enumerate( + range(lines_per_chunk + 1, len(file_lines), lines_per_chunk) + ): + end_line = min(start_line + lines_per_chunk - 1, len(file_lines)) + chunks.append( + { + "id": f"{file}:chunk_{i + 1}", + "ref": f"{file}:{start_line}-{end_line}", + } + ) + return chunks diff --git a/ragdaemon/annotators/chunker/chunk_llm.py b/ragdaemon/annotators/chunker/chunk_llm.py new file mode 100644 index 0000000..022ef1c --- /dev/null +++ b/ragdaemon/annotators/chunker/chunk_llm.py @@ -0,0 +1,214 @@ +import json +from collections import Counter, defaultdict +from functools import partial +from json.decoder import JSONDecodeError +from typing import List, Optional + +from spice import Spice, SpiceMessages +from spice.models import GPT_4o + +from ragdaemon.annotators.chunker.utils import ( + Chunk, + RawChunk, + resolve_chunk_parent, + resolve_raw_chunks, +) +from ragdaemon.errors import RagdaemonError +from ragdaemon.utils import semaphore + + +class ChunkErrorInPreviousBatch(RagdaemonError): + pass + + +def validate( + response: str, + file: str, + max_line: int, + file_chunks: Optional[set[str]], + last_chunk: Optional[RawChunk], +): + try: + chunks = json.loads(response).get("chunks") + except JSONDecodeError: + return False + if not isinstance(chunks, list): + return False + if not all(isinstance(chunk, dict) for chunk in chunks): + return False + + for chunk in chunks: + if not set(chunk.keys()) == {"id", "start_line", "end_line"}: + return False + + halves = chunk["id"].split(":") + if len(halves) != 2 or not halves[0] or not halves[1]: + return False + if halves[0] != file: + return False + + start, end = chunk.get("start_line"), chunk.get("end_line") + if start is None or end is None: + return False + + # Sometimes output is int, sometimes string. This accomodates either. + start, end = str(start), str(end) + if not start.isdigit() or not end.isdigit(): + return False + start, end = int(start), int(end) + + if not 1 <= start <= end <= max_line: + return False + + if last_chunk is not None: + if not any(chunk["id"] == last_chunk["id"] for chunk in chunks): + return False + + """ + The LLM sometimes returns invalid parents (i.e. path/to/file.ext:parent.chunk). + There are 3 cases why they might be invalid: + A) The LLM made a typo here. In that case, return False to try again. + B) The LLM made a typo when parsing the parent in a previous batch. In that case, + go back and redo the previous batch. We distinguish this from case A) by checking + if multiple chunks reference the same invalid parent. + C) An edge case where our schema breaks down, e.g. Javascript event handlers + usually try to set "document" as their parent, but that won't be a node. + + Case A) should be resolved by Spice's validator loop, i.e. this function returning + "False". For Case B), raise a special exception and step back one batch in the + chunk_document loop. Any chunks still referencing invalid parents after these two + loops are exhausted (including case C)) will just be accepted and linked to + path/to/file.ext:BASE. + """ + if file_chunks: # else, loops exhausted or Case C) + valid_parents = file_chunks.copy() + chunks_shortest_first = sorted(chunks, key=lambda x: len(x["id"])) + chunks_missing_parents = set() + for chunk in chunks_shortest_first: + if not resolve_chunk_parent(chunk["id"], valid_parents): + chunks_missing_parents.add(chunk["id"]) + valid_parents.add(chunk["id"]) + + if len(chunks_missing_parents) > 1: + missing_parents = [] + for chunk in chunks_missing_parents: + file, chunk_str = chunk.split(":") + parts = chunk_str.split(".") + missing_parents.append(f"{file}:{'.'.join(parts[:-1])}") + mp_counts = Counter(missing_parents) + parent, count = mp_counts.most_common(1)[0] + if count > 1: + raise ChunkErrorInPreviousBatch(parent) # Case B) + return False # Case A) + + return True + + +async def get_llm_response( + spice_client: Spice, + file: str, + file_lines: list[str], + file_chunks: Optional[set[str]] = None, + last_chunk: Optional[RawChunk] = None, + verbose: int = 0, +) -> List[RawChunk]: + """Get one chunking response from the LLM model.""" + messages = SpiceMessages(spice_client) + messages.add_system_prompt(name="chunk_llm.base") + if last_chunk is not None: + messages.add_system_prompt("chunk_llm.continuation", last_chunk=last_chunk) + messages.add_user_prompt("chunk_llm.user", path=file, code="\n".join(file_lines)) + + max_line = int(file_lines[-1].split(":")[0]) # Extract line number + validator = partial( + validate, + file=file, + max_line=max_line, + file_chunks=file_chunks, + last_chunk=last_chunk, + ) + async with semaphore: + try: + response = await spice_client.get_response( + messages=messages, + model=GPT_4o, + response_format={"type": "json_object"}, + validator=validator, + retries=2, + ) + return json.loads(response.text).get("chunks") + except ValueError: + pass + validator = partial( + validate, + file=file, + max_line=max_line, + file_chunks=None, # Skip parent chunk validation + last_chunk=last_chunk, + ) + try: + response = await spice_client.get_response( + messages=messages, + model=GPT_4o, + response_format={"type": "json_object"}, + validator=validator, + retries=1, + ) + return json.loads(response.text).get("chunks") + except ValueError: + if verbose > 0: + print( + f"Failed to get chunks for {file} batch ending at line {max_line}." + ) + return [] + + +async def chunk_document( + document: str, + spice_client: Spice, + retries=1, + batch_size: int = 800, + verbose: int = 0, +) -> list[Chunk]: + """Parse file_lines into a list of {id, ref} chunks.""" + lines = document.split("\n") + file = lines[0] + file_lines = lines[1:] + if not file_lines or not any(line for line in file_lines): + return [] + file_lines = [f"{i+1}:{line}" for i, line in enumerate(file_lines)] + + # Get raw llm output: {id, start_line, end_line} + chunks = list[RawChunk]() + n_batches = (len(file_lines) + batch_size - 1) // batch_size + retries_by_batch = {i: retries for i in range(n_batches)} + chunk_index_by_batch = defaultdict(int) + i = 0 + while i < n_batches: + while retries_by_batch[i] >= 0: + batch_lines = file_lines[i * batch_size : (i + 1) * batch_size] + chunk_index_by_batch[i] = len(chunks) + last_chunk = chunks.pop() if chunks else None + if retries_by_batch[i] > 0: + file_chunks = {c["id"] for c in chunks} + else: + file_chunks = None # Skip parent chunk validation + try: + _chunks = await get_llm_response( + spice_client, + file, + batch_lines, + file_chunks, + last_chunk, + ) + chunks.extend(_chunks) + i += 1 + break + except ChunkErrorInPreviousBatch as e: + if verbose > 1: + print(f"Chunker missed parent {e} in file {file}, retrying.") + retries_by_batch[i] -= 1 + chunks = chunks[: chunk_index_by_batch[i]] + i = max(0, i - 1) + + return resolve_raw_chunks(document, chunks) diff --git a/ragdaemon/annotators/chunker/utils.py b/ragdaemon/annotators/chunker/utils.py new file mode 100644 index 0000000..ca06e8f --- /dev/null +++ b/ragdaemon/annotators/chunker/utils.py @@ -0,0 +1,82 @@ +from typing import TypedDict + +from ragdaemon.utils import lines_set_to_ref + + +class RawChunk(TypedDict): + id: str + start_line: int + end_line: int + + +class Chunk(TypedDict): + id: str + ref: str + + +def resolve_chunk_parent(id: str, nodes: set[str]) -> str | None: + file, chunk_str = id.split(":") + if chunk_str == "BASE": + return file + elif "." not in chunk_str: + return f"{file}:BASE" + else: + parts = chunk_str.split(".") + while True: + parent = f"{file}:{'.'.join(parts[:-1])}" + if parent in nodes: + return parent + parent_str = parent.split(":")[1] + if "." not in parent_str: + return None + # If intermediate parents are missing, skip them + parts = parent_str.split(".") + + +def resolve_raw_chunks(document: str, chunks: list[RawChunk]) -> list[Chunk]: + """Take a list of {id, start_line, end_line} and return a corrected list of {id, ref}.""" + + # Convert to {id: set(lines)} for easier manipulation + id_sets = {c["id"]: set(range(c["start_line"], c["end_line"] + 1)) for c in chunks} + + def update_parent_nodes(id: str, _id_sets: dict[str, set[int]]): + parent_lines = _id_sets[id] + child_chunks = {k: v for k, v in _id_sets.items() if k.startswith(id + ".")} + if child_chunks: + # Make sure end_line of each 'parent' chunk covers all children + start_line = min(parent_lines) + end_line = start_line + for child_lines in child_chunks.values(): + if not child_lines: + continue + end_line = max(end_line, max(child_lines)) + parent_lines = set(range(start_line, end_line + 1)) + # Remove child lines from parent lines + for child_lines in child_chunks.values(): + parent_lines -= child_lines + _id_sets[id] = parent_lines + return _id_sets + + ids_longest_first = sorted(id_sets.keys(), key=lambda x: len(x), reverse=True) + for id in ids_longest_first: + id_sets = update_parent_nodes(id, id_sets) + + file_lines = document.split("\n") + file = file_lines[0] + output = list[Chunk]() + if id_sets: + # Generate a 'BASE chunk' with all lines not already part of a chunk + base_chunk_lines = set(range(1, len(file_lines))) + for lines in id_sets.values(): + base_chunk_lines -= lines + lines_ref = lines_set_to_ref(base_chunk_lines) + ref = f"{file}:{lines_ref}" if lines_ref else file + base_chunk = Chunk(id=f"{file}:BASE", ref=ref) + output.append(base_chunk) + + # Convert to refs and return + for id, lines in id_sets.items(): + lines_ref = lines_set_to_ref(lines) + ref = f"{file}:{lines_ref}" if lines_ref else file + output.append(Chunk(id=id, ref=ref)) + return output diff --git a/ragdaemon/annotators/chunker_line.py b/ragdaemon/annotators/chunker_line.py deleted file mode 100644 index adcf51d..0000000 --- a/ragdaemon/annotators/chunker_line.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Any - -from ragdaemon.annotators.chunker import Chunker - - -class ChunkerLine(Chunker): - name = "chunker_line" - chunk_field_id = "chunks_line" - - def __init__(self, *args, lines_per_chunk=50, **kwargs): - super().__init__(*args, **kwargs) - self.n = lines_per_chunk - - async def chunk_document(self, document: str) -> list[dict[str, Any]]: - lines = document.split("\n") - file = lines[0] - file_lines = lines[1:] - if not file_lines or not any(line for line in file_lines): - return [] - - chunks = list[dict[str, Any]]() - if len(file_lines) > self.n: - chunks.append( - { - "id": f"{file}:BASE", - "start_line": "1", - "end_line": str(self.n), - } - ) # First N lines is always the base chunk - for i, start_line in enumerate(range(self.n + 1, len(file_lines), self.n)): - chunks.append( - { - "id": f"{file}:chunk_{i + 1}", - "start_line": str(start_line), - "end_line": str(min(start_line + self.n - 1, len(file_lines))), - } - ) - # Convert start/end to refs - return [ - { - "id": chunk["id"], - "ref": f"{file}:{chunk['start_line']}-{chunk['end_line']}", - } - for chunk in chunks - ] diff --git a/ragdaemon/annotators/chunker_llm.py b/ragdaemon/annotators/chunker_llm.py deleted file mode 100644 index e0e06bd..0000000 --- a/ragdaemon/annotators/chunker_llm.py +++ /dev/null @@ -1,264 +0,0 @@ -import json -from collections import Counter, defaultdict -from functools import partial -from json.decoder import JSONDecodeError -from typing import Any, Dict, List, Optional - -from spice import SpiceMessages -from spice.models import TextModel - -from ragdaemon.annotators.chunker import Chunker, resolve_chunk_parent -from ragdaemon.errors import RagdaemonError -from ragdaemon.utils import DEFAULT_COMPLETION_MODEL, lines_set_to_ref, semaphore - - -class ChunkErrorInPreviousBatch(RagdaemonError): - pass - - -def validate( - response: str, - file: str, - max_line: int, - file_chunks: Optional[set[str]], - last_chunk: Optional[dict[str, Any]], -): - try: - chunks = json.loads(response).get("chunks") - except JSONDecodeError: - return False - if not isinstance(chunks, list): - return False - if not all(isinstance(chunk, dict) for chunk in chunks): - return False - - for chunk in chunks: - if not set(chunk.keys()) == {"id", "start_line", "end_line"}: - return False - - halves = chunk["id"].split(":") - if len(halves) != 2 or not halves[0] or not halves[1]: - return False - if halves[0] != file: - return False - - start, end = chunk.get("start_line"), chunk.get("end_line") - if start is None or end is None: - return False - - # Sometimes output is int, sometimes string. This accomodates either. - start, end = str(start), str(end) - if not start.isdigit() or not end.isdigit(): - return False - start, end = int(start), int(end) - - if not 1 <= start <= end <= max_line: - return False - - if last_chunk is not None: - if not any(chunk["id"] == last_chunk["id"] for chunk in chunks): - return False - - """ - The LLM sometimes returns invalid parents (i.e. path/to/file.ext:parent.chunk). - There are 3 cases why they might be invalid: - A) The LLM made a typo here. In that case, return False to try again. - B) The LLM made a typo when parsing the parent in a previous batch. In that case, - go back and redo the previous batch. We distinguish this from case A) by checking - if multiple chunks reference the same invalid parent. - C) An edge case where our schema breaks down, e.g. Javascript event handlers - usually try to set "document" as their parent, but that won't be a node. - - Case A) should be resolved by Spice's validator loop, i.e. this function returning - "False". For Case B), raise a special exception and step back one batch in the - chunk_document loop. Any chunks still referencing invalid parents after these two - loops are exhausted (including case C)) will just be accepted and linked to - path/to/file.ext:BASE. - """ - if file_chunks: # else, loops exhausted or Case C) - valid_parents = file_chunks.copy() - chunks_shortest_first = sorted(chunks, key=lambda x: len(x["id"])) - chunks_missing_parents = set() - for chunk in chunks_shortest_first: - if not resolve_chunk_parent(chunk["id"], valid_parents): - chunks_missing_parents.add(chunk["id"]) - valid_parents.add(chunk["id"]) - - if len(chunks_missing_parents) > 1: - missing_parents = [] - for chunk in chunks_missing_parents: - file, chunk_str = chunk.split(":") - parts = chunk_str.split(".") - missing_parents.append(f"{file}:{'.'.join(parts[:-1])}") - mp_counts = Counter(missing_parents) - parent, count = mp_counts.most_common(1)[0] - if count > 1: - raise ChunkErrorInPreviousBatch(parent) # Case B) - return False # Case A) - - return True - - -class ChunkerLLM(Chunker): - name = "chunker_llm" - chunk_field_id = "chunks_llm" - - def __init__( - self, - *args, - batch_size: int = 800, - model: Optional[TextModel | str] = DEFAULT_COMPLETION_MODEL, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.batch_size = batch_size - self.model = model - - async def get_llm_response( - self, - file: str, - file_lines: list[str], - file_chunks: Optional[set[str]] = None, - last_chunk: Optional[dict[str, Any]] = None, - ) -> List[Dict[str, Any]]: - """Get one chunking response from the LLM model.""" - if self.spice_client is None: - raise RagdaemonError("Spice client is not initialized.") - - messages = SpiceMessages(self.spice_client) - messages.add_system_prompt(name="chunker_llm.base") - if last_chunk is not None: - messages.add_system_prompt( - "chunker_llm.continuation", last_chunk=last_chunk - ) - messages.add_user_prompt( - "chunker_llm.user", path=file, code="\n".join(file_lines) - ) - - max_line = int(file_lines[-1].split(":")[0]) # Extract line number - validator = partial( - validate, - file=file, - max_line=max_line, - file_chunks=file_chunks, - last_chunk=last_chunk, - ) - async with semaphore: - try: - response = await self.spice_client.get_response( - messages=messages, - model=self.model, - response_format={"type": "json_object"}, - validator=validator, - retries=2, - ) - return json.loads(response.text).get("chunks") - except ValueError: - pass - validator = partial( - validate, - file=file, - max_line=max_line, - file_chunks=None, # Skip parent chunk validation - last_chunk=last_chunk, - ) - try: - response = await self.spice_client.get_response( - messages=messages, - model=self.model, - response_format={"type": "json_object"}, - validator=validator, - retries=1, - ) - return json.loads(response.text).get("chunks") - except ValueError: - if self.verbose > 0: - print( - f"Failed to get chunks for {file} batch ending at line {max_line}." - ) - return [] - - async def chunk_document(self, document: str, retries=1) -> list[dict[str, Any]]: - """Parse file_lines into a list of {id, ref} chunks.""" - lines = document.split("\n") - file = lines[0] - file_lines = lines[1:] - if not file_lines or not any(line for line in file_lines): - return [] - file_lines = [f"{i+1}:{line}" for i, line in enumerate(file_lines)] - - # Get raw llm output: {id, start_line, end_line} - chunks = list[dict[str, Any]]() - n_batches = (len(file_lines) + self.batch_size - 1) // self.batch_size - retries_by_batch = {i: retries for i in range(n_batches)} - chunk_index_by_batch = defaultdict(int) - i = 0 - while i < n_batches: - while retries_by_batch[i] >= 0: - batch_lines = file_lines[ - i * self.batch_size : (i + 1) * self.batch_size - ] - chunk_index_by_batch[i] = len(chunks) - last_chunk = chunks.pop() if chunks else None - if retries_by_batch[i] > 0: - file_chunks = {c["id"] for c in chunks} - else: - file_chunks = None # Skip parent chunk validation - try: - _chunks = await self.get_llm_response( - file, batch_lines, file_chunks, last_chunk - ) - chunks.extend(_chunks) - i += 1 - break - except ChunkErrorInPreviousBatch as e: - if self.verbose > 1: - print(f"Chunker missed parent {e} in file {file}, retrying.") - retries_by_batch[i] -= 1 - chunks = chunks[: chunk_index_by_batch[i]] - i = max(0, i - 1) - - # Convert to {id: set(lines)} for easier manipulation - chunks = { - c["id"]: set(range(c["start_line"], c["end_line"] + 1)) for c in chunks - } - - def update_parent_nodes(id: str, _chunks: dict[str, set[int]]): - parent_lines = _chunks[id] - child_chunks = {k: v for k, v in _chunks.items() if k.startswith(id + ".")} - if child_chunks: - # Make sure end_line of each 'parent' chunk covers all children - start_line = min(parent_lines) - end_line = start_line - for child_lines in child_chunks.values(): - if not child_lines: - continue - end_line = max(end_line, max(child_lines)) - parent_lines = set(range(start_line, end_line + 1)) - # Remove child lines from parent lines - for child_lines in child_chunks.values(): - parent_lines -= child_lines - _chunks[id] = parent_lines - return _chunks - - ids_longest_first = sorted(chunks, key=lambda x: len(x), reverse=True) - for id in ids_longest_first: - chunks = update_parent_nodes(id, chunks) - - output = [] - if chunks: - # Generate a 'BASE chunk' with all lines not already part of a chunk - base_chunk_lines = set(range(1, len(file_lines) + 1)) - for lines in chunks.values(): - base_chunk_lines -= lines - lines_ref = lines_set_to_ref(base_chunk_lines) - ref = f"{file}:{lines_ref}" if lines_ref else file - base_chunk = {"id": f"{file}:BASE", "ref": ref} - output.append(base_chunk) - - # Convert to refs and return - for id, lines in chunks.items(): - lines_ref = lines_set_to_ref(lines) - ref = f"{file}:{lines_ref}" if lines_ref else file - output.append({"id": id, "ref": ref}) - return output diff --git a/ragdaemon/app.py b/ragdaemon/app.py index c5b2912..42da668 100644 --- a/ragdaemon/app.py +++ b/ragdaemon/app.py @@ -39,7 +39,7 @@ diff = args.diff annotators = { "hierarchy": {}, - "chunker_llm": {"chunk_extensions": code_extensions}, + "chunker": {"use_llm": True}, # "summarizer": {}, # "clusterer_binary": {}, # "call_graph": {"call_extensions": code_extensions}, diff --git a/ragdaemon/daemon.py b/ragdaemon/daemon.py index 190a39f..0fbf490 100644 --- a/ragdaemon/daemon.py +++ b/ragdaemon/daemon.py @@ -22,7 +22,7 @@ def default_annotators(): return { "hierarchy": {}, - "chunker_line": {"lines_per_chunk": 30}, + "chunker": {"use_llm": False}, "diff": {}, } @@ -109,7 +109,7 @@ async def update(self, refresh: str | bool = False): Refresh can be - boolean to refresh all annotators/nodes - - string matching annotator names / node ids, e.g. ("chunker_llm") + - string matching annotator names / node ids, e.g. ("chunker") - string with wildcard operators to fuzzy-match annotators/nodes, e.g. ("*diff*") """ _graph = self.graph.copy() diff --git a/ragdaemon/database/__init__.py b/ragdaemon/database/__init__.py index 148e346..3df5006 100644 --- a/ragdaemon/database/__init__.py +++ b/ragdaemon/database/__init__.py @@ -1,4 +1,4 @@ -import os +import os # noqa: F401 from pathlib import Path from typing import Optional @@ -10,8 +10,11 @@ remove_update_db_duplicates, # noqa: F401 ) from ragdaemon.database.database import Database + +# from ragdaemon.database.chroma_database import ChromaDB from ragdaemon.database.lite_database import LiteDB -from ragdaemon.database.pg_database import PGDB + +# from ragdaemon.database.pg_database import PGDB from ragdaemon.utils import mentat_dir_path DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large" @@ -26,25 +29,25 @@ def get_db( ) -> Database: db_path = mentat_dir_path / "chroma" db_path.mkdir(parents=True, exist_ok=True) - if embedding_model is not None and "PYTEST_CURRENT_TEST" not in os.environ: - try: - # db = ChromaDB( - # cwd=cwd, - # db_path=db_path, - # spice_client=spice_client, - # embedding_model=embedding_model, - # embedding_provider=embedding_provider, - # verbose=verbose, - # ) - # # In case the api key is wrong, try to embed something to trigger an error. - # _ = db.add(ids="test", documents="test doc") - # db.delete(ids="test") - db = PGDB(cwd=cwd, db_path=db_path, verbose=verbose) - return db - except Exception as e: - if verbose > 1: - print( - f"Failed to initialize Postgres Database: {e}. Falling back to LiteDB." - ) - pass + # if embedding_model is not None and "PYTEST_CURRENT_TEST" not in os.environ: + # try: + # # db = ChromaDB( + # # cwd=cwd, + # # db_path=db_path, + # # spice_client=spice_client, + # # embedding_model=embedding_model, + # # embedding_provider=embedding_provider, + # # verbose=verbose, + # # ) + # # # In case the api key is wrong, try to embed something to trigger an error. + # # _ = db.add(ids="test", documents="test doc") + # # db.delete(ids="test") + # db = PGDB(cwd=cwd, db_path=db_path, verbose=verbose) + # return db + # except Exception as e: + # if verbose > 1: + # print( + # f"Failed to initialize Postgres Database: {e}. Falling back to LiteDB." + # ) + # pass return LiteDB(cwd=cwd, db_path=db_path, verbose=verbose) diff --git a/ragdaemon/database/pg_database.py b/ragdaemon/database/pg_database.py index 407ed1c..d4c96e7 100644 --- a/ragdaemon/database/pg_database.py +++ b/ragdaemon/database/pg_database.py @@ -20,10 +20,10 @@ class DocumentMetadata(Base): id: Mapped[str] = mapped_column(primary_key=True) # We serialize whatever we get, which can be 'null', so we need Optional - chunks_llm: Mapped[Optional[str]] + chunks: Mapped[Optional[str]] -def retry_on_exception(retries: int=3, exceptions={OperationalError}): +def retry_on_exception(retries: int = 3, exceptions={OperationalError}): def decorator(func): def wrapper(*args, **kwargs): for i in range(retries): @@ -33,7 +33,9 @@ def wrapper(*args, **kwargs): print(f"Caught exception: {e}") if i == retries - 1: raise e + return wrapper + return decorator @@ -132,7 +134,7 @@ def __init__(self, *args, **kwargs): class PGCollection(LiteCollection): """Wraps a LiteDB and adds/gets targeted fields from a remote Postgres Database.""" - def __init__(self, *args, fields: list[str] = ["chunks_llm"], **kwargs): + def __init__(self, *args, fields: list[str] = ["chunks"], **kwargs): super().__init__(*args, **kwargs) self.engine = Engine(self.verbose) self.fields = fields @@ -169,21 +171,14 @@ def get( include: list[str] | None = None, ): response = super().get(ids, include) - if include is not None and "metadatas" in include: - if not isinstance(ids, list): - ids = [ids] - remote_metadatas = self.engine.get_document_metadata(ids) - seen = set() + response_ids = response.get("ids", []) + if response_ids and include is not None and "metadatas" in include: + remote_metadatas = self.engine.get_document_metadata(response_ids) for id, metadata in zip( response.get("ids", []), response.get("metadatas", []) ): if id in remote_metadatas: metadata.update(remote_metadatas[id]) - seen.add(id) - for k, v in remote_metadatas.items(): - if k not in seen: - response["ids"].append(k) - response["metadatas"].append(v) return response diff --git a/ragdaemon/prompts/chunker_llm.toml b/ragdaemon/prompts/chunk_llm.toml similarity index 100% rename from ragdaemon/prompts/chunker_llm.toml rename to ragdaemon/prompts/chunk_llm.toml diff --git a/tests/annotators/test_chunker.py b/tests/annotators/test_chunker.py index 595a2c4..d154899 100644 --- a/tests/annotators/test_chunker.py +++ b/tests/annotators/test_chunker.py @@ -1,22 +1,14 @@ from pathlib import Path -from unittest.mock import AsyncMock, patch import pytest -from ragdaemon.annotators import Chunker, ChunkerLLM +from ragdaemon.annotators import Chunker +from ragdaemon.annotators.chunker.chunk_llm import chunk_document as chunk_llm +from ragdaemon.annotators.chunker.chunk_astroid import chunk_document as chunk_astroid from ragdaemon.daemon import Daemon from ragdaemon.graph import KnowledgeGraph -@pytest.fixture -def mock_get_llm_response(): - with patch( - "ragdaemon.annotators.chunker_llm.ChunkerLLM.get_llm_response", - return_value={"chunks_llm": []}, - ) as mock: - yield mock - - def test_chunker_is_complete(cwd, mock_db): chunker = Chunker() @@ -49,16 +41,68 @@ def test_chunker_is_complete(cwd, mock_db): ), "Chunker graph should be complete." +@pytest.fixture +def expected_chunks(): + return [ + {"id": "src/calculator.py:BASE", "ref": "src/calculator.py:1-4,29,42-45"}, + { + "id": "src/calculator.py:Calculator", + "ref": "src/calculator.py:5,10,13,16,19", + }, + {"id": "src/calculator.py:Calculator.__init__", "ref": "src/calculator.py:6-9"}, + { + "id": "src/calculator.py:Calculator.add_numbers", + "ref": "src/calculator.py:11-12", + }, + { + "id": "src/calculator.py:Calculator.subtract_numbers", + "ref": "src/calculator.py:14-15", + }, + { + "id": "src/calculator.py:Calculator.exp_numbers", + "ref": "src/calculator.py:17-18", + }, + {"id": "src/calculator.py:Calculator.call", "ref": "src/calculator.py:20-28"}, + {"id": "src/calculator.py:main", "ref": "src/calculator.py:30-41"}, + ] + + @pytest.mark.asyncio -async def test_chunker_llm_annotate(cwd, mock_get_llm_response, mock_db): - daemon = Daemon( - cwd=cwd, - annotators={"hierarchy": {}}, +async def test_chunker_astroid(cwd, expected_chunks): + text = Path("tests/data/hard_to_chunk.txt").read_text() + document = f"src/calculator.py\n{text}" + actual_chunks = await chunk_astroid(document) + + assert len(actual_chunks) == len(expected_chunks) + actual_chunks = sorted(actual_chunks, key=lambda x: x["ref"]) + expected_chunks = sorted(expected_chunks, key=lambda x: x["ref"]) + for actual, expected in zip(actual_chunks, expected_chunks): + assert actual == expected + + +@pytest.mark.skip(reason="This test requires calling an API") +@pytest.mark.asyncio +async def test_chunk_llm(cwd, expected_chunks): + # NOTE: TO RUN THIS YOU HAVE TO COMMENT_OUT tests/conftest.py/mock_openai_api_key + daemon = Daemon(cwd, annotators={"hierarchy": {}}) + + # One example with all the edge cases (when batch_size = 10 lines): + # - First batch ends mid-class, so second batch needs 'call path' + # - Second batch ends mid-function, third batch needs to pickup where it left off + # - Third batch is all inside one function, so needs to pass call forward. + text = Path("tests/data/hard_to_chunk.txt").read_text() + document = f"src/calculator.py\n{text}" + actual_chunks = await chunk_llm( + spice_client=daemon.spice_client, + document=document, + batch_size=10, + verbose=2, ) - chunker = ChunkerLLM(spice_client=AsyncMock()) - actual = await chunker.annotate(daemon.graph, mock_db) - for node, data in actual.nodes(data=True): - assert data, f"Node {node} is missing data" - if data["type"] == "file" and Path(node).suffix in chunker.chunk_extensions: - assert "chunks_llm" in data, f"File {node} is missing chunks" + print(actual_chunks) + + assert len(actual_chunks) == len(expected_chunks) + actual_chunks = sorted(actual_chunks, key=lambda x: x["ref"]) + expected_chunks = sorted(expected_chunks, key=lambda x: x["ref"]) + for actual, expected in zip(actual_chunks, expected_chunks): + assert actual == expected diff --git a/tests/annotators/test_chunker_llm.py b/tests/annotators/test_chunker_llm.py deleted file mode 100644 index 5eee24a..0000000 --- a/tests/annotators/test_chunker_llm.py +++ /dev/null @@ -1,56 +0,0 @@ -from pathlib import Path - -import pytest - -from ragdaemon.annotators.chunker_llm import ChunkerLLM -from ragdaemon.daemon import Daemon - - -@pytest.fixture -def expected_chunks(): - return [ - {"id": "src/calculator.py:BASE", "ref": "src/calculator.py:1-4,29,42-45"}, - { - "id": "src/calculator.py:Calculator", - "ref": "src/calculator.py:5,10,13,16,19", - }, - {"id": "src/calculator.py:Calculator.__init__", "ref": "src/calculator.py:6-9"}, - { - "id": "src/calculator.py:Calculator.add_numbers", - "ref": "src/calculator.py:11-12", - }, - { - "id": "src/calculator.py:Calculator.subtract_numbers", - "ref": "src/calculator.py:14-15", - }, - { - "id": "src/calculator.py:Calculator.exp_numbers", - "ref": "src/calculator.py:17-18", - }, - {"id": "src/calculator.py:Calculator.call", "ref": "src/calculator.py:20-28"}, - {"id": "src/calculator.py:main", "ref": "src/calculator.py:30-41"}, - ] - - -@pytest.mark.skip(reason="This test requires calling an API") -@pytest.mark.asyncio -async def test_chunker_llm_edge_cases(cwd, expected_chunks): - # NOTE: TO RUN THIS YOU HAVE TO COMMENT_OUT tests/conftest.py/mock_openai_api_key - daemon = Daemon(cwd, annotators={"hierarchy": {}}) - chunker = ChunkerLLM(spice_client=daemon.spice_client, batch_size=10) - - # One example with all the edge cases (when batch_size = 10 lines): - # - First batch ends mid-class, so second batch needs 'call path' - # - Second batch ends mid-function, third batch needs to pickup where it left off - # - Third batch is all inside one function, so needs to pass call forward. - text = Path("tests/data/hard_to_chunk.txt").read_text() - document = f"src/calculator.py\n{text}" - actual_chunks = await chunker.chunk_document(document) - - print(actual_chunks) - - assert len(actual_chunks) == len(expected_chunks) - actual_chunks = sorted(actual_chunks, key=lambda x: x["ref"]) - expected_chunks = sorted(expected_chunks, key=lambda x: x["ref"]) - for actual, expected in zip(actual_chunks, expected_chunks): - assert actual == expected diff --git a/tests/data/summarizer_graph.json b/tests/data/summarizer_graph.json index e054c12..75fac61 100644 --- a/tests/data/summarizer_graph.json +++ b/tests/data/summarizer_graph.json @@ -9,7 +9,7 @@ { "calls": "{}", "checksum": "0d78297d1a17a762d876be21cc8692cb", - "chunks_llm": [ + "chunks": [ { "id": "src/interface.py:BASE", "ref": "src/interface.py:1-4,15-16,19" @@ -42,7 +42,7 @@ { "calls": "{}", "checksum": "b9ba74388a4d956f0aff968bfc165db3", - "chunks_llm": [], + "chunks": [], "id": "src/__init__.py", "ref": "src/__init__.py", "summary": "Establish the 'src' as a Python package to organize related modules concerning command-line based arithmetic operations, without adding any explicit functionality.", @@ -62,7 +62,7 @@ { "calls": "{}", "checksum": "cfe1b2f9cda812d0e1f68eac86539e94", - "chunks_llm": [ + "chunks": [ { "id": "src/operations.py:BASE", "ref": "src/operations.py:1-3,6-7,10-11,14-15,18-19,22" @@ -98,7 +98,7 @@ { "calls": "{\"src/interface.py:parse_arguments\": [6], \"src/interface.py:render_response\": [19], \"src/operations.py:add\": [9], \"src/operations.py:subtract\": [11], \"src/operations.py:multiply\": [13], \"src/operations.py:divide\": [15]}", "checksum": "30a15283b0f5d5ac17a2d890a00675d9", - "chunks_llm": [ + "chunks": [ { "id": "main.py:BASE", "ref": "main.py:1-4,20-24"