From 589fddd317f008e704073c133bc2cb8958f287e6 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Thu, 15 Aug 2024 14:31:21 +0200 Subject: [PATCH] destroy process group in `end_training` (#3012) * destroy process group * rephrase * style * fix on_main_process --- src/accelerate/accelerator.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 72ff911588b..91006258efa 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2708,11 +2708,10 @@ def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = { for tracker in self.trackers: tracker.log(values, step=step, **log_kwargs.get(tracker.name, {})) - @on_main_process def end_training(self): """ - Runs any special end training behaviors, such as stopping trackers on the main process only. Should always be - called at the end of your script if using experiment tracking. + Runs any special end training behaviors, such as stopping trackers on the main process only or destoying + process group. Should always be called at the end of your script if using experiment tracking. Example: @@ -2728,6 +2727,10 @@ def end_training(self): for tracker in self.trackers: tracker.finish() + if torch.distributed.is_initialized(): + # needed when using torch.distributed.init_process_group + torch.distributed.destroy_process_group() + def save(self, obj, f, safe_serialization=False): """ Save the object passed to disk once per machine. Use in place of `torch.save`.