Skip to content

Commit

Permalink
v3.0.3
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Jan 1, 2024
1 parent e2d4f64 commit 6c670c8
Show file tree
Hide file tree
Showing 11 changed files with 252 additions and 281 deletions.
9 changes: 5 additions & 4 deletions src/bark_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Import necessary libraries including tqdm
import warnings
import threading
import queue
Expand All @@ -10,6 +11,7 @@
import yaml
from termcolor import cprint
import platform
from tqdm import tqdm # Importing tqdm for the progress bar

warnings.filterwarnings("ignore", message="torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")

Expand Down Expand Up @@ -37,7 +39,6 @@ def load_config(self):

def initialize_model_and_processor(self):
os_name = platform.system().lower()

# set compute device
if torch.cuda.is_available():
if torch.version.hip and os_name == 'linux':
Expand Down Expand Up @@ -125,8 +126,8 @@ def process_text_thread(self):
break

sentences = re.split(r'[.!?;]+', text_prompt)

for sentence in sentences:
# Adding tqdm progress bar
for sentence in tqdm(sentences, desc="Processing Sentences"):
if sentence.strip():
voice_preset = self.config['speaker']
inputs = self.processor(text=sentence, voice_preset=voice_preset, return_tensors="pt")
Expand Down Expand Up @@ -179,4 +180,4 @@ def release_resources(self):

if __name__ == "__main__":
bark_audio = BarkAudio()
bark_audio.run()
bark_audio.run()
84 changes: 63 additions & 21 deletions src/check_gpu.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,74 @@
import sys
from PySide6.QtWidgets import QApplication, QMessageBox
import torch

def display_info():
app = QApplication(sys.argv)
info_message = ""
try:
import torch
except ImportError:
def display_info():
app = QApplication(sys.argv)
msg_box = QMessageBox(QMessageBox.Information, "PyTorch Not Installed", "PyTorch is not installed on this system.")
msg_box.exec()

if torch.cuda.is_available():
info_message += "CUDA is available!\n"
info_message += "CUDA version: {}\n\n".format(torch.version.cuda)
else:
info_message += "CUDA is not available.\n\n"
else:
def check_bitsandbytes():
try:
import bitsandbytes as bnb
p = torch.nn.Parameter(torch.rand(10, 10).cuda())
a = torch.rand(10, 10).cuda()

if torch.backends.mps.is_available():
info_message += "Metal/MPS is available!\n\n"
else:
info_message += "Metal/MPS is not available.\n\n"
p1 = p.data.sum().item()

info_message += "If you want to check the version of Metal and MPS on your macOS device, you can go to \"About This Mac\" -> \"System Report\" -> \"Graphics/Displays\" and look for information related to Metal and MPS.\n\n"
adam = bnb.optim.Adam([p])

if torch.version.hip is not None:
info_message += "ROCm is available!\n"
info_message += "ROCm version: {}\n".format(torch.version.hip)
else:
info_message += "ROCm is not available.\n"
out = a * p
loss = out.sum()
loss.backward()
adam.step()

msg_box = QMessageBox(QMessageBox.Information, "GPU Acceleration Available?", info_message)
msg_box.exec()
p2 = p.data.sum().item()

assert p1 != p2
return "SUCCESS!\nInstallation of bitsandbytes was successful!"
except ImportError:
return "bitsandbytes is not installed."
except AssertionError:
return "bitsandbytes is installed, but the installation seems incorrect."
except Exception as e:
return f"An error occurred while checking bitsandbytes: {e}"

def display_info():
app = QApplication(sys.argv)
info_message = ""

if torch.cuda.is_available():
info_message += "CUDA is available!\n"
info_message += "CUDA version: {}\n\n".format(torch.version.cuda)
else:
info_message += "CUDA is not available.\n\n"

if torch.backends.mps.is_available():
info_message += "Metal/MPS is available!\n\n"
else:
info_message += "Metal/MPS is not available.\n\n"
if not torch.backends.mps.is_built():
info_message += "MPS not available because the current PyTorch install was not built with MPS enabled.\n\n"
else:
info_message += "MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.\n\n"

info_message += "If you want to check the version of Metal and MPS on your macOS device, you can go to \"About This Mac\" -> \"System Report\" -> \"Graphics/Displays\" and look for information related to Metal and MPS.\n\n"

if torch.version.hip is not None:
info_message += "ROCm is available!\n"
info_message += "ROCm version: {}\n".format(torch.version.hip)
else:
info_message += "ROCm is not available.\n"

# Check for bitsandbytes
bitsandbytes_message = check_bitsandbytes()
info_message += "\n" + bitsandbytes_message

msg_box = QMessageBox(QMessageBox.Information, "GPU Acceleration and Library Check", info_message)
msg_box.exec()

if __name__ == "__main__":
display_info()
26 changes: 14 additions & 12 deletions src/choose_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import os
from pathlib import Path
from PySide6.QtWidgets import QApplication, QFileDialog, QDialog, QVBoxLayout, QTextEdit, QPushButton, QHBoxLayout
import sys
import platform

def choose_documents_directory():
allowed_extensions = ['.pdf', '.docx', '.epub', '.txt', '.enex', '.eml', '.msg', '.csv', '.xls', '.xlsx', '.rtf', '.odt',
'.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tif', '.tiff', '.html', '.htm', '.md']
current_dir = Path(__file__).parent.resolve()
docs_folder = current_dir / "Docs_for_DB"
images_folder = current_dir / "Images_for_DB"
file_dialog = QFileDialog()
file_dialog.setFileMode(QFileDialog.ExistingFiles)
file_paths, _ = file_dialog.getOpenFileNames(None, "Choose Documents and Images for Database", str(current_dir))
Expand All @@ -21,12 +19,18 @@ def choose_documents_directory():
for file_path in file_paths:
extension = Path(file_path).suffix.lower()
if extension in allowed_extensions:
# Determine target folder without creating it
if extension in ['.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tif', '.tiff']:
target_folder = images_folder
target_folder = current_dir / "Images_for_DB"
else:
target_folder = docs_folder
target_folder.mkdir(parents=True, exist_ok=True)
target_folder = current_dir / "Docs_for_DB"

# Check and unlink existing symlink if necessary
symlink_target = target_folder / Path(file_path).name
if symlink_target.exists():
symlink_target.unlink()

# Create new symlink
symlink_target.symlink_to(file_path)
else:
incompatible_files.append(Path(file_path).name)
Expand Down Expand Up @@ -62,14 +66,12 @@ def see_documents_directory():
current_dir = Path(__file__).parent.resolve()
docs_folder = current_dir / "Docs_for_DB"

docs_folder.mkdir(parents=True, exist_ok=True)

# Cross-platform directory opening
if os.name == 'nt': # Windows
os_name = platform.system()
if os_name == 'Windows':
subprocess.Popen(['explorer', str(docs_folder)])
elif sys.platform == 'darwin': # macOS
elif os_name == 'Darwin':
subprocess.Popen(['open', str(docs_folder)])
elif sys.platform.startswith('linux'): # Linux
elif os_name == 'Linux':
subprocess.Popen(['xdg-open', str(docs_folder)])

if __name__ == '__main__':
Expand Down
89 changes: 35 additions & 54 deletions src/create_database.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,24 @@
import gc
from mailbox import Message
import os
import shutil
from pathlib import Path
from typing import Self

import torch
import yaml
from chromadb.config import Settings
import gc
from langchain.docstore.document import Document
from langchain.embeddings import (
HuggingFaceBgeEmbeddings,
HuggingFaceEmbeddings,
HuggingFaceInstructEmbeddings,
)
from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
from langchain.vectorstores import Chroma
from termcolor import cprint

from chromadb.config import Settings
from document_processor import load_documents, split_documents
import torch
from utilities import validate_symbolic_links
from termcolor import cprint
from pathlib import Path
import os

ENABLE_PRINT = True


def my_cprint(*args, **kwargs):
if ENABLE_PRINT:
modified_message = f"create_database.py: {args[0]}"
cprint(modified_message, *args[1:], **kwargs)


ROOT_DIRECTORY = Path(__file__).resolve().parent
SOURCE_DIRECTORY = ROOT_DIRECTORY / "Docs_for_DB"
PERSIST_DIRECTORY = ROOT_DIRECTORY / "Vector_DB"
Expand All @@ -37,93 +27,84 @@ def my_cprint(*args, **kwargs):
CHROMA_SETTINGS = Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=str(PERSIST_DIRECTORY),
anonymized_telemetry=False,
anonymized_telemetry=False
)


def main():
with open(ROOT_DIRECTORY / "config.yaml", "r") as stream:

with open(ROOT_DIRECTORY / "config.yaml", 'r') as stream:
config_data = yaml.safe_load(stream)

EMBEDDING_MODEL_NAME = config_data.get("EMBEDDING_MODEL_NAME")

my_cprint(f"Loading documents.", "white")
documents = load_documents(
SOURCE_DIRECTORY
) # invoke document_processor.py; returns a list of document objects
if documents == None or len(documents) == 0:
cprint(f"No documents to load.")
documents = load_documents(SOURCE_DIRECTORY) # invoke document_processor.py; returns a list of document objects
if documents is None or len(documents) == 0:
cprint("No documents to load.", "red")
return

my_cprint(f"Successfully loaded documents.", "white")
texts = split_documents(
documents
) # invoke document_processor.py again; returns a list of split document objects


texts = split_documents(documents) # invoke document_processor.py again; returns a list of split document objects

embeddings = get_embeddings(EMBEDDING_MODEL_NAME, config_data)
my_cprint("Embedding model loaded.", "green")

if PERSIST_DIRECTORY.exists():
shutil.rmtree(PERSIST_DIRECTORY)
PERSIST_DIRECTORY.mkdir(parents=True, exist_ok=True)

my_cprint("Creating database.", "white")

db = Chroma.from_documents(
texts,
embeddings,
persist_directory=str(PERSIST_DIRECTORY),
texts, embeddings,
persist_directory=str(PERSIST_DIRECTORY),
client_settings=CHROMA_SETTINGS,
)

my_cprint("Persisting database.", "white")
db.persist()
my_cprint("Database persisted.", "white")


del embeddings.client
del embeddings
torch.cuda.empty_cache()
gc.collect()
my_cprint("Embedding model removed from memory.", "red")


def get_embeddings(EMBEDDING_MODEL_NAME, config_data, normalize_embeddings=False):
my_cprint("Creating embeddings.", "white")

compute_device = config_data["Compute_Device"]["database_creation"]

compute_device = config_data['Compute_Device']['database_creation']
if "instructor" in EMBEDDING_MODEL_NAME:
embed_instruction = config_data["embedding-models"]["instructor"].get(
"embed_instruction"
)
query_instruction = config_data["embedding-models"]["instructor"].get(
"query_instruction"
)
embed_instruction = config_data['embedding-models']['instructor'].get('embed_instruction')
query_instruction = config_data['embedding-models']['instructor'].get('query_instruction')

return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
encode_kwargs={"normalize_embeddings": normalize_embeddings},
embed_instruction=embed_instruction,
query_instruction=query_instruction,
query_instruction=query_instruction
)

elif "bge" in EMBEDDING_MODEL_NAME:
query_instruction = config_data["embedding-models"]["bge"].get(
"query_instruction"
)
query_instruction = config_data['embedding-models']['bge'].get('query_instruction')

return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
query_instruction=query_instruction,
encode_kwargs={"normalize_embeddings": normalize_embeddings},
encode_kwargs={"normalize_embeddings": normalize_embeddings}
)

else:

return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
encode_kwargs={"normalize_embeddings": normalize_embeddings},
encode_kwargs={"normalize_embeddings": normalize_embeddings}
)


if __name__ == "__main__":
main()
Loading

0 comments on commit 6c670c8

Please sign in to comment.