Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Aug 12, 2024
1 parent c0cb10c commit c9c4774
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 12 deletions.
44 changes: 33 additions & 11 deletions src/database_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,11 @@ def initialize_vector_model(self, embedding_model_name, config_data):
't5-xl': 1,
't5-large': 2,
'instructor-xl': 2,
't5-base': 4,
'bge-large': 4,
'instructor-large': 4,
'e5-large': 4,
'gte-large': 4,
'stella': 2,
't5-base': 3,
'bge-large': 3,
'instructor-large': 3,
'e5-large': 3,
'gte-large': 3,
'instructor-base': 8,
'mpnet': 8,
'e5-base': 8,
Expand Down Expand Up @@ -103,6 +102,21 @@ def initialize_vector_model(self, embedding_model_name, config_data):
query_instruction=query_instruction,
encode_kwargs=encode_kwargs
)

elif "Alibaba" in embedding_model_name:
model_kwargs["tokenizer_kwargs"] = {
"max_length": 8192,
"padding": True,
"truncation": True
}
# encode_kwargs['show_progress_bar'] = True
model = HuggingFaceEmbeddings(
model_name=embedding_model_name,
show_progress=True,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)

else:
model = HuggingFaceEmbeddings(
model_name=embedding_model_name,
Expand Down Expand Up @@ -241,7 +255,7 @@ def run(self):
config_data = self.load_config(self.ROOT_DIRECTORY)
EMBEDDING_MODEL_NAME = config_data.get("EMBEDDING_MODEL_NAME")

# create a list to hold langchain "document objects"
# create a list to hold langchain "document objects"
# langchain_core.documents.base.Document
documents = []

Expand All @@ -268,7 +282,7 @@ def run(self):
if len(audio_documents) > 0:
print(f"Loaded {len(audio_documents)} audio transcription(s)...")

texts = [] # listed created to hold split documents
texts = [] # list created to hold split documents

# split documents
if isinstance(documents, list) and documents:
Expand Down Expand Up @@ -332,11 +346,19 @@ def initialize_vector_model(self):
query_instruction=query_instruction,
encode_kwargs=encode_kwargs
)
elif "stella" in model_path:
encode_kwargs["prompt_name"] = "s2p_query"
elif "Alibaba" in model_path:
return HuggingFaceEmbeddings(
model_name=model_path,
model_kwargs={"device": compute_device, "trust_remote_code": True},
model_kwargs={
"device": compute_device,
"trust_remote_code": True,
"tokenizer_kwargs": {
"max_length": 8192,
"padding": True,
"truncation": True
}
},
# model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/gui_tabs_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def populate_model_combobox(self):
for folder in vector_dir.iterdir():
if folder.is_dir():
model_found = True
display_name = folder.name.split('--')[-1]
display_name = folder.name
full_path = str(folder)
self.model_combobox.addItem(display_name, full_path)

Expand Down

0 comments on commit c9c4774

Please sign in to comment.