Skip to content

Commit

Permalink
do dtype conversion in hook
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Jul 19, 2024
1 parent cf710f3 commit 51c024d
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def dtensor_to_tensor_hook(
if dist.get_global_rank() == 0:
if cpu_offload:
tensor = tensor.cpu()
state_dict[fqn] = tensor
state_dict[fqn] = tensor.to(dtype=self.dtype)
if dist.get_global_rank() != 0:
for fqn in dtensor_fqns:
del state_dict[fqn]
Expand Down Expand Up @@ -463,10 +463,10 @@ def dtensor_to_tensor_hook(
with state_dict_context:
state_dict = state_dict_model.state_dict()

# Convert the state dict to the requested precis
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)
# # Convert the state dict to the requested precis
# for k, v in state_dict.items():
# if isinstance(v, torch.Tensor):
# state_dict[k] = v.to(dtype=self.dtype)

new_model_instance = None # Need this for pyright because variable could be unbound

Expand Down

0 comments on commit 51c024d

Please sign in to comment.