From 2399e10f5426b7e0e194f2ae612e20774bb9b53c Mon Sep 17 00:00:00 2001 From: qihqi Date: Tue, 30 Apr 2024 16:16:15 -0700 Subject: [PATCH] fix addbmm opinfo (#6993) --- .../torch_xla2/docs/fixing_op_info_test.md | 211 ++++++++++++++++++ experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/_ops.py | 9 + 3 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 experimental/torch_xla2/docs/fixing_op_info_test.md diff --git a/experimental/torch_xla2/docs/fixing_op_info_test.md b/experimental/torch_xla2/docs/fixing_op_info_test.md new file mode 100644 index 00000000000..03624f9487e --- /dev/null +++ b/experimental/torch_xla2/docs/fixing_op_info_test.md @@ -0,0 +1,211 @@ +# How to fix an op info test. + +## What is OpInfo test + +PyTorch created a list of python objects (OpInfo) to keep +track how to test each op. This is useful to us because it +ensures that the ops we implement produces the same results +pytorch would produce. + +Context: +* https://dev-discuss.pytorch.org/t/opinfos-in-pytorch-1-10/253 +* https://github.com/pytorch/pytorch/issues/54261 + + +## How to fix one + +### Remove one op from skiplist + +Open [test/test_ops.py](../test/test_ops.py) with your +favorite text editor. +Remove one line from the `skiplist` set. + +i.e. + +```bash +(base) hanq-macbookpro:torch_xla2 hanq$ git diff +diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py +index 72a39ae85..2a156cbce 100644 +--- a/experimental/torch_xla2/test/test_ops.py ++++ b/experimental/torch_xla2/test/test_ops.py +@@ -15,7 +15,6 @@ skiplist = { + "_native_batch_norm_legit", + "_segment_reduce", + "_upsample_bilinear2d_aa", +- "addbmm", + "addmm", + "addmv", + "addr", +``` + +### Run test to see what failure + +Error gotten: + +``` +E RuntimeError: ('No lowering found for\n\nTo execute this test, run the following from the base repo dir:\n python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0', 'aten::addbmm') +``` + +From here we have 2 strategies for fixing this test: + +1. Add an implementation to `aten::addbmm` operator using Jax ops. Or, +2. Add an implementation `aten::addbmm` operator using torch ops (this commonly known as "decompositions"). + +Either way works for torch_xla2. For ops that are not "Core Aten" sometimes we implement in torch ops with the goal of +upstreaming this decomposition to [pytorch decompositon](https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py) +so other projects can benefit from it. + +For illustration purposes, let's implement this op in Jax. + +(NOTE: this doesn't stop us from upstreaming a decomposition later if we want) + +### First Impl + +To implement this op using jax ops, we first find what +is the exact semantics in this page: +https://pytorch.org/docs/stable/generated/torch.addbmm.html + +From it's math formula: we can implement it as follows. + +``` ++@op(torch.ops.aten.addbmm.default) ++def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): ++ ++ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2) ++ return beta * input + alpha * mm +``` + +Now running test again: + +``` +python test/test_ops.py -k test_reference_eager_addbmm_cpu_int64 +``` + +(NOTE: the exact test command is printed out when we run +`pytest test/test_ops.py` so we can only run the failed test instead of running all tests.) + +We now see this error: + +``` +FAIL: test_reference_eager_addbmm_cpu_int64 (__main__.TestOpInfoCPU) [torch_xla2_diff:0.001] +---------------------------------------------------------------------- +Traceback (most recent call last): + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 654, in run_export_and_compare + diff_output( + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/test/test_ops.py", line 617, in diff_output + testcase.assertTrue( +AssertionError: False is not true +``` + +This is telling me that our implementation did not produce +the same result as the ops in PyTorch. + +To debug this, let's figure out what exact input caused this. +We can achieve this by setting a break point [here](https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/test/test_ops.py#L644), right before the diff. Here we can +inspect values of `res` and `res2`, as well as the `sample_input`. + +The sample input we get is +``` +SampleInput(input=tensor([[-3, -3, 9, 8, -8, -3, -4, 2, 2, 2], + [-5, 1, -9, 9, 1, -5, 6, 1, -4, -5], + [-2, -1, 5, -2, -3, 0, 5, -4, 9, -6], + [-1, -7, 6, 3, 8, 3, 8, 9, -5, 7], + [-3, -4, -9, 9, 7, -3, -8, 2, 5, -3]]), args=(tensor([[[-2, 4, -2, 5, 8], + [-6, -2, 5, 7, 7], + [-8, -3, 2, 5, -3], + [-4, 7, 0, -9, 8], + [ 3, 9, -9, -2, 0]], + + [[-7, 1, -3, 7, -4], + [ 3, 5, 4, 6, 5], + [-2, 8, 3, 5, 7], + [ 8, -2, -8, 2, 0], + [ 6, 1, -8, 8, 0]], + + [[ 2, -1, -5, -8, -9], + [ 5, 0, -4, -1, -6], + [-6, 2, -5, -2, -5], + [-5, -3, -5, -4, 9], + [-3, 4, -9, -9, 7]], + + [[ 2, 5, -7, -3, 8], + [-5, -7, -8, -4, 4], + [-4, -6, -3, 0, 6], + [ 8, 0, -3, -8, 2], + [-4, 3, -9, -6, 7]], + + [[ 2, 1, -6, 2, 8], + [ 2, 6, 4, 1, 8], + [-9, 9, -5, 8, 3], + [-5, 0, -2, 4, 0], + [ 5, 8, -4, 9, 7]]]), tensor([[[-1, -8, 3, 5, -8, 2, -5, 0, -9, -5], + [-4, -7, 2, 2, 1, -9, 2, 7, -1, -1], + [ 1, 8, -6, -4, -6, -8, -7, -9, 7, 4], + [-4, 1, -9, 3, 4, 6, 0, -2, -2, -7], + [ 5, 5, 0, 8, -3, 7, -7, 8, 3, 5]], + + [[ 8, -4, -9, 9, 5, 0, 5, 0, -5, 5], + [-5, -3, -2, 8, 1, -2, 4, -7, 5, 3], + [-4, 4, 1, -4, -8, 2, -5, 2, 9, -7], + [ 9, 6, -8, -3, 3, 1, 4, 6, -5, -4], + [-2, 1, 5, 5, 2, 6, 7, -3, -7, 3]], + + [[ 9, -8, 5, -3, -1, 2, -9, -5, -1, -3], + [-3, 3, -9, -7, -9, -8, 1, -3, 7, -2], + [ 8, -1, 8, -8, -7, 4, 8, 8, 5, -7], + [-1, 6, -8, 7, -1, -5, -8, 6, -2, 8], + [-5, -5, 8, 6, 0, 1, 3, -2, -3, -9]], + + [[ 7, -2, 6, -8, -5, 3, 2, -1, -5, 8], + [-6, -4, 3, 9, -9, -8, -7, 3, 9, 0], + [ 1, 3, 4, 4, -5, -2, -4, -2, 3, -7], + [-6, 9, 5, -1, 7, 7, 8, -3, -8, 0], + [-1, -6, -3, 3, 3, -8, -4, 9, -5, 7]], + + [[-5, -3, -9, 6, -1, -7, 9, -8, 1, -8], + [-8, -8, -2, -5, -7, -8, 1, 0, 0, -6], + [ 7, -5, 2, 2, 0, -9, -5, -7, 1, 8], + [-4, 0, 9, 6, -1, -6, 6, -6, -2, -1], + [ 7, 3, 0, 1, 1, -9, 5, -8, -1, -7]]])), kwargs={'beta': 0.6, 'alpha': 0.2}, broadcasts_input=False, name='') +``` + +And the `res` from torch is + +``` +tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) +``` + +So few observation is: +1. Input tensor are of type int64 +2. alpha and beta are both floats. + +So one can suspect that it has to do with rounding. +Reading the doc more carefully, we can find this sentence + + For inputs of type FloatTensor or DoubleTensor, arguments beta and alpha must be real numbers, otherwise they should be integers. + +So likely torch first casted the float alpha and beta to integer, which yields 0, then used them in math to get a matrix with all zeros. + +### Second Impl + +```python ++@op(torch.ops.aten.addbmm.default) ++def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): ++ alpha = jnp.array(alpha).astype(batch1.dtype) ++ beta = jnp.array(beta).astype(batch1.dtype) ++ mm = jnp.einsum('bxy, byz -> xz', batch1, batch2) ++ return jax.lax.cond(beta == 0, ++ lambda: alpha * mm, ++ lambda: beta*input + alpha*mm) ++ +``` + +Adding type casts makes the tests passes. + +### Submit +Now, let's remove the pdb and prints we added, and submit the fix as a PR: https://github.com/pytorch/xla/pull/6993 + diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 72a39ae8582..2a156cbce6c 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -15,7 +15,6 @@ "_native_batch_norm_legit", "_segment_reduce", "_upsample_bilinear2d_aa", - "addbmm", "addmm", "addmv", "addr", diff --git a/experimental/torch_xla2/torch_xla2/_ops.py b/experimental/torch_xla2/torch_xla2/_ops.py index fe0f97a0f01..0eacf2d47a3 100644 --- a/experimental/torch_xla2/torch_xla2/_ops.py +++ b/experimental/torch_xla2/torch_xla2/_ops.py @@ -415,6 +415,15 @@ def _aten_addmm(self, mat1, mat2, *, beta=1.0, alpha=1.0): self += alpha * jnp.matmul(mat1, mat2) return self +@op(torch.ops.aten.addbmm.default) +def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): + alpha = jnp.array(alpha).astype(batch1.dtype) + beta = jnp.array(beta).astype(batch1.dtype) + mm = jnp.einsum('bxy, byz -> xz', batch1, batch2) + return jax.lax.cond(beta == 0, + lambda: alpha * mm, + lambda: beta*input + alpha*mm) + @op(torch.ops.aten.gelu) def _aten_gelu(self, *, approximate="none"):