Skip to content

Commit

Permalink
v2.7.1
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Dec 11, 2023
1 parent c798067 commit 468895f
Show file tree
Hide file tree
Showing 14 changed files with 347 additions and 137 deletions.
16 changes: 12 additions & 4 deletions src/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ Supported_CTranslate2_Quantizations:
- int8_bfloat16
- int8
database:
chunk_overlap: 300
chunk_overlap: 200
chunk_size: 600
contexts: 25
contexts: 8
similarity: 0.9
embedding-models:
bge:
Expand All @@ -38,19 +38,27 @@ server:
model_max_tokens: -1
model_temperature: 0.1
prefix: '[INST]'
prefix_chat_ml: <|im_start|>
prefix_llama2_and_mistral: '[INST]'
prefix_neural_chat: '### User:'
prefix_orca2: <|im_start|>user
prompt_format_disabled: false
suffix: '[/INST]'
suffix_chat_ml: <|im_end|>
suffix_llama2_and_mistral: '[/INST]'
suffix_neural_chat: '### Assistant:'
suffix_orca2: <|im_end|><|im_start|>assistant
styles:
button: 'background-color: #323842; color: light gray; font: 10pt "Segoe UI Historic";
width: 29;'
frame: 'background-color: #161b22;'
input: 'background-color: #2e333b; color: light gray; font: 13pt "Segoe UI Historic";'
text: 'background-color: #092327; color: light gray; font: 12pt "Segoe UI Historic";'
test_embeddings: false
test_embeddings: true
transcribe_file:
device: cpu
file: null
model: small.en
model: large-v2
quant: float32
timestamps: true
transcriber:
Expand Down
2 changes: 0 additions & 2 deletions src/document_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
UnstructuredMarkdownLoader
)

# Import DOCUMENT_LOADERS from constants.py
from constants import DOCUMENT_LOADERS

ENABLE_PRINT = True
Expand All @@ -35,7 +34,6 @@ def my_cprint(*args, **kwargs):
SOURCE_DIRECTORY = f"{ROOT_DIRECTORY}/Docs_for_DB"
INGEST_THREADS = os.cpu_count() or 8

# Replace class names in DOCUMENT_LOADERS with actual classes
for ext, loader_name in DOCUMENT_LOADERS.items():
DOCUMENT_LOADERS[ext] = globals()[loader_name]

Expand Down
49 changes: 42 additions & 7 deletions src/gui.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from PySide6.QtWidgets import (
QApplication, QWidget, QPushButton, QVBoxLayout, QTabWidget,
QTextEdit, QSplitter, QFrame, QStyleFactory, QLabel, QGridLayout, QMenuBar, QCheckBox
QTextEdit, QSplitter, QFrame, QStyleFactory, QLabel, QGridLayout, QMenuBar, QCheckBox, QHBoxLayout
)
from PySide6.QtCore import Qt, QThread, Signal, QUrl
from PySide6.QtWebEngineWidgets import QWebEngineView
from PySide6.QtCore import Qt
import os
import torch
import yaml
import sys
from initialize import main as initialize_system
Expand All @@ -15,7 +15,7 @@
import create_database
from gui_tabs import create_tabs
from gui_threads import CreateDatabaseThread, SubmitButtonThread
from button_module import create_button_row
import voice_recorder_module
from utilities import list_theme_files, make_theme_changer, load_stylesheet

class DocQA_GUI(QWidget):
Expand All @@ -29,6 +29,12 @@ def __init__(self):
self.init_menu()
self.load_config()

def is_nvidia_gpu(self):
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
return "nvidia" in gpu_name.lower()
return False

def load_config(self):
script_dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(script_dir, 'config.yaml')
Expand All @@ -52,7 +58,7 @@ def init_ui(self):
# Buttons data
button_data = [
("Download Embedding Model", lambda: download_embedding_model(self)),
("Set Embedding Model Directory", select_embedding_model_directory),
("Choose Embedding Model Directory", select_embedding_model_directory),
("Choose Documents for Database", choose_documents_directory),
("Create Vector Database", self.on_create_button_clicked)
]
Expand Down Expand Up @@ -88,16 +94,18 @@ def init_ui(self):
right_vbox.addWidget(self.test_embeddings_checkbox)

# Create and add button row
button_row_widget = create_button_row(self.on_submit_button_clicked, self)
button_row_widget = self.create_button_row(self.on_submit_button_clicked)
right_vbox.addWidget(button_row_widget)

right_frame.setLayout(right_vbox)
main_splitter.addWidget(right_frame)

main_layout = QVBoxLayout(self)
main_layout.addWidget(main_splitter)

# Metrics bar
main_layout.addWidget(self.metrics_bar)
self.metrics_bar.setMaximumHeight(75 if self.is_nvidia_gpu() else 30)

def init_menu(self):
self.menu_bar = QMenuBar(self)
Expand Down Expand Up @@ -143,9 +151,36 @@ def update_transcription(self, text):
self.text_input.setPlainText(text)

def closeEvent(self, event):
self.metrics_bar.stop_monitors()
self.metrics_bar.stop_metrics_collector()
event.accept()

def create_button_row(self, submit_handler):
voice_recorder = voice_recorder_module.VoiceRecorder(self)

def start_recording():
voice_recorder.start_recording()

def stop_recording():
voice_recorder.stop_recording()

start_button = QPushButton("Start Recording")
start_button.clicked.connect(start_recording)

stop_button = QPushButton("Stop Recording")
stop_button.clicked.connect(stop_recording)

hbox = QHBoxLayout()
hbox.addWidget(start_button)
hbox.addWidget(stop_button)

hbox.setStretchFactor(start_button, 3)
hbox.setStretchFactor(stop_button, 3)

row_widget = QWidget()
row_widget.setLayout(hbox)

return row_widget

if __name__ == '__main__':
app = QApplication(sys.argv)
app.setStyle(QStyleFactory.create('fusion'))
Expand Down
15 changes: 10 additions & 5 deletions src/gui_tabs_settings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from PySide6.QtWidgets import QVBoxLayout, QGroupBox, QPushButton, QHBoxLayout, QWidget, QMessageBox
from PySide6.QtWidgets import QVBoxLayout, QGroupBox, QPushButton, QHBoxLayout, QWidget, QMessageBox, QLabel
from gui_tabs_settings_server import ServerSettingsTab
# from gui_tabs_settings_models import ModelsSettingsTab
# Commented out unless/until modifying BGE and Instructor settings become useful
from gui_tabs_settings_whisper import TranscriberSettingsTab
from gui_tabs_settings_database import DatabaseSettingsTab
# from gui_tabs_settings_chunks import ChunkSettingsTab

def update_all_configs(configs):
updated = any(config.update_config() for config in configs.values())
Expand All @@ -28,8 +27,7 @@ def __init__(self):
classes = {
"Server/LLM Settings": (ServerSettingsTab, 3),
"Voice Recorder Settings": (TranscriberSettingsTab, 1),
"Database Settings": (DatabaseSettingsTab, 4),
# "Chunking Settings": (ChunkSettingsTab, 4),
"Database Settings": (DatabaseSettingsTab, 3),
}

self.groups = {}
Expand Down Expand Up @@ -61,6 +59,13 @@ def __init__(self):
center_button_layout.addWidget(self.update_all_button)
center_button_layout.addStretch(1)

tip_label_1 = QLabel("<b><u>Must</u> 'Update Settings' before any settings take effect.</b>")
tip_label_2 = QLabel("<b><u>Must recreate</u> database if changing Chunk Size/Overlap settings</b>")
self.layout.addWidget(tip_label_1)
self.layout.addWidget(tip_label_2)

self.setLayout(self.layout)

self.layout.addLayout(center_button_layout)
self.setLayout(self.layout)
adjust_stretch(self.groups, self.layout)
adjust_stretch(self.groups, self.layout)
7 changes: 0 additions & 7 deletions src/gui_tabs_settings_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,6 @@ def __init__(self):
v_layout.addLayout(h_layout1)
v_layout.addLayout(h_layout2)
self.setLayout(v_layout)

tip_label_1 = QLabel("<b><u>Must</u> 'Update Settings' before any settings take effect.</b>")
tip_label_2 = QLabel("<b><u>RECREATE</u> database if changing Chunk Size/Overlap settings</b>")
v_layout.addWidget(tip_label_1)
v_layout.addWidget(tip_label_2)

self.setLayout(v_layout)

def update_config(self):
with open('config.yaml', 'r') as f:
Expand Down
42 changes: 31 additions & 11 deletions src/gui_tabs_settings_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from PySide6.QtWidgets import QWidget, QLabel, QLineEdit, QGridLayout, QMessageBox, QSizePolicy, QCheckBox
from PySide6.QtWidgets import QWidget, QLabel, QLineEdit, QGridLayout, QMessageBox, QSizePolicy, QCheckBox, QComboBox
from PySide6.QtGui import QIntValidator, QDoubleValidator
import yaml

Expand All @@ -7,14 +7,14 @@ def __init__(self):
super(ServerSettingsTab, self).__init__()

with open('config.yaml', 'r') as file:
config_data = yaml.safe_load(file)
self.connection_str = config_data.get('server', {}).get('connection_str', '')
self.config_data = yaml.safe_load(file)
self.connection_str = self.config_data.get('server', {}).get('connection_str', '')
self.current_port = self.connection_str.split(":")[-1].split("/")[0]
self.current_max_tokens = config_data.get('server', {}).get('model_max_tokens', '')
self.current_temperature = config_data.get('server', {}).get('model_temperature', '')
self.current_prefix = config_data.get('server', {}).get('prefix', '')
self.current_suffix = config_data.get('server', {}).get('suffix', '')
self.prompt_format_disabled = config_data.get('server', {}).get('prompt_format_disabled', False)
self.current_max_tokens = self.config_data.get('server', {}).get('model_max_tokens', '')
self.current_temperature = self.config_data.get('server', {}).get('model_temperature', '')
self.current_prefix = self.config_data.get('server', {}).get('prefix', '')
self.current_suffix = self.config_data.get('server', {}).get('suffix', '')
self.prompt_format_disabled = self.config_data.get('server', {}).get('prompt_format_disabled', False)

settings_dict = {
'port': {"placeholder": "Enter new port...", "validator": QIntValidator(), "current": self.current_port},
Expand All @@ -36,8 +36,13 @@ def __init__(self):

prompt_format_label = QLabel("Prompt Format:")
layout.addWidget(prompt_format_label, 2, 0)

self.prompt_format_combobox = QComboBox()
self.prompt_format_combobox.addItems(["", "ChatML", "Llama2/Mistral", "Neural Chat", "Orca2"])
layout.addWidget(self.prompt_format_combobox, 2, 1)
self.prompt_format_combobox.currentIndexChanged.connect(self.update_prefix_suffix)

disable_label = QLabel("Disable")
disable_label = QLabel("Disable:")
layout.addWidget(disable_label, 2, 2)

self.disable_checkbox = QCheckBox()
Expand All @@ -47,8 +52,8 @@ def __init__(self):

layout.addWidget(self.create_label('prefix', settings_dict), 3, 0)
layout.addWidget(self.create_edit('prefix', settings_dict), 3, 1)
layout.addWidget(self.create_label('suffix', settings_dict), 3, 2)
layout.addWidget(self.create_edit('suffix', settings_dict), 3, 3)
layout.addWidget(self.create_label('suffix', settings_dict), 4, 0)
layout.addWidget(self.create_edit('suffix', settings_dict), 4, 1)

self.setLayout(layout)

Expand All @@ -70,6 +75,21 @@ def create_edit(self, setting, settings_dict):
self.widgets[setting]['edit'] = edit
return edit

def update_prefix_suffix(self, index):
option = self.prompt_format_combobox.currentText()

key_mapping = {
"ChatML": ("prefix_chat_ml", "suffix_chat_ml"),
"Llama2/Mistral": ("prefix_llama2_and_mistral", "suffix_llama2_and_mistral"),
"Neural Chat": ("prefix_neural_chat", "suffix_neural_chat"),
"Orca2": ("prefix_orca2", "suffix_orca2"),
}

prefix_key, suffix_key = key_mapping.get(option, ("", ""))

self.widgets['prefix']['edit'].setText(self.config_data.get('server', {}).get(prefix_key, ''))
self.widgets['suffix']['edit'].setText(self.config_data.get('server', {}).get(suffix_key, ''))

def update_config(self):
with open('config.yaml', 'r') as file:
config_data = yaml.safe_load(file)
Expand Down
6 changes: 0 additions & 6 deletions src/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@ def get_supported_quantizations(device_type):
types = ctranslate2.get_supported_compute_types(device_type)
filtered_types = [q for q in types if q != 'int16']

# Define the desired order of quantizations
desired_order = ['float32', 'float16', 'bfloat16', 'int8_float32', 'int8_float16', 'int8_bfloat16', 'int8']

# Sort the filtered_types based on the desired order
sorted_types = [q for q in desired_order if q in filtered_types]

return sorted_types

def update_config_file(**system_info):
Expand All @@ -44,13 +40,11 @@ def update_config_file(**system_info):
config_data['Compute_Device'].setdefault('database_creation', 'cpu')
config_data['Compute_Device'].setdefault('database_query', 'cpu')

# Add supported quantizations for CPU and GPU
config_data['Supported_CTranslate2_Quantizations'] = {
'CPU': get_supported_quantizations('cpu'),
'GPU': get_supported_quantizations('cuda') if torch.cuda.is_available() else []
}

# Update other keys
for key, value in system_info.items():
if key != 'Compute_Device' and key != 'Supported_CTranslate2_Quantizations':
config_data[key] = value
Expand Down
Loading

0 comments on commit 468895f

Please sign in to comment.