diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 94f1b6e47f..b5c3726854 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -435,42 +435,7 @@ def _save_checkpoint(self, state: State, logger: Logger): cpu_offload = True - def dtensor_to_tensor_hook( - module: nn.Module, - state_dict: Dict[str, Any], - prefix: str, - *args: Any, - ) -> Dict[str, Any]: - dtensor_fqns = [] - for fqn in state_dict.keys(): - tensor = state_dict[fqn] - if isinstance(tensor, DTensor): - dtensor_fqns.append(fqn) - tensor = tensor.full_tensor() # type: ignore - if dist.get_global_rank() == 0: - if cpu_offload: - tensor = tensor.cpu() - state_dict[fqn] = tensor - if dist.get_global_rank() != 0: - for fqn in dtensor_fqns: - del state_dict[fqn] - return state_dict - - # def tensor_dtype_hook( - # module: nn.Module, - # state_dict: Dict[str, Any], - # prefix: str, - # *args: Any, - # ) -> Dict[str, Any]: - # for fqn in state_dict.keys(): - # tensor = state_dict[fqn] - # if isinstance(tensor, torch.Tensor): - # state_dict[fqn] = tensor.to(dtype=self.dtype) - # del tensor - # return state_dict - - # # Add hook to move tensors to cpu to avoid CUDA OOM - # def tensor_hook( + # def dtensor_to_tensor_hook( # module: nn.Module, # state_dict: Dict[str, Any], # prefix: str, @@ -483,24 +448,59 @@ def dtensor_to_tensor_hook( # dtensor_fqns.append(fqn) # tensor = tensor.full_tensor() # type: ignore # if dist.get_global_rank() == 0: - # # Offload any DTensors to CPU # if cpu_offload: # tensor = tensor.cpu() # state_dict[fqn] = tensor - # else: - # state_dict[fqn] = None - # # Convert the state dict to the requested precision + # if dist.get_global_rank() != 0: + # for fqn in dtensor_fqns: + # del state_dict[fqn] + # return state_dict + + # def tensor_dtype_hook( + # module: nn.Module, + # state_dict: Dict[str, Any], + # prefix: str, + # *args: Any, + # ) -> Dict[str, Any]: + # for fqn in state_dict.keys(): + # tensor = state_dict[fqn] # if isinstance(tensor, torch.Tensor): # state_dict[fqn] = tensor.to(dtype=self.dtype) # del tensor - # if dist.get_global_rank() != 0: - # state_dict = {} # return state_dict + # Add hook to move tensors to cpu to avoid CUDA OOM + def tensor_hook( + module: nn.Module, + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> Dict[str, Any]: + dtensor_fqns = [] + for fqn in state_dict.keys(): + tensor = state_dict[fqn] + if isinstance(tensor, DTensor): + dtensor_fqns.append(fqn) + tensor = tensor.full_tensor() # type: ignore + if dist.get_global_rank() == 0: + # Offload any DTensors to CPU + if cpu_offload: + tensor = tensor.cpu() + state_dict[fqn] = tensor + # Convert the state dict to the requested precision + if isinstance(tensor, torch.Tensor): + state_dict[fqn] = tensor.to(dtype=self.dtype) + del tensor + if dist.get_global_rank() != 0: + for fqn in dtensor_fqns: + del state_dict[fqn] + return state_dict + hooks = [] for _, module in state_dict_model.named_modules(): - if isinstance(module, FSDP): - hooks.append(module._register_state_dict_hook(dtensor_to_tensor_hook),) + # if isinstance(module, FSDP): + hooks.append(module._register_state_dict_hook(tensor_hook),) + state_dict = get_model_state_dict( state_dict_model,