diff --git a/pyproject.toml b/pyproject.toml index beb3757..8c23ade 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ packages=["ragdaemon"] [project] name = "ragdaemon" -version = "0.1.3" +version = "0.1.4" description = "Generate and render a call graph for a Python project." readme = "README.md" dependencies = [ @@ -15,7 +15,7 @@ dependencies = [ "fastapi==0.109.2", "Jinja2==3.1.3", "networkx==3.2.1", - "spiceai==0.1.9", + "spiceai==0.1.11", "starlette==0.36.3", "tiktoken==0.6.0", "tqdm==4.66.2", diff --git a/ragdaemon/__init__.py b/ragdaemon/__init__.py index ae73625..bbab024 100644 --- a/ragdaemon/__init__.py +++ b/ragdaemon/__init__.py @@ -1 +1 @@ -__version__ = "0.1.3" +__version__ = "0.1.4" diff --git a/ragdaemon/annotators/__init__.py b/ragdaemon/annotators/__init__.py index d99c25f..6126a70 100644 --- a/ragdaemon/annotators/__init__.py +++ b/ragdaemon/annotators/__init__.py @@ -1,7 +1,7 @@ from ragdaemon.annotators.base_annotator import Annotator # noqa: F401 from ragdaemon.annotators.chunker import Chunker -from ragdaemon.annotators.chunker_llm import ChunkerLLM from ragdaemon.annotators.chunker_line import ChunkerLine +from ragdaemon.annotators.chunker_llm import ChunkerLLM from ragdaemon.annotators.diff import Diff from ragdaemon.annotators.hierarchy import Hierarchy from ragdaemon.annotators.layout_hierarchy import LayoutHierarchy diff --git a/ragdaemon/annotators/diff.py b/ragdaemon/annotators/diff.py index 72fd601..cf7d139 100644 --- a/ragdaemon/annotators/diff.py +++ b/ragdaemon/annotators/diff.py @@ -3,15 +3,15 @@ from pathlib import Path import networkx as nx +from spice import Spice from ragdaemon.annotators.base_annotator import Annotator from ragdaemon.database import ( DEFAULT_EMBEDDING_MODEL, - Database, MAX_TOKENS_PER_EMBEDDING, + Database, ) from ragdaemon.errors import RagdaemonError -from ragdaemon.llm import token_counter from ragdaemon.utils import get_document, hash_str, parse_path_ref @@ -109,9 +109,7 @@ async def annotate( # If the full diff is too long to embed, it is truncated. Anything # removed will be captured in chunks. - tokens = token_counter( - document, model=DEFAULT_EMBEDDING_MODEL, full_message=False - ) + tokens = Spice().count_tokens(document, model=DEFAULT_EMBEDDING_MODEL) if tokens > MAX_TOKENS_PER_EMBEDDING: truncate_ratio = (MAX_TOKENS_PER_EMBEDDING / tokens) * 0.99 document = document[: int(len(document) * truncate_ratio)] diff --git a/ragdaemon/annotators/hierarchy.py b/ragdaemon/annotators/hierarchy.py index 62b12f5..7aeb08d 100644 --- a/ragdaemon/annotators/hierarchy.py +++ b/ragdaemon/annotators/hierarchy.py @@ -2,11 +2,12 @@ from pathlib import Path import networkx as nx +from spice import Spice from ragdaemon.annotators.base_annotator import Annotator from ragdaemon.database import MAX_TOKENS_PER_EMBEDDING, Database from ragdaemon.errors import RagdaemonError -from ragdaemon.llm import token_counter +from ragdaemon.llm import DEFAULT_COMPLETION_MODEL from ragdaemon.utils import get_document, get_non_gitignored_files, hash_str @@ -56,7 +57,7 @@ def get_active_checksums( path_str = path.as_posix() ref = path_str document = get_document(ref, cwd) - tokens = token_counter(document) + tokens = Spice().count_tokens(document, DEFAULT_COMPLETION_MODEL) if tokens > MAX_TOKENS_PER_EMBEDDING: # e.g. package-lock.json continue checksum = hash_str(document) diff --git a/ragdaemon/app.py b/ragdaemon/app.py index 97554e0..78fcb52 100644 --- a/ragdaemon/app.py +++ b/ragdaemon/app.py @@ -12,9 +12,9 @@ from spice import Spice from starlette.templating import Jinja2Templates -from ragdaemon.llm import DEFAULT_COMPLETION_MODEL -from ragdaemon.database import DEFAULT_EMBEDDING_MODEL from ragdaemon.daemon import Daemon +from ragdaemon.database import DEFAULT_EMBEDDING_MODEL +from ragdaemon.llm import DEFAULT_COMPLETION_MODEL # Load daemon with command line arguments and visualization annotators parser = argparse.ArgumentParser(description="Start the ragdaemon server.") diff --git a/ragdaemon/daemon.py b/ragdaemon/daemon.py index f71b295..5097a9a 100644 --- a/ragdaemon/daemon.py +++ b/ragdaemon/daemon.py @@ -9,8 +9,8 @@ from ragdaemon.annotators import Annotator, annotators_map from ragdaemon.context import ContextBuilder -from ragdaemon.database import Database, DEFAULT_EMBEDDING_MODEL, get_db -from ragdaemon.llm import DEFAULT_COMPLETION_MODEL, token_counter +from ragdaemon.database import DEFAULT_EMBEDDING_MODEL, Database, get_db +from ragdaemon.llm import DEFAULT_COMPLETION_MODEL from ragdaemon.utils import get_non_gitignored_files @@ -128,7 +128,9 @@ def get_context( # TODO: Compare graph hashes, reconcile changes context = context_builder include_context_message = context.render() - include_tokens = token_counter(include_context_message) + include_tokens = self.spice_client.count_tokens( + include_context_message, DEFAULT_COMPLETION_MODEL + ) if not auto_tokens or include_tokens >= max_tokens: return context @@ -140,7 +142,9 @@ def get_context( else: context.add_ref(node["ref"], tags=["search-result"]) next_context_message = context.render() - next_tokens = token_counter(next_context_message) + next_tokens = self.spice_client.count_tokens( + next_context_message, DEFAULT_COMPLETION_MODEL + ) if (next_tokens - include_tokens) > auto_tokens: if node["type"] == "diff": context.remove_diff(node["id"]) diff --git a/ragdaemon/database/__init__.py b/ragdaemon/database/__init__.py index f706896..3246ca1 100644 --- a/ragdaemon/database/__init__.py +++ b/ragdaemon/database/__init__.py @@ -1,10 +1,11 @@ import os from pathlib import Path -from spice import Spice, SpiceError +from spice import Spice +from spice.errors import SpiceError -from ragdaemon.database.database import Database from ragdaemon.database.chroma_database import ChromaDB +from ragdaemon.database.database import Database from ragdaemon.database.lite_database import LiteDB from ragdaemon.utils import mentat_dir_path diff --git a/ragdaemon/llm.py b/ragdaemon/llm.py index 57f8e46..1e813f5 100644 --- a/ragdaemon/llm.py +++ b/ragdaemon/llm.py @@ -1,12 +1 @@ -import tiktoken - - DEFAULT_COMPLETION_MODEL = "gpt-4-0125-preview" - - -def token_counter( - text: str, model: str | None = None, full_message: bool = False -) -> int: - if model is None: - model = DEFAULT_COMPLETION_MODEL - return len(tiktoken.encoding_for_model(model).encode(text)) diff --git a/ragdaemon/utils.py b/ragdaemon/utils.py index de885f4..bda9075 100644 --- a/ragdaemon/utils.py +++ b/ragdaemon/utils.py @@ -6,7 +6,6 @@ from ragdaemon.errors import RagdaemonError - mentat_dir_path = Path.home() / ".mentat" diff --git a/requirements.txt b/requirements.txt index 67c5c98..8ff9dac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ Jinja2==3.1.3 networkx==3.2.1 pytest==8.0.2 pytest-asyncio==0.23.5 -spiceai==0.1.9 +spiceai==0.1.11 starlette==0.36.3 tiktoken==0.6.0 tqdm==4.66.2 diff --git a/tests/conftest.py b/tests/conftest.py index c9fd000..e0df8ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import os import shutil import subprocess import tempfile @@ -64,3 +65,9 @@ def git_history(cwd): f.write("print('Hello, world!')\n") yield tmpdir_path + + +# We have to set the key since counting tokens with an openai model loads the openai client +@pytest.fixture(autouse=True) +def mock_openai_api_key(): + os.environ["OPENAI_API_KEY"] = "fake_key"