Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Aug 11, 2024
1 parent d37ee22 commit c0cb10c
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 102 deletions.
85 changes: 32 additions & 53 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,6 @@
DOWNLOAD_EMBEDDING_MODEL_TOOLTIP = "Remember, wait until downloading is complete!"

VECTOR_MODELS = {
'Alibaba-NLP': [
{
'name': 'gte-base-en-v1.5',
'dimensions': 768,
'max_sequence': 8192,
'size_mb': 547,
'repo_id': 'Alibaba-NLP/gte-base-en-v1.5',
'cache_dir': 'Alibaba-NLP--gte-base-en-v1.5',
'type': 'vector'
},
{
'name': 'gte-large-en-v1.5',
'dimensions': 1024,
'max_sequence': 8192,
'size_mb': 1740,
'repo_id': 'Alibaba-NLP/gte-large-en-v1.5',
'cache_dir': 'Alibaba-NLP--gte-large-en-v1.5',
'type': 'vector'
},
],
'BAAI': [
{
'name': 'bge-small-en-v1.5',
Expand Down Expand Up @@ -131,15 +111,26 @@
'cache_dir': 'sentence-transformers--sentence-t5-xl',
'type': 'vector'
},
# {
# 'name': 'sentence-t5-xxl',
# 'dimensions': 768,
# 'max_sequence': 256,
# 'size_mb': 9230,
# 'repo_id': 'sentence-transformers/sentence-t5-xxl',
# 'cache_dir': 'sentence-transformers--sentence-t5-xxl',
# 'type': 'vector'
# },
],
'Alibaba-NLP': [
{
'name': 'Alibaba-gte-base',
'dimensions': 768,
'max_sequence': 8192,
'size_mb': 547,
'repo_id': 'Alibaba-NLP/gte-base-en-v1.5',
'cache_dir': 'Alibaba-NLP--gte-base-en-v1.5',
'type': 'vector'
},
{
'name': 'Alibaba-gte-large',
'dimensions': 1024,
'max_sequence': 8192,
'size_mb': 1740,
'repo_id': 'Alibaba-NLP/gte-large-en-v1.5',
'cache_dir': 'Alibaba-NLP--gte-large-en-v1.5',
'type': 'vector'
},
],
'thenlper': [
{
Expand Down Expand Up @@ -199,18 +190,6 @@
'type': 'vector'
},
],
'dunzhang': [
{
'name': 'stella_en_1.5B_v5',
'dimensions': 1024,
'max_sequence': 512,
'size_mb': 547,
'repo_id': 'dunzhang/stella_en_1.5B_v5',
'cache_dir': 'dunzhang--stella_en_1.5B_v5',
'type': 'vector'
},
],

}


Expand All @@ -224,6 +203,15 @@
'avg_vram_usage': '2.5 GB',
'function': 'Zephyr_1_6B'
},
'Internlm2_5 - 1.8b': {
'model': 'Internlm2_5 - 1.8b',
'repo_id': 'internlm/internlm2_5-1_8b-chat',
'cache_dir': 'internlm--internlm2_5-1_8b-chat',
'tokens_per_second': 55.51,
'context_length': 32768,
'avg_vram_usage': '2.8 GB',
'function': 'InternLM2_5_1_8b'
},
'Zephyr - 3b': {
'model': 'Zephyr - 3b',
'repo_id': 'stabilityai/stablelm-zephyr-3b',
Expand Down Expand Up @@ -279,15 +267,6 @@
'avg_vram_usage': '5.8 GB',
'function': 'Neural_Chat_7b'
},
'Internlm2 - 7b': {
'model': 'Internlm2 - 7b',
'repo_id': 'internlm/internlm2-chat-7b',
'cache_dir': 'internlm--internlm2-chat-7b',
'tokens_per_second': 36.83,
'context_length': 32768,
'avg_vram_usage': '6.7 GB',
'function': 'InternLM2_7b'
},
'Internlm2_5 - 7b': {
'model': 'Internlm2_5 - 7b',
'repo_id': 'internlm/internlm2_5-7b-chat',
Expand Down Expand Up @@ -390,12 +369,12 @@
},
'Internlm2 - 20b': {
'model': 'Internlm2 - 20b',
'repo_id': 'internlm/internlm2-chat-20b',
'cache_dir': 'internlm--internlm2-chat-20b',
'repo_id': 'internlm/internlm2_5-chat-20b',
'cache_dir': 'internlm--internlm2_5-chat-20b',
'tokens_per_second': 20.21,
'context_length': 32768,
'avg_vram_usage': '14.2 GB',
'function': 'InternLM2_20b'
'function': 'InternLM2_5_20b'
},
}

Expand Down
19 changes: 13 additions & 6 deletions src/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ class ModelDownloadedSignal(QObject):
}

class ModelDownloader:
def __init__(self, model_name, model_type):
self.model_name = model_name
def __init__(self, model_info, model_type):
self.model_info = model_info
self.model_type = model_type
self._model_directory = None

def get_model_directory_name(self):
return self.model_name.replace("/", "--")
if isinstance(self.model_info, dict):
return self.model_info['cache_dir']
else:
return self.model_info.replace("/", "--")

def get_model_directory(self):
if not self._model_directory:
Expand All @@ -29,12 +32,15 @@ def get_model_directory(self):
return self._model_directory

def get_model_url(self):
return f"https://huggingface.co/{self.model_name}"
if isinstance(self.model_info, dict):
return f"https://huggingface.co/{self.model_info['repo_id']}"
else:
return f"https://huggingface.co/{self.model_info}"

def download_model(self):
model_url = self.get_model_url()
target_directory = self.get_model_directory()
print(f"Downloading {self.model_name}...")
print(f"Downloading {self.get_model_directory_name()}...")

env = os.environ.copy()
env["GIT_CLONE_PROTECTION_ACTIVE"] = "false"
Expand All @@ -46,6 +52,7 @@ def download_model(self):
env=env
)
print("\033[92mModel downloaded and ready to use.\033[0m")
model_downloaded_signal.downloaded.emit(self.model_name, self.model_type)
print(f"Emitting signal: {self.get_model_directory_name()}, {self.model_type}")
model_downloaded_signal.downloaded.emit(self.get_model_directory_name(), self.model_type)
except subprocess.CalledProcessError as e:
print(f"Command 'git clone' returned non-zero exit status {e.returncode}.")
39 changes: 24 additions & 15 deletions src/gui_tabs_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@

import database_interactions
from choose_documents_and_vector_model import select_embedding_model_directory, choose_documents_directory
from utilities import check_preconditions_for_db_creation, open_file, delete_file, backup_database
from utilities import check_preconditions_for_db_creation, open_file, delete_file, backup_database, get_pkl_file_path
from download_model import model_downloaded_signal

datasets_logger = logging.getLogger('datasets')
datasets_logger.setLevel(logging.WARNING)
# datasets_logger = logging.getLogger('datasets')
# datasets_logger.setLevel(logging.WARNING)

logging.getLogger("transformers").setLevel(logging.ERROR)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
logging.getLogger().setLevel(logging.WARNING)
# logging.getLogger("transformers").setLevel(logging.ERROR)
# warnings.filterwarnings("ignore", category=FutureWarning)
# warnings.filterwarnings("ignore", category=UserWarning)
# logging.getLogger().setLevel(logging.WARNING)

class CreateDatabaseThread(QThread):
creationComplete = Signal()
Expand All @@ -33,7 +34,7 @@ def __init__(self, database_name, parent=None):

def run(self):
create_vector_db = database_interactions.CreateVectorDB(database_name=self.database_name)
create_vector_db.run() # initiates database creation
create_vector_db.run() # INITIATES DB CREATION
self.update_config_with_database_name()
backup_database()

Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(self, parent=None):
def data(self, index, role=Qt.DisplayRole):
if role == Qt.DisplayRole and index.column() == 0:
file_path = super().filePath(index)
# opens the .pkl file and gets the file name from metadata
if file_path.endswith('.pkl'):
try:
with open(file_path, 'rb') as file:
Expand All @@ -82,6 +84,7 @@ def data(self, index, role=Qt.DisplayRole):
class DatabasesTab(QWidget):
def __init__(self):
super().__init__()
model_downloaded_signal.downloaded.connect(self.update_model_combobox)

self.layout = QVBoxLayout(self)
self.documents_group_box = self.create_group_box("Files To Add to Database", "Docs_for_DB")
Expand Down Expand Up @@ -122,7 +125,13 @@ def __init__(self):

self.sync_combobox_with_config()

def update_model_combobox(self, model_name, model_type):
if model_type == "vector":
self.populate_model_combobox()
self.sync_combobox_with_config()

def populate_model_combobox(self):
# 1. populates comobobox when script loads
self.model_combobox.clear()
self.model_combobox.addItem("Select a model", None)

Expand All @@ -145,6 +154,7 @@ def populate_model_combobox(self):
print(f"Warning: No model directories found in {vector_dir}")

def sync_combobox_with_config(self):
# 2. after the script loads, sets the model chosen to what is in the config
config_path = Path(__file__).resolve().parent / "config.yaml"
if config_path.exists():
with open(config_path, 'r', encoding='utf-8') as file:
Expand All @@ -164,6 +174,7 @@ def sync_combobox_with_config(self):
self.model_combobox.setCurrentIndex(0)

def on_model_selected(self, index):
# 3. updates the config when a user selects a different model
selected_path = self.model_combobox.itemData(index)
config_path = Path(__file__).resolve().parent / "config.yaml"
config_data = {}
Expand Down Expand Up @@ -221,15 +232,13 @@ def on_double_click(self, index):

if file_path.endswith('.pkl'):
try:
with open(file_path, 'rb') as file:
document = pickle.load(file)
internal_file_path = document.metadata.get('file_path')
if internal_file_path and Path(internal_file_path).exists():
internal_file_path = get_pkl_file_path(file_path)
if internal_file_path:
open_file(internal_file_path)
else:
QMessageBox.warning(self, "File Not Found", f"The file {internal_file_path} does not exist.")
except Exception as e:
QMessageBox.critical(self, "Error", f"Could not open the pickle file: {e}")
QMessageBox.warning(self, "File Not Found", f"The file from {file_path} does not exist.")
except ValueError as e:
QMessageBox.critical(self, "Error", str(e))
else:
open_file(file_path)

Expand Down
42 changes: 17 additions & 25 deletions src/module_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from constants import CHAT_MODELS, system_message
from utilities import my_cprint, bnb_bfloat16_settings, bnb_float16_settings, generate_settings_4096

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# warnings.filterwarnings("ignore", category=FutureWarning)
# warnings.filterwarnings("ignore", category=UserWarning)

class BaseModel(ABC):
def __init__(self, model_info, settings, tokenizer_kwargs=None, model_kwargs=None, eos_token=None):
Expand Down Expand Up @@ -54,7 +54,7 @@ def generate_response(self, inputs):
inputs (dict): A dictionary of inputs prepared for the model.
Returns:
str: The full generated response as a string.
str: Chunks of generated response as a string.
"""
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)

Expand Down Expand Up @@ -97,6 +97,19 @@ def cleanup_resources(model, tokenizer):
gc.collect()


class InternLM2_5_1_8b(BaseModel):
def __init__(self):
model_info = CHAT_MODELS['Internlm2_5 - 1.8b']
tokenizer_kwargs = {'trust_remote_code': True}
model_kwargs = {'trust_remote_code': True}
super().__init__(model_info, bnb_bfloat16_settings,
tokenizer_kwargs=tokenizer_kwargs,
model_kwargs=model_kwargs,
eos_token="<|im_end|>")

def create_prompt(self, augmented_query):
return f"<|begin_of_text|><|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{augmented_query}<|im_end|>\n<|im_start|>assistant\n"

class Dolphin_Llama3_8B(BaseModel):
def __init__(self):
model_info = CHAT_MODELS['Dolphin-Llama 3 - 8b']
Expand Down Expand Up @@ -143,7 +156,6 @@ def create_prompt(self, augmented_query):


class Dolphin_Qwen2_7b(BaseModel):
# Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
def __init__(self):
model_info = CHAT_MODELS['Dolphin-Qwen 2 - 7b']
super().__init__(model_info, bnb_bfloat16_settings)
Expand All @@ -153,9 +165,6 @@ def create_prompt(self, augmented_query):


class Dolphin_Qwen2_1_5b(BaseModel):
"""
Assistant: Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
"""
def __init__(self):
model_info = CHAT_MODELS['Dolphin-Qwen 2 - 1.5b']
super().__init__(model_info, bnb_bfloat16_settings)
Expand All @@ -173,7 +182,7 @@ def create_prompt(self, augmented_query):
return f"<|begin_of_text|><|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{augmented_query}<|im_end|>\n<|im_start|>assistant\n"


class InternLM2_20b(BaseModel):
class InternLM2_5_20b(BaseModel):
def __init__(self):
model_info = CHAT_MODELS['Internlm2 - 20b']
tokenizer_kwargs = {'trust_remote_code': True}
Expand All @@ -187,20 +196,6 @@ def create_prompt(self, augmented_query):
return f"<|begin_of_text|><|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{augmented_query}<|im_end|>\n<|im_start|>assistant\n"


class InternLM2_7b(BaseModel):
def __init__(self):
model_info = CHAT_MODELS['Internlm2 - 7b']
tokenizer_kwargs = {'trust_remote_code': True}
model_kwargs = {'trust_remote_code': True}
super().__init__(model_info, bnb_bfloat16_settings,
tokenizer_kwargs=tokenizer_kwargs,
model_kwargs=model_kwargs,
eos_token="<|im_end|>")

def create_prompt(self, augmented_query):
return f"<|begin_of_text|><|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{augmented_query}<|im_end|>\n<|im_start|>assistant\n"


class InternLM2_5_7b(BaseModel):
def __init__(self):
model_info = CHAT_MODELS['Internlm2_5 - 7b']
Expand All @@ -225,7 +220,6 @@ def create_prompt(self, augmented_query):


class Neural_Chat_7b(BaseModel):
# Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
def __init__(self):
model_info = CHAT_MODELS['Neural-Chat - 7b']
super().__init__(model_info, bnb_float16_settings)
Expand Down Expand Up @@ -271,7 +265,6 @@ def create_prompt(self, augmented_query):


class Zephyr_1_6B(BaseModel):
# Setting `pad_token_id` to `eos_token_id`:100257 for open-end generation.
def __init__(self):
model_info = CHAT_MODELS['Zephyr - 1.6b']
super().__init__(model_info, bnb_float16_settings)
Expand All @@ -281,7 +274,6 @@ def create_prompt(self, augmented_query):


class Zephyr_3B(BaseModel):
# Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
def __init__(self):
model_info = CHAT_MODELS['Zephyr - 3b']
super().__init__(model_info, bnb_bfloat16_settings)
Expand Down
Loading

0 comments on commit c0cb10c

Please sign in to comment.