-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatically move CUDA non XLA Tensors to XLA Device and back to CUDA device #6644
Conversation
Is this an experimental pr or you want to merge this? |
Ideally we would merge this, or is there a reason not to? |
torch_xla/core/dynamo_bridge.py
Outdated
@@ -387,6 +446,12 @@ def optimized_mod(*args): | |||
nonlocal xla_args_need_update | |||
nonlocal skip_checking_input_sharding_threashold | |||
|
|||
original_device: torch.device = _get_input_arg_device(args) | |||
is_cuda_args: bool = _args_on_cuda(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_args_on_cuda
will call _get_input_arg_device
which is redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think it's a little cleaner to do the redundant call, but removed the call here.
Sorry for being late. It's looking good! |
Currently only works for inference. The assumptions don't hold for training with Autograd yet.