diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 2f858dd186..35508cc0c7 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -465,8 +465,7 @@ def tensor_hook( hooks = [] for _, module in state_dict_model.named_modules(): - if isinstance(module, FSDP): - hooks.append(module._register_state_dict_hook(tensor_hook),) + hooks.append(module._register_state_dict_hook(tensor_hook),) state_dict = get_model_state_dict( state_dict_model,