From 3de4088677e47f5c19f9587ad10c68e043f7d377 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 24 Jul 2024 17:33:54 -0700 Subject: [PATCH] go back --- llmfoundry/callbacks/hf_checkpointer.py | 77 +++++++++++++++---------- 1 file changed, 45 insertions(+), 32 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index fbed36d1f0..94f1b6e47f 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -435,29 +435,7 @@ def _save_checkpoint(self, state: State, logger: Logger): cpu_offload = True - # def dtensor_to_tensor_hook( - # module: nn.Module, - # state_dict: Dict[str, Any], - # prefix: str, - # *args: Any, - # ) -> Dict[str, Any]: - # dtensor_fqns = [] - # for fqn in state_dict.keys(): - # tensor = state_dict[fqn] - # if isinstance(tensor, DTensor): - # dtensor_fqns.append(fqn) - # tensor = tensor.full_tensor() # type: ignore - # if dist.get_global_rank() == 0: - # if cpu_offload: - # tensor = tensor.cpu() - # state_dict[fqn] = tensor - # if dist.get_global_rank() != 0: - # for fqn in dtensor_fqns: - # del state_dict[fqn] - # return state_dict - - # Add hook to move tensors to cpu to avoid CUDA OOM - def tensor_hook( + def dtensor_to_tensor_hook( module: nn.Module, state_dict: Dict[str, Any], prefix: str, @@ -470,24 +448,59 @@ def tensor_hook( dtensor_fqns.append(fqn) tensor = tensor.full_tensor() # type: ignore if dist.get_global_rank() == 0: - # Offload any DTensors to CPU if cpu_offload: tensor = tensor.cpu() state_dict[fqn] = tensor - else: - state_dict[fqn] = None - # Convert the state dict to the requested precision - if isinstance(tensor, torch.Tensor): - state_dict[fqn] = tensor.to(dtype=self.dtype) - del tensor if dist.get_global_rank() != 0: - state_dict = {} + for fqn in dtensor_fqns: + del state_dict[fqn] return state_dict + + # def tensor_dtype_hook( + # module: nn.Module, + # state_dict: Dict[str, Any], + # prefix: str, + # *args: Any, + # ) -> Dict[str, Any]: + # for fqn in state_dict.keys(): + # tensor = state_dict[fqn] + # if isinstance(tensor, torch.Tensor): + # state_dict[fqn] = tensor.to(dtype=self.dtype) + # del tensor + # return state_dict + + # # Add hook to move tensors to cpu to avoid CUDA OOM + # def tensor_hook( + # module: nn.Module, + # state_dict: Dict[str, Any], + # prefix: str, + # *args: Any, + # ) -> Dict[str, Any]: + # dtensor_fqns = [] + # for fqn in state_dict.keys(): + # tensor = state_dict[fqn] + # if isinstance(tensor, DTensor): + # dtensor_fqns.append(fqn) + # tensor = tensor.full_tensor() # type: ignore + # if dist.get_global_rank() == 0: + # # Offload any DTensors to CPU + # if cpu_offload: + # tensor = tensor.cpu() + # state_dict[fqn] = tensor + # else: + # state_dict[fqn] = None + # # Convert the state dict to the requested precision + # if isinstance(tensor, torch.Tensor): + # state_dict[fqn] = tensor.to(dtype=self.dtype) + # del tensor + # if dist.get_global_rank() != 0: + # state_dict = {} + # return state_dict hooks = [] for _, module in state_dict_model.named_modules(): if isinstance(module, FSDP): - hooks.append(module._register_state_dict_hook(tensor_hook),) + hooks.append(module._register_state_dict_hook(dtensor_to_tensor_hook),) state_dict = get_model_state_dict( state_dict_model,