Skip to content

Commit

Permalink
Merge branch 'main' into notie_embd
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley authored Nov 13, 2023
2 parents d1df05c + d11ba82 commit 70b40ec
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 8 deletions.
25 changes: 21 additions & 4 deletions llmfoundry/utils/checkpoint_conversion_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

import numpy as np
import sentencepiece as spm
from transformers import AutoTokenizer, PreTrainedTokenizer
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)

log = logging.getLogger(__name__)

Expand All @@ -35,8 +36,9 @@ def _get_weight_data_type(data_type: str):

# TODO: move this functionality to composer once the bug fixes are upstreamed
def get_hf_tokenizer_from_composer_state_dict(
state_dict: Dict[str, Any],
tokenizer_save_dir: Optional[str] = None
state_dict: Dict[str, Any],
trust_remote_code: bool,
tokenizer_save_dir: Optional[str] = None,
) -> Optional[PreTrainedTokenizer]:
if 'state' not in state_dict:
raise RuntimeError(
Expand Down Expand Up @@ -85,7 +87,8 @@ def get_hf_tokenizer_from_composer_state_dict(
with open(tokenizer_file_path, 'wb') as _tmp_file:
_tmp_file.write(s.serialized_model_proto())

hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_save_dir)
hf_tokenizer = load_tokenizer(tokenizer_save_dir,
trust_remote_code=trust_remote_code)

# remove 'name_or_path'
hf_tokenizer.name_or_path = ''
Expand All @@ -94,6 +97,20 @@ def get_hf_tokenizer_from_composer_state_dict(
return hf_tokenizer


def load_tokenizer(
tokenizer_save_dir: str, trust_remote_code: bool
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
try:
return AutoTokenizer.from_pretrained(
tokenizer_save_dir, trust_remote_code=trust_remote_code)
except ValueError as e:
raise ValueError(
f'Got error while loading tokenizer with trust_remote_code={trust_remote_code}: {e}. '
+
'If accessing a tokenizer defined outside of the transformers module,'
+ ' please use --trust_remote_code.')


def _write_zero_bias(weight_name: str, weight_file_path: str,
bias_shape: Union[Tuple[int, ...], int]) -> None:
"""Write zeros for bias when converting MPT to FasterTransformer weights.
Expand Down
11 changes: 9 additions & 2 deletions scripts/inference/convert_composer_mpt_to_ft.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def write_ft_checkpoint_from_composer_checkpoint(
checkpoint_path: Union[Path, str],
infer_gpu_num: int,
save_dir: str,
trust_remote_code: bool,
output_precision: str = 'fp32',
local_checkpoint_save_location: Optional[Union[Path,
str]] = None) -> None:
Expand All @@ -79,6 +80,7 @@ def write_ft_checkpoint_from_composer_checkpoint(
checkpoint_path (Union[Path, str]): Path to the composer checkpoint, can be a local path, or a remote path beginning with ``s3://``, or another backend
supported by Composer.
infer_gpu_num (int): The number of gpus you are planning to use for inference.
trust_remote_code (bool): Whether or not to use code outside of the transformers module.
save_dir (str): Path of the directory to save the checkpoint in FT format.
output_precision (str, optional): The precision of the output weights saved to the FasterTransformer model. Can be either ``fp32`` or ``fp16``.
local_checkpoint_save_location (Optional[Union[Path, str]], optional): If specified, where to save the checkpoint file to locally.
Expand Down Expand Up @@ -125,7 +127,7 @@ def write_ft_checkpoint_from_composer_checkpoint(
print('#' * 30)
print('Extracting HF Tokenizer...')
hf_tokenizer = get_hf_tokenizer_from_composer_state_dict(
composer_state_dict)
composer_state_dict, trust_remote_code)
if hf_tokenizer is None:
print('Warning! No HF Tokenizer found!')

Expand Down Expand Up @@ -206,6 +208,10 @@ def parse_args() -> Namespace:
'Data type of weights in the FasterTransformer output model. Input checkpoint weights will be converted to this dtype.',
choices=['fp32', 'fp16'],
default='fp32')
parser.add_argument(
'--trust_remote_code',
action='store_true',
help='Whether or not to use code outside of transformers module.')

return parser.parse_args()

Expand All @@ -229,4 +235,5 @@ def parse_args() -> Namespace:
infer_gpu_num=args.infer_gpu_num,
save_dir=save_dir,
output_precision=args.output_precision,
local_checkpoint_save_location=args.local_checkpoint_save_location)
local_checkpoint_save_location=args.local_checkpoint_save_location,
trust_remote_code=args.trust_remote_code)
14 changes: 12 additions & 2 deletions scripts/inference/convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

from llmfoundry import MPTConfig, MPTForCausalLM
from llmfoundry.utils import get_hf_tokenizer_from_composer_state_dict
from llmfoundry.utils.checkpoint_conversion_helpers import load_tokenizer
from llmfoundry.utils.huggingface_hub_utils import \
edit_files_for_hf_compatibility


def write_huggingface_pretrained_from_composer_checkpoint(
checkpoint_path: Union[Path, str],
output_path: Union[Path, str],
trust_remote_code: bool,
output_precision: str = 'fp32',
local_checkpoint_save_location: Optional[Union[Path, str]] = None
) -> Tuple[PretrainedConfig, Optional[PreTrainedTokenizerBase]]:
Expand Down Expand Up @@ -63,6 +65,7 @@ def write_huggingface_pretrained_from_composer_checkpoint(
checkpoint_path (Union[Path, str]): Path to the composer checkpoint, can be a local path, or a remote path beginning with ``s3://``, or another backend
supported by :meth:`composer.utils.maybe_create_object_store_from_uri`.
output_path (Union[Path, str]): Path to the folder to write the output to.
trust_remote_code (bool): Whether or not to use code outside of the transformers module.
output_precision (str, optional): The precision of the output weights saved to `pytorch_model.bin`. Can be one of ``fp32``, ``fp16``, or ``bf16``.
local_checkpoint_save_location (Optional[Union[Path, str]], optional): If specified, where to save the checkpoint file to locally.
If the input ``checkpoint_path`` is already a local path, this will be a symlink.
Expand Down Expand Up @@ -110,7 +113,7 @@ def write_huggingface_pretrained_from_composer_checkpoint(
print('#' * 30)
print('Saving HF Tokenizer...')
hf_tokenizer = get_hf_tokenizer_from_composer_state_dict(
composer_state_dict)
composer_state_dict, trust_remote_code)
if hf_tokenizer is not None:
hf_tokenizer.save_pretrained(output_path)
print(hf_tokenizer)
Expand Down Expand Up @@ -157,6 +160,10 @@ def parse_args() -> Namespace:
default='fp32')
parser.add_argument('--hf_repo_for_upload', type=str, default=None)
parser.add_argument('--test_uploaded_model', action='store_true')
parser.add_argument(
'--trust_remote_code',
action='store_true',
help='Whether or not to use code outside of transformers module.')

return parser.parse_args()

Expand All @@ -179,6 +186,7 @@ def convert_composer_to_hf(args: Namespace) -> None:
config, tokenizer = write_huggingface_pretrained_from_composer_checkpoint(
checkpoint_path=args.composer_path,
output_path=local_folder_path,
trust_remote_code=args.trust_remote_code,
output_precision=args.output_precision,
local_checkpoint_save_location=args.local_checkpoint_save_location)

Expand Down Expand Up @@ -206,7 +214,9 @@ def convert_composer_to_hf(args: Namespace) -> None:
loaded_hf_model.save_pretrained(local_folder_path)

print(f'Loading tokenizer from {local_folder_path}')
tokenizer = transformers.AutoTokenizer.from_pretrained(local_folder_path)

tokenizer = load_tokenizer(local_folder_path,
trust_remote_code=args.trust_remote_code)
tokenizer.save_pretrained(local_folder_path)

# Only need to edit files for MPT because it has custom code
Expand Down
3 changes: 3 additions & 0 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ def test_convert_and_generate(model: str, tie_word_embeddings: bool,
output_precision='fp32',
local_checkpoint_save_location=None,
hf_repo_for_upload=None,
trust_remote_code=False,
test_uploaded_model=False)
convert_composer_to_hf(args)

Expand Down Expand Up @@ -591,6 +592,7 @@ def test_convert_and_generate_triton(tie_word_embeddings: str,
output_precision='fp32',
local_checkpoint_save_location=None,
hf_repo_for_upload=None,
trust_remote_code=False,
test_uploaded_model=False)
convert_composer_to_hf(args)

Expand Down Expand Up @@ -648,6 +650,7 @@ def test_convert_and_generate_meta(tie_word_embeddings: str,
output_precision='fp32',
local_checkpoint_save_location=None,
hf_repo_for_upload=None,
trust_remote_code=False,
test_uploaded_model=False)
convert_composer_to_hf(args)

Expand Down

0 comments on commit 70b40ec

Please sign in to comment.