diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 4de7f9f2c6..2ade458bb4 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -18,7 +18,6 @@ import torch import torch.nn as nn from composer.core import Callback, Event, Precision, State, Time, TimeUnit -from composer.core.state import fsdp_state_dict_type_context from composer.loggers import Logger, MLFlowLogger from composer.models import HuggingFaceModel from composer.utils import ( @@ -29,7 +28,12 @@ ) from composer.utils.misc import create_interval_scheduler from mlflow.transformers import _fetch_model_card, _write_license_information -from packaging import version +from torch.distributed._tensor import DTensor +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, +) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from transformers import ( PretrainedConfig, PreTrainedModel, @@ -179,6 +183,7 @@ def __init__( 'bfloat16': torch.bfloat16, }[precision] self.flatten_imports = flatten_imports + self.using_peft = False # mlflow config setup self.mlflow_registered_model_name = mlflow_registered_model_name @@ -274,6 +279,15 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set( '1GB', ) + + # Check if the model is using PEFT + if state.is_model_ddp: + composer_model = state.model.module + elif isinstance(state.model.model, FSDP): + composer_model = state.model + else: + composer_model = state.model + self.using_peft = composer_model.using_peft elif event == Event.FIT_END: # Wait for all child processes spawned by the callback to finish. timeout = 3600 @@ -362,6 +376,23 @@ def transform_config( copied_config.ffn_config['moe_world_size'] = 1 return copied_config + def transform_model_pre_registration( + self, + model: PreTrainedModel, + ) -> PreTrainedModel: + """Transform the model before registering with MLflow. + + This allows a subclass to modify the model before registering with MLflow. The base class implementation will + make no modifications. + + Args: + model (PreTrainedModel): The model to be transformed. + + Returns: + PreTrainedModel: The transformed model. + """ + return model + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -388,82 +419,62 @@ def _save_checkpoint(self, state: State, logger: Logger): temp_save_dir = tempfile.mkdtemp() if use_temp_dir else save_dir log.debug('Gathering state dict') - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP if state.is_model_ddp: - composer_model = state.model.module original_model: PreTrainedModel = state.model.module.model state_dict_model = state.model.module.model original_tokenizer = state.model.module.tokenizer elif isinstance(state.model.model, FSDP): - composer_model = state.model original_model: PreTrainedModel = state.model.model.module state_dict_model = state.model.model original_tokenizer = state.model.tokenizer else: - composer_model = state.model original_model: PreTrainedModel = state.model.model state_dict_model = state.model.model original_tokenizer = state.model.tokenizer - if version.parse(torch.__version__) > version.parse('2.2.9'): - from torch.distributed._tensor import DTensor - from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - get_model_state_dict, - ) - cpu_offload = True - - # Add a dtensor->cpu tensor hook to avoid CUDA OOM - def dtensor_to_tensor_hook( - module: nn.Module, - state_dict: Dict[str, Any], - prefix: str, - *args: Any, - ) -> Dict[str, Any]: - dtensor_fqns = [] - for fqn in state_dict.keys(): - tensor = state_dict[fqn] - if isinstance(tensor, DTensor): - dtensor_fqns.append(fqn) - tensor = tensor.full_tensor() # type: ignore - if dist.get_global_rank() == 0: - if cpu_offload: - tensor = tensor.cpu() - state_dict[fqn] = tensor - if dist.get_global_rank() != 0: - for fqn in dtensor_fqns: - del state_dict[fqn] - return state_dict - - hooks = [] - for _, module in state_dict_model.named_modules(): - if isinstance(module, FSDP): - hooks.append( - module. - _register_state_dict_hook(dtensor_to_tensor_hook), - ) + cpu_offload = True + + # Add a dtensor->cpu tensor hook to avoid CUDA OOM + def dtensor_to_tensor_hook( + module: nn.Module, + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> Dict[str, Any]: + dtensor_fqns = [] + for fqn in state_dict.keys(): + tensor = state_dict[fqn] + if isinstance(tensor, DTensor): + dtensor_fqns.append(fqn) + tensor = tensor.full_tensor() # type: ignore + if dist.get_global_rank() == 0: + if cpu_offload: + tensor = tensor.cpu() + state_dict[fqn] = tensor + if dist.get_global_rank() != 0: + for fqn in dtensor_fqns: + del state_dict[fqn] + return state_dict + + hooks = [] + for _, module in state_dict_model.named_modules(): + if isinstance(module, FSDP): + hooks.append( + module._register_state_dict_hook(dtensor_to_tensor_hook), + ) - state_dict = get_model_state_dict( - state_dict_model, - options=StateDictOptions( - full_state_dict=True, - cpu_offload=cpu_offload, - ), - ) - for hook in hooks: - hook.remove() - else: - state_dict_context = fsdp_state_dict_type_context( - original_model, - state_dict_type='full', - ) if ((not state.is_model_ddp) and - isinstance(state_dict_model, - FSDP)) else contextlib.nullcontext() - with state_dict_context: - state_dict = state_dict_model.state_dict() - - # Convert the state dict to the requested precis + state_dict = get_model_state_dict( + state_dict_model, + options=StateDictOptions( + full_state_dict=True, + cpu_offload=cpu_offload, + ), + ) + for hook in hooks: + hook.remove() + + # Convert the state dict to the requested precision for k, v in state_dict.items(): if isinstance(v, torch.Tensor): state_dict[k] = v.to(dtype=self.dtype) @@ -480,22 +491,19 @@ def dtensor_to_tensor_hook( log.debug(f'Creating new model instance') - if composer_model.using_peft: - # We don't use meta here because the state dict does not contain the full - # model, only the adapter weights. - active_adapter = original_model.active_adapter - base_model = original_model.get_base_model() - new_base_model_instance = type(base_model)(new_config) - - new_model_instance = type(original_model)( - new_base_model_instance, - original_model.peft_config[active_adapter], - ) - new_model_instance.to(dtype=self.dtype) - else: - # First create the model instance on meta device to avoid the - # initialization cost. - with init_empty_weights(): + # First create the model instance on meta device to avoid the + # initialization cost. + with init_empty_weights(): + if self.using_peft: + active_adapter = original_model.active_adapter + base_model = original_model.get_base_model() + new_base_model_instance = type(base_model)(new_config) + + new_model_instance = type(original_model)( + new_base_model_instance, + original_model.peft_config[active_adapter], + ) + else: new_model_instance = type(original_model)(new_config) new_model_instance.generation_config.update( **original_model.generation_config.to_dict(), @@ -556,6 +564,11 @@ def dtensor_to_tensor_hook( if dist.get_global_rank() == 0: if self.mlflow_registered_model_name and self._is_last_batch(state): + + new_model_instance = self.transform_model_pre_registration( + new_model_instance, + ) + components = {'model': new_model_instance} if original_tokenizer is not None: components['tokenizer'] = original_tokenizer @@ -575,7 +588,7 @@ def dtensor_to_tensor_hook( model_saving_kwargs: Dict[str, Any] = { 'path': local_save_path, } - if composer_model.using_peft: + if self.using_peft: model_saving_kwargs['flavor'] = 'peft' model_saving_kwargs['save_pretrained_dir' ] = temp_save_dir diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 2ef458fece..68dc855154 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -383,6 +383,9 @@ def test_huggingface_conversion_callback_interval( mlflow_logger_mock.model_registry_prefix = '' mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' mlflow_logger_mock._run_id = 'mlflow-run-id' + checkpointer_callback.transform_model_pre_registration = MagicMock( + wraps=checkpointer_callback.transform_model_pre_registration, + ) trainer = Trainer( model=original_model, device='gpu', @@ -407,8 +410,10 @@ def test_huggingface_conversion_callback_interval( input_example=ANY, metadata={}, ) + assert checkpointer_callback.transform_model_pre_registration.call_count == 1 assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 else: + assert checkpointer_callback.transform_model_pre_registration.call_count == 0 assert mlflow_logger_mock.save_model.call_count == 0 assert mlflow_logger_mock.register_model_with_run_id.call_count == 0