Skip to content

Commit

Permalink
destroy process group in end_training (#3012)
Browse files Browse the repository at this point in the history
* destroy process group

* rephrase

* style

* fix on_main_process
  • Loading branch information
SunMarc authored Aug 15, 2024
1 parent 99c69aa commit 589fddd
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2708,11 +2708,10 @@ def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = {
for tracker in self.trackers:
tracker.log(values, step=step, **log_kwargs.get(tracker.name, {}))

@on_main_process
def end_training(self):
"""
Runs any special end training behaviors, such as stopping trackers on the main process only. Should always be
called at the end of your script if using experiment tracking.
Runs any special end training behaviors, such as stopping trackers on the main process only or destoying
process group. Should always be called at the end of your script if using experiment tracking.
Example:
Expand All @@ -2728,6 +2727,10 @@ def end_training(self):
for tracker in self.trackers:
tracker.finish()

if torch.distributed.is_initialized():
# needed when using torch.distributed.init_process_group
torch.distributed.destroy_process_group()

def save(self, obj, f, safe_serialization=False):
"""
Save the object passed to disk once per machine. Use in place of `torch.save`.
Expand Down

0 comments on commit 589fddd

Please sign in to comment.