From 04bbfb91c20ea894eda293aab646d1dd0fdb98f4 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 24 Jul 2024 18:08:59 -0700 Subject: [PATCH] try --- llmfoundry/callbacks/hf_checkpointer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index b5c3726854..6ec12de517 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -487,19 +487,20 @@ def tensor_hook( if cpu_offload: tensor = tensor.cpu() state_dict[fqn] = tensor + else: + state_dict[fqn] = None # Convert the state dict to the requested precision - if isinstance(tensor, torch.Tensor): - state_dict[fqn] = tensor.to(dtype=self.dtype) + # if isinstance(tensor, torch.Tensor): + # state_dict[fqn] = tensor.to(dtype=self.dtype) del tensor if dist.get_global_rank() != 0: - for fqn in dtensor_fqns: - del state_dict[fqn] + state_dict = {} return state_dict hooks = [] for _, module in state_dict_model.named_modules(): - # if isinstance(module, FSDP): - hooks.append(module._register_state_dict_hook(tensor_hook),) + if isinstance(module, FSDP): + hooks.append(module._register_state_dict_hook(tensor_hook),) state_dict = get_model_state_dict(