Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic device APIs to the top-level torch_xla module. #6571

Merged
merged 6 commits into from
Feb 21, 2024

Conversation

will-cromar
Copy link
Collaborator

@will-cromar will-cromar commented Feb 20, 2024

Implemented basic device APIs from #6399.

  • Add device, devices, real_devices, and device_count as described in RFC.
  • We already do a substantial amount of setup in __init__.py . Add torch_xla.py for public functions on torch_xla module.

Follow up:

  • Update documentation to use new APIs.
  • Start deprecating or discouraging usage of old APIs like xm.xla_device and runtime.xla_device.

@will-cromar will-cromar added runtime usability Bugs/features related to improving the usability of PyTorch/XLA labels Feb 20, 2024
@will-cromar will-cromar requested a review from JackCaoG February 20, 2024 22:05
@will-cromar will-cromar marked this pull request as ready for review February 20, 2024 22:05

std::vector<std::string> xla_devices;
{
NoGilSection nogil;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why we release the GIL in this block? Is it for tpu v2/v3 where we allow multiple threads to do some runtime job such as GetXlaDevices(*devices)?

Copy link
Collaborator Author

@will-cromar will-cromar Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure to be honest. Maybe GetXlaDevices was a blocking call in XRT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh lol it is that torch_xla

Comment on lines +11 to +13
def setUpClass():
xr.set_device_type('CPU')
os.environ['CPU_NUM_DEVICES'] = '4'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, it is OK for now but shouldn't we also test it on GPU and TPU?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is sufficient IMO. We're really just testing the integration of this module with the runtime client, which has the same API regardless of the underlying device.

As we switch to using these functions by convention, they'll be exercised by almost every other test.

@@ -182,3 +182,5 @@ def _init_xla_lazy_backend():
if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS') == '1':
plugins.use_dynamic_plugins()
plugins.register_installed_plugins()

from .torch_xla import *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This imports the contents of torch_xla.py into torch_xla/'s module scope. Otherwise, the functions would be torch_xla.torch_xla.etc. This assigns them to torch_xla.etc

@will-cromar will-cromar merged commit 0ec5b91 into master Feb 21, 2024
17 of 18 checks passed
amithrm pushed a commit to amithrm/xla that referenced this pull request Mar 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
runtime usability Bugs/features related to improving the usability of PyTorch/XLA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants