-
Notifications
You must be signed in to change notification settings - Fork 429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix MPS with sequence loss #2834
Conversation
Upon further inspection, |
31d92e8
to
b945c5a
Compare
b945c5a
to
1dd8bc0
Compare
@mvpatel2000 this is a relatively small change but is fairly annoying for running stuff on MPS, could it be merged in please? Thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix! I left one small comment, and happy to approve if that makes sense to you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@JAEarly apologies for the delay in review! Next time please feel free to tag me as a reviewer.... I will see if we can setup auto-tagging as well. We didn't see this earlier :( |
* Add MPS support for list outputs in training eval loop * Add error for invalid state outputs type in trainer * Remove raise ValueError in trainer eval loop --------- Co-authored-by: Daniel King <[email protected]>
* Add MPS support for list outputs in training eval loop * Add error for invalid state outputs type in trainer * Remove raise ValueError in trainer eval loop --------- Co-authored-by: Daniel King <[email protected]>
What does this PR do?
TLDR
Fixes the trainer eval loop for MPS devices when the model output is a list of tensors rather than a
torch.Tensor
orMapping
.Related
Similar issue but with dict: #2632
PR for above issue: #2706
Details
In the trainer evaluation loop, model outputs (
self.state.outputs
) are moved to the CPU if they are on an MPS device (to avoid torchvision numerical errors on MPS devices). Currently this works fine ifself.state.outputs
is a Mapping or a torch.Tensor. However, it fails for Sequence types (e.g. a list of torch.Tensors).According to the State class,
outputs
is expected to be of typetorch.Tensor | Sequence[torch.Tensor]
. So Sequence types should be supported in this process in the eval loop. Strictly,outputs
should not be aMapping
, despite the eval operation supporting this. I have left theMapping
support in place for now but happy to revisit.As it was a little difficult to debug this issue, I have added an error message which should make it clearer if an invalid output type is used (such that it cannot correctly be mapped from MPS to CPU).
Before submitting
pre-commit
on your change? (see thepre-commit
section of prerequisites)