-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
220 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
# How to fix an op info test. | ||
|
||
## What is OpInfo test | ||
|
||
PyTorch created a list of python objects (OpInfo) to keep | ||
track how to test each op. This is useful to us because it | ||
ensures that the ops we implement produces the same results | ||
pytorch would produce. | ||
|
||
Context: | ||
* https://dev-discuss.pytorch.org/t/opinfos-in-pytorch-1-10/253 | ||
* https://github.com/pytorch/pytorch/issues/54261 | ||
|
||
|
||
## How to fix one | ||
|
||
### Remove one op from skiplist | ||
|
||
Open [test/test_ops.py](../test/test_ops.py) with your | ||
favorite text editor. | ||
Remove one line from the `skiplist` set. | ||
|
||
i.e. | ||
|
||
```bash | ||
(base) hanq-macbookpro:torch_xla2 hanq$ git diff | ||
diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py | ||
index 72a39ae85..2a156cbce 100644 | ||
--- a/experimental/torch_xla2/test/test_ops.py | ||
+++ b/experimental/torch_xla2/test/test_ops.py | ||
@@ -15,7 +15,6 @@ skiplist = { | ||
"_native_batch_norm_legit", | ||
"_segment_reduce", | ||
"_upsample_bilinear2d_aa", | ||
- "addbmm", | ||
"addmm", | ||
"addmv", | ||
"addr", | ||
``` | ||
### Run test to see what failure | ||
Error gotten: | ||
``` | ||
E RuntimeError: ('No lowering found for\n\nTo execute this test, run the following from the base repo dir:\n python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0', 'aten::addbmm') | ||
``` | ||
From here we have 2 strategies for fixing this test: | ||
1. Add an implementation to `aten::addbmm` operator using Jax ops. Or, | ||
2. Add an implementation `aten::addbmm` operator using torch ops (this commonly known as "decompositions"). | ||
Either way works for torch_xla2. For ops that are not "Core Aten" sometimes we implement in torch ops with the goal of | ||
upstreaming this decomposition to [pytorch decompositon](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py) | ||
so other projects can benefit from it. | ||
For illustration purposes, let's implement this op in Jax. | ||
(NOTE: this doesn't stop us from upstreaming a decomposition later if we want) | ||
### First Impl | ||
To implement this op using jax ops, we first find what | ||
is the exact semantics in this page: | ||
https://pytorch.org/docs/stable/generated/torch.addbmm.html | ||
From it's math formula: we can implement it as follows. | ||
``` | ||
+@op(torch.ops.aten.addbmm.default) | ||
+def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): | ||
+ | ||
+ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2) | ||
+ return beta * input + alpha * mm | ||
``` | ||
Now running test again: | ||
``` | ||
python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64 | ||
``` | ||
(NOTE: the exact test command is printed out when we run | ||
`pytest test/test_ops.py` so we can only run the failed test instead of running all tests.) | ||
We now see this error: | ||
``` | ||
FAIL: test_reference_eager_addbmm_cpu_int64 (__main__.TestOpInfoCPU) [torch_xla2_diff:0.001] | ||
---------------------------------------------------------------------- | ||
Traceback (most recent call last): | ||
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 654, in run_export_and_compare | ||
diff_output( | ||
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 617, in diff_output | ||
testcase.assertTrue( | ||
AssertionError: False is not true | ||
``` | ||
This is telling me that our implementation did not produce | ||
the same result as the ops in PyTorch. | ||
To debug this, let's figure out what exact input caused this. | ||
We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/test/test_ops.py#L644), right before the diff. Here we can | ||
inspect values of `res` and `res2`, as well as the `sample_input`. | ||
The sample input we get is | ||
``` | ||
SampleInput(input=tensor([[-3, -3, 9, 8, -8, -3, -4, 2, 2, 2], | ||
[-5, 1, -9, 9, 1, -5, 6, 1, -4, -5], | ||
[-2, -1, 5, -2, -3, 0, 5, -4, 9, -6], | ||
[-1, -7, 6, 3, 8, 3, 8, 9, -5, 7], | ||
[-3, -4, -9, 9, 7, -3, -8, 2, 5, -3]]), args=(tensor([[[-2, 4, -2, 5, 8], | ||
[-6, -2, 5, 7, 7], | ||
[-8, -3, 2, 5, -3], | ||
[-4, 7, 0, -9, 8], | ||
[ 3, 9, -9, -2, 0]], | ||
[[-7, 1, -3, 7, -4], | ||
[ 3, 5, 4, 6, 5], | ||
[-2, 8, 3, 5, 7], | ||
[ 8, -2, -8, 2, 0], | ||
[ 6, 1, -8, 8, 0]], | ||
[[ 2, -1, -5, -8, -9], | ||
[ 5, 0, -4, -1, -6], | ||
[-6, 2, -5, -2, -5], | ||
[-5, -3, -5, -4, 9], | ||
[-3, 4, -9, -9, 7]], | ||
[[ 2, 5, -7, -3, 8], | ||
[-5, -7, -8, -4, 4], | ||
[-4, -6, -3, 0, 6], | ||
[ 8, 0, -3, -8, 2], | ||
[-4, 3, -9, -6, 7]], | ||
[[ 2, 1, -6, 2, 8], | ||
[ 2, 6, 4, 1, 8], | ||
[-9, 9, -5, 8, 3], | ||
[-5, 0, -2, 4, 0], | ||
[ 5, 8, -4, 9, 7]]]), tensor([[[-1, -8, 3, 5, -8, 2, -5, 0, -9, -5], | ||
[-4, -7, 2, 2, 1, -9, 2, 7, -1, -1], | ||
[ 1, 8, -6, -4, -6, -8, -7, -9, 7, 4], | ||
[-4, 1, -9, 3, 4, 6, 0, -2, -2, -7], | ||
[ 5, 5, 0, 8, -3, 7, -7, 8, 3, 5]], | ||
[[ 8, -4, -9, 9, 5, 0, 5, 0, -5, 5], | ||
[-5, -3, -2, 8, 1, -2, 4, -7, 5, 3], | ||
[-4, 4, 1, -4, -8, 2, -5, 2, 9, -7], | ||
[ 9, 6, -8, -3, 3, 1, 4, 6, -5, -4], | ||
[-2, 1, 5, 5, 2, 6, 7, -3, -7, 3]], | ||
[[ 9, -8, 5, -3, -1, 2, -9, -5, -1, -3], | ||
[-3, 3, -9, -7, -9, -8, 1, -3, 7, -2], | ||
[ 8, -1, 8, -8, -7, 4, 8, 8, 5, -7], | ||
[-1, 6, -8, 7, -1, -5, -8, 6, -2, 8], | ||
[-5, -5, 8, 6, 0, 1, 3, -2, -3, -9]], | ||
[[ 7, -2, 6, -8, -5, 3, 2, -1, -5, 8], | ||
[-6, -4, 3, 9, -9, -8, -7, 3, 9, 0], | ||
[ 1, 3, 4, 4, -5, -2, -4, -2, 3, -7], | ||
[-6, 9, 5, -1, 7, 7, 8, -3, -8, 0], | ||
[-1, -6, -3, 3, 3, -8, -4, 9, -5, 7]], | ||
[[-5, -3, -9, 6, -1, -7, 9, -8, 1, -8], | ||
[-8, -8, -2, -5, -7, -8, 1, 0, 0, -6], | ||
[ 7, -5, 2, 2, 0, -9, -5, -7, 1, 8], | ||
[-4, 0, 9, 6, -1, -6, 6, -6, -2, -1], | ||
[ 7, 3, 0, 1, 1, -9, 5, -8, -1, -7]]])), kwargs={'beta': 0.6, 'alpha': 0.2}, broadcasts_input=False, name='') | ||
``` | ||
And the `res` from torch is | ||
``` | ||
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | ||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | ||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | ||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | ||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) | ||
``` | ||
So few observation is: | ||
1. Input tensor are of type int64 | ||
2. alpha and beta are both floats. | ||
So one can suspect that it has to do with rounding. | ||
Reading the doc more carefully, we can find this sentence | ||
For inputs of type FloatTensor or DoubleTensor, arguments beta and alpha must be real numbers, otherwise they should be integers. | ||
So likely torch first casted the float alpha and beta to integer, which yields 0, then used them in math to get a matrix with all zeros. | ||
### Second Impl | ||
```python | ||
+@op(torch.ops.aten.addbmm.default) | ||
+def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): | ||
+ alpha = jnp.array(alpha).astype(batch1.dtype) | ||
+ beta = jnp.array(beta).astype(batch1.dtype) | ||
+ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2) | ||
+ return jax.lax.cond(beta == 0, | ||
+ lambda: alpha * mm, | ||
+ lambda: beta*input + alpha*mm) | ||
+ | ||
``` | ||
Adding type casts makes the tests passes. | ||
### Submit | ||
Now, let's remove the pdb and prints we added, and submit the fix as a PR: https://github.com/pytorch/xla/pull/6993 | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters