Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MPS with sequence loss #2834

Merged
merged 5 commits into from
Jan 22, 2024
Merged

Fix MPS with sequence loss #2834

merged 5 commits into from
Jan 22, 2024

Conversation

JAEarly
Copy link
Contributor

@JAEarly JAEarly commented Jan 10, 2024

What does this PR do?

TLDR
Fixes the trainer eval loop for MPS devices when the model output is a list of tensors rather than a torch.Tensor or Mapping.

Related
Similar issue but with dict: #2632
PR for above issue: #2706

Details
In the trainer evaluation loop, model outputs (self.state.outputs) are moved to the CPU if they are on an MPS device (to avoid torchvision numerical errors on MPS devices). Currently this works fine if self.state.outputs is a Mapping or a torch.Tensor. However, it fails for Sequence types (e.g. a list of torch.Tensors).

According to the State class, outputs is expected to be of type torch.Tensor | Sequence[torch.Tensor]. So Sequence types should be supported in this process in the eval loop. Strictly, outputs should not be a Mapping, despite the eval operation supporting this. I have left the Mapping support in place for now but happy to revisit.

As it was a little difficult to debug this issue, I have added an error message which should make it clearer if an invalid output type is used (such that it cannot correctly be mapped from MPS to CPU).

Before submitting

  • Have you read the contributor guidelines?
  • Is this change a documentation change or typo fix? If so, skip the rest of this checklist.
  • Was this change discussed/approved in a GitHub issue first? It is much more likely to be merged if so.
  • Did you update any related docs and document your change?
  • Did you update any related tests and add any new tests related to your change? (see testing)
  • Did you run the tests locally to make sure they pass?
  • Did you run pre-commit on your change? (see the pre-commit section of prerequisites)

@JAEarly
Copy link
Contributor Author

JAEarly commented Jan 10, 2024

Upon further inspection, eval_forward can strictly return Any type, so keeping support for Mapping makes sense. Adding support for self.state.options == None in this PR may also be useful.

@JAEarly JAEarly force-pushed the eval_loop_mps branch 2 times, most recently from 31d92e8 to b945c5a Compare January 18, 2024 10:18
@JAEarly JAEarly marked this pull request as draft January 19, 2024 11:30
@JAEarly JAEarly marked this pull request as ready for review January 19, 2024 11:30
@JAEarly
Copy link
Contributor Author

JAEarly commented Jan 19, 2024

@mvpatel2000 this is a relatively small change but is fairly annoying for running stuff on MPS, could it be merged in please? Thank you!

Copy link
Contributor

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! I left one small comment, and happy to approve if that makes sense to you!

composer/trainer/trainer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@mvpatel2000
Copy link
Contributor

@JAEarly apologies for the delay in review! Next time please feel free to tag me as a reviewer.... I will see if we can setup auto-tagging as well. We didn't see this earlier :(

@mvpatel2000 mvpatel2000 merged commit 1df5557 into mosaicml:dev Jan 22, 2024
17 checks passed
ShashankMosaicML pushed a commit to ShashankMosaicML/composer that referenced this pull request Feb 3, 2024
* Add MPS support for list outputs in training eval loop

* Add error for invalid state outputs type in trainer

* Remove raise ValueError in trainer eval loop

---------

Co-authored-by: Daniel King <[email protected]>
ShashankMosaicML pushed a commit to ShashankMosaicML/composer that referenced this pull request Feb 3, 2024
* Add MPS support for list outputs in training eval loop

* Add error for invalid state outputs type in trainer

* Remove raise ValueError in trainer eval loop

---------

Co-authored-by: Daniel King <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants