Skip to content

Commit

Permalink
Deterministic Chunking with Astroid
Browse files Browse the repository at this point in the history
* implement python-specific chunker using astroid

* use LiteDB by default

* minor version bump

* format fixes
  • Loading branch information
granawkins authored May 26, 2024
1 parent 593b805 commit 5d1e8ac
Show file tree
Hide file tree
Showing 18 changed files with 509 additions and 485 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ packages=["ragdaemon"]

[project]
name = "ragdaemon"
version = "0.6.2"
version = "0.7.0"
description = "Generate and render a call graph for a Python project."
readme = "README.md"
dependencies = [
"astroid==3.2.2",
"chromadb==0.4.24",
"dict2xml==1.7.5",
"fastapi==0.109.2",
Expand Down
2 changes: 1 addition & 1 deletion ragdaemon/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.6.2"
__version__ = "0.7.0"
4 changes: 0 additions & 4 deletions ragdaemon/annotators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from ragdaemon.annotators.base_annotator import Annotator # noqa: F401
from ragdaemon.annotators.call_graph import CallGraph # noqa: F401
from ragdaemon.annotators.chunker import Chunker
from ragdaemon.annotators.chunker_line import ChunkerLine
from ragdaemon.annotators.chunker_llm import ChunkerLLM
from ragdaemon.annotators.diff import Diff
from ragdaemon.annotators.hierarchy import Hierarchy
from ragdaemon.annotators.layout_hierarchy import LayoutHierarchy
Expand All @@ -11,8 +9,6 @@
annotators_map = {
"call_graph": CallGraph,
"chunker": Chunker,
"chunker_line": ChunkerLine,
"chunker_llm": ChunkerLLM,
"diff": Diff,
"hierarchy": Hierarchy,
"layout_hierarchy": LayoutHierarchy,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,10 @@
"""
Chunk data a list of objects following [
{id: path/to/file:class.method, start_line: int, end_line: int}
]
It's stored on the file node as data['chunks'] and json.dumped into the database.
A chunker annotator:
1. Is complete when all files (with matching extensions) have a 'chunks' field
2. Generates chunks using a subclass method (llm, ctags..)
3. Adds that data to each file's graph node and database record
4. Add graph nodes (and db records) for each of those chunks
5. Add hierarchy edges connecting everything back to cwd
The Chunker base class below handles everything except step 2.
"""

import asyncio
import json
from copy import deepcopy
from functools import partial
from pathlib import Path
from typing import Any, Optional

from astroid.exceptions import AstroidSyntaxError
from tqdm.asyncio import tqdm

from ragdaemon.annotators.base_annotator import Annotator
Expand All @@ -29,6 +13,11 @@
remove_add_to_db_duplicates,
remove_update_db_duplicates,
)
from ragdaemon.annotators.chunker.utils import resolve_chunk_parent
from ragdaemon.annotators.chunker.chunk_astroid import chunk_document as chunk_astroid
from ragdaemon.annotators.chunker.chunk_llm import chunk_document as chunk_llm
from ragdaemon.annotators.chunker.chunk_line import chunk_document as chunk_line

from ragdaemon.errors import RagdaemonError
from ragdaemon.graph import KnowledgeGraph
from ragdaemon.utils import (
Expand All @@ -40,34 +29,39 @@
)


def resolve_chunk_parent(id: str, nodes: set[str]) -> str | None:
file, chunk_str = id.split(":")
if chunk_str == "BASE":
return file
elif "." not in chunk_str:
return f"{file}:BASE"
else:
parts = chunk_str.split(".")
while True:
parent = f"{file}:{'.'.join(parts[:-1])}"
if parent in nodes:
return parent
parent_str = parent.split(":")[1]
if "." not in parent_str:
return None
# If intermediate parents are missing, skip them
parts = parent_str.split(".")


class Chunker(Annotator):
name = "chunker"
chunk_field_id = "chunks"

def __init__(self, *args, chunk_extensions: Optional[list[str]] = None, **kwargs):
def __init__(self, *args, use_llm: bool = False, **kwargs):
super().__init__(*args, **kwargs)
if chunk_extensions is None:
chunk_extensions = DEFAULT_CODE_EXTENSIONS
self.chunk_extensions = chunk_extensions

# By default, use either the LLM chunker or a basic line chunker.
if use_llm and self.spice_client is not None:
default_chunk_fn = partial(
chunk_llm, spice_client=self.spice_client, verbose=self.verbose
)
else:
default_chunk_fn = chunk_line

# For python files, try to use astroid. If that fails, fall back to the default chunker.
async def python_chunk_fn(document: str):
try:
return await chunk_astroid(document)
except AstroidSyntaxError:
if self.verbose > 0:
file = document.split("\n")[0]
print(
f"Error chunking {file} with astroid; falling back to default chunker."
)
return await default_chunk_fn(document)

self.chunk_extensions_map = {}
for extension in DEFAULT_CODE_EXTENSIONS:
if extension == ".py":
self.chunk_extensions_map[extension] = python_chunk_fn
else:
self.chunk_extensions_map[extension] = default_chunk_fn

def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool:
for node, data in graph.nodes(data=True):
Expand All @@ -77,10 +71,10 @@ def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool:
continue
chunks = data.get(self.chunk_field_id, None)
if chunks is None:
if self.chunk_extensions is None:
if self.chunk_extensions_map is None:
return False
extension = Path(data["ref"]).suffix
if extension in self.chunk_extensions:
if extension in self.chunk_extensions_map:
return False
else:
if not isinstance(chunks, list):
Expand All @@ -90,15 +84,12 @@ def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool:
return False
return True

async def chunk_document(self, document: str) -> list[dict[str, Any]]:
"""Return a list of {id, ref} chunks for the given document."""
raise NotImplementedError()

async def get_file_chunk_data(self, node, data):
"""Generate and save chunk data for a file node to graph and db"""
document = data["document"]
extension = Path(data["ref"]).suffix
try:
chunks = await self.chunk_document(document)
chunks = await self.chunk_extensions_map[extension](document)
except RagdaemonError:
if self.verbose > 0:
print(f"Error chunking {node}; skipping.")
Expand All @@ -118,11 +109,11 @@ async def annotate(
if data.get("type") == "chunk":
graph.remove_node(node)
elif data.get("type") == "file":
if self.chunk_extensions is None:
if self.chunk_extensions_map is None:
files_with_chunks.append((node, data))
else:
extension = Path(data["ref"]).suffix
if extension in self.chunk_extensions:
if extension in self.chunk_extensions_map:
files_with_chunks.append((node, data))

# Generate/add chunk data for nodes that don't have it
Expand Down
35 changes: 35 additions & 0 deletions ragdaemon/annotators/chunker/chunk_astroid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import astroid

from ragdaemon.annotators.chunker.utils import Chunk, RawChunk, resolve_raw_chunks
from ragdaemon.errors import RagdaemonError


async def chunk_document(document: str) -> list[Chunk]:
# Parse the code into an astroid AST
lines = document.split("\n")
file_path = lines[0].strip()
code = "\n".join(lines[1:])

tree = astroid.parse(code)

chunks = list[RawChunk]()

def extract_chunks(node, parent_path=None):
if isinstance(node, (astroid.FunctionDef, astroid.ClassDef)):
delimiter = ":" if parent_path == file_path else "."
current_path = f"{parent_path}{delimiter}{node.name}"
start_line, end_line = node.lineno, node.end_lineno
if start_line is None or end_line is None:
raise RagdaemonError(f"Function {node.name} has no line numbers.")
chunks.append(
RawChunk(id=current_path, start_line=start_line, end_line=end_line)
)
# Recursively handle nested functions
for child in node.body:
extract_chunks(child, parent_path=current_path)

# Recursively extract chunks from the AST
for node in tree.body:
extract_chunks(node, parent_path=file_path)

return resolve_raw_chunks(document, chunks)
28 changes: 28 additions & 0 deletions ragdaemon/annotators/chunker/chunk_line.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
async def chunk_document(
document: str, lines_per_chunk: int = 100
) -> list[dict[str, str]]:
lines = document.split("\n")
file = lines[0]
file_lines = lines[1:]
if not file_lines or not any(line for line in file_lines):
return []

chunks = list[dict[str, str]]()
if len(file_lines) > lines_per_chunk:
chunks.append(
{
"id": f"{file}:BASE",
"ref": f"{file}:1-{lines_per_chunk}",
}
) # First N lines is always the base chunk
for i, start_line in enumerate(
range(lines_per_chunk + 1, len(file_lines), lines_per_chunk)
):
end_line = min(start_line + lines_per_chunk - 1, len(file_lines))
chunks.append(
{
"id": f"{file}:chunk_{i + 1}",
"ref": f"{file}:{start_line}-{end_line}",
}
)
return chunks
Loading

0 comments on commit 5d1e8ac

Please sign in to comment.