Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do you correctly restart training? #83

Closed
vsumaria opened this issue Feb 13, 2021 · 11 comments
Closed

How do you correctly restart training? #83

vsumaria opened this issue Feb 13, 2021 · 11 comments
Assignees
Labels
enhancement New feature or request

Comments

@vsumaria
Copy link

I am trying to restart training using pre-trained weights from the checkpoint. I load then using the load_pretrained function, but the training seems to start from the beginning. What am I doing wrong ?

@zulissi zulissi added the enhancement New feature or request label Feb 13, 2021
@zulissi
Copy link
Member

zulissi commented Feb 13, 2021

You're looking for either a skorch regressor with warm_start=True or using partial_fit instead of fit. As written in train it always re-initializes.

This is a good suggestion; maybe we should have an option to set warm_start in the skorch model or as another config to pass around?

@zulissi zulissi assigned zulissi and mshuaibii and unassigned zulissi Feb 13, 2021
@mshuaibii
Copy link
Collaborator

mshuaibii commented Feb 13, 2021

@zulissi This shouldn't be an issue. The skorch regressor loads the weights, optimizer, history after it is reinitialized:

self.net.initialize()
if gpu2cpu:
params_path = os.path.join(checkpoint_path, "params_cpu.pt")
if not os.path.exists(params_path):
params = torch.load(
os.path.join(checkpoint_path, "params.pt"),
map_location=torch.device("cpu"),
)
new_dict = OrderedDict()
for k, v in params.items():
name = k[7:]
new_dict[name] = v
torch.save(new_dict, params_path)
else:
params_path = os.path.join(checkpoint_path, "params.pt")
try:
self.net.load_params(
f_params=params_path,
f_optimizer=os.path.join(checkpoint_path, "optimizer.pt"),
f_criterion=os.path.join(checkpoint_path, "criterion.pt"),
f_history=os.path.join(checkpoint_path, "history.json"),
)
normalizers = torch.load(os.path.join(checkpoint_path, "normalizers.pt"))
self.feature_scaler = normalizers["feature"]
self.target_scaler = normalizers["target"]
except NotImplementedError:
print("Unable to load checkpoint!")

Additionally, we have tests in place to test this specific functionality: https://github.com/ulissigroup/amptorch/blob/master/amptorch/tests/pretrained_test.py

I ran a quick test using the train_example.py we have:

No retraining:
image

Retraining:
image

We see that the losses are indeed starting from a point close to the end of the previous run.

@vsumaria Can you share how you're attempting to retrain. Maybe the tests linked above can help provide guidance.

All I've done to generate the above images was:

trainer = AtomsTrainer(config) % using the same config used earlier
trainer.load_pretrained(checkpoint_path)
trainer.train()

@zulissi
Copy link
Member

zulissi commented Feb 13, 2021

Thanks @mshuaibii!

@vsumaria
Copy link
Author

I am doing exactly that, loading the AtomsTrainer with same config, then loading the pretrained model using the checkpoint_path and trainer.train(). But my last iteration in the previous training with decreasing loss is this:

Screen Shot 2021-02-13 at 10 43 31 AM

And the one after I restart is this:
Screen Shot 2021-02-13 at 10 45 52 AM

That's what is confusing me.

@vsumaria
Copy link
Author

My learning rate also resets if I am using a LR scheduler.

@vsumaria
Copy link
Author

You're looking for either a skorch regressor with warm_start=True or using partial_fit instead of fit. As written in train it always re-initializes.

This is a good suggestion; maybe we should have an option to set warm_start in the skorch model or as another config to pass around?

I tried "warm_start= True", it starts correctly, but you see the loss function starting to rise and then decrease. I am not 100% sure how the gradients get initialized when using this setting. But it at least not restarting the whole training.

@mshuaibii
Copy link
Collaborator

mshuaibii commented Feb 13, 2021

I am doing exactly that, loading the AtomsTrainer with same config, then loading the pretrained model using the checkpoint_path and trainer.train(). But my last iteration in the previous training with decreasing loss is this:

Screen Shot 2021-02-13 at 10 43 31 AM

And the one after I restart is this:
Screen Shot 2021-02-13 at 10 45 52 AM

That's what is confusing me.

Yeah this is strange...I'm unable to replicate the warm_start issue. Both give me the same results which is expected given that we manually initialize/override settings when we call load_pretrained. Can you try these tests out on the trainer_example.py and see if you get similar issues, this may make it easier to debug. Side note - the checkpoint that get's stored is marked by a + in the cp column of the output. For instance, line 934 of your output wasn't checkpointed so we shouldn't be comparing the retrained numbers to that value but rather the last +. Given how high that number is, I suspect the last checkpoint was still much lower than the retrained 1st epoch, but maybe you can confirm.

My learning rate also resets if I am using a LR scheduler.

Yeah currently the learning rate scheduler reinitializes after training. Skorch doesn't save lr information in the history which makes this a little tricky: skorch-dev/skorch#739. I'll take a look and see if I can find a work around.

@zulissi
Copy link
Member

zulissi commented Feb 13, 2021

I tried "warm_start= True", it starts correctly, but you see the loss function starting to rise and then decrease. I am not 100% sure how the gradients get initialized when using this setting. But it at least not restarting the whole training.

This sounds like the learning rate got reset to the original learning rate, so the loss goes up then goes back down as learning rate drops.

@vsumaria
Copy link
Author

vsumaria commented Feb 14, 2021

@mshuaibii you are right. I tried to restart training, it is behaving correctly. I didn't realize the "+" sign meant writing of new check points. (AMP generated checkpoints at equal epoch intervals).

Btw, how do you decide which epoch should be used as the checkpoint ? Is it based on just the validation loss?

Sorry for the confusion. I think the LR reseting also might have been because of the same reason.

@zulissi
Copy link
Member

zulissi commented Feb 14, 2021

It's a skorch callback for the best loss

callbacks.append(

For larger datasets saving the checkpoint with the lowest validation error is probably better, but most of our tests have been with such small datasets that any train/validation split isn't very helpful.

@vsumaria
Copy link
Author

I think that's a good rule of thumb to also avoid overfitting the potential.

Closing this comment, I think the retraining is working fine, just didn't realize what check point was being used for the retraining. (I'd update the main README file which is still using the "load_checkpoint" function @mshuaibii)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants