Skip to content

Commit

Permalink
Trigger
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Sep 11, 2023
1 parent 00ee7db commit 9297483
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 28 deletions.
15 changes: 8 additions & 7 deletions docs/source/concept_guides/deferring_execution.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,17 @@ with accelerator.main_process_first():

## Applying checks such as Early Stopping

To have a check that works with a flag set by a particular process, the `check` and `set` breakpoint API should be used.
To have a check that works with a flag set by a particular process, the `check` and `set` breakpoint API should be used. Useful examples
for doing so can include situations such as using early stopping and monitoring the loss (as each loss slightly differs on each process).

Call [`Accelerator.set_breakpoint`] when your condition has been met, and [`Accelerator.check_breakpoint`] when checking if that condition has been met in any process:
Call [`Accelerator.set_trigger`] when your condition has been met, and [`Accelerator.check_trigger`] when checking if that condition has been met in any process:

```python
# Assume `should_do_breakpoint` is a custom defined function that returns a conditional
if should_do_breakpoint(loss):
accelerator.set_breakpoint()
# Assume `should_do_early_stopping` is a custom defined function that returns a conditional
if should_do_early_stopping(loss):
accelerator.set_trigger()

# Later in the training script when we need to check for the breakpoint
if accelerator.check_breakpoint():
if accelerator.check_trigger():
break
```
```
18 changes: 9 additions & 9 deletions examples/by_feature/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,15 @@ def training_function(config, args):
lr_scheduler.step()
optimizer.zero_grad()

# New code
# Check if we should stop the training on any processes
if callback.check_early_stopping(loss.item()):
accelerator.set_trigger()

# If so, we break the loop
if accelerator.check_trigger():
break

model.eval()
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
Expand All @@ -212,15 +221,6 @@ def training_function(config, args):

eval_metric = metric.compute()

# New code
# Check if we should stop the training on any processes
if callback.check_early_stopping(outputs.loss.item()):
accelerator.set_breakpoint()

# If so, we break the loop
if accelerator.check_breakpoint():
break

# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)

Expand Down
18 changes: 9 additions & 9 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1966,10 +1966,10 @@ def backward(self, loss, **kwargs):
else:
loss.backward(**kwargs)

def set_breakpoint(self):
def set_trigger(self):
"""
Sets the internal flag tensor to 1 on the current process. A latter check of `Accelerator().check_breakpoint()`
should follow using this which will check across all processes.
Sets the internal trigger tensor to 1 on the current process. A latter check should follow using this which
will check across all processes.
Note:
Does not require `wait_for_everyone()`
Expand All @@ -1984,18 +1984,18 @@ def set_breakpoint(self):
>>> # `should_do_breakpoint` is a custom function to monitor when to break,
>>> # e.g. when the loss is NaN
>>> if should_do_breakpoint(loss):
... accelerator.set_breakpoint()
... accelerator.set_trigger()
>>> # Assume later in the training script
>>> if accelerator.check_breakpoint():
... break
```
"""
self.flag_tensor = torch.tensor(1, device=self.device)

def check_breakpoint(self):
def check_trigger(self):
"""
Checks if `self.flag_tensor` has been set to 1 in any of the processes. If so, will return `True` and reset the
flag tensor to 0.
Checks if the internal trigger tensor has been set to 1 in any of the processes. If so, will return `True` and
reset the trigger tensor to 0.
Note:
Does not require `wait_for_everyone()`
Expand All @@ -2010,9 +2010,9 @@ def check_breakpoint(self):
>>> # `should_do_breakpoint` is a custom function to monitor when to break,
>>> # e.g. when the loss is NaN
>>> if should_do_breakpoint(loss):
... accelerator.set_breakpoint()
... accelerator.set_trigger()
>>> # Assume later in the training script
>>> if accelerator.check_breakpoint():
>>> if accelerator.check_trigger():
... break
```
"""
Expand Down
6 changes: 3 additions & 3 deletions src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,15 +544,15 @@ def test_split_between_processes_tensor():
def test_breakpoint():
accelerator = Accelerator()
# should start with being false
assert accelerator.check_breakpoint() is False
assert accelerator.set_trigger() is False

# set a breakpoint on the main process
if accelerator.is_main_process:
accelerator.set_breakpoint()
accelerator.set_trigger()

# check it's been activated across all processes
# calls `all_reduce` and triggers a sync
assert accelerator.check_breakpoint() is True
assert accelerator.check_trigger() is True


def main():
Expand Down

0 comments on commit 9297483

Please sign in to comment.