Skip to content

Commit

Permalink
add local embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
fynnfluegge committed Sep 23, 2023
1 parent 3439c12 commit 4125e3d
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 71 deletions.
30 changes: 17 additions & 13 deletions codeqai/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.memory import ConversationSummaryMemory
from yaspin import yaspin

from codeqai import codeparser, repo
from codeqai.config import create_cache_dir, create_config, get_cache_path, load_config
from codeqai.config import (create_cache_dir, create_config, get_cache_path,
load_config)
from codeqai.constants import EmbeddingsModel
from codeqai.embeddings import Embeddings
from codeqai.vector_store import VectorStore


Expand All @@ -25,7 +27,7 @@ def run():
create_config()

# load config
config = None
config = {}
try:
config = load_config()
except FileNotFoundError:
Expand All @@ -36,20 +38,23 @@ def run():
# init cache
create_cache_dir()

embeddings_model = Embeddings(
local=True,
model=EmbeddingsModel[config["embeddings"].upper().replace("-", "_")],
)

# check if faiss.index exists
if not os.path.exists(os.path.join(get_cache_path(), f"{repo_name}.index")):
# sync repo
files = repo.load_files()
documents = codeparser.parse_code_files(files)
vector_store = VectorStore(
repo_name,
OpenAIEmbeddings(client=None, model="text-search-ada-doc-001"),
embeddings=embeddings_model.embeddings,
documents=documents,
)
else:
vector_store = VectorStore(
repo_name, OpenAIEmbeddings(client=None, model="text-search-ada-doc-001")
)
vector_store = VectorStore(repo_name, embeddings=embeddings_model.embeddings)

llm = ChatOpenAI(temperature=0.9, max_tokens=2048, model="gpt-3.5-turbo")
memory = ConversationSummaryMemory(
Expand All @@ -68,12 +73,9 @@ def run():
similarity_result = vector_store.similarity_search(search_pattern)
spinner.stop()
for doc in similarity_result:
# print(doc.metadata["file_name"])
# print(doc.metadata["method_name"])
# print(doc.page_content)
print(doc)
print(doc.page_content)

choice = input("(C)ontinue search or (E)xit [C]?").strip().lower()
choice = input("[?] (C)ontinue search or (E)xit [C]:").strip().lower()

elif args.action == "chat":
question = input("🤖 Ask me anything about the codebase: ")
Expand All @@ -84,7 +86,9 @@ def run():
print(result["answer"])

choice = (
input("(C)ontinue chat, (R)eset chat or (E)xit [C]?").strip().lower()
input("[?] (C)ontinue chat, (R)eset chat or (E)xit [C]:")
.strip()
.lower()
)

if choice == "r":
Expand Down
56 changes: 18 additions & 38 deletions codeqai/codeparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@


def parse_code_files(code_files: list[str]) -> list[Document]:
source_code_documents, docstring_documents = [], []
source_code_splitter = None
docstring_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024, chunk_overlap=128
)
documents = []
code_splitter = None
for code_file in code_files:
with open(code_file, "r") as file:
file_bytes = file.read().encode()
Expand All @@ -29,9 +26,9 @@ def parse_code_files(code_files: list[str]) -> list[Document]:
langchain_language = utils.get_langchain_language(programming_language)

if langchain_language:
source_code_splitter = RecursiveCharacterTextSplitter.from_language(
code_splitter = RecursiveCharacterTextSplitter.from_language(
language=langchain_language,
chunk_size=1024,
chunk_size=512,
chunk_overlap=128,
)

Expand All @@ -42,39 +39,22 @@ def parse_code_files(code_files: list[str]) -> list[Document]:
for node in treesitterNodes:
method_source_code = node.method_source_code
filename = os.path.basename(code_file)
if programming_language == Language.PYTHON:
docstring_pattern = r"(\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\")"
method_source_code = re.sub(
docstring_pattern, "", node.method_source_code, flags=re.DOTALL
)
source_code_documents.append(
Document(
page_content=method_source_code,

if node.doc_comment and programming_language != Language.PYTHON:
method_source_code = node.doc_comment + "\n" + method_source_code

splitted_documents = [method_source_code]
if code_splitter:
splitted_documents = code_splitter.split_text(method_source_code)

for splitted_document in splitted_documents:
document = Document(
page_content=splitted_document,
metadata={
"file_name": filename,
"filename": filename,
"method_name": node.name,
},
)
)
if node.doc_comment:
docstring_documents.append(
Document(
page_content=node.doc_comment,
metadata={
"file_name": filename,
"method_name": node.name,
},
)
)

splitted_source_code_documents = source_code_documents
if source_code_splitter:
splitted_source_code_documents = source_code_splitter.split_documents(
source_code_documents
)

splitted_docstring_documents = docstring_splitter.split_documents(
docstring_documents
)
documents.append(document)

return splitted_source_code_documents + splitted_docstring_documents
return documents
55 changes: 41 additions & 14 deletions codeqai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,28 +62,55 @@ def create_config():
inquirer.Confirm(
"confirm", message="Do you want to use local models?", default=False
),
inquirer.List(
"embeddings",
message="Which embeddings do you want to use?",
choices=["USE", "BERT"],
default="USE",
),
inquirer.List(
"llm",
message="Which LLM do you want to use?",
choices=["GPT-2", "GPT-3"],
default="GPT-2",
),
]

confirm = inquirer.prompt(questions)

if confirm and confirm["confirm"]:
questions = [
inquirer.List(
"embeddings",
message="Which local embeddings model do you want to use?",
choices=[
"SentenceTransformers-all-mpnet-base-v2",
"Instructor-Large",
"Ollama",
],
default="SentenceTransformers-all-mpnet-base-v2",
),
inquirer.List(
"llm",
message="Which local LLM do you want to use?",
choices=["Llamacpp", "Ollama", "Huggingface"],
default="Llamacpp",
),
]
else:
questions = [
inquirer.List(
"embeddings",
message="Which embeddings do you want to use?",
choices=["OpenAI-text-embedding-ada-002", "Azure-OpenAI"],
default="OpenAI-text-embedding-ada-002",
),
inquirer.List(
"llm",
message="Which LLM do you want to use?",
choices=["GPT-3.5-Turbo", "GPT-4"],
default="GPT-3.5-Turbo",
),
]

answers = inquirer.prompt(questions)

if answers:
if confirm and answers:
config = {
"local": answers["confirm"],
"local": confirm["confirm"],
"embeddings": answers["embeddings"],
"llm": answers["llm"],
}
save_config(config)

return config

return {}
12 changes: 12 additions & 0 deletions codeqai/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,15 @@ class Language(Enum):
SCALA = "scala"
LUA = "lua"
UNKNOWN = "unknown"


class EmbeddingsModel(Enum):
SENTENCETRANSFORMERS_ALL_MPNET_BASE_V2 = "SentenceTransformers-all-mpnet-base-v2"
INSTRUCTOR_LARGE = "Instructor-Large"
OLLAMA = "Ollama"
OPENAI_TEXT_EMBEDDING_ADA_002 = "OpenAI-text-embedding-ada-002"
AZURE_OPENAI = "Azure-OpenAI"


class LocalLLMModel(Enum):
GPT_3_5_TURBO = "gpt-3.5-turbo"
93 changes: 91 additions & 2 deletions codeqai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,94 @@
import inquirer
from langchain.embeddings import (HuggingFaceEmbeddings,
HuggingFaceInstructEmbeddings)
from langchain.embeddings.openai import OpenAIEmbeddings

from codeqai import utils
from codeqai.constants import EmbeddingsModel

def get_embeddings():
pass

class Embeddings:
def __init__(
self, local=False, model=EmbeddingsModel.OPENAI_TEXT_EMBEDDING_ADA_002
):
self.model = model

if not local:
if model == EmbeddingsModel.OPENAI_TEXT_EMBEDDING_ADA_002:
self.embeddings = OpenAIEmbeddings(
client=None, model="text_embedding_ada_002"
)
else:
if model == EmbeddingsModel.OLLAMA:
pass
else:
try:
import sentence_transformers # noqa: F401
except ImportError:
self._install_sentence_transformers()

if model == EmbeddingsModel.SENTENCETRANSFORMERS_ALL_MPNET_BASE_V2:
self.embeddings = HuggingFaceEmbeddings()
elif model == EmbeddingsModel.INSTRUCTOR_LARGE:
try:
from InstructorEmbedding import \
INSTRUCTOR # noqa: F401
except ImportError:
self._install_instructor_embedding()
self.embeddings = HuggingFaceInstructEmbeddings()

def _install_sentence_transformers(self):
question = [
inquirer.Confirm(
"confirm",
message=f"{utils.get_bold_text('SentenceTransformers')} not found in this python environment. Do you want to install it now?",
default=True,
),
]

answers = inquirer.prompt(question)
if answers and answers["confirm"]:
import subprocess
import sys

try:
subprocess.run(
[
sys.executable,
"-m",
"pip",
"install",
"sentence_transformers",
],
check=True,
)
except subprocess.CalledProcessError as e:
print(f"Error during sentence_transformers installation: {e}")

def _install_instructor_embedding(self):
question = [
inquirer.Confirm(
"confirm",
message=f"{utils.get_bold_text('InstructorEmbedding')} not found in this python environment. Do you want to install it now?",
default=True,
),
]

answers = inquirer.prompt(question)
if answers and answers["confirm"]:
import subprocess
import sys

try:
subprocess.run(
[
sys.executable,
"-m",
"pip",
"install",
"InstructorEmbedding",
],
check=True,
)
except subprocess.CalledProcessError as e:
print(f"Error during sentence_transformers installation: {e}")
7 changes: 3 additions & 4 deletions codeqai/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,12 @@ def similarity_search(self, query: str):

def install_faiss(self):
try:
from faiss import FAISS_VERSION_MAJOR # noqa: F401
from faiss import FAISS_VERSION_MINOR
except: # noqa: E722
import faiss
except ImportError:
question = [
inquirer.Confirm(
"confirm",
message=f"{utils.get_bold_text('FAISS')} is not found in this python environment. Do you want to install it now?",
message=f"{utils.get_bold_text('FAISS')} not found in this python environment. Do you want to install it now?",
default=True,
),
]
Expand Down

0 comments on commit 4125e3d

Please sign in to comment.