Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add improved earlystopping #632

Merged

Conversation

AMHermansen
Copy link
Collaborator

@AMHermansen AMHermansen commented Nov 22, 2023

Adds an early-stopping callback, which also saves and loads the best weights, along with the model config.
For the design I've added an outdir, to save the state_dict/model_config to, so I'm not sure if it can replace the default early stopping.

I couldn't come up with a good descriptive name for the callback, that wasn't overly verbose, but I'm more than open to suggestions.

@RasmusOrsoe
Copy link
Collaborator

Hey @AMHermansen! Thanks for giving this a crack.

After spending some time thinking about this, I realized that we can solve this without adding a custom callback class.

in StandardModel.fit we process the callbacks argument - and we have functionality that creates the callbacks in case none is given. This part of the code base is old and when I reviewed it today I realized it's unnecessarily complicated -> complexity arises from trying to infer if early stopping was given and if a validation loader is present. I think this is an ideal place to add a ModelCheckpoint callback on behalf of the user while giving this part of the code a little love.

We introduce minor changes to .fit like so:

def fit(
        self,
        train_dataloader: DataLoader,
        val_dataloader: Optional[DataLoader] = None,
        *,
        max_epochs: int = 10,
        early_stopping_patience: int = 5,
        gpus: Optional[Union[List[int], int]] = None,
        callbacks: Optional[List[Callback]] = None,
        ckpt_path: Optional[str] = None,
        logger: Optional[LightningLogger] = None,
        log_every_n_steps: int = 1,
        gradient_clip_val: Optional[float] = None,
        distribution_strategy: Optional[str] = "ddp",
        **trainer_kwargs: Any,
    ) -> None:
    """Fit `StandardModel` using `pytorch_lightning.Trainer`."""
    # Checks
    if callbacks is None:
        # We create the bare-minimum callbacks for you.
        callbacks = self._create_default_callbacks(
            val_dataloader=val_dataloader,
        )
    else:
        # You are on your own!
        # We just add the progressbar if you forgot it.
        has_progress_bar = False
        for callback in callbacks:
            if isinstance(callback, ProgressBar):
                has_progress_bar = True
        if has_progress_bar is False:
            callbacks.append(ProgressBar())
    
    has_early_stopping = self._has_early_stopping(callbacks)
    has_model_checkpoint = self._has_model_checkpoint(callbacks)

    self.train(mode=True)
    trainer = self._construct_trainer(
        max_epochs=max_epochs,
        gpus=gpus,
        callbacks=callbacks,
        logger=logger,
        log_every_n_steps=log_every_n_steps,
        gradient_clip_val=gradient_clip_val,
        distribution_strategy=distribution_strategy,
        **trainer_kwargs,
    )

    try:
        trainer.fit(
            self, train_dataloader, val_dataloader, ckpt_path=ckpt_path
        )
    except KeyboardInterrupt:
        self.warning("[ctrl+c] Exiting gracefully.")
        pass

    # Load weights from best-fit model after training if possible
    if has_early_stopping & has_model_checkpoint:
        for callback in callbacks:
            if isinstance(callback, ModelCheckpoint):
                checkpoint_callback = callback 
        self.load_state_dict(torch.load(checkpoint_callback.best_model_path)['state_dict'])
   else: 
        # raise informative warning

The idea here is to toggle between two cases; either the user gave no callbacks or the user did. If none is given, we infer bare-minimum callbacks based on the other arguments given. Notice this introduces the new argument early_stopping_patience. If callbacks are given, the user is "on its own". This allows us to simplify the code a little bit.

After the training is finished, we check if has_early_stopping & has_model_checkpoint and load in the best-fit model parameters if possible. This would result in "expected" behavior of model.fit when validation loader is given, and would allow us to shave off a few lines of code in the example scripts because specifying callbacks is not needed. @AMHermansen what do you think?

The default callback function could then be simplified to:

def _create_default_callbacks(self, 
                                  val_dataloader: DataLoader, 
                                  early_stopping_patience: int) -> List:
    """ Create default callbacks.
    
    Used in cases where no callbacks are specified by the user in .fit"""
    callbacks = [ProgressBar()]
    if val_dataloader is not None:
        # Add Early Stopping
        callbacks.append(EarlyStopping(
                            monitor="val_loss",
                            patience=early_stopping_patience,
                        ))
        # Add Model Check Point
        callbacks.append(ModelCheckpoint(save_top_k=1, monitor="val_loss", mode="min",
                        filename= f"{self._gnn.__class__.__name__}"+'-{epoch}-{val_loss:.2f}-{train_loss:.2f}'))
        self.info(f'EarlyStopping has been added with a patience of {early_stopping_patience}.')
    return callbacks   

@AMHermansen
Copy link
Collaborator Author

Hello @RasmusOrsoe Thank you for your input.

I think your solution also works. I'm personally not a big fan of "enforcing" non-essential default callbacks, but I understand if you would prefer to have default callbacks, to reduce "boilerplate-y"-code in the training scripts.

I think if you end up going with adding a mixture of EarlyStopping and ModelCheckpoint as default callbacks. Then you should make sure the model logs the checkpoint file-path used.

Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AMHermansen as mentioned in the call today, I think we should offer this callback as a compliment to the changes I proposed earlier. I left a few minor comments. I'll introduce my proposed changes in a separate PR.

Can you confirm that the callback works as intended?

**kwargs: Keyword arguments to pass to `EarlyStopping`. See
`pytorch_lightning.callbacks.EarlyStopping` for details.
"""
self.save_dir = save_dir
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to specify the save_dir here? From the documentation of ModelCallback it says it will automatically default to the directory of the Trainer.

Copy link
Collaborator Author

@AMHermansen AMHermansen Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could probably change the callback to get the directory from the trainer, or maybe have that be a fallback option.

This callback is only subclassing EarlyStopping, and adding functionality such that whenever the "best_score" is changed it will save the statedict to disk, and it also saves the model_config to disk at the start of the run.

The idea behind this behavior over ModelCheckpoint, is to be more inline, with how models are best saved within GraphNeT. ModelCheckpoints store more data, and are to my experience not straight forward to load in, when the hyperparameters of a lightning model is not saved with a call to save_hyperparameters().

ModelCheckpoints to my understanding are more-so used to be able to "interrupt" a training and resume it at a later point, which is usefull when running on clusters where you might reserve a node for less time, than what it takes to finish a training.

Let me know what you think.

@@ -152,3 +158,92 @@ def on_train_epoch_end(
h.setLevel(logging.ERROR)
logger.info(str(super().train_progress_bar))
h.setLevel(level)


class GraphnetEarlyStopping(EarlyStopping):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GraphNeTEarlyStopping

"""Early stopping callback for graphnet."""

def __init__(self, save_dir: str, **kwargs: Dict[str, Any]) -> None:
"""Construct `GraphnetEarlyStopping` Callback.
Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GraphNeTEarlyStopping

@AMHermansen
Copy link
Collaborator Author

AMHermansen commented Nov 29, 2023

@AMHermansen as mentioned in the call today, I think we should offer this callback as a compliment to the changes I proposed earlier. I left a few minor comments. I'll introduce my proposed changes in a separate PR.

Can you confirm that the callback works as intended?

I can confirm that it works as intended. I tried a slightly modified version of it, that also logged in a verbose manner whenever it saved something to disk, and it was doing it corretly (i.e. only at epochs where it achieved an improved validation loss)

@RasmusOrsoe RasmusOrsoe self-requested a review December 1, 2023 08:15
@AMHermansen AMHermansen merged commit 800ebd9 into graphnet-team:main Dec 1, 2023
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants