Skip to content

Commit

Permalink
fix addbmm opinfo
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Apr 29, 2024
1 parent 6443e59 commit dcc78fc
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 1 deletion.
214 changes: 214 additions & 0 deletions experimental/torch_xla2/docs/fixing_op_info_test.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# How to fix an op info test.

## What is OpInfo test

PyTorch created a list of python objects (OpInfo) to keep
track how to test each op. This is useful to us because it
ensures that the ops we implement produces the same results
pytorch would produce.

Context:
* https://dev-discuss.pytorch.org/t/opinfos-in-pytorch-1-10/253
* https://github.com/pytorch/pytorch/issues/54261


## How to fix one

### Remove one op from skiplist

Open [test/test_ops.py](../test/test_ops.py) with your
favorite text editor.
Remove one line from the `skiplist` set.

i.e.

```bash
(base) hanq-macbookpro:torch_xla2 hanq$ git diff
diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py
index 72a39ae85..2a156cbce 100644
--- a/experimental/torch_xla2/test/test_ops.py
+++ b/experimental/torch_xla2/test/test_ops.py
@@ -15,7 +15,6 @@ skiplist = {
"_native_batch_norm_legit",
"_segment_reduce",
"_upsample_bilinear2d_aa",
- "addbmm",
"addmm",
"addmv",
"addr",
```
### Run test to see what failure
Error gotten:
```
E RuntimeError: ('No lowering found for\n\nTo execute this test, run the following from the base repo dir:\n python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0', 'aten::addbmm')
```
From here we have 2 strategies for fixing this test:
1. Add an implementation to `aten::addbmm` operator. Or,
2. Add a decomposition of `aten::addbmm` operator.
Usually we prefer decompositions, specially for ops that are
not "Core Aten".
To find out if it's an Core Aten op or not, we can either check
the [`native_functions.yaml`]() looking for `core` tag, or we can print it out in the runtime:
```python
In [8]: torch.ops.aten.addbmm.default.tags
Out[8]: [<Tag.pt2_compliant_tag: 9>]
```
### First Impl
To implement this op using jax ops, we first find what
is the exact semantics in this page:
https://pytorch.org/docs/stable/generated/torch.addbmm.html
From it's math formula: we can implement it as follows.
```
+@op(torch.ops.aten.addbmm.default)
+def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1):
+
+ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2)
+ return beta * input + alpha * mm
```
Now running test again:
```
python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64
```
(NOTE: the exact test command is printed out when we run
`pytest test/test_ops.py` so we can only run the failed test instead of running all tests.)
We now see this error:
```
FAIL: test_reference_eager_addbmm_cpu_int64 (__main__.TestOpInfoCPU) [torch_xla2_diff:0.001]
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 654, in run_export_and_compare
diff_output(
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 617, in diff_output
testcase.assertTrue(
AssertionError: False is not true
```
This is telling me that our implementation did not produce
the same result as the ops in PyTorch.
To debug this, let's figure out what exact input caused this.
We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/test/test_ops.py#L644), right before the diff. Here we can
inspect values of `res` and `res2`, as well as the `sample_input`.
The sample input we get is
```
SampleInput(input=tensor([[-3, -3, 9, 8, -8, -3, -4, 2, 2, 2],
[-5, 1, -9, 9, 1, -5, 6, 1, -4, -5],
[-2, -1, 5, -2, -3, 0, 5, -4, 9, -6],
[-1, -7, 6, 3, 8, 3, 8, 9, -5, 7],
[-3, -4, -9, 9, 7, -3, -8, 2, 5, -3]]), args=(tensor([[[-2, 4, -2, 5, 8],
[-6, -2, 5, 7, 7],
[-8, -3, 2, 5, -3],
[-4, 7, 0, -9, 8],
[ 3, 9, -9, -2, 0]],
[[-7, 1, -3, 7, -4],
[ 3, 5, 4, 6, 5],
[-2, 8, 3, 5, 7],
[ 8, -2, -8, 2, 0],
[ 6, 1, -8, 8, 0]],
[[ 2, -1, -5, -8, -9],
[ 5, 0, -4, -1, -6],
[-6, 2, -5, -2, -5],
[-5, -3, -5, -4, 9],
[-3, 4, -9, -9, 7]],
[[ 2, 5, -7, -3, 8],
[-5, -7, -8, -4, 4],
[-4, -6, -3, 0, 6],
[ 8, 0, -3, -8, 2],
[-4, 3, -9, -6, 7]],
[[ 2, 1, -6, 2, 8],
[ 2, 6, 4, 1, 8],
[-9, 9, -5, 8, 3],
[-5, 0, -2, 4, 0],
[ 5, 8, -4, 9, 7]]]), tensor([[[-1, -8, 3, 5, -8, 2, -5, 0, -9, -5],
[-4, -7, 2, 2, 1, -9, 2, 7, -1, -1],
[ 1, 8, -6, -4, -6, -8, -7, -9, 7, 4],
[-4, 1, -9, 3, 4, 6, 0, -2, -2, -7],
[ 5, 5, 0, 8, -3, 7, -7, 8, 3, 5]],
[[ 8, -4, -9, 9, 5, 0, 5, 0, -5, 5],
[-5, -3, -2, 8, 1, -2, 4, -7, 5, 3],
[-4, 4, 1, -4, -8, 2, -5, 2, 9, -7],
[ 9, 6, -8, -3, 3, 1, 4, 6, -5, -4],
[-2, 1, 5, 5, 2, 6, 7, -3, -7, 3]],
[[ 9, -8, 5, -3, -1, 2, -9, -5, -1, -3],
[-3, 3, -9, -7, -9, -8, 1, -3, 7, -2],
[ 8, -1, 8, -8, -7, 4, 8, 8, 5, -7],
[-1, 6, -8, 7, -1, -5, -8, 6, -2, 8],
[-5, -5, 8, 6, 0, 1, 3, -2, -3, -9]],
[[ 7, -2, 6, -8, -5, 3, 2, -1, -5, 8],
[-6, -4, 3, 9, -9, -8, -7, 3, 9, 0],
[ 1, 3, 4, 4, -5, -2, -4, -2, 3, -7],
[-6, 9, 5, -1, 7, 7, 8, -3, -8, 0],
[-1, -6, -3, 3, 3, -8, -4, 9, -5, 7]],
[[-5, -3, -9, 6, -1, -7, 9, -8, 1, -8],
[-8, -8, -2, -5, -7, -8, 1, 0, 0, -6],
[ 7, -5, 2, 2, 0, -9, -5, -7, 1, 8],
[-4, 0, 9, 6, -1, -6, 6, -6, -2, -1],
[ 7, 3, 0, 1, 1, -9, 5, -8, -1, -7]]])), kwargs={'beta': 0.6, 'alpha': 0.2}, broadcasts_input=False, name='')
```
And the `res` from torch is
```
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
```
So few observation is:
1. Input tensor are of type int64
2. alpha and beta are both floats.
So one can suspect that it has to do with rounding.
Reading the doc more carefully, we can find this sentence
For inputs of type FloatTensor or DoubleTensor, arguments beta and alpha must be real numbers, otherwise they should be integers.
So likely torch first casted the float alpha and beta to integer, which yields 0, then used them in math to get a matrix with all zeros.
### Second Impl
```python
+@op(torch.ops.aten.addbmm.default)
+def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1):
+ alpha = jnp.array(alpha).astype(batch1.dtype)
+ beta = jnp.array(beta).astype(batch1.dtype)
+ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2)
+ return jax.lax.cond(beta == 0,
+ lambda: alpha * mm,
+ lambda: beta*input + alpha*mm)
+
```
Adding type casts makes the tests passes.
### Submit
Now, let's remove the pdb and prints we added, and submit the fix as a PR:
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"_native_batch_norm_legit",
"_segment_reduce",
"_upsample_bilinear2d_aa",
"addbmm",
"addmm",
"addmv",
"addr",
Expand Down
9 changes: 9 additions & 0 deletions experimental/torch_xla2/torch_xla2/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,15 @@ def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0):
self += alpha * jnp.matmul(mat1, mat2)
return self

@op(torch.ops.aten.addbmm.default)
def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1):
alpha = jnp.array(alpha).astype(batch1.dtype)
beta = jnp.array(beta).astype(batch1.dtype)
mm = jnp.einsum('bxy, byz -> xz', batch1, batch2)
return jax.lax.cond(beta == 0,
lambda: alpha * mm,
lambda: beta*input + alpha*mm)


@op(torch.ops.aten.gelu)
def _aten_gelu(self, *, approximate="none"):
Expand Down

0 comments on commit dcc78fc

Please sign in to comment.