From 2d5e752e37baa3f74ef4451839a6bb7295d39bdc Mon Sep 17 00:00:00 2001 From: BBC-Esq Date: Wed, 24 Jan 2024 14:35:22 -0500 Subject: [PATCH] v3.0.1 Fixed a bug when searching by document type. Consolidated three vision model scripts into the new loader_images.py script. --- src/constants.py | 2 - src/database_interactions.py | 29 +++ src/document_processor.py | 37 +-- src/extract_metadata.py | 19 +- src/gui_tabs_settings_database_query.py | 23 +- src/gui_tabs_settings_whisper.py | 11 +- src/gui_tabs_tools_transcribe.py | 6 +- src/loader_images.py | 322 ++++++++++++++++++++++++ src/server_connector.py | 18 +- src/transcribe_module.py | 4 +- src/voice_recorder_module.py | 2 +- 11 files changed, 412 insertions(+), 61 deletions(-) create mode 100644 src/loader_images.py diff --git a/src/constants.py b/src/constants.py index 10975faa..9d9dc6ce 100644 --- a/src/constants.py +++ b/src/constants.py @@ -307,8 +307,6 @@ ".html": "UnstructuredHTMLLoader", } -WHISPER_MODEL_NAMES = ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2"] - CHUNKS_ONLY_TOOLTIP = "Only return relevant chunks without connecting to the LLM. Extremely useful to test the chunk size/overlap settings." SPEAK_RESPONSE_TOOLTIP = "Only click this after the LLM's entire response is received otherwise your computer might explode." diff --git a/src/database_interactions.py b/src/database_interactions.py index 775294f5..dba6cd2e 100644 --- a/src/database_interactions.py +++ b/src/database_interactions.py @@ -108,3 +108,32 @@ def get_embeddings(self, EMBEDDING_MODEL_NAME, config_data): if __name__ == "__main__": create_vector_db = CreateVectorDB() create_vector_db.run() + +# To delete entries based on the "hash" metadata attribute, you can use this as_retriever method to create a retriever that filters documents based on their metadata. Once you retrieve the documents with the specific hash, you can then extract their IDs and use the delete method to remove them from the vectorstore. + +# Here is how you might implement this in your CreateVectorDB class: + +# python + +# class CreateVectorDB: + # # ... [other methods] ... + + # def delete_entries_by_hash(self, target_hash): + # my_cprint(f"Deleting entries with hash: {target_hash}", "red") + + # # Initialize the retriever with a filter for the specific hash + # retriever = self.db.as_retriever(search_kwargs={'filter': {'hash': target_hash}}) + + # # Retrieve documents with the specific hash + # documents = retriever.search("") + + # # Extract IDs from the documents + # ids_to_delete = [doc.id for doc in documents] + + # # Delete entries with the extracted IDs + # if ids_to_delete: + # self.db.delete(ids=ids_to_delete) + # my_cprint(f"Deleted {len(ids_to_delete)} entries from the database.", "green") + # else: + # my_cprint("No entries found with the specified hash.", "yellow") + diff --git a/src/document_processor.py b/src/document_processor.py index 524e8d93..e420dc4d 100644 --- a/src/document_processor.py +++ b/src/document_processor.py @@ -21,9 +21,7 @@ ) from constants import DOCUMENT_LOADERS -from loader_vision_llava import llava_process_images -from loader_vision_cogvlm import cogvlm_process_images -from loader_salesforce import salesforce_process_images +from loader_images import loader_cogvlm, loader_llava, loader_salesforce from extract_metadata import extract_document_metadata from utilities import my_cprint @@ -34,15 +32,18 @@ for ext, loader_name in DOCUMENT_LOADERS.items(): DOCUMENT_LOADERS[ext] = globals()[loader_name] -def process_images_wrapper(config): +def choose_image_loader(config): chosen_model = config["vision"]["chosen_model"] if chosen_model == 'llava' or chosen_model == 'bakllava': - return llava_process_images() + image_loader = loader_llava() + return image_loader.llava_process_images() elif chosen_model == 'cogvlm': - return cogvlm_process_images() + image_loader = loader_cogvlm() + return image_loader.cogvlm_process_images() elif chosen_model == 'salesforce': - return salesforce_process_images() + image_loader = loader_salesforce() + return image_loader.salesforce_process_images() else: return [] @@ -76,11 +77,8 @@ def load_single_document(file_path: Path) -> Document: document = loader.load()[0] - metadata = extract_document_metadata(file_path) # get metadata + metadata = extract_document_metadata(file_path) document.metadata.update(metadata) - - # with open("output_load_single_document.txt", "w", encoding="utf-8") as output_file: - # output_file.write(document.page_content) return document @@ -88,7 +86,7 @@ def load_document_batch(filepaths): with ThreadPoolExecutor(len(filepaths)) as exe: futures = [exe.submit(load_single_document, name) for name in filepaths] data_list = [future.result() for future in futures] - return (data_list, filepaths) # "data_list" = list of all document objects created by load single document + return (data_list, filepaths) def load_documents(source_dir: Path) -> list[Document]: all_files = list(source_dir.iterdir()) @@ -118,9 +116,9 @@ def load_documents(source_dir: Path) -> list[Document]: with open("config.yaml", "r") as config_file: config = yaml.safe_load(config_file) - # Use ProcessPoolExecutor to process images + # ProcessPoolExecutor to process images with ProcessPoolExecutor(1) as executor: - future = executor.submit(process_images_wrapper, config) + future = executor.submit(choose_image_loader, config) processed_docs = future.result() additional_docs = processed_docs if processed_docs is not None else [] @@ -137,10 +135,6 @@ def split_documents(documents): text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) texts = text_splitter.split_documents(documents) - - # Add 'text' attribute to metadata of each split document - #for document in texts: - #document.metadata["text"] = document.page_content my_cprint(f"Number of Chunks: {len(texts)}", "white") @@ -156,9 +150,4 @@ def split_documents(documents): count = sum(lower_bound <= size <= upper_bound for size in chunk_sizes) my_cprint(f"Chunks between {lower_bound} and {upper_bound} characters: {count}", "white") - return texts - -''' -# document object structure: Document(page_content="[ALL TEXT EXTRACTED]", metadata={'source': '[FULL FILE PATH WITH DOUBLE BACKSLASHES'}) -# list structure: [Document(page_content="...", metadata={'source': '...'}), Document(page_content="...", metadata={'source': '...'})] -''' \ No newline at end of file + return texts \ No newline at end of file diff --git a/src/extract_metadata.py b/src/extract_metadata.py index 5c7b6e9c..2f6f919d 100644 --- a/src/extract_metadata.py +++ b/src/extract_metadata.py @@ -1,12 +1,20 @@ import os import datetime +import hashlib -def extract_image_metadata(file_path, file_name): +def compute_file_hash(file_path): + hash_sha256 = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_sha256.update(chunk) + return hash_sha256.hexdigest() +def extract_image_metadata(file_path, file_name): file_type = os.path.splitext(file_name)[1] file_size = os.path.getsize(file_path) creation_date = datetime.datetime.fromtimestamp(os.path.getctime(file_path)).isoformat() modification_date = datetime.datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat() + file_hash = compute_file_hash(file_path) return { "file_path": file_path, @@ -15,7 +23,8 @@ def extract_image_metadata(file_path, file_name): "file_size": file_size, "creation_date": creation_date, "modification_date": modification_date, - "image": "True" + "document_type": "image", + "hash": file_hash } def extract_document_metadata(file_path): @@ -23,6 +32,7 @@ def extract_document_metadata(file_path): file_size = os.path.getsize(file_path) creation_date = datetime.datetime.fromtimestamp(os.path.getctime(file_path)).isoformat() modification_date = datetime.datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat() + file_hash = compute_file_hash(file_path) return { "file_path": str(file_path), @@ -31,5 +41,6 @@ def extract_document_metadata(file_path): "file_size": file_size, "creation_date": creation_date, "modification_date": modification_date, - "image": "False" - } \ No newline at end of file + "document_type": "document", + "hash": file_hash + } diff --git a/src/gui_tabs_settings_database_query.py b/src/gui_tabs_settings_database_query.py index 1a81cc94..98715507 100644 --- a/src/gui_tabs_settings_database_query.py +++ b/src/gui_tabs_settings_database_query.py @@ -13,6 +13,7 @@ def __init__(self): self.database_creation_device = config_data['Compute_Device']['database_creation'] self.database_query_device = config_data['Compute_Device']['database_query'] self.search_term = config_data['database'].get('search_term', '') + self.document_type = config_data['database'].get('document_types', '') v_layout = QVBoxLayout() h_layout_device = QHBoxLayout() @@ -64,7 +65,17 @@ def __init__(self): h_layout_search_term.addWidget(self.filter_button) self.file_type_combo = QComboBox() - self.file_type_combo.addItems(["All Files", "Images Only", "Non-Images Only"]) + file_type_items = ["All Files", "Images Only", "Documents Only"] + self.file_type_combo.addItems(file_type_items) + + if self.document_type == 'image': + default_index = file_type_items.index("Images Only") + elif self.document_type == 'document': + default_index = file_type_items.index("Documents Only") + else: + default_index = file_type_items.index("All Files") + self.file_type_combo.setCurrentIndex(default_index) + h_layout_search_term.addWidget(self.file_type_combo) v_layout.addLayout(h_layout_search_term) @@ -106,16 +117,16 @@ def update_config(self): file_type_map = { "All Files": '', - "Images Only": True, - "Non-Images Only": False + "Images Only": 'image', + "Documents Only": 'document' } file_type_selection = self.file_type_combo.currentText() - images_only_value = file_type_map[file_type_selection] + document_type_value = file_type_map[file_type_selection] - if images_only_value != config_data['database'].get('images_only', ''): + if document_type_value != config_data['database'].get('document_types', ''): settings_changed = True - config_data['database']['images_only'] = images_only_value + config_data['database']['document_types'] = document_type_value if settings_changed: with open('config.yaml', 'w') as f: diff --git a/src/gui_tabs_settings_whisper.py b/src/gui_tabs_settings_whisper.py index 70c176ad..6cc63525 100644 --- a/src/gui_tabs_settings_whisper.py +++ b/src/gui_tabs_settings_whisper.py @@ -57,7 +57,7 @@ def create_layout(self): model_label = QLabel("Model") layout.addWidget(model_label, 0, 0) self.model_combo = QComboBox() - self.model_combo.addItems(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2"]) + self.model_combo.addItems(["whisper-tiny.en", "whisper-base.en", "whisper-small.en", "whisper-medium.en", "whisper-large-v2"]) layout.addWidget(self.model_combo, 0, 1) # Quantization @@ -116,12 +116,3 @@ def update_config(self): yaml.dump(config_data, f) return settings_changed - -if __name__ == "__main__": - from PySide6.QtWidgets import QApplication - import sys - - app = QApplication(sys.argv) - transcriber_settings_tab = TranscriberSettingsTab() - transcriber_settings_tab.show() - sys.exit(app.exec()) diff --git a/src/gui_tabs_tools_transcribe.py b/src/gui_tabs_tools_transcribe.py index e7e45d7d..529ef682 100644 --- a/src/gui_tabs_tools_transcribe.py +++ b/src/gui_tabs_tools_transcribe.py @@ -5,7 +5,6 @@ from PySide6.QtCore import Qt import yaml from pathlib import Path -from constants import WHISPER_MODEL_NAMES from transcribe_module import TranscribeFile import threading @@ -38,7 +37,7 @@ def create_layout(self): hbox1 = QHBoxLayout() hbox1.addWidget(QLabel("Model")) self.model_combo = QComboBox() - self.model_combo.addItems([model for model in WHISPER_MODEL_NAMES if model not in ["tiny", "tiny.en", "base", "base.en"]]) + self.model_combo.addItems(["whisper-small.en", "whisper-medium.en", "whisper-large-v2"]) self.model_combo.setCurrentText(self.default_model) self.model_combo.currentTextChanged.connect(self.update_model_in_config) hbox1.addWidget(self.model_combo) @@ -73,7 +72,7 @@ def create_layout(self): main_layout.addLayout(hbox2) - # Third row of widgets (Select Audio File and Transcribe buttons) + # Third row of widgets hbox3 = QHBoxLayout() self.select_file_button = QPushButton("Select Audio File") self.select_file_button.clicked.connect(self.select_audio_file) @@ -85,7 +84,6 @@ def create_layout(self): main_layout.addLayout(hbox3) - # Label for displaying the selected file path self.file_path_label = QLabel("No file currently selected") main_layout.addWidget(self.file_path_label) diff --git a/src/loader_images.py b/src/loader_images.py new file mode 100644 index 00000000..3157dfa4 --- /dev/null +++ b/src/loader_images.py @@ -0,0 +1,322 @@ +import datetime +import gc +import os +import platform +import time +import torch +import yaml +from PIL import Image +from tqdm import tqdm +from transformers import ( + AutoModelForCausalLM, + AutoProcessor, + BlipForConditionalGeneration, + BlipProcessor, + LlamaTokenizer, + LlavaForConditionalGeneration +) +from langchain.docstore.document import Document +from extract_metadata import extract_image_metadata +from utilities import my_cprint + +def get_best_device(): + if torch.cuda.is_available(): + return 'cuda' + elif torch.backends.mps.is_available(): + return 'mps' + elif hasattr(torch.version, 'hip') and torch.version.hip and platform.system() == 'Linux': + return 'cuda' + else: + return 'cpu' + +class loader_cogvlm: + def initialize_model_and_tokenizer(self, config): + # Initialization logic for the model and tokenizer + tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5') + + if config['vision']['chosen_model'] == 'cogvlm' and config['vision']['chosen_quant'] == '4-bit': + model = AutoModelForCausalLM.from_pretrained( + 'THUDM/cogvlm-chat-hf', + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + load_in_4bit=True, + resume_download=True + ) + chosen_quant = "4-bit" + + elif config['vision']['chosen_model'] == 'cogvlm' and config['vision']['chosen_quant'] == '8-bit': + model = AutoModelForCausalLM.from_pretrained( + 'THUDM/cogvlm-chat-hf', + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + trust_remote_code=True, + load_in_8bit=True, + resume_download=True + ) + chosen_quant = "8-bit" + + print("Selected model: cogvlm") + print(f"Selected quant: {chosen_quant}") + my_cprint(f"Vision model loaded.", "green") + + return model, tokenizer + + def cogvlm_process_images(self): + script_dir = os.path.dirname(__file__) + image_dir = os.path.join(script_dir, "Images_for_DB") + documents = [] + + if not os.listdir(image_dir): + print("No files detected in the 'Images_for_DB' directory.") + return + + with open('config.yaml', 'r') as file: + config = yaml.safe_load(file) + + device = get_best_device() + print(f"Using device: {device}") + model, tokenizer = self.initialize_model_and_tokenizer(config) + + total_start_time = time.time() + + with tqdm(total=len(os.listdir(image_dir)), unit="image") as progress_bar: + for file_name in os.listdir(image_dir): + full_path = os.path.join(image_dir, file_name) + prompt = "Describe in detail what this image depicts in as much detail as possible." + + try: + with Image.open(full_path).convert('RGB') as raw_image: + inputs = model.build_conversation_input_ids(tokenizer, query=prompt, history=[], images=[raw_image]) + if config['vision']['chosen_quant'] == '4-bit': + inputs = { + 'input_ids': inputs['input_ids'].unsqueeze(0).to(device), + 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(device), + 'attention_mask': inputs['attention_mask'].unsqueeze(0).to(device), + 'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]], + } + elif config['vision']['chosen_quant'] == '8-bit': + inputs = { + 'input_ids': inputs['input_ids'].unsqueeze(0).to(device), + 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(device), + 'attention_mask': inputs['attention_mask'].unsqueeze(0).to(device), + 'images': [[inputs['images'][0].to(device).to(torch.float16)]], + } + + gen_kwargs = {"max_length": 2048, "do_sample": False} + with torch.no_grad(): + output = model.generate(**inputs, **gen_kwargs) + output = output[:, inputs['input_ids'].shape[1]:] + model_response = tokenizer.decode(output[0], skip_special_tokens=True).split("ASSISTANT: ")[-1] + + extracted_text = model_response + extracted_metadata = extract_image_metadata(full_path, file_name) + document = Document(page_content=extracted_text, metadata=extracted_metadata) + documents.append(document) + + except Exception as e: + print(f"{file_name}: Error processing image. Details: {e}") + + progress_bar.update(1) + + total_end_time = time.time() + total_time_taken = total_end_time - total_start_time + print(f"Total image processing time: {total_time_taken:.2f} seconds") + + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + my_cprint(f"Vision model removed from memory.", "red") + + return documents + +class loader_llava: + def llava_process_images(self): + script_dir = os.path.dirname(__file__) + image_dir = os.path.join(script_dir, "Images_for_DB") + documents = [] + + if not os.listdir(image_dir): + print("No files detected in the 'Images_for_DB' directory.") + return + + with open('config.yaml', 'r') as file: + config = yaml.safe_load(file) + + chosen_model = config['vision']['chosen_model'] + chosen_size = config['vision']['chosen_size'] + chosen_quant = config['vision']['chosen_quant'] + + model_id = "" + if chosen_model == 'llava' and chosen_size == '7b': + model_id = "llava-hf/llava-1.5-7b-hf" + elif chosen_model == 'bakllava': + model_id = "llava-hf/bakLlava-v1-hf" + elif chosen_model == 'llava' and chosen_size == '13b': + model_id = "llava-hf/llava-1.5-13b-hf" + + print(f"Selected model: {chosen_model}") + print(f"Selected size: {chosen_size}") + print(f"Selected quant: {chosen_quant}") + + device = get_best_device() + print(f"Using device: {device}") + + if chosen_model == 'llava' and chosen_quant == 'float16': + model = LlavaForConditionalGeneration.from_pretrained( + "llava-hf/llava-1.5-7b-hf", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + resume_download=True + ).to(device) + + elif chosen_model == 'llava' and chosen_quant == '8-bit': + model = LlavaForConditionalGeneration.from_pretrained( + model_id, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + load_in_8bit=True, + resume_download=True + ) + + elif chosen_model == 'llava' and chosen_quant == '4-bit': + model = LlavaForConditionalGeneration.from_pretrained( + model_id, + torch_dtype=torch.float32, + low_cpu_mem_usage=True, + load_in_4bit=True, + resume_download=True + ) + + elif chosen_model == 'bakllava' and chosen_quant == 'float16': + model = LlavaForConditionalGeneration.from_pretrained( + model_id, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + resume_download=True + ).to(device) + + elif chosen_model == 'bakllava' and chosen_quant == '8-bit': + model = LlavaForConditionalGeneration.from_pretrained( + model_id, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + load_in_8bit=True, + resume_download=True + ) + + elif chosen_model == 'bakllava' and chosen_quant == '4-bit': + model = LlavaForConditionalGeneration.from_pretrained( + model_id, + torch_dtype=torch.float32, + low_cpu_mem_usage=True, + load_in_4bit=True, + resume_download=True + ) + + my_cprint(f"Vision model loaded.", "green") + + processor = AutoProcessor.from_pretrained(model_id, resume_download=True) + + total_start_time = time.time() + total_tokens = 0 + + with tqdm(total=len(os.listdir(image_dir)), unit="image") as progress_bar: + for file_name in os.listdir(image_dir): + full_path = os.path.join(image_dir, file_name) + prompt = "USER: \nDescribe in detail what this image depicts in as much detail as possible.\nASSISTANT:" + + try: + with Image.open(full_path) as raw_image: + if chosen_quant == 'bfloat16' and chosen_model == 'bakllava': + inputs = processor(prompt, raw_image, return_tensors='pt').to(device, torch.bfloat16) + elif chosen_quant == 'float16': + inputs = processor(prompt, raw_image, return_tensors='pt').to(device, torch.float16) + elif chosen_quant == '8-bit': + if chosen_model == 'llava': + inputs = processor(prompt, raw_image, return_tensors='pt').to(device, torch.float16) + elif chosen_model == 'bakllava': + inputs = processor(prompt, raw_image, return_tensors='pt').to(device, torch.bfloat16) + elif chosen_quant == '4-bit': + inputs = processor(prompt, raw_image, return_tensors='pt').to(device, torch.float32) + + output = model.generate(**inputs, max_new_tokens=200, do_sample=True) + full_response = processor.decode(output[0][2:], skip_special_tokens=True, do_sample=True) # can add num_beams=5 + model_response = full_response.split("ASSISTANT: ")[-1] + + extracted_text = model_response + extracted_metadata = extract_image_metadata(full_path, file_name) + document = Document(page_content=extracted_text, metadata=extracted_metadata) + documents.append(document) + + total_tokens += output[0].shape[0] + progress_bar.update(1) + + except Exception as e: + print(f"{file_name}: Error processing image - {e}") + + total_end_time = time.time() + total_time_taken = total_end_time - total_start_time + print(f"Total image processing time: {total_time_taken:.2f} seconds") + print(f"Tokens per second: {total_tokens / total_time_taken:.2f}") + + del model + del processor + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + my_cprint(f"Vision model removed from memory.", "red") + + return documents + +class loader_salesforce: + def salesforce_process_images(self): + script_dir = os.path.dirname(__file__) + image_dir = os.path.join(script_dir, "Images_for_DB") + documents = [] + + if not os.listdir(image_dir): + print("No files detected in the 'Images_for_DB' directory.") + return + + device = get_best_device() + processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") + model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device) + + total_tokens = 0 + total_start_time = time.time() + + with tqdm(total=len(os.listdir(image_dir)), unit="image") as progress_bar: + for file_name in os.listdir(image_dir): + full_path = os.path.join(image_dir, file_name) + try: + with Image.open(full_path) as raw_image: + inputs = processor(raw_image, return_tensors="pt").to(device) + output = model.generate(**inputs, max_new_tokens=50) + caption = processor.decode(output[0], skip_special_tokens=True) + total_tokens += output[0].shape[0] + + extracted_metadata = extract_image_metadata(full_path, file_name) + document = Document(page_content=caption, metadata=extracted_metadata) + documents.append(document) + + progress_bar.update(1) + + except Exception as e: + print(f"{file_name}: Error processing image - {e}") + + total_end_time = time.time() + total_time_taken = total_end_time - total_start_time + print(f"Total image processing time: {total_time_taken:.2f} seconds") + print(f"Tokens per second: {total_tokens / total_time_taken:.2f}") + + del model + del processor + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + return documents diff --git a/src/server_connector.py b/src/server_connector.py index 09c2e949..4859846a 100644 --- a/src/server_connector.py +++ b/src/server_connector.py @@ -100,7 +100,7 @@ def ask_local_chatgpt(query, chunks_only, persist_directory=str(PERSIST_DIRECTOR try: EMBEDDING_MODEL_NAME = config['EMBEDDING_MODEL_NAME'] search_term = config['database'].get('search_term', '').lower() - images_only = config['database'].get('images_only', False) # Read images_only setting + document_types = config['database'].get('document_types', '') # Change here except KeyError: msg_box = QMessageBox() msg_box.setText("Configuration error: Missing required keys in config.yaml") @@ -111,8 +111,6 @@ def ask_local_chatgpt(query, chunks_only, persist_directory=str(PERSIST_DIRECTOR score_threshold = float(config['database']['similarity']) k = int(config['database']['contexts']) - model_kwargs = {"device": compute_device} - my_cprint("Embedding model loaded.", "green") if "instructor" in EMBEDDING_MODEL_NAME: @@ -121,7 +119,7 @@ def ask_local_chatgpt(query, chunks_only, persist_directory=str(PERSIST_DIRECTOR embeddings = HuggingFaceInstructEmbeddings( model_name=EMBEDDING_MODEL_NAME, - model_kwargs=model_kwargs, + model_kwargs={"device": compute_device}, embed_instruction=embed_instruction, query_instruction=query_instruction ) @@ -131,14 +129,14 @@ def ask_local_chatgpt(query, chunks_only, persist_directory=str(PERSIST_DIRECTOR embeddings = HuggingFaceBgeEmbeddings( model_name=EMBEDDING_MODEL_NAME, - model_kwargs=model_kwargs, + model_kwargs={"device": compute_device}, query_instruction=query_instruction ) else: embeddings = HuggingFaceEmbeddings( model_name=EMBEDDING_MODEL_NAME, - model_kwargs=model_kwargs + model_kwargs={"device": compute_device} ) tokenizer_path = "./Tokenizer" @@ -152,10 +150,15 @@ def ask_local_chatgpt(query, chunks_only, persist_directory=str(PERSIST_DIRECTOR my_cprint("Database initialized.", "white") + if document_types: + search_filter = {'document_type': document_types} + else: + search_filter = {} + retriever = db.as_retriever(search_kwargs={ 'score_threshold': score_threshold, 'k': k, - 'filter': {'image': str(images_only)} + 'filter': search_filter }) my_cprint("Querying database.", "white") @@ -225,7 +228,6 @@ def stop_interaction(): global stop_streaming stop_streaming = True - if __name__ == "__main__": user_input = "Your query here" ask_local_chatgpt(user_input) diff --git a/src/transcribe_module.py b/src/transcribe_module.py index 21b11a81..fb533826 100644 --- a/src/transcribe_module.py +++ b/src/transcribe_module.py @@ -14,7 +14,7 @@ def __init__(self, audio_file, config_path='config.yaml'): self.cpu_threads = max(4, os.cpu_count() - 4) self.model_config = { - 'model_name': f"ctranslate2-4you/whisper-{config['model']}-ct2-{config['quant']}", + 'model_name': f"ctranslate2-4you/{config['model']}-ct2-{config['quant']}", 'device': config['device'], 'compute_type': config['quant'], 'cpu_threads': self.cpu_threads @@ -42,7 +42,7 @@ def transcribe(self): ) my_cprint("Whisper model loaded.", 'green') - segments_generator, _ = model.transcribe(self.audio_file, beam_size=1) + segments_generator, _ = model.transcribe(self.audio_file, beam_size=5) segments = [] for segment in segments_generator: diff --git a/src/voice_recorder_module.py b/src/voice_recorder_module.py index 3fbbb868..5c0608c5 100644 --- a/src/voice_recorder_module.py +++ b/src/voice_recorder_module.py @@ -73,7 +73,7 @@ def start_recording(self): config_data = yaml.safe_load(stream) transcriber_config = config_data['transcriber'] - model_string = f"ctranslate2-4you/whisper-{transcriber_config['model']}-ct2-{transcriber_config['quant']}" + model_string = f"ctranslate2-4you/{transcriber_config['model']}-ct2-{transcriber_config['quant']}" cpu_threads = max(4, os.cpu_count() - 6)