Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically move CUDA non XLA Tensors to XLA Device and back to CUDA device #6644

Merged
merged 3 commits into from
Mar 13, 2024

Conversation

changm
Copy link
Collaborator

@changm changm commented Feb 29, 2024

Currently only works for inference. The assumptions don't hold for training with Autograd yet.

@changm changm self-assigned this Feb 29, 2024
@changm changm requested a review from vanbasten23 February 29, 2024 00:24
@vanbasten23 vanbasten23 requested a review from JackCaoG March 5, 2024 21:49
@changm changm requested a review from golechwierowicz March 6, 2024 15:11
@changm changm force-pushed the changm/automove branch from 476edaf to 2f7c0cc Compare March 7, 2024 14:51
@changm changm requested a review from vanbasten23 March 7, 2024 14:52
@JackCaoG
Copy link
Collaborator

JackCaoG commented Mar 7, 2024

Is this an experimental pr or you want to merge this?

@changm changm force-pushed the changm/automove branch from f50e471 to 195978d Compare March 7, 2024 20:38
@changm
Copy link
Collaborator Author

changm commented Mar 7, 2024

Is this an experimental pr or you want to merge this?

Ideally we would merge this, or is there a reason not to?

@changm changm changed the title Automatically move non XLA Tensors to XLA Device and back to original device. Automatically move CUDA non XLA Tensors to XLA Device and back to CUDA device Mar 11, 2024
@@ -387,6 +446,12 @@ def optimized_mod(*args):
nonlocal xla_args_need_update
nonlocal skip_checking_input_sharding_threashold

original_device: torch.device = _get_input_arg_device(args)
is_cuda_args: bool = _args_on_cuda(args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_args_on_cuda will call _get_input_arg_device which is redundant.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think it's a little cleaner to do the redundant call, but removed the call here.

torch_xla/core/dynamo_bridge.py Show resolved Hide resolved
torch_xla/core/dynamo_bridge.py Show resolved Hide resolved
torch_xla/core/dynamo_bridge.py Show resolved Hide resolved
@changm changm requested review from ysiraichi and JackCaoG March 11, 2024 23:02
@changm changm merged commit d13ae1b into master Mar 13, 2024
18 checks passed
@changm changm deleted the changm/automove branch March 13, 2024 16:41
@vanbasten23
Copy link
Collaborator

Sorry for being late. It's looking good!

yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Oct 11, 2024
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Dec 11, 2024
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Dec 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants