Skip to content

Commit

Permalink
v3.2
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Jan 20, 2024
1 parent 25b3d07 commit 32e624a
Show file tree
Hide file tree
Showing 18 changed files with 307 additions and 308 deletions.
15 changes: 2 additions & 13 deletions src/bark_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# Import necessary libraries including tqdm
import warnings
import threading
import queue
Expand All @@ -9,20 +8,12 @@
import pyaudio
import gc
import yaml
from termcolor import cprint
import platform
from tqdm import tqdm
from utilities import my_cprint

warnings.filterwarnings("ignore", message="torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")

ENABLE_PRINT = True

def my_cprint(*args, **kwargs):
if ENABLE_PRINT:
filename = "bark_module.py"
modified_message = f"{filename}: {args[0]}"
cprint(modified_message, *args[1:], **kwargs)

class BarkAudio:
def __init__(self):
self.load_config()
Expand All @@ -43,12 +34,10 @@ def initialize_model_and_processor(self):
if torch.cuda.is_available():
if torch.version.hip and os_name == 'linux':
self.device = "cuda:0"
elif torch.version.cuda and os_name == 'windows':
elif torch.version.cuda:
self.device = "cuda:0"
elif torch.version.hip and os_name == 'windows':
self.device = "cpu"
else:
self.device = "cpu"
elif torch.backends.mps.is_available():
self.device = "mps"
elif os_name == 'darwin':
Expand Down
86 changes: 86 additions & 0 deletions src/choose_documents_and_vector_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import subprocess
import os
import yaml
from pathlib import Path
from PySide6.QtWidgets import QFileDialog, QDialog, QVBoxLayout, QTextEdit, QPushButton, QHBoxLayout

def choose_documents_directory():
allowed_extensions = ['.pdf', '.docx', '.epub', '.txt', '.enex', '.eml', '.msg', '.csv', '.xls', '.xlsx', '.rtf', '.odt',
'.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tif', '.tiff', '.html', '.htm', '.md', '.doc']
current_dir = Path(__file__).parent.resolve()
file_dialog = QFileDialog()
file_dialog.setFileMode(QFileDialog.ExistingFiles)
file_paths, _ = file_dialog.getOpenFileNames(None, "Choose Documents and Images for Database", str(current_dir))

if file_paths:
incompatible_files = []
compatible_files = []

for file_path in file_paths:
extension = Path(file_path).suffix.lower()
if extension in allowed_extensions:
if extension in ['.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tif', '.tiff']:
target_folder = current_dir / "Images_for_DB"
else:
target_folder = current_dir / "Docs_for_DB"

# Check and unlink existing symlink if necessary
symlink_target = target_folder / Path(file_path).name
if symlink_target.exists():
symlink_target.unlink()

# Create new symlink
symlink_target.symlink_to(file_path)
else:
incompatible_files.append(Path(file_path).name)

if incompatible_files:
dialog = QDialog()
dialog.setWindowTitle("Incompatible Files Detected")
layout = QVBoxLayout()

text_edit = QTextEdit()
text_edit.setReadOnly(True)
text_edit.setText("One or more files selected are not compatible to be put into the database. Click 'Ok' to only add compatible documents or 'cancel' to back out::\n\n" + "\n".join(incompatible_files))
layout.addWidget(text_edit)

button_box = QHBoxLayout()
ok_button = QPushButton("OK")
cancel_button = QPushButton("Cancel")
button_box.addWidget(ok_button)
button_box.addWidget(cancel_button)
layout.addLayout(button_box)

dialog.setLayout(layout)

ok_button.clicked.connect(dialog.accept)
cancel_button.clicked.connect(dialog.reject)

user_choice = dialog.exec()

if user_choice == QDialog.Rejected:
return

def load_config():
with open(Path("config.yaml"), 'r') as stream:
return yaml.safe_load(stream)

def select_embedding_model_directory():
initial_dir = Path('Embedding_Models') if Path('Embedding_Models').exists() else Path.home()
chosen_directory = QFileDialog.getExistingDirectory(None, "Select Embedding Model Directory", str(initial_dir))

if chosen_directory:
config_file_path = Path("config.yaml")
if config_file_path.exists():
try:
with open(config_file_path, 'r') as file:
config_data = yaml.safe_load(file)
except Exception as e:
config_data = {}

config_data["EMBEDDING_MODEL_NAME"] = chosen_directory

with open(config_file_path, 'w') as file:
yaml.dump(config_data, file)

print(f"Selected directory: {chosen_directory}")
19 changes: 9 additions & 10 deletions src/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ Compute_Device:
- cpu
database_creation: cpu
database_query: cpu
gpu_brand:
EMBEDDING_MODEL_NAME: null
gpu_brand:
EMBEDDING_MODEL_NAME:
Platform_Info:
os:
os:
Supported_CTranslate2_Quantizations:
CPU:
- float32
Expand All @@ -22,13 +22,13 @@ Supported_CTranslate2_Quantizations:
- int8
bark:
enable_cpu_offload: false
model_precision: float32
size: small
model_precision: float16
size: normal
speaker: v2/en_speaker_6
use_better_transformer: true
database:
chunk_overlap: 250
chunk_size: 750
chunk_overlap: 200
chunk_size: 800
contexts: 6
similarity: 0.9
embedding-models:
Expand Down Expand Up @@ -65,11 +65,10 @@ styles:
frame: 'background-color: #161b22;'
input: 'background-color: #2e333b; color: light gray; font: 13pt "Segoe UI Historic";'
text: 'background-color: #092327; color: light gray; font: 12pt "Segoe UI Historic";'
test_embeddings: true
transcribe_file:
device: cpu
file: null
model: medium.en
model: small.en
quant: float32
timestamps: true
transcriber:
Expand Down Expand Up @@ -108,4 +107,4 @@ vision:
- float32
available_sizes:
- 470m
test_image: null
test_image:
21 changes: 6 additions & 15 deletions src/create_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from document_processor import load_documents, split_documents
import torch
from utilities import validate_symbolic_links
from termcolor import cprint
from pathlib import Path
import os
from utilities import backup_database
from utilities import backup_database, my_cprint
import logging

logging.basicConfig(
Expand All @@ -20,10 +19,6 @@
)
logging.getLogger('chromadb.db.duckdb').setLevel(logging.WARNING)

def my_cprint(*args, **kwargs):
modified_message = f"create_database.py: {args[0]}"
cprint(modified_message, *args[1:], **kwargs)

ROOT_DIRECTORY = Path(__file__).resolve().parent
SOURCE_DIRECTORY = ROOT_DIRECTORY / "Docs_for_DB"
PERSIST_DIRECTORY = ROOT_DIRECTORY / "Vector_DB"
Expand All @@ -45,7 +40,7 @@ def main():
my_cprint(f"Loading documents.", "white")
documents = load_documents(SOURCE_DIRECTORY) # invoke document_processor.py; returns a list of document objects
if documents is None or len(documents) == 0:
cprint("No documents to load.", "red")
my_cprint("No documents to load.", "red")
return
my_cprint(f"Successfully loaded documents.", "white")

Expand Down Expand Up @@ -78,7 +73,7 @@ def main():
gc.collect()
my_cprint("Embedding model removed from memory.", "red")

def get_embeddings(EMBEDDING_MODEL_NAME, config_data, normalize_embeddings=False):
def get_embeddings(EMBEDDING_MODEL_NAME, config_data):
my_cprint("Creating embeddings.", "white")

compute_device = config_data['Compute_Device']['database_creation']
Expand All @@ -90,9 +85,8 @@ def get_embeddings(EMBEDDING_MODEL_NAME, config_data, normalize_embeddings=False
return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
encode_kwargs={"normalize_embeddings": normalize_embeddings},
embed_instruction=embed_instruction,
query_instruction=query_instruction
query_instruction=query_instruction # cache_folder=, encode_kwargs=
)

elif "bge" in EMBEDDING_MODEL_NAME:
Expand All @@ -101,16 +95,13 @@ def get_embeddings(EMBEDDING_MODEL_NAME, config_data, normalize_embeddings=False
return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
query_instruction=query_instruction,
encode_kwargs={"normalize_embeddings": normalize_embeddings}
query_instruction=query_instruction # encode_kwargs=, cache_folder=
)

else:

return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
encode_kwargs={"normalize_embeddings": normalize_embeddings}
model_kwargs={"device": compute_device} # encode_kwargs=, cache_folder=, multi_process=
)

if __name__ == "__main__":
Expand Down
34 changes: 13 additions & 21 deletions src/document_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import yaml
from termcolor import cprint
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor
from pathlib import Path
from langchain.docstore.document import Document
Expand All @@ -24,27 +24,16 @@
from loader_vision_llava import llava_process_images
from loader_vision_cogvlm import cogvlm_process_images
from loader_salesforce import salesforce_process_images
from extract_metadata import extract_document_metadata
from utilities import my_cprint

ENABLE_PRINT = True
ROOT_DIRECTORY = Path(__file__).parent
SOURCE_DIRECTORY = ROOT_DIRECTORY / "Docs_for_DB"
INGEST_THREADS = os.cpu_count() or 8

def my_cprint(*args, **kwargs):
if ENABLE_PRINT:
filename = "document_processor.py"
modified_message = f"{filename}: {args[0]}"
cprint(modified_message, *args[1:], **kwargs)

for ext, loader_name in DOCUMENT_LOADERS.items():
DOCUMENT_LOADERS[ext] = globals()[loader_name]

from langchain.document_loaders import (
UnstructuredEPubLoader, UnstructuredRTFLoader,
UnstructuredODTLoader, UnstructuredMarkdownLoader,
UnstructuredExcelLoader, UnstructuredCSVLoader
)

def process_images_wrapper(config):
chosen_model = config["vision"]["chosen_model"]

Expand All @@ -61,6 +50,7 @@ def load_single_document(file_path: Path) -> Document:
file_extension = file_path.suffix.lower()
loader_class = DOCUMENT_LOADERS.get(file_extension)

# specific loader parameters
if loader_class:
if file_extension == ".txt":
loader = loader_class(str(file_path), encoding='utf-8', autodetect_encoding=True)
Expand All @@ -76,7 +66,7 @@ def load_single_document(file_path: Path) -> Document:
loader = UnstructuredMarkdownLoader(str(file_path), mode="single", strategy="fast")
elif file_extension == ".xlsx" or file_extension == ".xlsd":
loader = UnstructuredExcelLoader(str(file_path), mode="single")
elif file_extension == ".html" or file_extension == ".htm":
elif file_extension == ".html":
loader = UnstructuredHTMLLoader(str(file_path), mode="single", strategy="fast")
elif file_extension == ".csv":
loader = UnstructuredCSVLoader(str(file_path), mode="single")
Expand All @@ -87,13 +77,14 @@ def load_single_document(file_path: Path) -> Document:

document = loader.load()[0]

# with open("output_load_single_document.txt", "w", encoding="utf-8") as output_file:
# output_file.write(document.page_content)
metadata = extract_document_metadata(file_path) # get metadata
document.metadata.update(metadata)

# text extracted before metadata added
# with open("output_load_single_document.txt", "w", encoding="utf-8") as output_file:
# output_file.write(document.page_content)

return document


def load_document_batch(filepaths):
with ThreadPoolExecutor(len(filepaths)) as exe:
futures = [exe.submit(load_single_document, name) for name in filepaths]
Expand All @@ -102,7 +93,8 @@ def load_document_batch(filepaths):

def load_documents(source_dir: Path) -> list[Document]:
all_files = list(source_dir.iterdir())
paths = [f for f in all_files if f.suffix in DOCUMENT_LOADERS.keys()]
# Adjust for case-insensitive extension matching
paths = [f for f in all_files if f.suffix.lower() in (key.lower() for key in DOCUMENT_LOADERS.keys())]

docs = []

Expand All @@ -128,7 +120,7 @@ def load_documents(source_dir: Path) -> list[Document]:
with open("config.yaml", "r") as config_file:
config = yaml.safe_load(config_file)

# Use ProcessPoolExecutor to run the selected image processing function in a separate process
# Use ProcessPoolExecutor for processing images
with ProcessPoolExecutor(1) as executor:
future = executor.submit(process_images_wrapper, config)
processed_docs = future.result()
Expand Down
Loading

0 comments on commit 32e624a

Please sign in to comment.