Skip to content

Commit

Permalink
Update for Spice v0.1.11 with @PCSwingle
Browse files Browse the repository at this point in the history
  • Loading branch information
granawkins committed Apr 11, 2024
2 parents 698d6dc + 5748fbe commit eb89fc6
Show file tree
Hide file tree
Showing 12 changed files with 31 additions and 32 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ packages=["ragdaemon"]

[project]
name = "ragdaemon"
version = "0.1.3"
version = "0.1.4"
description = "Generate and render a call graph for a Python project."
readme = "README.md"
dependencies = [
"chromadb==0.4.24",
"fastapi==0.109.2",
"Jinja2==3.1.3",
"networkx==3.2.1",
"spiceai==0.1.9",
"spiceai==0.1.11",
"starlette==0.36.3",
"tiktoken==0.6.0",
"tqdm==4.66.2",
Expand Down
2 changes: 1 addition & 1 deletion ragdaemon/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.3"
__version__ = "0.1.4"
2 changes: 1 addition & 1 deletion ragdaemon/annotators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ragdaemon.annotators.base_annotator import Annotator # noqa: F401
from ragdaemon.annotators.chunker import Chunker
from ragdaemon.annotators.chunker_llm import ChunkerLLM
from ragdaemon.annotators.chunker_line import ChunkerLine
from ragdaemon.annotators.chunker_llm import ChunkerLLM
from ragdaemon.annotators.diff import Diff
from ragdaemon.annotators.hierarchy import Hierarchy
from ragdaemon.annotators.layout_hierarchy import LayoutHierarchy
Expand Down
8 changes: 3 additions & 5 deletions ragdaemon/annotators/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from pathlib import Path

import networkx as nx
from spice import Spice

from ragdaemon.annotators.base_annotator import Annotator
from ragdaemon.database import (
DEFAULT_EMBEDDING_MODEL,
Database,
MAX_TOKENS_PER_EMBEDDING,
Database,
)
from ragdaemon.errors import RagdaemonError
from ragdaemon.llm import token_counter
from ragdaemon.utils import get_document, hash_str, parse_path_ref


Expand Down Expand Up @@ -109,9 +109,7 @@ async def annotate(

# If the full diff is too long to embed, it is truncated. Anything
# removed will be captured in chunks.
tokens = token_counter(
document, model=DEFAULT_EMBEDDING_MODEL, full_message=False
)
tokens = Spice().count_tokens(document, model=DEFAULT_EMBEDDING_MODEL)
if tokens > MAX_TOKENS_PER_EMBEDDING:
truncate_ratio = (MAX_TOKENS_PER_EMBEDDING / tokens) * 0.99
document = document[: int(len(document) * truncate_ratio)]
Expand Down
5 changes: 3 additions & 2 deletions ragdaemon/annotators/hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from pathlib import Path

import networkx as nx
from spice import Spice

from ragdaemon.annotators.base_annotator import Annotator
from ragdaemon.database import MAX_TOKENS_PER_EMBEDDING, Database
from ragdaemon.errors import RagdaemonError
from ragdaemon.llm import token_counter
from ragdaemon.llm import DEFAULT_COMPLETION_MODEL
from ragdaemon.utils import get_document, get_non_gitignored_files, hash_str


Expand Down Expand Up @@ -56,7 +57,7 @@ def get_active_checksums(
path_str = path.as_posix()
ref = path_str
document = get_document(ref, cwd)
tokens = token_counter(document)
tokens = Spice().count_tokens(document, DEFAULT_COMPLETION_MODEL)
if tokens > MAX_TOKENS_PER_EMBEDDING: # e.g. package-lock.json
continue
checksum = hash_str(document)
Expand Down
4 changes: 2 additions & 2 deletions ragdaemon/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from spice import Spice
from starlette.templating import Jinja2Templates

from ragdaemon.llm import DEFAULT_COMPLETION_MODEL
from ragdaemon.database import DEFAULT_EMBEDDING_MODEL
from ragdaemon.daemon import Daemon
from ragdaemon.database import DEFAULT_EMBEDDING_MODEL
from ragdaemon.llm import DEFAULT_COMPLETION_MODEL

# Load daemon with command line arguments and visualization annotators
parser = argparse.ArgumentParser(description="Start the ragdaemon server.")
Expand Down
12 changes: 8 additions & 4 deletions ragdaemon/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from ragdaemon.annotators import Annotator, annotators_map
from ragdaemon.context import ContextBuilder
from ragdaemon.database import Database, DEFAULT_EMBEDDING_MODEL, get_db
from ragdaemon.llm import DEFAULT_COMPLETION_MODEL, token_counter
from ragdaemon.database import DEFAULT_EMBEDDING_MODEL, Database, get_db
from ragdaemon.llm import DEFAULT_COMPLETION_MODEL
from ragdaemon.utils import get_non_gitignored_files


Expand Down Expand Up @@ -128,7 +128,9 @@ def get_context(
# TODO: Compare graph hashes, reconcile changes
context = context_builder
include_context_message = context.render()
include_tokens = token_counter(include_context_message)
include_tokens = self.spice_client.count_tokens(
include_context_message, DEFAULT_COMPLETION_MODEL
)
if not auto_tokens or include_tokens >= max_tokens:
return context

Expand All @@ -140,7 +142,9 @@ def get_context(
else:
context.add_ref(node["ref"], tags=["search-result"])
next_context_message = context.render()
next_tokens = token_counter(next_context_message)
next_tokens = self.spice_client.count_tokens(
next_context_message, DEFAULT_COMPLETION_MODEL
)
if (next_tokens - include_tokens) > auto_tokens:
if node["type"] == "diff":
context.remove_diff(node["id"])
Expand Down
5 changes: 3 additions & 2 deletions ragdaemon/database/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from pathlib import Path

from spice import Spice, SpiceError
from spice import Spice
from spice.errors import SpiceError

from ragdaemon.database.database import Database
from ragdaemon.database.chroma_database import ChromaDB
from ragdaemon.database.database import Database
from ragdaemon.database.lite_database import LiteDB
from ragdaemon.utils import mentat_dir_path

Expand Down
11 changes: 0 additions & 11 deletions ragdaemon/llm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1 @@
import tiktoken


DEFAULT_COMPLETION_MODEL = "gpt-4-0125-preview"


def token_counter(
text: str, model: str | None = None, full_message: bool = False
) -> int:
if model is None:
model = DEFAULT_COMPLETION_MODEL
return len(tiktoken.encoding_for_model(model).encode(text))
1 change: 0 additions & 1 deletion ragdaemon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from ragdaemon.errors import RagdaemonError


mentat_dir_path = Path.home() / ".mentat"


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Jinja2==3.1.3
networkx==3.2.1
pytest==8.0.2
pytest-asyncio==0.23.5
spiceai==0.1.9
spiceai==0.1.11
starlette==0.36.3
tiktoken==0.6.0
tqdm==4.66.2
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import shutil
import subprocess
import tempfile
Expand Down Expand Up @@ -64,3 +65,9 @@ def git_history(cwd):
f.write("print('Hello, world!')\n")

yield tmpdir_path


# We have to set the key since counting tokens with an openai model loads the openai client
@pytest.fixture(autouse=True)
def mock_openai_api_key():
os.environ["OPENAI_API_KEY"] = "fake_key"

0 comments on commit eb89fc6

Please sign in to comment.