From 596dd9dcef1df60bee39c62868b48b2d82d7cb28 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 22 Jul 2024 18:39:35 -0700 Subject: [PATCH] Do dtype conversion in torch hook to save memory (#1384) * Do dtype conversion in torch hook to save memory * update code comment Co-authored-by: Saaketh Narayan --------- Co-authored-by: Saaketh Narayan --- llmfoundry/callbacks/hf_checkpointer.py | 26 +++++++++---------- .../inference/test_convert_composer_to_hf.py | 2 ++ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 2ade458bb4..7127d37f40 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -435,8 +435,8 @@ def _save_checkpoint(self, state: State, logger: Logger): cpu_offload = True - # Add a dtensor->cpu tensor hook to avoid CUDA OOM - def dtensor_to_tensor_hook( + # Add hook to move tensors to cpu to avoid CUDA OOM + def tensor_hook( module: nn.Module, state_dict: Dict[str, Any], prefix: str, @@ -449,20 +449,23 @@ def dtensor_to_tensor_hook( dtensor_fqns.append(fqn) tensor = tensor.full_tensor() # type: ignore if dist.get_global_rank() == 0: + # Offload any DTensors to CPU if cpu_offload: tensor = tensor.cpu() state_dict[fqn] = tensor + else: + state_dict[fqn] = None + # Convert the state dict to the requested precision + if isinstance(tensor, torch.Tensor): + state_dict[fqn] = tensor.to(dtype=self.dtype) + del tensor if dist.get_global_rank() != 0: - for fqn in dtensor_fqns: - del state_dict[fqn] + state_dict = {} return state_dict hooks = [] for _, module in state_dict_model.named_modules(): - if isinstance(module, FSDP): - hooks.append( - module._register_state_dict_hook(dtensor_to_tensor_hook), - ) + hooks.append(module._register_state_dict_hook(tensor_hook),) state_dict = get_model_state_dict( state_dict_model, @@ -474,11 +477,6 @@ def dtensor_to_tensor_hook( for hook in hooks: hook.remove() - # Convert the state dict to the requested precision - for k, v in state_dict.items(): - if isinstance(v, torch.Tensor): - state_dict[k] = v.to(dtype=self.dtype) - new_model_instance = None # Need this for pyright because variable could be unbound if dist.get_global_rank() == 0: @@ -537,7 +535,7 @@ def dtensor_to_tensor_hook( original_tokenizer.save_pretrained(temp_save_dir) # Only need to edit files for MPT because it has custom code - if original_model.config.model_type == 'mpt': + if new_model_instance.config.model_type == 'mpt': log.debug('Editing MPT files for HuggingFace compatibility') edit_files_for_hf_compatibility( temp_save_dir, diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 68dc855154..ffdb09ca98 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -383,6 +383,8 @@ def test_huggingface_conversion_callback_interval( mlflow_logger_mock.model_registry_prefix = '' mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' mlflow_logger_mock._run_id = 'mlflow-run-id' + mlflow_logger_mock._enabled = True + mlflow_logger_mock.run_url = 'fake-url' checkpointer_callback.transform_model_pre_registration = MagicMock( wraps=checkpointer_callback.transform_model_pre_registration, )