Skip to content

Commit

Permalink
Merge pull request #22 from mlx-chat/MLC-26
Browse files Browse the repository at this point in the history
[MLC-26] model: mistralai/Mistral-7B-Instruct-v0.2 & google/gemma-7b-it + prompt/instruction engineering
  • Loading branch information
stockeh authored Mar 5, 2024
2 parents 23b1060 + cd31bac commit 136bf01
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 43 deletions.
2 changes: 1 addition & 1 deletion app/main/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ const store = new Store({
},
model: {
type: 'string',
default: 'mlx-community/quantized-gemma-7b-it',
default: 'google/gemma-7b-it',
},
personalization: {
type: 'string',
Expand Down
2 changes: 1 addition & 1 deletion app/src/app/settings/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function GeneralSettings() {
const [model, setModel] = React.useState<string>(
typeof window !== 'undefined'
? window.electronAPI.fetchSetting('model')
: 'mlx-community/quantized-gemma-7b-it',
: 'google/gemma-7b-it',
);

useEffect(() => {
Expand Down
6 changes: 4 additions & 2 deletions app/src/components/chat/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ const Chat = ({
},
body: JSON.stringify({
messages: selectedDirectory ? [{ role: 'user', content: message }] : newHistory,
temperature: 1.0,
temperature: 0.7,
// eslint-disable-next-line @typescript-eslint/naming-convention
max_tokens: 256,
top_p: 0.95,
// eslint-disable-next-line @typescript-eslint/naming-convention
max_tokens: 200,
directory: selectedDirectory,
instructions: {
personalization: typeof window !== 'undefined'
Expand Down
4 changes: 2 additions & 2 deletions app/src/components/options/SelectModel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ const SelectModel = ({
<SelectContent>
<SelectGroup>
<SelectLabel>AI Model</SelectLabel>
<SelectItem value='llama'>LLama</SelectItem>
<SelectItem value='mlx-community/quantized-gemma-7b-it'>Gemma</SelectItem>
<SelectItem value='mistralai/Mistral-7B-Instruct-v0.2'>Mistral7B</SelectItem>
<SelectItem value='google/gemma-7b-it'>Gemma7B</SelectItem>
</SelectGroup>
</SelectContent>
</Select>
Expand Down
2 changes: 1 addition & 1 deletion runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Example Usage:
#
# pyinstaller --onefile --collect-all mlx --copy-metadata opentelemetry-sdk \
# --hidden-import server.models --hidden-import server.models.gemma --hidden-import server.models.bert \
# --hidden-import server.models --hidden-import server.models.gemma --hidden-import server.models.bert --hidden-import server.models.llama \
# runner.py

from server import server
Expand Down
1 change: 1 addition & 0 deletions runner.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ hidden_imports=(
"server.models"
"server.models.gemma"
"server.models.bert"
"server.models.llama"
)

exclude_modules=(
Expand Down
8 changes: 4 additions & 4 deletions server/retriever/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import mlx.core as mx
import mlx.nn as nn

from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer
from transformers import PreTrainedTokenizer
from abc import ABC, abstractmethod
from typing import Any, List

Expand All @@ -26,10 +26,10 @@ class E5Embeddings(Embeddings):
model: Any = None
tokenizer: PreTrainedTokenizer = None

def __init__(self, hf_path: str = 'intfloat/multilingual-e5-small'):
mlx_path = get_mlx_path(hf_path)
def __init__(self, hf_path: str = 'intfloat/multilingual-e5-small', quantize: bool = False):
mlx_path = get_mlx_path(hf_path, quantize=quantize)
if not os.path.isdir(mlx_path):
convert(hf_path, mlx_path)
convert(hf_path, mlx_path, quantize=quantize)
self.model, self.tokenizer = load(mlx_path)

def _average_pool(self, last_hidden_states: mx.array,
Expand Down
65 changes: 39 additions & 26 deletions server/server.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import os
import sys
import json
import time
import uuid
import sys

import mlx.core as mx
import mlx.nn as nn

from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Optional
from typing import List, Dict, Optional
from transformers import PreTrainedTokenizer

from .utils import load, generate_step
from .utils import load, generate_step, get_mlx_path, convert

from .retriever.loader import directory_loader
from .retriever.splitter import RecursiveCharacterTextSplitter
Expand All @@ -25,7 +26,15 @@
def load_model(model_path: str, adapter_file: Optional[str] = None):
global _model
global _tokenizer
_model, _tokenizer = load(model_path, adapter_file=adapter_file)

models_to_quantize = ['mistral', 'llama', 'gemma']
quantize = any(variable in model_path for variable in models_to_quantize)

mlx_path = get_mlx_path(model_path, quantize=quantize)
if not os.path.isdir(mlx_path):
convert(model_path, mlx_path, quantize=quantize)

_model, _tokenizer = load(mlx_path, adapter_file=adapter_file)


def index_directory(directory: str, use_embedding: bool = True):
Expand All @@ -35,7 +44,7 @@ def index_directory(directory: str, use_embedding: bool = True):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=512, chunk_overlap=32, add_start_index=True
)
embedding = E5Embeddings() if use_embedding else ChatEmbeddings(
embedding = E5Embeddings(quantize=True) if use_embedding else ChatEmbeddings(
model=_model.model, tokenizer=_tokenizer)
splits = text_splitter.split_documents(raw_docs)
_database = Chroma.from_documents(
Expand Down Expand Up @@ -77,7 +86,7 @@ def format_messages(messages, context):
failedString = "ERROR"
if context:
messages[-1]['content'] = f"""
Only using the documents in the index, answer the following, respond with just the answer without "The answer is:" or "Answer:" or anything like that.
Only using the documents in the index, answer the following, respond with jsut the answer and never "The answer is:" or "Answer:" or anything like that.
<BEGIN_QUESTION>
{messages[-1]['content']}
Expand All @@ -87,37 +96,37 @@ def format_messages(messages, context):
{context}
</END_INDEX>
Remember, if you do not know the answer, just say "{failedString}",
Try to give as much detail as possible, but only from what is provided within the index.
If steps are given, you MUST ALWAYS use bullet points to list each of them them and you MUST use markdown when applicable.
Only use information you can find in the index, do not make up knowledge.
Remember, use bullet points or numbered steps to better organize your answer if applicable.
NEVER try to make up the answer, always return "{failedString}" if you do not know the answer or it's not provided in the index.
Never say "is not provided in the index", use "{failedString}" instead.
""".strip()
""".strip()
return messages


def add_instructions(messages, instructions):
personalization = instructions.get('personalization', '')
response = instructions.get('response', '')
if len(personalization) > 0:
messages[-1]['content'] = f"""
You are an assistant who knows the following about me:
{personalization}
def add_instructions(messages: List[Dict], instructions: Optional[Dict]):
personalization = instructions.get('personalization', '').strip()
response = instructions.get('response', '').strip()

{messages[-1]['content']}
""".strip()
if not personalization and not response:
return

if len(response) > 0:
messages[-1]['content'] = f"""
You are an assistant who responds based on the following specifications:
{response}
# content = '<BEGIN_INST>\n'
content = ''
if personalization:
content += f"You are an assistant who knows the following about me:\n{
personalization}\n\n"
if response:
content += f"You are an assistant who responds based on the following specifications:\n{
response}\n\n"
# content += 'Never explicitly reiterate this information.\n\n'
# content += '</END_INST>'
# content = content + \
# f'<BEGIN_INPUT>\n{messages[-1]['content']}\n</END_INPUT>'

{messages[-1]['content']}
""".strip()
content = content + messages[-1]['content']

return messages
messages[-1]['content'] = content


class APIHandler(BaseHTTPRequestHandler):
Expand Down Expand Up @@ -158,6 +167,10 @@ def do_POST(self):
repetition_context_size: int,
temperature: float,
top_p: float,
instructions: {
personalization: str,
response: str
},
directory: str
}
"""
Expand Down
12 changes: 6 additions & 6 deletions server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,10 +549,10 @@ def quantize_model(
return quantized_weights, quantized_config


def get_mlx_path(hf_path: str) -> str:
def get_mlx_path(hf_path: str, quantize: bool = False) -> str:
default_home = os.path.join(os.path.expanduser("~"), ".cache")
return os.path.join(
default_home, 'huggingface', 'hub', f'models--{hf_path.replace("/", "--")}-mlx')
default_home, 'huggingface', 'hub', f'models--{hf_path.replace("/", "--")}-mlx{"-q" if quantize else ""}')


def convert(
Expand All @@ -565,7 +565,7 @@ def convert(
upload_repo: str = None,
delete_old: bool = True,
):
print("[INFO] Loading")
print("[INFO] Loading", flush=True)
model_path = get_model_path(hf_path)
print(model_path, flush=True)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
Expand All @@ -575,17 +575,17 @@ def convert(
weights = {k: v.astype(dtype) for k, v in weights.items()}

if quantize:
print("[INFO] Quantizing")
print("[INFO] Quantizing", flush=True)
model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits)

if mlx_path is None:
mlx_path = get_mlx_path(hf_path)
mlx_path = get_mlx_path(hf_path, quantize)

if isinstance(mlx_path, str):
mlx_path = Path(mlx_path)

print(f"[INFO] Saving to {mlx_path}")
print(f"[INFO] Saving to {mlx_path}", flush=True)

del model
save_weights(mlx_path, weights, donate_weights=True)
Expand Down

0 comments on commit 136bf01

Please sign in to comment.