Skip to content

Commit

Permalink
Add some helpful debugging notes to TROUBLESHOOTING docs
Browse files Browse the repository at this point in the history
  • Loading branch information
changm committed Feb 1, 2024
1 parent 010b6f0 commit 61914a8
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions TROUBLESHOOTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ We don't expect users to use tools in this section to debug their models. But we
them when you submit a bug report since they provide additional information that metrics report
doesn't have.

* ```print(torch_xla._XLAC._get_xla_tensors_text([res]))``` where `res` is the result tensor prints out the IR.
* ```print(torch_xla._XLAC._get_xla_tensors_hlo([res]))``` where `res` is the result tensor prints out the generated XLA HLO.

Note these functions must be called prior to `mark_step()`, otherwise the tensor will already be materialized.

### Environment Variables

There are also a number of environment variables which control the behavior of the _PyTorch/XLA_
Expand Down Expand Up @@ -371,3 +376,16 @@ only be enabled for debugging.
```
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE="xla_graph_executor=5,pjrt_computation_client=3"
```

### Reproducing PyTorch/XLA CI/CD unit test failures.

You may see some test failures for a PR such as:

```
To execute this test, run the following from the base repo dir:
PYTORCH_TEST_WITH_SLOW=1 python ../test/test_torch.py -k test_put_xla_uint8
```

Running this directly in the command line does not work. You need to set the environment variable `TORCH_TEST_DEVICES` to your local `pytorch/xla/test/pytorch_test_base.py`. For example:

`TORCH_TEST_DEVICES='/path/to/pytorch/xla/test/pytorch_test_base.py PYTORCH_TEST_WITH_SLOW=1 python ../test/test_torch.py -k test_put_xla_uint8` should work.

0 comments on commit 61914a8

Please sign in to comment.