Skip to content

Commit

Permalink
[hf_modelzoo] Adds import rust model from Huggingface (#3125)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Apr 26, 2024
1 parent 547c8cf commit 832bb70
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 5 deletions.
2 changes: 1 addition & 1 deletion extensions/tokenizers/src/main/python/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def converter_args():
parser.add_argument("-f",
"--output-format",
default="PyTorch",
choices=["PyTorch", "OnnxRuntime"],
choices=["PyTorch", "OnnxRuntime", "Rust"],
help="Model output format")
parser.add_argument("-r",
"--retry-failed",
Expand Down
72 changes: 68 additions & 4 deletions extensions/tokenizers/src/main/python/huggingface_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from argparse import Namespace

import onnx
import safetensors_convert
import torch
from huggingface_hub import hf_hub_download
from transformers import pipeline, AutoTokenizer
from huggingface_hub import hf_hub_download, HfApi
from transformers import pipeline, AutoTokenizer, AutoConfig

from metadata import HuggingfaceMetadata
from shasum import sha1_sum
Expand All @@ -33,6 +34,12 @@ def __init__(self, tokenizer, model):
self.model = model


class ModelHolder(object):

def __init__(self, config):
self.config = config


class HuggingfaceConverter:

def __init__(self):
Expand All @@ -43,10 +50,13 @@ def __init__(self):
self.translator = None
self.inputs = None
self.outputs = None
self.api = HfApi()

def save_model(self, model_info, args: Namespace, temp_dir: str):
if args.output_format == "OnnxRuntime":
return self.save_onnx_model(model_info, args, temp_dir)
elif args.output_format == "Rust":
return self.save_rust_model(model_info, args, temp_dir)
else:
return self.save_pytorch_model(model_info, args, temp_dir)

Expand All @@ -71,13 +81,67 @@ def save_onnx_model(self, model_info, args: Namespace, temp_dir: str):
include_types = "token_type_id" in inputs

tokenizer = AutoTokenizer.from_pretrained(model_id)
hf_pipeline = PipelineHolder(tokenizer, model)
config = AutoConfig.from_pretrained(model_id)
hf_pipeline = PipelineHolder(tokenizer, ModelHolder(config))
size = self.save_to_model_zoo(model_info, args.output_dir,
"OnnxRuntime", temp_dir, hf_pipeline,
include_types)

return True, None, size

def save_rust_model(self, model_info, args: Namespace, temp_dir: str):
model_id = model_info.modelId

config = AutoConfig.from_pretrained(model_id)
if hasattr(config, "model_type"):
if config.model_type == "bert":
include_types = True
elif config.model_type == "distilbert":
include_types = False
else:
return False, f"Unsupported model_type: {config.model_type}", -1

logging.info(f"Saving rust model: {model_id} ...")

if not os.path.exists(temp_dir):
os.makedirs(temp_dir)

tokenizer = AutoTokenizer.from_pretrained(model_id)
hf_pipeline = PipelineHolder(tokenizer, ModelHolder(config))
try:
# Save tokenizer.json to temp dir
self.save_tokenizer(hf_pipeline, temp_dir)
except Exception as e:
logging.warning(f"Failed to save tokenizer: {model_id}.")
logging.warning(e, exc_info=True)
return False, "Failed to save tokenizer", -1

target = os.path.join(temp_dir, "model.safetensors")
model = self.api.model_info(model_id, files_metadata=True)
has_sf_file = False
has_pt_file = False
for sibling in model.siblings:
if sibling.rfilename == "model.safetensors":
has_sf_file = True
elif sibling.rfilename == "pytorch_model.bin":
has_pt_file = True

if has_sf_file:
file = hf_hub_download(repo_id=model_id,
filename="model.safetensors")
shutil.copyfile(file, target)
elif has_pt_file:
file = hf_hub_download(repo_id=model_id,
filename="pytorch_model.bin")
safetensors_convert.convert_file(file, target)
else:
return False, f"No model file found for: {model_id}", -1

size = self.save_to_model_zoo(model_info, args.output_dir, "Rust",
temp_dir, hf_pipeline, include_types)

return True, None, size

def save_pytorch_model(self, model_info, args: Namespace, temp_dir: str):
model_id = model_info.modelId
if not os.path.exists(temp_dir):
Expand Down Expand Up @@ -134,7 +198,7 @@ def save_tokenizer(hf_pipeline, temp_dir: str):
hf_pipeline.tokenizer.save_pretrained(temp_dir)
# only keep tokenizer.json file
for path in os.listdir(temp_dir):
if path != "tokenizer.json":
if path != "tokenizer.json" and path != "tokenizer_config.json":
os.remove(os.path.join(temp_dir, path))

def jit_trace_model(self, hf_pipeline, model_id: str, temp_dir: str,
Expand Down
1 change: 1 addition & 0 deletions extensions/tokenizers/src/main/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ transformers
torch
protobuf==3.20.2
optimum[exporters,onnxruntime]
safetensors
101 changes: 101 additions & 0 deletions extensions/tokenizers/src/main/python/safetensors_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#!/usr/bin/env python
#
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import os
from collections import defaultdict
from typing import List, Dict

import torch
from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file


def _remove_duplicate_names(
state_dict: Dict[str, torch.Tensor],
*,
preferred_names: List[str] = None,
discard_names: List[str] = None,
) -> Dict[str, List[str]]:
if preferred_names is None:
preferred_names = []
preferred_names = set(preferred_names)
if discard_names is None:
discard_names = []
discard_names = set(discard_names)

shareds = _find_shared_tensors(state_dict)
to_remove = defaultdict(list)
for shared in shareds:
complete_names = set(
[name for name in shared if _is_complete(state_dict[name])])
if not complete_names:
if len(shared) == 1:
# Force contiguous
name = list(shared)[0]
state_dict[name] = state_dict[name].clone()
complete_names = {name}
else:
raise RuntimeError(
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
)

keep_name = sorted(list(complete_names))[0]

preferred = complete_names.difference(discard_names)
if preferred:
keep_name = sorted(list(preferred))[0]

if preferred_names:
preferred = preferred_names.intersection(complete_names)
if preferred:
keep_name = sorted(list(preferred))[0]
for name in sorted(shared):
if name != keep_name:
to_remove[keep_name].append(name)
return to_remove


def convert_file(pt_filename: str, sf_filename: str):
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
to_removes = _remove_duplicate_names(loaded)

metadata = {"format": "pt"}
for kept_name, to_remove_group in to_removes.items():
for to_remove in to_remove_group:
if to_remove not in metadata:
metadata[to_remove] = kept_name
del loaded[to_remove]
# Force tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}

dir_name = os.path.dirname(sf_filename)
os.makedirs(dir_name, exist_ok=True)
save_file(loaded, sf_filename, metadata=metadata)
check_file_size(sf_filename, pt_filename)
reloaded = load_file(sf_filename)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")


def check_file_size(sf_filename: str, pt_filename: str):
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size

if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
""")

0 comments on commit 832bb70

Please sign in to comment.