diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index bf14dafc90..f11f488fba 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -41,6 +41,8 @@ from llmfoundry.utils.huggingface_hub_utils import \ edit_files_for_hf_compatibility +from llmfoundry.callbacks.scheduled_gc_callback import gc_cuda + try: import transformer_engine.pytorch as te is_te_imported = True @@ -390,6 +392,8 @@ def _save_checkpoint(self, state: State, logger: Logger): log.debug('Gathering state dict') from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + gc_cuda() + if state.is_model_ddp: composer_model = state.model.module original_model: PreTrainedModel = state.model.module.model @@ -426,11 +430,18 @@ def dtensor_to_tensor_hook( tensor = state_dict[fqn] if isinstance(tensor, DTensor): dtensor_fqns.append(fqn) - tensor = tensor.full_tensor() # type: ignore + tensor = tensor.full_tensor() if dist.get_global_rank() == 0: if cpu_offload: - tensor = tensor.cpu() - state_dict[fqn] = tensor.to(dtype=self.dtype) + tensor = tensor.to(dtype=self.dtype, device=torch.device('cpu')) + state_dict[fqn] = tensor + else: + state_dict[fqn] = None + del tensor + else: + log.debug(f'Not a DTensor {fqn}') + gc_cuda() + if dist.get_global_rank() != 0: for fqn in dtensor_fqns: del state_dict[fqn]