Skip to content

Commit

Permalink
v3.4.1
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Jan 31, 2024
1 parent 8f9edd6 commit 9c4e952
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 101 deletions.
10 changes: 5 additions & 5 deletions src/choose_documents_and_vector_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def choose_documents_directory():
return

def load_config():
with open(Path("config.yaml"), 'r') as stream:
with open(Path("config.yaml"), 'r', encoding='utf-8') as stream:
return yaml.safe_load(stream)

def select_embedding_model_directory():
Expand All @@ -73,14 +73,14 @@ def select_embedding_model_directory():
config_file_path = Path("config.yaml")
if config_file_path.exists():
try:
with open(config_file_path, 'r') as file:
with open(config_file_path, 'r', encoding='utf-8') as file:
config_data = yaml.safe_load(file)
except Exception as e:
except Exception:
config_data = {}

config_data["EMBEDDING_MODEL_NAME"] = chosen_directory

with open(config_file_path, 'w') as file:
with open(config_file_path, 'w', encoding='utf-8') as file:
yaml.dump(config_data, file)

print(f"Selected directory: {chosen_directory}")
print(f"Selected directory: {chosen_directory}")
87 changes: 41 additions & 46 deletions src/database_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from chromadb.config import Settings
from document_processor import load_documents, split_documents
import torch
from utilities import validate_symbolic_links
from utilities import validate_symbolic_links, backup_database, my_cprint
from pathlib import Path
import os
from utilities import backup_database, my_cprint
import logging

logging.basicConfig(
Expand All @@ -19,35 +18,66 @@
)
logging.getLogger('chromadb.db.duckdb').setLevel(logging.WARNING)

def load_config(root_directory):
with open(root_directory / "config.yaml", 'r', encoding='utf-8') as stream:
return yaml.safe_load(stream)

def create_embeddings(embedding_model_name, config_data):
my_cprint("Creating embeddings.", "white")
compute_device = config_data['Compute_Device']['database_creation']

if "instructor" in embedding_model_name:
embed_instruction = config_data['embedding-models']['instructor'].get('embed_instruction')
query_instruction = config_data['embedding-models']['instructor'].get('query_instruction')

return HuggingFaceInstructEmbeddings(
model_name=embedding_model_name,
model_kwargs={"device": compute_device},
embed_instruction=embed_instruction,
query_instruction=query_instruction # cache_folder=, encode_kwargs=
)

elif "bge" in embedding_model_name:
query_instruction = config_data['embedding-models']['bge'].get('query_instruction')

return HuggingFaceBgeEmbeddings(
model_name=embedding_model_name,
model_kwargs={"device": compute_device},
query_instruction=query_instruction # encode_kwargs=, cache_folder=
)

else:
return HuggingFaceEmbeddings(
model_name=embedding_model_name,
model_kwargs={"device": compute_device} # encode_kwargs=, cache_folder=, multi_process=
)

class CreateVectorDB:
def __init__(self):
self.ROOT_DIRECTORY = Path(__file__).resolve().parent
self.SOURCE_DIRECTORY = self.ROOT_DIRECTORY / "Docs_for_DB"
self.PERSIST_DIRECTORY = self.ROOT_DIRECTORY / "Vector_DB"
self.INGEST_THREADS = os.cpu_count() or 8

self.CHROMA_SETTINGS = Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=str(self.PERSIST_DIRECTORY),
anonymized_telemetry=False
)

def run(self):
with open(self.ROOT_DIRECTORY / "config.yaml", 'r') as stream:
config_data = yaml.safe_load(stream)

config_data = load_config(self.ROOT_DIRECTORY)
EMBEDDING_MODEL_NAME = config_data.get("EMBEDDING_MODEL_NAME")

my_cprint(f"Loading documents.", "white")
documents = load_documents(self.SOURCE_DIRECTORY) # invoke document_processor.py; returns a list of document objects
my_cprint("Loading documents.", "white")
documents = load_documents(self.SOURCE_DIRECTORY) # returns a list of full-text document objects
if documents is None or len(documents) == 0:
my_cprint("No documents to load.", "red")
return
my_cprint(f"Successfully loaded documents.", "white")
my_cprint("Successfully loaded documents.", "white")

texts = split_documents(documents) # invoke document_processor.py again; returns a list of split document objects
texts = split_documents(documents) # returns a list of chunked document objects

embeddings = self.get_embeddings(EMBEDDING_MODEL_NAME, config_data)
embeddings = create_embeddings(EMBEDDING_MODEL_NAME, config_data)
my_cprint("Embedding model loaded.", "green")

if self.PERSIST_DIRECTORY.exists():
Expand All @@ -74,41 +104,6 @@ def run(self):
gc.collect()
my_cprint("Embedding model removed from memory.", "red")

def get_embeddings(self, EMBEDDING_MODEL_NAME, config_data):
my_cprint("Creating embeddings.", "white")

compute_device = config_data['Compute_Device']['database_creation']

if "instructor" in EMBEDDING_MODEL_NAME:
embed_instruction = config_data['embedding-models']['instructor'].get('embed_instruction')
query_instruction = config_data['embedding-models']['instructor'].get('query_instruction')

return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
embed_instruction=embed_instruction,
query_instruction=query_instruction # cache_folder=, encode_kwargs=
)

elif "bge" in EMBEDDING_MODEL_NAME:
query_instruction = config_data['embedding-models']['bge'].get('query_instruction')

return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
query_instruction=query_instruction # encode_kwargs=, cache_folder=
)

else:
return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device} # encode_kwargs=, cache_folder=, multi_process=
)

if __name__ == "__main__":
create_vector_db = CreateVectorDB()
create_vector_db.run()

# To delete entries based on the "hash" metadata attribute, you can use this as_retriever method to create a retriever that filters documents based on their metadata. Once you retrieve the documents with the specific hash, you can then extract their IDs and use the delete method to remove them from the vectorstore.

# Here is how you might implement this in your CreateVectorDB class:
Expand Down
3 changes: 1 addition & 2 deletions src/gui_tabs_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,4 @@ def __init__(self):

self.layout.addLayout(center_button_layout)
self.setLayout(self.layout)
adjust_stretch(self.groups, self.layout)

adjust_stretch(self.groups, self.layout)
3 changes: 1 addition & 2 deletions src/gui_tabs_tools_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def openFolderDialog(self):
except Exception as e:
config = {}

# Update only the 'test_image' key in the 'vision' section of the config
vision_config = config.get('vision', {})
vision_config['test_image'] = file_path
config['vision'] = vision_config
Expand All @@ -110,4 +109,4 @@ def __init__(self, processing_function):
def run(self):
process = multiprocessing.Process(target=self.processing_function)
process.start()
process.join()
process.join()
1 change: 1 addition & 0 deletions src/loader_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def cogvlm_process_images(self):
print(f"Total image processing time: {total_time_taken:.2f} seconds")

del model
del tokenizer
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
Expand Down
Loading

0 comments on commit 9c4e952

Please sign in to comment.