Skip to content

Commit

Permalink
[MLC-28] server: added Bert MLX model with conversions for e5 models
Browse files Browse the repository at this point in the history
  • Loading branch information
stockeh committed Mar 1, 2024
1 parent 40c1a78 commit cb4f103
Show file tree
Hide file tree
Showing 6 changed files with 1,012 additions and 91 deletions.
3 changes: 1 addition & 2 deletions server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .convert import convert
from .utils import generate, load
from .utils import generate, load, convert

__version__ = "0.0.14"
88 changes: 3 additions & 85 deletions server/convert.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,6 @@
import argparse
import copy
import glob
import json
import shutil
from pathlib import Path
from typing import Tuple

import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten

from .utils import (
fetch_from_hub,
get_model_path,
linear_class_predicate,
save_weights,
upload_to_hub,
)
from .utils import convert


def configure_parser() -> argparse.ArgumentParser:
Expand All @@ -30,7 +14,8 @@ def configure_parser() -> argparse.ArgumentParser:
description="Convert Hugging Face model to MLX format"
)

parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
parser.add_argument("--hf-path", type=str,
help="Path to the Hugging Face model.")
parser.add_argument(
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
)
Expand Down Expand Up @@ -59,73 +44,6 @@ def configure_parser() -> argparse.ArgumentParser:
return parser


def quantize_model(
model: nn.Module, config: dict, q_group_size: int, q_bits: int
) -> Tuple:
"""
Applies quantization to the model weights.
Args:
model (nn.Module): The model to be quantized.
config (dict): Model configuration.
q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization.
Returns:
Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)

nn.QuantizedLinear.quantize_module(
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
quantized_weights = dict(tree_flatten(model.parameters()))

return quantized_weights, quantized_config


def convert(
hf_path: str,
mlx_path: str = "mlx_model",
quantize: bool = False,
q_group_size: int = 64,
q_bits: int = 4,
dtype: str = "float16",
upload_repo: str = None,
):
print("[INFO] Loading")
model_path = get_model_path(hf_path)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)

weights = dict(tree_flatten(model.parameters()))
dtype = mx.float16 if quantize else getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}

if quantize:
print("[INFO] Quantizing")
model.load_weights(list(weights.items()))
weights, config = quantize_model(model, config, q_group_size, q_bits)

if isinstance(mlx_path, str):
mlx_path = Path(mlx_path)

del model
save_weights(mlx_path, weights, donate_weights=True)

py_files = glob.glob(str(model_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)

tokenizer.save_pretrained(mlx_path)

with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)

if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, hf_path)


if __name__ == "__main__":
parser = configure_parser()
args = parser.parse_args()
Expand Down
Loading

0 comments on commit cb4f103

Please sign in to comment.