Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss spike when training mosaic-bert (fp32) #236

Closed
sameerreddy13 opened this issue Mar 14, 2023 · 3 comments
Closed

Loss spike when training mosaic-bert (fp32) #236

sameerreddy13 opened this issue Mar 14, 2023 · 3 comments

Comments

@sameerreddy13
Copy link

sameerreddy13 commented Mar 14, 2023

I am training with the mosaic-bert-base-uncased.yaml recipe on 8xA40s, with data created with the mosaic provided script for c4. I consistently get a loss spike and loss being stuck around 10k -15k steps into training. The only change is using fp32 instead of bfloat16.

Below spikes at ~10k
Screenshot 2023-03-14 at 9 14 53 AM
Screenshot 2023-03-14 at 9 15 13 AM

@sameerreddy13
Copy link
Author

For my environment you can refer to my other issue (#237 (comment))

@jacobfulano
Copy link
Contributor

Hey @sameerreddy13,

We did most of our benchmarking and testing in bf16, and chose hyperparameters that worked well with this setting.

With regards to fp32, we would recommend lowering the learning rate slightly as a first line of defense (barring potential issues relating to environment). If you are also playing around with a larger architecture (e.g. BERT-Large), we would also recommend lowering the learning rate!

@sameerreddy13
Copy link
Author

Hi @jacobfulano. Thank you for the advice! I actually realized the issue was with my flash attention being broken and am resolving that in another issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants