diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index bad9084235..2f858dd186 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -435,40 +435,6 @@ def _save_checkpoint(self, state: State, logger: Logger): cpu_offload = True - # def dtensor_to_tensor_hook( - # module: nn.Module, - # state_dict: Dict[str, Any], - # prefix: str, - # *args: Any, - # ) -> Dict[str, Any]: - # dtensor_fqns = [] - # for fqn in state_dict.keys(): - # tensor = state_dict[fqn] - # if isinstance(tensor, DTensor): - # dtensor_fqns.append(fqn) - # tensor = tensor.full_tensor() # type: ignore - # if dist.get_global_rank() == 0: - # if cpu_offload: - # tensor = tensor.cpu() - # state_dict[fqn] = tensor - # if dist.get_global_rank() != 0: - # for fqn in dtensor_fqns: - # del state_dict[fqn] - # return state_dict - - # def tensor_dtype_hook( - # module: nn.Module, - # state_dict: Dict[str, Any], - # prefix: str, - # *args: Any, - # ) -> Dict[str, Any]: - # for fqn in state_dict.keys(): - # tensor = state_dict[fqn] - # if isinstance(tensor, torch.Tensor): - # state_dict[fqn] = tensor.to(dtype=self.dtype) - # del tensor - # return state_dict - # Add hook to move tensors to cpu to avoid CUDA OOM def tensor_hook( module: nn.Module, @@ -486,49 +452,22 @@ def tensor_hook( # Offload any DTensors to CPU if cpu_offload: tensor = tensor.cpu() - tensor = tensor.to(dtype=self.dtype) state_dict[fqn] = tensor else: state_dict[fqn] = None - elif isinstance(tensor, torch.Tensor): - state_dict[fqn] = tensor.to(dtype=self.dtype) + + if isinstance(state_dict[fqn], torch.Tensor): + state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype) del tensor if dist.get_global_rank() != 0: state_dict = {} return state_dict - - # def tensor_hook( - # module: nn.Module, - # state_dict: Dict[str, Any], - # prefix: str, - # *args: Any, - # ) -> Dict[str, Any]: - # dtensor_fqns = [] - # for fqn in state_dict.keys(): - # tensor = state_dict[fqn] - # if isinstance(tensor, DTensor): - # dtensor_fqns.append(fqn) - # tensor = tensor.full_tensor() # type: ignore - # if dist.get_global_rank() == 0: - # if cpu_offload: - # tensor = tensor.cpu() - # state_dict[fqn] = tensor - # if dist.get_global_rank() != 0: - # for fqn in dtensor_fqns: - # del state_dict[fqn] - - # for fqn in state_dict.keys(): - # if isinstance(state_dict[fqn], torch.Tensor): - # state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype) - - # return state_dict hooks = [] for _, module in state_dict_model.named_modules(): if isinstance(module, FSDP): hooks.append(module._register_state_dict_hook(tensor_hook),) - state_dict = get_model_state_dict( state_dict_model, options=StateDictOptions(