Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix addbmm opinfo #6993

Merged
merged 1 commit into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions experimental/torch_xla2/docs/fixing_op_info_test.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# How to fix an op info test.

## What is OpInfo test

PyTorch created a list of python objects (OpInfo) to keep
track how to test each op. This is useful to us because it
ensures that the ops we implement produces the same results
pytorch would produce.

Context:
* https://dev-discuss.pytorch.org/t/opinfos-in-pytorch-1-10/253
* https://github.com/pytorch/pytorch/issues/54261


## How to fix one

### Remove one op from skiplist

Open [test/test_ops.py](../test/test_ops.py) with your
favorite text editor.
Remove one line from the `skiplist` set.

i.e.

```bash
(base) hanq-macbookpro:torch_xla2 hanq$ git diff
diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py
index 72a39ae85..2a156cbce 100644
--- a/experimental/torch_xla2/test/test_ops.py
+++ b/experimental/torch_xla2/test/test_ops.py
@@ -15,7 +15,6 @@ skiplist = {
"_native_batch_norm_legit",
"_segment_reduce",
"_upsample_bilinear2d_aa",
- "addbmm",
"addmm",
"addmv",
"addr",
```

### Run test to see what failure

Error gotten:

```
E RuntimeError: ('No lowering found for\n\nTo execute this test, run the following from the base repo dir:\n python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0', 'aten::addbmm')
```

From here we have 2 strategies for fixing this test:

1. Add an implementation to `aten::addbmm` operator using Jax ops. Or,
2. Add an implementation `aten::addbmm` operator using torch ops (this commonly known as "decompositions").

Either way works for torch_xla2. For ops that are not "Core Aten" sometimes we implement in torch ops with the goal of
upstreaming this decomposition to [pytorch decompositon](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py)
so other projects can benefit from it.

For illustration purposes, let's implement this op in Jax.

(NOTE: this doesn't stop us from upstreaming a decomposition later if we want)

### First Impl

To implement this op using jax ops, we first find what
is the exact semantics in this page:
https://pytorch.org/docs/stable/generated/torch.addbmm.html

From it's math formula: we can implement it as follows.

```
+@op(torch.ops.aten.addbmm.default)
+def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1):
+
+ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2)
+ return beta * input + alpha * mm
```

Now running test again:

```
python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64
```

(NOTE: the exact test command is printed out when we run
`pytest test/test_ops.py` so we can only run the failed test instead of running all tests.)

We now see this error:

```
FAIL: test_reference_eager_addbmm_cpu_int64 (__main__.TestOpInfoCPU) [torch_xla2_diff:0.001]
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 654, in run_export_and_compare
diff_output(
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 617, in diff_output
testcase.assertTrue(
AssertionError: False is not true
```

This is telling me that our implementation did not produce
the same result as the ops in PyTorch.

To debug this, let's figure out what exact input caused this.
We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/test/test_ops.py#L644), right before the diff. Here we can
inspect values of `res` and `res2`, as well as the `sample_input`.

The sample input we get is
```
SampleInput(input=tensor([[-3, -3, 9, 8, -8, -3, -4, 2, 2, 2],
[-5, 1, -9, 9, 1, -5, 6, 1, -4, -5],
[-2, -1, 5, -2, -3, 0, 5, -4, 9, -6],
[-1, -7, 6, 3, 8, 3, 8, 9, -5, 7],
[-3, -4, -9, 9, 7, -3, -8, 2, 5, -3]]), args=(tensor([[[-2, 4, -2, 5, 8],
[-6, -2, 5, 7, 7],
[-8, -3, 2, 5, -3],
[-4, 7, 0, -9, 8],
[ 3, 9, -9, -2, 0]],

[[-7, 1, -3, 7, -4],
[ 3, 5, 4, 6, 5],
[-2, 8, 3, 5, 7],
[ 8, -2, -8, 2, 0],
[ 6, 1, -8, 8, 0]],

[[ 2, -1, -5, -8, -9],
[ 5, 0, -4, -1, -6],
[-6, 2, -5, -2, -5],
[-5, -3, -5, -4, 9],
[-3, 4, -9, -9, 7]],

[[ 2, 5, -7, -3, 8],
[-5, -7, -8, -4, 4],
[-4, -6, -3, 0, 6],
[ 8, 0, -3, -8, 2],
[-4, 3, -9, -6, 7]],

[[ 2, 1, -6, 2, 8],
[ 2, 6, 4, 1, 8],
[-9, 9, -5, 8, 3],
[-5, 0, -2, 4, 0],
[ 5, 8, -4, 9, 7]]]), tensor([[[-1, -8, 3, 5, -8, 2, -5, 0, -9, -5],
[-4, -7, 2, 2, 1, -9, 2, 7, -1, -1],
[ 1, 8, -6, -4, -6, -8, -7, -9, 7, 4],
[-4, 1, -9, 3, 4, 6, 0, -2, -2, -7],
[ 5, 5, 0, 8, -3, 7, -7, 8, 3, 5]],

[[ 8, -4, -9, 9, 5, 0, 5, 0, -5, 5],
[-5, -3, -2, 8, 1, -2, 4, -7, 5, 3],
[-4, 4, 1, -4, -8, 2, -5, 2, 9, -7],
[ 9, 6, -8, -3, 3, 1, 4, 6, -5, -4],
[-2, 1, 5, 5, 2, 6, 7, -3, -7, 3]],

[[ 9, -8, 5, -3, -1, 2, -9, -5, -1, -3],
[-3, 3, -9, -7, -9, -8, 1, -3, 7, -2],
[ 8, -1, 8, -8, -7, 4, 8, 8, 5, -7],
[-1, 6, -8, 7, -1, -5, -8, 6, -2, 8],
[-5, -5, 8, 6, 0, 1, 3, -2, -3, -9]],

[[ 7, -2, 6, -8, -5, 3, 2, -1, -5, 8],
[-6, -4, 3, 9, -9, -8, -7, 3, 9, 0],
[ 1, 3, 4, 4, -5, -2, -4, -2, 3, -7],
[-6, 9, 5, -1, 7, 7, 8, -3, -8, 0],
[-1, -6, -3, 3, 3, -8, -4, 9, -5, 7]],

[[-5, -3, -9, 6, -1, -7, 9, -8, 1, -8],
[-8, -8, -2, -5, -7, -8, 1, 0, 0, -6],
[ 7, -5, 2, 2, 0, -9, -5, -7, 1, 8],
[-4, 0, 9, 6, -1, -6, 6, -6, -2, -1],
[ 7, 3, 0, 1, 1, -9, 5, -8, -1, -7]]])), kwargs={'beta': 0.6, 'alpha': 0.2}, broadcasts_input=False, name='')
```

And the `res` from torch is

```
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
```

So few observation is:
1. Input tensor are of type int64
2. alpha and beta are both floats.

So one can suspect that it has to do with rounding.
Reading the doc more carefully, we can find this sentence

For inputs of type FloatTensor or DoubleTensor, arguments beta and alpha must be real numbers, otherwise they should be integers.

So likely torch first casted the float alpha and beta to integer, which yields 0, then used them in math to get a matrix with all zeros.

### Second Impl

```python
+@op(torch.ops.aten.addbmm.default)
+def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1):
+ alpha = jnp.array(alpha).astype(batch1.dtype)
+ beta = jnp.array(beta).astype(batch1.dtype)
+ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2)
+ return jax.lax.cond(beta == 0,
+ lambda: alpha * mm,
+ lambda: beta*input + alpha*mm)
+
```

Adding type casts makes the tests passes.

### Submit
Now, let's remove the pdb and prints we added, and submit the fix as a PR: https://github.com/pytorch/xla/pull/6993

1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"_native_batch_norm_legit",
"_segment_reduce",
"_upsample_bilinear2d_aa",
"addbmm",
"addmm",
"addmv",
"addr",
Expand Down
9 changes: 9 additions & 0 deletions experimental/torch_xla2/torch_xla2/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,15 @@ def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0):
self += alpha * jnp.matmul(mat1, mat2)
return self

@op(torch.ops.aten.addbmm.default)
def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1):
alpha = jnp.array(alpha).astype(batch1.dtype)
beta = jnp.array(beta).astype(batch1.dtype)
mm = jnp.einsum('bxy, byz -> xz', batch1, batch2)
return jax.lax.cond(beta == 0,
lambda: alpha * mm,
lambda: beta*input + alpha*mm)


@op(torch.ops.aten.gelu)
def _aten_gelu(self, *, approximate="none"):
Expand Down
Loading