swin transformer no speed up with torch.cuda.amp #623
Replies: 2 comments
-
@thomas0809 I measure gains of approx 50% for both infer and train on my 3090 card w/ NGC 21.03. I haven't tried on newer NGC yet, perhaps float32 is also improved so there is less gap? Or maybe the float32 on the A100 is even stronger relative to the float16 on the A100 arch vs 3090. Pure float16 is 2x gain, so could also be lots of ops that aren't autocast... would work for inference but could be unstable for train. |
Beta Was this translation helpful? Give feedback.
-
@thomas0809 |
Beta Was this translation helpful? Give feedback.
-
I am trying to use
torch.cuda.amp
to speed up training. However, I found on A100 machines, amp couldn't speed up swin transformers, while it worked pretty well for other models such as resnet. See a detailed example here.Base on the profile, a lot of time was spent on the
copy_device_to_device
operator. I think there should be a way to get rid of these operations. Hope some one more familiar with the implementation can help!Beta Was this translation helpful? Give feedback.
All reactions