Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Could there be a bug in mixed precision? #101

Closed
voldemortX opened this issue Jun 7, 2021 · 24 comments
Closed

Could there be a bug in mixed precision? #101

voldemortX opened this issue Jun 7, 2021 · 24 comments

Comments

@voldemortX
Copy link

voldemortX commented Jun 7, 2021

When I use torch 1.6.0 & accelerate 0.3.0 and set mixed precision as yes in accelerate config, nothing happens (still full precision training). If I set in the code Accelerator(fp16=True) then amp is triggered, but the loss becomes inf right away.

But if I use the pytorch way (i.e. autocast in the code myself), the training is normal and amp is enabled.

So I wonder if there is a possible bug in accelerate.

My enviroment is single 2080 Ti, local machine.
The code with this problem is here.

@sgugger
Copy link
Collaborator

sgugger commented Jun 7, 2021

Hi there, how do you know the script is not launched in mixed precision but use full precision? Just tried on my side and it runs properly in mixed precision.

@voldemortX
Copy link
Author

Well, I know that because the training speed & memory usage is the same as full precision.

@voldemortX
Copy link
Author

I think maybe this is due to my code structure (training may be too complex?).

@sgugger
Copy link
Collaborator

sgugger commented Jun 7, 2021

Memory usage won't be very different unless you are using a very large batch size. For the speed you have to make sure all the dimensions of the tensors you are using are multiple of 8.

@voldemortX
Copy link
Author

voldemortX commented Jun 7, 2021

I did use a batch size of 8. If I use PyTorch's own autocast, speed-up is normal. (mixed precision performance is very obvious on a Turing card)

Also in my task, there is almost 40% memory reduction.

@sgugger
Copy link
Collaborator

sgugger commented Jun 7, 2021

It's hard to say what's going wrong without seeing any code.

@voldemortX
Copy link
Author

voldemortX commented Jun 7, 2021

Sorry I don't really have a simple reproduceable code sample...
It's a rather big project, if you're interested, the pytorch autocast version (which is working) is here:
https://github.com/voldemortX/DST-CBC/blob/master/segmentation/main.py

When I set fp16 in accelerate config and run my code without --mixed-precision, it is just like full precision training.

@voldemortX
Copy link
Author

voldemortX commented Jun 7, 2021

On second thought, probably it's because I wrapped my code with with autocast(False)?

But the loss did explode when I use only fp16=True.

@sgugger
Copy link
Collaborator

sgugger commented Jun 7, 2021

You are not letting Accelerate handle mixed precision here, you are doing it in your script yourself: when the is_mixed_precision flag is True, you are also scaling the loss which means it will be scaled twice.

@voldemortX
Copy link
Author

Thanks for your answers! I'll try remove my own mixed precision handling, and see if the loss explosion issue can be reproduced.

@voldemortX
Copy link
Author

@sgugger I removed my own mixed precision handling:
https://github.com/voldemortX/DST-CBC/blob/test-fp16/segmentation/main.py

And if I set mixed precision as yes in accelerate config, GPU memory usage is ~8.9GB, same as fp32 (training speed is also same). When I used mixed precision, memory usage should be ~5.6GB.

If I set fp16=True directly in the code, the speed & memory usage is normal but the loss are all inf.

Please help me to locate the problem here? Thanks in advance.

@sgugger
Copy link
Collaborator

sgugger commented Jun 8, 2021

It's really weird that you have a different behavior between config and by setting fp16=true. Could you retry this on master?
Also trying to see why your loss is Nan. AFAICT, the same code is executed as with your commented lines for native AMP.

@voldemortX
Copy link
Author

AFAICT, the same code is executed as with your commented lines for native AMP.

Yes it is indeed weird. Could it be somehow the GradScaler is not activated? I'll try it again with master tomorrow.

@voldemortX
Copy link
Author

@sgugger I installed from master by these instructions:
https://huggingface.co/docs/accelerate/installation.html#editable-install

And the behavior is still the same.
BTW, I think cd transformers should be cd accelerate?

@edornd
Copy link

edornd commented Aug 1, 2021

Hi!
I'm also experiencing this weird NaN issue on the loss with mixed precision activated. When using a standard "full precision" GPU training everything is fine and converges as expected, using a custom model and nn.CrossEntropyLoss:

998/998 [04:22<00:00,  3.80batch/s, loss=0.248] 

Once I turn on FP16 (doesn't matter whether from config or args) it doesn't break, but the loss stays nan. It also becomes noticeably faster, but I kind of expected that given the AMP setting.

... | 88/998 [00:15<02:27,  6.18batch/s, loss=nan]

I'm writing this to know if you guys found a solution perhaps? I also don't have a minimum working example, but I'm working on uploading the code in a repository if required.

If it is of any help, I'm running on a linux machine with this package configuration:

CUDA 11.1
python==3.8.10
torch==1.9.0+cu111
torchvision==0.10.0+cu111
accelerate==0.3.0

@voldemortX
Copy link
Author

@edornd My solution for now is switching back to torch amp.

@sgugger
Copy link
Collaborator

sgugger commented Aug 2, 2021

If you do get a simple reproducer, I'm happy to investigate more. I have just not been able to reproduce this error on my side.

@edornd
Copy link

edornd commented Aug 2, 2021

Hi @sgugger, thanks for the quick reply! Unfortunately I didn't have time to build a proper minimum working example yet, but I managed to adjust the CV example to a segmentation task minimizing changes to the code, here it is.
This unfortunately in my case reproduces the problem.

I apologize for the use of the custom dataset and decoder, however if you check the code there's nothing particularly weird about them, just standard PyTorch stuff. The dataset is also nothing out of the ordinary as you can see here.
I tested the "standard" XEntropy with both reduction="sum" and reduction="mean", in the first case I get inf losses, in the latter nan (it converges as expected without fp16).

Reading around, I suspect this has little to do with accelerate, but it is rather linked to underflow and log transformations in the loss (?). I'll try to adapt the same script to manual AMP and see if the same issue arises, otherwise I'll see what I can do to make it self-contained, so that it can be launched without too many configuration troubles.

Cheers!

@edornd
Copy link

edornd commented Aug 4, 2021

Just a quick follow-up: no, still no self-contained min. working example, sorry :)
But, I've discovered a few things: first, I effectively solved the problem by manually overriding the loss backward call using AMP, as @voldemortX suggested. I still kept accelerate for the multi-device goodies, but the AMP part was done manually and interestingly it works (and my limited knowledge about the subject keeps me from knowing why). Here's the manual AMP version.

I believe that it's working for two reasons: training became noticeably faster and the memory usage dropped by a modest 25/30%, without of course nan loss issues.

However, debugging and printing stuff around, the only difference I could notice was that the loss tensor in my code was a float16 with accelerate.backward(), while it became float32 with native AMP, which is weird. I honestly expected the same dtype. I may be doing something wrong, nevertheless it's doing its job apparently.
I also quickly checked the fp16 part in accelerate and it looks to me like a simple wrapper of autocast around the model's forward, so it should be absolutely identical, at least in theory. In the end, I'm clueless.

@sgugger
Copy link
Collaborator

sgugger commented Aug 4, 2021

Thanks for the analysis and the example you provided. I'll try to dig more into the differences tomorrow.

@sgugger
Copy link
Collaborator

sgugger commented Aug 5, 2021

I was able to investigate this more and I think I found the problem. The PR above should fix the issue, would you mind giving it a try?

@edornd
Copy link

edornd commented Aug 5, 2021

Sorry for the late response, I managed to give it a try and I can confirm that now the behaviour is the same of the "manual" AMP version! I noticed I still have some nan issues after a good amount of epochs, when the loss is getting smaller, but that's just the risk of AMP I guess (and it happens in manual override as well), so nothing to do about that apart from going full precision.

This looks like it's working properly: I noticed ~50% speed increase, while GPU memory went down from 8753MiB / 11019MiB to 5781MiB / 11019MiB, while tqdm prints this lovely string: loss=1.9387, type=torch.float32.

Just for curiosity/ignorance from my part: in your comments to the PR, I didn't get the bit about "computing the loss in the model": is that a thing, or is it just limited to transformer models? And why is that more stable? If you can point me to an example doing so I'd be glad to give it a try as well!

Thank you very much once again @sgugger!

@sgugger
Copy link
Collaborator

sgugger commented Aug 5, 2021

Just for curiosity/ignorance from my part: in your comments to the PR, I didn't get the bit about "computing the loss in the model": is that a thing, or is it just limited to transformer models?

What I meant in the documentation is that ideally, for Accelerate to work best, the loss should be directly by the model (if you look at any transformer models, they return loss and logits), but in 90% of the cases, what you have with just a cross entropy loss applied after the model should work perfectly. This comment was meant for more complicated loss functions.

Glad to know it solved the issue!

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants