Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core ATen Opset] Lower aten_randperm #5994

Closed
wonjoolee95 opened this issue Dec 2, 2023 · 28 comments
Closed

[Core ATen Opset] Lower aten_randperm #5994

wonjoolee95 opened this issue Dec 2, 2023 · 28 comments
Assignees
Labels

Comments

@wonjoolee95
Copy link
Collaborator

wonjoolee95 commented Dec 2, 2023

In order for PyTorch/XLA to support the PyTorch core ATen opset, it requires lowering each core ATen op in PyTorch/XLA. This issue is used to track the PyTorch/XLA lowering for aten_randperm.

Here are some general guidelines to lowering this op:

  • Uncomment @unittest.skip or @unittest.expectFailure and run the unit test at test_core_aten_ops.py. Eg: pytest test/test_core_aten_ops.py -k test_aten_randperm_0
  • Make code changes until the test passes. Read and follow fix_lowering_for_core_aten_ops.md for ideas to fix.
    • There may be multiple unit tests for a single op. For this op, the corresponding unit tests are:
      • test_aten_randperm_0
    • Please also uncomment the skips for all these tests and ensure all tests are fixed.
    • Note that sometimes the fix may be to fix the unit tests itself. Please take a look at the corresponding unit tests to make sure the tests are valid.
  • Submit the PR!

For any questions, feel free to leave a comment in this PR.

@wonjoolee95
Copy link
Collaborator Author

This should be a good first issue for Mason, as it requires lowering randperm entirely from scratch and it mostly is a simpler op. Mason is not a collaborator to PyTorch/XLA repo yet, so I'll assign this issue as soon as I get Mason's GitHub alias.

@changm
Copy link
Collaborator

changm commented Jan 11, 2024

Great thanks I can take a look at this!

@wonjoolee95
Copy link
Collaborator Author

Thanks! Also added you as an collaborator to PyTorch/XLA, so you should be able to create a pull request. Let me know if you have any questions while working on this, thanks!

@changm
Copy link
Collaborator

changm commented Jan 12, 2024

Thanks! I had a couple of naiive questions, thanks for your patience.

  1. IIUC, the CPU implementation of torch.randperm does Fisher Yates Shuffle here. The GPU CUDA implementation does something else here.

The Tensorflow version of randperm similar to the PyTorch CPU implementation with tf.random.shuffle here. Their TF -> XLA lowering version is here. Any thoughts / preferences about whether the implementation should just copy the CPU version for this issue versus copy the PyTorch CUDA approach?

My current thinking would be a fill of [0-N] array + TF style shuffle.

  1. The PyTorch docs for randperm have lots of various input parameters. The unit test only has a single test with the single input param n. Is that OK?

  2. The test infra expects equal outputs. Since this op by definition is random output, I presume I also have to update the test harness to support random equals? (e.g. assert unique + sort + check equality). Or is there a preferred way to test random output that I'm missing.

Thanks! Also please let me know if I'm reading the wrong code places :).

@wonjoolee95
Copy link
Collaborator Author

Great questions! Here are my thoughts:

  1. The proposed a fill of [0-N] array + TF style shuffle sounds good. All the different types of implementation should be good, as long as we support all the inputs to the torch.randperm properly.
  2. Good point. So the purpose of this test (https://github.com/pytorch/xla/blob/master/test/test_core_aten_ops.py) is more of a quick smoke testing to make sure all core aten ops are supported by PyTorch/XLA. For a more detailed test (testing various of input parameters), we can do that in our C++ op unit tests, located here: https://github.com/pytorch/xla/blob/master/test/cpp/test_aten_xla_tensor_1.cpp. Usually we want to create these C++ unit tests when we lower a new op, so adding some new C++ tests with your PR would be great.
  3. Asserting unique + sort + check equality sounds good to me.

@changm
Copy link
Collaborator

changm commented Jan 17, 2024

Do you mind taking a look at this work in progress PR here. I'm not quite sure what I'm missing. I'm following some example PRs like here. It isn't fully supposed to work but I was hoping to at least get some printfs.

Whenever I run python test/test_core_aten_ops.py -k test_aten_randperm_0, I'm still not getting any printfs anywhere.

The log I'm getting is:

python test/test_core_aten_ops.py -k test_aten_randperm_0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1705530486.251095 3391106 cpu_client.cc:370] TfrtCpuClient created.
[2024-01-17 22:28:06,255] torch._dynamo.eval_frame: [WARNING] could not determine __code__ for aten.randperm
[2024-01-17 22:28:06,321] torch._dynamo.eval_frame: [WARNING] could not determine __code__ for aten.randperm

======================================================================
FAIL: test_aten_randperm_0 (__main__.AtenOpTest) [torch_xla_diff:0.001]
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/test_core_aten_ops.py", line 62, in run_export_and_compare
    diff_output(
  File "test/test_core_aten_ops.py", line 33, in diff_output
    testcase.assertTrue(
AssertionError: False is not true

======================================================================
FAIL: test_aten_randperm_0 (__main__.AtenOpTest) [can_export]
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/test_core_aten_ops.py", line 65, in run_export_and_compare
    exported = torch.export.export(func, args, kwargs)
  File "/usr/local/google/home/masonchang/anaconda3/envs/py38/lib/python3.8/site-packages/torch/export/__init__.py", line 191, in export
    return _export(
  File "/usr/local/google/home/masonchang/anaconda3/envs/py38/lib/python3.8/site-packages/torch/export/exported_program.py", line 78, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/google/home/masonchang/anaconda3/envs/py38/lib/python3.8/site-packages/torch/export/_trace.py", line 719, in _export
    ep_non_strict = _export_non_strict(
  File "/usr/local/google/home/masonchang/anaconda3/envs/py38/lib/python3.8/site-packages/torch/export/_trace.py", line 455, in _export_non_strict
    tensor_constants = lift_constant_tensor_pass(gm, export_graph_signature)
  File "/usr/local/google/home/masonchang/anaconda3/envs/py38/lib/python3.8/site-packages/torch/_export/passes/lift_constant_tensor_pass.py", line 24, in lift_constant_tensor_pass
    assert fake_mode is not None
AssertionError

----------------------------------------------------------------------
Ran 1 test in 0.099s

FAILED (failures=2)
I0000 00:00:1705530486.596958 3391106 cpu_client.cc:373] TfrtCpuClient destroyed.

I feel like I'm missing something basic, do you see anything obvious? Thanks!

@wonjoolee95
Copy link
Collaborator Author

Hmm, I don't see anything obvious but my guess would be something is wrong with the dispatch (since we don't even see your prints, so the op may not be getting properly dispatched to PyTorch/XLA) or the unit test itself. I'd recommend writing a simpler and smaller unit test for testing this, just in plain Python code.

@wonjoolee95
Copy link
Collaborator Author

Ah, looking at the documentation for torch.randperm (https://pytorch.org/docs/stable/generated/torch.randperm.html), it looks like it accepts one non-optional parameter n. Usually, ops that takes a tensor as an input determines knows it is an XLA tensor and dispatches to PyTorch/XLA. For example, let's look at a very simple op, torch.add:

>>> import torch
>>> a = torch.randn(4)
>>> torch.add(a, 20) # this dispatches to PyTorch 

>>> import torch_xla.core.xla_model as xm
>>> xla_device = xm.xla_device()
>>> a_xla = a.to(xla_device)
>>> torch.add(a_xla 20) # this dispatches to PyTorch/XLA, since the parameter is an XLA tensor

Now looking back to torch.randperm, with its non-optional parameter int n, it doesn't know which device to dispatch to. But it looks like it accepts a parameter device. So maybe we need to pass in xla_device to this optional parameter?

@changm
Copy link
Collaborator

changm commented Jan 18, 2024

Thanks! You're right that it needed the xla_device to see some printfs! Couple of questions though:

  1. If I build PyTorch/XLA and use an interactive shell, and do:
import torch
import torch_xla
import torch_xla.core.xla_model as xm

a = torch.randn(4)
print(torch.randperm(5, device=xm.xla_device()))

It works! I see my printfs!! However, if I copy/paste this code into a test.py and then do python test.py, I don't see the printfs anymore. Any ideas?

  1. Given the PR, this code shouldn't have worked at all, yet it's printing out 4 random numbers. Is there some fallback if XLA compilation fails without explicit calls to the CPU fallback / something else?

Thanks again for your help!

@changm
Copy link
Collaborator

changm commented Jan 18, 2024

Thanks for your help, I figured out (1). My environment was wonky and PyTorch has to be built with python setup.py develop. I updated the CONTRIBUTION doc to reflect that.

Re (2) - This was also just a vestige of wonkiness with my environment not reflecting code updates. I had a C++ code fallback which was happening. Deleting that made it all crash as expected. Thanks!

@wonjoolee95
Copy link
Collaborator Author

Nice! For 2, it seems like your PR just calls the cpu fallback for randperm -- https://github.com/pytorch/xla/compare/master...changm:pytorch-xla:randperm?expand=1#diff-5e65c3c1d847191cb691d1874732e971f09fa1aad7a980a555c3b0504a5b6470R2484. Hence it's actually falling back to CPU to generate 4 random numbers.

@changm
Copy link
Collaborator

changm commented Jan 19, 2024

Ahh yeah you're right thanks again! I'm making some decent progress, I think I got a basic implementation but it's not correct. Couple of questions:

  1. Are there some flags / somewhere to dump the generated HLO module for a specific program?
  2. Since this is a random operation, should this op ever be cached? If I have this program:
# Prints 5 numbers
print(torch.randperm(5, device=xm.xla_device())

# Prints 5 numbers still even though the input is 10
print(torch.randperm(10, device=xm.xla_device())

It seems like the HLO is being cached here somewhere or am I missing something? I'm also noticing that RandPerm::Lower() isn't called the second time.

  1. Is there a way to dump the generated dynamo trace and resulting XLA IR?

Thanks!

@JackCaoG
Copy link
Collaborator

print(torch_xla._XLAC._get_xla_tensors_text([res]))
print(torch_xla._XLAC._get_xla_tensors_hlo([res]))

where res is the result tensor. It will use result as the root and return you the IR and HLO.

@wonjoolee95
Copy link
Collaborator Author

Thanks! Jack's answer should answer questions 1) and 3).

For question 2), that behavior is not expected. I recommend looking at the IR/HLO to see what's happening. Might be possible that we're not passing the int n variable correctly through the layers, but it shouldn't be cached.

@changm
Copy link
Collaborator

changm commented Jan 23, 2024

Had a chat with Wonjoo and I'm getting unexpected results. See my PR here. I'm using this test file:

  1 import torch
  2 import torch_xla
  3 import torch_xla.core.xla_model as xm
  4
  5 xla_a = torch.randperm(3, device=xm.xla_device())
  6 print(xla_a)
  7
  8 print(torch_xla._XLAC._get_xla_tensors_text([xla_a]))
  9 print(torch_xla._XLAC._get_xla_tensors_hlo([xla_a]))

Running via python test.py. The changes to xla_graph_executor.cpp dump the following HLO, which is expected as its the HLO build from RandPerm::Lower():

HloModule SyncTensorsGraph.9, entry_computation_layout={()->(s64[2]{0})}

ENTRY %SyncTensorsGraph.9 () -> (s64[2]) {
  %iota.1 = s64[2]{0} iota(), iota_dimension=0
  %constant.3 = s64[] constant(1)
  %dynamic-slice.5 = s64[1]{0} dynamic-slice(s64[2]{0} %iota.1, s64[] %constant.3), dynamic_slice_sizes={1}
  %constant.2 = s64[] constant(0)
  %dynamic-update-slice.6 = s64[2]{0} dynamic-update-slice(s64[2]{0} %iota.1, s64[1]{0} %dynamic-slice.5, s64[] %constant.2)
  %dynamic-slice.4 = s64[1]{0} dynamic-slice(s64[2]{0} %iota.1, s64[] %constant.2), dynamic_slice_sizes={1}
  %dynamic-update-slice.7 = s64[2]{0} dynamic-update-slice(s64[2]{0} %iota.1, s64[1]{0} %dynamic-slice.4, s64[] %constant.3)
  ROOT %tuple.8 = (s64[2]{0}) tuple(s64[2]{0} %dynamic-update-slice.7)
}

Line #8 prints:

IR {
  %0 = s64[3]{0} xla::device_data(), xla_shape=s64[3]{0}, ROOT=0
}

Line #9 prints:

HloModule IrToHlo.3, entry_computation_layout={(s64[3]{0})->(s64[3]{0})}

ENTRY %IrToHlo.3 (p0.1: s64[3]) -> (s64[3]) {
  %p0.1 = s64[3]{0} parameter(0)
  ROOT %tuple.2 = (s64[3]{0}) tuple(s64[3]{0} %p0.1)
}

I'm pretty sure I'm producing the wrong HLO which is fine, I'm just wondering if there's a better way to dump the HLO / why there's a difference between the HLO computation in xla_graph_executor.cpp versus python level print(torch_xla._XLAC._get_xla_tensors_hlo([xla_a]))?

My hunch is print(torch_xla._XLAC._get_xla_tensors_hlo([xla_a])) is the optimized HLO versus unoptimized, but it's just a theory. Do either of you have any ideas? Thanks!

@JackCaoG
Copy link
Collaborator

_get_xla_tensors_hlo

prints pre_optimized

IR {
  %0 = s64[3]{0} xla::device_data(), xla_shape=s64[3]{0}, ROOT=0
}

means it is just a device_data from compiler perspective. This usually means that there is a fall back to CPU and the op actually got executed on CPU. You can dump the metrics report to confirm this theory. https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#get-a-metrics-report I expect you to see some aten::xxx in the metrics report after you execute randperm.

@changm
Copy link
Collaborator

changm commented Jan 24, 2024

Thanks for the quick reply. I added a print(met.metrics_report()) and got this:

Metric: DeviceLockWait
  TotalSamples: 2
  Accumulator: 013.940us
  ValueRate: 02s658ms550.535us / second
  Rate: 237812 / second
  Percentiles: 1%=001.320us; 5%=001.320us; 10%=001.320us; 20%=001.320us; 50%=012.620us; 80%=012.620us; 90%=012.620us; 95%=012.620us; 99%=012.620us
Metric: LazyTracing
  TotalSamples: 2
  Accumulator: 021ms860.457us
  ValueRate: 993ms582.913us / second
  Rate: 95.1641 / second
  Percentiles: 1%=127.680us; 5%=127.680us; 10%=127.680us; 20%=127.680us; 50%=021ms732.777us; 80%=021ms732.777us; 90%=021ms732.777us; 95%=021ms732.777us; 99%=021ms732.777us
Metric: TensorsGraphSize
  TotalSamples: 1
  Accumulator: 1.00
  Percentiles: 1%=1.00; 5%=1.00; 10%=1.00; 20%=1.00; 50%=1.00; 80%=1.00; 90%=1.00; 95%=1.00; 99%=1.00
Metric: UnwrapXlaData
  TotalSamples: 2
  Accumulator: 000.960us
  ValueRate: 007ms042.770us / second
  Rate: 14672.4 / second
  Percentiles: 1%=000.070us; 5%=000.070us; 10%=000.070us; 20%=000.070us; 50%=000.890us; 80%=000.890us; 90%=000.890us; 95%=000.890us; 99%=000.890us
Metric: WrapXlaData
  TotalSamples: 1
  Accumulator: 000.071us
  Percentiles: 1%=000.071us; 5%=000.071us; 10%=000.071us; 20%=000.071us; 50%=000.071us; 80%=000.071us; 90%=000.071us; 95%=000.071us; 99%=000.071us
Counter: CreateXlaTensor
  Value: 1
Counter: RegisterXLAFunctions
  Value: 1
Counter: UncachedCompile
  Value: 1
Counter: xla::_to_copy
  Value: 1
Counter: xla::randperm
  Value: 1
Metric: CompileTime
  TotalSamples: 1
  Accumulator: 010ms906.028us
  Percentiles: 1%=010ms906.028us; 5%=010ms906.028us; 10%=010ms906.028us; 20%=010ms906.028us; 50%=010ms906.028us; 80%=010ms906.028us; 90%=010ms906.028us; 95%=010ms906.028us; 99%=010ms906.028us
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 043.690us
  Percentiles: 1%=043.690us; 5%=043.690us; 10%=043.690us; 20%=043.690us; 50%=043.690us; 80%=043.690us; 90%=043.690us; 95%=043.690us; 99%=043.690us
Metric: InboundData
  TotalSamples: 1
  Accumulator: 16.00B
  Percentiles: 1%=16.00B; 5%=16.00B; 10%=16.00B; 20%=16.00B; 50%=16.00B; 80%=16.00B; 90%=16.00B; 95%=16.00B; 99%=16.00B
Metric: TransferFromDeviceTime
  TotalSamples: 1
  Accumulator: 003ms270.323us
  Percentiles: 1%=003ms270.323us; 5%=003ms270.323us; 10%=003ms270.323us; 20%=003ms270.323us; 50%=003ms270.323us; 80%=003ms270.323us; 90%=003ms270.323us; 95%=003ms270.323us; 99%=003ms270.323us
Counter: CreateCompileHandles
  Value: 1
Counter: CreateDataHandles

I don't see any aten::xxx in the logs unless I'm missing something? I'm also still curious why I would correctly see all the genreated HLO that I produced in randperm::Lower() in xla_graph_executor.cpp?

@changm
Copy link
Collaborator

changm commented Jan 24, 2024

his usually means that there is a fall back to CPU and the op actually got executed on CPU. You can dump the metrics report to confirm this theory.

Actually this is true since this is a CPU VM just to get started. Is maybe thats why? XLA:CPU doesn't actually work yet?

@JackCaoG
Copy link
Collaborator

CPU vm shouldn't matter, this is about xla device vs non-xla device... This is a bit weird then, I do see xla::randperm which means you lowering is triggered, but the output apeeared to be a computed value...

OK I see, you need to remove

print(xla_a)

if you want to inspect the IR, otherwise it will materialize the output..

@changm
Copy link
Collaborator

changm commented Jan 24, 2024

Ahh I just went back to master and tried stuff at head. Interestingly I think I just got very unlucky for some reason :). Doing this breaks and crashes:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

a = torch.randperm(10)
xla_a = a.to(xm.xla_device())
xla_add = torch.mean(xla_a)

print(torch_xla._XLAC._get_xla_tensors_text([xla_add]))

However, creating a with a = torch.randn(10) works whereas even using torch.randperm(10) with upstream torch breaks. Anyway guess I can continue debugging, thank you!

@wonjoolee95
Copy link
Collaborator Author

Thanks Jack for pitching in, very helpful as always. @changm, let us know if you need help further debugging.

@changm
Copy link
Collaborator

changm commented Jan 24, 2024

Thanks! I got a PR that I think is ready to merge, but I'm still confused. All the tests work and my printfs work, however I still can't get HLO text. When doing:

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

a = torch.randperm(10, device=xm.xla_device())

print(torch_xla._XLAC._get_xla_tensors_hlo([a]))
print(met.metrics_report())

I get:

IR {
  %0 = s64[10]{0} xla::device_data(), location=<module>@test.py:9, xla_shape=s64[10]{0}, ROOT=0
}

Metric: DeviceLockWait
  TotalSamples: 2
  Accumulator: 008.840us
  ValueRate: 03s080ms139.373us / second
  Rate: 696864 / second
  Percentiles: 1%=001.790us; 5%=001.790us; 10%=001.790us; 20%=001.790us; 50%=007.050us; 80%=007.050us; 90%=007.050us; 95%=007.050us; 99%=007.050us
Metric: IrValueTensorToXlaData
  TotalSamples: 1
  Accumulator: 049.490us
  Percentiles: 1%=049.490us; 5%=049.490us; 10%=049.490us; 20%=049.490us; 50%=049.490us; 80%=049.490us; 90%=049.490us; 95%=049.490us; 99%=049.490us
Metric: LazyTracing
  TotalSamples: 6
  Accumulator: 040ms928.038us
  ValueRate: 02s053ms536.832us / second
  Rate: 308.435 / second
  Percentiles: 1%=001.470us; 5%=001.470us; 10%=001.470us; 20%=010.810us; 50%=831.820us; 80%=018ms277.629us; 90%=020ms163.289us; 95%=020ms163.289us; 99%=020ms163.289us
Metric: TensorToData
  TotalSamples: 1
  Accumulator: 041.840us
  Percentiles: 1%=041.840us; 5%=041.840us; 10%=041.840us; 20%=041.840us; 50%=041.840us; 80%=041.840us; 90%=041.840us; 95%=041.840us; 99%=041.840us
Metric: TensorsGraphSize
  TotalSamples: 1
  Accumulator: 2.00
  Percentiles: 1%=2.00; 5%=2.00; 10%=2.00; 20%=2.00; 50%=2.00; 80%=2.00; 90%=2.00; 95%=2.00; 99%=2.00
Metric: UnwrapXlaData
  TotalSamples: 2
  Accumulator: 001.170us
  ValueRate: 008ms341.651us / second
  Rate: 14259.2 / second
  Percentiles: 1%=000.070us; 5%=000.070us; 10%=000.070us; 20%=000.070us; 50%=001.100us; 80%=001.100us; 90%=001.100us; 95%=001.100us; 99%=001.100us
Metric: WrapXlaData
  TotalSamples: 1
  Accumulator: 000.050us
  Percentiles: 1%=000.050us; 5%=000.050us; 10%=000.050us; 20%=000.050us; 50%=000.050us; 80%=000.050us; 90%=000.050us; 95%=000.050us; 99%=000.050us
Counter: CreateXlaTensor
  Value: 2
Counter: DestroyLtcTensor
  Value: 1
Counter: DestroyXlaTensor
  Value: 1
Counter: RegisterXLAFunctions
  Value: 1
Counter: UncachedCompile
  Value: 1
Counter: xla::_copy_from_and_resize
  Value: 1
Counter: xla::_propagate_xla_data
  Value: 1
Counter: xla::_to_cpu
  Value: 1
Counter: xla::empty_symint
  Value: 2
Counter: xla::randperm
  Value: 1
Metric: CompileTime
  TotalSamples: 1
  Accumulator: 013ms328.109us
  Percentiles: 1%=013ms328.109us; 5%=013ms328.109us; 10%=013ms328.109us; 20%=013ms328.109us; 50%=013ms328.109us; 80%=013ms328.109us; 90%=013ms328.109us; 95%=013ms328.109us; 99%=013ms328.109us
Metric: ExecuteTime
  TotalSamples: 1
  Accumulator: 053.080us
  Percentiles: 1%=053.080us; 5%=053.080us; 10%=053.080us; 20%=053.080us; 50%=053.080us; 80%=053.080us; 90%=053.080us; 95%=053.080us; 99%=053.080us
Metric: InboundData
  TotalSamples: 1
  Accumulator: 80.00B
  Percentiles: 1%=80.00B; 5%=80.00B; 10%=80.00B; 20%=80.00B; 50%=80.00B; 80%=80.00B; 90%=80.00B; 95%=80.00B; 99%=80.00B
Metric: OutboundData
  TotalSamples: 1
  Accumulator: 80.00B
  Percentiles: 1%=80.00B; 5%=80.00B; 10%=80.00B; 20%=80.00B; 50%=80.00B; 80%=80.00B; 90%=80.00B; 95%=80.00B; 99%=80.00B
Metric: TransferFromDeviceTime
  TotalSamples: 1
  Accumulator: 117.240us
  Percentiles: 1%=117.240us; 5%=117.240us; 10%=117.240us; 20%=117.240us; 50%=117.240us; 80%=117.240us; 90%=117.240us; 95%=117.240us; 99%=117.240us
Metric: TransferToDeviceTime
  TotalSamples: 1
  Accumulator: 011.820us
  Percentiles: 1%=011.820us; 5%=011.820us; 10%=011.820us; 20%=011.820us; 50%=011.820us; 80%=011.820us; 90%=011.820us; 95%=011.820us; 99%=011.820us
Counter: CreateCompileHandles
  Value: 1
Counter: CreateDataHandles
  Value: 2
Counter: aten::randperm
  Value: 1
Counter: aten::randperm.generator_out
  Value: 1

Questions:

  1. @JackCaoG previously suspected that we'd be falling back to CPU here. However I see both an aten::randperm and xla::randperm. The tests in test_core_aten_ops.py also show that the Randperm::Lower method is being invoked and compiled. Does anything stand out here / is this falling back to CPU?
  2. Manually dumping the generated HLO in XlaGraphExecutor::Compile pretty much gives me what I expect. Does torch_xla._XLAC._get_xla_tensors_hlo always work?

There's another theory I have that since N is constant, we can precompute everything and XLA doesn't actually compile / run anything since it's all optimized away. We saw some conditions like that in Tensorflow.

@changm
Copy link
Collaborator

changm commented Jan 24, 2024

Gah sorry for the noise, I found a bug and was able to print all the HLO as expected. Thank you!

@wonjoolee95
Copy link
Collaborator Author

Nice! Just curious, wondering what was the bug?

@changm
Copy link
Collaborator

changm commented Jan 24, 2024

The bug was incorrectly checking an std::optional value before checking a std::optional.has_value(), which caused compilation to crash. It seems like if compilation doesn't work / crashes, PyTorch/XLA just fallsback to upstream and runs on CPU?

@wonjoolee95
Copy link
Collaborator Author

Usually, if there is a crash then the entire program should crash and exit. It should fallback only if there is an explicit call to fallback in case of a crash.

@changm
Copy link
Collaborator

changm commented Jan 29, 2024

I was able to reproduce the build issue locally but I'm actually very confused. I think this goes back to "materializing" an output. Reproducing via:

PYTORCH_TEST_WITH_SLOW=1 python ../test/test_torch.py -k test_put_xla_uint8

The test harness here calls randperm. Add a print(idx) makes the test always pass. Similarly, always forcing CPU fallback in the randperm implementation here always works as well. So maybe back to basics questions:

  1. If I have a program such as:
import torch
import torch_xla
import torch_xla.core.xla_model as xm

a = torch.randperm(5, dtype=torch.int64, device=xm.xla_device())
print(a)

This both runs the lowered HLO and materializes the Tensor?

  1. Any hints or thoughts about if the test harness is incorrectly materializing tensors from randperm since it's now being produced on an XLA device?

  2. Maybe I'm on a wild goose chase and totally off :). Any other hints?

Thanks!

@wonjoolee95
Copy link
Collaborator Author

Closing as it is fixed with #6482.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants