Skip to content

Commit

Permalink
Introduce Stateful Callbacks (#29666)
Browse files Browse the repository at this point in the history
* Introduce saveable callbacks

* Add note

* Test for non-present and flag

* Support early stopping and refusing to train further

* Update docstring

* More saving

* Import oopsie

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* Make it go through TrainerArguments

* Document

* Fix test

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* Rework to allow for duplicates

* CLean

* Fix failing tests

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
muellerzr and amyeroberts authored Apr 25, 2024
1 parent 86f2569 commit ad697f1
Show file tree
Hide file tree
Showing 4 changed files with 337 additions and 7 deletions.
58 changes: 55 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from .trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
ExportableState,
PrinterCallback,
ProgressCallback,
TrainerCallback,
Expand Down Expand Up @@ -649,12 +650,15 @@ def __init__(
else:
self.label_smoother = None

self.control = TrainerControl()

self.state = TrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
],
)

self.control = TrainerControl()
# Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
# returned to 0 every time flos need to be logged
self.current_flos = 0
Expand Down Expand Up @@ -1499,6 +1503,8 @@ def _tune_save_checkpoint(self, checkpoint_dir: str):
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir, _internal_call=True)
if self.args.should_save:
# Update the `TrainerControl` state to where we are currently
self.state.stateful_callbacks["TrainerControl"] = self.control.state()
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
Expand Down Expand Up @@ -1970,7 +1976,11 @@ def _inner_training_loop(
if not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

self.state = TrainerState()
self.state = TrainerState(
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
]
)
self.state.is_hyper_param_search = trial is not None
self.state.train_batch_size = self._train_batch_size

Expand Down Expand Up @@ -2079,6 +2089,7 @@ def _inner_training_loop(
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
self.compare_trainer_and_checkpoint_args(self.args, self.state)
self._load_callback_state()
epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
Expand Down Expand Up @@ -2786,6 +2797,8 @@ def _save_checkpoint(self, model, trial, metrics=None):

# Save the Trainer state
if self.args.should_save:
# Update the `TrainerControl` state to where we are currently
self.state.stateful_callbacks["TrainerControl"] = self.control.state()
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))

if self.args.push_to_hub:
Expand Down Expand Up @@ -2970,6 +2983,45 @@ def opt_load_hook(mod, opt):
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)

def _load_callback_state(self):
"""If callback states exist and were passed in, restore their states if enabled"""
if not self.args.restore_callback_states_from_checkpoint:
return
# Callback states are stored in stateful_callbacks
not_found = []
new_callbacks = []
original_callbacks = self.callback_handler.callbacks + [self.control]
for stored_callback, data in self.state.stateful_callbacks.items():
if not isinstance(data, list):
data = [data]
if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks):
# We can load/restore from multiple callbacks of the same type.
duplicates = [
callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback
]
for callback, callback_data in zip(duplicates, data):
args = callback_data.get("args", {})
attributes = callback_data.get("attributes", {})
new_callback = type(callback)(**args)
for attribute, value in attributes.items():
setattr(new_callback, attribute, value)
if isinstance(callback, TrainerControl):
# Specifically for restoring the `control` state
self.control = new_callback
else:
new_callbacks.append(new_callback)
# We remove the existing callback and add it to the list of new callbacks
self.callback_handler.remove_callback(type(new_callback))
logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in")
else:
not_found.append(stored_callback)
if len(not_found) > 0:
logger.warning(
f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})"
)
for callback in new_callbacks:
self.callback_handler.add_callback(callback)

def hyperparameter_search(
self,
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
Expand Down
98 changes: 96 additions & 2 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ class TrainerState:
is_hyper_param_search (`bool`, *optional*, defaults to `False`):
Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will
impact the way data will be logged in TensorBoard.
stateful_callbacks (`List[StatefulTrainerCallback]`, *optional*):
Callbacks attached to the `Trainer` that should have their states be saved or restored.
Relevent callbacks should implement a `state` and `from_state` function.
"""

epoch: Optional[float] = None
Expand All @@ -104,10 +107,34 @@ class TrainerState:
is_hyper_param_search: bool = False
trial_name: str = None
trial_params: Dict[str, Union[str, float, int, bool]] = None
stateful_callbacks: List["TrainerCallback"] = None

def __post_init__(self):
if self.log_history is None:
self.log_history = []
if self.stateful_callbacks is None:
self.stateful_callbacks = {}
elif isinstance(self.stateful_callbacks, dict):
# We are loading the callbacks in from the state file, no need to process them
pass
else:
# Saveable callbacks get stored as dict of kwargs
stateful_callbacks = {}
for callback in self.stateful_callbacks:
if not isinstance(callback, (ExportableState)):
raise TypeError(
f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}"
)
name = callback.__class__.__name__
if name in stateful_callbacks:
# We can have multiple versions of the same callback
# if so, we store them as a list of states to restore
if not isinstance(stateful_callbacks[name], list):
stateful_callbacks[name] = [stateful_callbacks[name]]
stateful_callbacks[name].append(callback.state())
else:
stateful_callbacks[name] = callback.state()
self.stateful_callbacks = stateful_callbacks

def save_to_json(self, json_path: str):
"""Save the content of this instance in JSON format inside `json_path`."""
Expand All @@ -123,8 +150,52 @@ def load_from_json(cls, json_path: str):
return cls(**json.loads(text))


class ExportableState:
"""
A class for objects that include the ability to have its state
be saved during `Trainer._save_checkpoint` and loaded back in during
`Trainer._load_from_checkpoint`.
These must implement a `state` function that gets called during the respective
Trainer function call. It should only include parameters and attributes needed to
recreate the state at a particular time, to avoid utilizing pickle/maintain standard
file IO writing.
Example:
```python
class EarlyStoppingCallback(TrainerCallback, ExportableState):
def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
self.early_stopping_patience = early_stopping_patience
self.early_stopping_threshold = early_stopping_threshold
# early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
self.early_stopping_patience_counter = 0
def state(self) -> dict:
return {
"args": {
"early_stopping_patience": self.early_stopping_patience,
"early_stopping_threshold": self.early_stopping_threshold,
},
"attributes": {
"early_stopping_patience_counter": self.early_stopping_patience_counter,
}
}
```"""

def state(self) -> dict:
raise NotImplementedError("You must implement a `state` function to utilize this class.")

@classmethod
def from_state(cls, state):
instance = cls(**state["args"])
for k, v in state["attributes"].items():
setattr(instance, k, v)
return instance


@dataclass
class TrainerControl:
class TrainerControl(ExportableState):
"""
A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
switches in the training loop.
Expand Down Expand Up @@ -172,6 +243,18 @@ def _new_step(self):
self.should_evaluate = False
self.should_log = False

def state(self) -> dict:
return {
"args": {
"should_training_stop": self.should_training_stop,
"should_epoch_stop": self.should_epoch_stop,
"should_save": self.should_save,
"should_evaluate": self.should_evaluate,
"should_log": self.should_log,
},
"attributes": {},
}


class TrainerCallback:
# no-format
Expand Down Expand Up @@ -546,7 +629,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
print(logs)


class EarlyStoppingCallback(TrainerCallback):
class EarlyStoppingCallback(TrainerCallback, ExportableState):
"""
A [`TrainerCallback`] that handles early stopping.
Expand Down Expand Up @@ -605,3 +688,14 @@ def on_evaluate(self, args, state, control, metrics, **kwargs):
self.check_metric_value(args, state, control, metric_value)
if self.early_stopping_patience_counter >= self.early_stopping_patience:
control.should_training_stop = True

def state(self) -> dict:
return {
"args": {
"early_stopping_patience": self.early_stopping_patience,
"early_stopping_threshold": self.early_stopping_threshold,
},
"attributes": {
"early_stopping_patience_counter": self.early_stopping_patience_counter,
},
}
9 changes: 9 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,9 @@ class TrainingArguments:
Note that when this is true, you won't be able to resume training from checkpoint.
This enables you to save storage by not storing the optimizer, scheduler & rng state.
You can only load the model using `from_pretrained` with this option set to `True`.
restore_callback_states_from_checkpoint (`bool`, *optional*, defaults to `False`):
Whether to restore the callback states from the checkpoint. If `True`, will override
callbacks passed to the `Trainer` if they exist in the checkpoint."
use_cpu (`bool`, *optional*, defaults to `False`):
Whether or not to use cpu. If set to False, we will use cuda or mps device if available.
seed (`int`, *optional*, defaults to 42):
Expand Down Expand Up @@ -951,6 +954,12 @@ class TrainingArguments:
)
},
)
restore_callback_states_from_checkpoint: bool = field(
default=False,
metadata={
"help": "Whether to restore the callback states from the checkpoint. If `True`, will override callbacks passed to the `Trainer` if they exist in the checkpoint."
},
)
no_cuda: bool = field(
default=False,
metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."},
Expand Down
Loading

0 comments on commit ad697f1

Please sign in to comment.