From c23be4ab9e146ff1064758a83fbe57c7d7a8e2ba Mon Sep 17 00:00:00 2001 From: Charles Tang Date: Mon, 17 Jun 2024 21:31:31 -0700 Subject: [PATCH] Fix TE HF checkpoint saving (#1280) --- llmfoundry/callbacks/hf_checkpointer.py | 22 +++++- .../inference/test_convert_composer_to_hf.py | 78 +++++++++++++------ 2 files changed, 73 insertions(+), 27 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 28b33b43d8..d80060d6f6 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -17,7 +17,7 @@ import numpy as np import torch import torch.nn as nn -from composer.core import Callback, Event, State, Time, TimeUnit +from composer.core import Callback, Event, Precision, State, Time, TimeUnit from composer.core.state import fsdp_state_dict_type_context from composer.loggers import Logger, MLFlowLogger from composer.models import HuggingFaceModel @@ -37,6 +37,12 @@ from llmfoundry.utils.huggingface_hub_utils import \ edit_files_for_hf_compatibility +try: + import transformer_engine.pytorch as te + is_te_imported = True +except ModuleNotFoundError: + is_te_imported = False + log = logging.getLogger(__name__) __all__ = ['HuggingFaceCheckpointer'] @@ -486,9 +492,19 @@ def dtensor_to_tensor_hook( ) log.debug('Saving Hugging Face checkpoint to disk') - new_model_instance.save_pretrained(temp_save_dir) + # This context manager casts the TE extra state in io.BytesIO format to tensor format + # Needed for proper hf ckpt saving. + context_manager = te.onnx_export( + True, + ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( + ) + with context_manager: + new_model_instance.save_pretrained(temp_save_dir) if original_tokenizer is not None: - assert isinstance(original_tokenizer, PreTrainedTokenizerBase) + assert isinstance( + original_tokenizer, + PreTrainedTokenizerBase, + ) original_tokenizer.save_pretrained(temp_save_dir) # Only need to edit files for MPT because it has custom code diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 0577e13a1f..1b2f791995 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import contextlib import json import math import os @@ -468,6 +469,7 @@ def _get_model_and_tokenizer( model: str, max_seq_len: int, tie_word_embeddings: bool, + precision: str, ): if model == 'mpt': model_cfg = { @@ -482,6 +484,7 @@ def _get_model_and_tokenizer( 'attn_config': { 'attn_impl': 'torch', }, + 'fc_type': 'te' if precision == 'amp_fp8' else 'torch', 'loss_fn': 'torch_crossentropy', 'tie_word_embeddings': tie_word_embeddings, } @@ -783,8 +786,9 @@ def _assert_checkpoint_equivalence( ) @pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None]) @pytest.mark.parametrize( - 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', - [('1ba', '1ba', '1ba', 1, 1)], + 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints,trainer_precision', + [('1ba', '1ba', '1ba', 1, 1, 'amp_bf16'), + ('1ba', '1ba', '1ba', 1, 1, 'amp_fp8')], ) @patch('os.cpu_count', MagicMock(return_value=1)) @patch( @@ -801,10 +805,30 @@ def test_huggingface_conversion_callback( max_duration: str, expected_hf_checkpoints: int, expected_normal_checkpoints: int, + trainer_precision: str, peft_config: Optional[dict], ): if model == 'mptmoe' and fsdp_state_dict_type is None: pytest.skip('mptmoe requires FSDP') + if trainer_precision == 'amp_fp8': + # Check if transformer-engine is installed for FP8. + try: + import transformer_engine.pytorch as te + except ImportError: + pytest.skip( + 'Precision amp_fp8 requires transformer-engine to be installed', + ) + + # Check we are using mpt models only for FP8. + if (model == 'neo' or model == 'llama2'): + pytest.skip( + 'Precision amp_fp8 works only for mpt models, not hf models', + ) + + # Check that we are using H100 or later for FP8. + if not (torch.cuda.get_device_capability() >= (8, 9)): + pytest.skip('Amp FP8 requires a GPU with compute capability >= 8.9') + delete_transformers_cache() dist.initialize_dist(get_device('gpu')) @@ -825,9 +849,10 @@ def test_huggingface_conversion_callback( # Get small version of each model model_cfg, tokenizer_name = _get_model_and_tokenizer( - model, - max_seq_len, - tie_word_embeddings, + model=model, + max_seq_len=max_seq_len, + tie_word_embeddings=tie_word_embeddings, + precision=trainer_precision, ) assert model_cfg is not None assert tokenizer_name is not None @@ -883,7 +908,7 @@ def test_huggingface_conversion_callback( trainer = Trainer( model=original_model, device='gpu', - precision='amp_bf16', + precision=trainer_precision, fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None, train_dataloader=train_dataloader, save_folder=os.path.join(tmp_path, 'checkpoints'), @@ -900,24 +925,29 @@ def test_huggingface_conversion_callback( # summon full params to check equivalence from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - with FSDP.summon_full_params( - trainer.state.model, - writeback=False, - recurse=True, - ): - _assert_checkpoint_equivalence( - tmp_path=tmp_path, - expected_normal_checkpoints=expected_normal_checkpoints, - expected_hf_checkpoints=expected_hf_checkpoints, - trainer=trainer, - batches_per_epoch=batches_per_epoch, - original_model=original_model, - precision=precision, - model=model, - tokenizer=tokenizer, - fsdp_state_dict_type=fsdp_state_dict_type, - peft_config=peft_config, - ) + + context_manager = te.onnx_export( # type: ignore + True, + ) if trainer_precision == 'amp_fp8' else contextlib.nullcontext() + with context_manager: + with FSDP.summon_full_params( + trainer.state.model, + writeback=False, + recurse=True, + ): + _assert_checkpoint_equivalence( + tmp_path=tmp_path, + expected_normal_checkpoints=expected_normal_checkpoints, + expected_hf_checkpoints=expected_hf_checkpoints, + trainer=trainer, + batches_per_epoch=batches_per_epoch, + original_model=original_model, + precision=precision, + model=model, + tokenizer=tokenizer, + fsdp_state_dict_type=fsdp_state_dict_type, + peft_config=peft_config, + ) dist.barrier() delete_transformers_cache()