You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am training with the mosaic-bert-base-uncased.yaml recipe on 8xA40s, with data created with the mosaic provided script for c4. I consistently get a loss spike and loss being stuck around 10k -15k steps into training. The only change is using fp32 instead of bfloat16.
Below spikes at ~10k
The text was updated successfully, but these errors were encountered:
We did most of our benchmarking and testing in bf16, and chose hyperparameters that worked well with this setting.
With regards to fp32, we would recommend lowering the learning rate slightly as a first line of defense (barring potential issues relating to environment). If you are also playing around with a larger architecture (e.g. BERT-Large), we would also recommend lowering the learning rate!
Hi @jacobfulano. Thank you for the advice! I actually realized the issue was with my flash attention being broken and am resolving that in another issue.
I am training with the
mosaic-bert-base-uncased.yaml
recipe on 8xA40s, with data created with the mosaic provided script for c4. I consistently get a loss spike and loss being stuck around 10k -15k steps into training. The only change is using fp32 instead of bfloat16.Below spikes at ~10k
The text was updated successfully, but these errors were encountered: