-
Notifications
You must be signed in to change notification settings - Fork 493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add basic device APIs to the top-level torch_xla
module.
#6571
Conversation
|
||
std::vector<std::string> xla_devices; | ||
{ | ||
NoGilSection nogil; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why we release the GIL in this block? Is it for tpu v2/v3 where we allow multiple threads to do some runtime job such as GetXlaDevices(*devices)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure to be honest. Maybe GetXlaDevices
was a blocking call in XRT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh lol it is that torch_xla
def setUpClass(): | ||
xr.set_device_type('CPU') | ||
os.environ['CPU_NUM_DEVICES'] = '4' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, it is OK for now but shouldn't we also test it on GPU and TPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is sufficient IMO. We're really just testing the integration of this module with the runtime client, which has the same API regardless of the underlying device.
As we switch to using these functions by convention, they'll be exercised by almost every other test.
@@ -182,3 +182,5 @@ def _init_xla_lazy_backend(): | |||
if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS') == '1': | |||
plugins.use_dynamic_plugins() | |||
plugins.register_installed_plugins() | |||
|
|||
from .torch_xla import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This imports the contents of torch_xla.py
into torch_xla/
's module scope. Otherwise, the functions would be torch_xla.torch_xla.etc
. This assigns them to torch_xla.etc
Implemented basic device APIs from #6399.
device
,devices
,real_devices
, anddevice_count
as described in RFC.__init__.py
. Addtorch_xla.py
for public functions ontorch_xla
module.Follow up:
xm.xla_device
andruntime.xla_device
.