Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap constant arg in XLAExportInterpreter #6460

Merged
merged 4 commits into from
Feb 6, 2024

Conversation

lsy323
Copy link
Collaborator

@lsy323 lsy323 commented Feb 2, 2024

If the op spec expects a Tensor input, we wrap the python primitive type to a torch.tensor. The dtype for tensor respects torch.get_default_dtype.

Without this wrapping, the python float will be wrapped before it enters dispatcher, and it doesn't respect the global default dtype.

With this change, aten.sub(1.0, x) and aten.sub(torch.tensor(1.0), x) will be the same, when they enters the LTC.

It's still unknown where the scalar constant is wrapped into a tensor in upstream. It should be before it enters dispatcher (By investigation and syncing with PyTorch team).

@lsy323 lsy323 requested review from qihqi and JackCaoG February 2, 2024 22:40
torch_xla/stablehlo.py Outdated Show resolved Hide resolved
@lsy323 lsy323 self-assigned this Feb 2, 2024
@lsy323 lsy323 marked this pull request as ready for review February 2, 2024 23:12
@lsy323 lsy323 requested a review from JackCaoG February 3, 2024 00:01
@lsy323 lsy323 merged commit 5dbbb28 into master Feb 6, 2024
18 checks passed
amithrm pushed a commit to amithrm/xla that referenced this pull request Mar 1, 2024
@lsy323 lsy323 deleted the lsiyuan/interpreter-wrap-constant branch March 4, 2024 19:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants