Replies: 8 comments 3 replies
-
@zeyuwang615 I can't tell if your GPU setup uses the same global batch size, you need to scale the LR per global batch ... the deit code (based on timm for the augs, sched, opt, etc) uses this formula You should use at least 300 epochs and cosine lr schedule. RepeatAug definitely works better with more epochs and is often quite a bit worse earlier on. I adapted repeat aug from the DeiT impl by Hugo and it has been tuned for 4 x GPU. I think it should work okay distributed across 8 processes but I haven't done much experiments, it's a bit touchy. The deit code uses opt-eps of 1e-8 but I've often found 1e-6 better, so could be something to try. |
Beta Was this translation helpful? Give feedback.
-
@rwightman Yes, for my GPU setup, I use batchsize=256 per GPU for 4 GPUs in total. So I'd say the global batch size is the same for GPU and TPU setting. Thanks for mentioning the lr scale strategy used in deit codebase! That's a vital difference in implemantation between deit code and this codebase! Sorry, I misses it. I'll try using the same formula Yeah, I know that RAS sampler works better with more epochs. I only trained the model for 100 epochs for fast checking. Also, I cann't see why the RAS sampler in deit implementation cannot work with 8 processes. Just set num_replicas=8 do the job. Since I've also included experiment results without using RAS sampler, I think the problem lies in elsewhere. As for other strategy like setting lr to |
Beta Was this translation helpful? Give feedback.
-
@zeyuwang615 the GPU and TPU are different enough that things will not quite be the same, but any differences should be far smaller than what you saw. The multiplies in matmul, conv, etc that get done in the TPU MXU are always done internally at bfloat16 regardless of the precision set in PyTorch. This is quite a bit different than the way it works on GPU. One thing that I've not been 100% clear on is what the above means for elementwise divide ops, if that is done in bfloat16 it would impact ops like normalizing, so the eps value you use could be important (probably want it a bit larger than usual on GPU). I have noted some sentitivies to eps values in the past, but usually I get things working well on TPU. The only timm specific issue I can think of right now is random erasing. I can't do random erasing on the device with PyTorch XLA, so I had to make a special path (added a few months ago) that pushes it down into the CPU transforms, this changes where normalization happens, etc. I've tested it a fair bit and it seems to work okay but I haven't done extensive CPU vs GPU runs with same hparams. I think I did a quick ResNet50 check. If none of the other ideas above help, running with reprob = 0. and maybe adding some classifier dropout |
Beta Was this translation helpful? Give feedback.
-
One other q, you are running both GPU and TPU from bits_and_tpu branch? In case I have some default args that differ between the two... |
Beta Was this translation helpful? Give feedback.
-
Yes, I've noticed the issue of random erasing. In fact, cutmix also needs to be done on CPU. I think now direct indexing on pytorch tensor using TPU is not supported. I will run the sanity check. That might take several days. I will report my experiment results here when I have them all. Sorry I didn't make it clear at first. On TPU I'm using code from bits_and_tpu branch. On GPU I'm using code directly from original DeiT codebase. So every result from GPU is reliable. |
Beta Was this translation helpful? Give feedback.
-
@zeyuwang615 in that case, the LR scaling difference will definitely have a big impact, esp early on. I always do that calc myself manually, since the optimal scaling does change with optimizer, but some codebases like diet do it for you relative to a base batch size such as 512 or 256. A good detail to check. Double check the args again to see if any defaults differ. LR + any other default args would be much higher priority than fiddling with random erasing as I feel it is working okay (despite slowdown). I do mixup + cutmix during the transfer to device (part of collate) so it's still done on CPU. Looking forward to seeing your next set of results. I'm moving this to discussion as I feel it's likely not a bug and will be helpful to persist rather than close. |
Beta Was this translation helpful? Give feedback.
-
Have your had any success in reproducing the results? If so, could you please provide the hyperparmeters you've used? |
Beta Was this translation helpful? Give feedback.
-
I am seeing similar issue for DeiT training, where bf16 models are in general a few points off from the reported amp fp16 SOTA. Not sure if this is a general phenomenon or something to do with worse precision of the data type compared to fp16. |
Beta Was this translation helpful? Give feedback.
-
Describe the bug
Hi,I am trying to train a deit_small_patch16_224 model with a Google Cloud TPU V3 VM Instance using the bits_and_tpu branch. The strange thing is, though I am able to run the code, the final accuracy of the model is non-trivially lower than the counterpart trained on GPU. It's 73.1 vs 74.6 with Repeated Augmentation, 74.7 vs over 76 without Repeated Augmentation (I manually merge the RAS sampler from master branch). I wonder if there is any way to locate the problem and fix it?
To Reproduce
The training command I used is as follows (Note that amp is disabled since currently amp is not supported on TPU):
python3 launch_xla.py --num-devices 8 train.py /mnt/disks/persist/ImageNet --val-split val --model deit_small_patch16_224 --batch-size 128 --opt adamw --opt-eps 1e-8 --momentum 0.9 --weight-decay 0.05 --sched cosine --lr 5e-4 --lr-cycle-decay 1.0 --warmup-lr 1e-6 --epochs 100 --decay-epochs 30 --cooldown-epochs 0 --aa rand-m9-mstd0.5-inc1 (--aug-repeats 3) --reprob 0.25 --mixup 0.8 --cutmix 1.0 --smoothing 0.1 --train-interpolation bicubic --drop-path 0.1 --workers 12 --seed 0 --pin-mem --output output --experiment deit_small_patch16_224_default_noamp_epoch100 --log-interval 200 --recovery-interval 12500
Training Log
Here is the log for repeat-aug-training on TPU:
deit_small_patch16_224_repeated_aug_epoch100.csv
Here is the log for repeat-aug-training on GPU:
deit_small_patch16_224_repeat_aug_epoch100.txt
Here is the log for no-repeat-aug-training on TPU:
deit_small_patch16_224_no_repeated_aug_epoch100.csv
Sorry, I don't have a exact log for no-repeat-aug-trainingon GPU, but it should closely resemble this:
similar_deit_small_patch16_224_no_repeat_aug_epoch100.txt
Server
Additional Comment
Maybe there is some data augmentation or optimization strategy I am missing here. But I've also tried modify the original Deit codebase using the same PyTorch/XLA technique (so that data augs and optimization will be exactly the same as original DeiT), the training results is still consistently lower than the counterpart on GPU for 100 epochs training (around absolute 2%). For 300 epochs training, the TPU model accuracy is lower than that of the GPU model for the first 80-100 epochs, but will catch up in the 150th epoch.
Beta Was this translation helpful? Give feedback.
All reactions