diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index e8866b9f93..bad9084235 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -470,33 +470,6 @@ def _save_checkpoint(self, state: State, logger: Logger): # return state_dict # Add hook to move tensors to cpu to avoid CUDA OOM - # def tensor_hook( - # module: nn.Module, - # state_dict: Dict[str, Any], - # prefix: str, - # *args: Any, - # ) -> Dict[str, Any]: - # dtensor_fqns = [] - # for fqn in state_dict.keys(): - # tensor = state_dict[fqn] - # if isinstance(tensor, DTensor): - # dtensor_fqns.append(fqn) - # tensor = tensor.full_tensor() # type: ignore - # if dist.get_global_rank() == 0: - # # Offload any DTensors to CPU - # if cpu_offload: - # tensor = tensor.cpu() - # tensor = tensor.to(dtype=self.dtype) - # state_dict[fqn] = tensor - # else: - # state_dict[fqn] = None - # elif isinstance(tensor, torch.Tensor): - # state_dict[fqn] = tensor.to(dtype=self.dtype) - # del tensor - # if dist.get_global_rank() != 0: - # state_dict = {} - # return state_dict - def tensor_hook( module: nn.Module, state_dict: Dict[str, Any], @@ -510,18 +483,45 @@ def tensor_hook( dtensor_fqns.append(fqn) tensor = tensor.full_tensor() # type: ignore if dist.get_global_rank() == 0: + # Offload any DTensors to CPU if cpu_offload: tensor = tensor.cpu() + tensor = tensor.to(dtype=self.dtype) state_dict[fqn] = tensor + else: + state_dict[fqn] = None + elif isinstance(tensor, torch.Tensor): + state_dict[fqn] = tensor.to(dtype=self.dtype) + del tensor if dist.get_global_rank() != 0: - for fqn in dtensor_fqns: - del state_dict[fqn] + state_dict = {} + return state_dict + + # def tensor_hook( + # module: nn.Module, + # state_dict: Dict[str, Any], + # prefix: str, + # *args: Any, + # ) -> Dict[str, Any]: + # dtensor_fqns = [] + # for fqn in state_dict.keys(): + # tensor = state_dict[fqn] + # if isinstance(tensor, DTensor): + # dtensor_fqns.append(fqn) + # tensor = tensor.full_tensor() # type: ignore + # if dist.get_global_rank() == 0: + # if cpu_offload: + # tensor = tensor.cpu() + # state_dict[fqn] = tensor + # if dist.get_global_rank() != 0: + # for fqn in dtensor_fqns: + # del state_dict[fqn] - for fqn in state_dict.keys(): - if isinstance(state_dict[fqn], torch.Tensor): - state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype) + # for fqn in state_dict.keys(): + # if isinstance(state_dict[fqn], torch.Tensor): + # state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype) - return state_dict + # return state_dict hooks = [] for _, module in state_dict_model.named_modules():