diff --git a/server/LORA.md b/server/LORA.md deleted file mode 100644 index 445b929..0000000 --- a/server/LORA.md +++ /dev/null @@ -1,167 +0,0 @@ -# Fine-Tuning with LoRA or QLoRA - -You can use use the `mlx-lm` package to fine-tune an LLM with low rank -adaptation (LoRA) for a target task.[^lora] The example also supports quantized -LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families: - -- Mistral -- Llama -- Phi2 -- Mixtral -- Qwen2 -- OLMo - -## Contents - -* [Run](#Run) - * [Fine-tune](#Fine-tune) - * [Evaluate](#Evaluate) - * [Generate](#Generate) -* [Fuse and Upload](#Fuse-and-Upload) -* [Data](#Data) -* [Memory Issues](#Memory-Issues) - -## Run - -The main command is `mlx_lm.lora`. To see a full list of options run: - -```shell -python -m mlx_lm.lora --help -``` - -Note, in the following the `--model` argument can be any compatible Hugging -Face repo or a local path to a converted model. - -### Fine-tune - -To fine-tune a model use: - -```shell -python -m mlx_lm.lora \ - --model \ - --train \ - --data \ - --iters 600 -``` - -The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl` -when using `--train` and a path to a `test.jsonl` when using `--test`. For more -details on the data format see the section on [Data](#Data). - -For example, to fine-tune a Mistral 7B you can use `--model -mistralai/Mistral-7B-v0.1`. - -If `--model` points to a quantized model, then the training will use QLoRA, -otherwise it will use regular LoRA. - -By default, the adapter weights are saved in `adapters.npz`. You can specify -the output location with `--adapter-file`. - -You can resume fine-tuning with an existing adapter with -`--resume-adapter-file `. - -### Evaluate - -To compute test set perplexity use: - -```shell -python -m mlx_lm.lora \ - --model \ - --adapter-file \ - --data \ - --test -``` - -## Fuse and Upload - -You can generate a model fused with the low-rank adapters using the -`mlx_lm.fuse` command. This command also allows you to upload the fused model -to the Hugging Face Hub. - -To see supported options run: - -```shell -python -m mlx_lm.fuse --help -``` - -To generate the fused model run: - -```shell -python -m mlx_lm.fuse --model -``` - -This will by default load the adapters from `adapters.npz`, and save the fused -model in the path `lora_fused_model/`. All of these are configurable. - -To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments -to `mlx_lm.fuse`. The latter is the repo name of the original model, which is -useful for the sake of attribution and model versioning. - -For example, to fuse and upload a model derived from Mistral-7B-v0.1, run: - -```shell -python -m mlx_lm.fuse \ - --model mistralai/Mistral-7B-v0.1 \ - --upload-repo mlx-community/my-4bit-lora-mistral \ - --hf-path mistralai/Mistral-7B-v0.1 -``` - -## Data - -The LoRA command expects you to provide a dataset with `--data`. The MLX -Examples GitHub repo has an [example of the WikiSQL -data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the -correct format. - -For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a -`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data -loader expects a `test.jsonl` in the data directory. Each line in the `*.jsonl` -file should look like: - -``` -{"text": "This is an example for the model."} -``` - -Note, other keys will be ignored by the loader. - -## Memory Issues - -Fine-tuning a large model with LoRA requires a machine with a decent amount -of memory. Here are some tips to reduce memory use should you need to do so: - -1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model - with `convert.py` and the `-q` flag. See the [Setup](#setup) section for - more details. - -2. Try using a smaller batch size with `--batch-size`. The default is `4` so - setting this to `2` or `1` will reduce memory consumption. This may slow - things down a little, but will also reduce the memory use. - -3. Reduce the number of layers to fine-tune with `--lora-layers`. The default - is `16`, so you can try `8` or `4`. This reduces the amount of memory - needed for back propagation. It may also reduce the quality of the - fine-tuned model if you are fine-tuning with a lot of data. - -4. Longer examples require more memory. If it makes sense for your data, one thing - you can do is break your examples into smaller - sequences when making the `{train, valid, test}.jsonl` files. - -For example, for a machine with 32 GB the following should run reasonably fast: - -``` -python lora.py \ - --model mistralai/Mistral-7B-v0.1 \ - --train \ - --batch-size 1 \ - --lora-layers 4 \ - --data wikisql -``` - -The above command on an M1 Max with 32 GB runs at about 250 -tokens-per-second, using the MLX Example -[`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) -data set. - - -[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA. -[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) diff --git a/server/MERGE.md b/server/MERGE.md deleted file mode 100644 index 2ee2414..0000000 --- a/server/MERGE.md +++ /dev/null @@ -1,50 +0,0 @@ -# Model Merging - -You can use `mlx-lm` to merge models and upload them to the Hugging -Face hub or save them locally for LoRA fine tuning. - -The main command is `mlx_lm.merge`: - -```shell -python -m mlx_lm.merge --config config.yaml -``` - -The merged model will be saved by default in `mlx_merged_model`. To see a -full list of options run: - -```shell -python -m mlx_lm.merge --help -``` - -Here is an example `config.yaml`: - -```yaml -models: - - OpenPipe/mistral-ft-optimized-1218 - - mlabonne/NeuralHermes-2.5-Mistral-7B -method: slerp -parameters: - t: - - filter: self_attn - value: [0, 0.5, 0.3, 0.7, 1] - - filter: mlp - value: [1, 0.5, 0.7, 0.3, 0] - - value: 0.5 -``` - -The `models` field is a list of Hugging Face repo ids. The first model in the -list is treated as the base model into which the remaining models are merged. - -The `method` field is the merging method. Right now `slerp` is the only -supported method. - -The `parameters` are the corresponding parameters for the given `method`. -Each parameter is a list with `filter` determining which layer the parameter -applies to and `value` determining the actual value used. The last item in -the list without a `filter` field is the default. - -If `value` is a list, it specifies the start and end values for the -corresponding segment of blocks. In the example above, the models have 32 -blocks. For blocks 1-8, the layers with `self_attn` in the name will use the -values `np.linspace(0, 0.5, 8)`, the same layers in the next 8 blocks (9-16) -will use `np.linspace(0.5, 0.3, 8)`, and so on. diff --git a/server/README.md b/server/README.md deleted file mode 100644 index 66f2b5e..0000000 --- a/server/README.md +++ /dev/null @@ -1,10 +0,0 @@ -## Generate Text with MLX and :hugs: Hugging Face - -This an example of large language model text generation that can pull models from -the Hugging Face Hub. - -For more information on this example, see the [README](../README.md) in the -parent directory. - -This package also supports fine tuning with LoRA or QLoRA. For more information -see the [LoRA documentation](LORA.md). diff --git a/server/SERVER.md b/server/SERVER.md deleted file mode 100644 index 1176951..0000000 --- a/server/SERVER.md +++ /dev/null @@ -1,63 +0,0 @@ -# HTTP Model Server - -You use `mlx-lm` to make an HTTP API for generating text with any supported -model. The HTTP API is intended to be similar to the [OpenAI chat -API](https://platform.openai.com/docs/api-reference). - -Start the server with: - -```shell -python -m mlx_lm.server --model -``` - -For example: - -```shell -python -m mlx_lm.server --model mistralai/Mistral-7B-Instruct-v0.1 -``` - -This will start a text generation server on port `8080` of the `localhost` -using Mistral 7B instruct. The model will be downloaded from the provided -Hugging Face repo if it is not already in the local cache. - -To see a full list of options run: - -```shell -python -m mlx_lm.server --help -``` - -You can make a request to the model by running: - -```shell -curl localhost:8080/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "messages": [{"role": "user", "content": "Say this is a test!"}], - "temperature": 0.7 - }' -``` - -### Request Fields - -- `messages`: An array of message objects representing the conversation - history. Each message object should have a role (e.g. user, assistant) and - content (the message text). - -- `role_mapping`: (Optional) A dictionary to customize the role prefixes in - the generated prompt. If not provided, the default mappings are used. - -- `stop`: (Optional) An array of strings or a single string. Thesse are - sequences of tokens on which the generation should stop. - -- `max_tokens`: (Optional) An integer specifying the maximum number of tokens - to generate. Defaults to `100`. - -- `stream`: (Optional) A boolean indicating if the response should be - streamed. If true, responses are sent as they are generated. Defaults to - false. - -- `temperature`: (Optional) A float specifying the sampling temperature. - Defaults to `1.0`. - -- `top_p`: (Optional) A float specifying the nucleus sampling parameter. - Defaults to `1.0`. diff --git a/server/UPLOAD.md b/server/UPLOAD.md deleted file mode 100644 index f5de365..0000000 --- a/server/UPLOAD.md +++ /dev/null @@ -1,37 +0,0 @@ -### Packaging for PyPI - -Install `build` and `twine`: - -``` -pip install --user --upgrade build -pip install --user --upgrade twine -``` - -Generate the source distribution and wheel: - -``` -python -m build -``` - -> [!warning] -> Use a test server first - -#### Test Upload - -Upload to test server: - -``` -python -m twine upload --repository testpypi dist/* -``` - -Install from test server and check that it works: - -``` -python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx-lm -``` - -#### Upload - -``` -python -m twine upload dist/* -``` diff --git a/server/examples/merge_config.yaml b/server/examples/merge_config.yaml deleted file mode 100644 index 98701e5..0000000 --- a/server/examples/merge_config.yaml +++ /dev/null @@ -1,11 +0,0 @@ -models: - - OpenPipe/mistral-ft-optimized-1218 - - mlabonne/NeuralHermes-2.5-Mistral-7B -method: slerp -parameters: - t: - - filter: self_attn - value: [0, 0.5, 0.3, 0.7, 1] - - filter: mlp - value: [1, 0.5, 0.7, 0.3, 0] - - value: 0.5 diff --git a/server/flaskserver.py b/server/flaskserver.py deleted file mode 100644 index 25bb5e6..0000000 --- a/server/flaskserver.py +++ /dev/null @@ -1,262 +0,0 @@ -import argparse -import uuid -import time -import json -import numpy as np -import mlx.core as mx -import mlx.nn as nn - -from collections import namedtuple -from typing import List, Optional -from flask import Flask, request, Response -from transformers import PreTrainedTokenizer - -from .utils import load - -app = Flask(__name__) - - -_model: Optional[nn.Module] = None -_tokenizer: Optional[PreTrainedTokenizer] = None - - -def load_model(model_path: str, adapter_file: Optional[str] = None) -> None: - global _model - global _tokenizer - _model, _tokenizer = load(model_path, adapter_file=adapter_file) - - -StopCondition = namedtuple('StopCondition', ['stop_met', 'trim_length']) - - -def stopping_criteria( - tokens: List[int], - stop_id_sequences: List[np.ndarray], - eos_token_id: int, -) -> StopCondition: - ''' - Determines whether the token generation should stop based on predefined conditions. - - Args: - tokens (List[int]): The current sequence of generated tokens. - stop_id_sequences (List[np.ndarray]): A list of numpy arrays, each representing a sequence of token IDs. - If the end of the `tokens` list matches any of these sequences, the generation should stop. - eos_token_id (int): The token ID that represents the end-of-sequence. If the last token in `tokens` matches this, - the generation should stop. - - Returns: - StopCondition: A named tuple indicating whether the stop condition has been met (`stop_met`) - and how many tokens should be trimmed from the end if it has (`trim_length`). - ''' - if tokens and tokens[-1] == eos_token_id: - return StopCondition(stop_met=True, trim_length=0) - - for stop_ids in stop_id_sequences: - if len(tokens) >= len(stop_ids): - if np.array_equal(tokens[-len(stop_ids):], stop_ids): - return StopCondition(stop_met=True, trim_length=len(stop_ids)) - - return StopCondition(stop_met=False, trim_length=0) - - -def generate(prompt: mx.array, model: nn.Module, temp: float = 0.0, top_p: float = 1.0): - def sample(logits): - if temp == 0: - return mx.argmax(logits, axis=-1) - else: - if top_p > 0 and top_p < 1.0: - if ( - logits.dtype == mx.bfloat16 - ): # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16 - logits = logits.astype(mx.float32) - probs = mx.softmax(logits / temp, axis=-1) - - sorted_probs = mx.sort(probs)[::-1] - sorted_indices = mx.argsort(probs)[::-1] - cumulative_probs = mx.cumsum(sorted_probs, axis=-1) - - top_probs = mx.where( - cumulative_probs > 1 - top_p, - sorted_probs, - mx.zeros_like(sorted_probs), - ) - sorted_tok = mx.random.categorical(mx.log(top_probs)) - tok = sorted_indices.squeeze(0)[sorted_tok] - return tok - return mx.random.categorical(logits * (1 / temp)) - - y = prompt - cache = None - - while True: - logits, cache = model(y[None], cache=cache) - logits = logits[:, -1, :] - - y = sample(logits) - token = y.item() - - yield token - - -def convert_chat(messages: any, role_mapping: Optional[dict] = None) -> str: - default_role_mapping = { - 'system_prompt': 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.', - 'system': "ASSISTANT's RULE: ", - 'user': 'USER: ', - 'assistant': 'ASSISTANT: ', - 'stop': '\n', - } - role_mapping = role_mapping if role_mapping is not None else default_role_mapping - - prompt = '' - for line in messages: - role_prefix = role_mapping.get(line['role'], '') - stop = role_mapping.get('stop', '') - content = line.get('content', '') - prompt += f'{role_prefix}{content}{stop}' - - prompt += role_mapping.get('assistant', '') - return prompt.rstrip() - - -def create_response(chat_id, requested_model, prompt, tokens, text): - response = { - 'id': chat_id, - 'object': 'chat.completion', - 'created': int(time.time()), - 'model': requested_model, - 'system_fingerprint': f'fp_{uuid.uuid4()}', - 'choices': [ - { - 'index': 0, - 'message': { - 'role': 'assistant', - 'content': text, - }, - 'logprobs': None, - 'finish_reason': None, - } - ], - 'usage': { - 'prompt_tokens': len(prompt), - 'completion_tokens': len(tokens), - 'total_tokens': len(prompt) + len(tokens), - }, - } - - return response - - -class APIHandler: - @staticmethod - def _set_headers(response, status_code=200): - response.status_code = status_code - response.headers["Content-type"] = "application/json" - response.headers["Access-Control-Allow-Origin"] = "*" - response.headers["Access-Control-Allow-Methods"] = "*" - response.headers["Access-Control-Allow-Headers"] = "*" - - @staticmethod - def handle_post_request(post_data): - body = json.loads(post_data.decode("utf-8")) - chat_id = f"chatcmpl-{uuid.uuid4()}" - if hasattr(_tokenizer, "apply_chat_template") and _tokenizer.chat_template: - prompt = _tokenizer.apply_chat_template( - body["messages"], - tokenize=True, - add_generation_prompt=True, - return_tensors="np", - ) - else: - prompt = convert_chat(body["messages"], body.get("role_mapping")) - prompt = _tokenizer.encode(prompt, return_tensors="np") - - prompt = mx.array(prompt[0]) - stop_words = body.get("stop", []) - stop_words = [stop_words] if isinstance( - stop_words, str) else stop_words - stop_id_sequences = [ - _tokenizer.encode(stop_word, return_tensors="np", - add_special_tokens=False)[0] - for stop_word in stop_words - ] - eos_token_id = _tokenizer.eos_token_id - max_tokens = body.get("max_tokens", 100) - stream = body.get("stream", False) - requested_model = body.get("model", "default_model") - temperature = body.get("temperature", 1.0) - top_p = body.get("top_p", 1.0) - if not stream: - tokens = [] - for token, _ in zip( - generate(prompt, _model, temperature, top_p=top_p), - range(max_tokens), - ): - tokens.append(token) - stop_condition = stopping_criteria( - tokens, stop_id_sequences, eos_token_id - ) - if stop_condition.stop_met: - if stop_condition.trim_length: - tokens = tokens[: -stop_condition.trim_length] - break - - text = _tokenizer.decode(tokens) - return create_response(chat_id, requested_model, prompt, tokens, text) - else: - pass - - @app.route('/v1/chat/completions', methods=['POST', 'OPTIONS']) - def chat_completions(): - try: - if request.method == 'OPTIONS': - response = Response() - APIHandler._set_headers(response, 204) - return response - - elif request.method == 'POST': - post_data = request.data - response = Response() - - APIHandler._set_headers(response, 200) - - response_data = APIHandler.handle_post_request(post_data) - - response.data = json.dumps(response_data) - return response - except Exception as e: - return Response(json.dumps({"error": f"An unexpected error occurred. {e}"}), - status=500, content_type="application/json") - - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser(description='MLX Http Server.') - parser.add_argument( - '--model', - type=str, - required=True, - help='The path to the MLX model weights, tokenizer, and config', - ) - parser.add_argument( - '--adapter-file', - type=str, - help='Optional path for the trained adapter weights.', - ) - parser.add_argument( - '--host', - type=str, - default='127.0.0.1', - help='Host for the HTTP server (default: 127.0.0.1)', - ) - parser.add_argument( - '--port', - type=int, - default=8080, - help='Port for the HTTP server (default: 8080)', - ) - args = parser.parse_args() - - load_model(args.model, adapter_file=args.adapter_file) - - app.run(host=args.host, port=args.port) diff --git a/server/fuse.py b/server/fuse.py deleted file mode 100644 index 132982d..0000000 --- a/server/fuse.py +++ /dev/null @@ -1,105 +0,0 @@ -import argparse -import glob -import json -import shutil -from pathlib import Path -from typing import Any, Dict, Union - -from mlx.utils import tree_flatten, tree_unflatten - -from .tuner.lora import LoRALinear -from .tuner.utils import apply_lora_layers, dequantize -from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub - - -def parse_arguments() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") - parser.add_argument( - "--model", - default="mlx_model", - help="The path to the local model directory or Hugging Face repo.", - ) - parser.add_argument( - "--save-path", - default="lora_fused_model", - help="The path to save the fused model.", - ) - parser.add_argument( - "--adapter-file", - type=str, - default="adapters.npz", - help="Path to the trained adapter weights (npz or safetensors).", - ) - parser.add_argument( - "--hf-path", - type=str, - default=None, - help="Path to the original Hugging Face model. Required for upload if --model is a local directory.", - ) - parser.add_argument( - "--upload-repo", - help="The Hugging Face repo to upload the model to.", - type=str, - default=None, - ) - parser.add_argument( - "--de-quantize", - help="Generate a de-quantized model.", - action="store_true", - ) - return parser.parse_args() - - -def main() -> None: - print("Loading pretrained model") - args = parse_arguments() - - model_path = get_model_path(args.model) - model, config, tokenizer = fetch_from_hub(model_path) - - model.freeze() - model = apply_lora_layers(model, args.adapter_file) - - fused_linears = [ - (n, m.to_linear()) - for n, m in model.named_modules() - if isinstance(m, LoRALinear) - ] - - model.update_modules(tree_unflatten(fused_linears)) - - if args.de_quantize: - print("De-quantizing model") - model = dequantize(model) - - weights = dict(tree_flatten(model.parameters())) - - save_path = Path(args.save_path) - - save_weights(save_path, weights) - - py_files = glob.glob(str(model_path / "*.py")) - for file in py_files: - shutil.copy(file, save_path) - - tokenizer.save_pretrained(save_path) - - if args.de_quantize: - config.pop("quantization", None) - - with open(save_path / "config.json", "w") as fid: - json.dump(config, fid, indent=4) - - if args.upload_repo is not None: - hf_path = args.hf_path or ( - args.model if not Path(args.model).exists() else None - ) - if hf_path is None: - raise ValueError( - "Must provide original Hugging Face repo to upload local model." - ) - upload_to_hub(args.save_path, args.upload_repo, hf_path) - - -if __name__ == "__main__": - main() diff --git a/server/generate.py b/server/generate.py deleted file mode 100644 index a5fdd3d..0000000 --- a/server/generate.py +++ /dev/null @@ -1,125 +0,0 @@ -import argparse - -import mlx.core as mx - -from .utils import generate, load - -DEFAULT_MODEL_PATH = "mlx_model" -DEFAULT_PROMPT = "hello" -DEFAULT_MAX_TOKENS = 100 -DEFAULT_TEMP = 0.6 -DEFAULT_TOP_P = 1.0 -DEFAULT_SEED = 0 - - -def setup_arg_parser(): - """Set up and return the argument parser.""" - parser = argparse.ArgumentParser(description="LLM inference script") - parser.add_argument( - "--model", - type=str, - default="mlx_model", - help="The path to the local model directory or Hugging Face repo.", - ) - parser.add_argument( - "--trust-remote-code", - action="store_true", - help="Enable trusting remote code for tokenizer", - ) - parser.add_argument( - "--eos-token", - type=str, - default=None, - help="End of sequence token for tokenizer", - ) - parser.add_argument( - "--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model" - ) - parser.add_argument( - "--max-tokens", - "-m", - type=int, - default=DEFAULT_MAX_TOKENS, - help="Maximum number of tokens to generate", - ) - parser.add_argument( - "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" - ) - parser.add_argument( - "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" - ) - parser.add_argument("--seed", type=int, - default=DEFAULT_SEED, help="PRNG seed") - parser.add_argument( - "--ignore-chat-template", - action="store_true", - help="Use the raw prompt without the tokenizer's chat template.", - ) - parser.add_argument( - "--colorize", - action="store_true", - help="Colorize output based on T[0] probability", - ) - return parser - - -def colorprint(color, s): - color_codes = { - "black": 30, - "red": 31, - "green": 32, - "yellow": 33, - "blue": 34, - "magenta": 35, - "cyan": 36, - "white": 39, - } - ccode = color_codes.get(color, 30) - print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True) - - -def colorprint_by_t0(s, t0): - if t0 > 0.95: - color = "white" - elif t0 > 0.70: - color = "green" - elif t0 > 0.30: - color = "yellow" - else: - color = "red" - colorprint(color, s) - - -def main(args): - mx.random.seed(args.seed) - - # Building tokenizer_config - tokenizer_config = { - "trust_remote_code": True if args.trust_remote_code else None} - if args.eos_token is not None: - tokenizer_config["eos_token"] = args.eos_token - - model, tokenizer = load(args.model, tokenizer_config=tokenizer_config) - - if not args.ignore_chat_template and ( - hasattr(tokenizer, "apply_chat_template") - and tokenizer.chat_template is not None - ): - messages = [{"role": "user", "content": args.prompt}] - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - else: - prompt = args.prompt - - formatter = colorprint_by_t0 if args.colorize else None - - generate( - model, tokenizer, prompt, args.temp, args.max_tokens, True, formatter=formatter, top_p=args.top_p - ) - - -if __name__ == "__main__": - parser = setup_arg_parser() - args = parser.parse_args() - main(args) diff --git a/server/lora.py b/server/lora.py deleted file mode 100644 index a8a2912..0000000 --- a/server/lora.py +++ /dev/null @@ -1,248 +0,0 @@ -import argparse -import json -import math -from pathlib import Path - -import mlx.optimizers as optim -import numpy as np -from mlx.utils import tree_flatten - -from .tuner.trainer import TrainingArgs, evaluate, train -from .tuner.utils import linear_to_lora_layers -from .utils import generate, load - - -def build_parser(): - parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") - parser.add_argument( - "--model", - default="mlx_model", - help="The path to the local model directory or Hugging Face repo.", - ) - # Generation args - parser.add_argument( - "--max-tokens", - "-m", - type=int, - default=100, - help="The maximum number of tokens to generate", - ) - parser.add_argument( - "--temp", type=float, default=0.8, help="The sampling temperature" - ) - parser.add_argument( - "--prompt", - "-p", - type=str, - help="The prompt for generation", - default=None, - ) - - # Training args - parser.add_argument( - "--train", - action="store_true", - help="Do training", - ) - parser.add_argument( - "--data", - type=str, - default="data/", - help="Directory with {train, valid, test}.jsonl files", - ) - parser.add_argument( - "--lora-layers", - type=int, - default=16, - help="Number of layers to fine-tune", - ) - parser.add_argument("--batch-size", type=int, default=4, help="Minibatch size.") - parser.add_argument( - "--iters", type=int, default=1000, help="Iterations to train for." - ) - parser.add_argument( - "--val-batches", - type=int, - default=25, - help="Number of validation batches, -1 uses the entire validation set.", - ) - parser.add_argument( - "--learning-rate", type=float, default=1e-5, help="Adam learning rate." - ) - parser.add_argument( - "--steps-per-report", - type=int, - default=10, - help="Number of training steps between loss reporting.", - ) - parser.add_argument( - "--steps-per-eval", - type=int, - default=200, - help="Number of training steps between validations.", - ) - parser.add_argument( - "--resume-adapter-file", - type=str, - default=None, - help="Load path to resume training with the given adapter weights.", - ) - parser.add_argument( - "--adapter-file", - type=str, - default="adapters.npz", - help="Save/load path for the trained adapter weights.", - ) - parser.add_argument( - "--save-every", - type=int, - default=100, - help="Save the model every N iterations.", - ) - parser.add_argument( - "--test", - action="store_true", - help="Evaluate on the test set after training", - ) - parser.add_argument( - "--test-batches", - type=int, - default=500, - help="Number of test set batches, -1 uses the entire test set.", - ) - parser.add_argument( - "--max-seq-length", - type=int, - default=2048, - help="Maximum sequence length.", - ) - parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") - return parser - - -class Dataset: - """ - Light-weight wrapper to hold lines from a jsonl file - """ - - def __init__(self, path: Path, key: str = "text"): - if not path.exists(): - self._data = None - else: - with open(path, "r") as fid: - self._data = [json.loads(l) for l in fid] - self._key = key - - def __getitem__(self, idx: int): - return self._data[idx][self._key] - - def __len__(self): - if self._data is None: - return 0 - return len(self._data) - - -def load_dataset(args): - names = ("train", "valid", "test") - train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names) - if args.train and len(train) == 0: - raise ValueError( - "Training set not found or empty. Must provide training set for fine-tuning." - ) - if args.train and len(valid) == 0: - raise ValueError( - "Validation set not found or empty. Must provide validation set for fine-tuning." - ) - if args.test and len(test) == 0: - raise ValueError( - "Test set not found or empty. Must provide test set for evaluation." - ) - return train, valid, test - - -if __name__ == "__main__": - parser = build_parser() - args = parser.parse_args() - - np.random.seed(args.seed) - - print("Loading pretrained model") - model, tokenizer = load(args.model) - - # Freeze all layers - model.freeze() - # Convert linear layers to lora layers and unfreeze in the process - linear_to_lora_layers(model, args.lora_layers) - - p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 - print(f"Total parameters {p:.3f}M") - p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 - print(f"Trainable parameters {p:.3f}M") - - print("Loading datasets") - train_set, valid_set, test_set = load_dataset(args) - - # Resume training the given adapters. - if args.resume_adapter_file is not None: - print(f"Loading pretrained adapters from {args.resume_adapter_file}") - model.load_weights(args.resume_adapter_file, strict=False) - # init training args - trainingArgs = TrainingArgs( - batch_size=args.batch_size, - iters=args.iters, - val_batches=args.val_batches, - steps_per_report=args.steps_per_report, - steps_per_eval=args.steps_per_eval, - steps_per_save=args.save_every, - adapter_file=args.adapter_file, - max_seq_length=args.max_seq_length, - ) - if args.train: - print("Training") - model.train() - opt = optim.Adam(learning_rate=args.learning_rate) - # Train model - train( - model=model, - tokenizer=tokenizer, - args=trainingArgs, - optimizer=opt, - train_dataset=train_set, - val_dataset=valid_set, - ) - - # Load the LoRA adapter weights which we assume should exist by this point - if not Path(args.adapter_file).is_file(): - raise ValueError( - f"Adapter file {args.adapter_file} missing. " - "Use --train to learn and save the adapters.npz." - ) - model.load_weights(args.adapter_file, strict=False) - - if args.test: - print("Testing") - model.eval() - - test_loss = evaluate( - model=model, - dataset=test_set, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_batches=args.test_batches, - ) - - test_ppl = math.exp(test_loss) - - print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") - - if args.prompt is not None: - print("Generating") - model.eval() - generate( - model=model, - tokenizer=tokenizer, - temp=args.temp, - max_tokens=args.max_tokens, - prompt=args.prompt, - verbose=True, - ) diff --git a/server/merge.py b/server/merge.py deleted file mode 100644 index 2603653..0000000 --- a/server/merge.py +++ /dev/null @@ -1,159 +0,0 @@ -import argparse -import glob -import json -from pathlib import Path - -import mlx.core as mx -import numpy as np -import yaml -from mlx.utils import tree_flatten, tree_map - -from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub - - -def configure_parser() -> argparse.ArgumentParser: - """ - Configures and returns the argument parser for the script. - - Returns: - argparse.ArgumentParser: Configured argument parser. - """ - parser = argparse.ArgumentParser(description="Merge multiple models.") - - parser.add_argument("--config", type=str, help="Path to the YAML config.") - parser.add_argument( - "--mlx-path", - type=str, - default="mlx_merged_model", - help="Path to save the MLX model.", - ) - parser.add_argument( - "--upload-repo", - help="The Hugging Face repo to upload the model to.", - type=str, - default=None, - ) - return parser - - -def slerp(t, w1, w2, eps=1e-5): - """ - Spherical linear interpolation - - Args: - t (float): Interpolation weight in [0.0, 1.0] - w1 (mx.array): First input - w2 (mx.array): Second input - eps (float): Constant for numerical stability - Returns: - mx.array: Interpolated result - """ - t = float(t) - if t == 0: - return w1 - elif t == 1: - return w2 - # Normalize - v1 = w1 / mx.linalg.norm(w1) - v2 = w2 / mx.linalg.norm(w2) - # Angle - dot = mx.clip((v1 * v2).sum(), 0.0, 1.0) - theta = mx.arccos(dot) - sin_theta = mx.sin(theta + eps) - s1 = mx.sin(theta * (1 - t)) / sin_theta - s2 = mx.sin(theta * t) / sin_theta - return s1 * w1 + s2 * w2 - - -def merge_models(base_model, model, config): - method = config.get("method", None) - if method != "slerp": - raise ValueError(f"Merge method {method} not supported") - - num_layers = len(model.layers) - - def unpack_values(vals): - if isinstance(vals, (int, float)): - return np.full(num_layers, vals) - bins = len(vals) - 1 - sizes = [num_layers // bins] * bins - sizes[-1] = num_layers - sum(sizes[:-1]) - return np.concatenate( - [np.linspace(v1, v2, s) for v1, v2, s in zip(vals[:-1], vals[1:], sizes)] - ) - - param_list = config["parameters"]["t"] - params = {} - filter_keys = set() - for pl in param_list[:-1]: - params[pl["filter"]] = unpack_values(pl["value"]) - filter_keys.add(pl["filter"]) - default = unpack_values(param_list[-1]["value"]) - - for e in range(num_layers): - bl = base_model.layers[e] - l = model.layers[e] - base_weights = bl.parameters() - weights = l.parameters() - for k, w1 in base_weights.items(): - w2 = weights[k] - t = params.get(k, default)[e] - base_weights[k] = tree_map(lambda x, y: slerp(t, x, y), w1, w2) - base_model.update(base_weights) - - -def merge( - config: str, - mlx_path: str = "mlx_model", - upload_repo: str = None, -): - with open(config, "r") as fid: - merge_conf = yaml.safe_load(fid) - print("[INFO] Loading") - - model_paths = merge_conf.get("models", []) - if len(model_paths) < 2: - raise ValueError(f"Expected at least 2 models, got {len(models)}.") - - # Load all models - base_hf_path = model_paths[0] - base_path = get_model_path(base_hf_path) - base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True) - models = [] - for mp in model_paths[1:]: - model, config, _ = fetch_from_hub(get_model_path(mp), lazy=True) - base_type = base_config["model_type"] - model_type = config["model_type"] - if base_type != model_type: - raise ValueError( - f"Can only merge models of the same type," - f" but got {base_type} and {model_type}." - ) - models.append(model) - - # Merge models into base model - for m in models: - merge_models(base_model, m, merge_conf) - - # Save base model - mlx_path = Path(mlx_path) - weights = dict(tree_flatten(base_model.parameters())) - del models, base_model - save_weights(mlx_path, weights, donate_weights=True) - py_files = glob.glob(str(base_path / "*.py")) - for file in py_files: - shutil.copy(file, mlx_path) - - tokenizer.save_pretrained(mlx_path) - - with open(mlx_path / "config.json", "w") as fid: - json.dump(base_config, fid, indent=4) - - if upload_repo is not None: - upload_to_hub(mlx_path, upload_repo, base_hf_path) - - -if __name__ == "__main__": - parser = configure_parser() - args = parser.parse_args() - merge(**vars(args)) diff --git a/server/models/gemma.py b/server/models/gemma.py index 2bc782b..01b5b81 100644 --- a/server/models/gemma.py +++ b/server/models/gemma.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from functools import partial -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn diff --git a/server/models/mixtral.py b/server/models/mixtral.py deleted file mode 100644 index c2ddcb7..0000000 --- a/server/models/mixtral.py +++ /dev/null @@ -1,252 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn -import numpy as np - -from .base import BaseModelArgs -from .layers import RMSNorm - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int = 32000 - max_position_embeddings: int = 4096 * 32 - hidden_size: int = 4096 - intermediate_size: int = 14336 - num_hidden_layers: int = 32 - num_attention_heads: int = 32 - num_experts_per_tok: int = 2 - num_key_value_heads: int = 8 - num_local_experts: int = 8 - rms_norm_eps: float = 1e-5 - rope_theta: float = 1e6 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - -class MixtralAttention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_size = args.hidden_size - self.num_heads = args.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = args.num_key_value_heads - self.max_position_embeddings = args.max_position_embeddings - self.rope_theta = args.rope_theta - - self.repeats = self.num_heads // self.num_key_value_heads - - self.scale = self.head_dim**-0.5 - - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False - ) - self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False - ) - - self.rope = nn.RoPE( - self.head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( - 0, 2, 1, 3 - ) - - if self.repeats > 1: - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) - - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) - - -class MixtralBLockSparseTop2MLP(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.ffn_dim = args.intermediate_size - self.hidden_dim = args.hidden_size - - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - - self.act_fn = nn.silu - - def __call__(self, x: mx.array) -> mx.array: - current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states - - -class MixtralSparseMoeBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_dim = args.hidden_size - self.ffn_dim = args.intermediate_size - self.num_experts = args.num_local_experts - self.num_experts_per_tok = args.num_experts_per_tok - - # gating - self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - - self.experts = [ - MixtralBLockSparseTop2MLP(args=args) for _ in range(self.num_experts) - ] - - def __call__(self, x: mx.array) -> mx.array: - ne = self.num_experts_per_tok - orig_shape = x.shape - x = x.reshape(-1, x.shape[-1]) - - gates = self.gate(x) - - inds = mx.stop_gradient( - mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] - ) # TODO remove it once we figure out how to fine tune TopK in MOE - - scores = mx.softmax( - mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), - axis=-1, - ).astype(gates.dtype) - - if self.training: - mx.eval(inds) - inds = np.array(inds) - y = mx.zeros((x.shape[0], ne, x.shape[-1])) - for e, expert in enumerate(self.experts): - idx1, idx2 = map(mx.array, np.where(inds == e)) - if idx1.size == 0: - continue - y[idx1, idx2] = expert(x[idx1]) - - y = (y * scores[:, :, None]).sum(axis=1) - else: - y = [] - for xt, st, it in zip(x, scores, inds.tolist()): - yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1) - yt = (yt * st).sum(axis=-1) - y.append(yt[None, :]) - y = mx.concatenate(y) - - return y.reshape(orig_shape) - - -class MixtralDecoderLayer(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.hidden_size = args.hidden_size - - self.self_attn = MixtralAttention(args) - - self.block_sparse_moe = MixtralSparseMoeBlock(args) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.block_sparse_moe(self.post_attention_layernorm(h)) - out = h + r - return out, cache - - -class MixtralModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - cache=None, - ): - h = self.embed_tokens(inputs) - - mask = None - T = h.shape[1] - if T > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(T) - mask = mask.astype(h.dtype) - - if cache is None: - cache = [None] * len(self.layers) - - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) - - return self.norm(h), cache - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model_type = args.model_type - self.model = MixtralModel(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - cache=None, - ): - out, cache = self.model(inputs, cache) - return self.lm_head(out), cache - - @property - def layers(self): - return self.model.layers diff --git a/server/models/olmo.py b/server/models/olmo.py deleted file mode 100644 index f97ce6f..0000000 --- a/server/models/olmo.py +++ /dev/null @@ -1,180 +0,0 @@ -from dataclasses import dataclass -from sys import exit -from typing import Dict, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs -from .layers import LayerNorm - -try: - import hf_olmo -except ImportError: - print("To run olmo install ai2-olmo: pip install ai2-olmo") - exit(1) - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - d_model: int - n_layers: int - mlp_hidden_size: int - n_heads: int - vocab_size: int - embedding_size: int - model_type: str - rope_theta: float = 10000 - rope_traditional: bool = False - mlp_ratio: int = 4 - weight_tying: bool = False - - def __post_init__(self): - self.mlp_hidden_size = ( - self.mlp_hidden_size - if self.mlp_hidden_size is not None - else self.mlp_ratio * self.d_model - ) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.n_heads = args.n_heads - dim = args.d_model - - self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False) - self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False) - - self.att_norm = LayerNorm(dim, affine=False) - self.ff_norm = LayerNorm(dim, affine=False) - - head_dim = dim // self.n_heads - self.scale = head_dim**-0.5 - - self.att_proj = nn.Linear(dim, 3 * dim, bias=False) - self.attn_out = nn.Linear(dim, dim, bias=False) - - self.rope = nn.RoPE( - head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - ) - - self.args = args - - def attend( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = mx.split(self.att_proj(x), 3, axis=-1) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.attn_out(output), (keys, values) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - r, cache = self.attend(self.att_norm(x), mask, cache) - h = x + r - - x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1) - - out = h + self.ff_out(nn.silu(x2) * x1) - return out, cache - - -class Transformer(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.n_layers = args.n_layers - self.weight_tying = args.weight_tying - - self.wte = nn.Embedding(args.embedding_size, args.d_model) - self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)] - if not self.weight_tying: - self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False) - self.norm = LayerNorm(args.d_model, affine=False) - - def __call__( - self, - inputs: mx.array, - cache=None, - ): - h = self.wte(inputs) - - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) - - if cache is None: - cache = [None] * len(self.blocks) - - for e, block in enumerate(self.blocks): - h, cache[e] = block(h, mask, cache[e]) - - h = self.norm(h) - - if self.weight_tying: - return h @ self.wte.weight.T, cache - - return self.ff_out(h), cache - - -class OlmoModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.transformer = Transformer(args) - - def __call__( - self, - inputs: mx.array, - cache=None, - ): - return self.transformer(inputs, cache) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model_type = args.model_type - self.model = OlmoModel(args) - - def __call__( - self, - inputs: mx.array, - cache=None, - ): - return self.model(inputs, cache) - - @property - def layers(self): - return self.model.transformer.blocks diff --git a/server/models/phi.py b/server/models/phi.py deleted file mode 100644 index 85d1675..0000000 --- a/server/models/phi.py +++ /dev/null @@ -1,180 +0,0 @@ -import math -from dataclasses import dataclass -from typing import Tuple - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs -from .layers import LayerNorm - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - max_position_embeddings: int = 2048 - vocab_size: int = 51200 - hidden_size: int = 2560 - num_attention_heads: int = 32 - num_hidden_layers: int = 32 - num_key_value_heads: int = 32 - partial_rotary_factor: float = 0.4 - intermediate_size: int = 10240 - layer_norm_eps: float = 1e-5 - rope_theta: float = 10000.0 - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - -class PhiAttention(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.repeats = self.num_heads // self.num_key_value_heads - self.rope_theta = config.rope_theta - self.partial_rotary_factor = config.partial_rotary_factor - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=True - ) - self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True - ) - self.dense = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=True - ) - - self.rope = nn.RoPE( - int(self.partial_rotary_factor * self.head_dim), - traditional=False, - base=self.rope_theta, - ) - - def __call__(self, x, mask=None, cache=None): - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Extract some shapes - B, L, D = queries.shape - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.num_heads, self.head_dim).transpose( - 0, 2, 1, 3 - ) - keys = keys.reshape(B, L, self.num_key_value_heads, self.head_dim).transpose( - 0, 2, 1, 3 - ) - values = values.reshape( - B, L, self.num_key_value_heads, self.head_dim - ).transpose(0, 2, 1, 3) - - if self.repeats > 1: - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) - - # Add RoPE to the queries and keys and combine them with the cache - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - queries = queries.astype(mx.float32) - keys = keys.astype(mx.float32) - - # Finally perform the attention computation - scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores = scores + mask - - scores = mx.softmax(scores, axis=-1).astype(values.dtype) - values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.dense(values_hat), (keys, values) - - -class PhiMLP(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) - self.act = nn.GELU(approx="precise") - - def __call__(self, x) -> mx.array: - return self.fc2(self.act(self.fc1(x))) - - -class PhiDecoderLayer(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.self_attn = PhiAttention(config=config) - self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.mlp = PhiMLP(config) - - def __call__(self, x, mask, cache): - h = self.input_layernorm(x) - attn_h, cache = self.self_attn(h, mask, cache) - ff_h = self.mlp(h) - return attn_h + ff_h + x, cache - - -class PhiModel(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)] - self.final_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def __call__(self, x, mask, cache): - x = self.embed_tokens(x) - if cache is None: - cache = [None] * len(self.layers) - - for e, layer in enumerate(self.layers): - x, cache[e] = layer(x, mask, cache[e]) - return self.final_layernorm(x), cache - - -class Model(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.model_type = config.model_type - self.model = PhiModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) - - def __call__( - self, - x: mx.array, - mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: - mask = None - if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(x.dtype) - - y, cache = self.model(x, mask, cache) - return self.lm_head(y), cache - - @property - def layers(self): - return self.model.layers diff --git a/server/models/phixtral.py b/server/models/phixtral.py deleted file mode 100644 index 8537645..0000000 --- a/server/models/phixtral.py +++ /dev/null @@ -1,213 +0,0 @@ -import inspect -import math -from dataclasses import dataclass -from typing import Tuple - -import mlx.core as mx -import mlx.nn as nn -import numpy as np - -from .layers import LayerNorm - - -@dataclass -class ModelArgs: - model_type: str - max_sequence_length: int = 2048 - num_vocab: int = 51200 - model_dim: int = 2560 - num_heads: int = 32 - num_layers: int = 32 - rotary_dim: int = 32 - num_experts_per_tok: int = 2 - num_local_experts: int = 4 - - @classmethod - def from_dict(cls, params): - return cls( - **{ - k: v - for k, v in params.items() - if k in inspect.signature(cls).parameters - } - ) - - -class RoPEAttention(nn.Module): - def __init__(self, dims: int, num_heads: int, rotary_dim: int): - super().__init__() - - self.num_heads = num_heads - - self.rope = nn.RoPE(rotary_dim, traditional=False) - self.Wqkv = nn.Linear(dims, 3 * dims) - self.out_proj = nn.Linear(dims, dims) - - def __call__(self, x, mask=None, cache=None): - qkv = self.Wqkv(x) - queries, keys, values = mx.split(qkv, 3, axis=-1) - - # Extract some shapes - num_heads = self.num_heads - B, L, D = queries.shape - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - - # Add RoPE to the queries and keys and combine them with the cache - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - queries = queries.astype(mx.float32) - keys = keys.astype(mx.float32) - - # Finally perform the attention computation - scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores = scores + mask - - scores = mx.softmax(scores, axis=-1).astype(values.dtype) - values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.out_proj(values_hat), (keys, values) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.fc1 = nn.Linear(dim, hidden_dim) - self.fc2 = nn.Linear(hidden_dim, dim) - self.act = nn.GELU(approx="precise") - - def __call__(self, x) -> mx.array: - return self.fc2(self.act(self.fc1(x))) - - -class MOE(nn.Module): - def __init__(self, args: ModelArgs, dim: int, hidden_dim: int): - super().__init__() - self.dim = dim - self.hidden_dim = hidden_dim - self.num_experts = args.num_local_experts - self.num_experts_per_tok = args.num_experts_per_tok - self.mlp = [MLP(self.dim, self.hidden_dim) for _ in range(self.num_experts)] - self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False) - - def __call__(self, x: mx.array) -> mx.array: - ne = self.num_experts_per_tok - orig_shape = x.shape - x = x.reshape(-1, x.shape[-1]) - - gates = self.gate(x) - inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1))[:, :ne] - scores = mx.softmax( - mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), - axis=-1, - ).astype(gates.dtype) - - if self.training: - ys = [] - y = mx.zeros((x.shape[0], ne, x.shape[-1])) - for e, expert in enumerate(self.mlp): - idx1, idx2 = map(mx.array, np.where(inds == e)) - if idx1.size == 0: - continue - y[idx1, idx2] = expert(x[idx1]) - - y = (y * scores[..., None]).sum(axis=1) - else: - y = [] - for xt, st, it in zip(x, scores, inds.tolist()): - yt = mx.concatenate([self.mlp[e](xt)[:, None] for e in it], axis=-1) - yt = (yt * st).sum(axis=-1) - y.append(yt[None, :]) - y = mx.concatenate(y) - - return y.reshape(orig_shape) - - -class ParallelBlock(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - dims = config.model_dim - mlp_dims = dims * 4 - self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim) - self.ln = LayerNorm(dims) - self.moe = MOE(config, dims, mlp_dims) - - def __call__(self, x, mask, cache): - h = self.ln(x) - attn_h, cache = self.mixer(h, mask, cache) - ff_h = self.moe(h) - return attn_h + ff_h + x, cache - - -class TransformerDecoder(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.embd = Embd(config) - self.h = [ParallelBlock(config) for i in range(config.num_layers)] - - def __call__(self, x, mask, cache): - x = self.embd(x) - if cache is None: - cache = [None] * len(self.h) - - for e, layer in enumerate(self.h): - x, cache[e] = layer(x, mask, cache[e]) - return x, cache - - -class Embd(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.wte = nn.Embedding(config.num_vocab, config.model_dim) - - def __call__(self, x): - return self.wte(x) - - -class OutputHead(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.ln = LayerNorm(config.model_dim) - self.linear = nn.Linear(config.model_dim, config.num_vocab) - - def __call__(self, inputs): - return self.linear(self.ln(inputs)) - - -class Model(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.model_type = config.model_type - self.transformer = TransformerDecoder(config) - self.lm_head = OutputHead(config) - - def __call__( - self, - x: mx.array, - mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: - mask = None - if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(x.dtype) - - y, cache = self.transformer(x, mask, cache) - return self.lm_head(y), cache - - @property - def layers(self): - return self.transformer.h diff --git a/server/models/plamo.py b/server/models/plamo.py deleted file mode 100644 index ba02633..0000000 --- a/server/models/plamo.py +++ /dev/null @@ -1,224 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn -import numpy as np - -from .base import BaseModelArgs -from .layers import RMSNorm - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - n_shared_head: int = (8,) - rope_theta: float = 10000 - rope_traditional: bool = False - - -class Attention(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - head_dim = self.hidden_size // config.num_attention_heads - - self.q_num_heads = config.num_attention_heads - self.qk_dim = self.v_dim = head_dim - self.k_num_heads = self.v_num_heads = int( - np.ceil(self.q_num_heads / config.n_shared_head) - ) - - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear( - self.hidden_size, self.q_num_heads * self.qk_dim, bias=False - ) - self.k_proj = nn.Linear( - self.hidden_size, self.k_num_heads * self.qk_dim, bias=False - ) - self.v_proj = nn.Linear( - self.hidden_size, self.v_num_heads * self.v_dim, bias=False - ) - self.o_proj = nn.Linear( - self.q_num_heads * self.v_dim, self.hidden_size, bias=False - ) - self.rotary_emb = nn.RoPE( - head_dim, - traditional=config.rope_traditional, - base=config.rope_theta, - scale=1.0, - ) - - def __call__( - self, - hidden_states: mx.array, - attention_mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: - bsz, q_len, _ = hidden_states.shape - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Prepare the queries, keys and values for the attention computation - query_states = query_states.reshape( - bsz, q_len, self.q_num_heads, self.qk_dim - ).transpose(0, 2, 1, 3) - key_states = key_states.reshape( - bsz, q_len, self.k_num_heads, self.qk_dim - ).transpose(0, 2, 1, 3) - value_states = value_states.reshape( - bsz, q_len, self.v_num_heads, self.v_dim - ).transpose(0, 2, 1, 3) - - def _expand_kv(a: mx.array) -> mx.array: - a = mx.concatenate( - [mx.expand_dims(a, 1)] * self.config.n_shared_head, axis=1 - ) - return a.reshape([bsz, self.q_num_heads, q_len, -1]) - - # expand shared kv - assert self.k_num_heads == self.v_num_heads - key_states = _expand_kv(key_states) - value_states = _expand_kv(value_states) - - kv_seq_len = 0 - if cache is not None: - kv_seq_len += cache[0].shape[-2] - query_states = self.rotary_emb(query_states, offset=kv_seq_len) - key_states = self.rotary_emb(key_states, offset=kv_seq_len) - - if cache is not None: - # reuse k, v, self_attention - key_states = mx.concatenate([cache[0], key_states], axis=2) - value_states = mx.concatenate([cache[1], value_states], axis=2) - - scores = (query_states * self.scale) @ key_states.transpose(0, 1, 3, 2) - if attention_mask is not None: - scores += attention_mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ value_states).transpose(0, 2, 1, 3).reshape(bsz, q_len, -1) - - return self.o_proj(output), (key_states, value_states) - - -class MLP(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - - def __call__(self, x: mx.array) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) # type: ignore - - -class PlamoDecoderLayer(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.self_attn = Attention(config) - self.mlp = MLP(config) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def __call__( - self, - hidden_states: mx.array, - attention_mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> Tuple[Any, ...]: - # from LlamaDecoder - residual = hidden_states - - hidden_states = self.norm(hidden_states) - - # Self Attention - hidden_states_sa, cache = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - cache=cache, - ) - - # Fully Connected - hidden_states_mlp = self.mlp(hidden_states) - - hidden_states = residual + hidden_states_sa + hidden_states_mlp - return hidden_states, cache - - -class PlamoDecoder(nn.Module): - def __init__(self, config: ModelArgs) -> None: - super().__init__() - self.layers = [ - PlamoDecoderLayer(config) for _ in range(config.num_hidden_layers) - ] - - -class PlamoModel(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.config = config - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = PlamoDecoder(config) # type: ignore - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None, - ) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]: - h = self.embed_tokens(inputs) - - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(self.embed_tokens.weight.dtype) - - if cache is None: - past_key_values_length = 0 - cache = [None for _ in range(len(self.layers.layers))] - else: - if cache[0] is not None: - past_key_values_length = cache[0][0].shape[2] - - for e, layer in enumerate(self.layers.layers): - h, c = layer(h, mask, cache[e]) - if cache is not None: - cache[e] = c - else: - cache.append(c) - - return self.norm(h), cache - - -class Model(nn.Module): - def __init__(self, args: ModelArgs) -> None: - super().__init__() - self.model_type = args.model_type - self.model = PlamoModel(args) - self.lm_head: nn.Module = nn.Linear( - args.hidden_size, args.vocab_size, bias=False - ) - - def __call__( - self, - inputs: mx.array, - cache: Optional[List[Tuple[mx.array, mx.array]]] = None, - ) -> Tuple[mx.array, mx.array]: - out, cache = self.model(inputs, cache) - return self.lm_head(out), cache diff --git a/server/models/qwen.py b/server/models/qwen.py deleted file mode 100644 index 1660941..0000000 --- a/server/models/qwen.py +++ /dev/null @@ -1,164 +0,0 @@ -from dataclasses import dataclass -from typing import Tuple - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs -from .layers import RMSNorm - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int = 2048 - num_attention_heads: int = 16 - num_hidden_layers: int = 24 - kv_channels: int = 128 - max_position_embeddings: int = 8192 - layer_norm_epsilon: float = 1e-6 - intermediate_size: int = 11008 - no_bias: bool = True - vocab_size: int = 151936 - num_key_value_heads = None - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - hidden_size = args.hidden_size - self.num_attention_heads = args.num_attention_heads - - hidden_size_per_attention_head = hidden_size // self.num_attention_heads - - self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False) - - proj_size = args.kv_channels * self.num_attention_heads - - self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True) - self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias) - - self.scale = hidden_size_per_attention_head**-0.5 - - def __call__(self, x, mask=None, cache=None): - qkv = self.c_attn(x) - - q, k, v = mx.split(qkv, 3, axis=-1) - - B, L, _ = q.shape - - q = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) - k = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) - v = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - k_cache, v_cache = cache - q = self.rotary_emb(q, offset=k_cache.shape[2]) - k = self.rotary_emb(k, offset=k_cache.shape[2]) - k = mx.concatenate([k_cache, k], axis=2) - v = mx.concatenate([v_cache, v], axis=2) - - else: - q = self.rotary_emb(q) - k = self.rotary_emb(k) - - scores = (q * self.scale) @ k.transpose(0, 1, 3, 2) - - if mask is not None: - scores = scores + mask - - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.c_proj(v_hat), (k, v) - - -class MLP(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.w1 = nn.Linear( - args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias - ) - self.w2 = nn.Linear( - args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias - ) - self.c_proj = nn.Linear( - args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias - ) - - def __call__(self, x): - a1 = self.w1(x) - a2 = self.w2(x) - return self.c_proj(a1 * nn.silu(a2)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.ln_1 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - self.attn = Attention(args) - self.ln_2 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - self.mlp = MLP(args) - - def __call__(self, x, mask=None, cache=None): - residual = x - x = self.ln_1(x) - x, cache = self.attn(x, mask=mask, cache=cache) - residual = x + residual - x = self.ln_2(residual) - x = self.mlp(x) - x = x + residual - - return x, cache - - -class QwenModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.wte = nn.Embedding(args.vocab_size, args.hidden_size) - self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)] - self.ln_f = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - - def __call__(self, inputs, mask=None, cache=None): - x = self.wte(inputs) - - mask = None - T = x.shape[1] - if T > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(T) - mask = mask.astype(x.dtype) - - if cache is None: - cache = [None] * len(self.h) - - for e, layer in enumerate(self.h): - x, cache[e] = layer(x, mask, cache[e]) - - x = self.ln_f(x[:, T - 1 : T, :]) - return x, cache - - -class Model(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.model_type = config.model_type - self.transformer = QwenModel(config) - self.lm_head = nn.Linear( - config.hidden_size, config.vocab_size, bias=not config.no_bias - ) - - def __call__( - self, - x: mx.array, - mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: - y, cache = self.transformer(x, mask, cache) - return self.lm_head(y), cache diff --git a/server/models/qwen2.py b/server/models/qwen2.py deleted file mode 100644 index f0c1917..0000000 --- a/server/models/qwen2.py +++ /dev/null @@ -1,198 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs -from .layers import RMSNorm - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - num_key_value_heads: int = None - rope_theta: float = 1000000 - rope_traditional: bool = False - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - if self.rope_scaling: - required_keys = {"factor", "type"} - if not all(key in self.rope_scaling for key in required_keys): - raise ValueError(f"rope_scaling must contain keys {required_keys}") - - if self.rope_scaling["type"] != "linear": - raise ValueError("rope_scaling 'type' currently only supports 'linear'") - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - self.repeats = n_heads // n_kv_heads - - head_dim = args.hidden_size // n_heads - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - - rope_scale = ( - 1 / args.rope_scaling["factor"] - if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" - else 1 - ) - self.rope = nn.RoPE( - head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if self.repeats > 1: - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) - - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args) - self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out, cache - - -class Qwen2Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - cache=None, - ): - h = self.embed_tokens(inputs) - - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) - - if cache is None: - cache = [None] * len(self.layers) - - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) - - return self.norm(h), cache - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model_type = args.model_type - self.model = Qwen2Model(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - cache=None, - ): - out, cache = self.model(inputs, cache) - return self.lm_head(out), cache - - @staticmethod - def sanitize(weights): - # Remove unused precomputed rotary freqs - return { - k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k - } - - @property - def layers(self): - return self.model.layers diff --git a/server/models/stablelm_epoch.py b/server/models/stablelm_epoch.py deleted file mode 100644 index 6b13012..0000000 --- a/server/models/stablelm_epoch.py +++ /dev/null @@ -1,186 +0,0 @@ -import math -from dataclasses import dataclass -from typing import Tuple - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs -from .layers import LayerNorm - - -@dataclass -class ModelArgs(BaseModelArgs): - max_position_embeddings: int - model_type: str - vocab_size: int - hidden_size: int - num_attention_heads: int - num_hidden_layers: int - num_key_value_heads: int - rope_pct: float - intermediate_size: int - norm_eps: float - rope_theta: float - use_qkv_bias: bool - - -class Attention(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.repeats = self.num_heads // self.num_key_value_heads - self.rope_theta = config.rope_theta - self.rope_pct = config.rope_pct - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias - ) - self.k_proj = nn.Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=config.use_qkv_bias, - ) - self.v_proj = nn.Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=config.use_qkv_bias, - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False - ) - - self.rope = nn.RoPE( - int(self.rope_pct * self.head_dim), - traditional=False, - base=self.rope_theta, - ) - - def __call__(self, x, mask=None, cache=None): - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Extract some shapes - B, L, D = queries.shape - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.num_heads, self.head_dim).transpose( - 0, 2, 1, 3 - ) - keys = keys.reshape(B, L, self.num_key_value_heads, self.head_dim).transpose( - 0, 2, 1, 3 - ) - values = values.reshape( - B, L, self.num_key_value_heads, self.head_dim - ).transpose(0, 2, 1, 3) - - if self.repeats > 1: - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) - - # Add RoPE to the queries and keys and combine them with the cache - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - queries = queries.astype(mx.float32) - keys = keys.astype(mx.float32) - - # Finally perform the attention computation - scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores = scores + mask - - scores = mx.softmax(scores, axis=-1).astype(values.dtype) - values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.o_proj(values_hat), (keys, values) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class DecoderLayer(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.self_attn = Attention(config=config) - self.input_layernorm = LayerNorm(config.hidden_size, eps=config.norm_eps) - self.mlp = MLP(config.hidden_size, config.intermediate_size) - self.input_layernorm = LayerNorm(config.hidden_size, eps=config.norm_eps) - self.post_attention_layernorm = LayerNorm( - config.hidden_size, eps=config.norm_eps - ) - - def __call__(self, x, mask, cache): - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out, cache - - -class StableLM(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)] - self.norm = LayerNorm(config.hidden_size, eps=config.norm_eps) - - def __call__(self, x, mask, cache): - x = self.embed_tokens(x) - if cache is None: - cache = [None] * len(self.layers) - - for e, layer in enumerate(self.layers): - x, cache[e] = layer(x, mask, cache[e]) - return self.norm(x), cache - - -class Model(nn.Module): - def __init__(self, config: ModelArgs): - super().__init__() - self.model_type = config.model_type - self.model = StableLM(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - def __call__( - self, - x: mx.array, - mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: - mask = None - if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(x.dtype) - - y, cache = self.model(x, mask, cache) - return self.lm_head(y), cache - - @property - def layers(self): - return self.model.layers diff --git a/server/requirements.txt b/server/requirements.txt index 049049e..06fc939 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,5 +1,6 @@ -mlx>=0.4 -numpy -transformers>=4.38.0 -protobuf -pyyaml +chromadb==0.4.23 +huggingface_hub==0.20.3 +mlx==0.4.0 +mlx_data==0.0.2 +transformers==4.38.1 +pyinstaller==6.4.0 diff --git a/server/tuner/__init__.py b/server/tuner/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/tuner/lora.py b/server/tuner/lora.py deleted file mode 100644 index adc1f8c..0000000 --- a/server/tuner/lora.py +++ /dev/null @@ -1,103 +0,0 @@ -import math - -import mlx.core as mx -import mlx.nn as nn - - -class LoRALinear(nn.Module): - @staticmethod - def from_linear( - linear: nn.Linear, - r: int = 8, - lora_alpha: float = 16, - lora_dropout: float = 0.05, - scale: float = 10.0, - ): - # TODO remove when input_dims and output_dims are attributes - # on linear and quantized linear - output_dims, input_dims = linear.weight.shape - if isinstance(linear, nn.QuantizedLinear): - input_dims *= 32 // linear.bits - lora_lin = LoRALinear( - input_dims=input_dims, - output_dims=output_dims, - r=r, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - scale=scale, - ) - lora_lin.linear = linear - return lora_lin - - def to_linear(self, de_quantize: bool = False): - linear = self.linear - bias = "bias" in linear - weight = linear.weight - is_quantized = isinstance(linear, nn.QuantizedLinear) - - # Use the same type as the linear weight if not quantized - dtype = weight.dtype - - if is_quantized: - dtype = mx.float16 - weight = mx.dequantize( - weight, - linear.scales, - linear.biases, - linear.group_size, - linear.bits, - ) - output_dims, input_dims = weight.shape - fused_linear = nn.Linear(input_dims, output_dims, bias=bias) - - lora_b = (self.scale * self.lora_b.T).astype(dtype) - lora_a = self.lora_a.T.astype(dtype) - fused_linear.weight = weight + lora_b @ lora_a - if bias: - fused_linear.bias = linear.bias - - if is_quantized and not de_quantize: - fused_linear = nn.QuantizedLinear.from_linear( - fused_linear, - linear.group_size, - linear.bits, - ) - - return fused_linear - - def __init__( - self, - input_dims: int, - output_dims: int, - r: int = 8, - lora_alpha: float = 16, - lora_dropout: float = 0.0, - scale: float = 10.0, - bias: bool = False, - ): - super().__init__() - - # Regular linear layer weights - self.linear = nn.Linear(input_dims, output_dims, bias=bias) - - self.lora_dropout = nn.Dropout(p=lora_dropout) - - # Scale for low-rank update - self.scale = scale * (lora_alpha / r) - - # Low rank lora weights - scale = 1 / math.sqrt(input_dims) - self.lora_a = mx.random.uniform( - low=-scale, - high=scale, - shape=(input_dims, r), - ) - self.lora_b = mx.zeros(shape=(r, output_dims)) - - def __call__(self, x): - dtype = self.linear.weight.dtype - if isinstance(self.linear, nn.QuantizedLinear): - dtype = self.linear.scales.dtype - y = self.linear(x.astype(dtype)) - z = (self.lora_dropout(x) @ self.lora_a) @ self.lora_b - return y + self.scale * z diff --git a/server/tuner/trainer.py b/server/tuner/trainer.py deleted file mode 100644 index bc3f281..0000000 --- a/server/tuner/trainer.py +++ /dev/null @@ -1,261 +0,0 @@ -import os -import time -from dataclasses import dataclass, field - -import mlx.core as mx -import mlx.nn as nn -import numpy as np -from mlx.utils import tree_flatten - - -@dataclass -class TrainingArgs: - lora_layers: int = field( - default=16, metadata={"help": "Number of layers to fine-tune"} - ) - batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) - iters: int = field(default=100, metadata={"help": "Iterations to train for."}) - val_batches: int = field( - default=25, - metadata={ - "help": "Number of validation batches, -1 uses the entire validation set." - }, - ) - steps_per_report: int = field( - default=10, - metadata={"help": "Number of training steps between loss reporting."}, - ) - steps_per_eval: int = field( - default=200, metadata={"help": "Number of training steps between validations."} - ) - steps_per_save: int = field( - default=100, metadata={"help": "Save the model every number steps"} - ) - max_seq_length: int = field( - default=2048, metadata={"help": "Maximum sequence length."} - ) - adapter_file: str = field( - default="adapter.npz", - metadata={"help": "Save/load path for the trained adapter weights."}, - ) - - -def default_loss(model, inputs, targets, lengths): - logits, _ = model(inputs) - logits = logits.astype(mx.float32) - - length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] - - ce = nn.losses.cross_entropy(logits, targets) * length_mask - ntoks = length_mask.sum() - ce = ce.sum() / ntoks - - return ce, ntoks - - -def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): - while True: - # Shuffle indices - indices = np.arange(len(dataset)) - indices = np.random.permutation(indices) - # Collect batches from dataset - for i in range(0, len(indices) - batch_size + 1, batch_size): - # Encode batch - batch = [ - tokenizer.encode(dataset[indices[i + j]]) for j in range(batch_size) - ] - lengths = [len(x) for x in batch] - - if max(lengths) > max_seq_length: - print( - f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " - f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " - "Consider pre-splitting your data to save memory." - ) - - # Pad to the max length - max_length_in_batch = min(max(lengths), max_seq_length) - batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32) - - for j in range(batch_size): - truncated_length = min(lengths[j], max_seq_length) - batch_arr[j, :truncated_length] = batch[j][:truncated_length] - lengths[j] = ( - truncated_length # Update lengths to match truncated lengths - ) - batch = mx.array(batch_arr) - - yield batch[:, :-1], batch[:, 1:], mx.array(lengths) - - if not train: - break - - -def evaluate( - model, - dataset, - tokenizer, - batch_size, - num_batches, - max_seq_length=2048, - loss: callable = default_loss, - iterate_batches: callable = iterate_batches, -): - all_losses = [] - ntokens = 0 - for it, batch in zip( - range(num_batches), - iterate_batches( - dataset=dataset, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_length=max_seq_length, - ), - ): - losses, toks = loss(model, *batch) - all_losses.append((losses * toks).item()) - ntokens += toks.item() - - return np.sum(all_losses) / ntokens - - -class TrainingCallback: - - def on_train_loss_report(self, train_info: dict): - """Called to report training loss at specified intervals.""" - pass - - def on_val_loss_report(self, val_info: dict): - """Called to report validation loss at specified intervals or the beginning.""" - pass - - -def train( - model, - tokenizer, - optimizer, - train_dataset, - val_dataset, - args: TrainingArgs = TrainingArgs(), - loss: callable = default_loss, - iterate_batches: callable = iterate_batches, - training_callback: TrainingCallback = None, -): - print(f"Starting training..., iters: {args.iters}") - - # Create checkpoints directory if it does not exist - if not os.path.exists("checkpoints"): - os.makedirs("checkpoints") - - # Create value and grad function for loss - loss_value_and_grad = nn.value_and_grad(model, loss) - - losses = [] - n_tokens = 0 - trained_tokens = 0 - # Main training loop - start = time.perf_counter() - for it, batch in zip( - range(args.iters), - iterate_batches( - dataset=train_dataset, - tokenizer=tokenizer, - batch_size=args.batch_size, - max_seq_length=args.max_seq_length, - train=True, - ), - ): - # Forward and backward pass - (lvalue, toks), grad = loss_value_and_grad(model, *batch) - - # Model update - optimizer.update(model, grad) - - mx.eval(model.parameters(), optimizer.state, lvalue) - - # Record loss - losses.append(lvalue.item()) - n_tokens += toks.item() - - # Report training loss if needed - if (it + 1) % args.steps_per_report == 0: - train_loss = np.mean(losses) - - stop = time.perf_counter() - learning_rate = optimizer.learning_rate.item() - it_sec = args.steps_per_report / (stop - start) - tokens_sec = float(n_tokens) / (stop - start) - trained_tokens += n_tokens - print( - f"Iter {it + 1}: Train loss {train_loss:.3f}, " - f"Learning Rate {learning_rate:.3e}, " - f"It/sec {it_sec:.3f}, " - f"Tokens/sec {tokens_sec:.3f}, " - f"Trained Tokens {trained_tokens}" - ) - - if training_callback is not None: - train_info = { - "iteration": it + 1, - "train_loss": train_loss, - "learning_rate": learning_rate, - "iterations_per_second": it_sec, - "tokens_per_second": tokens_sec, - "trained_tokens": trained_tokens, - } - training_callback.on_train_loss_report(train_info) - - losses = [] - n_tokens = 0 - start = time.perf_counter() - - # Report validation loss if needed - if it == 0 or (it + 1) % args.steps_per_eval == 0: - stop = time.perf_counter() - val_loss = evaluate( - model=model, - dataset=val_dataset, - loss=loss, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_batches=args.val_batches, - max_seq_length=args.max_seq_length, - iterate_batches=iterate_batches, - ) - val_time = time.perf_counter() - stop - print( - f"Iter {it + 1}: " - f"Val loss {val_loss:.3f}, " - f"Val took {val_time:.3f}s" - ) - - if training_callback is not None: - val_info = { - "iteration": it + 1, - "val_loss": val_loss, - "val_time": val_time, - } - training_callback.on_val_loss_report(val_info) - - start = time.perf_counter() - - # Save adapter weights if needed - if (it + 1) % args.steps_per_save == 0: - checkpoint_adapter_file = f"checkpoints/{it + 1}_{args.adapter_file}" - save_adapter(model=model, adapter_file=checkpoint_adapter_file) - print( - f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}." - ) - - # save final adapter weights - save_adapter(model=model, adapter_file=args.adapter_file) - print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.") - - -def save_adapter( - model: nn.Module, - adapter_file: str, -): - flattened_tree = tree_flatten(model.trainable_parameters()) - - mx.savez(adapter_file, **dict(flattened_tree)) diff --git a/server/tuner/utils.py b/server/tuner/utils.py deleted file mode 100644 index 579fca5..0000000 --- a/server/tuner/utils.py +++ /dev/null @@ -1,144 +0,0 @@ -import os - -import mlx.core as mx -import mlx.nn as nn -from mlx.utils import tree_unflatten - -from .lora import LoRALinear - - -def linear_to_lora_layers(model: nn.Module, num_lora_layers: int): - """ - Convert some of the models linear layers to lora layers. - - Args: - model (nn.Module): The neural network model. - num_lora_layers (int): The number of blocks to convert to lora layers - starting from the last layer. - """ - - def check_lora_layers(num_model): - if num_lora_layers > num_model: - raise ValueError( - f"Requested {num_lora_layers} LoRA layers " - f"but the model only has {num_model} layers." - ) - - if model.model_type in [ - "mistral", - "llama", - "phi", - "mixtral", - "stablelm_epoch", - "qwen2", - "gemma", - ]: - check_lora_layers(len(model.model.layers)) - - for l in model.model.layers[len(model.model.layers) - num_lora_layers :]: - l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) - l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) - if hasattr(l, "block_sparse_moe"): - l.block_sparse_moe.gate = LoRALinear.from_linear( - l.block_sparse_moe.gate - ) - elif model.model_type == "olmo": - check_lora_layers(len(model.model.transformer.blocks)) - - for l in model.model.transformer.blocks[ - len(model.model.transformer.blocks) - num_lora_layers : - ]: - l.att_proj = LoRALinear.from_linear(l.att_proj) - elif model.model_type == "phi-msft": - check_lora_layers(len(model.transformer.h)) - - for l in model.transformer.h[len(model.transformer.h) - num_lora_layers :]: - l.mixer.Wqkv = LoRALinear.from_linear(l.mixer.Wqkv) - l.moe.gate = LoRALinear.from_linear(l.moe.gate) - - else: - raise ValueError(f"Lora does not support {model.model_type}") - - -def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module: - """ - Apply LoRA layers to the model. - - Args: - model (nn.Module): The neural network model. - adapter_file (str): Path to the adapter configuration file. - - Returns: - nn.Module: The updated model with LoRA layers applied. - """ - if not os.path.exists(adapter_file): - raise FileNotFoundError(f"The adapter file does not exist: {adapter_file}") - - adapters = list(mx.load(adapter_file).items()) - - linear_replacements = [] - lora_layers = set( - [name.replace(".lora_a", "").replace(".lora_b", "") for name, _ in adapters] - ) - for name, module in model.named_modules(): - if name in lora_layers: - replacement_module = LoRALinear.from_linear(module) - linear_replacements.append((name, replacement_module)) - - model.update_modules(tree_unflatten(linear_replacements)) - - model.update(tree_unflatten(adapters)) - - return model - - -def dequantize(model: nn.Module) -> nn.Module: - """ - Dequantize the quantized linear layers in the model. - - Args: - model (nn.Module): The model with quantized linear layers. - - Returns: - nn.Module: The model with dequantized layers. - """ - de_quantize_layers = [] - for name, module in model.named_modules(): - if isinstance(module, nn.QuantizedLinear): - bias = "bias" in module - weight = module.weight - weight = mx.dequantize( - weight, - module.scales, - module.biases, - module.group_size, - module.bits, - ).astype(mx.float16) - output_dims, input_dims = weight.shape - linear = nn.Linear(input_dims, output_dims, bias=bias) - linear.weight = weight - if bias: - linear.bias = module.bias - de_quantize_layers.append((name, linear)) - if len(de_quantize_layers) > 0: - model.update_modules(tree_unflatten(de_quantize_layers)) - return model - - -def remove_lora_layers(model: nn.Module) -> nn.Module: - """ - Remove the LoRA layers from the model. - - Args: - model (nn.Module): The model with LoRA layers. - - Returns: - nn.Module: The model without LoRA layers. - """ - reset_layers = [] - for name, module in model.named_modules(): - if isinstance(module, LoRALinear): - reset_layers.append((name, module.linear)) - if len(reset_layers) > 0: - model.update_modules(tree_unflatten(reset_layers)) - return model diff --git a/server/utils.py b/server/utils.py index b55c7c1..c939fc1 100644 --- a/server/utils.py +++ b/server/utils.py @@ -7,7 +7,7 @@ import logging import time from pathlib import Path -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -16,9 +16,6 @@ from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer -# Local imports -from .tuner.utils import apply_lora_layers - # Constants MODEL_REMAPPING = { "mistral": "llama", # mistral is compatible with llama @@ -382,7 +379,8 @@ def load( model = load_model(model_path, lazy) if adapter_file is not None: - model = apply_lora_layers(model, adapter_file) + # TODO: Apply LoRA layers + # model = apply_lora_layers(model, adapter_file) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)