From e54005570b8a6ed732caccb832f84862e44fd6f3 Mon Sep 17 00:00:00 2001 From: Jason Stock Date: Mon, 4 Mar 2024 16:08:36 -0500 Subject: [PATCH 1/3] [MLC-26] model: default to mistralai/Mistral-7B-Instruct-v0.2 with 4bit quantization, also quantizing e5 embedding model --- app/main/main.ts | 2 +- app/src/app/settings/page.tsx | 2 +- app/src/components/options/SelectModel.tsx | 1 + runner.py | 2 +- runner.sh | 1 + server/retriever/embeddings.py | 8 ++++---- server/server.py | 17 +++++++++++++---- server/utils.py | 12 ++++++------ 8 files changed, 28 insertions(+), 17 deletions(-) diff --git a/app/main/main.ts b/app/main/main.ts index 089e738..5b76d2a 100644 --- a/app/main/main.ts +++ b/app/main/main.ts @@ -186,7 +186,7 @@ const store = new Store({ }, model: { type: 'string', - default: 'mlx-community/quantized-gemma-7b-it', + default: 'mistralai/Mistral-7B-Instruct-v0.2', }, personalization: { type: 'string', diff --git a/app/src/app/settings/page.tsx b/app/src/app/settings/page.tsx index 1c3ac09..e836b86 100644 --- a/app/src/app/settings/page.tsx +++ b/app/src/app/settings/page.tsx @@ -81,7 +81,7 @@ function GeneralSettings() { const [model, setModel] = React.useState( typeof window !== 'undefined' ? window.electronAPI.fetchSetting('model') - : 'mlx-community/quantized-gemma-7b-it', + : 'mistralai/Mistral-7B-Instruct-v0.2', ); useEffect(() => { diff --git a/app/src/components/options/SelectModel.tsx b/app/src/components/options/SelectModel.tsx index 2099619..08bf8b7 100644 --- a/app/src/components/options/SelectModel.tsx +++ b/app/src/components/options/SelectModel.tsx @@ -28,6 +28,7 @@ const SelectModel = ({ AI Model LLama + Mistral Gemma diff --git a/runner.py b/runner.py index 1047c2e..e1c86b4 100644 --- a/runner.py +++ b/runner.py @@ -3,7 +3,7 @@ # Example Usage: # # pyinstaller --onefile --collect-all mlx --copy-metadata opentelemetry-sdk \ -# --hidden-import server.models --hidden-import server.models.gemma --hidden-import server.models.bert \ +# --hidden-import server.models --hidden-import server.models.gemma --hidden-import server.models.bert --hidden-import server.models.llama \ # runner.py from server import server diff --git a/runner.sh b/runner.sh index 158f2ea..136981a 100755 --- a/runner.sh +++ b/runner.sh @@ -9,6 +9,7 @@ hidden_imports=( "server.models" "server.models.gemma" "server.models.bert" + "server.models.llama" ) exclude_modules=( diff --git a/server/retriever/embeddings.py b/server/retriever/embeddings.py index 0352a92..af0220b 100644 --- a/server/retriever/embeddings.py +++ b/server/retriever/embeddings.py @@ -2,7 +2,7 @@ import mlx.core as mx import mlx.nn as nn -from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer +from transformers import PreTrainedTokenizer from abc import ABC, abstractmethod from typing import Any, List @@ -26,10 +26,10 @@ class E5Embeddings(Embeddings): model: Any = None tokenizer: PreTrainedTokenizer = None - def __init__(self, hf_path: str = 'intfloat/multilingual-e5-small'): - mlx_path = get_mlx_path(hf_path) + def __init__(self, hf_path: str = 'intfloat/multilingual-e5-small', quantize: bool = False): + mlx_path = get_mlx_path(hf_path, quantize=quantize) if not os.path.isdir(mlx_path): - convert(hf_path, mlx_path) + convert(hf_path, mlx_path, quantize=quantize) self.model, self.tokenizer = load(mlx_path) def _average_pool(self, last_hidden_states: mx.array, diff --git a/server/server.py b/server/server.py index 4885b22..fa75033 100644 --- a/server/server.py +++ b/server/server.py @@ -1,7 +1,8 @@ +import os +import sys import json import time import uuid -import sys import mlx.core as mx import mlx.nn as nn @@ -10,7 +11,7 @@ from typing import Optional from transformers import PreTrainedTokenizer -from .utils import load, generate_step +from .utils import load, generate_step, get_mlx_path, convert from .retriever.loader import directory_loader from .retriever.splitter import RecursiveCharacterTextSplitter @@ -25,7 +26,15 @@ def load_model(model_path: str, adapter_file: Optional[str] = None): global _model global _tokenizer - _model, _tokenizer = load(model_path, adapter_file=adapter_file) + + models_to_quantize = ['mistral', 'llama'] + quantize = any(variable in model_path for variable in models_to_quantize) + + mlx_path = get_mlx_path(model_path, quantize=quantize) + if not os.path.isdir(mlx_path): + convert(model_path, mlx_path, quantize=quantize) + + _model, _tokenizer = load(mlx_path, adapter_file=adapter_file) def index_directory(directory: str, use_embedding: bool = True): @@ -35,7 +44,7 @@ def index_directory(directory: str, use_embedding: bool = True): text_splitter = RecursiveCharacterTextSplitter( chunk_size=512, chunk_overlap=32, add_start_index=True ) - embedding = E5Embeddings() if use_embedding else ChatEmbeddings( + embedding = E5Embeddings(quantize=True) if use_embedding else ChatEmbeddings( model=_model.model, tokenizer=_tokenizer) splits = text_splitter.split_documents(raw_docs) _database = Chroma.from_documents( diff --git a/server/utils.py b/server/utils.py index 0013d82..b55c7c1 100644 --- a/server/utils.py +++ b/server/utils.py @@ -549,10 +549,10 @@ def quantize_model( return quantized_weights, quantized_config -def get_mlx_path(hf_path: str) -> str: +def get_mlx_path(hf_path: str, quantize: bool = False) -> str: default_home = os.path.join(os.path.expanduser("~"), ".cache") return os.path.join( - default_home, 'huggingface', 'hub', f'models--{hf_path.replace("/", "--")}-mlx') + default_home, 'huggingface', 'hub', f'models--{hf_path.replace("/", "--")}-mlx{"-q" if quantize else ""}') def convert( @@ -565,7 +565,7 @@ def convert( upload_repo: str = None, delete_old: bool = True, ): - print("[INFO] Loading") + print("[INFO] Loading", flush=True) model_path = get_model_path(hf_path) print(model_path, flush=True) model, config, tokenizer = fetch_from_hub(model_path, lazy=True) @@ -575,17 +575,17 @@ def convert( weights = {k: v.astype(dtype) for k, v in weights.items()} if quantize: - print("[INFO] Quantizing") + print("[INFO] Quantizing", flush=True) model.load_weights(list(weights.items())) weights, config = quantize_model(model, config, q_group_size, q_bits) if mlx_path is None: - mlx_path = get_mlx_path(hf_path) + mlx_path = get_mlx_path(hf_path, quantize) if isinstance(mlx_path, str): mlx_path = Path(mlx_path) - print(f"[INFO] Saving to {mlx_path}") + print(f"[INFO] Saving to {mlx_path}", flush=True) del model save_weights(mlx_path, weights, donate_weights=True) From d5aa14ff254ec7d49e03cad071db7317d21476a4 Mon Sep 17 00:00:00 2001 From: Jason Stock Date: Mon, 4 Mar 2024 21:24:45 -0500 Subject: [PATCH 2/3] [MLC-26] model: default to Gemma with prelim instruction tuning / prompt engineering --- app/src/components/chat/Chat.tsx | 6 ++- app/src/components/options/SelectModel.tsx | 5 +-- server/server.py | 50 ++++++++++++---------- 3 files changed, 33 insertions(+), 28 deletions(-) diff --git a/app/src/components/chat/Chat.tsx b/app/src/components/chat/Chat.tsx index 3432cb2..0795720 100644 --- a/app/src/components/chat/Chat.tsx +++ b/app/src/components/chat/Chat.tsx @@ -43,9 +43,11 @@ const Chat = ({ }, body: JSON.stringify({ messages: selectedDirectory ? [{ role: 'user', content: message }] : newHistory, - temperature: 1.0, + temperature: 0.7, // eslint-disable-next-line @typescript-eslint/naming-convention - max_tokens: 256, + top_p: 0.95, + // eslint-disable-next-line @typescript-eslint/naming-convention + max_tokens: 200, directory: selectedDirectory, instructions: { personalization: typeof window !== 'undefined' diff --git a/app/src/components/options/SelectModel.tsx b/app/src/components/options/SelectModel.tsx index 08bf8b7..c812fbd 100644 --- a/app/src/components/options/SelectModel.tsx +++ b/app/src/components/options/SelectModel.tsx @@ -27,9 +27,8 @@ const SelectModel = ({ AI Model - LLama - Mistral - Gemma + Mistral7B + Gemma7B diff --git a/server/server.py b/server/server.py index fa75033..627a26d 100644 --- a/server/server.py +++ b/server/server.py @@ -8,7 +8,7 @@ import mlx.nn as nn from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import Optional +from typing import List, Dict, Optional from transformers import PreTrainedTokenizer from .utils import load, generate_step, get_mlx_path, convert @@ -27,7 +27,7 @@ def load_model(model_path: str, adapter_file: Optional[str] = None): global _model global _tokenizer - models_to_quantize = ['mistral', 'llama'] + models_to_quantize = ['mistral', 'llama', 'gemma'] quantize = any(variable in model_path for variable in models_to_quantize) mlx_path = get_mlx_path(model_path, quantize=quantize) @@ -86,7 +86,7 @@ def format_messages(messages, context): failedString = "ERROR" if context: messages[-1]['content'] = f""" -Only using the documents in the index, answer the following, respond with just the answer without "The answer is:" or "Answer:" or anything like that. +Only using the documents in the index, answer the following, respond with jsut the answer and never "The answer is:" or "Answer:" or anything like that. {messages[-1]['content']} @@ -96,37 +96,37 @@ def format_messages(messages, context): {context} -Remember, if you do not know the answer, just say "{failedString}", Try to give as much detail as possible, but only from what is provided within the index. If steps are given, you MUST ALWAYS use bullet points to list each of them them and you MUST use markdown when applicable. Only use information you can find in the index, do not make up knowledge. -Remember, use bullet points or numbered steps to better organize your answer if applicable. NEVER try to make up the answer, always return "{failedString}" if you do not know the answer or it's not provided in the index. -Never say "is not provided in the index", use "{failedString}" instead. - """.strip() +""".strip() return messages -def add_instructions(messages, instructions): - personalization = instructions.get('personalization', '') - response = instructions.get('response', '') - if len(personalization) > 0: - messages[-1]['content'] = f""" -You are an assistant who knows the following about me: -{personalization} +def add_instructions(messages: List[Dict], instructions: Optional[Dict]): + personalization = instructions.get('personalization', '').strip() + response = instructions.get('response', '').strip() -{messages[-1]['content']} -""".strip() + if not personalization and not response: + return - if len(response) > 0: - messages[-1]['content'] = f""" -You are an assistant who responds based on the following specifications: -{response} + # content = '\n' + content = '' + if personalization: + content += f"You are an assistant who knows the following about me:\n{ + personalization}\n\n" + if response: + content += f"You are an assistant who responds based on the following specifications:\n{ + response}\n\n" + # content += 'Never explicitly reiterate this information.\n\n' + # content += '' + # content = content + \ + # f'\n{messages[-1]['content']}\n' -{messages[-1]['content']} -""".strip() + content = content + messages[-1]['content'] - return messages + messages[-1]['content'] = content class APIHandler(BaseHTTPRequestHandler): @@ -167,6 +167,10 @@ def do_POST(self): repetition_context_size: int, temperature: float, top_p: float, + instructions: { + personalization: str, + response: str + }, directory: str } """ From cd31bacc7ff23a0209f4dc164e6860a123897acd Mon Sep 17 00:00:00 2001 From: Jason Stock Date: Mon, 4 Mar 2024 21:28:39 -0500 Subject: [PATCH 3/3] [MLC-26] model: default to Gemma7b --- app/main/main.ts | 2 +- app/src/app/settings/page.tsx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/app/main/main.ts b/app/main/main.ts index 5b76d2a..fd965da 100644 --- a/app/main/main.ts +++ b/app/main/main.ts @@ -186,7 +186,7 @@ const store = new Store({ }, model: { type: 'string', - default: 'mistralai/Mistral-7B-Instruct-v0.2', + default: 'google/gemma-7b-it', }, personalization: { type: 'string', diff --git a/app/src/app/settings/page.tsx b/app/src/app/settings/page.tsx index e836b86..24ff7c9 100644 --- a/app/src/app/settings/page.tsx +++ b/app/src/app/settings/page.tsx @@ -81,7 +81,7 @@ function GeneralSettings() { const [model, setModel] = React.useState( typeof window !== 'undefined' ? window.electronAPI.fetchSetting('model') - : 'mistralai/Mistral-7B-Instruct-v0.2', + : 'google/gemma-7b-it', ); useEffect(() => {