Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLC-26] model: mistralai/Mistral-7B-Instruct-v0.2 & google/gemma-7b-it + prompt/instruction engineering #22

Merged
merged 3 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/main/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ const store = new Store({
},
model: {
type: 'string',
default: 'mlx-community/quantized-gemma-7b-it',
default: 'google/gemma-7b-it',
},
personalization: {
type: 'string',
Expand Down
2 changes: 1 addition & 1 deletion app/src/app/settings/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function GeneralSettings() {
const [model, setModel] = React.useState<string>(
typeof window !== 'undefined'
? window.electronAPI.fetchSetting('model')
: 'mlx-community/quantized-gemma-7b-it',
: 'google/gemma-7b-it',
);

useEffect(() => {
Expand Down
6 changes: 4 additions & 2 deletions app/src/components/chat/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ const Chat = ({
},
body: JSON.stringify({
messages: selectedDirectory ? [{ role: 'user', content: message }] : newHistory,
temperature: 1.0,
temperature: 0.7,
// eslint-disable-next-line @typescript-eslint/naming-convention
max_tokens: 256,
top_p: 0.95,
// eslint-disable-next-line @typescript-eslint/naming-convention
max_tokens: 200,
directory: selectedDirectory,
instructions: {
personalization: typeof window !== 'undefined'
Expand Down
4 changes: 2 additions & 2 deletions app/src/components/options/SelectModel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ const SelectModel = ({
<SelectContent>
<SelectGroup>
<SelectLabel>AI Model</SelectLabel>
<SelectItem value='llama'>LLama</SelectItem>
<SelectItem value='mlx-community/quantized-gemma-7b-it'>Gemma</SelectItem>
<SelectItem value='mistralai/Mistral-7B-Instruct-v0.2'>Mistral7B</SelectItem>
<SelectItem value='google/gemma-7b-it'>Gemma7B</SelectItem>
</SelectGroup>
</SelectContent>
</Select>
Expand Down
2 changes: 1 addition & 1 deletion runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Example Usage:
#
# pyinstaller --onefile --collect-all mlx --copy-metadata opentelemetry-sdk \
# --hidden-import server.models --hidden-import server.models.gemma --hidden-import server.models.bert \
# --hidden-import server.models --hidden-import server.models.gemma --hidden-import server.models.bert --hidden-import server.models.llama \
# runner.py

from server import server
Expand Down
1 change: 1 addition & 0 deletions runner.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ hidden_imports=(
"server.models"
"server.models.gemma"
"server.models.bert"
"server.models.llama"
)

exclude_modules=(
Expand Down
8 changes: 4 additions & 4 deletions server/retriever/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import mlx.core as mx
import mlx.nn as nn

from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer
from transformers import PreTrainedTokenizer
from abc import ABC, abstractmethod
from typing import Any, List

Expand All @@ -26,10 +26,10 @@ class E5Embeddings(Embeddings):
model: Any = None
tokenizer: PreTrainedTokenizer = None

def __init__(self, hf_path: str = 'intfloat/multilingual-e5-small'):
mlx_path = get_mlx_path(hf_path)
def __init__(self, hf_path: str = 'intfloat/multilingual-e5-small', quantize: bool = False):
mlx_path = get_mlx_path(hf_path, quantize=quantize)
if not os.path.isdir(mlx_path):
convert(hf_path, mlx_path)
convert(hf_path, mlx_path, quantize=quantize)
self.model, self.tokenizer = load(mlx_path)

def _average_pool(self, last_hidden_states: mx.array,
Expand Down
65 changes: 39 additions & 26 deletions server/server.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import os
import sys
import json
import time
import uuid
import sys

import mlx.core as mx
import mlx.nn as nn

from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Optional
from typing import List, Dict, Optional
from transformers import PreTrainedTokenizer

from .utils import load, generate_step
from .utils import load, generate_step, get_mlx_path, convert

from .retriever.loader import directory_loader
from .retriever.splitter import RecursiveCharacterTextSplitter
Expand All @@ -25,7 +26,15 @@
def load_model(model_path: str, adapter_file: Optional[str] = None):
global _model
global _tokenizer
_model, _tokenizer = load(model_path, adapter_file=adapter_file)

models_to_quantize = ['mistral', 'llama', 'gemma']
quantize = any(variable in model_path for variable in models_to_quantize)

mlx_path = get_mlx_path(model_path, quantize=quantize)
if not os.path.isdir(mlx_path):
convert(model_path, mlx_path, quantize=quantize)

_model, _tokenizer = load(mlx_path, adapter_file=adapter_file)


def index_directory(directory: str, use_embedding: bool = True):
Expand All @@ -35,7 +44,7 @@ def index_directory(directory: str, use_embedding: bool = True):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=512, chunk_overlap=32, add_start_index=True
)
embedding = E5Embeddings() if use_embedding else ChatEmbeddings(
embedding = E5Embeddings(quantize=True) if use_embedding else ChatEmbeddings(
model=_model.model, tokenizer=_tokenizer)
splits = text_splitter.split_documents(raw_docs)
_database = Chroma.from_documents(
Expand Down Expand Up @@ -77,7 +86,7 @@ def format_messages(messages, context):
failedString = "ERROR"
if context:
messages[-1]['content'] = f"""
Only using the documents in the index, answer the following, respond with just the answer without "The answer is:" or "Answer:" or anything like that.
Only using the documents in the index, answer the following, respond with jsut the answer and never "The answer is:" or "Answer:" or anything like that.
<BEGIN_QUESTION>
{messages[-1]['content']}
Expand All @@ -87,37 +96,37 @@ def format_messages(messages, context):
{context}
</END_INDEX>
Remember, if you do not know the answer, just say "{failedString}",
Try to give as much detail as possible, but only from what is provided within the index.
If steps are given, you MUST ALWAYS use bullet points to list each of them them and you MUST use markdown when applicable.
Only use information you can find in the index, do not make up knowledge.
Remember, use bullet points or numbered steps to better organize your answer if applicable.
NEVER try to make up the answer, always return "{failedString}" if you do not know the answer or it's not provided in the index.
Never say "is not provided in the index", use "{failedString}" instead.
""".strip()
""".strip()
return messages


def add_instructions(messages, instructions):
personalization = instructions.get('personalization', '')
response = instructions.get('response', '')
if len(personalization) > 0:
messages[-1]['content'] = f"""
You are an assistant who knows the following about me:
{personalization}
def add_instructions(messages: List[Dict], instructions: Optional[Dict]):
personalization = instructions.get('personalization', '').strip()
response = instructions.get('response', '').strip()

{messages[-1]['content']}
""".strip()
if not personalization and not response:
return

if len(response) > 0:
messages[-1]['content'] = f"""
You are an assistant who responds based on the following specifications:
{response}
# content = '<BEGIN_INST>\n'
content = ''
if personalization:
content += f"You are an assistant who knows the following about me:\n{
personalization}\n\n"
if response:
content += f"You are an assistant who responds based on the following specifications:\n{
response}\n\n"
# content += 'Never explicitly reiterate this information.\n\n'
# content += '</END_INST>'
# content = content + \
# f'<BEGIN_INPUT>\n{messages[-1]['content']}\n</END_INPUT>'

{messages[-1]['content']}
""".strip()
content = content + messages[-1]['content']

return messages
messages[-1]['content'] = content


class APIHandler(BaseHTTPRequestHandler):
Expand Down Expand Up @@ -158,6 +167,10 @@ def do_POST(self):
repetition_context_size: int,
temperature: float,
top_p: float,
instructions: {
personalization: str,
response: str
},
directory: str
}
"""
Expand Down
12 changes: 6 additions & 6 deletions server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,10 +549,10 @@ def quantize_model(
return quantized_weights, quantized_config


def get_mlx_path(hf_path: str) -> str:
def get_mlx_path(hf_path: str, quantize: bool = False) -> str:
default_home = os.path.join(os.path.expanduser("~"), ".cache")
return os.path.join(
default_home, 'huggingface', 'hub', f'models--{hf_path.replace("/", "--")}-mlx')
default_home, 'huggingface', 'hub', f'models--{hf_path.replace("/", "--")}-mlx{"-q" if quantize else ""}')


def convert(
Expand All @@ -565,7 +565,7 @@ def convert(
upload_repo: str = None,
delete_old: bool = True,
):
print("[INFO] Loading")
print("[INFO] Loading", flush=True)
model_path = get_model_path(hf_path)
print(model_path, flush=True)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
Expand All @@ -575,17 +575,17 @@ def convert(
weights = {k: v.astype(dtype) for k, v in weights.items()}

if quantize:
print("[INFO] Quantizing")
print("[INFO] Quantizing", flush=True)
model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits)

if mlx_path is None:
mlx_path = get_mlx_path(hf_path)
mlx_path = get_mlx_path(hf_path, quantize)

if isinstance(mlx_path, str):
mlx_path = Path(mlx_path)

print(f"[INFO] Saving to {mlx_path}")
print(f"[INFO] Saving to {mlx_path}", flush=True)

del model
save_weights(mlx_path, weights, donate_weights=True)
Expand Down
Loading