diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 4c70a44..a97034c 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -26,7 +26,7 @@ jobs: python -m pip install --upgrade pip pip install -e . pip install -e .[dev] - + - name: Format check run: ruff format . diff --git a/pyproject.toml b/pyproject.toml index 6736684..bc55202 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,13 +7,14 @@ packages=["ragdaemon"] [project] name = "ragdaemon" -version = "0.7.8" +version = "0.8.0" description = "Generate and render a call graph for a Python project." readme = "README.md" dependencies = [ "astroid==3.2.2", "chromadb==0.4.24", "dict2xml==1.7.5", + "docker==7.1.0", "fastapi==0.109.2", "Jinja2==3.1.3", "networkx==3.2.1", @@ -41,10 +42,10 @@ ragdaemon = "ragdaemon.__main__:run" [project.optional-dependencies] dev = [ "ruff", - "pyright", + "pyright==1.1.372", "pytest", "pytest-asyncio" ] [tool.pyright] -ignore = ["tests/sample"] +ignore = ["tests/sample", "venv", ".venv"] diff --git a/ragdaemon/__init__.py b/ragdaemon/__init__.py index 894cebc..777f190 100644 --- a/ragdaemon/__init__.py +++ b/ragdaemon/__init__.py @@ -1 +1 @@ -__version__ = "0.7.8" +__version__ = "0.8.0" diff --git a/ragdaemon/annotators/base_annotator.py b/ragdaemon/annotators/base_annotator.py index ad9157d..e685947 100644 --- a/ragdaemon/annotators/base_annotator.py +++ b/ragdaemon/annotators/base_annotator.py @@ -6,6 +6,7 @@ from ragdaemon.database import Database from ragdaemon.graph import KnowledgeGraph +from ragdaemon.io import IO class Annotator: @@ -13,10 +14,12 @@ class Annotator: def __init__( self, + io: IO, verbose: int = 0, spice_client: Optional[Spice] = None, pipeline: Optional[dict[str, Annotator]] = None, ): + self.io = io self.verbose = verbose self.spice_client = spice_client pass diff --git a/ragdaemon/annotators/chunker/__init__.py b/ragdaemon/annotators/chunker/__init__.py index be76f2c..800090b 100644 --- a/ragdaemon/annotators/chunker/__init__.py +++ b/ragdaemon/annotators/chunker/__init__.py @@ -165,7 +165,7 @@ async def annotate( # Load chunks into graph for chunk in chunks: id, ref = chunk["id"], chunk["ref"] - document = get_document(ref, Path(graph.graph["cwd"])) + document = get_document(ref, self.io, type="chunk") checksum = hash_str(document) chunk_data = { "id": id, diff --git a/ragdaemon/annotators/diff.py b/ragdaemon/annotators/diff.py index a0c1adb..fc41025 100644 --- a/ragdaemon/annotators/diff.py +++ b/ragdaemon/annotators/diff.py @@ -1,11 +1,9 @@ import json import re from copy import deepcopy -from pathlib import Path from ragdaemon.annotators.base_annotator import Annotator from ragdaemon.database import Database, remove_add_to_db_duplicates -from ragdaemon.get_paths import get_git_root_for_path from ragdaemon.graph import KnowledgeGraph from ragdaemon.errors import RagdaemonError from ragdaemon.utils import ( @@ -74,19 +72,17 @@ def id(self) -> str: return "DEFAULT" if not self.diff_args else self.diff_args def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool: - cwd = Path(graph.graph["cwd"]) - if not get_git_root_for_path(cwd, raise_error=False): + if not self.io.is_git_repo(): return True - document = get_document(self.diff_args, cwd, type="diff") + document = get_document(self.diff_args, self.io, type="diff") checksum = hash_str(document) return self.id in graph and graph.nodes[self.id]["checksum"] == checksum async def annotate( self, graph: KnowledgeGraph, db: Database, refresh: str | bool = False ) -> KnowledgeGraph: - cwd = Path(graph.graph["cwd"]) - if not get_git_root_for_path(cwd, raise_error=False): + if not self.io.is_git_repo(): return graph graph_nodes = { @@ -97,7 +93,7 @@ async def annotate( graph.remove_nodes_from(graph_nodes) checksums = dict[str, str]() - document = get_document(self.diff_args, cwd, type="diff") + document = get_document(self.diff_args, self.io, type="diff") checksum = hash_str(document) chunks = get_chunks_from_diff(id=self.id, diff=document) data = { @@ -112,7 +108,7 @@ async def annotate( checksums[self.id] = checksum for chunk_id, chunk_ref in chunks.items(): - document = get_document(chunk_ref, cwd, type="diff") + document = get_document(chunk_ref, self.io, type="diff") chunk_checksum = hash_str(document) data = { "id": chunk_id, diff --git a/ragdaemon/annotators/hierarchy.py b/ragdaemon/annotators/hierarchy.py index 0f464f8..b0d45e4 100644 --- a/ragdaemon/annotators/hierarchy.py +++ b/ragdaemon/annotators/hierarchy.py @@ -3,17 +3,17 @@ from ragdaemon.annotators.base_annotator import Annotator from ragdaemon.database import Database, remove_add_to_db_duplicates -from ragdaemon.get_paths import get_paths_for_directory from ragdaemon.graph import KnowledgeGraph from ragdaemon.errors import RagdaemonError +from ragdaemon.io import IO from ragdaemon.utils import get_document, hash_str, truncate -def files_checksum(cwd: Path, ignore_patterns: set[Path] = set()) -> str: +def files_checksum(io: IO, ignore_patterns: set[Path] = set()) -> str: timestamps = "" - for path in get_paths_for_directory(cwd, exclude_patterns=ignore_patterns): + for path in io.get_paths_for_directory(exclude_patterns=ignore_patterns): try: - timestamps += str((cwd / path).stat().st_mtime) + timestamps += str(io.last_modified(path)) except FileNotFoundError: pass return hash_str(timestamps) @@ -28,9 +28,8 @@ def __init__(self, *args, ignore_patterns: set[Path] = set(), **kwargs): super().__init__(*args, **kwargs) def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool: - cwd = Path(graph.graph["cwd"]) return graph.graph.get("files_checksum") == files_checksum( - cwd, self.ignore_patterns + self.io, self.ignore_patterns ) async def annotate( @@ -45,12 +44,12 @@ async def annotate( # Load active files/dirs and checksums checksums = dict[Path, str]() - paths = get_paths_for_directory(cwd, exclude_patterns=self.ignore_patterns) + paths = self.io.get_paths_for_directory(exclude_patterns=self.ignore_patterns) directories = set() edges = set() for path in paths: path_str = path.as_posix() - document = get_document(path_str, cwd) + document = get_document(path_str, self.io) checksum = hash_str(document) data = { "id": path_str, @@ -115,5 +114,5 @@ async def annotate( add_to_db = remove_add_to_db_duplicates(**add_to_db) db.add(**add_to_db) - graph.graph["files_checksum"] = files_checksum(cwd, self.ignore_patterns) + graph.graph["files_checksum"] = files_checksum(self.io, self.ignore_patterns) return graph diff --git a/ragdaemon/annotators/summarizer.py b/ragdaemon/annotators/summarizer.py index c9c7806..eddf402 100644 --- a/ragdaemon/annotators/summarizer.py +++ b/ragdaemon/annotators/summarizer.py @@ -12,6 +12,7 @@ from ragdaemon.database import Database, remove_update_db_duplicates from ragdaemon.graph import KnowledgeGraph from ragdaemon.errors import RagdaemonError +from ragdaemon.io import IO from ragdaemon.utils import ( DEFAULT_COMPLETION_MODEL, match_refresh, @@ -84,6 +85,7 @@ def build_filetree( def get_document_and_context( node: str, graph: KnowledgeGraph, + io: IO, summary_field_id: str = "summary", model: Optional[TextModel] = None, ) -> tuple[str, str]: @@ -98,12 +100,12 @@ def get_document_and_context( if data.get("type") == "directory": document = f"Directory: {node}" else: - cb = ContextBuilder(graph) + cb = ContextBuilder(graph, io) cb.add_id(node) document = cb.render() if data.get("type") == "chunk": - cb = ContextBuilder(graph) + cb = ContextBuilder(graph, io) # Parent chunks back to the file def get_hierarchical_parents(target: str, cb: ContextBuilder): @@ -253,7 +255,11 @@ async def generate_summary( or summary_checksum != data.get(self.checksum_field_id) ): document, context = get_document_and_context( - node, graph, summary_field_id=self.summary_field_id, model=self.model + node, + graph, + self.io, + summary_field_id=self.summary_field_id, + model=self.model, ) subprompt = "root" if node == "ROOT" else data.get("type", "") previous_summary = "" if _refresh else data.get(self.summary_field_id, "") diff --git a/ragdaemon/context.py b/ragdaemon/context.py index c1e4d02..a80c79f 100644 --- a/ragdaemon/context.py +++ b/ragdaemon/context.py @@ -8,6 +8,7 @@ from dict2xml import dict2xml from ragdaemon.errors import RagdaemonError from ragdaemon.graph import KnowledgeGraph +from ragdaemon.io import IO from ragdaemon.utils import get_document, parse_diff_id, parse_path_ref NestedStrDict = Union[str, Dict[str, "NestedStrDict"]] @@ -36,15 +37,16 @@ def render_comments(comments: list[Comment]) -> str: class ContextBuilder: """Renders items from a graph into an llm-readable string.""" - def __init__(self, graph: KnowledgeGraph, verbose: int = 0): + def __init__(self, graph: KnowledgeGraph, io: IO, verbose: int = 0): self.graph = graph + self.io = io self.verbose = verbose self.context = dict[ str, dict[str, Any] ]() # {path: {lines, tags, document, diff}} def copy(self): - duplicate = ContextBuilder(self.graph, self.verbose) + duplicate = ContextBuilder(self.graph, self.io, self.verbose) duplicate.context = deepcopy(self.context) return duplicate @@ -73,8 +75,7 @@ def _add_path(self, path_str: str): if document is None: # Truncated or deleted try: # TODO: Add ignored files to the graph/database - cwd = Path(self.graph.graph["cwd"]) - document = get_document(path_str, cwd, type="file") + document = get_document(path_str, self.io, type="file") except FileNotFoundError: # Or could be deleted but have a diff document = f"{path_str}\n[DELETED]" diff --git a/ragdaemon/daemon.py b/ragdaemon/daemon.py index cc51d86..49fc841 100644 --- a/ragdaemon/daemon.py +++ b/ragdaemon/daemon.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any, Dict, Iterable, Optional +from docker.models.containers import Container from networkx.readwrite import json_graph from spice import Spice from spice.models import Model, TextModel @@ -14,8 +15,8 @@ from ragdaemon.context import ContextBuilder from ragdaemon.database import DEFAULT_EMBEDDING_MODEL, Database, get_db from ragdaemon.errors import RagdaemonError -from ragdaemon.get_paths import get_paths_for_directory from ragdaemon.graph import KnowledgeGraph +from ragdaemon.io import DockerIO, IO, LocalIO from ragdaemon.locate import locate from ragdaemon.utils import DEFAULT_COMPLETION_MODEL, match_refresh, mentat_dir_path @@ -39,22 +40,23 @@ def __init__( cwd: Path, annotators: Optional[dict[str, dict]] = None, verbose: bool | int = 0, - graph_path: Optional[Path] = None, spice_client: Optional[Spice] = None, logging_dir: Optional[Path | str] = None, model: str = DEFAULT_EMBEDDING_MODEL, provider: Optional[str] = None, + container: Optional[Container] = None, ): self.cwd = cwd + if container is not None: + self.io: IO = DockerIO(cwd, container) + else: + self.io: IO = LocalIO(cwd) if isinstance(verbose, bool): verbose = 1 if verbose else 0 self.verbose = verbose - if graph_path is not None: - self.graph_path = (cwd / graph_path).resolve() - else: - self.graph_path = ( - mentat_dir_path / "ragdaemon" / f"ragdaemon-{self.cwd.name}.json" - ) + self.graph_path = ( + mentat_dir_path / "ragdaemon" / f"ragdaemon-{self.cwd.name}.json" + ) self.graph_path.parent.mkdir(parents=True, exist_ok=True) if spice_client is None: spice_client = Spice( @@ -82,6 +84,7 @@ def set_annotators(self, annotators: Optional[Dict[str, Dict]] = None): self.pipeline = {} for ann, kwargs in annotators.items(): self.pipeline[ann] = annotators_map[ann]( + io=self.io, **kwargs, verbose=self.verbose, spice_client=self.spice_client, @@ -92,7 +95,6 @@ def set_annotators(self, annotators: Optional[Dict[str, Dict]] = None): def db(self) -> Database: if not hasattr(self, "_db"): self._db = get_db( - self.cwd, spice_client=self.spice_client, embedding_model=self.embedding_model, embedding_provider=self.embedding_provider, @@ -130,13 +132,13 @@ async def update(self, refresh: str | bool = False): async def watch(self, interval=2, debounce=5): """Calls self.update interval debounce seconds after a file is modified.""" - paths = get_paths_for_directory(self.cwd) + paths = self.io.get_paths_for_directory() last_updated = 0 _update_task = None while True: await asyncio.sleep(interval) - paths = get_paths_for_directory(self.cwd) - _last_updated = max((self.cwd / path).stat().st_mtime for path in paths) + paths = self.io.get_paths_for_directory() + _last_updated = max(self.io.last_modified(path) for path in paths) if ( _last_updated > last_updated and (time.time() - _last_updated) > debounce @@ -171,7 +173,7 @@ def get_context( model: Model | str = DEFAULT_COMPLETION_MODEL, ) -> ContextBuilder: if context_builder is None: - context = ContextBuilder(self.graph, self.verbose) + context = ContextBuilder(self.graph, self.io, self.verbose) else: # TODO: Compare graph hashes, reconcile changes context = context_builder diff --git a/ragdaemon/database/__init__.py b/ragdaemon/database/__init__.py index 3df5006..7ef778e 100644 --- a/ragdaemon/database/__init__.py +++ b/ragdaemon/database/__init__.py @@ -1,5 +1,4 @@ import os # noqa: F401 -from pathlib import Path from typing import Optional from spice import Spice @@ -21,7 +20,6 @@ def get_db( - cwd: Path, spice_client: Spice, embedding_model: str | None = None, embedding_provider: Optional[str] = None, @@ -32,7 +30,6 @@ def get_db( # if embedding_model is not None and "PYTEST_CURRENT_TEST" not in os.environ: # try: # # db = ChromaDB( - # # cwd=cwd, # # db_path=db_path, # # spice_client=spice_client, # # embedding_model=embedding_model, @@ -42,7 +39,7 @@ def get_db( # # # In case the api key is wrong, try to embed something to trigger an error. # # _ = db.add(ids="test", documents="test doc") # # db.delete(ids="test") - # db = PGDB(cwd=cwd, db_path=db_path, verbose=verbose) + # db = PGDB(db_path=db_path, verbose=verbose) # return db # except Exception as e: # if verbose > 1: @@ -50,4 +47,4 @@ def get_db( # f"Failed to initialize Postgres Database: {e}. Falling back to LiteDB." # ) # pass - return LiteDB(cwd=cwd, db_path=db_path, verbose=verbose) + return LiteDB(db_path=db_path, verbose=verbose) diff --git a/ragdaemon/database/chroma_database.py b/ragdaemon/database/chroma_database.py index 0fd473a..5d70c9f 100644 --- a/ragdaemon/database/chroma_database.py +++ b/ragdaemon/database/chroma_database.py @@ -43,14 +43,12 @@ def remove_update_db_duplicates( class ChromaDB(Database): def __init__( self, - cwd: Path, db_path: Path, spice_client: Spice, embedding_model: str, embedding_provider: Optional[str] = None, verbose: int = 0, ) -> None: - self.cwd = cwd self.db_path = db_path self.embedding_model = embedding_model self.verbose = verbose diff --git a/ragdaemon/database/database.py b/ragdaemon/database/database.py index a44ecfc..6d4d306 100644 --- a/ragdaemon/database/database.py +++ b/ragdaemon/database/database.py @@ -8,7 +8,7 @@ class Database: embedding_model: str | None = None _collection = None # Collection | LiteDB - def __init__(self, cwd: Path, db_path: Path) -> None: + def __init__(self, db_path: Path) -> None: raise NotImplementedError def __getattr__(self, name): diff --git a/ragdaemon/database/lite_database.py b/ragdaemon/database/lite_database.py index dc4179e..c31c36f 100644 --- a/ragdaemon/database/lite_database.py +++ b/ragdaemon/database/lite_database.py @@ -11,8 +11,7 @@ def tokenize(document: str) -> list[str]: class LiteDB(Database): - def __init__(self, cwd: Path, db_path: Path, verbose: int = 0): - self.cwd = cwd + def __init__(self, db_path: Path, verbose: int = 0): self.db_path = db_path self.verbose = verbose self._collection = LiteCollection(self.verbose) diff --git a/ragdaemon/io/__init__.py b/ragdaemon/io/__init__.py new file mode 100644 index 0000000..b4bbd5c --- /dev/null +++ b/ragdaemon/io/__init__.py @@ -0,0 +1,7 @@ +from typing import Union + +from ragdaemon.io.docker_io import DockerIO +from ragdaemon.io.local_io import LocalIO + + +IO = Union[LocalIO, DockerIO] diff --git a/ragdaemon/io/docker_io.py b/ragdaemon/io/docker_io.py new file mode 100644 index 0000000..7db6d8f --- /dev/null +++ b/ragdaemon/io/docker_io.py @@ -0,0 +1,155 @@ +from contextlib import contextmanager +from pathlib import Path +from typing import Iterator, Optional, Set + +from docker.models.containers import Container + +from ragdaemon.errors import RagdaemonError +from ragdaemon.get_paths import match_path_with_patterns +from ragdaemon.io.file_like import FileLike + + +class FileInDocker(FileLike): + def __init__(self, container, path, mode): + self.container = container + self.path = path + self.mode = mode + self._content = None + + if "r" in mode: + result = self.container.exec_run(f"cat /{self.path}") + if result.exit_code != 0: + if "No such file or directory" in result.output.decode("utf-8"): + raise FileNotFoundError(f"No such file exists: {self.path}") + else: + raise IOError( + f"Failed to read file {self.path} in container: {result.stderr.decode('utf-8')}" + ) + self._content = result.output.decode("utf-8") + + def read(self, size: int = -1) -> str: + if self._content is None: + raise IOError("File not opened in read mode") + return self._content if size == -1 else self._content[:size] + + def write(self, data: str) -> int: + if "w" not in self.mode: + raise IOError("File not opened in write mode") + result = self.container.exec_run(f"sh -c 'printf \"%s\" > {self.path}'" % data) + if result.exit_code != 0: + raise IOError( + f"Failed to write file {self.path} in container: {result.stderr.decode('utf-8')}" + ) + return len(data) + + def __enter__(self) -> "FileInDocker": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + pass + + +class DockerIO: + def __init__(self, cwd: Path | str, container: Container): + self.cwd = Path(cwd) + self.container = container + + @contextmanager + def open(self, path: Path | str, mode: str = "r") -> Iterator[FileLike]: + path = Path(path) + file_path = self.cwd / path + docker_file = FileInDocker(self.container, file_path, mode) + yield docker_file + + def get_paths_for_directory( + self, path: Optional[Path | str] = None, exclude_patterns: Set[Path] = set() + ) -> Set[Path]: + root = self.cwd if path is None else self.cwd / path + if not self.is_git_repo(path): + raise RagdaemonError( + f"Path {root} is not a git repo. Ragdaemon DockerIO only supports git repos." + ) + + def get_non_gitignored_files(root: Path) -> Set[Path]: + return set( + Path(p) + for p in filter( + lambda p: p != "", + self.container.exec_run( + ["git", "ls-files", "-c", "-o", "--exclude-standard"], + workdir=f"/{root.as_posix()}", + ) + .output.decode("utf-8") + .split("\n"), + ) + ) + + files = set[Path]() + for file in get_non_gitignored_files(root): + if exclude_patterns: + abs_path = ( + self.container.exec_run(f"realpath {file}") + .output.decode("utf-8") + .strip() + ) + if match_path_with_patterns(abs_path, exclude_patterns): + continue + try: + with self.open(file) as f: + f.read() + except FileNotFoundError: + continue # File was deleted + except UnicodeDecodeError: + continue # File is not text-encoded + files.add(file) + return files + + def is_git_repo(self, path: Optional[Path | str] = None): + root = self.cwd if path is None else self.cwd / path + args = ["git", "ls-files", "--error-unmatch"] + try: + result = self.container.exec_run(args, workdir=f"/{root.as_posix()}") + return result.exit_code == 0 + except Exception: + return False + + def last_modified(self, path: Path | str) -> float: + path = self.cwd / path + result = self.container.exec_run(f"stat -c %Y {path}") + if result.exit_code != 0: + raise FileNotFoundError(f"No such file exists: {path}") + return float(result.output.decode("utf-8")) + + def get_git_diff(self, diff_args: Optional[str] = None) -> str: + args = ["git", "diff", "-U1"] + if diff_args and diff_args != "DEFAULT": + args += diff_args.split(" ") + result = self.container.exec_run(args, workdir=f"/{self.cwd}") + if result.exit_code != 0: + raise IOError(f"Failed to get git diff: {result.output.decode('utf-8')}") + return result.output.decode("utf-8") + + def mkdir(self, path: Path | str, parents: bool = False, exist_ok: bool = False): + result = self.container.exec_run(f"mkdir -p {self.cwd / path}") + if result.exit_code != 0: + raise IOError( + f"Failed to make directory {self.cwd / path} in container: {result.output.decode('utf-8')}" + ) + + def unlink(self, path: Path | str): + result = self.container.exec_run(f"rm {self.cwd / path}") + if result.exit_code != 0: + raise IOError( + f"Failed to unlink {self.cwd / path} in container: {result.output.decode('utf-8')}" + ) + + def rename(self, src: Path | str, dst: Path | str): + result = self.container.exec_run(f"mv {self.cwd / src} {self.cwd / dst}") + if result.exit_code != 0: + raise IOError( + f"Failed to rename {self.cwd / src} to {self.cwd / dst} in container: {result.output.decode('utf-8')}" + ) + + def exists(self, path: Path | str) -> bool: + result = self.container.exec_run(f"test -e {self.cwd / path}") + return result.exit_code == 0 diff --git a/ragdaemon/io/file_like.py b/ragdaemon/io/file_like.py new file mode 100644 index 0000000..262838c --- /dev/null +++ b/ragdaemon/io/file_like.py @@ -0,0 +1,11 @@ +from typing import Protocol + + +class FileLike(Protocol): + def read(self) -> str: ... + + def write(self, data: str) -> int: ... + + def __enter__(self) -> "FileLike": ... + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... diff --git a/ragdaemon/io/local_io.py b/ragdaemon/io/local_io.py new file mode 100644 index 0000000..cdd5bf9 --- /dev/null +++ b/ragdaemon/io/local_io.py @@ -0,0 +1,79 @@ +import subprocess +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Iterator, Optional, Set, Union +from types import TracebackType + +from ragdaemon.get_paths import get_paths_for_directory +from ragdaemon.io.file_like import FileLike + + +class FileWrapper: + def __init__(self, file: Any): + self._file = file + + def read(self, size: int = -1) -> str: + return self._file.read(size) + + def write(self, data: str) -> int: + return self._file.write(data) + + def __enter__(self) -> "FileWrapper": + return self + + def __exit__( + self, + exc_type: Union[type, None], + exc_val: Union[BaseException, None], + exc_tb: Union[TracebackType, None], + ) -> None: + self._file.__exit__(exc_type, exc_val, exc_tb) + + +class LocalIO: + def __init__(self, cwd: Path | str): + self.cwd = Path(cwd) + + @contextmanager + def open(self, path: Path | str, mode: str = "r") -> Iterator[FileLike]: + path = Path(path) + with open(self.cwd / path, mode) as file: + yield FileWrapper(file) + + def get_paths_for_directory( + self, path: Optional[Path | str] = None, exclude_patterns: Set[Path] = set() + ): + path = self.cwd if path is None else self.cwd / path + return get_paths_for_directory(path, exclude_patterns=exclude_patterns) + + def is_git_repo(self, path: Optional[Path | str] = None): + args = ["git", "ls-files", "--error-unmatch"] + if path: + args.append(Path(path).as_posix()) + try: + output = subprocess.run(args, cwd=self.cwd) + return output.returncode == 0 + except subprocess.CalledProcessError: + return False + + def last_modified(self, path: Path | str) -> float: + return (self.cwd / path).stat().st_mtime + + def get_git_diff(self, diff_args: Optional[str] = None) -> str: + args = ["git", "diff", "-U1"] + if diff_args and diff_args != "DEFAULT": + args += diff_args.split(" ") + diff = subprocess.check_output(args, cwd=self.cwd, text=True) + return diff + + def mkdir(self, path: Path | str, parents: bool = False, exist_ok: bool = False): + (self.cwd / path).mkdir(parents=parents, exist_ok=exist_ok) + + def unlink(self, path: Path | str): + (self.cwd / path).unlink() + + def rename(self, src: Path | str, dst: Path | str): + (self.cwd / src).rename(self.cwd / dst) + + def exists(self, path: Path | str) -> bool: + return (self.cwd / path).exists() diff --git a/ragdaemon/utils.py b/ragdaemon/utils.py index 25960fc..767af30 100644 --- a/ragdaemon/utils.py +++ b/ragdaemon/utils.py @@ -1,7 +1,6 @@ import asyncio import hashlib import re -import subprocess from base64 import b64encode from pathlib import Path @@ -10,7 +9,7 @@ from spice.spice import get_model_from_name from ragdaemon.errors import RagdaemonError -from ragdaemon.get_paths import get_paths_for_directory +from ragdaemon.io import IO mentat_dir_path = Path.home() / ".mentat" @@ -54,14 +53,6 @@ def basic_auth(username: str, password: str): return f"Basic {token}" -def get_git_diff(diff_args: str, cwd: str) -> str: - args = ["git", "diff", "-U1"] - if diff_args and diff_args != "DEFAULT": - args += diff_args.split(" ") - diff = subprocess.check_output(args, cwd=cwd, text=True) - return diff - - def parse_lines_ref(ref: str) -> set[int] | None: lines = set() for ref in ref.split(","): @@ -96,7 +87,7 @@ def parse_diff_id(id: str) -> tuple[str, Path | None, set[int] | None]: def get_document( - ref: str, cwd: Path, type: str = "file", ignore_patterns: set[Path] = set() + ref: str, io: IO, type: str = "file", ignore_patterns: set[Path] = set() ) -> str: if type == "diff": if ":" in ref: @@ -104,7 +95,7 @@ def get_document( lines = parse_lines_ref(lines_ref) else: diff_ref, lines = ref, None - diff = get_git_diff(diff_ref, str(cwd)) + diff = io.get_git_diff(diff_ref) if lines: text = "\n".join( [line for i, line in enumerate(diff.split("\n")) if i + 1 in lines] @@ -114,11 +105,13 @@ def get_document( ref = f"git diff{'' if diff_ref == 'DEFAULT' else f' {diff_ref}'}" elif type == "directory": - path = cwd if ref == "ROOT" else cwd / ref + path = None if ref == "ROOT" else Path(ref) paths = sorted( [ p.as_posix() - for p in get_paths_for_directory(path, exclude_patterns=ignore_patterns) + for p in io.get_paths_for_directory( + path=path, exclude_patterns=ignore_patterns + ) ] ) text = "\n".join(paths) @@ -127,7 +120,7 @@ def get_document( path, lines = parse_path_ref(ref) if lines: text = "" - with open(cwd / path, "r") as f: + with io.open(path, "r") as f: file_lines = f.read().split("\n") if max(lines) > len(file_lines): raise RagdaemonError(f"{type} {ref} has invalid line numbers") @@ -135,7 +128,7 @@ def get_document( text += f"{file_lines[line - 1]}\n" else: try: - with open(cwd / path, "r") as f: + with io.open(path, "r") as f: text = f.read() except UnicodeDecodeError: raise RagdaemonError(f"Not a text file: {path}") diff --git a/tests/annotators/test_chunker.py b/tests/annotators/test_chunker.py index d154899..6db0267 100644 --- a/tests/annotators/test_chunker.py +++ b/tests/annotators/test_chunker.py @@ -9,8 +9,8 @@ from ragdaemon.graph import KnowledgeGraph -def test_chunker_is_complete(cwd, mock_db): - chunker = Chunker() +def test_chunker_is_complete(io, mock_db): + chunker = Chunker(io) empty_graph = KnowledgeGraph() assert chunker.is_complete(empty_graph, mock_db), "Empty graph is complete." @@ -68,7 +68,7 @@ def expected_chunks(): @pytest.mark.asyncio -async def test_chunker_astroid(cwd, expected_chunks): +async def test_chunker_astroid(expected_chunks): text = Path("tests/data/hard_to_chunk.txt").read_text() document = f"src/calculator.py\n{text}" actual_chunks = await chunk_astroid(document) diff --git a/tests/annotators/test_diff.py b/tests/annotators/test_diff.py index bfedfd2..2a347c1 100644 --- a/tests/annotators/test_diff.py +++ b/tests/annotators/test_diff.py @@ -7,11 +7,12 @@ from ragdaemon.context import ContextBuilder from ragdaemon.daemon import Daemon from ragdaemon.graph import KnowledgeGraph -from ragdaemon.utils import get_git_diff +from ragdaemon.io import LocalIO -def test_diff_get_chunks_from_diff(git_history): - diff = get_git_diff("HEAD", cwd=git_history) +def test_diff_get_chunks_from_diff(cwd_git_diff): + io = LocalIO(cwd_git_diff) + diff = io.get_git_diff("HEAD") actual = get_chunks_from_diff("HEAD", diff) expected = { "HEAD:main.py": "HEAD:5-28", @@ -36,10 +37,11 @@ def test_diff_parse_diff_id(): @pytest.mark.asyncio -async def test_diff_annotate(git_history, mock_db): +async def test_diff_annotate(cwd_git_diff, mock_db): graph = KnowledgeGraph.load("tests/data/hierarchy_graph.json") - graph.graph["cwd"] = git_history.as_posix() - annotator = Diff() + graph.graph["cwd"] = cwd_git_diff.as_posix() + io = LocalIO(cwd_git_diff) + annotator = Diff(io) actual = await annotator.annotate(graph, mock_db) actual_nodes = {n for n, d in actual.nodes(data=True) if d and d["type"] == "diff"} @@ -54,12 +56,12 @@ async def test_diff_annotate(git_history, mock_db): @pytest.mark.asyncio -async def test_diff_render(git_history, mock_db): - daemon = Daemon(cwd=git_history) +async def test_diff_render(cwd_git_diff, mock_db): + daemon = Daemon(cwd=cwd_git_diff) await daemon.update(refresh=True) # Only diffs - context = ContextBuilder(daemon.graph) + context = ContextBuilder(daemon.graph, daemon.io) context.add_diff("DEFAULT:main.py") context.add_diff("DEFAULT:src/operations.py:1-5") context.add_diff("DEFAULT:src/operations.py:8-10") diff --git a/tests/annotators/test_hierarchy.py b/tests/annotators/test_hierarchy.py index 6fee45c..4a3d904 100644 --- a/tests/annotators/test_hierarchy.py +++ b/tests/annotators/test_hierarchy.py @@ -7,10 +7,10 @@ from ragdaemon.graph import KnowledgeGraph -def test_hierarchy_is_complete(cwd, mock_db): +def test_hierarchy_is_complete(cwd, io, mock_db): empty_graph = KnowledgeGraph() empty_graph.graph["cwd"] = cwd.as_posix() - hierarchy = Hierarchy() + hierarchy = Hierarchy(io) assert not hierarchy.is_complete( empty_graph, mock_db @@ -25,10 +25,10 @@ def test_hierarchy_is_complete(cwd, mock_db): @pytest.mark.asyncio -async def test_hierarchy_annotate(cwd, mock_db): +async def test_hierarchy_annotate(cwd, io, mock_db): graph = KnowledgeGraph() graph.graph["cwd"] = cwd.as_posix() - hierarchy = Hierarchy() + hierarchy = Hierarchy(io) actual = await hierarchy.annotate(graph, mock_db) # Load the template graph diff --git a/tests/annotators/test_layout_hierarchy.py b/tests/annotators/test_layout_hierarchy.py index 437464a..707c709 100644 --- a/tests/annotators/test_layout_hierarchy.py +++ b/tests/annotators/test_layout_hierarchy.py @@ -4,8 +4,8 @@ from ragdaemon.graph import KnowledgeGraph -def test_layout_hierarchy_is_complete(cwd, mock_db): - layout_hierarchy = LayoutHierarchy() +def test_layout_hierarchy_is_complete(io, mock_db): + layout_hierarchy = LayoutHierarchy(io) empty_graph = KnowledgeGraph() assert layout_hierarchy.is_complete( @@ -35,9 +35,9 @@ def test_layout_hierarchy_is_complete(cwd, mock_db): @pytest.mark.asyncio -async def test_layout_hierarchy_annotate(cwd, mock_db): +async def test_layout_hierarchy_annotate(io, mock_db): hierarchy_graph = KnowledgeGraph.load("tests/data/hierarchy_graph.json") - actual = await LayoutHierarchy().annotate(hierarchy_graph, mock_db) + actual = await LayoutHierarchy(io).annotate(hierarchy_graph, mock_db) all_coordinates = set() for node, data in actual.nodes(data=True): diff --git a/tests/annotators/test_summarizer.py b/tests/annotators/test_summarizer.py index 8305917..9db4812 100644 --- a/tests/annotators/test_summarizer.py +++ b/tests/annotators/test_summarizer.py @@ -41,15 +41,15 @@ async def test_build_filetree(cwd): @pytest.mark.asyncio -async def test_get_document_and_context(cwd): +async def test_get_document_and_context(io): graph = KnowledgeGraph.load("tests/data/summarizer_graph.json") # Chunk data for _, data in graph.nodes(data=True): - document = get_document(data["ref"], cwd=cwd, type=data["type"]) + document = get_document(data["ref"], io, type=data["type"]) data["document"] = document # A chunk document, context = get_document_and_context( - "src/interface.py:parse_arguments", graph + "src/interface.py:parse_arguments", graph, io ) assert ( document @@ -103,7 +103,7 @@ async def test_get_document_and_context(cwd): ) # A file - document, context = get_document_and_context("src/interface.py", graph) + document, context = get_document_and_context("src/interface.py", graph, io) assert document.startswith("src/interface.py\n") assert ( context @@ -127,7 +127,7 @@ async def test_get_document_and_context(cwd): ) # A directory - document, context = get_document_and_context("src", graph) + document, context = get_document_and_context("src", graph, io) assert document == "Directory: src" assert ( context diff --git a/tests/conftest.py b/tests/conftest.py index d6443f8..0ebd695 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,19 @@ +import io as _io import os +import platform import shutil import subprocess +import tarfile import tempfile from pathlib import Path from unittest.mock import AsyncMock +import docker +from docker.errors import DockerException import pytest from ragdaemon.database import DEFAULT_EMBEDDING_MODEL, get_db +from ragdaemon.io import LocalIO @pytest.fixture @@ -16,14 +22,17 @@ def cwd(): @pytest.fixture -def mock_db(cwd): - return get_db( - cwd, spice_client=AsyncMock(), embedding_model=DEFAULT_EMBEDDING_MODEL - ) +def io(cwd): + return LocalIO(cwd) + + +@pytest.fixture +def mock_db(): + return get_db(spice_client=AsyncMock(), embedding_model=DEFAULT_EMBEDDING_MODEL) @pytest.fixture(scope="function") -def git_history(cwd): +def cwd_git(cwd): with tempfile.TemporaryDirectory() as tmpdir: tmpdir_path = Path(tmpdir) sample_dir = cwd @@ -44,23 +53,78 @@ def git_history(cwd): subprocess.run( ["git", "commit", "-m", "Initial commit"], cwd=tmpdir_path, check=True ) + yield tmpdir_path - # Diff - modify_lines = [1, 2, 3, 8] # Modify - with open(tmpdir_path / "src" / "operations.py", "r") as f: - lines = f.readlines() - for i in modify_lines: - lines[i] = lines[i].strip() + " #modified\n" - with open(tmpdir_path / "src" / "operations.py", "w") as f: - f.writelines(lines) - (tmpdir_path / "main.py").unlink() # Remove - with open(tmpdir_path / "hello.py", "w") as f: # Add - f.write("print('Hello, world!')\n") - yield tmpdir_path +@pytest.fixture(scope="function") +def cwd_git_diff(cwd_git): + modify_lines = [1, 2, 3, 8] # Modify + with open(cwd_git / "src" / "operations.py", "r") as f: + lines = f.readlines() + for i in modify_lines: + lines[i] = lines[i].strip() + " #modified\n" + with open(cwd_git / "src" / "operations.py", "w") as f: + f.writelines(lines) + (cwd_git / "main.py").unlink() # Remove + with open(cwd_git / "hello.py", "w") as f: # Add + f.write("print('Hello, world!')\n") + + yield cwd_git # We have to set the key since counting tokens with an openai model loads the openai client @pytest.fixture(autouse=True) def mock_openai_api_key(): os.environ["OPENAI_API_KEY"] = "fake_key" + + +""" +GithubActions for Linux comes with Docker pre-installed. +Setting up a Docker environment for MacOS and Windows in Github Actions is tedious. +The purpose of supporting docker IO is for butler, which only runs on Linux anyway, +so we can skip these tests on MacOS and Windows. We still make an attempt though, +because if Docker IS installed (i.e. local development on MacOS or Windows), it should +still work. +""" +def fail_silently_on_macos_and_windows(docker_function, *args, **kwargs): + try: + return docker_function(*args, **kwargs) + except DockerException as e: + if platform.system() in ["Darwin", "Windows"]: + pytest.skip(f"Skipping Docker tests on {platform.system()} due to Docker error: {e}") + else: + raise e + + +@pytest.fixture(scope="session") +def docker_client(): + return fail_silently_on_macos_and_windows(docker.from_env) + + +@pytest.fixture +def container(cwd, docker_client, path="tests/sample"): + image = "python:3.10" + container = fail_silently_on_macos_and_windows( + docker_client.containers.run, image, detach=True, tty=True, command="sh" + ) + + # Create the tests/sample directory in the container + container.exec_run(f"mkdir -p {path}") + + # Copy everything in cwd into the docker container at the same location + tarstream = _io.BytesIO() + with tarfile.open(fileobj=tarstream, mode="w") as tar: + tar.add(cwd, arcname=".") + tarstream.seek(0) + container.put_archive(path, tarstream) + + workdir = f"/{path}" + container.exec_run("git init", workdir=workdir) + container.exec_run("git config user.email you@example.com", workdir=workdir) + container.exec_run("git config user.name 'Your Name'", workdir=workdir) + + try: + yield container + finally: + container.stop() + container.remove() diff --git a/tests/test_comments.py b/tests/test_comments.py index 478c4d3..43d306c 100644 --- a/tests/test_comments.py +++ b/tests/test_comments.py @@ -6,11 +6,11 @@ @pytest.mark.asyncio -async def test_comment_render(git_history, mock_db): - daemon = Daemon(cwd=git_history) +async def test_comment_render(cwd_git_diff, mock_db): + daemon = Daemon(cwd=cwd_git_diff) await daemon.update(refresh=True) - context = ContextBuilder(daemon.graph) + context = ContextBuilder(daemon.graph, daemon.io) context.add_ref("src/operations.py") context.add_comment( "src/operations.py", {"comment": "What is this file for?"}, tags=["test-flag"] diff --git a/tests/test_context.py b/tests/test_context.py index 89003fc..39af1c2 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -8,17 +8,17 @@ from ragdaemon.utils import get_document -def test_daemon_render_context(cwd): +def test_daemon_render_context(io): path_str = Path("src/interface.py").as_posix() ref = path_str # Base Chunk - context = ContextBuilder(KnowledgeGraph()) + context = ContextBuilder(KnowledgeGraph(), io) context.context = { path_str: { "lines": set([1, 2, 3, 4, 15]), "tags": ["test-flag"], - "document": get_document(ref, cwd), + "document": get_document(ref, io), "diffs": set(), "comments": dict(), } @@ -43,7 +43,7 @@ def test_daemon_render_context(cwd): path_str: { "lines": set([5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), "tags": ["test-flag"], - "document": get_document(ref, cwd), + "document": get_document(ref, io), "diffs": set(), "comments": dict(), } @@ -94,17 +94,17 @@ async def test_context_builder_methods(cwd, mock_db): assert combined_context.context["src/operations.py"]["lines"] == set(range(1, 23)) -def test_to_refs(cwd, mock_db): +def test_to_refs(io, mock_db): path_str = Path("src/interface.py").as_posix() ref = path_str # Setup Context - context = ContextBuilder(KnowledgeGraph()) + context = ContextBuilder(KnowledgeGraph(), io) context.context = { path_str: { "lines": set([1, 2, 3, 4, 15]), "tags": ["test-flag"], - "document": get_document(ref, cwd), + "document": get_document(ref, io), "diffs": set(), "comments": dict(), } diff --git a/tests/test_database.py b/tests/test_database.py index 1d107f5..afb6c63 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -3,6 +3,6 @@ from ragdaemon.database import DEFAULT_EMBEDDING_MODEL, LiteDB, get_db -def test_mock_database(cwd): - db = get_db(cwd, AsyncMock(), embedding_model=DEFAULT_EMBEDDING_MODEL) +def test_mock_database(): + db = get_db(AsyncMock(), embedding_model=DEFAULT_EMBEDDING_MODEL) assert isinstance(db, LiteDB) diff --git a/tests/test_get_paths.py b/tests/test_get_paths.py index 51b3eea..c5493cd 100644 --- a/tests/test_get_paths.py +++ b/tests/test_get_paths.py @@ -46,16 +46,16 @@ def add_permissions(func, path, exc_info): raise -def test_get_paths_for_directory_without_git(git_history): - # Using the 'git_history' fixture because it sets up a tempdir. - git_history = git_history.resolve() - git_dir = git_history / ".git" +def test_get_paths_for_directory_without_git(cwd_git_diff): + # Using the 'cwd_git_diff' fixture because it sets up a tempdir. + cwd_git_diff = cwd_git_diff.resolve() + git_dir = cwd_git_diff / ".git" shutil.rmtree(git_dir, onerror=add_permissions) - is_git_repo = get_git_root_for_path(git_history, raise_error=False) + is_git_repo = get_git_root_for_path(cwd_git_diff, raise_error=False) assert not is_git_repo, "Is a git repository" - paths = get_paths_for_directory(git_history) + paths = get_paths_for_directory(cwd_git_diff) assert paths == { Path(".gitignore"), Path("README.md"), diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 0000000..6e04354 --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,129 @@ +from pathlib import Path + +import pytest + +from ragdaemon.daemon import Daemon +from ragdaemon.io import DockerIO, IO, LocalIO + + +def all_io_methods(io: IO): + text = "Hello, world!" + + with io.open("tempfile.txt", "w") as f: + f.write(text) + + with io.open("tempfile.txt") as f: + assert f.read() == text + + assert io.is_git_repo() + + assert io.last_modified("tempfile.txt") > 0 + + assert io.get_git_diff() == "" + + io.mkdir("tempdir/tempsubdir", parents=True) + assert io.exists("tempdir/tempsubdir") + + io.rename("tempdir/tempsubdir", "tempdir/renamedsubdir") + assert io.exists("tempdir/renamedsubdir") + + io.unlink("tempfile.txt") + assert not io.exists("tempfile.txt") + + +@pytest.mark.asyncio +async def test_local_io_methods(cwd_git): + + io = LocalIO(Path(cwd_git)) + all_io_methods(io) + + +@pytest.mark.asyncio +async def test_docker_io_methods(container): + io = DockerIO(Path("tests/sample"), container=container) + all_io_methods(io) + + +def get_message_chunk_set(message): # Because order can vary + chunks = message.split("\n\n") + if len(chunks) > 0: + for i in range(len(chunks) - 1): + chunks[i] += "\n" + + +@pytest.mark.asyncio +async def test_docker_io_integration(container, path="tests/sample"): + daemon = Daemon(Path(path), annotators={"hierarchy": {}}, container=container) + await daemon.update() + + actual = daemon.get_context("test", max_tokens=1000).render(use_tags=True) + + with open("tests/data/context_message.txt", "r") as f: + expected = f.read() + assert get_message_chunk_set(actual) == get_message_chunk_set(expected) + + # Included Files + context = daemon.get_context("test") + context.add_ref("src/interface.py:11-12", tags=["user-included"]) + actual = daemon.get_context("test", context_builder=context, auto_tokens=0).render( + use_tags=True + ) + assert ( + actual + == """\ +src/interface.py (user-included) +... +11: match = re.match(r"(\\d+)(\\D)(\\d+)", args.operation) +12: if match is None: +... +""" + ) + + +""" +'diff --git a/main.py b/main.py +deleted file mode 100644 +index fcabfbe..0000000 +--- a/main.py ++++ /dev/null +@@ -1,23 +0,0 @@ +-from src.interface import parse_arguments, render_response +-from src.operations import add, divide, multiply, subtract +- +- +-def main(): +- a, op, b = parse_arguments() +- +- if op == "+": +- result = add(a, b) +- elif op == "-": +- result = subtract(a, b) +- elif op == "*": +- result = multiply(a, b) +- elif op == "/": +- result = divide(a, b) +- else: +- raise ValueError("Unsupported operation") +- +- render_response(result) +- +- +-if __name__ == "__main__": +- main() +diff --git a/src/operations.py b/src/operations.py +index 9f1facd..073af81 100644 +--- a/src/operations.py ++++ b/src/operations.py +@@ -1,5 +1,5 @@ + import math +- +- +-def add(a, b): ++ #modified ++ #modified ++def add(a, b): #modified + return a + b +@@ -8,3 +8,3 @@ def add(a, b): + def subtract(a, b): +- return a - b\n+return a - b #modified\n \n' +"""