diff --git a/docs/source/concept_guides/deferring_execution.md b/docs/source/concept_guides/deferring_execution.md index 22e6a21cb0b..c2cc55a4239 100644 --- a/docs/source/concept_guides/deferring_execution.md +++ b/docs/source/concept_guides/deferring_execution.md @@ -111,16 +111,17 @@ with accelerator.main_process_first(): ## Applying checks such as Early Stopping -To have a check that works with a flag set by a particular process, the `check` and `set` breakpoint API should be used. +To have a check that works with a flag set by a particular process, the `check` and `set` breakpoint API should be used. Useful examples +for doing so can include situations such as using early stopping and monitoring the loss (as each loss slightly differs on each process). -Call [`Accelerator.set_breakpoint`] when your condition has been met, and [`Accelerator.check_breakpoint`] when checking if that condition has been met in any process: +Call [`Accelerator.set_trigger`] when your condition has been met, and [`Accelerator.check_trigger`] when checking if that condition has been met in any process: ```python -# Assume `should_do_breakpoint` is a custom defined function that returns a conditional -if should_do_breakpoint(loss): - accelerator.set_breakpoint() +# Assume `should_do_early_stopping` is a custom defined function that returns a conditional +if should_do_early_stopping(loss): + accelerator.set_trigger() # Later in the training script when we need to check for the breakpoint -if accelerator.check_breakpoint(): +if accelerator.check_trigger(): break -``` +``` \ No newline at end of file diff --git a/examples/by_feature/early_stopping.py b/examples/by_feature/early_stopping.py index cfa6528ec40..e9c5e2ccbd0 100644 --- a/examples/by_feature/early_stopping.py +++ b/examples/by_feature/early_stopping.py @@ -197,6 +197,15 @@ def training_function(config, args): lr_scheduler.step() optimizer.zero_grad() + # New code + # Check if we should stop the training on any processes + if callback.check_early_stopping(loss.item()): + accelerator.set_trigger() + + # If so, we break the loop + if accelerator.check_trigger(): + break + model.eval() for step, batch in enumerate(eval_dataloader): # We could avoid this line since we set the accelerator with `device_placement=True`. @@ -212,15 +221,6 @@ def training_function(config, args): eval_metric = metric.compute() - # New code - # Check if we should stop the training on any processes - if callback.check_early_stopping(outputs.loss.item()): - accelerator.set_breakpoint() - - # If so, we break the loop - if accelerator.check_breakpoint(): - break - # Use accelerator.print to print only on the main process. accelerator.print(f"epoch {epoch}:", eval_metric) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index a911a6a6637..2d20ba716b8 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1966,10 +1966,10 @@ def backward(self, loss, **kwargs): else: loss.backward(**kwargs) - def set_breakpoint(self): + def set_trigger(self): """ - Sets the internal flag tensor to 1 on the current process. A latter check of `Accelerator().check_breakpoint()` - should follow using this which will check across all processes. + Sets the internal trigger tensor to 1 on the current process. A latter check should follow using this which + will check across all processes. Note: Does not require `wait_for_everyone()` @@ -1984,7 +1984,7 @@ def set_breakpoint(self): >>> # `should_do_breakpoint` is a custom function to monitor when to break, >>> # e.g. when the loss is NaN >>> if should_do_breakpoint(loss): - ... accelerator.set_breakpoint() + ... accelerator.set_trigger() >>> # Assume later in the training script >>> if accelerator.check_breakpoint(): ... break @@ -1992,10 +1992,10 @@ def set_breakpoint(self): """ self.flag_tensor = torch.tensor(1, device=self.device) - def check_breakpoint(self): + def check_trigger(self): """ - Checks if `self.flag_tensor` has been set to 1 in any of the processes. If so, will return `True` and reset the - flag tensor to 0. + Checks if the internal trigger tensor has been set to 1 in any of the processes. If so, will return `True` and + reset the trigger tensor to 0. Note: Does not require `wait_for_everyone()` @@ -2010,9 +2010,9 @@ def check_breakpoint(self): >>> # `should_do_breakpoint` is a custom function to monitor when to break, >>> # e.g. when the loss is NaN >>> if should_do_breakpoint(loss): - ... accelerator.set_breakpoint() + ... accelerator.set_trigger() >>> # Assume later in the training script - >>> if accelerator.check_breakpoint(): + >>> if accelerator.check_trigger(): ... break ``` """ diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index 93b6ce660e0..82dbe2cfb80 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -544,15 +544,15 @@ def test_split_between_processes_tensor(): def test_breakpoint(): accelerator = Accelerator() # should start with being false - assert accelerator.check_breakpoint() is False + assert accelerator.set_trigger() is False # set a breakpoint on the main process if accelerator.is_main_process: - accelerator.set_breakpoint() + accelerator.set_trigger() # check it's been activated across all processes # calls `all_reduce` and triggers a sync - assert accelerator.check_breakpoint() is True + assert accelerator.check_trigger() is True def main():