Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for other pytorch device types, including MPS #1445

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

adamobeng
Copy link

Fixes (#1441)

Change list

  1. Add command line arguments --device_type and --device_ids which allow torch backend and device ordinals to be specified
  2. Make code specific to GPUs/cuda device-agnostic (in particular by using a list of torch devices rather than GPU ids)
  3. Maintain support for --gpu_ids argument with some special logic (it would be cleaner but non-backwards compatible to remove it)
  4. Add some tests of the argument parsing

Testing

  • Unit tests pass
  • Results generated with python train.py --dataroot ./datasets/maps --name maps --model pix2pix --direction AtoB --device_type mps seem reasonable.
  • Suggestions on more rigorous testing are welcomed!

NB: On my specific setup, loading a model trained with MPS fails with RuntimeError: don't know how to restore data location of torch.storage._UntypedStorage (tagged with mps:0), but it seems like this is a known and intermittent issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant