Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Oscillation in the train loss on multiple GPUs #2339

Closed
sinahmr opened this issue Nov 20, 2024 · 11 comments
Closed

Oscillation in the train loss on multiple GPUs #2339

sinahmr opened this issue Nov 20, 2024 · 11 comments
Assignees
Labels
bug Something isn't working

Comments

@sinahmr
Copy link
Contributor

sinahmr commented Nov 20, 2024

I'm running the command provided here for training a ViT, once using batch size 288 on two GPUs (like the link), and once using batch size 576 on one GPU. As you can see in the plot below, the training loss for the run with one GPU is much smoother than the one with two GPUs, which oscillates a lot (although still similar decreasing trend), and sometimes makes training unstable.
Is this behaviour expected? If not, I suspect there should be some errors in the implementation of the multi-GPU code, but couldn't find out. Can you please have a look?
Thanks!

To Reproduce
./distributed_train.sh {1 or 2} --data-dir /path/to/100class/data --num-classes 100 --model vit_small_patch16_224 --sched cosine --epochs 300 --opt adamw -j 8 --warmup-lr 1e-6 --mixup .2 --model-ema --model-ema-decay 0.99996 --aa rand-m9-mstd0.5-inc1 --remode pixel --reprob 0.25 --amp --lr 5e-4 --weight-decay .05 --drop 0.1 --drop-path .1 -b {576 or 288}

Expected behavior
Fairly similar train loss for the experiments.

Screenshots
W B Chart 2024-11-19, 11_06_25 PM

Environment

  • OS: Ubuntu 22.04
  • Using timm v1.0.11
  • PyTorch 2.4.1 with CUDA 12.2

Additional context
I also tried an experiment using batch size 288 on one GPU but with --grad-accum-steps 2 (to have a global batch size of 576 like the other experiments) and saw no problem (no extreme oscillation) in the loss plot, it was alright like the other one on a single GPU.

@sinahmr sinahmr added the bug Something isn't working label Nov 20, 2024
@rwightman
Copy link
Collaborator

@sinahmr unrelated to the oscillations, but those hparams are a bit out of date, there's some better ones to base off here: https://gist.github.com/rwightman/fb37c339efd2334177ff99a8083ebbc4

On this, I don't think there's any bug, I've done so much training with these scripts without issue 1,2,4,8, all the way to 512 or so GPU is what I've run in the past. A lot of 2, 4, and 8 in the rececnt months.

So the 576 vs 288 x 2 in the graph is without involving the grad accum right? Does 576 x 2 work? What's the dataset like, is it long tail dist with some challenging samples, uneven class dist or more uniform? If you run with a diff seed at 288 x 2 does it change? If you pull back the lr to 1e-4 does 288x2 + 576 match?

@sinahmr
Copy link
Contributor Author

sinahmr commented Nov 20, 2024

@rwightman Thanks for sharing the newer hparams!
Update: I took a look at the link, and then at the registered models in models/vision_transformer.py. There are some variants that are a bit confusing to me, like little, medium, betwixt, wee, etc. Is there an article in which the new variants are explained alongside which one of the original variants they are to be compared with? I only found this which is a general explanation.

Yes, in the graph above, --grad-accum-steps is 1 in all cases. If I included 1 GPU x 288 batch x grad-accum-step 2 it would be close to the blue one, i.e., no severe oscillation.

The dataset is Mini-ImageNet, which is a subset of ImageNet-1K, with the same size images (100 categories, 500 training and 100 testing samples per category). So, I think it shouldn't make a big difference.

Besides, I have experiments on ImageNet-1K, but using Swin Transformer (and a different set of hparams), and a similar phenomenon is visible. More specifically, the train loss of 2 GPU x 512 batch x grad-accum-step 1 (blue, "azure-bird-10") oscillates much more than 1 GPU x 512 batch x grad-accum-step 2 (red, "laced-salad-21").

Based on your suggestions I'll try three new experiments and will keep you posted about the results:
  1. LR=1e-4 for both 2 GPU x 288 batch x grad-accum-step 1 and 1 GPU x 576 batch x grad-accum-step 1.
  2. 2 GPU x 288 batch x grad-accum-step 1 with a different seed.

@rwightman
Copy link
Collaborator

@sinahmr

There are two things which could be having an impact

  1. in train, the loss averaging is different, to keep it simple and avoid a reduction every step I only do loss reduction for train logging every log interval, this is cosmetic only .. eval loss & eval metrics should track the same
  2. your EMA isn't great for a small dataset ... should bring it down to 0.999 range perhaps. Also the newer --model-ema-warmup arg might be useful to mitigate this, should work better the original constant

Second one won't have any impact on train loss but could result in really slow apparent convergence of eval and at end of training might have have ramped up the EMA weights fully..

Is there any difference of note in the eval loss & accuracy for the different runs?

There are two things with those SBB ViT...

  1. further refinement of hparams for training from scratch on ImageNet
  2. a different set of model 'shapes', ie diff width / depth / mlp ratio combos.. hence the weird names, 'little' is closer to 'small' but I feel works a bit better on smaller datasets

@rwightman
Copy link
Collaborator

You can try changing the loss averaging to verify train losses, pull the part within the logging block out as an else in the main part of the train step. I don't normally train w/ distributed for datasets that are this small so there are more sampling points and this doesn't have as big an impact on the logging...

So:

if args.distributed:
    reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
    losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
else:
    losses_m.update(loss.item() * accum_steps, input.size(0))

@rwightman
Copy link
Collaborator

rwightman commented Nov 20, 2024

#2340 will keep running avg of loss updated every step (completely independent) and only sync for the logs and final return. So better mix of current and proposed verification above.

In any case I'm not seeing any issues running with this dataset https://www.kaggle.com/datasets/ctrnngtrung/miniimagenet ... 2 GPU and same hparams (though did --model-ema-decay 0.999 --model-ema-warmup), train loss in the log was bouncing around as you say but the eval tracked fine. WIth #2340 eval is tracking the same and the train loss (logged) is smoother since it includes all batches in the sampling.

@sinahmr
Copy link
Contributor Author

sinahmr commented Nov 20, 2024

I ran an experiment with the changes in #2340, and I confirm that the oscillating train loss problem on multi-GPU is gone.
Thanks a lot for taking the time.

(Not important anymore, just to note about the previous discussion: changing LR or seed didn't help the oscillation, #2340 does.)

@rwightman
Copy link
Collaborator

@sinahmr thanks for the update, you mentioned training was unstable too... this loss fix in #2340 is purely cosmetic, it's no longer sampling the loss for logging sparsley during distributed training, in either case the training progression and the evaluation results should be the same. Before I close, do you observe that?

@sinahmr
Copy link
Contributor Author

sinahmr commented Nov 21, 2024

It happened to me once that a run on 2 GPUs became unstable in the middle of training (blue), and the same run on 1 GPU with --grad-accum-step 2 did not (red).

But to be honest, now that I know that the oscillations I previously observed was just a problem in the graph, I'm thinking maybe I made some other mistake in the blue run. I have done many experiments on a custom ViT, and as far as I remember eval loss and accuracy never looked problematic.

Also, eval loss and accuracy of the run with #2340 looks very close to the run on 1 GPU. I think the slight difference is fine, when we use 2 GPUs the order of data is changed, is that right?

@rwightman
Copy link
Collaborator

@sinahmr yes, the order of data won't be preserved as you change the world size. I'll merge 2340 soon with another small change I wanted to make for results output.

@sinahmr
Copy link
Contributor Author

sinahmr commented Nov 21, 2024

Great, thank you!

@rwightman
Copy link
Collaborator

merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants