Skip to content

Commit

Permalink
v3.0.1
Browse files Browse the repository at this point in the history
Fixed a bug when searching by document type.

Consolidated three vision model scripts into the new loader_images.py script.
  • Loading branch information
BBC-Esq authored Jan 24, 2024
1 parent 87ad685 commit 2d5e752
Show file tree
Hide file tree
Showing 11 changed files with 412 additions and 61 deletions.
2 changes: 0 additions & 2 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,6 @@
".html": "UnstructuredHTMLLoader",
}

WHISPER_MODEL_NAMES = ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2"]

CHUNKS_ONLY_TOOLTIP = "Only return relevant chunks without connecting to the LLM. Extremely useful to test the chunk size/overlap settings."

SPEAK_RESPONSE_TOOLTIP = "Only click this after the LLM's entire response is received otherwise your computer might explode."
Expand Down
29 changes: 29 additions & 0 deletions src/database_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,32 @@ def get_embeddings(self, EMBEDDING_MODEL_NAME, config_data):
if __name__ == "__main__":
create_vector_db = CreateVectorDB()
create_vector_db.run()

# To delete entries based on the "hash" metadata attribute, you can use this as_retriever method to create a retriever that filters documents based on their metadata. Once you retrieve the documents with the specific hash, you can then extract their IDs and use the delete method to remove them from the vectorstore.

# Here is how you might implement this in your CreateVectorDB class:

# python

# class CreateVectorDB:
# # ... [other methods] ...

# def delete_entries_by_hash(self, target_hash):
# my_cprint(f"Deleting entries with hash: {target_hash}", "red")

# # Initialize the retriever with a filter for the specific hash
# retriever = self.db.as_retriever(search_kwargs={'filter': {'hash': target_hash}})

# # Retrieve documents with the specific hash
# documents = retriever.search("")

# # Extract IDs from the documents
# ids_to_delete = [doc.id for doc in documents]

# # Delete entries with the extracted IDs
# if ids_to_delete:
# self.db.delete(ids=ids_to_delete)
# my_cprint(f"Deleted {len(ids_to_delete)} entries from the database.", "green")
# else:
# my_cprint("No entries found with the specified hash.", "yellow")

37 changes: 13 additions & 24 deletions src/document_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
)

from constants import DOCUMENT_LOADERS
from loader_vision_llava import llava_process_images
from loader_vision_cogvlm import cogvlm_process_images
from loader_salesforce import salesforce_process_images
from loader_images import loader_cogvlm, loader_llava, loader_salesforce
from extract_metadata import extract_document_metadata
from utilities import my_cprint

Expand All @@ -34,15 +32,18 @@
for ext, loader_name in DOCUMENT_LOADERS.items():
DOCUMENT_LOADERS[ext] = globals()[loader_name]

def process_images_wrapper(config):
def choose_image_loader(config):
chosen_model = config["vision"]["chosen_model"]

if chosen_model == 'llava' or chosen_model == 'bakllava':
return llava_process_images()
image_loader = loader_llava()
return image_loader.llava_process_images()
elif chosen_model == 'cogvlm':
return cogvlm_process_images()
image_loader = loader_cogvlm()
return image_loader.cogvlm_process_images()
elif chosen_model == 'salesforce':
return salesforce_process_images()
image_loader = loader_salesforce()
return image_loader.salesforce_process_images()
else:
return []

Expand Down Expand Up @@ -76,19 +77,16 @@ def load_single_document(file_path: Path) -> Document:

document = loader.load()[0]

metadata = extract_document_metadata(file_path) # get metadata
metadata = extract_document_metadata(file_path)
document.metadata.update(metadata)

# with open("output_load_single_document.txt", "w", encoding="utf-8") as output_file:
# output_file.write(document.page_content)

return document

def load_document_batch(filepaths):
with ThreadPoolExecutor(len(filepaths)) as exe:
futures = [exe.submit(load_single_document, name) for name in filepaths]
data_list = [future.result() for future in futures]
return (data_list, filepaths) # "data_list" = list of all document objects created by load single document
return (data_list, filepaths)

def load_documents(source_dir: Path) -> list[Document]:
all_files = list(source_dir.iterdir())
Expand Down Expand Up @@ -118,9 +116,9 @@ def load_documents(source_dir: Path) -> list[Document]:
with open("config.yaml", "r") as config_file:
config = yaml.safe_load(config_file)

# Use ProcessPoolExecutor to process images
# ProcessPoolExecutor to process images
with ProcessPoolExecutor(1) as executor:
future = executor.submit(process_images_wrapper, config)
future = executor.submit(choose_image_loader, config)
processed_docs = future.result()
additional_docs = processed_docs if processed_docs is not None else []

Expand All @@ -137,10 +135,6 @@ def split_documents(documents):

text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
texts = text_splitter.split_documents(documents)

# Add 'text' attribute to metadata of each split document
#for document in texts:
#document.metadata["text"] = document.page_content

my_cprint(f"Number of Chunks: {len(texts)}", "white")

Expand All @@ -156,9 +150,4 @@ def split_documents(documents):
count = sum(lower_bound <= size <= upper_bound for size in chunk_sizes)
my_cprint(f"Chunks between {lower_bound} and {upper_bound} characters: {count}", "white")

return texts

'''
# document object structure: Document(page_content="[ALL TEXT EXTRACTED]", metadata={'source': '[FULL FILE PATH WITH DOUBLE BACKSLASHES'})
# list structure: [Document(page_content="...", metadata={'source': '...'}), Document(page_content="...", metadata={'source': '...'})]
'''
return texts
19 changes: 15 additions & 4 deletions src/extract_metadata.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import os
import datetime
import hashlib

def extract_image_metadata(file_path, file_name):
def compute_file_hash(file_path):
hash_sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()

def extract_image_metadata(file_path, file_name):
file_type = os.path.splitext(file_name)[1]
file_size = os.path.getsize(file_path)
creation_date = datetime.datetime.fromtimestamp(os.path.getctime(file_path)).isoformat()
modification_date = datetime.datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat()
file_hash = compute_file_hash(file_path)

return {
"file_path": file_path,
Expand All @@ -15,14 +23,16 @@ def extract_image_metadata(file_path, file_name):
"file_size": file_size,
"creation_date": creation_date,
"modification_date": modification_date,
"image": "True"
"document_type": "image",
"hash": file_hash
}

def extract_document_metadata(file_path):
file_type = os.path.splitext(file_path)[1]
file_size = os.path.getsize(file_path)
creation_date = datetime.datetime.fromtimestamp(os.path.getctime(file_path)).isoformat()
modification_date = datetime.datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat()
file_hash = compute_file_hash(file_path)

return {
"file_path": str(file_path),
Expand All @@ -31,5 +41,6 @@ def extract_document_metadata(file_path):
"file_size": file_size,
"creation_date": creation_date,
"modification_date": modification_date,
"image": "False"
}
"document_type": "document",
"hash": file_hash
}
23 changes: 17 additions & 6 deletions src/gui_tabs_settings_database_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self):
self.database_creation_device = config_data['Compute_Device']['database_creation']
self.database_query_device = config_data['Compute_Device']['database_query']
self.search_term = config_data['database'].get('search_term', '')
self.document_type = config_data['database'].get('document_types', '')

v_layout = QVBoxLayout()
h_layout_device = QHBoxLayout()
Expand Down Expand Up @@ -64,7 +65,17 @@ def __init__(self):
h_layout_search_term.addWidget(self.filter_button)

self.file_type_combo = QComboBox()
self.file_type_combo.addItems(["All Files", "Images Only", "Non-Images Only"])
file_type_items = ["All Files", "Images Only", "Documents Only"]
self.file_type_combo.addItems(file_type_items)

if self.document_type == 'image':
default_index = file_type_items.index("Images Only")
elif self.document_type == 'document':
default_index = file_type_items.index("Documents Only")
else:
default_index = file_type_items.index("All Files")
self.file_type_combo.setCurrentIndex(default_index)

h_layout_search_term.addWidget(self.file_type_combo)

v_layout.addLayout(h_layout_search_term)
Expand Down Expand Up @@ -106,16 +117,16 @@ def update_config(self):

file_type_map = {
"All Files": '',
"Images Only": True,
"Non-Images Only": False
"Images Only": 'image',
"Documents Only": 'document'
}

file_type_selection = self.file_type_combo.currentText()
images_only_value = file_type_map[file_type_selection]
document_type_value = file_type_map[file_type_selection]

if images_only_value != config_data['database'].get('images_only', ''):
if document_type_value != config_data['database'].get('document_types', ''):
settings_changed = True
config_data['database']['images_only'] = images_only_value
config_data['database']['document_types'] = document_type_value

if settings_changed:
with open('config.yaml', 'w') as f:
Expand Down
11 changes: 1 addition & 10 deletions src/gui_tabs_settings_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def create_layout(self):
model_label = QLabel("Model")
layout.addWidget(model_label, 0, 0)
self.model_combo = QComboBox()
self.model_combo.addItems(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2"])
self.model_combo.addItems(["whisper-tiny.en", "whisper-base.en", "whisper-small.en", "whisper-medium.en", "whisper-large-v2"])
layout.addWidget(self.model_combo, 0, 1)

# Quantization
Expand Down Expand Up @@ -116,12 +116,3 @@ def update_config(self):
yaml.dump(config_data, f)

return settings_changed

if __name__ == "__main__":
from PySide6.QtWidgets import QApplication
import sys

app = QApplication(sys.argv)
transcriber_settings_tab = TranscriberSettingsTab()
transcriber_settings_tab.show()
sys.exit(app.exec())
6 changes: 2 additions & 4 deletions src/gui_tabs_tools_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from PySide6.QtCore import Qt
import yaml
from pathlib import Path
from constants import WHISPER_MODEL_NAMES
from transcribe_module import TranscribeFile
import threading

Expand Down Expand Up @@ -38,7 +37,7 @@ def create_layout(self):
hbox1 = QHBoxLayout()
hbox1.addWidget(QLabel("Model"))
self.model_combo = QComboBox()
self.model_combo.addItems([model for model in WHISPER_MODEL_NAMES if model not in ["tiny", "tiny.en", "base", "base.en"]])
self.model_combo.addItems(["whisper-small.en", "whisper-medium.en", "whisper-large-v2"])
self.model_combo.setCurrentText(self.default_model)
self.model_combo.currentTextChanged.connect(self.update_model_in_config)
hbox1.addWidget(self.model_combo)
Expand Down Expand Up @@ -73,7 +72,7 @@ def create_layout(self):

main_layout.addLayout(hbox2)

# Third row of widgets (Select Audio File and Transcribe buttons)
# Third row of widgets
hbox3 = QHBoxLayout()
self.select_file_button = QPushButton("Select Audio File")
self.select_file_button.clicked.connect(self.select_audio_file)
Expand All @@ -85,7 +84,6 @@ def create_layout(self):

main_layout.addLayout(hbox3)

# Label for displaying the selected file path
self.file_path_label = QLabel("No file currently selected")
main_layout.addWidget(self.file_path_label)

Expand Down
Loading

0 comments on commit 2d5e752

Please sign in to comment.