From d54c40c141b8384f751474a054fd0dad437bd9f9 Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 28 Nov 2024 19:08:58 +0000 Subject: [PATCH 1/2] update requirements for ollama --- requirements.txt | 3 ++- requirements_mac.txt | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 9301497..d308b6e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ faster-whisper>=1.0.3 modelscope>=1.17.1 deepfilternet>=0.5.6 openai>=1.40.1 -useful-moonshine @ git+https://github.com/andimarafioti/moonshine.git \ No newline at end of file +useful-moonshine @ git+https://github.com/andimarafioti/moonshine.git +ollama>=0.3.3 \ No newline at end of file diff --git a/requirements_mac.txt b/requirements_mac.txt index c6c7a5b..b521008 100644 --- a/requirements_mac.txt +++ b/requirements_mac.txt @@ -11,4 +11,5 @@ faster-whisper>=1.0.3 modelscope>=1.17.1 deepfilternet>=0.5.6 openai>=1.40.1 -useful-moonshine @ git+https://github.com/andimarafioti/moonshine.git \ No newline at end of file +useful-moonshine @ git+https://github.com/andimarafioti/moonshine.git +ollama>=0.3.3 \ No newline at end of file From e6e66cadde84e162226127e0a13c482bf7877764 Mon Sep 17 00:00:00 2001 From: rchan Date: Thu, 28 Nov 2024 19:09:25 +0000 Subject: [PATCH 2/2] add ollama to llm models --- LLM/ollama_language_model.py | 87 +++++++++++++++++++ arguments_classes/module_arguments.py | 4 +- .../ollama_language_model_arguments.py | 53 +++++++++++ .../open_api_language_model_arguments.py | 1 - s2s_pipeline.py | 31 +++++-- 5 files changed, 167 insertions(+), 9 deletions(-) create mode 100644 LLM/ollama_language_model.py create mode 100644 arguments_classes/ollama_language_model_arguments.py diff --git a/LLM/ollama_language_model.py b/LLM/ollama_language_model.py new file mode 100644 index 0000000..8a83c1d --- /dev/null +++ b/LLM/ollama_language_model.py @@ -0,0 +1,87 @@ +from threading import Thread +from ollama import Client +import torch + +from LLM.chat import Chat +from baseHandler import BaseHandler +from rich.console import Console +import logging +from nltk import sent_tokenize + +logger = logging.getLogger(__name__) + +console = Console() + + +WHISPER_LANGUAGE_TO_LLM_LANGUAGE = { + "en": "english", + "fr": "french", + "es": "spanish", + "zh": "chinese", + "ja": "japanese", + "ko": "korean", + "hi": "hindi", +} + +class OllamaLanguageModelHandler(BaseHandler): + """ + Handles the language model part. + """ + + def setup( + self, + model_name="hf.co/HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF", + device="", + torch_dtype="", + gen_kwargs={}, + api_endpoint=None, + user_role="user", + chat_size=1, + init_chat_role=None, + init_chat_prompt="You are a helpful AI assistant.", + ): + self.model_name = model_name + self.client = Client(host=api_endpoint) + + self.gen_kwargs = gen_kwargs + + self.chat = Chat(chat_size) + if init_chat_role: + if not init_chat_prompt: + raise ValueError( + "An initial promt needs to be specified when setting init_chat_role." + ) + self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt}) + self.user_role = user_role + + self.warmup() + + def warmup(self): + logger.info(f"Warming up {self.__class__.__name__}") + self.client.chat(model=self.model_name, messages=[]) + + def process(self, prompt): + logger.debug("infering language model...") + language_code = None + + if isinstance(prompt, tuple): + prompt, language_code = prompt + if language_code[-5:] == "-auto": + language_code = language_code[:-5] + prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt + + self.chat.append({"role": self.user_role, "content": prompt}) + + stream = self.client.chat( + model=self.model_name, + messages=self.chat.to_list(), + stream=True, + ) + + generated_text = "" + for chunk in stream: + chunk_text = chunk['message']['content'] + generated_text += chunk_text + print(chunk_text, end='', flush=True) + + self.chat.append({"role": "assistant", "content": generated_text}) diff --git a/arguments_classes/module_arguments.py b/arguments_classes/module_arguments.py index d2f00e1..68e013d 100644 --- a/arguments_classes/module_arguments.py +++ b/arguments_classes/module_arguments.py @@ -23,13 +23,13 @@ class ModuleArguments: stt: Optional[str] = field( default="whisper", metadata={ - "help": "The STT to use. Either 'whisper', 'whisper-mlx', 'faster-whisper', and 'paraformer'. Default is 'whisper'." + "help": "The STT to use. Either 'moonshine', 'whisper', 'whisper-mlx', 'faster-whisper', and 'paraformer'. Default is 'whisper'." }, ) llm: Optional[str] = field( default="transformers", metadata={ - "help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'" + "help": "The LLM to use. Either 'transformers', 'mlx-lm', 'openai' or 'ollama'. Default is 'transformers'" }, ) tts: Optional[str] = field( diff --git a/arguments_classes/ollama_language_model_arguments.py b/arguments_classes/ollama_language_model_arguments.py new file mode 100644 index 0000000..e6cf37d --- /dev/null +++ b/arguments_classes/ollama_language_model_arguments.py @@ -0,0 +1,53 @@ +from dataclasses import dataclass, field + + +@dataclass +class OllamaLanguageModelHandlerArguments: + ollama_model_name: str = field( + default="hf.co/HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF", + metadata={ + "help": "The pretrained language model to use. Default is 'hf.co/HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF'." + }, + ) + ollama_user_role: str = field( + default="user", + metadata={ + "help": "Role assigned to the user in the chat context. Default is 'user'." + }, + ) + ollama_init_chat_role: str = field( + default="system", + metadata={ + "help": "Initial role for setting up the chat context. Default is 'system'." + }, + ) + ollama_init_chat_prompt: str = field( + default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.", + metadata={ + "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" + }, + ) + ollama_gen_max_new_tokens: int = field( + default=128, + metadata={ + "help": "Maximum number of new tokens to generate in a single completion. Default is 128." + }, + ) + ollama_gen_temperature: float = field( + default=0.0, + metadata={ + "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0." + }, + ) + ollama_api_endpoint: str = field( + default="http://localhost:11434", + metadata={ + "help": "Ollama endpoint. Default is 'http://localhost:11434'" + }, + ) + ollama_chat_size: int = field( + default=2, + metadata={ + "help": "Number of interactions assitant-user to keep for the chat. None for no limitations." + }, + ) diff --git a/arguments_classes/open_api_language_model_arguments.py b/arguments_classes/open_api_language_model_arguments.py index 2f07afa..d65cfac 100644 --- a/arguments_classes/open_api_language_model_arguments.py +++ b/arguments_classes/open_api_language_model_arguments.py @@ -29,7 +29,6 @@ class OpenApiLanguageModelHandlerArguments: "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'" }, ) - open_api_chat_size: int = field( default=2, metadata={ diff --git a/s2s_pipeline.py b/s2s_pipeline.py index 6ca0b8a..1f66f71 100644 --- a/s2s_pipeline.py +++ b/s2s_pipeline.py @@ -13,6 +13,9 @@ from arguments_classes.mlx_language_model_arguments import ( MLXLanguageModelHandlerArguments, ) +from arguments_classes.ollama_language_model_arguments import ( + OllamaLanguageModelHandlerArguments, +) from arguments_classes.module_arguments import ModuleArguments from arguments_classes.paraformer_stt_arguments import ParaformerSTTHandlerArguments from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments @@ -84,6 +87,7 @@ def parse_arguments(): LanguageModelHandlerArguments, OpenApiLanguageModelHandlerArguments, MLXLanguageModelHandlerArguments, + OllamaLanguageModelHandlerArguments, ParlerTTSHandlerArguments, MeloTTSHandlerArguments, ChatTTSHandlerArguments, @@ -173,6 +177,7 @@ def prepare_all_args( language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, + ollama_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, @@ -186,6 +191,7 @@ def prepare_all_args( language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, + ollama_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, @@ -197,6 +203,7 @@ def prepare_all_args( rename_args(paraformer_stt_handler_kwargs, "paraformer_stt") rename_args(language_model_handler_kwargs, "lm") rename_args(mlx_language_model_handler_kwargs, "mlx_lm") + rename_args(ollama_language_model_handler_kwargs, "ollama") rename_args(open_api_language_model_handler_kwargs, "open_api") rename_args(parler_tts_handler_kwargs, "tts") rename_args(melo_tts_handler_kwargs, "melo") @@ -227,6 +234,7 @@ def build_pipeline( language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, + ollama_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, @@ -278,7 +286,7 @@ def build_pipeline( ) stt = get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_queue, whisper_stt_handler_kwargs, faster_whisper_stt_handler_kwargs, paraformer_stt_handler_kwargs) - lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs) + lm = get_llm_handler(module_kwargs, stop_event, text_prompt_queue, lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, ollama_language_model_handler_kwargs) tts = get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, facebook_mms_tts_handler_kwargs) return ThreadManager([*comms_handlers, vad, stt, lm, tts]) @@ -326,7 +334,7 @@ def get_stt_handler(module_kwargs, stop_event, spoken_prompt_queue, text_prompt_ setup_kwargs=vars(faster_whisper_stt_handler_kwargs), ) else: - raise ValueError("The STT should be either whisper, whisper-mlx, or paraformer.") + raise ValueError("The STT should be either moonshine, whisper, whisper-mlx, paraformer or faster-whisper.") def get_llm_handler( @@ -336,7 +344,8 @@ def get_llm_handler( lm_response_queue, language_model_handler_kwargs, open_api_language_model_handler_kwargs, - mlx_language_model_handler_kwargs + mlx_language_model_handler_kwargs, + ollama_language_model_handler_kwargs, ): if module_kwargs.llm == "transformers": from LLM.language_model import LanguageModelHandler @@ -354,7 +363,6 @@ def get_llm_handler( queue_out=lm_response_queue, setup_kwargs=vars(open_api_language_model_handler_kwargs), ) - elif module_kwargs.llm == "mlx-lm": from LLM.mlx_language_model import MLXLanguageModelHandler return MLXLanguageModelHandler( @@ -363,9 +371,17 @@ def get_llm_handler( queue_out=lm_response_queue, setup_kwargs=vars(mlx_language_model_handler_kwargs), ) + elif module_kwargs.llm == "ollama": + from LLM.ollama_language_model import OllamaLanguageModelHandler + return OllamaLanguageModelHandler( + stop_event, + queue_in=text_prompt_queue, + queue_out=lm_response_queue, + setup_kwargs=vars(ollama_language_model_handler_kwargs), + ) else: - raise ValueError("The LLM should be either transformers or mlx-lm") + raise ValueError("The LLM should be either transformers, open_ai, mlx-lm or ollama") def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chunks_queue, should_listen, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, facebook_mms_tts_handler_kwargs): @@ -416,7 +432,7 @@ def get_tts_handler(module_kwargs, stop_event, lm_response_queue, send_audio_chu setup_kwargs=vars(facebook_mms_tts_handler_kwargs), ) else: - raise ValueError("The TTS should be either parler, melo or chatTTS") + raise ValueError("The TTS should be either parler, melo, chatTTS or facebookMMS") def main(): @@ -431,6 +447,7 @@ def main(): language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, + ollama_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, @@ -447,6 +464,7 @@ def main(): language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, + ollama_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs, @@ -466,6 +484,7 @@ def main(): language_model_handler_kwargs, open_api_language_model_handler_kwargs, mlx_language_model_handler_kwargs, + ollama_language_model_handler_kwargs, parler_tts_handler_kwargs, melo_tts_handler_kwargs, chat_tts_handler_kwargs,