Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ollama option to choice of LLMs #143

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions LLM/ollama_language_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from threading import Thread
from ollama import Client
import torch

from LLM.chat import Chat
from baseHandler import BaseHandler
from rich.console import Console
import logging
from nltk import sent_tokenize

logger = logging.getLogger(__name__)

console = Console()


WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
"en": "english",
"fr": "french",
"es": "spanish",
"zh": "chinese",
"ja": "japanese",
"ko": "korean",
"hi": "hindi",
}

class OllamaLanguageModelHandler(BaseHandler):
"""
Handles the language model part.
"""

def setup(
self,
model_name="hf.co/HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF",
device="",
torch_dtype="",
gen_kwargs={},
api_endpoint=None,
user_role="user",
chat_size=1,
init_chat_role=None,
init_chat_prompt="You are a helpful AI assistant.",
):
self.model_name = model_name
self.client = Client(host=api_endpoint)

self.gen_kwargs = gen_kwargs

self.chat = Chat(chat_size)
if init_chat_role:
if not init_chat_prompt:
raise ValueError(
"An initial promt needs to be specified when setting init_chat_role."
)
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
self.user_role = user_role

self.warmup()

def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
self.client.chat(model=self.model_name, messages=[])

def process(self, prompt):
logger.debug("infering language model...")
language_code = None

if isinstance(prompt, tuple):
prompt, language_code = prompt
if language_code[-5:] == "-auto":
language_code = language_code[:-5]
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt

self.chat.append({"role": self.user_role, "content": prompt})

stream = self.client.chat(
model=self.model_name,
messages=self.chat.to_list(),
stream=True,
)

generated_text = ""
for chunk in stream:
chunk_text = chunk['message']['content']
generated_text += chunk_text
print(chunk_text, end='', flush=True)

self.chat.append({"role": "assistant", "content": generated_text})
4 changes: 2 additions & 2 deletions arguments_classes/module_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ class ModuleArguments:
stt: Optional[str] = field(
default="whisper",
metadata={
"help": "The STT to use. Either 'whisper', 'whisper-mlx', 'faster-whisper', and 'paraformer'. Default is 'whisper'."
"help": "The STT to use. Either 'moonshine', 'whisper', 'whisper-mlx', 'faster-whisper', and 'paraformer'. Default is 'whisper'."
},
)
llm: Optional[str] = field(
default="transformers",
metadata={
"help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'"
"help": "The LLM to use. Either 'transformers', 'mlx-lm', 'openai' or 'ollama'. Default is 'transformers'"
},
)
tts: Optional[str] = field(
Expand Down
53 changes: 53 additions & 0 deletions arguments_classes/ollama_language_model_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from dataclasses import dataclass, field


@dataclass
class OllamaLanguageModelHandlerArguments:
ollama_model_name: str = field(
default="hf.co/HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF",
metadata={
"help": "The pretrained language model to use. Default is 'hf.co/HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF'."
},
)
ollama_user_role: str = field(
default="user",
metadata={
"help": "Role assigned to the user in the chat context. Default is 'user'."
},
)
ollama_init_chat_role: str = field(
default="system",
metadata={
"help": "Initial role for setting up the chat context. Default is 'system'."
},
)
ollama_init_chat_prompt: str = field(
default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
metadata={
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
},
)
ollama_gen_max_new_tokens: int = field(
default=128,
metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 128."
},
)
ollama_gen_temperature: float = field(
default=0.0,
metadata={
"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
},
)
ollama_api_endpoint: str = field(
default="http://localhost:11434",
metadata={
"help": "Ollama endpoint. Default is 'http://localhost:11434'"
},
)
ollama_chat_size: int = field(
default=2,
metadata={
"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
},
)
1 change: 0 additions & 1 deletion arguments_classes/open_api_language_model_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class OpenApiLanguageModelHandlerArguments:
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
},
)

open_api_chat_size: int = field(
default=2,
metadata={
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ faster-whisper>=1.0.3
modelscope>=1.17.1
deepfilternet>=0.5.6
openai>=1.40.1
useful-moonshine @ git+https://github.com/andimarafioti/moonshine.git
useful-moonshine @ git+https://github.com/andimarafioti/moonshine.git
ollama>=0.3.3
3 changes: 2 additions & 1 deletion requirements_mac.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ faster-whisper>=1.0.3
modelscope>=1.17.1
deepfilternet>=0.5.6
openai>=1.40.1
useful-moonshine @ git+https://github.com/andimarafioti/moonshine.git
useful-moonshine @ git+https://github.com/andimarafioti/moonshine.git
ollama>=0.3.3
31 changes: 25 additions & 6 deletions s2s_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from arguments_classes.mlx_language_model_arguments import (
MLXLanguageModelHandlerArguments,
)
from arguments_classes.ollama_language_model_arguments import (
OllamaLanguageModelHandlerArguments,
)
from arguments_classes.module_arguments import ModuleArguments
from arguments_classes.paraformer_stt_arguments import ParaformerSTTHandlerArguments
from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments
Expand Down Expand Up @@ -84,6 +87,7 @@ def parse_arguments():
LanguageModelHandlerArguments,
OpenApiLanguageModelHandlerArguments,
MLXLanguageModelHandlerArguments,
OllamaLanguageModelHandlerArguments,
ParlerTTSHandlerArguments,
MeloTTSHandlerArguments,
ChatTTSHandlerArguments,
Expand Down Expand Up @@ -173,6 +177,7 @@ def prepare_all_args(
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
ollama_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
Expand All @@ -186,6 +191,7 @@ def prepare_all_args(
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
ollama_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
Expand All @@ -197,6 +203,7 @@ def prepare_all_args(
rename_args(paraformer_stt_handler_kwargs, "paraformer_stt")
rename_args(language_model_handler_kwargs, "lm")
rename_args(mlx_language_model_handler_kwargs, "mlx_lm")
rename_args(ollama_language_model_handler_kwargs, "ollama")
rename_args(open_api_language_model_handler_kwargs, "open_api")
rename_args(parler_tts_handler_kwargs, "tts")
rename_args(melo_tts_handler_kwargs, "melo")
Expand Down Expand Up @@ -227,6 +234,7 @@ def build_pipeline(
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
ollama_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
Expand Down Expand Up @@ -278,7 +286,7 @@ def build_pipeline(
)

stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, faster_whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs)
lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs)
lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, ollama_language_model_handler_kwargs)
tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, facebook_mms_tts_handler_kwargs)

return ThreadManager([*comms_handlers, vad, stt, lm, tts])
Expand Down Expand Up @@ -326,7 +334,7 @@ def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_
setup_kwargs=vars(faster_whisper_stt_handler_kwargs),
)
else:
raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.")
raise ValueError("The STT should be either moonshine, whisper, whisper-mlx, paraformer or faster-whisper.")


def get_llm_handler(
Expand All @@ -336,7 +344,8 @@ def get_llm_handler(
lm_response_queue,
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs
mlx_language_model_handler_kwargs,
ollama_language_model_handler_kwargs,
):
if module_kwargs.llm == "transformers":
from LLM.language_model import LanguageModelHandler
Expand All @@ -354,7 +363,6 @@ def get_llm_handler(
queue_out=lm_response_queue,
setup_kwargs=vars(open_api_language_model_handler_kwargs),
)

elif module_kwargs.llm == "mlx-lm":
from LLM.mlx_language_model import MLXLanguageModelHandler
return MLXLanguageModelHandler(
Expand All @@ -363,9 +371,17 @@ def get_llm_handler(
queue_out=lm_response_queue,
setup_kwargs=vars(mlx_language_model_handler_kwargs),
)
elif module_kwargs.llm == "ollama":
from LLM.ollama_language_model import OllamaLanguageModelHandler
return OllamaLanguageModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
setup_kwargs=vars(ollama_language_model_handler_kwargs),
)

else:
raise ValueError("The LLM should be either transformers or mlx-lm")
raise ValueError("The LLM should be either transformers, open_ai, mlx-lm or ollama")


def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, facebook_mms_tts_handler_kwargs):
Expand Down Expand Up @@ -416,7 +432,7 @@ def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chu
setup_kwargs=vars(facebook_mms_tts_handler_kwargs),
)
else:
raise ValueError("The TTS should be either parler, melo or chatTTS")
raise ValueError("The TTS should be either parler, melo, chatTTS or facebookMMS")


def main():
Expand All @@ -431,6 +447,7 @@ def main():
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
ollama_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
Expand All @@ -447,6 +464,7 @@ def main():
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
ollama_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
Expand All @@ -466,6 +484,7 @@ def main():
language_model_handler_kwargs,
open_api_language_model_handler_kwargs,
mlx_language_model_handler_kwargs,
ollama_language_model_handler_kwargs,
parler_tts_handler_kwargs,
melo_tts_handler_kwargs,
chat_tts_handler_kwargs,
Expand Down