-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How do you correctly restart training? #83
Comments
You're looking for either a skorch regressor with This is a good suggestion; maybe we should have an option to set warm_start in the skorch model or as another config to pass around? |
@zulissi This shouldn't be an issue. The skorch regressor loads the weights, optimizer, history after it is reinitialized: Lines 303 to 331 in 9b71f1f
Additionally, we have tests in place to test this specific functionality: https://github.com/ulissigroup/amptorch/blob/master/amptorch/tests/pretrained_test.py I ran a quick test using the We see that the losses are indeed starting from a point close to the end of the previous run. @vsumaria Can you share how you're attempting to retrain. Maybe the tests linked above can help provide guidance. All I've done to generate the above images was:
|
Thanks @mshuaibii! |
My learning rate also resets if I am using a LR scheduler. |
I tried "warm_start= True", it starts correctly, but you see the loss function starting to rise and then decrease. I am not 100% sure how the gradients get initialized when using this setting. But it at least not restarting the whole training. |
Yeah this is strange...I'm unable to replicate the
Yeah currently the learning rate scheduler reinitializes after training. Skorch doesn't save |
This sounds like the learning rate got reset to the original learning rate, so the loss goes up then goes back down as learning rate drops. |
@mshuaibii you are right. I tried to restart training, it is behaving correctly. I didn't realize the "+" sign meant writing of new check points. (AMP generated checkpoints at equal epoch intervals). Btw, how do you decide which epoch should be used as the checkpoint ? Is it based on just the validation loss? Sorry for the confusion. I think the LR reseting also might have been because of the same reason. |
It's a skorch callback for the best loss Line 129 in b18e3cb
For larger datasets saving the checkpoint with the lowest validation error is probably better, but most of our tests have been with such small datasets that any train/validation split isn't very helpful. |
I think that's a good rule of thumb to also avoid overfitting the potential. Closing this comment, I think the retraining is working fine, just didn't realize what check point was being used for the retraining. (I'd update the main README file which is still using the "load_checkpoint" function @mshuaibii) |
I am trying to restart training using pre-trained weights from the checkpoint. I load then using the load_pretrained function, but the training seems to start from the beginning. What am I doing wrong ?
The text was updated successfully, but these errors were encountered: