From 02d9d68eb1b8f63143759a2410710cacce79dff1 Mon Sep 17 00:00:00 2001 From: fynnfluegge Date: Sat, 13 Jan 2024 12:09:08 +0100 Subject: [PATCH] fix: VectorCache from Json parse error in sync command --- codeqai/app.py | 3 +++ codeqai/cache.py | 10 +++++++--- tests/vector_store_test.py | 7 ------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/codeqai/app.py b/codeqai/app.py index c53dfaa..8fd3804 100644 --- a/codeqai/app.py +++ b/codeqai/app.py @@ -140,6 +140,7 @@ def run(): vector_store.sync_documents(documents) save_vector_cache(vector_store.vector_cache, f"{repo_name}.json") spinner.stop() + print("⚙️ Vector store synced with current git checkout.") llm = LLM( llm_host=LlmHost[config["llm-host"].upper().replace("-", "_")], @@ -156,6 +157,8 @@ def run(): console = Console() while True: choice = None + if args.action == "sync": + break if args.action == "search": search_pattern = input("🔎 Enter a search pattern: ") spinner = yaspin(text="🤖 Processing...", color="green") diff --git a/codeqai/cache.py b/codeqai/cache.py index 1881933..4f25d2d 100644 --- a/codeqai/cache.py +++ b/codeqai/cache.py @@ -2,6 +2,7 @@ import os import platform from pathlib import Path +from typing import Dict class VectorCache: @@ -11,7 +12,7 @@ def __init__(self, filename, vector_ids, commit_hash): self.commit_hash = commit_hash @classmethod - def from_json(cls, json_data): + def from_json(cls, json_data) -> "VectorCache": filename = json_data.get("filename") vector_ids = json_data.get("vector_ids", []) commit_hash = json_data.get("commit_hash", "") @@ -25,9 +26,12 @@ def to_json(self): } -def load_vector_cache(filename): +def load_vector_cache(filename) -> Dict[str, VectorCache]: with open(get_cache_path() + "/" + filename, "r") as vector_cache_file: - vector_cache = json.load(vector_cache_file, object_hook=VectorCache.from_json) + vector_cache_json = json.load(vector_cache_file) + vector_cache = {} + for key, value in vector_cache_json.items(): + vector_cache[key] = VectorCache.from_json(value) return vector_cache diff --git a/tests/vector_store_test.py b/tests/vector_store_test.py index 32a6efc..5fcf571 100644 --- a/tests/vector_store_test.py +++ b/tests/vector_store_test.py @@ -37,19 +37,12 @@ def test_sync_documents(modified_vector_entries, vector_entries, vector_cache): assert len(vector_store.db.index_to_docstore_id) == 4 vector_store.vector_cache = vector_cache - for id in vector_store.db.index_to_docstore_id.values(): - print(vector_store.db.docstore.search(id)) - for vector_id in vector_store.db.index_to_docstore_id.values(): vector_store.vector_cache[ vector_store.db.docstore.search(vector_id).metadata["filename"] ].vector_ids.append(vector_id) vector_store.sync_documents(modified_vector_entries) - print("After sync") - for id in vector_store.db.index_to_docstore_id.values(): - print(vector_store.db.docstore.search(id)) - assert len(vector_store.db.index_to_docstore_id) == 5 for vector_id in vector_store.db.index_to_docstore_id.values():