Skip to content

Commit

Permalink
Add barriers and fix pyright
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Mar 31, 2024
1 parent c9c22a7 commit dff1334
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)

new_model_instance = None # Need this for pyright because variable could be unbound

if dist.get_global_rank() == 0:
log.debug('Saving Hugging Face checkpoint in global rank 0')

Expand Down Expand Up @@ -371,6 +373,9 @@ def _save_checkpoint(self, state: State, logger: Logger):
overwrite=self.overwrite,
)

dist.barrier()

if dist.get_global_rank() == 0:
if self.mlflow_registered_model_name and self._is_last_batch(
state):
components = {'model': new_model_instance}
Expand Down Expand Up @@ -435,3 +440,4 @@ def _save_checkpoint(self, state: State, logger: Logger):
})
process.start()
self.child_processes.append(process)
dist.barrier()

0 comments on commit dff1334

Please sign in to comment.