-
Notifications
You must be signed in to change notification settings - Fork 97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add improved earlystopping #632
Add improved earlystopping #632
Conversation
Hey @AMHermansen! Thanks for giving this a crack. After spending some time thinking about this, I realized that we can solve this without adding a custom callback class. in We introduce minor changes to . def fit(
self,
train_dataloader: DataLoader,
val_dataloader: Optional[DataLoader] = None,
*,
max_epochs: int = 10,
early_stopping_patience: int = 5,
gpus: Optional[Union[List[int], int]] = None,
callbacks: Optional[List[Callback]] = None,
ckpt_path: Optional[str] = None,
logger: Optional[LightningLogger] = None,
log_every_n_steps: int = 1,
gradient_clip_val: Optional[float] = None,
distribution_strategy: Optional[str] = "ddp",
**trainer_kwargs: Any,
) -> None:
"""Fit `StandardModel` using `pytorch_lightning.Trainer`."""
# Checks
if callbacks is None:
# We create the bare-minimum callbacks for you.
callbacks = self._create_default_callbacks(
val_dataloader=val_dataloader,
)
else:
# You are on your own!
# We just add the progressbar if you forgot it.
has_progress_bar = False
for callback in callbacks:
if isinstance(callback, ProgressBar):
has_progress_bar = True
if has_progress_bar is False:
callbacks.append(ProgressBar())
has_early_stopping = self._has_early_stopping(callbacks)
has_model_checkpoint = self._has_model_checkpoint(callbacks)
self.train(mode=True)
trainer = self._construct_trainer(
max_epochs=max_epochs,
gpus=gpus,
callbacks=callbacks,
logger=logger,
log_every_n_steps=log_every_n_steps,
gradient_clip_val=gradient_clip_val,
distribution_strategy=distribution_strategy,
**trainer_kwargs,
)
try:
trainer.fit(
self, train_dataloader, val_dataloader, ckpt_path=ckpt_path
)
except KeyboardInterrupt:
self.warning("[ctrl+c] Exiting gracefully.")
pass
# Load weights from best-fit model after training if possible
if has_early_stopping & has_model_checkpoint:
for callback in callbacks:
if isinstance(callback, ModelCheckpoint):
checkpoint_callback = callback
self.load_state_dict(torch.load(checkpoint_callback.best_model_path)['state_dict'])
else:
# raise informative warning The idea here is to toggle between two cases; either the user gave no callbacks or the user did. If none is given, we infer bare-minimum callbacks based on the other arguments given. Notice this introduces the new argument After the training is finished, we check The default callback function could then be simplified to: def _create_default_callbacks(self,
val_dataloader: DataLoader,
early_stopping_patience: int) -> List:
""" Create default callbacks.
Used in cases where no callbacks are specified by the user in .fit"""
callbacks = [ProgressBar()]
if val_dataloader is not None:
# Add Early Stopping
callbacks.append(EarlyStopping(
monitor="val_loss",
patience=early_stopping_patience,
))
# Add Model Check Point
callbacks.append(ModelCheckpoint(save_top_k=1, monitor="val_loss", mode="min",
filename= f"{self._gnn.__class__.__name__}"+'-{epoch}-{val_loss:.2f}-{train_loss:.2f}'))
self.info(f'EarlyStopping has been added with a patience of {early_stopping_patience}.')
return callbacks |
Hello @RasmusOrsoe Thank you for your input. I think your solution also works. I'm personally not a big fan of "enforcing" non-essential default callbacks, but I understand if you would prefer to have default callbacks, to reduce "boilerplate-y"-code in the training scripts. I think if you end up going with adding a mixture of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AMHermansen as mentioned in the call today, I think we should offer this callback as a compliment to the changes I proposed earlier. I left a few minor comments. I'll introduce my proposed changes in a separate PR.
Can you confirm that the callback works as intended?
**kwargs: Keyword arguments to pass to `EarlyStopping`. See | ||
`pytorch_lightning.callbacks.EarlyStopping` for details. | ||
""" | ||
self.save_dir = save_dir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to specify the save_dir
here? From the documentation of ModelCallback
it says it will automatically default to the directory of the Trainer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could probably change the callback to get the directory from the trainer, or maybe have that be a fallback option.
This callback is only subclassing EarlyStopping, and adding functionality such that whenever the "best_score" is changed it will save the statedict to disk, and it also saves the model_config to disk at the start of the run.
The idea behind this behavior over ModelCheckpoint, is to be more inline, with how models are best saved within GraphNeT. ModelCheckpoints store more data, and are to my experience not straight forward to load in, when the hyperparameters of a lightning model is not saved with a call to save_hyperparameters()
.
ModelCheckpoints to my understanding are more-so used to be able to "interrupt" a training and resume it at a later point, which is usefull when running on clusters where you might reserve a node for less time, than what it takes to finish a training.
Let me know what you think.
@@ -152,3 +158,92 @@ def on_train_epoch_end( | |||
h.setLevel(logging.ERROR) | |||
logger.info(str(super().train_progress_bar)) | |||
h.setLevel(level) | |||
|
|||
|
|||
class GraphnetEarlyStopping(EarlyStopping): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GraphNeTEarlyStopping
"""Early stopping callback for graphnet.""" | ||
|
||
def __init__(self, save_dir: str, **kwargs: Dict[str, Any]) -> None: | ||
"""Construct `GraphnetEarlyStopping` Callback. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GraphNeTEarlyStopping
I can confirm that it works as intended. I tried a slightly modified version of it, that also logged in a verbose manner whenever it saved something to disk, and it was doing it corretly (i.e. only at epochs where it achieved an improved validation loss) |
Adds an early-stopping callback, which also saves and loads the best weights, along with the model config.
For the design I've added an outdir, to save the
state_dict/model_config
to, so I'm not sure if it can replace the default early stopping.I couldn't come up with a good descriptive name for the callback, that wasn't overly verbose, but I'm more than open to suggestions.