Skip to content

Commit

Permalink
Remove extra call to .to and load_state_dict in hf checkpointer (#939)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored and bigning committed Feb 5, 2024
1 parent 969c000 commit a97cf66
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,13 @@ def _save_checkpoint(self, state: State, logger: Logger):
new_model_instance = type(original_model)(
new_base_model_instance,
original_model.peft_config[active_adapter])
new_model_instance.to(dtype=self.dtype)
else:
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
new_model_instance = type(original_model)(copied_config)

new_model_instance.to(dtype=self.dtype)
new_model_instance.load_state_dict(state_dict)

# Then load the state dict in with "assign" so that the state dict
# is loaded properly even though the model is initially on meta device.
new_model_instance.load_state_dict(state_dict, assign=True)
Expand Down

0 comments on commit a97cf66

Please sign in to comment.