-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core ATen Opset] Lower aten_randperm #5994
Comments
This should be a good first issue for Mason, as it requires lowering |
Great thanks I can take a look at this! |
Thanks! Also added you as an collaborator to PyTorch/XLA, so you should be able to create a pull request. Let me know if you have any questions while working on this, thanks! |
Thanks! I had a couple of naiive questions, thanks for your patience.
The Tensorflow version of randperm similar to the PyTorch CPU implementation with My current thinking would be a fill of [0-N] array + TF style shuffle.
Thanks! Also please let me know if I'm reading the wrong code places :). |
Great questions! Here are my thoughts:
|
Do you mind taking a look at this work in progress PR here. I'm not quite sure what I'm missing. I'm following some example PRs like here. It isn't fully supposed to work but I was hoping to at least get some printfs. Whenever I run The log I'm getting is:
I feel like I'm missing something basic, do you see anything obvious? Thanks! |
Hmm, I don't see anything obvious but my guess would be something is wrong with the dispatch (since we don't even see your prints, so the op may not be getting properly dispatched to PyTorch/XLA) or the unit test itself. I'd recommend writing a simpler and smaller unit test for testing this, just in plain Python code. |
Ah, looking at the documentation for torch.randperm (https://pytorch.org/docs/stable/generated/torch.randperm.html), it looks like it accepts one non-optional parameter
Now looking back to |
Thanks! You're right that it needed the
It works! I see my printfs!! However, if I copy/paste this code into a
Thanks again for your help! |
Thanks for your help, I figured out (1). My environment was wonky and PyTorch has to be built with Re (2) - This was also just a vestige of wonkiness with my environment not reflecting code updates. I had a C++ code fallback which was happening. Deleting that made it all crash as expected. Thanks! |
Nice! For 2, it seems like your PR just calls the cpu fallback for randperm -- https://github.com/pytorch/xla/compare/master...changm:pytorch-xla:randperm?expand=1#diff-5e65c3c1d847191cb691d1874732e971f09fa1aad7a980a555c3b0504a5b6470R2484. Hence it's actually falling back to CPU to generate 4 random numbers. |
Ahh yeah you're right thanks again! I'm making some decent progress, I think I got a basic implementation but it's not correct. Couple of questions:
It seems like the HLO is being cached here somewhere or am I missing something? I'm also noticing that
Thanks! |
where |
Thanks! Jack's answer should answer questions 1) and 3). For question 2), that behavior is not expected. I recommend looking at the IR/HLO to see what's happening. Might be possible that we're not passing the int |
Had a chat with Wonjoo and I'm getting unexpected results. See my PR here. I'm using this test file:
Running via
Line #8 prints:
Line #9 prints:
I'm pretty sure I'm producing the wrong HLO which is fine, I'm just wondering if there's a better way to dump the HLO / why there's a difference between the HLO computation in xla_graph_executor.cpp versus python level My hunch is |
prints pre_optimized
means it is just a device_data from compiler perspective. This usually means that there is a fall back to CPU and the op actually got executed on CPU. You can dump the metrics report to confirm this theory. https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#get-a-metrics-report I expect you to see some |
Thanks for the quick reply. I added a
I don't see any |
Actually this is true since this is a CPU VM just to get started. Is maybe thats why? XLA:CPU doesn't actually work yet? |
CPU vm shouldn't matter, this is about xla device vs non-xla device... This is a bit weird then, I do see OK I see, you need to remove
if you want to inspect the IR, otherwise it will materialize the output.. |
Ahh I just went back to master and tried stuff at head. Interestingly I think I just got very unlucky for some reason :). Doing this breaks and crashes:
However, creating |
Thanks Jack for pitching in, very helpful as always. @changm, let us know if you need help further debugging. |
Thanks! I got a PR that I think is ready to merge, but I'm still confused. All the tests work and my printfs work, however I still can't get HLO text. When doing:
I get:
Questions:
There's another theory I have that since N is constant, we can precompute everything and XLA doesn't actually compile / run anything since it's all optimized away. We saw some conditions like that in Tensorflow. |
Gah sorry for the noise, I found a bug and was able to print all the HLO as expected. Thank you! |
Nice! Just curious, wondering what was the bug? |
The bug was incorrectly checking an |
Usually, if there is a crash then the entire program should crash and exit. It should fallback only if there is an explicit call to fallback in case of a crash. |
I was able to reproduce the build issue locally but I'm actually very confused. I think this goes back to "materializing" an output. Reproducing via:
The test harness here calls
This both runs the lowered HLO and materializes the Tensor?
Thanks! |
Closing as it is fixed with #6482. |
In order for PyTorch/XLA to support the PyTorch core ATen opset, it requires lowering each core ATen op in PyTorch/XLA. This issue is used to track the PyTorch/XLA lowering for aten_randperm.
Here are some general guidelines to lowering this op:
@unittest.skip
or@unittest.expectFailure
and run the unit test at test_core_aten_ops.py. Eg:pytest test/test_core_aten_ops.py -k test_aten_randperm_0
For any questions, feel free to leave a comment in this PR.
The text was updated successfully, but these errors were encountered: